|
@@ -75,6 +75,11 @@ def main(
|
|
|
|
|
|
chats = tokenizer.apply_chat_template(dialogs)
|
|
|
|
|
|
+ terminators = [
|
|
|
+ tokenizer.eos_token_id,
|
|
|
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
|
|
+ ]
|
|
|
+
|
|
|
with torch.no_grad():
|
|
|
for idx, chat in enumerate(chats):
|
|
|
safety_checker = get_safety_checker(enable_azure_content_safety,
|
|
@@ -113,6 +118,7 @@ def main(
|
|
|
top_k=top_k,
|
|
|
repetition_penalty=repetition_penalty,
|
|
|
length_penalty=length_penalty,
|
|
|
+ eos_token_id=terminators,
|
|
|
**kwargs
|
|
|
)
|
|
|
|