Browse Source

clean up and typo fixes

Hamid Shojanazeri 1 year ago
parent
commit
754b5d22c7
2 changed files with 6 additions and 6 deletions
  1. 4 4
      inference/inference.py
  2. 2 2
      inference/safety_utils.py

+ 4 - 4
inference/inference.py

@@ -31,7 +31,7 @@ def main(
     length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
-    enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
+    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
     max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
     **kwargs
 ):
@@ -59,10 +59,11 @@ def main(
             "pad_token": "<PAD>",
         }
     )
+    model.resize_token_embeddings(model.config.vocab_size + 1) 
     
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
-                                        enable_saleforce_content_safety,
+                                        enable_salesforce_content_safety,
                                         )
 
     # Safety check of the user prompt
@@ -77,7 +78,7 @@ def main(
             if not is_safe:
                 print(method)
                 print(report)
-        print("Skipping the inferece as the prompt is not safe.")
+        print("Skipping the inference as the prompt is not safe.")
         sys.exit(1)  # Exit the program with an error status
 
     if peft_model:
@@ -85,7 +86,6 @@ def main(
 
     model.eval()
     batch = tokenizer(user_prompt, padding='max_length', truncation=True,max_length=max_padding_length,return_tensors="pt")
-    model.resize_token_embeddings(model.config.vocab_size + 1) 
     batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():

+ 2 - 2
inference/safety_utils.py

@@ -154,14 +154,14 @@ class AzureSaftyChecker(object):
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
-                       enable_saleforce_content_safety,
+                       enable_salesforce_content_safety,
                        ):
     safety_checker = []
     if enable_azure_content_safety:
         safety_checker.append(AzureSaftyChecker())
     if enable_sensitive_topics:
         safety_checker.append(AuditNLGSensitiveTopics())
-    if enable_saleforce_content_safety:
+    if enable_salesforce_content_safety:
         safety_checker.append(SalesforceSafetyChecker())
     return safety_checker