safety_utils.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import torch
  5. import warnings
  6. from typing import List
  7. from string import Template
  8. from enum import Enum
  9. class AgentType(Enum):
  10. AGENT = "Agent"
  11. USER = "User"
  12. # Class for performing safety checks using AuditNLG library
  13. class AuditNLGSensitiveTopics(object):
  14. def __init__(self):
  15. pass
  16. def __call__(self, output_text, **kwargs):
  17. try:
  18. from auditnlg.safety.exam import safety_scores
  19. except ImportError as e:
  20. print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt")
  21. raise e
  22. data = [{"output": output_text}]
  23. result = safety_scores(data=data, method="sensitive_topics")
  24. scores = result[1]["all_scores"][0]
  25. is_safe = scores["pred_class"] == "none"
  26. report = ""
  27. if not is_safe:
  28. report += f"Predicted class: {scores['pred_class']}\n"
  29. report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n"
  30. report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n"
  31. return "Sensitive Topics", is_safe, report
  32. class SalesforceSafetyChecker(object):
  33. def __init__(self):
  34. pass
  35. def __call__(self, output_text, **kwargs):
  36. from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
  37. config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
  38. tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base")
  39. model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config)
  40. model.eval()
  41. prefix = "Is the <Text> field safe or unsafe?"
  42. input_ids = tokenizer(prefix + " <Text> " + output_text + " <Context> ", return_tensors="pt").input_ids
  43. if len(input_ids[0]) > 512:
  44. warnings.warn(
  45. "Input length is > 512 token. Safety check result could be incorrect."
  46. )
  47. with torch.no_grad():
  48. outputs = model.generate(
  49. input_ids,
  50. output_scores=True,
  51. return_dict_in_generate=True,
  52. max_new_tokens=20,
  53. )
  54. is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe"
  55. report = ""
  56. if not is_safe:
  57. true_false_ids = tokenizer("true false").input_ids[:2]
  58. keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"]
  59. scores = {}
  60. for k, i in zip(keys, range(3,20,2)):
  61. scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5)
  62. report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n"
  63. report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n"
  64. return "Salesforce Content Safety Flan T5 Base", is_safe, report
  65. def get_total_length(self, data):
  66. prefix = "Is the <Text> field safe or unsafe "
  67. input_sample = "<Text> {output} <Context> ".format(**data[0])
  68. return len(self.tokenizer(prefix + input_sample)["input_ids"])
  69. # Class for performing safety checks using Azure Content Safety service
  70. class AzureSaftyChecker(object):
  71. def __init__(self):
  72. try:
  73. from azure.ai.contentsafety import ContentSafetyClient
  74. from azure.core.credentials import AzureKeyCredential
  75. key = os.environ["CONTENT_SAFETY_KEY"]
  76. endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"]
  77. except ImportError:
  78. raise Exception(
  79. "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety"
  80. )
  81. except KeyError:
  82. raise Exception(
  83. "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT."
  84. )
  85. self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
  86. def __call__(self, output_text, **kwargs):
  87. from azure.core.exceptions import HttpResponseError
  88. from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
  89. print(len(output_text))
  90. if len(output_text) > 1000:
  91. raise Exception("Input length to safety check is too long (>1000).")
  92. categories = [
  93. TextCategory.VIOLENCE,
  94. TextCategory.SELF_HARM,
  95. TextCategory.SEXUAL,
  96. TextCategory.HATE,
  97. ]
  98. request = AnalyzeTextOptions(text=output_text, categories=categories)
  99. try:
  100. response = self.client.analyze_text(request)
  101. except HttpResponseError as e:
  102. print("Analyze text failed.")
  103. if e.error:
  104. print(f"Error code: {e.error.code}")
  105. print(f"Error message: {e.error.message}")
  106. raise
  107. print(e)
  108. raise e
  109. levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"}
  110. severities = [
  111. getattr(response, c.name.lower() + "_result").severity for c in categories
  112. ]
  113. DEFAULT_LEVELS = [0, 0, 0, 0]
  114. is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)])
  115. report = ""
  116. if not is_safe:
  117. report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n"
  118. report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n"
  119. return "Azure Content Saftey API", is_safe, report
  120. class LlamaGuardSafetyChecker(object):
  121. def __init__(self):
  122. from transformers import AutoModelForCausalLM, AutoTokenizer
  123. model_id = "meta-llama/LlamaGuard-7b"
  124. self.tokenizer = AutoTokenizer.from_pretrained(model_id)
  125. self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
  126. pass
  127. def __call__(self, output_text, **kwargs):
  128. agent_type = kwargs.get('agent_type', AgentType.USER)
  129. user_prompt = kwargs.get('user_prompt', "")
  130. model_prompt = output_text.strip()
  131. if(agent_type == AgentType.AGENT):
  132. if user_prompt == "":
  133. print("empty user prompt for agent check, returning unsafe")
  134. return "Llama Guard", False, "Missing user_prompt from Agent response check"
  135. else:
  136. model_prompt = model_prompt.replace(user_prompt, "")
  137. user_prompt = f"User: {user_prompt}"
  138. agent_prompt = f"Agent: {model_prompt}"
  139. chat = [
  140. {"role": "user", "content": user_prompt},
  141. {"role": "assistant", "content": agent_prompt},
  142. ]
  143. else:
  144. chat = [
  145. {"role": "user", "content": model_prompt},
  146. ]
  147. input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to("cuda")
  148. prompt_len = input_ids.shape[-1]
  149. output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
  150. result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
  151. splitted_result = result.split("\n")[0];
  152. is_safe = splitted_result == "safe"
  153. report = result
  154. return "Llama Guard", is_safe, report
  155. # Function to load the PeftModel for performance optimization
  156. # Function to determine which safety checker to use based on the options selected
  157. def get_safety_checker(enable_azure_content_safety,
  158. enable_sensitive_topics,
  159. enable_salesforce_content_safety,
  160. enable_llamaguard_content_safety):
  161. safety_checker = []
  162. if enable_azure_content_safety:
  163. safety_checker.append(AzureSaftyChecker())
  164. if enable_sensitive_topics:
  165. safety_checker.append(AuditNLGSensitiveTopics())
  166. if enable_salesforce_content_safety:
  167. safety_checker.append(SalesforceSafetyChecker())
  168. if enable_llamaguard_content_safety:
  169. safety_checker.append(LlamaGuardSafetyChecker())
  170. return safety_checker