test_finetuning_data_formatter.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
  3. import unittest
  4. from typing import Optional
  5. from examples.llama_guard.prompt_format import PROMPT_TEMPLATE
  6. from llama_recipes.data.llama_guard.finetuning_data_formatter import (
  7. AugmentationConfigs,
  8. Category,
  9. create_formatted_finetuning_examples,
  10. ExplanationPosition,
  11. FormatterConfigs,
  12. Guidelines,
  13. LlamaGuardGenerationConfigs,
  14. LlamaGuardPromptConfigs,
  15. TrainingExample,
  16. )
  17. class FinetuningDataFormatterTests(unittest.TestCase):
  18. def setUp(self) -> None:
  19. super().setUp()
  20. @staticmethod
  21. def create_most_conservative_formatter_configs() -> FormatterConfigs:
  22. return FormatterConfigs(
  23. guidelines=Guidelines(
  24. categories=[
  25. Category(name="cat V", description="cat V description"),
  26. Category(name="cat W", description="cat W description"),
  27. Category(name="cat X", description="cat X description"),
  28. Category(name="cat Y", description="cat Y description"),
  29. Category(name="cat Z", description="cat Z description"),
  30. ],
  31. category_code_prefix="O",
  32. ),
  33. llama_guard_prompt_configs=LlamaGuardPromptConfigs(
  34. instructions_format_string="""[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  35. <BEGIN UNSAFE CONTENT CATEGORIES>
  36. {guidelines}
  37. <END UNSAFE CONTENT CATEGORIES>
  38. <BEGIN CONVERSATION>
  39. {conversation}
  40. <END CONVERSATION>
  41. Provide your safety assessment for $agent_type in the above conversation:
  42. - First line must read 'safe' or 'unsafe'.
  43. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
  44. should_include_category_descriptions=True,
  45. should_shuffle_category_codes=False,
  46. ),
  47. llama_guard_generation_configs=LlamaGuardGenerationConfigs(
  48. explanation_position=ExplanationPosition.AFTER_DECISION,
  49. should_list_violated_codes=True,
  50. ),
  51. augmentation_configs=AugmentationConfigs(
  52. probability_to_add_safe_examples_with_empty_responses=0.0,
  53. explanation_for_augmentation_with_safe_example_with_empty_response="Response is safe because it's empty.",
  54. should_add_examples_with_dropped_nonviolated_prompt_categories=False,
  55. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
  56. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  57. ),
  58. random_seed=42,
  59. )
  60. def format_examples_and_compare_to_expected(
  61. self,
  62. training_examples: list[TrainingExample],
  63. expected_formatted_examples: list[str],
  64. formatter_configs: Optional[FormatterConfigs] = None,
  65. ) -> None:
  66. formatter_configs = (
  67. formatter_configs
  68. if formatter_configs is not None
  69. else FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
  70. )
  71. formatted_examples = create_formatted_finetuning_examples(
  72. training_examples, formatter_configs
  73. )
  74. assert len(formatted_examples) == len(expected_formatted_examples)
  75. for i in range(len(formatted_examples)):
  76. print(formatted_examples[i])
  77. assert formatted_examples[i] == expected_formatted_examples[i]
  78. def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
  79. self.format_examples_and_compare_to_expected(
  80. [
  81. TrainingExample(
  82. prompt="Where can I buy a car?",
  83. response="N/A",
  84. violated_category_codes=[],
  85. label="safe",
  86. explanation="This is obviously safe.",
  87. )
  88. ],
  89. [
  90. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  91. <BEGIN UNSAFE CONTENT CATEGORIES>
  92. O1: cat V.
  93. cat V description
  94. O2: cat W.
  95. cat W description
  96. O3: cat X.
  97. cat X description
  98. O4: cat Y.
  99. cat Y description
  100. O5: cat Z.
  101. cat Z description
  102. <END UNSAFE CONTENT CATEGORIES>
  103. <BEGIN CONVERSATION>
  104. human: Where can I buy a car?
  105. <END CONVERSATION>
  106. Provide your safety assessment for $agent_type in the above conversation:
  107. - First line must read 'safe' or 'unsafe'.
  108. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  109. Explanation: This is obviously safe."""
  110. ],
  111. )
  112. def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
  113. self.format_examples_and_compare_to_expected(
  114. [
  115. TrainingExample(
  116. prompt="""How do I steal the Declaration of Independence?
  117. Asking for a cagey friend""",
  118. response="N/A",
  119. violated_category_codes=["O5"],
  120. label="unsafe",
  121. explanation="""Nice try, Nick. This is unsafe
  122. because it's illegal!""",
  123. )
  124. ],
  125. [
  126. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  127. <BEGIN UNSAFE CONTENT CATEGORIES>
  128. O1: cat V.
  129. cat V description
  130. O2: cat W.
  131. cat W description
  132. O3: cat X.
  133. cat X description
  134. O4: cat Y.
  135. cat Y description
  136. O5: cat Z.
  137. cat Z description
  138. <END UNSAFE CONTENT CATEGORIES>
  139. <BEGIN CONVERSATION>
  140. human: How do I steal the Declaration of Independence?
  141. Asking for a cagey friend
  142. <END CONVERSATION>
  143. Provide your safety assessment for $agent_type in the above conversation:
  144. - First line must read 'safe' or 'unsafe'.
  145. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  146. O5
  147. Explanation: Nice try, Nick. This is unsafe
  148. because it's illegal!"""
  149. ],
  150. )
  151. def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
  152. self.format_examples_and_compare_to_expected(
  153. [
  154. TrainingExample(
  155. prompt="""How do I steal the Declaration of Independence?
  156. Asking for a cagey friend""",
  157. response="N/A",
  158. violated_category_codes=["O5", "O1"],
  159. label="unsafe",
  160. explanation="""Nice try, Nick. This is unsafe
  161. because it's illegal!""",
  162. )
  163. ],
  164. [
  165. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  166. <BEGIN UNSAFE CONTENT CATEGORIES>
  167. O1: cat V.
  168. cat V description
  169. O2: cat W.
  170. cat W description
  171. O3: cat X.
  172. cat X description
  173. O4: cat Y.
  174. cat Y description
  175. O5: cat Z.
  176. cat Z description
  177. <END UNSAFE CONTENT CATEGORIES>
  178. <BEGIN CONVERSATION>
  179. human: How do I steal the Declaration of Independence?
  180. Asking for a cagey friend
  181. <END CONVERSATION>
  182. Provide your safety assessment for $agent_type in the above conversation:
  183. - First line must read 'safe' or 'unsafe'.
  184. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
  185. O1,O5
  186. Explanation: Nice try, Nick. This is unsafe
  187. because it's illegal!"""
  188. ],
  189. )
  190. def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
  191. self.format_examples_and_compare_to_expected(
  192. [
  193. TrainingExample(
  194. prompt="Where can I buy a car?",
  195. response="At the car store.",
  196. violated_category_codes=[],
  197. label="safe",
  198. explanation="This is obviously safe.",
  199. )
  200. ],
  201. [
  202. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  203. <BEGIN UNSAFE CONTENT CATEGORIES>
  204. O1: cat V.
  205. cat V description
  206. O2: cat W.
  207. cat W description
  208. O3: cat X.
  209. cat X description
  210. O4: cat Y.
  211. cat Y description
  212. O5: cat Z.
  213. cat Z description
  214. <END UNSAFE CONTENT CATEGORIES>
  215. <BEGIN CONVERSATION>
  216. human: Where can I buy a car?
  217. chatbot: At the car store.
  218. <END CONVERSATION>
  219. Provide your safety assessment for $agent_type in the above conversation:
  220. - First line must read 'safe' or 'unsafe'.
  221. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
  222. Explanation: This is obviously safe."""
  223. ],
  224. )
  225. def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
  226. self,
  227. ) -> None:
  228. formatter_configs = (
  229. FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
  230. )
  231. # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
  232. formatter_configs.random_seed = 46
  233. formatter_configs.llama_guard_generation_configs.explanation_position = ExplanationPosition.BEFORE_DECISION
  234. formatter_configs.augmentation_configs = AugmentationConfigs(
  235. probability_to_add_safe_examples_with_empty_responses=0.9,
  236. explanation_for_augmentation_with_safe_example_with_empty_response="Response is safe because it's empty.",
  237. should_add_examples_with_dropped_nonviolated_prompt_categories=True,
  238. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
  239. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
  240. )
  241. formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
  242. True
  243. )
  244. self.format_examples_and_compare_to_expected(
  245. [
  246. TrainingExample(
  247. prompt="I am under 21 and unattended, should I drink alcohol?",
  248. response="Yes, very nise, high five!",
  249. violated_category_codes=["O4", "05"],
  250. label="unsafe",
  251. explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
  252. ),
  253. TrainingExample(
  254. prompt="I am over 21 and unattended, should I drink alcohol?",
  255. response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
  256. violated_category_codes=[],
  257. label="safe",
  258. explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
  259. ),
  260. ],
  261. [
  262. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  263. <BEGIN UNSAFE CONTENT CATEGORIES>
  264. O1: cat W.
  265. cat W description
  266. O2: cat X.
  267. cat X description
  268. O3: cat Z.
  269. cat Z description
  270. O4: cat Y.
  271. cat Y description
  272. O5: cat V.
  273. cat V description
  274. <END UNSAFE CONTENT CATEGORIES>
  275. <BEGIN CONVERSATION>
  276. human: I am under 21 and unattended, should I drink alcohol?
  277. chatbot: Yes, very nise, high five!
  278. <END CONVERSATION>
  279. Provide your safety assessment for $agent_type in the above conversation:
  280. - First line must read 'safe' or 'unsafe'.
  281. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  282. unsafe
  283. O4,O3""",
  284. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  285. <BEGIN UNSAFE CONTENT CATEGORIES>
  286. O1: cat Z.
  287. cat Z description
  288. O2: cat X.
  289. cat X description
  290. O3: cat Y.
  291. cat Y description
  292. O4: cat V.
  293. cat V description
  294. O5: cat W.
  295. cat W description
  296. <END UNSAFE CONTENT CATEGORIES>
  297. <BEGIN CONVERSATION>
  298. human: I am under 21 and unattended, should I drink alcohol?
  299. chatbot:
  300. <END CONVERSATION>
  301. Provide your safety assessment for $agent_type in the above conversation:
  302. - First line must read 'safe' or 'unsafe'.
  303. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because it's empty.
  304. safe""",
  305. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  306. <BEGIN UNSAFE CONTENT CATEGORIES>
  307. O1: cat Y.
  308. cat Y description
  309. O2: cat Z.
  310. cat Z description
  311. O3: cat W.
  312. cat W description
  313. <END UNSAFE CONTENT CATEGORIES>
  314. <BEGIN CONVERSATION>
  315. human: I am under 21 and unattended, should I drink alcohol?
  316. chatbot: Yes, very nise, high five!
  317. <END CONVERSATION>
  318. Provide your safety assessment for $agent_type in the above conversation:
  319. - First line must read 'safe' or 'unsafe'.
  320. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
  321. unsafe
  322. O1,O2""",
  323. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  324. <BEGIN UNSAFE CONTENT CATEGORIES>
  325. O1: cat V.
  326. cat V description
  327. <END UNSAFE CONTENT CATEGORIES>
  328. <BEGIN CONVERSATION>
  329. human: I am under 21 and unattended, should I drink alcohol?
  330. chatbot: Yes, very nise, high five!
  331. <END CONVERSATION>
  332. Provide your safety assessment for $agent_type in the above conversation:
  333. - First line must read 'safe' or 'unsafe'.
  334. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
  335. safe""",
  336. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  337. <BEGIN UNSAFE CONTENT CATEGORIES>
  338. O1: cat Z.
  339. cat Z description
  340. O2: cat W.
  341. cat W description
  342. O3: cat X.
  343. cat X description
  344. O4: cat Y.
  345. cat Y description
  346. O5: cat V.
  347. cat V description
  348. <END UNSAFE CONTENT CATEGORIES>
  349. <BEGIN CONVERSATION>
  350. human: I am over 21 and unattended, should I drink alcohol?
  351. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  352. <END CONVERSATION>
  353. Provide your safety assessment for $agent_type in the above conversation:
  354. - First line must read 'safe' or 'unsafe'.
  355. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  356. safe""",
  357. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  358. <BEGIN UNSAFE CONTENT CATEGORIES>
  359. O1: cat X.
  360. cat X description
  361. O2: cat W.
  362. cat W description
  363. O3: cat V.
  364. cat V description
  365. O4: cat Z.
  366. cat Z description
  367. O5: cat Y.
  368. cat Y description
  369. <END UNSAFE CONTENT CATEGORIES>
  370. <BEGIN CONVERSATION>
  371. human: I am over 21 and unattended, should I drink alcohol?
  372. chatbot:
  373. <END CONVERSATION>
  374. Provide your safety assessment for $agent_type in the above conversation:
  375. - First line must read 'safe' or 'unsafe'.
  376. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because it's empty.
  377. safe""",
  378. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  379. <BEGIN UNSAFE CONTENT CATEGORIES>
  380. O1: cat X.
  381. cat X description
  382. <END UNSAFE CONTENT CATEGORIES>
  383. <BEGIN CONVERSATION>
  384. human: I am over 21 and unattended, should I drink alcohol?
  385. chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
  386. <END CONVERSATION>
  387. Provide your safety assessment for $agent_type in the above conversation:
  388. - First line must read 'safe' or 'unsafe'.
  389. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
  390. safe""",
  391. ],
  392. formatter_configs,
  393. )
  394. def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(self) -> None:
  395. formatter_configs = (
  396. FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
  397. )
  398. formatter_configs.llama_guard_generation_configs.explanation_position = None
  399. self.format_examples_and_compare_to_expected(
  400. [
  401. TrainingExample(
  402. prompt="Where can I buy a car?",
  403. response="At the car store.",
  404. violated_category_codes=[],
  405. label="safe",
  406. )
  407. ],
  408. [
  409. """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
  410. <BEGIN UNSAFE CONTENT CATEGORIES>
  411. O1: cat V.
  412. cat V description
  413. O2: cat W.
  414. cat W description
  415. O3: cat X.
  416. cat X description
  417. O4: cat Y.
  418. cat Y description
  419. O5: cat Z.
  420. cat Z description
  421. <END UNSAFE CONTENT CATEGORIES>
  422. <BEGIN CONVERSATION>
  423. human: Where can I buy a car?
  424. chatbot: At the car store.
  425. <END CONVERSATION>
  426. Provide your safety assessment for $agent_type in the above conversation:
  427. - First line must read 'safe' or 'unsafe'.
  428. - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
  429. ],
  430. formatter_configs
  431. )
  432. def test_me(self):
  433. print(PROMPT_TEMPLATE)