Bläddra i källkod

Update paddings (#85)

Geeta Chauhan 1 år sedan
förälder
incheckning
03faba661f
3 ändrade filer med 28 tillägg och 6 borttagningar
  1. 15 0
      docs/inference.md
  2. 11 4
      inference/inference.py
  3. 2 2
      inference/safety_utils.py

+ 15 - 0
docs/inference.md

@@ -27,6 +27,21 @@ inference/samsum_prompt.txt
 ...
 ```
 
+**Note**
+Currently pad token by default in [HuggingFace Tokenizer is `None`](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L110). We add the padding token as a special token to the tokenizer, which in this case requires to resize the token_embeddings as shown below:
+
+```python
+tokenizer.add_special_tokens(
+        {
+         
+            "pad_token": "<PAD>",
+        }
+    )
+model.resize_token_embeddings(model.config.vocab_size + 1) 
+```
+Padding would be required for batch inference. In this this [example](../inference/inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
+
+
 **Chat completion**
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
 

+ 11 - 4
inference/inference.py

@@ -31,7 +31,8 @@ def main(
     length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
-    enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
@@ -76,10 +77,11 @@ def main(
             "pad_token": "<PAD>",
         }
     )
+    model.resize_token_embeddings(model.config.vocab_size + 1) 
     
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
-                                        enable_saleforce_content_safety,
+                                        enable_salesforce_content_safety,
                                         )
 
     # Safety check of the user prompt
@@ -94,10 +96,15 @@ def main(
             if not is_safe:
                 print(method)
                 print(report)
-        print("Skipping the inferece as the prompt is not safe.")
+        print("Skipping the inference as the prompt is not safe.")
         sys.exit(1)  # Exit the program with an error status
+        
+    if peft_model:
+        model = load_peft_model(model, peft_model)
+
+    model.eval()
+    batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
 
-    batch = tokenizer(user_prompt, return_tensors="pt")
     batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():

+ 2 - 2
inference/safety_utils.py

@@ -154,14 +154,14 @@ class AzureSaftyChecker(object):
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
-                       enable_saleforce_content_safety,
+                       enable_salesforce_content_safety,
                        ):
     safety_checker = []
     if enable_azure_content_safety:
         safety_checker.append(AzureSaftyChecker())
     if enable_sensitive_topics:
         safety_checker.append(AuditNLGSensitiveTopics())
-    if enable_saleforce_content_safety:
+    if enable_salesforce_content_safety:
         safety_checker.append(SalesforceSafetyChecker())
     return safety_checker