123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- import sys
- from pathlib import Path
- from typing import List, Literal, TypedDict
- from unittest.mock import patch
- import pytest
- import torch
- from llama_recipes.inference.chat_utils import read_dialogs_from_file
- ROOT_DIR = Path(__file__).parents[1]
- CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
- sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
- Role = Literal["user", "assistant"]
- class Message(TypedDict):
- role: Role
- content: str
- Dialog = List[Message]
- B_INST, E_INST = "[INST]", "[/INST]"
- B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
- def _encode_header(message, tokenizer):
- tokens = []
- tokens.extend(tokenizer.encode("<|start_header_id|>"))
- tokens.extend(tokenizer.encode(message["role"]))
- tokens.extend(tokenizer.encode("<|end_header_id|>"))
- tokens.extend(tokenizer.encode("\n\n"))
- return tokens
- def _encode_message(message, tokenizer):
- tokens = _encode_header(message, tokenizer)
- tokens.extend(tokenizer.encode(message["content"].strip()))
- tokens.extend(tokenizer.encode("<|eot_id|>"))
- return tokens
- def _format_dialog(dialog, tokenizer):
- tokens = []
- tokens.extend(tokenizer.encode("<|begin_of_text|>"))
- for msg in dialog:
- tokens.extend(_encode_message(msg, tokenizer))
- tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
- return tokens
- def _format_tokens_llama3(dialogs, tokenizer):
- return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
- def _format_tokens_llama2(dialogs, tokenizer):
- prompt_tokens = []
- for dialog in dialogs:
- if dialog[0]["role"] == "system":
- dialog = [
- {
- "role": dialog[1]["role"],
- "content": B_SYS
- + dialog[0]["content"]
- + E_SYS
- + dialog[1]["content"],
- }
- ] + dialog[2:]
- assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
- [msg["role"] == "assistant" for msg in dialog[1::2]]
- ), (
- "model only supports 'system','user' and 'assistant' roles, "
- "starting with user and alternating (u/a/u/a/u...)"
- )
- """
- Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
- Here, we are adding it manually.
- """
- dialog_tokens: List[int] = sum(
- [
- tokenizer.encode(
- f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
- )
- + [tokenizer.eos_token_id]
- for prompt, answer in zip(dialog[::2], dialog[1::2])
- ],
- [],
- )
- assert (
- dialog[-1]["role"] == "user"
- ), f"Last message must be from user, got {dialog[-1]['role']}"
- dialog_tokens += tokenizer.encode(
- f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
- )
- prompt_tokens.append(dialog_tokens)
- return prompt_tokens
- @pytest.mark.skip_missing_tokenizer
- @patch("chat_completion.AutoTokenizer")
- @patch("chat_completion.load_model")
- def test_chat_completion(
- load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
- ):
- from chat_completion import main
- setup_tokenizer(tokenizer)
- kwargs = {
- "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
- }
- main(llama_version, **kwargs)
- dialogs = read_dialogs_from_file(kwargs["prompt_file"])
- format_tokens = (
- _format_tokens_llama2
- if llama_version == "meta-llama/Llama-2-7b-hf"
- else _format_tokens_llama3
- )
- REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
- assert all(
- (
- load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
- == torch.tensor(REF_RESULT[0]).long()
- ).tolist()
- )
- assert all(
- (
- load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
- == torch.tensor(REF_RESULT[1]).long()
- ).tolist()
- )
- assert all(
- (
- load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
- == torch.tensor(REF_RESULT[2]).long()
- ).tolist()
- )
- assert all(
- (
- load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
- == torch.tensor(REF_RESULT[3]).long()
- ).tolist()
- )
- assert all(
- (
- load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
- == torch.tensor(REF_RESULT[4]).long()
- ).tolist()
- )
|