test_chat_completion.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import sys
  2. from pathlib import Path
  3. from typing import List, Literal, TypedDict
  4. from unittest.mock import patch
  5. import pytest
  6. import torch
  7. from llama_recipes.inference.chat_utils import read_dialogs_from_file
  8. ROOT_DIR = Path(__file__).parents[1]
  9. CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
  10. sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
  11. Role = Literal["user", "assistant"]
  12. class Message(TypedDict):
  13. role: Role
  14. content: str
  15. Dialog = List[Message]
  16. B_INST, E_INST = "[INST]", "[/INST]"
  17. B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
  18. def _encode_header(message, tokenizer):
  19. tokens = []
  20. tokens.extend(tokenizer.encode("<|start_header_id|>"))
  21. tokens.extend(tokenizer.encode(message["role"]))
  22. tokens.extend(tokenizer.encode("<|end_header_id|>"))
  23. tokens.extend(tokenizer.encode("\n\n"))
  24. return tokens
  25. def _encode_message(message, tokenizer):
  26. tokens = _encode_header(message, tokenizer)
  27. tokens.extend(tokenizer.encode(message["content"].strip()))
  28. tokens.extend(tokenizer.encode("<|eot_id|>"))
  29. return tokens
  30. def _format_dialog(dialog, tokenizer):
  31. tokens = []
  32. tokens.extend(tokenizer.encode("<|begin_of_text|>"))
  33. for msg in dialog:
  34. tokens.extend(_encode_message(msg, tokenizer))
  35. tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
  36. return tokens
  37. def _format_tokens_llama3(dialogs, tokenizer):
  38. return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
  39. def _format_tokens_llama2(dialogs, tokenizer):
  40. prompt_tokens = []
  41. for dialog in dialogs:
  42. if dialog[0]["role"] == "system":
  43. dialog = [
  44. {
  45. "role": dialog[1]["role"],
  46. "content": B_SYS
  47. + dialog[0]["content"]
  48. + E_SYS
  49. + dialog[1]["content"],
  50. }
  51. ] + dialog[2:]
  52. assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
  53. [msg["role"] == "assistant" for msg in dialog[1::2]]
  54. ), (
  55. "model only supports 'system','user' and 'assistant' roles, "
  56. "starting with user and alternating (u/a/u/a/u...)"
  57. )
  58. """
  59. Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
  60. Here, we are adding it manually.
  61. """
  62. dialog_tokens: List[int] = sum(
  63. [
  64. tokenizer.encode(
  65. f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
  66. )
  67. + [tokenizer.eos_token_id]
  68. for prompt, answer in zip(dialog[::2], dialog[1::2])
  69. ],
  70. [],
  71. )
  72. assert (
  73. dialog[-1]["role"] == "user"
  74. ), f"Last message must be from user, got {dialog[-1]['role']}"
  75. dialog_tokens += tokenizer.encode(
  76. f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
  77. )
  78. prompt_tokens.append(dialog_tokens)
  79. return prompt_tokens
  80. @pytest.mark.skip_missing_tokenizer
  81. @patch("chat_completion.AutoTokenizer")
  82. @patch("chat_completion.load_model")
  83. def test_chat_completion(
  84. load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
  85. ):
  86. from chat_completion import main
  87. setup_tokenizer(tokenizer)
  88. kwargs = {
  89. "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
  90. }
  91. main(llama_version, **kwargs)
  92. dialogs = read_dialogs_from_file(kwargs["prompt_file"])
  93. format_tokens = (
  94. _format_tokens_llama2
  95. if llama_version == "meta-llama/Llama-2-7b-hf"
  96. else _format_tokens_llama3
  97. )
  98. REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
  99. assert all(
  100. (
  101. load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
  102. == torch.tensor(REF_RESULT[0]).long()
  103. ).tolist()
  104. )
  105. assert all(
  106. (
  107. load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
  108. == torch.tensor(REF_RESULT[1]).long()
  109. ).tolist()
  110. )
  111. assert all(
  112. (
  113. load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
  114. == torch.tensor(REF_RESULT[2]).long()
  115. ).tolist()
  116. )
  117. assert all(
  118. (
  119. load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
  120. == torch.tensor(REF_RESULT[3]).long()
  121. ).tolist()
  122. )
  123. assert all(
  124. (
  125. load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
  126. == torch.tensor(REF_RESULT[4]).long()
  127. ).tolist()
  128. )