Browse Source

Add option to enable Llamaguard content safety check in chat_completion

Joone Hur 1 year ago
parent
commit
ed3e11e9a8
1 changed files with 2 additions and 0 deletions
  1. 2 0
      examples/chat_completion/chat_completion.py

+ 2 - 0
examples/chat_completion/chat_completion.py

@@ -35,6 +35,7 @@ def main(
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
     use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    enable_llamaguard_content_safety: bool = False,
     **kwargs
     **kwargs
 ):
 ):
     if prompt_file is not None:
     if prompt_file is not None:
@@ -90,6 +91,7 @@ def main(
             safety_checker = get_safety_checker(enable_azure_content_safety,
             safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_sensitive_topics,
                                         enable_saleforce_content_safety,
                                         enable_saleforce_content_safety,
+                                        enable_llamaguard_content_safety,
                                         )
                                         )
             # Safety check of the user prompt
             # Safety check of the user prompt
             safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]
             safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]