chat_utils.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from typing import List, Literal, Optional, Tuple, TypedDict, Union
  4. import json
  5. Role = Literal["user", "assistant"]
  6. class Message(TypedDict):
  7. role: Role
  8. content: str
  9. Dialog = List[Message]
  10. B_INST, E_INST = "[INST]", "[/INST]"
  11. B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
  12. def format_tokens(dialogs, tokenizer):
  13. prompt_tokens = []
  14. for dialog in dialogs:
  15. if dialog[0]["role"] == "system":
  16. dialog = [
  17. {
  18. "role": dialog[1]["role"],
  19. "content": B_SYS
  20. + dialog[0]["content"]
  21. + E_SYS
  22. + dialog[1]["content"],
  23. }
  24. ] + dialog[2:]
  25. assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
  26. [msg["role"] == "assistant" for msg in dialog[1::2]]
  27. ), (
  28. "model only supports 'system','user' and 'assistant' roles, "
  29. "starting with user and alternating (u/a/u/a/u...)"
  30. )
  31. """
  32. Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
  33. Here, we are adding it manually.
  34. """
  35. dialog_tokens: List[int] = sum(
  36. [
  37. tokenizer.encode(
  38. f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
  39. )
  40. for prompt, answer in zip(dialog[::2], dialog[1::2])
  41. ],
  42. [],
  43. )
  44. assert (
  45. dialog[-1]["role"] == "user"
  46. ), f"Last message must be from user, got {dialog[-1]['role']}"
  47. dialog_tokens += tokenizer.encode(
  48. f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
  49. )
  50. prompt_tokens.append(dialog_tokens)
  51. return prompt_tokens
  52. def read_dialogs_from_file(file_path):
  53. with open(file_path, 'r') as file:
  54. dialogs = json.load(file)
  55. return dialogs