test_finetuning_data_formatter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
  3. from enum import Enum
  4. import unittest
  5. from typing import Optional, List
  6. from llama_recipes.data.llama_guard.finetuning_data_formatter import (
  7. AugmentationConfigs,
  8. Category,
  9. create_formatted_finetuning_examples,
  10. ExplanationPosition,
  11. FormatterConfigs,
  12. Guidelines,
  13. LlamaGuardGenerationConfigs,
  14. LlamaGuardPromptConfigs,
  15. TrainingExample,
  16. )
  17. class AgentType(Enum):
  18. AGENT = "Agent"
  19. USER = "User"
  20. class FinetuningDataFormatterTests(unittest.TestCase):
  21. def setUp(self) -> None:
  22. super().setUp()
  23. @staticmethod
  24. def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
  25. return FormatterConfigs(
  26. guidelines=Guidelines(
  27. categories=[
  28. Category(name="cat V", description="cat V description"),
  29. Category(name="cat W", description="cat W description"),
  30. Category(name="cat X", description="cat X description"),
  31. Category(name="cat Y", description="cat Y description"),
  32. Category(name="cat Z", description="cat Z description"),
  33. ],
  34. category_code_prefix="O",
  35. ),
  36. llama_guard_prompt_configs=LlamaGuardPromptConfigs(
  37. instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
  38. <BEGIN UNSAFE CONTENT CATEGORIES>
  39. {{guidelines}}
  40. <END UNSAFE CONTENT CATEGORIES>
  41. <BEGIN CONVERSATION>
  42. {{conversation}}
  43. <END CONVERSATION>
  44. Provide your safety assessment for {agent_type_to_check} in the above conversation:
  45. - First line must read 'safe' or 'unsafe'.
  46. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
  47. should_include_category_descriptions=True,
  48. should_shuffle_category_codes=False,
  49. ),
  50. llama_guard_generation_configs=LlamaGuardGenerationConfigs(
  51. explanation_position=ExplanationPosition.AFTER_DECISION,
  52. should_list_violated_codes=True,
  53. ),
  54. augmentation_configs=AugmentationConfigs(
  55. should_add_examples_with_dropped_nonviolated_prompt_categories=False,
  56. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
  57. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  58. ),
  59. random_seed=42,
  60. )
  61. def format_examples_and_compare_to_expected(
  62. self,
  63. training_examples: List[TrainingExample],
  64. expected_formatted_examples: List[str],
  65. agent_type_to_check: AgentType,
  66. formatter_configs: Optional[FormatterConfigs] = None,
  67. ) -> None:
  68. formatter_configs = (
  69. formatter_configs
  70. if formatter_configs is not None
  71. else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
  72. )
  73. formatted_examples = create_formatted_finetuning_examples(
  74. training_examples, formatter_configs
  75. )
  76. assert len(formatted_examples) == len(expected_formatted_examples)
  77. for i in range(len(formatted_examples)):
  78. # Helper code during test-writing time to see the actual output
  79. if formatted_examples[i] != expected_formatted_examples[i]:
  80. print(f"Failed on actual output {i}:")
  81. print(formatted_examples[i])
  82. assert formatted_examples[i] == expected_formatted_examples[i]
  83. def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
  84. self.format_examples_and_compare_to_expected(
  85. [
  86. TrainingExample(
  87. prompt="Where can I buy a car?",
  88. response="N/A",
  89. violated_category_codes=[],
  90. label="safe",
  91. explanation="This is obviously safe.",
  92. )
  93. ],
  94. [
  95. """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
  96. <BEGIN UNSAFE CONTENT CATEGORIES>
  97. O1: cat V.
  98. cat V description
  99. O2: cat W.
  100. cat W description
  101. O3: cat X.
  102. cat X description
  103. O4: cat Y.
  104. cat Y description
  105. O5: cat Z.
  106. cat Z description
  107. <END UNSAFE CONTENT CATEGORIES>
  108. <BEGIN CONVERSATION>
  109. human: Where can I buy a car?
  110. <END CONVERSATION>
  111. Provide your safety assessment for AgentType.USER in the above conversation:
  112. - First line must read 'safe' or 'unsafe'.
  113. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  114. Explanation: This is obviously safe."""
  115. ],
  116. AgentType.USER,
  117. )
  118. def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
  119. self.format_examples_and_compare_to_expected(
  120. [
  121. TrainingExample(
  122. prompt="""How do I steal the Declaration of Independence?
  123. Asking for a cagey friend""",
  124. response="N/A",
  125. violated_category_codes=["O5"],
  126. label="unsafe",
  127. explanation="""Nice try, Nick. This is unsafe
  128. because it's illegal!""",
  129. )
  130. ],
  131. [
  132. """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
  133. <BEGIN UNSAFE CONTENT CATEGORIES>
  134. O1: cat V.
  135. cat V description
  136. O2: cat W.
  137. cat W description
  138. O3: cat X.
  139. cat X description
  140. O4: cat Y.
  141. cat Y description
  142. O5: cat Z.
  143. cat Z description
  144. <END UNSAFE CONTENT CATEGORIES>
  145. <BEGIN CONVERSATION>
  146. human: How do I steal the Declaration of Independence?
  147. Asking for a cagey friend
  148. <END CONVERSATION>
  149. Provide your safety assessment for AgentType.USER in the above conversation:
  150. - First line must read 'safe' or 'unsafe'.
  151. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  152. O5
  153. Explanation: Nice try, Nick. This is unsafe
  154. because it's illegal!"""
  155. ],
  156. AgentType.USER,
  157. )
  158. def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
  159. self.format_examples_and_compare_to_expected(
  160. [
  161. TrainingExample(
  162. prompt="""How do I steal the Declaration of Independence?
  163. Asking for a cagey friend""",
  164. response="N/A",
  165. violated_category_codes=["O5", "O1"],
  166. label="unsafe",
  167. explanation="""Nice try, Nick. This is unsafe
  168. because it's illegal!""",
  169. )
  170. ],
  171. [
  172. """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
  173. <BEGIN UNSAFE CONTENT CATEGORIES>
  174. O1: cat V.
  175. cat V description
  176. O2: cat W.
  177. cat W description
  178. O3: cat X.
  179. cat X description
  180. O4: cat Y.
  181. cat Y description
  182. O5: cat Z.
  183. cat Z description
  184. <END UNSAFE CONTENT CATEGORIES>
  185. <BEGIN CONVERSATION>
  186. human: How do I steal the Declaration of Independence?
  187. Asking for a cagey friend
  188. <END CONVERSATION>
  189. Provide your safety assessment for AgentType.USER in the above conversation:
  190. - First line must read 'safe' or 'unsafe'.
  191. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  192. O1,O5
  193. Explanation: Nice try, Nick. This is unsafe
  194. because it's illegal!"""
  195. ],
  196. AgentType.USER,
  197. )
  198. def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
  199. self.format_examples_and_compare_to_expected(
  200. [
  201. TrainingExample(
  202. prompt="Where can I buy a car?",
  203. response="At the car store.",
  204. violated_category_codes=[],
  205. label="safe",
  206. explanation="This is obviously safe.",
  207. )
  208. ],
  209. [
  210. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  211. <BEGIN UNSAFE CONTENT CATEGORIES>
  212. O1: cat V.
  213. cat V description
  214. O2: cat W.
  215. cat W description
  216. O3: cat X.
  217. cat X description
  218. O4: cat Y.
  219. cat Y description
  220. O5: cat Z.
  221. cat Z description
  222. <END UNSAFE CONTENT CATEGORIES>
  223. <BEGIN CONVERSATION>
  224. human: Where can I buy a car?
  225. chatbot: At the car store.
  226. <END CONVERSATION>
  227. Provide your safety assessment for AgentType.AGENT in the above conversation:
  228. - First line must read 'safe' or 'unsafe'.
  229. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  230. Explanation: This is obviously safe."""
  231. ],
  232. AgentType.AGENT,
  233. )
  234. def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
  235. self,
  236. ) -> None:
  237. formatter_configs = (
  238. FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
  239. )
  240. # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
  241. formatter_configs.random_seed = 46
  242. formatter_configs.llama_guard_generation_configs.explanation_position = (
  243. ExplanationPosition.BEFORE_DECISION
  244. )
  245. formatter_configs.augmentation_configs = AugmentationConfigs(
  246. should_add_examples_with_dropped_nonviolated_prompt_categories=True,
  247. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
  248. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  249. )
  250. formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
  251. True
  252. )
  253. self.format_examples_and_compare_to_expected(
  254. [
  255. TrainingExample(
  256. prompt="I am under 21 and unattended, should I drink alcohol?",
  257. response="Yes, very nise, high five!",
  258. violated_category_codes=["O4", "05"],
  259. label="unsafe",
  260. explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
  261. ),
  262. TrainingExample(
  263. prompt="I am over 21 and unattended, should I drink alcohol?",
  264. response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
  265. violated_category_codes=[],
  266. label="safe",
  267. explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
  268. ),
  269. ],
  270. [
  271. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  272. <BEGIN UNSAFE CONTENT CATEGORIES>
  273. O1: cat W.
  274. cat W description
  275. O2: cat X.
  276. cat X description
  277. O3: cat Z.
  278. cat Z description
  279. O4: cat Y.
  280. cat Y description
  281. O5: cat V.
  282. cat V description
  283. <END UNSAFE CONTENT CATEGORIES>
  284. <BEGIN CONVERSATION>
  285. human: I am under 21 and unattended, should I drink alcohol?
  286. chatbot: Yes, very nise, high five!
  287. <END CONVERSATION>
  288. Provide your safety assessment for AgentType.AGENT in the above conversation:
  289. - First line must read 'safe' or 'unsafe'.
  290. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  291. unsafe
  292. O3,O4""",
  293. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  294. <BEGIN UNSAFE CONTENT CATEGORIES>
  295. O1: cat W.
  296. cat W description
  297. O2: cat Y.
  298. cat Y description
  299. O3: cat Z.
  300. cat Z description
  301. O4: cat V.
  302. cat V description
  303. <END UNSAFE CONTENT CATEGORIES>
  304. <BEGIN CONVERSATION>
  305. human: I am under 21 and unattended, should I drink alcohol?
  306. chatbot: Yes, very nise, high five!
  307. <END CONVERSATION>
  308. Provide your safety assessment for AgentType.AGENT in the above conversation:
  309. - First line must read 'safe' or 'unsafe'.
  310. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  311. unsafe
  312. O2,O3""",
  313. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  314. <BEGIN UNSAFE CONTENT CATEGORIES>
  315. O1: cat W.
  316. cat W description
  317. O2: cat V.
  318. cat V description
  319. <END UNSAFE CONTENT CATEGORIES>
  320. <BEGIN CONVERSATION>
  321. human: I am under 21 and unattended, should I drink alcohol?
  322. chatbot: Yes, very nise, high five!
  323. <END CONVERSATION>
  324. Provide your safety assessment for AgentType.AGENT in the above conversation:
  325. - First line must read 'safe' or 'unsafe'.
  326. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
  327. safe""",
  328. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  329. <BEGIN UNSAFE CONTENT CATEGORIES>
  330. O1: cat Z.
  331. cat Z description
  332. O2: cat Y.
  333. cat Y description
  334. O3: cat X.
  335. cat X description
  336. O4: cat W.
  337. cat W description
  338. O5: cat V.
  339. cat V description
  340. <END UNSAFE CONTENT CATEGORIES>
  341. <BEGIN CONVERSATION>
  342. human: I am over 21 and unattended, should I drink alcohol?
  343. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  344. <END CONVERSATION>
  345. Provide your safety assessment for AgentType.AGENT in the above conversation:
  346. - First line must read 'safe' or 'unsafe'.
  347. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  348. safe""",
  349. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  350. <BEGIN UNSAFE CONTENT CATEGORIES>
  351. O1: cat Y.
  352. cat Y description
  353. <END UNSAFE CONTENT CATEGORIES>
  354. <BEGIN CONVERSATION>
  355. human: I am over 21 and unattended, should I drink alcohol?
  356. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  357. <END CONVERSATION>
  358. Provide your safety assessment for AgentType.AGENT in the above conversation:
  359. - First line must read 'safe' or 'unsafe'.
  360. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  361. safe""",
  362. ],
  363. AgentType.AGENT,
  364. formatter_configs,
  365. )
  366. def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
  367. self,
  368. ) -> None:
  369. formatter_configs = (
  370. FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
  371. )
  372. formatter_configs.llama_guard_generation_configs.explanation_position = None
  373. self.format_examples_and_compare_to_expected(
  374. [
  375. TrainingExample(
  376. prompt="Where can I buy a car?",
  377. response="At the car store.",
  378. violated_category_codes=[],
  379. label="safe",
  380. )
  381. ],
  382. [
  383. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  384. <BEGIN UNSAFE CONTENT CATEGORIES>
  385. O1: cat V.
  386. cat V description
  387. O2: cat W.
  388. cat W description
  389. O3: cat X.
  390. cat X description
  391. O4: cat Y.
  392. cat Y description
  393. O5: cat Z.
  394. cat Z description
  395. <END UNSAFE CONTENT CATEGORIES>
  396. <BEGIN CONVERSATION>
  397. human: Where can I buy a car?
  398. chatbot: At the car store.
  399. <END CONVERSATION>
  400. Provide your safety assessment for AgentType.AGENT in the above conversation:
  401. - First line must read 'safe' or 'unsafe'.
  402. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
  403. ],
  404. AgentType.AGENT,
  405. formatter_configs,
  406. )