Browse Source

Llama guard data formatter example (#337)

Sample script to format data in the expected format by Llama Guard.
albertodepaola 1 year ago
parent
commit
aaa769c91b

+ 1 - 1
README.md

@@ -1,6 +1,6 @@
 # Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
 
-**[Update Dec. 15, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md).**
+**[Update Dec. 28, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md). For details on formatting data for fine tuning Llama Guard, we provide a script and sample usage [here](./src/llama_recipes/data/llama_guard/README.md).**
 
 **[Update Dec 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem),  how to use Azure Llama 2 API (Model-as-a-Service), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp and Messenger, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 

+ 21 - 2
examples/llama_guard/README.md

@@ -6,7 +6,7 @@ This folder contains an example file to run Llama Guard inference directly.
 
 ## Requirements
 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes dependencies installed 
+2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
 3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
 
 ## Llama Guard inference script
@@ -34,8 +34,27 @@ To run the samples, with all the dependencies installed, execute this command:
 
 `python examples/llama_guard/inference.py`
 
+This is the output:
+
+```
+['<Sample user prompt>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
+> safe
+
+==================================
+```
+
 ## Inference Safety Checker
-When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be show, else a message with the error will be show, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
+When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
 
 In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method.
 

+ 6 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1216,4 +1216,9 @@ Anyscale
 ADDR
 ckpt
 HuggingFace
-llamaguard
+llamaguard
+AugmentationConfigs
+FormatterConfigs
+LlamaGuardGenerationConfigs
+LlamaGuardPromptConfigs
+TrainingExample

+ 119 - 0
src/llama_recipes/data/llama_guard/README.md

@@ -0,0 +1,119 @@
+# Finetuning Data Formatter
+
+The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are:
+* `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation.
+* `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses.
+* `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted.
+* `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted.
+* `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data.
+* `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method.
+
+## Running the script
+
+1. Clone the llama-recipes repo
+2. Install the dependencies
+3. Run the script with the following command: `python src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py > sample.json`
+
+## Code overview
+To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example:
+
+```
+training_examples = [
+    TrainingExample(
+        prompt="Can you give me the phone number of Jane Doe?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    # Add more training examples here...
+]
+```
+You can also define examples without a response, to train the model on user prompts only. For example:
+```
+TrainingExample(
+    prompt="What is the home address of John Smith?",
+    response="N/A",
+    violated_category_codes=["O2"],
+    label="unsafe",
+    explanation="The prompt asks for a home address."
+)
+```
+
+Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example:
+
+```
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+```
+
+Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options. 
+
+**Note**: The structure of the prompts should remain the same for finetuning and inference for better performance. 
+
+For example:
+
+```
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories.  """,
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+```
+You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example:
+
+```
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+```
+The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example:
+
+```
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+```
+
+Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example:
+
+```
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+# Print the formatted examples
+print(formatted_examples)
+
+```

+ 10 - 10
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -63,7 +63,7 @@ class FormatterConfigs:
 class TrainingExample:
     prompt: str
     response: str
-    violated_category_codes: list[str]
+    violated_category_codes: List[str]
     label: Literal["safe", "unsafe"]
     explanation: Optional[str] = None
 
@@ -71,7 +71,7 @@ class TrainingExample:
 def create_formatted_finetuning_examples(
     training_examples: Sequence[TrainingExample],
     formatter_configs: FormatterConfigs,
-) -> list[str]:
+) -> List[str]:
     """
     This formatter takes consumer-provided training examples and converts them to
     the right format for finetuning llama-guard.
@@ -285,7 +285,7 @@ def _get_map_of_original_category_indices_to_rewritten_category_codes(
 
 def _maybe_add_data_augmentations_for_example(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -317,8 +317,8 @@ def _maybe_add_data_augmentations_for_example(
 
 
 def _convert_category_codes_to_indices(
-    codes: list[str], formatter_configs: FormatterConfigs
-) -> list[int]:
+    codes: List[str], formatter_configs: FormatterConfigs
+) -> List[int]:
     # Category codes start at 1, but indices start at 0, so we subtract 1
     return [
         int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
@@ -328,9 +328,9 @@ def _convert_category_codes_to_indices(
 
 def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    nonviolated_category_indices: list[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """
@@ -368,10 +368,10 @@ def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
 
 def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    violated_category_indices: list[int],
-    nonviolated_category_indices: list[int],
+    violated_category_indices: List[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """

+ 90 - 0
src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py

@@ -0,0 +1,90 @@
+from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
+
+# Define a list of guidelines and training examples to match
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+
+training_examples = [
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="The home address of John Smith is 123 Main St, Anytown USA",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The response reveals the user's location."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="N/A",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The prompt asks for a home address."
+    ),
+    TrainingExample(
+        prompt="What is your favorite color?",
+        response="My favorite color is blue.",
+        violated_category_codes=[],
+        label="safe",
+        explanation="The response does not contain any sensitive information."
+    )
+]
+
+# Setting up configs
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+
+# Print the formatted examples
+print(formatted_examples)

+ 3 - 3
tests/test_finetuning_data_formatter.py

@@ -3,7 +3,7 @@
 
 from enum import Enum
 import unittest
-from typing import Optional
+from typing import Optional, List
 
 from llama_recipes.data.llama_guard.finetuning_data_formatter import (
     AugmentationConfigs,
@@ -72,8 +72,8 @@ Provide your safety assessment for {agent_type_to_check} in the above conversati
 
     def format_examples_and_compare_to_expected(
         self,
-        training_examples: list[TrainingExample],
-        expected_formatted_examples: list[str],
+        training_examples: List[TrainingExample],
+        expected_formatted_examples: List[str],
         agent_type_to_check: AgentType,
         formatter_configs: Optional[FormatterConfigs] = None,
     ) -> None: