Browse Source

set end of text id as termination sign in gernerate

Matthias Reso 5 months ago
parent
commit
362c866aaa

+ 6 - 0
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -75,6 +75,11 @@ def main(
 
     chats = tokenizer.apply_chat_template(dialogs)
 
+    terminators = [
+        tokenizer.eos_token_id,
+        tokenizer.convert_tokens_to_ids("<|eot_id|>")
+        ]
+
     with torch.no_grad():
         for idx, chat in enumerate(chats):
             safety_checker = get_safety_checker(enable_azure_content_safety,
@@ -113,6 +118,7 @@ def main(
                 top_k=top_k,
                 repetition_penalty=repetition_penalty,
                 length_penalty=length_penalty,
+                eos_token_id=terminators,
                 **kwargs
             )