|
@@ -35,6 +35,7 @@ def main(
|
|
|
enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
|
|
|
enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
|
|
|
use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
|
|
|
+ enable_llamaguard_content_safety: bool = False,
|
|
|
**kwargs
|
|
|
):
|
|
|
if prompt_file is not None:
|
|
@@ -90,6 +91,7 @@ def main(
|
|
|
safety_checker = get_safety_checker(enable_azure_content_safety,
|
|
|
enable_sensitive_topics,
|
|
|
enable_saleforce_content_safety,
|
|
|
+ enable_llamaguard_content_safety,
|
|
|
)
|
|
|
# Safety check of the user prompt
|
|
|
safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]
|