safety_utils.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import torch
  5. import warnings
  6. # Class for performing safety checks using AuditNLG library
  7. class AuditNLGSensitiveTopics(object):
  8. def __init__(self):
  9. pass
  10. def __call__(self, output_text):
  11. try:
  12. from auditnlg.safety.exam import safety_scores
  13. except ImportError as e:
  14. print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt")
  15. raise e
  16. data = [{"output": output_text}]
  17. result = safety_scores(data=data, method="sensitive_topics")
  18. scores = result[1]["all_scores"][0]
  19. is_safe = scores["pred_class"] == "none"
  20. report = ""
  21. if not is_safe:
  22. report += f"Predicted class: {scores['pred_class']}\n"
  23. report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n"
  24. report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n"
  25. return "Sensitive Topics", is_safe, report
  26. class SalesforceSafetyChecker(object):
  27. def __init__(self):
  28. pass
  29. def __call__(self, output_text):
  30. from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
  31. config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
  32. tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base")
  33. model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config)
  34. model.eval()
  35. prefix = "Is the <Text> field safe or unsafe?"
  36. input_ids = tokenizer(prefix + " <Text> " + output_text + " <Context> ", return_tensors="pt").input_ids
  37. if len(input_ids[0]) > 512:
  38. warnings.warn(
  39. "Input length is > 512 token. Safety check result could be incorrect."
  40. )
  41. with torch.no_grad():
  42. outputs = model.generate(
  43. input_ids,
  44. output_scores=True,
  45. return_dict_in_generate=True,
  46. max_new_tokens=20,
  47. )
  48. is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe"
  49. report = ""
  50. if not is_safe:
  51. true_false_ids = tokenizer("true false").input_ids[:2]
  52. keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"]
  53. scores = {}
  54. for k, i in zip(keys, range(3,20,2)):
  55. scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5)
  56. report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n"
  57. report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n"
  58. return "Salesforce Content Safety Flan T5 Base", is_safe, report
  59. def get_total_length(self, data):
  60. prefix = "Is the <Text> field safe or unsafe "
  61. input_sample = "<Text> {output} <Context> ".format(**data[0])
  62. return len(self.tokenizer(prefix + input_sample)["input_ids"])
  63. # Class for performing safety checks using Azure Content Safety service
  64. class AzureSaftyChecker(object):
  65. def __init__(self):
  66. try:
  67. from azure.ai.contentsafety import ContentSafetyClient
  68. from azure.core.credentials import AzureKeyCredential
  69. key = os.environ["CONTENT_SAFETY_KEY"]
  70. endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"]
  71. except ImportError:
  72. raise Exception(
  73. "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety"
  74. )
  75. except KeyError:
  76. raise Exception(
  77. "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT."
  78. )
  79. self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
  80. def __call__(self, output_text):
  81. from azure.core.exceptions import HttpResponseError
  82. from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
  83. print(len(output_text))
  84. if len(output_text) > 1000:
  85. raise Exception("Input length to safety check is too long (>1000).")
  86. categories = [
  87. TextCategory.VIOLENCE,
  88. TextCategory.SELF_HARM,
  89. TextCategory.SEXUAL,
  90. TextCategory.HATE,
  91. ]
  92. request = AnalyzeTextOptions(text=output_text, categories=categories)
  93. try:
  94. response = self.client.analyze_text(request)
  95. except HttpResponseError as e:
  96. print("Analyze text failed.")
  97. if e.error:
  98. print(f"Error code: {e.error.code}")
  99. print(f"Error message: {e.error.message}")
  100. raise
  101. print(e)
  102. raise e
  103. levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"}
  104. severities = [
  105. getattr(response, c.name.lower() + "_result").severity for c in categories
  106. ]
  107. DEFAULT_LEVELS = [0, 0, 0, 0]
  108. is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)])
  109. report = ""
  110. if not is_safe:
  111. report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n"
  112. report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n"
  113. return "Azure Content Saftey API", is_safe, report
  114. # Function to load the PeftModel for performance optimization
  115. # Function to determine which safety checker to use based on the options selected
  116. def get_safety_checker(enable_azure_content_safety,
  117. enable_sensitive_topics,
  118. enable_salesforce_content_safety,
  119. ):
  120. safety_checker = []
  121. if enable_azure_content_safety:
  122. safety_checker.append(AzureSaftyChecker())
  123. if enable_sensitive_topics:
  124. safety_checker.append(AuditNLGSensitiveTopics())
  125. if enable_salesforce_content_safety:
  126. safety_checker.append(SalesforceSafetyChecker())
  127. return safety_checker