Browse Source

Adapt readme + check_completion.py to reflect that no manual change is needed to support eot_id

Matthias Reso 5 months ago
parent
commit
79aa70442e
2 changed files with 5 additions and 22 deletions
  1. 5 16
      README.md
  2. 0 6
      recipes/inference/local_inference/chat_completion/chat_completion.py

+ 5 - 16
README.md

@@ -7,9 +7,9 @@ The 'llama-recipes' repository is a companion to the [Meta Llama 2](https://gith
 > | Token | Description |
 > |---|---|
 > `<\|begin_of_text\|>` | This is equivalent to the BOS token. |
-> `<\|eot_id\|>` | This signifies the end of the message in a turn. The generate function needs to be set up as shown below or in [this example](./recipes/inference/local_inference/chat_completion/chat_completion.py) to terminate the generation after the turn.|
+> `<\|end_of_text\|>` | This is equivalent to the EOS token. For multiturn-conversations it's usually unused. Instead, every message is terminated with `<\|eot_id\|>` instead.|
+> `<\|eot_id\|>` | This token signifies the end of the message in a turn i.e. the end of a single message by a system, user or assistant role as shown below.|
 > `<\|start_header_id\|>{role}<\|end_header_id\|>` | These tokens enclose the role for a particular message. The possible roles can be: system, user, assistant. |
-> `<\|end_of_text\|>` | This is equivalent to the EOS token. Its usually not used during multiturn-conversations. Instead, each message is terminated with `<\|eot_id\|>` |
 >
 > A multiturn-conversation with Llama 3 follows this prompt template:
 > ```
@@ -23,21 +23,10 @@ The 'llama-recipes' repository is a companion to the [Meta Llama 2](https://gith
 >
 > {{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
 > ```
+> Each message gets trailed by an `<|eot_id|>` token before a new header is started, signaling a role change.
 >
-> To signal the end of the current message the model emits the `<\|eot_id\|>` token. To terminate the generation we need to call the model's generate function as follows:
-> ```
-> terminators = [
->     tokenizer.eos_token_id,
->     tokenizer.convert_tokens_to_ids("<|eot_id|>")
->     ]
->  ...
-> outputs = model.generate(
->     ...
->     eos_token_id=terminators,
->     )
-> ```
+> More details on the new tokenizer and prompt template can be found [here](https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3#special-tokens-used-with-meta-llama-3).
 >
-> More details on the new tokenizer and prompt template: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3#special-tokens-used-with-meta-llama-3
 > [!NOTE]
 > The llama-recipes repository was recently refactored to promote a better developer experience of using the examples. Some files have been moved to new locations. The `src/` folder has NOT been modified, so the functionality of this repo and package is not impacted.
 >
@@ -69,7 +58,7 @@ These instructions will get you a copy of the project up and running on your loc
 ### Prerequisites
 
 #### PyTorch Nightlies
-I you want to use PyTorch nightlies instead of the stable release, go to [this guide](https://pytorch.org/get-started/locally/) to retrieve the right `--extra-index-url URL` parameter for the `pip install` commands on your platform.
+If you want to use PyTorch nightlies instead of the stable release, go to [this guide](https://pytorch.org/get-started/locally/) to retrieve the right `--extra-index-url URL` parameter for the `pip install` commands on your platform.
 
 ### Installing
 Llama-recipes provides a pip distribution for easy install and usage in other projects. Alternatively, it can be installed from source.

+ 0 - 6
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -75,11 +75,6 @@ def main(
 
     chats = tokenizer.apply_chat_template(dialogs)
 
-    terminators = [
-        tokenizer.eos_token_id,
-        tokenizer.convert_tokens_to_ids("<|eot_id|>")
-        ]
-
     with torch.no_grad():
         for idx, chat in enumerate(chats):
             safety_checker = get_safety_checker(enable_azure_content_safety,
@@ -118,7 +113,6 @@ def main(
                 top_k=top_k,
                 repetition_penalty=repetition_penalty,
                 length_penalty=length_penalty,
-                eos_token_id=terminators,
                 **kwargs
             )