|
@@ -0,0 +1,155 @@
|
|
|
+import sys
|
|
|
+from pathlib import Path
|
|
|
+from typing import List, Literal, TypedDict
|
|
|
+from unittest.mock import patch
|
|
|
+
|
|
|
+import pytest
|
|
|
+import torch
|
|
|
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
|
|
|
+
|
|
|
+ROOT_DIR = Path(__file__).parents[1]
|
|
|
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
|
|
|
+
|
|
|
+sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
|
|
|
+
|
|
|
+Role = Literal["user", "assistant"]
|
|
|
+
|
|
|
+
|
|
|
+class Message(TypedDict):
|
|
|
+ role: Role
|
|
|
+ content: str
|
|
|
+
|
|
|
+
|
|
|
+Dialog = List[Message]
|
|
|
+
|
|
|
+B_INST, E_INST = "[INST]", "[/INST]"
|
|
|
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
|
|
+
|
|
|
+
|
|
|
+def _encode_header(message, tokenizer):
|
|
|
+ tokens = []
|
|
|
+ tokens.extend(tokenizer.encode("<|start_header_id|>"))
|
|
|
+ tokens.extend(tokenizer.encode(message["role"]))
|
|
|
+ tokens.extend(tokenizer.encode("<|end_header_id|>"))
|
|
|
+ tokens.extend(tokenizer.encode("\n\n"))
|
|
|
+ return tokens
|
|
|
+
|
|
|
+
|
|
|
+def _encode_message(message, tokenizer):
|
|
|
+ tokens = _encode_header(message, tokenizer)
|
|
|
+ tokens.extend(tokenizer.encode(message["content"].strip()))
|
|
|
+ tokens.extend(tokenizer.encode("<|eot_id|>"))
|
|
|
+ return tokens
|
|
|
+
|
|
|
+
|
|
|
+def _format_dialog(dialog, tokenizer):
|
|
|
+ tokens = []
|
|
|
+ tokens.extend(tokenizer.encode("<|begin_of_text|>"))
|
|
|
+ for msg in dialog:
|
|
|
+ tokens.extend(_encode_message(msg, tokenizer))
|
|
|
+ tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
|
|
|
+ return tokens
|
|
|
+
|
|
|
+
|
|
|
+def _format_tokens_llama3(dialogs, tokenizer):
|
|
|
+ return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
|
|
|
+
|
|
|
+
|
|
|
+def _format_tokens_llama2(dialogs, tokenizer):
|
|
|
+ prompt_tokens = []
|
|
|
+ for dialog in dialogs:
|
|
|
+ if dialog[0]["role"] == "system":
|
|
|
+ dialog = [
|
|
|
+ {
|
|
|
+ "role": dialog[1]["role"],
|
|
|
+ "content": B_SYS
|
|
|
+ + dialog[0]["content"]
|
|
|
+ + E_SYS
|
|
|
+ + dialog[1]["content"],
|
|
|
+ }
|
|
|
+ ] + dialog[2:]
|
|
|
+ assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
|
|
|
+ [msg["role"] == "assistant" for msg in dialog[1::2]]
|
|
|
+ ), (
|
|
|
+ "model only supports 'system','user' and 'assistant' roles, "
|
|
|
+ "starting with user and alternating (u/a/u/a/u...)"
|
|
|
+ )
|
|
|
+ """
|
|
|
+ Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
|
|
|
+ Here, we are adding it manually.
|
|
|
+ """
|
|
|
+ dialog_tokens: List[int] = sum(
|
|
|
+ [
|
|
|
+ tokenizer.encode(
|
|
|
+ f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
|
|
|
+ )
|
|
|
+ + [tokenizer.eos_token_id]
|
|
|
+ for prompt, answer in zip(dialog[::2], dialog[1::2])
|
|
|
+ ],
|
|
|
+ [],
|
|
|
+ )
|
|
|
+ assert (
|
|
|
+ dialog[-1]["role"] == "user"
|
|
|
+ ), f"Last message must be from user, got {dialog[-1]['role']}"
|
|
|
+ dialog_tokens += tokenizer.encode(
|
|
|
+ f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
|
|
|
+ )
|
|
|
+ prompt_tokens.append(dialog_tokens)
|
|
|
+ return prompt_tokens
|
|
|
+
|
|
|
+
|
|
|
+@pytest.mark.skip_missing_tokenizer
|
|
|
+@patch("chat_completion.AutoTokenizer")
|
|
|
+@patch("chat_completion.load_model")
|
|
|
+def test_chat_completion(
|
|
|
+ load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
|
|
|
+):
|
|
|
+ from chat_completion import main
|
|
|
+
|
|
|
+ setup_tokenizer(tokenizer)
|
|
|
+
|
|
|
+ kwargs = {
|
|
|
+ "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
|
|
|
+ }
|
|
|
+
|
|
|
+ main(llama_version, **kwargs)
|
|
|
+
|
|
|
+ dialogs = read_dialogs_from_file(kwargs["prompt_file"])
|
|
|
+ format_tokens = (
|
|
|
+ _format_tokens_llama2
|
|
|
+ if llama_version == "meta-llama/Llama-2-7b-hf"
|
|
|
+ else _format_tokens_llama3
|
|
|
+ )
|
|
|
+
|
|
|
+ REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
|
|
|
+
|
|
|
+ assert all(
|
|
|
+ (
|
|
|
+ load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
|
|
|
+ == torch.tensor(REF_RESULT[0]).long()
|
|
|
+ ).tolist()
|
|
|
+ )
|
|
|
+ assert all(
|
|
|
+ (
|
|
|
+ load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
|
|
|
+ == torch.tensor(REF_RESULT[1]).long()
|
|
|
+ ).tolist()
|
|
|
+ )
|
|
|
+ assert all(
|
|
|
+ (
|
|
|
+ load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
|
|
|
+ == torch.tensor(REF_RESULT[2]).long()
|
|
|
+ ).tolist()
|
|
|
+ )
|
|
|
+ assert all(
|
|
|
+ (
|
|
|
+ load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
|
|
|
+ == torch.tensor(REF_RESULT[3]).long()
|
|
|
+ ).tolist()
|
|
|
+ )
|
|
|
+ assert all(
|
|
|
+ (
|
|
|
+ load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
|
|
|
+ == torch.tensor(REF_RESULT[4]).long()
|
|
|
+ ).tolist()
|
|
|
+ )
|