浏览代码

update the code with llama-guard safety check

Hamid Shojanazeri 1 年之前
父节点
当前提交
85ee7bcc24
共有 1 个文件被更改,包括 2 次插入0 次删除
  1. 2 0
      examples/code_llama/code_completion_example.py

+ 2 - 0
examples/code_llama/code_completion_example.py

@@ -33,6 +33,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
     **kwargs
 ):
 ):
@@ -72,6 +73,7 @@ def main(
     safety_checker = get_safety_checker(enable_azure_content_safety,
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
                                         enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety,
                                         )
                                         )
 
 
     # Safety check of the user prompt
     # Safety check of the user prompt