finetuning_data_formatter.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama Guard License Agreement.
  3. import copy
  4. import random
  5. from dataclasses import dataclass
  6. from enum import Enum
  7. from typing import Dict, List, Literal, Optional, Sequence
  8. @dataclass
  9. class Category:
  10. name: str
  11. description: str
  12. @dataclass
  13. class Guidelines:
  14. categories: Sequence[Category]
  15. category_code_prefix: str = "O"
  16. class ExplanationPosition(Enum):
  17. BEFORE_DECISION = 0
  18. AFTER_DECISION = 1
  19. @dataclass
  20. class LlamaGuardPromptConfigs:
  21. instructions_format_string: str
  22. should_include_category_descriptions: bool
  23. should_shuffle_category_codes: bool = True
  24. @dataclass
  25. class LlamaGuardGenerationConfigs:
  26. should_list_violated_codes: bool
  27. explanation_position: Optional[ExplanationPosition]
  28. @dataclass
  29. class AugmentationConfigs:
  30. probability_to_add_safe_examples_with_empty_responses: float = 0
  31. explanation_for_augmentation_with_safe_example_with_empty_response: Optional[
  32. str
  33. ] = None
  34. should_add_examples_with_dropped_nonviolated_prompt_categories: bool = True
  35. should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories: bool = (
  36. False
  37. )
  38. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories: Optional[
  39. str
  40. ] = None
  41. @dataclass
  42. class FormatterConfigs:
  43. guidelines: Guidelines
  44. llama_guard_prompt_configs: LlamaGuardPromptConfigs
  45. llama_guard_generation_configs: LlamaGuardGenerationConfigs
  46. augmentation_configs: AugmentationConfigs
  47. # Allows subsequent reruns to reuse a stable seed for reproducibility
  48. random_seed: int = 42
  49. @dataclass
  50. class TrainingExample:
  51. prompt: str
  52. response: str
  53. violated_category_codes: list[str]
  54. label: Literal["safe", "unsafe"]
  55. explanation: str
  56. def create_formatted_finetuning_examples(
  57. training_examples: Sequence[TrainingExample],
  58. formatter_configs: FormatterConfigs,
  59. ) -> list[str]:
  60. """
  61. This formatter takes consumer-provided training examples and converts them to
  62. the right format for finetuning llama-guard.
  63. There are various configuration options available.
  64. A notable one is the ability to automagically augment the finetuning data set with some useful
  65. transformations of the original training examples. These augmentations make the
  66. classifier more flexible by improving its ability to be modified at inference time
  67. to include only a subset of the original categories it was trained on - without any
  68. additional finetuning.
  69. Some of these augmented transformations are made by duplicating training
  70. examples and safely removing some violation categories from the llama
  71. guard prompts. Because of this, in some of this file you will see
  72. references to "original" category indices/codes and rewritten ones. The originals
  73. are the indices/codes of the violation categories as they appear in the
  74. consumer-provided guidelines. The rewritten codes are the ones as they appear
  75. in the llama guard prompts of the augmented examples. We occasionally need to
  76. convert between the two.
  77. """
  78. _verify_formatter_configs(formatter_configs)
  79. random.seed(formatter_configs.random_seed)
  80. indices_of_all_categories = range(len(formatter_configs.guidelines.categories))
  81. to_return = []
  82. for training_example in training_examples:
  83. to_return.append(
  84. _create_formatted_finetuning_example(
  85. training_example,
  86. formatter_configs,
  87. category_indeces_to_include_in_llama_guard_prompt=list(
  88. indices_of_all_categories
  89. ),
  90. )
  91. )
  92. _maybe_add_data_augmentations_for_example(
  93. training_example, to_return, indices_of_all_categories, formatter_configs
  94. )
  95. return to_return
  96. def _verify_formatter_configs(
  97. formatter_configs: FormatterConfigs,
  98. ) -> None:
  99. if (
  100. formatter_configs.augmentation_configs.probability_to_add_safe_examples_with_empty_responses
  101. > 0
  102. and formatter_configs.llama_guard_generation_configs.explanation_position
  103. is not None
  104. and formatter_configs.augmentation_configs.explanation_for_augmentation_with_safe_example_with_empty_response
  105. is None
  106. ):
  107. raise ValueError(
  108. """The configuration setup requires you to specify
  109. explanation_for_augmentation_with_safe_example_with_empty_response. This is an
  110. explanation that we use for dynamically-created safe augmentation examples.
  111. Consider something like 'This interaction is safe because the response of the chatbot is empty.'"""
  112. )
  113. if (
  114. formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
  115. == True
  116. and formatter_configs.llama_guard_generation_configs.explanation_position
  117. is not None
  118. and formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
  119. is None
  120. ):
  121. raise ValueError(
  122. """The configuration setup requires you to specify
  123. explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories.
  124. This is an explanation that we use for dynamically-created safe augmentation examples.
  125. Consider something like 'This interaction is safe because any riskiness it contains
  126. is related to violation categories that we're explicitly not trying to detect here.'"""
  127. )
  128. def _create_formatted_finetuning_example(
  129. training_example: TrainingExample,
  130. formatter_configs: FormatterConfigs,
  131. category_indeces_to_include_in_llama_guard_prompt: List[int],
  132. ) -> str:
  133. if formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes:
  134. random.shuffle(category_indeces_to_include_in_llama_guard_prompt)
  135. else:
  136. category_indeces_to_include_in_llama_guard_prompt = sorted(
  137. category_indeces_to_include_in_llama_guard_prompt
  138. )
  139. llama_guard_prompt = _create_llama_guard_prompt(
  140. training_example,
  141. category_indeces_to_include_in_llama_guard_prompt,
  142. formatter_configs,
  143. )
  144. llama_guard_generation = _create_llama_guard_generation(
  145. training_example,
  146. formatter_configs,
  147. category_indeces_to_include_in_llama_guard_prompt,
  148. )
  149. return f"{llama_guard_prompt} {llama_guard_generation}"
  150. def _create_llama_guard_prompt(
  151. training_example: TrainingExample,
  152. category_indices_to_include: List[int],
  153. formatter_configs: FormatterConfigs,
  154. ) -> str:
  155. full_guidelines_text = ""
  156. for (
  157. rewritten_category_index_for_current_prompt,
  158. original_category_index,
  159. ) in enumerate(category_indices_to_include):
  160. category = formatter_configs.guidelines.categories[original_category_index]
  161. # Indices start at 0, but categories start at 1, so we add 1
  162. full_guidelines_text += f"\n{formatter_configs.guidelines.category_code_prefix}{rewritten_category_index_for_current_prompt + 1}: {category.name}. "
  163. if (
  164. formatter_configs.llama_guard_prompt_configs.should_include_category_descriptions
  165. ):
  166. full_guidelines_text += f"\n{category.description}"
  167. conversation = {"human": training_example.prompt}
  168. if not _is_a_prompt_only_example(training_example):
  169. conversation["chatbot"] = training_example.response
  170. return formatter_configs.llama_guard_prompt_configs.instructions_format_string.format_map(
  171. {
  172. "guidelines": full_guidelines_text,
  173. "conversation": _serialize_conversation(conversation),
  174. }
  175. )
  176. def _is_a_prompt_only_example(training_example: TrainingExample) -> bool:
  177. return training_example.response == "N/A"
  178. def _serialize_conversation(conversation: Dict[str, str]) -> str:
  179. conversation_as_list = []
  180. for speaker, message in conversation.items():
  181. conversation_as_list.append(f"{speaker}: {message}")
  182. return "\n\n".join(conversation_as_list)
  183. def _create_llama_guard_generation(
  184. training_example: TrainingExample,
  185. formatter_configs: FormatterConfigs,
  186. category_indices_included_in_llama_guard_prompt: List[int],
  187. ) -> str:
  188. to_return = training_example.label
  189. if (
  190. training_example.label == "unsafe"
  191. and formatter_configs.llama_guard_generation_configs.should_list_violated_codes
  192. ):
  193. violated_category_indices = set(
  194. _convert_category_codes_to_indices(
  195. training_example.violated_category_codes,
  196. formatter_configs,
  197. )
  198. )
  199. map_of_original_category_indices_to_rewritten_category_codes = (
  200. _get_map_of_original_category_indices_to_rewritten_category_codes(
  201. formatter_configs, category_indices_included_in_llama_guard_prompt
  202. )
  203. )
  204. rewritten_violated_category_codes = [
  205. map_of_original_category_indices_to_rewritten_category_codes[violated_index]
  206. for violated_index in violated_category_indices
  207. ]
  208. to_return += "\n"
  209. to_return += ",".join(rewritten_violated_category_codes)
  210. explanation_position = (
  211. formatter_configs.llama_guard_generation_configs.explanation_position
  212. )
  213. if explanation_position == ExplanationPosition.BEFORE_DECISION:
  214. to_return = f"Explanation: {training_example.explanation}\n{to_return}"
  215. elif explanation_position == ExplanationPosition.AFTER_DECISION:
  216. to_return = f"{to_return}\nExplanation: {training_example.explanation}"
  217. return to_return
  218. def _get_map_of_original_category_indices_to_rewritten_category_codes(
  219. formatter_configs: FormatterConfigs,
  220. category_indices_included_in_llama_guard_prompt: List[int],
  221. ) -> Dict[int, str]:
  222. to_return = {}
  223. for rewritten_category_index, original_category_index in enumerate(
  224. category_indices_included_in_llama_guard_prompt
  225. ):
  226. to_return[
  227. original_category_index
  228. ] = formatter_configs.guidelines.category_code_prefix + str(
  229. rewritten_category_index + 1
  230. )
  231. return to_return
  232. def _maybe_add_data_augmentations_for_example(
  233. training_example: TrainingExample,
  234. formatted_examples_being_built: list[str],
  235. indices_of_all_categories: range,
  236. formatter_configs: FormatterConfigs,
  237. ) -> None:
  238. _maybe_add_safe_example_with_empty_response(
  239. training_example,
  240. formatted_examples_being_built,
  241. indices_of_all_categories,
  242. formatter_configs,
  243. )
  244. _maybe_add_examples_with_dropped_prompt_categories(
  245. training_example,
  246. formatted_examples_being_built,
  247. indices_of_all_categories,
  248. formatter_configs,
  249. )
  250. def _maybe_add_safe_example_with_empty_response(
  251. training_example: TrainingExample,
  252. formatted_examples_being_built: list[str],
  253. indices_of_all_categories: range,
  254. formatter_configs: FormatterConfigs,
  255. ) -> None:
  256. """
  257. For any prompt+response pair, an empty response is a safe response,
  258. so we allow the data to be augmented by adding a safe example with the same
  259. prompt but an empty response.
  260. """
  261. if (
  262. not _is_a_prompt_only_example(training_example)
  263. and training_example.response != ""
  264. and random.random()
  265. < formatter_configs.augmentation_configs.probability_to_add_safe_examples_with_empty_responses
  266. ):
  267. training_example_copy = copy.deepcopy(training_example)
  268. training_example_copy.response = ""
  269. training_example_copy.label = "safe"
  270. training_example_copy.violated_category_codes = []
  271. training_example_copy.explanation = (
  272. formatter_configs.augmentation_configs.explanation_for_augmentation_with_safe_example_with_empty_response
  273. )
  274. formatted_examples_being_built.append(
  275. _create_formatted_finetuning_example(
  276. training_example_copy,
  277. formatter_configs,
  278. category_indeces_to_include_in_llama_guard_prompt=list(
  279. indices_of_all_categories
  280. ),
  281. )
  282. )
  283. def _maybe_add_examples_with_dropped_prompt_categories(
  284. training_example: TrainingExample,
  285. formatted_examples_being_built: list[str],
  286. indices_of_all_categories: range,
  287. formatter_configs: FormatterConfigs,
  288. ) -> None:
  289. violated_category_indices = _convert_category_codes_to_indices(
  290. training_example.violated_category_codes,
  291. formatter_configs,
  292. )
  293. nonviolated_category_indices = list(
  294. set(indices_of_all_categories) - set(violated_category_indices)
  295. )
  296. _maybe_add_example_with_dropped_nonviolated_prompt_categories(
  297. training_example,
  298. formatted_examples_being_built,
  299. indices_of_all_categories,
  300. nonviolated_category_indices,
  301. formatter_configs,
  302. )
  303. _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
  304. training_example,
  305. formatted_examples_being_built,
  306. indices_of_all_categories,
  307. violated_category_indices,
  308. nonviolated_category_indices,
  309. formatter_configs,
  310. )
  311. def _convert_category_codes_to_indices(
  312. codes: list[str], formatter_configs: FormatterConfigs
  313. ) -> list[int]:
  314. # Category codes start at 1, but indices start at 0, so we subtract 1
  315. return [
  316. int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
  317. for code in codes
  318. ]
  319. def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
  320. training_example: TrainingExample,
  321. formatted_examples_being_built: list[str],
  322. indices_of_all_categories: range,
  323. nonviolated_category_indices: list[int],
  324. formatter_configs: FormatterConfigs,
  325. ) -> None:
  326. """
  327. If a prompt+response pair does not violate certain categories, we can augment
  328. the data by duplicating the training example but removing some of the non-violated
  329. categories from the llama guard prompt. This facilitates removing categories from
  330. the llama guard prompt at inference time without any additional finetuning.
  331. """
  332. if (
  333. not formatter_configs.augmentation_configs.should_add_examples_with_dropped_nonviolated_prompt_categories
  334. ):
  335. return
  336. number_of_categories_to_drop = random.randint(0, len(nonviolated_category_indices))
  337. if number_of_categories_to_drop == len(indices_of_all_categories):
  338. number_of_categories_to_drop -= 1
  339. dropped_category_indices = random.sample(
  340. nonviolated_category_indices, number_of_categories_to_drop
  341. )
  342. retained_category_indices = list(
  343. set(indices_of_all_categories) - (set(dropped_category_indices))
  344. )
  345. formatted_examples_being_built.append(
  346. _create_formatted_finetuning_example(
  347. training_example,
  348. formatter_configs,
  349. category_indeces_to_include_in_llama_guard_prompt=retained_category_indices,
  350. )
  351. )
  352. def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
  353. training_example: TrainingExample,
  354. formatted_examples_being_built: list[str],
  355. indices_of_all_categories: range,
  356. violated_category_indices: list[int],
  357. nonviolated_category_indices: list[int],
  358. formatter_configs: FormatterConfigs,
  359. ) -> None:
  360. """
  361. Same as in _maybe_add_example_with_dropped_nonviolated_prompt_categories but we
  362. also drop all of the violated categories from the llama guard prompt.
  363. """
  364. if (
  365. training_example.label == "safe"
  366. or not formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
  367. ):
  368. return
  369. random_nonviolated_category_indices_to_drop = random.sample(
  370. nonviolated_category_indices,
  371. random.randint(0, len(nonviolated_category_indices) - 1),
  372. )
  373. set_of_retained_category_indices = (
  374. set(indices_of_all_categories)
  375. - set(violated_category_indices)
  376. - set(random_nonviolated_category_indices_to_drop)
  377. )
  378. training_example_copy = copy.deepcopy(training_example)
  379. training_example_copy.label = "safe"
  380. training_example_copy.violated_category_codes = []
  381. training_example_copy.explanation = (
  382. formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
  383. )
  384. formatted_examples_being_built.append(
  385. _create_formatted_finetuning_example(
  386. training_example_copy,
  387. formatter_configs,
  388. category_indeces_to_include_in_llama_guard_prompt=list(
  389. set_of_retained_category_indices
  390. ),
  391. )
  392. )