test_finetuning_data_formatter.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
  3. import unittest
  4. from typing import Optional
  5. from examples.llama_guard.prompt_format import PROMPT_TEMPLATE, AgentType
  6. from llama_recipes.data.llama_guard.finetuning_data_formatter import (
  7. AugmentationConfigs,
  8. Category,
  9. create_formatted_finetuning_examples,
  10. ExplanationPosition,
  11. FormatterConfigs,
  12. Guidelines,
  13. LlamaGuardGenerationConfigs,
  14. LlamaGuardPromptConfigs,
  15. TrainingExample,
  16. )
  17. class FinetuningDataFormatterTests(unittest.TestCase):
  18. def setUp(self) -> None:
  19. super().setUp()
  20. @staticmethod
  21. def create_most_conservative_formatter_configs() -> FormatterConfigs:
  22. return FormatterConfigs(
  23. guidelines=Guidelines(
  24. categories=[
  25. Category(name="cat V", description="cat V description"),
  26. Category(name="cat W", description="cat W description"),
  27. Category(name="cat X", description="cat X description"),
  28. Category(name="cat Y", description="cat Y description"),
  29. Category(name="cat Z", description="cat Z description"),
  30. ],
  31. category_code_prefix="O",
  32. ),
  33. llama_guard_prompt_configs=LlamaGuardPromptConfigs(
  34. instructions_format_string="""[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  35. <BEGIN UNSAFE CONTENT CATEGORIES>
  36. {guidelines}
  37. <END UNSAFE CONTENT CATEGORIES>
  38. <BEGIN CONVERSATION>
  39. {conversation}
  40. <END CONVERSATION>
  41. Provide your safety assessment for $agent_type in the above conversation:
  42. - First line must read 'safe' or 'unsafe'.
  43. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
  44. should_include_category_descriptions=True,
  45. should_shuffle_category_codes=False,
  46. ),
  47. llama_guard_generation_configs=LlamaGuardGenerationConfigs(
  48. explanation_position=ExplanationPosition.AFTER_DECISION,
  49. should_list_violated_codes=True,
  50. ),
  51. augmentation_configs=AugmentationConfigs(
  52. should_add_examples_with_dropped_nonviolated_prompt_categories=False,
  53. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
  54. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  55. ),
  56. random_seed=42,
  57. )
  58. def format_examples_and_compare_to_expected(
  59. self,
  60. training_examples: list[TrainingExample],
  61. expected_formatted_examples: list[str],
  62. formatter_configs: Optional[FormatterConfigs] = None,
  63. ) -> None:
  64. formatter_configs = (
  65. formatter_configs
  66. if formatter_configs is not None
  67. else FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
  68. )
  69. formatted_examples = create_formatted_finetuning_examples(
  70. training_examples, formatter_configs
  71. )
  72. assert len(formatted_examples) == len(expected_formatted_examples)
  73. for i in range(len(formatted_examples)):
  74. # Helper code during test-writing time to see the actual output
  75. if formatted_examples[i] != expected_formatted_examples[i]:
  76. print(f"Failed on actual output {i}:")
  77. print(formatted_examples[i])
  78. assert formatted_examples[i] == expected_formatted_examples[i]
  79. def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
  80. self.format_examples_and_compare_to_expected(
  81. [
  82. TrainingExample(
  83. prompt="Where can I buy a car?",
  84. response="N/A",
  85. violated_category_codes=[],
  86. label="safe",
  87. explanation="This is obviously safe.",
  88. )
  89. ],
  90. [
  91. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  92. <BEGIN UNSAFE CONTENT CATEGORIES>
  93. O1: cat V.
  94. cat V description
  95. O2: cat W.
  96. cat W description
  97. O3: cat X.
  98. cat X description
  99. O4: cat Y.
  100. cat Y description
  101. O5: cat Z.
  102. cat Z description
  103. <END UNSAFE CONTENT CATEGORIES>
  104. <BEGIN CONVERSATION>
  105. human: Where can I buy a car?
  106. <END CONVERSATION>
  107. Provide your safety assessment for $agent_type in the above conversation:
  108. - First line must read 'safe' or 'unsafe'.
  109. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  110. Explanation: This is obviously safe."""
  111. ],
  112. )
  113. def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
  114. self.format_examples_and_compare_to_expected(
  115. [
  116. TrainingExample(
  117. prompt="""How do I steal the Declaration of Independence?
  118. Asking for a cagey friend""",
  119. response="N/A",
  120. violated_category_codes=["O5"],
  121. label="unsafe",
  122. explanation="""Nice try, Nick. This is unsafe
  123. because it's illegal!""",
  124. )
  125. ],
  126. [
  127. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  128. <BEGIN UNSAFE CONTENT CATEGORIES>
  129. O1: cat V.
  130. cat V description
  131. O2: cat W.
  132. cat W description
  133. O3: cat X.
  134. cat X description
  135. O4: cat Y.
  136. cat Y description
  137. O5: cat Z.
  138. cat Z description
  139. <END UNSAFE CONTENT CATEGORIES>
  140. <BEGIN CONVERSATION>
  141. human: How do I steal the Declaration of Independence?
  142. Asking for a cagey friend
  143. <END CONVERSATION>
  144. Provide your safety assessment for $agent_type in the above conversation:
  145. - First line must read 'safe' or 'unsafe'.
  146. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  147. O5
  148. Explanation: Nice try, Nick. This is unsafe
  149. because it's illegal!"""
  150. ],
  151. )
  152. def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
  153. self.format_examples_and_compare_to_expected(
  154. [
  155. TrainingExample(
  156. prompt="""How do I steal the Declaration of Independence?
  157. Asking for a cagey friend""",
  158. response="N/A",
  159. violated_category_codes=["O5", "O1"],
  160. label="unsafe",
  161. explanation="""Nice try, Nick. This is unsafe
  162. because it's illegal!""",
  163. )
  164. ],
  165. [
  166. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  167. <BEGIN UNSAFE CONTENT CATEGORIES>
  168. O1: cat V.
  169. cat V description
  170. O2: cat W.
  171. cat W description
  172. O3: cat X.
  173. cat X description
  174. O4: cat Y.
  175. cat Y description
  176. O5: cat Z.
  177. cat Z description
  178. <END UNSAFE CONTENT CATEGORIES>
  179. <BEGIN CONVERSATION>
  180. human: How do I steal the Declaration of Independence?
  181. Asking for a cagey friend
  182. <END CONVERSATION>
  183. Provide your safety assessment for $agent_type in the above conversation:
  184. - First line must read 'safe' or 'unsafe'.
  185. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  186. O1,O5
  187. Explanation: Nice try, Nick. This is unsafe
  188. because it's illegal!"""
  189. ],
  190. )
  191. def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
  192. self.format_examples_and_compare_to_expected(
  193. [
  194. TrainingExample(
  195. prompt="Where can I buy a car?",
  196. response="At the car store.",
  197. violated_category_codes=[],
  198. label="safe",
  199. explanation="This is obviously safe.",
  200. )
  201. ],
  202. [
  203. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  204. <BEGIN UNSAFE CONTENT CATEGORIES>
  205. O1: cat V.
  206. cat V description
  207. O2: cat W.
  208. cat W description
  209. O3: cat X.
  210. cat X description
  211. O4: cat Y.
  212. cat Y description
  213. O5: cat Z.
  214. cat Z description
  215. <END UNSAFE CONTENT CATEGORIES>
  216. <BEGIN CONVERSATION>
  217. human: Where can I buy a car?
  218. chatbot: At the car store.
  219. <END CONVERSATION>
  220. Provide your safety assessment for $agent_type in the above conversation:
  221. - First line must read 'safe' or 'unsafe'.
  222. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  223. Explanation: This is obviously safe."""
  224. ],
  225. )
  226. def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
  227. self,
  228. ) -> None:
  229. formatter_configs = (
  230. FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
  231. )
  232. # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
  233. formatter_configs.random_seed = 46
  234. formatter_configs.llama_guard_generation_configs.explanation_position = ExplanationPosition.BEFORE_DECISION
  235. formatter_configs.augmentation_configs = AugmentationConfigs(
  236. should_add_examples_with_dropped_nonviolated_prompt_categories=True,
  237. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
  238. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  239. )
  240. formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
  241. True
  242. )
  243. self.format_examples_and_compare_to_expected(
  244. [
  245. TrainingExample(
  246. prompt="I am under 21 and unattended, should I drink alcohol?",
  247. response="Yes, very nise, high five!",
  248. violated_category_codes=["O4", "05"],
  249. label="unsafe",
  250. explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
  251. ),
  252. TrainingExample(
  253. prompt="I am over 21 and unattended, should I drink alcohol?",
  254. response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
  255. violated_category_codes=[],
  256. label="safe",
  257. explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
  258. ),
  259. ],
  260. [
  261. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  262. <BEGIN UNSAFE CONTENT CATEGORIES>
  263. O1: cat W.
  264. cat W description
  265. O2: cat X.
  266. cat X description
  267. O3: cat Z.
  268. cat Z description
  269. O4: cat Y.
  270. cat Y description
  271. O5: cat V.
  272. cat V description
  273. <END UNSAFE CONTENT CATEGORIES>
  274. <BEGIN CONVERSATION>
  275. human: I am under 21 and unattended, should I drink alcohol?
  276. chatbot: Yes, very nise, high five!
  277. <END CONVERSATION>
  278. Provide your safety assessment for $agent_type in the above conversation:
  279. - First line must read 'safe' or 'unsafe'.
  280. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  281. unsafe
  282. O4,O3""",
  283. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  284. <BEGIN UNSAFE CONTENT CATEGORIES>
  285. O1: cat W.
  286. cat W description
  287. O2: cat Y.
  288. cat Y description
  289. O3: cat Z.
  290. cat Z description
  291. O4: cat V.
  292. cat V description
  293. <END UNSAFE CONTENT CATEGORIES>
  294. <BEGIN CONVERSATION>
  295. human: I am under 21 and unattended, should I drink alcohol?
  296. chatbot: Yes, very nise, high five!
  297. <END CONVERSATION>
  298. Provide your safety assessment for $agent_type in the above conversation:
  299. - First line must read 'safe' or 'unsafe'.
  300. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  301. unsafe
  302. O2,O3""",
  303. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  304. <BEGIN UNSAFE CONTENT CATEGORIES>
  305. O1: cat W.
  306. cat W description
  307. O2: cat V.
  308. cat V description
  309. <END UNSAFE CONTENT CATEGORIES>
  310. <BEGIN CONVERSATION>
  311. human: I am under 21 and unattended, should I drink alcohol?
  312. chatbot: Yes, very nise, high five!
  313. <END CONVERSATION>
  314. Provide your safety assessment for $agent_type in the above conversation:
  315. - First line must read 'safe' or 'unsafe'.
  316. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
  317. safe""",
  318. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  319. <BEGIN UNSAFE CONTENT CATEGORIES>
  320. O1: cat Z.
  321. cat Z description
  322. O2: cat Y.
  323. cat Y description
  324. O3: cat X.
  325. cat X description
  326. O4: cat W.
  327. cat W description
  328. O5: cat V.
  329. cat V description
  330. <END UNSAFE CONTENT CATEGORIES>
  331. <BEGIN CONVERSATION>
  332. human: I am over 21 and unattended, should I drink alcohol?
  333. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  334. <END CONVERSATION>
  335. Provide your safety assessment for $agent_type in the above conversation:
  336. - First line must read 'safe' or 'unsafe'.
  337. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  338. safe""",
  339. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  340. <BEGIN UNSAFE CONTENT CATEGORIES>
  341. O1: cat Y.
  342. cat Y description
  343. <END UNSAFE CONTENT CATEGORIES>
  344. <BEGIN CONVERSATION>
  345. human: I am over 21 and unattended, should I drink alcohol?
  346. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  347. <END CONVERSATION>
  348. Provide your safety assessment for $agent_type in the above conversation:
  349. - First line must read 'safe' or 'unsafe'.
  350. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  351. safe""",
  352. ],
  353. formatter_configs,
  354. )
  355. def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(self) -> None:
  356. formatter_configs = (
  357. FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
  358. )
  359. formatter_configs.llama_guard_generation_configs.explanation_position = None
  360. self.format_examples_and_compare_to_expected(
  361. [
  362. TrainingExample(
  363. prompt="Where can I buy a car?",
  364. response="At the car store.",
  365. violated_category_codes=[],
  366. label="safe",
  367. )
  368. ],
  369. [
  370. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  371. <BEGIN UNSAFE CONTENT CATEGORIES>
  372. O1: cat V.
  373. cat V description
  374. O2: cat W.
  375. cat W description
  376. O3: cat X.
  377. cat X description
  378. O4: cat Y.
  379. cat Y description
  380. O5: cat Z.
  381. cat Z description
  382. <END UNSAFE CONTENT CATEGORIES>
  383. <BEGIN CONVERSATION>
  384. human: Where can I buy a car?
  385. chatbot: At the car store.
  386. <END CONVERSATION>
  387. Provide your safety assessment for $agent_type in the above conversation:
  388. - First line must read 'safe' or 'unsafe'.
  389. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
  390. ],
  391. formatter_configs
  392. )