import sys from pathlib import Path from typing import List, Literal, TypedDict from unittest.mock import patch import pytest import torch from llama_recipes.inference.chat_utils import read_dialogs_from_file ROOT_DIR = Path(__file__).parents[1] CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/" sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path Role = Literal["user", "assistant"] class Message(TypedDict): role: Role content: str Dialog = List[Message] B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" def _encode_header(message, tokenizer): tokens = [] tokens.extend(tokenizer.encode("<|start_header_id|>")) tokens.extend(tokenizer.encode(message["role"])) tokens.extend(tokenizer.encode("<|end_header_id|>")) tokens.extend(tokenizer.encode("\n\n")) return tokens def _encode_message(message, tokenizer): tokens = _encode_header(message, tokenizer) tokens.extend(tokenizer.encode(message["content"].strip())) tokens.extend(tokenizer.encode("<|eot_id|>")) return tokens def _format_dialog(dialog, tokenizer): tokens = [] tokens.extend(tokenizer.encode("<|begin_of_text|>")) for msg in dialog: tokens.extend(_encode_message(msg, tokenizer)) tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer)) return tokens def _format_tokens_llama3(dialogs, tokenizer): return [_format_dialog(dialog, tokenizer) for dialog in dialogs] def _format_tokens_llama2(dialogs, tokenizer): prompt_tokens = [] for dialog in dialogs: if dialog[0]["role"] == "system": dialog = [ { "role": dialog[1]["role"], "content": B_SYS + dialog[0]["content"] + E_SYS + dialog[1]["content"], } ] + dialog[2:] assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( [msg["role"] == "assistant" for msg in dialog[1::2]] ), ( "model only supports 'system','user' and 'assistant' roles, " "starting with user and alternating (u/a/u/a/u...)" ) """ Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs. Here, we are adding it manually. """ dialog_tokens: List[int] = sum( [ tokenizer.encode( f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", ) + [tokenizer.eos_token_id] for prompt, answer in zip(dialog[::2], dialog[1::2]) ], [], ) assert ( dialog[-1]["role"] == "user" ), f"Last message must be from user, got {dialog[-1]['role']}" dialog_tokens += tokenizer.encode( f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", ) prompt_tokens.append(dialog_tokens) return prompt_tokens @pytest.mark.skip_missing_tokenizer @patch("chat_completion.AutoTokenizer") @patch("chat_completion.load_model") def test_chat_completion( load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version ): from chat_completion import main setup_tokenizer(tokenizer) kwargs = { "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(), } main(llama_version, **kwargs) dialogs = read_dialogs_from_file(kwargs["prompt_file"]) format_tokens = ( _format_tokens_llama2 if llama_version == "meta-llama/Llama-2-7b-hf" else _format_tokens_llama3 ) REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version]) assert all( ( load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu() == torch.tensor(REF_RESULT[0]).long() ).tolist() ) assert all( ( load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu() == torch.tensor(REF_RESULT[1]).long() ).tolist() ) assert all( ( load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu() == torch.tensor(REF_RESULT[2]).long() ).tolist() ) assert all( ( load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu() == torch.tensor(REF_RESULT[3]).long() ).tolist() ) assert all( ( load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu() == torch.tensor(REF_RESULT[4]).long() ).tolist() )