123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama Guard License Agreement.
- import copy
- import random
- from dataclasses import dataclass
- from enum import Enum
- from typing import Dict, List, Literal, Optional, Sequence
- @dataclass
- class Category:
- name: str
- description: str
- @dataclass
- class Guidelines:
- categories: Sequence[Category]
- category_code_prefix: str = "O"
- class ExplanationPosition(Enum):
- BEFORE_DECISION = 0
- AFTER_DECISION = 1
- @dataclass
- class LlamaGuardPromptConfigs:
- instructions_format_string: str
- should_include_category_descriptions: bool
- should_shuffle_category_codes: bool = True
- @dataclass
- class LlamaGuardGenerationConfigs:
- should_list_violated_codes: bool
- explanation_position: Optional[ExplanationPosition]
- @dataclass
- class AugmentationConfigs:
- probability_to_add_safe_examples_with_empty_responses: float = 0
- explanation_for_augmentation_with_safe_example_with_empty_response: Optional[
- str
- ] = None
- should_add_examples_with_dropped_nonviolated_prompt_categories: bool = True
- should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories: bool = (
- False
- )
- explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories: Optional[
- str
- ] = None
- @dataclass
- class FormatterConfigs:
- guidelines: Guidelines
- llama_guard_prompt_configs: LlamaGuardPromptConfigs
- llama_guard_generation_configs: LlamaGuardGenerationConfigs
- augmentation_configs: AugmentationConfigs
- # Allows subsequent reruns to reuse a stable seed for reproducibility
- random_seed: int = 42
- @dataclass
- class TrainingExample:
- prompt: str
- response: str
- violated_category_codes: list[str]
- label: Literal["safe", "unsafe"]
- explanation: str
- def create_formatted_finetuning_examples(
- training_examples: Sequence[TrainingExample],
- formatter_configs: FormatterConfigs,
- ) -> list[str]:
- """
- This formatter takes consumer-provided training examples and converts them to
- the right format for finetuning llama-guard.
- There are various configuration options available.
- A notable one is the ability to automagically augment the finetuning data set with some useful
- transformations of the original training examples. These augmentations make the
- classifier more flexible by improving its ability to be modified at inference time
- to include only a subset of the original categories it was trained on - without any
- additional finetuning.
- Some of these augmented transformations are made by duplicating training
- examples and safely removing some violation categories from the llama
- guard prompts. Because of this, in some of this file you will see
- references to "original" category indices/codes and rewritten one. The originals
- are the indices/codes of the violation categories as they appear in the
- consumer-provided guidelines. The rewritten codes are the ones as they appear
- in the llama guard prompts of the augmented examples. We occasionally need to
- convert between the two.
- """
- _verify_formatter_configs(formatter_configs)
- random.seed(formatter_configs.random_seed)
- indices_of_all_categories = range(len(formatter_configs.guidelines.categories))
- to_return = []
- for training_example in training_examples:
- to_return.append(
- _create_formatted_finetuning_example(
- training_example,
- formatter_configs,
- category_indeces_to_include_in_llama_guard_prompt=list(
- indices_of_all_categories
- ),
- )
- )
- _maybe_add_data_augmentations_for_example(
- training_example, to_return, indices_of_all_categories, formatter_configs
- )
- return to_return
- def _verify_formatter_configs(
- formatter_configs: FormatterConfigs,
- ) -> None:
- if (
- formatter_configs.augmentation_configs.probability_to_add_safe_examples_with_empty_responses
- > 0
- and formatter_configs.llama_guard_generation_configs.explanation_position
- is not None
- and formatter_configs.augmentation_configs.explanation_for_augmentation_with_safe_example_with_empty_response
- is None
- ):
- raise ValueError(
- """The configuration setup requires you to specify
- explanation_for_augmentation_with_safe_example_with_empty_response. This is an
- explanation that we use for dynamically-created safe augmentation examples.
- Consider something like 'This interaction is safe because the response of the chatbot is empty.'"""
- )
- if (
- formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
- > 0
- and formatter_configs.llama_guard_generation_configs.explanation_position
- is not None
- and formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
- is None
- ):
- raise ValueError(
- """The configuration setup requires you to specify
- explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories.
- This is an explanation that we use for dynamically-created safe augmentation examples.
- Consider something like 'This interaction is safe because any riskiness it contains
- is related to violation categories that we're explicitly not trying to detect here.'"""
- )
- def _create_formatted_finetuning_example(
- training_example: TrainingExample,
- formatter_configs: FormatterConfigs,
- category_indeces_to_include_in_llama_guard_prompt: List[int],
- ) -> str:
- if formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes:
- random.shuffle(category_indeces_to_include_in_llama_guard_prompt)
- else:
- category_indeces_to_include_in_llama_guard_prompt = sorted(
- category_indeces_to_include_in_llama_guard_prompt
- )
- llama_guard_prompt = _create_llama_guard_prompt(
- training_example,
- category_indeces_to_include_in_llama_guard_prompt,
- formatter_configs,
- )
- llama_guard_generation = _create_llama_guard_generation(
- training_example,
- formatter_configs,
- category_indeces_to_include_in_llama_guard_prompt,
- )
- return f"{llama_guard_prompt} {llama_guard_generation}"
- def _is_a_prompt_only_example(training_example: TrainingExample) -> bool:
- return training_example.response == "N/A"
- def _create_llama_guard_prompt(
- training_example: TrainingExample,
- category_indices_to_include: List[int],
- formatter_configs: FormatterConfigs,
- ) -> str:
- full_guidelines_text = ""
- for (
- rewritten_category_index_for_current_prompt,
- original_category_index,
- ) in enumerate(category_indices_to_include):
- category = formatter_configs.guidelines.categories[original_category_index]
- # Indices start at 0, but categories start at 1, so we add 1
- full_guidelines_text += f"\n{formatter_configs.guidelines.category_code_prefix}{rewritten_category_index_for_current_prompt + 1}: {category.name}. "
- if (
- formatter_configs.llama_guard_prompt_configs.should_include_category_descriptions
- ):
- full_guidelines_text += f"\n{category.description}"
- conversation = {"human": training_example.prompt}
- if not _is_a_prompt_only_example(training_example):
- conversation["chatbot"] = training_example.response
- return formatter_configs.llama_guard_prompt_configs.instructions_format_string.format_map(
- {
- "guidelines": full_guidelines_text,
- "conversation": _serialize_conversation(conversation),
- }
- )
- def _serialize_conversation(conversation: Dict[str, str]) -> str:
- conversation_as_list = []
- for speaker, message in conversation.items():
- conversation_as_list.append(f"{speaker}: {message}")
- return "\n\n".join(conversation_as_list)
- def _create_llama_guard_generation(
- training_example: TrainingExample,
- formatter_configs: FormatterConfigs,
- category_indices_included_in_llama_guard_prompt: List[int],
- ) -> str:
- to_return = training_example.label
- if (
- training_example.label == "unsafe"
- and formatter_configs.llama_guard_generation_configs.should_list_violated_codes
- ):
- violated_category_indices = set(
- _convert_category_codes_to_indices(
- training_example.violated_category_codes,
- formatter_configs,
- )
- )
- map_of_original_category_indices_to_rewritten_category_codes = (
- _get_map_of_original_category_indices_to_rewritten_category_codes(
- formatter_configs, category_indices_included_in_llama_guard_prompt
- )
- )
- rewritten_violated_category_codes = [
- map_of_original_category_indices_to_rewritten_category_codes[violated_index]
- for violated_index in violated_category_indices
- ]
- to_return += "\n"
- to_return += ",".join(rewritten_violated_category_codes)
- explanation_position = (
- formatter_configs.llama_guard_generation_configs.explanation_position
- )
- if explanation_position == ExplanationPosition.BEFORE_DECISION:
- to_return = f"Explanation: {training_example.explanation}\n{to_return}"
- elif explanation_position == ExplanationPosition.AFTER_DECISION:
- to_return = f"{to_return}\nExplanation: {training_example.explanation}"
- return to_return
- def _get_map_of_original_category_indices_to_rewritten_category_codes(
- formatter_configs: FormatterConfigs,
- category_indices_included_in_llama_guard_prompt: List[int],
- ) -> Dict[int, str]:
- to_return = {}
- for rewritten_category_index, original_category_index in enumerate(
- category_indices_included_in_llama_guard_prompt
- ):
- to_return[
- original_category_index
- ] = formatter_configs.guidelines.category_code_prefix + str(
- rewritten_category_index + 1
- )
- return to_return
- def _maybe_add_data_augmentations_for_example(
- training_example: TrainingExample,
- formatted_examples_being_built: list[dict[str, str]],
- indices_of_all_categories: range,
- formatter_configs: FormatterConfigs,
- ) -> None:
- _maybe_add_safe_example_with_empty_response(
- training_example,
- formatted_examples_being_built,
- indices_of_all_categories,
- formatter_configs,
- )
- _maybe_add_examples_with_dropped_prompt_categories(
- training_example,
- formatted_examples_being_built,
- indices_of_all_categories,
- formatter_configs,
- )
- def _maybe_add_safe_example_with_empty_response(
- training_example: TrainingExample,
- formatted_examples_being_built: list[dict[str, str]],
- indices_of_all_categories: range,
- formatter_configs: FormatterConfigs,
- ) -> None:
- """
- For any prompt+response pair, an empty response is a safe response,
- so we allow the data to be augmented by adding a safe example with the same
- prompt but an empty response.
- """
- if (
- not _is_a_prompt_only_example(training_example)
- and training_example.response != ""
- and random.random()
- < formatter_configs.augmentation_configs.probability_to_add_safe_examples_with_empty_responses
- ):
- training_example_copy = copy.deepcopy(training_example)
- training_example_copy.response = ""
- training_example_copy.label = "safe"
- training_example_copy.violated_category_codes = []
- training_example_copy.explanation = (
- formatter_configs.augmentation_configs.explanation_for_augmentation_with_safe_example_with_empty_response
- )
- formatted_examples_being_built.append(
- _create_formatted_finetuning_example(
- training_example_copy,
- formatter_configs,
- category_indeces_to_include_in_llama_guard_prompt=list(
- indices_of_all_categories
- ),
- )
- )
- def _maybe_add_examples_with_dropped_prompt_categories(
- training_example: TrainingExample,
- formatted_examples_being_built: list[dict[str, str]],
- indices_of_all_categories: range,
- formatter_configs: FormatterConfigs,
- ) -> None:
- violated_category_indices = _convert_category_codes_to_indices(
- training_example.violated_category_codes,
- formatter_configs,
- )
- nonviolated_category_indices = list(
- set(indices_of_all_categories) - set(violated_category_indices)
- )
- _maybe_add_example_with_dropped_nonviolated_prompt_categories(
- training_example,
- formatted_examples_being_built,
- indices_of_all_categories,
- nonviolated_category_indices,
- formatter_configs,
- )
- _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
- training_example,
- formatted_examples_being_built,
- indices_of_all_categories,
- violated_category_indices,
- nonviolated_category_indices,
- formatter_configs,
- )
- def _convert_category_codes_to_indices(
- codes: list[str], formatter_configs: FormatterConfigs
- ) -> list[int]:
- # Category codes start at 1, but indices start at 0, so we subtract 1
- return [
- int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
- for code in codes
- ]
- def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
- training_example: TrainingExample,
- formatted_examples_being_built: list[dict[str, str]],
- indices_of_all_categories: range,
- nonviolated_category_indices: list[int],
- formatter_configs: FormatterConfigs,
- ) -> None:
- """
- If a prompt+response pair does not violate certain categories, we can augment
- the data by duplicating the training example but removing some of the non-violated
- categories from the llama guard prompt. This facilitates removing categories from
- the llama guard prompt at inference time without any additional finetuning.
- """
- if (
- not formatter_configs.augmentation_configs.should_add_examples_with_dropped_nonviolated_prompt_categories
- ):
- return
- number_of_categories_to_drop = random.randint(0, len(nonviolated_category_indices))
- if number_of_categories_to_drop == len(indices_of_all_categories):
- number_of_categories_to_drop -= 1
- dropped_category_indices = random.sample(
- nonviolated_category_indices, number_of_categories_to_drop
- )
- retained_category_indices = list(
- set(indices_of_all_categories) - (set(dropped_category_indices))
- )
- formatted_examples_being_built.append(
- _create_formatted_finetuning_example(
- training_example,
- formatter_configs,
- category_indeces_to_include_in_llama_guard_prompt=retained_category_indices,
- )
- )
- def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
- training_example: TrainingExample,
- formatted_examples_being_built: list[dict[str, str]],
- indices_of_all_categories: range,
- violated_category_indices: list[int],
- nonviolated_category_indices: list[int],
- formatter_configs: FormatterConfigs,
- ) -> None:
- """
- Same as in _maybe_add_example_with_dropped_nonviolated_prompt_categories but we
- also drop all of the violated categories from the llama guard prompt.
- """
- if (
- training_example.label == "safe"
- or not formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
- ):
- return
- random_nonviolated_category_indices_to_drop = random.sample(
- nonviolated_category_indices,
- random.randint(0, len(nonviolated_category_indices) - 1),
- )
- set_of_retained_category_indices = (
- set(indices_of_all_categories)
- - set(violated_category_indices)
- - set(random_nonviolated_category_indices_to_drop)
- )
- training_example_copy = copy.deepcopy(training_example)
- training_example_copy.label = "safe"
- training_example_copy.violated_category_codes = []
- training_example_copy.explanation = (
- formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
- )
- formatted_examples_being_built.append(
- _create_formatted_finetuning_example(
- training_example_copy,
- formatter_configs,
- category_indeces_to_include_in_llama_guard_prompt=list(
- set_of_retained_category_indices
- ),
- )
- )
|