safety_utils.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import torch
  5. import warnings
  6. from peft import PeftConfig
  7. from transformers import LlamaConfig, LlamaTokenizer, LlamaForCausalLM
  8. # Class for performing safety checks using AuditNLG library
  9. class AuditNLGSensitiveTopics(object):
  10. def __init__(self):
  11. pass
  12. def __call__(self, output_text):
  13. try:
  14. from auditnlg.safety.exam import safety_scores
  15. except ImportError as e:
  16. print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt")
  17. raise e
  18. data = [{"output": output_text}]
  19. result = safety_scores(data=data, method="sensitive_topics")
  20. scores = result[1]["all_scores"][0]
  21. is_safe = scores["pred_class"] == "none"
  22. report = ""
  23. if not is_safe:
  24. report += f"Predicted class: {scores['pred_class']}\n"
  25. report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n"
  26. report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n"
  27. return "Sensitive Topics", is_safe, report
  28. class SalesforceSafetyChecker(object):
  29. def __init__(self):
  30. pass
  31. def __call__(self, output_text):
  32. from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
  33. config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
  34. tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base")
  35. model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config)
  36. model.eval()
  37. prefix = "Is the <Text> field safe or unsafe?"
  38. input_ids = tokenizer(prefix + " <Text> " + output_text + " <Context> ", return_tensors="pt").input_ids
  39. if len(input_ids[0]) > 512:
  40. warnings.warn(
  41. "Input length is > 512 token. Safety check result could be incorrect."
  42. )
  43. with torch.no_grad():
  44. outputs = model.generate(
  45. input_ids,
  46. output_scores=True,
  47. return_dict_in_generate=True,
  48. max_new_tokens=20,
  49. )
  50. is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe"
  51. report = ""
  52. if not is_safe:
  53. true_false_ids = tokenizer("true false").input_ids[:2]
  54. keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"]
  55. scores = {}
  56. for k, i in zip(keys, range(3,20,2)):
  57. scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5)
  58. report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n"
  59. report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n"
  60. return "Salesforce Content Safety Flan T5 Base", is_safe, report
  61. def get_total_length(self, data):
  62. prefix = "Is the <Text> field safe or unsafe "
  63. input_sample = "<Text> {output} <Context> ".format(**data[0])
  64. return len(self.tokenizer(prefix + input_sample)["input_ids"])
  65. # Class for performing safety checks using Azure Content Safety service
  66. class AzureSaftyChecker(object):
  67. def __init__(self):
  68. try:
  69. from azure.ai.contentsafety import ContentSafetyClient
  70. from azure.core.credentials import AzureKeyCredential
  71. key = os.environ["CONTENT_SAFETY_KEY"]
  72. endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"]
  73. except ImportError:
  74. raise Exception(
  75. "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety"
  76. )
  77. except KeyError:
  78. raise Exception(
  79. "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT."
  80. )
  81. self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
  82. def __call__(self, output_text):
  83. from azure.core.exceptions import HttpResponseError
  84. from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
  85. print(len(output_text))
  86. if len(output_text) > 1000:
  87. raise Exception("Input length to safety check is too long (>1000).")
  88. categories = [
  89. TextCategory.VIOLENCE,
  90. TextCategory.SELF_HARM,
  91. TextCategory.SEXUAL,
  92. TextCategory.HATE,
  93. ]
  94. request = AnalyzeTextOptions(text=output_text, categories=categories)
  95. try:
  96. response = self.client.analyze_text(request)
  97. except HttpResponseError as e:
  98. print("Analyze text failed.")
  99. if e.error:
  100. print(f"Error code: {e.error.code}")
  101. print(f"Error message: {e.error.message}")
  102. raise
  103. print(e)
  104. raise e
  105. levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"}
  106. severities = [
  107. getattr(response, c.name.lower() + "_result").severity for c in categories
  108. ]
  109. DEFAULT_LEVELS = [0, 0, 0, 0]
  110. is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)])
  111. report = ""
  112. if not is_safe:
  113. report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n"
  114. report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n"
  115. return "Azure Content Saftey API", is_safe, report
  116. # Function to load the PeftModel for performance optimization
  117. # Function to determine which safety checker to use based on the options selected
  118. def get_safety_checker(enable_azure_content_safety,
  119. enable_sensitive_topics,
  120. enable_salesforce_content_safety,
  121. ):
  122. safety_checker = []
  123. if enable_azure_content_safety:
  124. safety_checker.append(AzureSaftyChecker())
  125. if enable_sensitive_topics:
  126. safety_checker.append(AuditNLGSensitiveTopics())
  127. if enable_salesforce_content_safety:
  128. safety_checker.append(SalesforceSafetyChecker())
  129. return safety_checker