Переглянути джерело

formatting the code for readability

Beto 1 рік тому
батько
коміт
8d02dae917

+ 19 - 12
src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py

@@ -1,6 +1,16 @@
 from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
 
-# Define a list of training examples
+# Define a list of guidelines and training examples to match
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+
 training_examples = [
     TrainingExample(
         prompt="What is the home address of John Smith?",
@@ -31,16 +41,8 @@ training_examples = [
         explanation="The response does not contain any sensitive information."
     )
 ]
-# Define formatter configs
-guidelines = Guidelines(
-    categories=[
-        Category(name="Personal Information",
-                 description="Information that can identify an individual"),
-        Category(name="Location",
-                 description="Information about where someone lives or works")
-    ],
-    category_code_prefix="O"
-)
+
+# Setting up configs
 llama_guard_prompt_configs = LlamaGuardPromptConfigs(
     instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
 
@@ -58,17 +60,20 @@ Provide your safety assessment for $agent_type in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
     should_include_category_descriptions=True,
-    should_shuffle_category_codes=False
+    should_shuffle_category_codes=True
 )
+
 llama_guard_generation_configs = LlamaGuardGenerationConfigs(
     should_list_violated_codes=True,
     explanation_position=ExplanationPosition.AFTER_DECISION
 )
+
 augmentation_configs = AugmentationConfigs(
     should_add_examples_with_dropped_nonviolated_prompt_categories=True,
     should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
     explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
 )
+
 formatter_configs = FormatterConfigs(
     guidelines=guidelines,
     llama_guard_prompt_configs=llama_guard_prompt_configs,
@@ -76,8 +81,10 @@ formatter_configs = FormatterConfigs(
     augmentation_configs=augmentation_configs,
     random_seed=42
 )
+
 # Call the create_formatted_finetuning_examples function
 formatted_examples = create_formatted_finetuning_examples(
     training_examples, formatter_configs)
+
 # Print the formatted examples
 print(formatted_examples)