Переглянути джерело

Update chat format for llama3

Matthias Reso 11 місяців тому
батько
коміт
a414ca6a57

+ 4 - 4
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -8,7 +8,7 @@ import os
 import sys
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
 from llama_recipes.inference.model_utils import load_model, load_peft_model
@@ -65,14 +65,14 @@ def main(
     if peft_model:
         model = load_peft_model(model, peft_model)
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-         
+
             "pad_token": "<PAD>",
         }
     )
-    
+
     chats = format_tokens(dialogs, tokenizer)
 
     with torch.no_grad():

+ 38 - 1
src/llama_recipes/inference/chat_utils.py

@@ -17,7 +17,44 @@ Dialog = List[Message]
 
 B_INST, E_INST = "[INST]", "[/INST]"
 B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
 def format_tokens(dialogs, tokenizer):
+    if tokenizer.vocab_size >= 128000:
+        return _format_tokens_llama3(dialogs, tokenizer)
+    else:
+        return _format_tokens_llama2(dialogs, tokenizer)
+
+
+def _encode_header(message, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|start_header_id|>"))
+    tokens.extend(tokenizer.encode(message["role"]))
+    tokens.extend(tokenizer.encode("<|end_header_id|>"))
+    tokens.extend(tokenizer.encode("\n\n"))
+    return tokens
+
+def _encode_message(message, tokenizer):
+    tokens = _encode_header(message, tokenizer)
+    tokens.extend(tokenizer.encode(message["content"].strip()))
+    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    return tokens
+
+def _format_dialog(dialog, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    for msg in dialog:
+        tokens.extend(_encode_message(msg, tokenizer))
+    tokens.extend(
+        _encode_header({"role": "assistant", "content": ""}, tokenizer)
+        )
+    return tokens
+
+
+def _format_tokens_llama3(dialogs, tokenizer):
+    return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
+
+
+def _format_tokens_llama2(dialogs, tokenizer):
     prompt_tokens = []
     for dialog in dialogs:
         if dialog[0]["role"] == "system":
@@ -57,7 +94,7 @@ def format_tokens(dialogs, tokenizer):
         )
         prompt_tokens.append(dialog_tokens)
     return prompt_tokens
-        
+
 
 def read_dialogs_from_file(file_path):
     with open(file_path, 'r') as file: