|
@@ -1,6 +1,16 @@
|
|
from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
|
|
from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
|
|
|
|
|
|
-# Define a list of training examples
|
|
|
|
|
|
+# Define a list of guidelines and training examples to match
|
|
|
|
+guidelines = Guidelines(
|
|
|
|
+ categories=[
|
|
|
|
+ Category(name="Personal Information",
|
|
|
|
+ description="Information that can identify an individual"),
|
|
|
|
+ Category(name="Location",
|
|
|
|
+ description="Information about where someone lives or works")
|
|
|
|
+ ],
|
|
|
|
+ category_code_prefix="O"
|
|
|
|
+)
|
|
|
|
+
|
|
training_examples = [
|
|
training_examples = [
|
|
TrainingExample(
|
|
TrainingExample(
|
|
prompt="What is the home address of John Smith?",
|
|
prompt="What is the home address of John Smith?",
|
|
@@ -31,16 +41,8 @@ training_examples = [
|
|
explanation="The response does not contain any sensitive information."
|
|
explanation="The response does not contain any sensitive information."
|
|
)
|
|
)
|
|
]
|
|
]
|
|
-# Define formatter configs
|
|
|
|
-guidelines = Guidelines(
|
|
|
|
- categories=[
|
|
|
|
- Category(name="Personal Information",
|
|
|
|
- description="Information that can identify an individual"),
|
|
|
|
- Category(name="Location",
|
|
|
|
- description="Information about where someone lives or works")
|
|
|
|
- ],
|
|
|
|
- category_code_prefix="O"
|
|
|
|
-)
|
|
|
|
|
|
+
|
|
|
|
+# Setting up configs
|
|
llama_guard_prompt_configs = LlamaGuardPromptConfigs(
|
|
llama_guard_prompt_configs = LlamaGuardPromptConfigs(
|
|
instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
|
|
instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
|
|
|
|
|
|
@@ -58,17 +60,20 @@ Provide your safety assessment for $agent_type in the above conversation:
|
|
- First line must read 'safe' or 'unsafe'.
|
|
- First line must read 'safe' or 'unsafe'.
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
|
|
should_include_category_descriptions=True,
|
|
should_include_category_descriptions=True,
|
|
- should_shuffle_category_codes=False
|
|
|
|
|
|
+ should_shuffle_category_codes=True
|
|
)
|
|
)
|
|
|
|
+
|
|
llama_guard_generation_configs = LlamaGuardGenerationConfigs(
|
|
llama_guard_generation_configs = LlamaGuardGenerationConfigs(
|
|
should_list_violated_codes=True,
|
|
should_list_violated_codes=True,
|
|
explanation_position=ExplanationPosition.AFTER_DECISION
|
|
explanation_position=ExplanationPosition.AFTER_DECISION
|
|
)
|
|
)
|
|
|
|
+
|
|
augmentation_configs = AugmentationConfigs(
|
|
augmentation_configs = AugmentationConfigs(
|
|
should_add_examples_with_dropped_nonviolated_prompt_categories=True,
|
|
should_add_examples_with_dropped_nonviolated_prompt_categories=True,
|
|
should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
|
|
should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
|
|
explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
|
|
explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
|
|
)
|
|
)
|
|
|
|
+
|
|
formatter_configs = FormatterConfigs(
|
|
formatter_configs = FormatterConfigs(
|
|
guidelines=guidelines,
|
|
guidelines=guidelines,
|
|
llama_guard_prompt_configs=llama_guard_prompt_configs,
|
|
llama_guard_prompt_configs=llama_guard_prompt_configs,
|
|
@@ -76,8 +81,10 @@ formatter_configs = FormatterConfigs(
|
|
augmentation_configs=augmentation_configs,
|
|
augmentation_configs=augmentation_configs,
|
|
random_seed=42
|
|
random_seed=42
|
|
)
|
|
)
|
|
|
|
+
|
|
# Call the create_formatted_finetuning_examples function
|
|
# Call the create_formatted_finetuning_examples function
|
|
formatted_examples = create_formatted_finetuning_examples(
|
|
formatted_examples = create_formatted_finetuning_examples(
|
|
training_examples, formatter_configs)
|
|
training_examples, formatter_configs)
|
|
|
|
+
|
|
# Print the formatted examples
|
|
# Print the formatted examples
|
|
print(formatted_examples)
|
|
print(formatted_examples)
|