Przeglądaj źródła

Inference updates (#12)

Geeta Chauhan 1 rok temu
rodzic
commit
f6c3ffd4f7
3 zmienionych plików z 15 dodań i 19 usunięć
  1. 3 5
      inference/chat_completion.py
  2. 7 7
      inference/inference.py
  3. 5 7
      llama_finetuning.py

+ 3 - 5
inference/chat_completion.py

@@ -62,13 +62,11 @@ def main(
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-            "eos_token": "</s>",
-            "bos_token": "</s>",
-            "unk_token": "</s>",
-            "pad_token": "[PAD]",
+         
+            "pad_token": "<PAD>",
         }
     )
-
+    
     chats = format_tokens(dialogs, tokenizer)
 
     with torch.no_grad():

+ 7 - 7
inference/inference.py

@@ -7,6 +7,7 @@ import fire
 import torch
 import os
 import sys
+import time
 from typing import List
 
 from transformers import LlamaTokenizer
@@ -49,15 +50,13 @@ def main(
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
+    
     model = load_model(model_name, quantization)
-
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-            "eos_token": "</s>",
-            "bos_token": "</s>",
-            "unk_token": "</s>",
-            "pad_token": "[PAD]",
+         
+            "pad_token": "<PAD>",
         }
     )
     
@@ -88,7 +87,7 @@ def main(
 
     batch = tokenizer(user_prompt, return_tensors="pt")
     batch = {k: v.to("cuda") for k, v in batch.items()}
-    
+    start = time.perf_counter()
     with torch.no_grad():
         outputs = model.generate(
             **batch,
@@ -103,7 +102,8 @@ def main(
             length_penalty=length_penalty,
             **kwargs 
         )
-
+    e2e_inference_time = (time.perf_counter()-start)*1000
+    print(f"the inference time is {e2e_inference_time} ms")
     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
     
     # Safety check of the model output

+ 5 - 7
llama_finetuning.py

@@ -109,13 +109,11 @@ def main(**kwargs):
     # Load the tokenizer and add special tokens
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer.add_special_tokens(
-        {
-            "eos_token": "</s>",
-            "bos_token": "</s>",
-            "unk_token": "</s>",
-            "pad_token": '[PAD]',
-        }
-    )
+            {
+            
+                "pad_token": "<PAD>",
+            }
+        )
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)