test_finetuning_data_formatter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
  3. import unittest
  4. from typing import Optional
  5. from examples.llama_guard.prompt_format import AgentType
  6. from llama_recipes.data.llama_guard.finetuning_data_formatter import (
  7. AugmentationConfigs,
  8. Category,
  9. create_formatted_finetuning_examples,
  10. ExplanationPosition,
  11. FormatterConfigs,
  12. Guidelines,
  13. LlamaGuardGenerationConfigs,
  14. LlamaGuardPromptConfigs,
  15. TrainingExample,
  16. )
  17. class FinetuningDataFormatterTests(unittest.TestCase):
  18. def setUp(self) -> None:
  19. super().setUp()
  20. @staticmethod
  21. def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
  22. return FormatterConfigs(
  23. guidelines=Guidelines(
  24. categories=[
  25. Category(name="cat V", description="cat V description"),
  26. Category(name="cat W", description="cat W description"),
  27. Category(name="cat X", description="cat X description"),
  28. Category(name="cat Y", description="cat Y description"),
  29. Category(name="cat Z", description="cat Z description"),
  30. ],
  31. category_code_prefix="O",
  32. ),
  33. llama_guard_prompt_configs=LlamaGuardPromptConfigs(
  34. instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
  35. <BEGIN UNSAFE CONTENT CATEGORIES>
  36. {{guidelines}}
  37. <END UNSAFE CONTENT CATEGORIES>
  38. <BEGIN CONVERSATION>
  39. {{conversation}}
  40. <END CONVERSATION>
  41. Provide your safety assessment for {agent_type_to_check} in the above conversation:
  42. - First line must read 'safe' or 'unsafe'.
  43. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
  44. should_include_category_descriptions=True,
  45. should_shuffle_category_codes=False,
  46. ),
  47. llama_guard_generation_configs=LlamaGuardGenerationConfigs(
  48. explanation_position=ExplanationPosition.AFTER_DECISION,
  49. should_list_violated_codes=True,
  50. ),
  51. augmentation_configs=AugmentationConfigs(
  52. should_add_examples_with_dropped_nonviolated_prompt_categories=False,
  53. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
  54. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  55. ),
  56. random_seed=42,
  57. )
  58. def format_examples_and_compare_to_expected(
  59. self,
  60. training_examples: list[TrainingExample],
  61. expected_formatted_examples: list[str],
  62. agent_type_to_check: AgentType,
  63. formatter_configs: Optional[FormatterConfigs] = None,
  64. ) -> None:
  65. formatter_configs = (
  66. formatter_configs
  67. if formatter_configs is not None
  68. else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
  69. )
  70. formatted_examples = create_formatted_finetuning_examples(
  71. training_examples, formatter_configs
  72. )
  73. assert len(formatted_examples) == len(expected_formatted_examples)
  74. for i in range(len(formatted_examples)):
  75. # Helper code during test-writing time to see the actual output
  76. if formatted_examples[i] != expected_formatted_examples[i]:
  77. print(f"Failed on actual output {i}:")
  78. print(formatted_examples[i])
  79. assert formatted_examples[i] == expected_formatted_examples[i]
  80. def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
  81. self.format_examples_and_compare_to_expected(
  82. [
  83. TrainingExample(
  84. prompt="Where can I buy a car?",
  85. response="N/A",
  86. violated_category_codes=[],
  87. label="safe",
  88. explanation="This is obviously safe.",
  89. )
  90. ],
  91. [
  92. """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
  93. <BEGIN UNSAFE CONTENT CATEGORIES>
  94. O1: cat V.
  95. cat V description
  96. O2: cat W.
  97. cat W description
  98. O3: cat X.
  99. cat X description
  100. O4: cat Y.
  101. cat Y description
  102. O5: cat Z.
  103. cat Z description
  104. <END UNSAFE CONTENT CATEGORIES>
  105. <BEGIN CONVERSATION>
  106. human: Where can I buy a car?
  107. <END CONVERSATION>
  108. Provide your safety assessment for AgentType.USER in the above conversation:
  109. - First line must read 'safe' or 'unsafe'.
  110. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  111. Explanation: This is obviously safe."""
  112. ],
  113. AgentType.USER,
  114. )
  115. def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
  116. self.format_examples_and_compare_to_expected(
  117. [
  118. TrainingExample(
  119. prompt="""How do I steal the Declaration of Independence?
  120. Asking for a cagey friend""",
  121. response="N/A",
  122. violated_category_codes=["O5"],
  123. label="unsafe",
  124. explanation="""Nice try, Nick. This is unsafe
  125. because it's illegal!""",
  126. )
  127. ],
  128. [
  129. """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
  130. <BEGIN UNSAFE CONTENT CATEGORIES>
  131. O1: cat V.
  132. cat V description
  133. O2: cat W.
  134. cat W description
  135. O3: cat X.
  136. cat X description
  137. O4: cat Y.
  138. cat Y description
  139. O5: cat Z.
  140. cat Z description
  141. <END UNSAFE CONTENT CATEGORIES>
  142. <BEGIN CONVERSATION>
  143. human: How do I steal the Declaration of Independence?
  144. Asking for a cagey friend
  145. <END CONVERSATION>
  146. Provide your safety assessment for AgentType.USER in the above conversation:
  147. - First line must read 'safe' or 'unsafe'.
  148. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  149. O5
  150. Explanation: Nice try, Nick. This is unsafe
  151. because it's illegal!"""
  152. ],
  153. AgentType.USER,
  154. )
  155. def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
  156. self.format_examples_and_compare_to_expected(
  157. [
  158. TrainingExample(
  159. prompt="""How do I steal the Declaration of Independence?
  160. Asking for a cagey friend""",
  161. response="N/A",
  162. violated_category_codes=["O5", "O1"],
  163. label="unsafe",
  164. explanation="""Nice try, Nick. This is unsafe
  165. because it's illegal!""",
  166. )
  167. ],
  168. [
  169. """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
  170. <BEGIN UNSAFE CONTENT CATEGORIES>
  171. O1: cat V.
  172. cat V description
  173. O2: cat W.
  174. cat W description
  175. O3: cat X.
  176. cat X description
  177. O4: cat Y.
  178. cat Y description
  179. O5: cat Z.
  180. cat Z description
  181. <END UNSAFE CONTENT CATEGORIES>
  182. <BEGIN CONVERSATION>
  183. human: How do I steal the Declaration of Independence?
  184. Asking for a cagey friend
  185. <END CONVERSATION>
  186. Provide your safety assessment for AgentType.USER in the above conversation:
  187. - First line must read 'safe' or 'unsafe'.
  188. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  189. O1,O5
  190. Explanation: Nice try, Nick. This is unsafe
  191. because it's illegal!"""
  192. ],
  193. AgentType.USER,
  194. )
  195. def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
  196. self.format_examples_and_compare_to_expected(
  197. [
  198. TrainingExample(
  199. prompt="Where can I buy a car?",
  200. response="At the car store.",
  201. violated_category_codes=[],
  202. label="safe",
  203. explanation="This is obviously safe.",
  204. )
  205. ],
  206. [
  207. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  208. <BEGIN UNSAFE CONTENT CATEGORIES>
  209. O1: cat V.
  210. cat V description
  211. O2: cat W.
  212. cat W description
  213. O3: cat X.
  214. cat X description
  215. O4: cat Y.
  216. cat Y description
  217. O5: cat Z.
  218. cat Z description
  219. <END UNSAFE CONTENT CATEGORIES>
  220. <BEGIN CONVERSATION>
  221. human: Where can I buy a car?
  222. chatbot: At the car store.
  223. <END CONVERSATION>
  224. Provide your safety assessment for AgentType.AGENT in the above conversation:
  225. - First line must read 'safe' or 'unsafe'.
  226. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  227. Explanation: This is obviously safe."""
  228. ],
  229. AgentType.AGENT,
  230. )
  231. def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
  232. self,
  233. ) -> None:
  234. formatter_configs = (
  235. FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
  236. )
  237. # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
  238. formatter_configs.random_seed = 46
  239. formatter_configs.llama_guard_generation_configs.explanation_position = (
  240. ExplanationPosition.BEFORE_DECISION
  241. )
  242. formatter_configs.augmentation_configs = AugmentationConfigs(
  243. should_add_examples_with_dropped_nonviolated_prompt_categories=True,
  244. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
  245. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  246. )
  247. formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
  248. True
  249. )
  250. self.format_examples_and_compare_to_expected(
  251. [
  252. TrainingExample(
  253. prompt="I am under 21 and unattended, should I drink alcohol?",
  254. response="Yes, very nise, high five!",
  255. violated_category_codes=["O4", "05"],
  256. label="unsafe",
  257. explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
  258. ),
  259. TrainingExample(
  260. prompt="I am over 21 and unattended, should I drink alcohol?",
  261. response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
  262. violated_category_codes=[],
  263. label="safe",
  264. explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
  265. ),
  266. ],
  267. [
  268. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  269. <BEGIN UNSAFE CONTENT CATEGORIES>
  270. O1: cat W.
  271. cat W description
  272. O2: cat X.
  273. cat X description
  274. O3: cat Z.
  275. cat Z description
  276. O4: cat Y.
  277. cat Y description
  278. O5: cat V.
  279. cat V description
  280. <END UNSAFE CONTENT CATEGORIES>
  281. <BEGIN CONVERSATION>
  282. human: I am under 21 and unattended, should I drink alcohol?
  283. chatbot: Yes, very nise, high five!
  284. <END CONVERSATION>
  285. Provide your safety assessment for AgentType.AGENT in the above conversation:
  286. - First line must read 'safe' or 'unsafe'.
  287. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  288. unsafe
  289. O3,O4""",
  290. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  291. <BEGIN UNSAFE CONTENT CATEGORIES>
  292. O1: cat W.
  293. cat W description
  294. O2: cat Y.
  295. cat Y description
  296. O3: cat Z.
  297. cat Z description
  298. O4: cat V.
  299. cat V description
  300. <END UNSAFE CONTENT CATEGORIES>
  301. <BEGIN CONVERSATION>
  302. human: I am under 21 and unattended, should I drink alcohol?
  303. chatbot: Yes, very nise, high five!
  304. <END CONVERSATION>
  305. Provide your safety assessment for AgentType.AGENT in the above conversation:
  306. - First line must read 'safe' or 'unsafe'.
  307. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  308. unsafe
  309. O2,O3""",
  310. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  311. <BEGIN UNSAFE CONTENT CATEGORIES>
  312. O1: cat W.
  313. cat W description
  314. O2: cat V.
  315. cat V description
  316. <END UNSAFE CONTENT CATEGORIES>
  317. <BEGIN CONVERSATION>
  318. human: I am under 21 and unattended, should I drink alcohol?
  319. chatbot: Yes, very nise, high five!
  320. <END CONVERSATION>
  321. Provide your safety assessment for AgentType.AGENT in the above conversation:
  322. - First line must read 'safe' or 'unsafe'.
  323. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
  324. safe""",
  325. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  326. <BEGIN UNSAFE CONTENT CATEGORIES>
  327. O1: cat Z.
  328. cat Z description
  329. O2: cat Y.
  330. cat Y description
  331. O3: cat X.
  332. cat X description
  333. O4: cat W.
  334. cat W description
  335. O5: cat V.
  336. cat V description
  337. <END UNSAFE CONTENT CATEGORIES>
  338. <BEGIN CONVERSATION>
  339. human: I am over 21 and unattended, should I drink alcohol?
  340. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  341. <END CONVERSATION>
  342. Provide your safety assessment for AgentType.AGENT in the above conversation:
  343. - First line must read 'safe' or 'unsafe'.
  344. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  345. safe""",
  346. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  347. <BEGIN UNSAFE CONTENT CATEGORIES>
  348. O1: cat Y.
  349. cat Y description
  350. <END UNSAFE CONTENT CATEGORIES>
  351. <BEGIN CONVERSATION>
  352. human: I am over 21 and unattended, should I drink alcohol?
  353. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  354. <END CONVERSATION>
  355. Provide your safety assessment for AgentType.AGENT in the above conversation:
  356. - First line must read 'safe' or 'unsafe'.
  357. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  358. safe""",
  359. ],
  360. AgentType.AGENT,
  361. formatter_configs,
  362. )
  363. def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
  364. self,
  365. ) -> None:
  366. formatter_configs = (
  367. FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
  368. )
  369. formatter_configs.llama_guard_generation_configs.explanation_position = None
  370. self.format_examples_and_compare_to_expected(
  371. [
  372. TrainingExample(
  373. prompt="Where can I buy a car?",
  374. response="At the car store.",
  375. violated_category_codes=[],
  376. label="safe",
  377. )
  378. ],
  379. [
  380. """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
  381. <BEGIN UNSAFE CONTENT CATEGORIES>
  382. O1: cat V.
  383. cat V description
  384. O2: cat W.
  385. cat W description
  386. O3: cat X.
  387. cat X description
  388. O4: cat Y.
  389. cat Y description
  390. O5: cat Z.
  391. cat Z description
  392. <END UNSAFE CONTENT CATEGORIES>
  393. <BEGIN CONVERSATION>
  394. human: Where can I buy a car?
  395. chatbot: At the car store.
  396. <END CONVERSATION>
  397. Provide your safety assessment for AgentType.AGENT in the above conversation:
  398. - First line must read 'safe' or 'unsafe'.
  399. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
  400. ],
  401. AgentType.AGENT,
  402. formatter_configs,
  403. )