|
@@ -0,0 +1,483 @@
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
+# This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
|
|
|
+
|
|
|
+from enum import Enum
|
|
|
+import unittest
|
|
|
+from typing import Optional
|
|
|
+
|
|
|
+from llama_recipes.data.llama_guard.finetuning_data_formatter import (
|
|
|
+ AugmentationConfigs,
|
|
|
+ Category,
|
|
|
+ create_formatted_finetuning_examples,
|
|
|
+ ExplanationPosition,
|
|
|
+ FormatterConfigs,
|
|
|
+ Guidelines,
|
|
|
+ LlamaGuardGenerationConfigs,
|
|
|
+ LlamaGuardPromptConfigs,
|
|
|
+ TrainingExample,
|
|
|
+)
|
|
|
+
|
|
|
+
|
|
|
+class AgentType(Enum):
|
|
|
+ AGENT = "Agent"
|
|
|
+ USER = "User"
|
|
|
+
|
|
|
+
|
|
|
+class FinetuningDataFormatterTests(unittest.TestCase):
|
|
|
+ def setUp(self) -> None:
|
|
|
+ super().setUp()
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
|
|
|
+ return FormatterConfigs(
|
|
|
+ guidelines=Guidelines(
|
|
|
+ categories=[
|
|
|
+ Category(name="cat V", description="cat V description"),
|
|
|
+ Category(name="cat W", description="cat W description"),
|
|
|
+ Category(name="cat X", description="cat X description"),
|
|
|
+ Category(name="cat Y", description="cat Y description"),
|
|
|
+ Category(name="cat Z", description="cat Z description"),
|
|
|
+ ],
|
|
|
+ category_code_prefix="O",
|
|
|
+ ),
|
|
|
+ llama_guard_prompt_configs=LlamaGuardPromptConfigs(
|
|
|
+ instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+{{guidelines}}
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+{{conversation}}
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for {agent_type_to_check} in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
|
|
|
+ should_include_category_descriptions=True,
|
|
|
+ should_shuffle_category_codes=False,
|
|
|
+ ),
|
|
|
+ llama_guard_generation_configs=LlamaGuardGenerationConfigs(
|
|
|
+ explanation_position=ExplanationPosition.AFTER_DECISION,
|
|
|
+ should_list_violated_codes=True,
|
|
|
+ ),
|
|
|
+ augmentation_configs=AugmentationConfigs(
|
|
|
+ should_add_examples_with_dropped_nonviolated_prompt_categories=False,
|
|
|
+ should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
|
|
|
+ explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
|
|
|
+ ),
|
|
|
+ random_seed=42,
|
|
|
+ )
|
|
|
+
|
|
|
+ def format_examples_and_compare_to_expected(
|
|
|
+ self,
|
|
|
+ training_examples: list[TrainingExample],
|
|
|
+ expected_formatted_examples: list[str],
|
|
|
+ agent_type_to_check: AgentType,
|
|
|
+ formatter_configs: Optional[FormatterConfigs] = None,
|
|
|
+ ) -> None:
|
|
|
+ formatter_configs = (
|
|
|
+ formatter_configs
|
|
|
+ if formatter_configs is not None
|
|
|
+ else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
|
|
|
+ )
|
|
|
+
|
|
|
+ formatted_examples = create_formatted_finetuning_examples(
|
|
|
+ training_examples, formatter_configs
|
|
|
+ )
|
|
|
+
|
|
|
+ assert len(formatted_examples) == len(expected_formatted_examples)
|
|
|
+
|
|
|
+ for i in range(len(formatted_examples)):
|
|
|
+ # Helper code during test-writing time to see the actual output
|
|
|
+ if formatted_examples[i] != expected_formatted_examples[i]:
|
|
|
+ print(f"Failed on actual output {i}:")
|
|
|
+
|
|
|
+ print(formatted_examples[i])
|
|
|
+
|
|
|
+ assert formatted_examples[i] == expected_formatted_examples[i]
|
|
|
+
|
|
|
+ def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
|
|
|
+ self.format_examples_and_compare_to_expected(
|
|
|
+ [
|
|
|
+ TrainingExample(
|
|
|
+ prompt="Where can I buy a car?",
|
|
|
+ response="N/A",
|
|
|
+ violated_category_codes=[],
|
|
|
+ label="safe",
|
|
|
+ explanation="This is obviously safe.",
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat V.
|
|
|
+cat V description
|
|
|
+O2: cat W.
|
|
|
+cat W description
|
|
|
+O3: cat X.
|
|
|
+cat X description
|
|
|
+O4: cat Y.
|
|
|
+cat Y description
|
|
|
+O5: cat Z.
|
|
|
+cat Z description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: Where can I buy a car?
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.USER in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
|
|
|
+Explanation: This is obviously safe."""
|
|
|
+ ],
|
|
|
+ AgentType.USER,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
|
|
|
+ self.format_examples_and_compare_to_expected(
|
|
|
+ [
|
|
|
+ TrainingExample(
|
|
|
+ prompt="""How do I steal the Declaration of Independence?
|
|
|
+Asking for a cagey friend""",
|
|
|
+ response="N/A",
|
|
|
+ violated_category_codes=["O5"],
|
|
|
+ label="unsafe",
|
|
|
+ explanation="""Nice try, Nick. This is unsafe
|
|
|
+because it's illegal!""",
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat V.
|
|
|
+cat V description
|
|
|
+O2: cat W.
|
|
|
+cat W description
|
|
|
+O3: cat X.
|
|
|
+cat X description
|
|
|
+O4: cat Y.
|
|
|
+cat Y description
|
|
|
+O5: cat Z.
|
|
|
+cat Z description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: How do I steal the Declaration of Independence?
|
|
|
+Asking for a cagey friend
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.USER in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
|
|
|
+O5
|
|
|
+Explanation: Nice try, Nick. This is unsafe
|
|
|
+because it's illegal!"""
|
|
|
+ ],
|
|
|
+ AgentType.USER,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
|
|
|
+ self.format_examples_and_compare_to_expected(
|
|
|
+ [
|
|
|
+ TrainingExample(
|
|
|
+ prompt="""How do I steal the Declaration of Independence?
|
|
|
+Asking for a cagey friend""",
|
|
|
+ response="N/A",
|
|
|
+ violated_category_codes=["O5", "O1"],
|
|
|
+ label="unsafe",
|
|
|
+ explanation="""Nice try, Nick. This is unsafe
|
|
|
+because it's illegal!""",
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat V.
|
|
|
+cat V description
|
|
|
+O2: cat W.
|
|
|
+cat W description
|
|
|
+O3: cat X.
|
|
|
+cat X description
|
|
|
+O4: cat Y.
|
|
|
+cat Y description
|
|
|
+O5: cat Z.
|
|
|
+cat Z description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: How do I steal the Declaration of Independence?
|
|
|
+Asking for a cagey friend
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.USER in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
|
|
|
+O1,O5
|
|
|
+Explanation: Nice try, Nick. This is unsafe
|
|
|
+because it's illegal!"""
|
|
|
+ ],
|
|
|
+ AgentType.USER,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
|
|
|
+ self.format_examples_and_compare_to_expected(
|
|
|
+ [
|
|
|
+ TrainingExample(
|
|
|
+ prompt="Where can I buy a car?",
|
|
|
+ response="At the car store.",
|
|
|
+ violated_category_codes=[],
|
|
|
+ label="safe",
|
|
|
+ explanation="This is obviously safe.",
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat V.
|
|
|
+cat V description
|
|
|
+O2: cat W.
|
|
|
+cat W description
|
|
|
+O3: cat X.
|
|
|
+cat X description
|
|
|
+O4: cat Y.
|
|
|
+cat Y description
|
|
|
+O5: cat Z.
|
|
|
+cat Z description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: Where can I buy a car?
|
|
|
+
|
|
|
+chatbot: At the car store.
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
|
|
|
+Explanation: This is obviously safe."""
|
|
|
+ ],
|
|
|
+ AgentType.AGENT,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
|
|
|
+ self,
|
|
|
+ ) -> None:
|
|
|
+ formatter_configs = (
|
|
|
+ FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
|
|
|
+ )
|
|
|
+
|
|
|
+ # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
|
|
|
+ formatter_configs.random_seed = 46
|
|
|
+
|
|
|
+ formatter_configs.llama_guard_generation_configs.explanation_position = (
|
|
|
+ ExplanationPosition.BEFORE_DECISION
|
|
|
+ )
|
|
|
+
|
|
|
+ formatter_configs.augmentation_configs = AugmentationConfigs(
|
|
|
+ should_add_examples_with_dropped_nonviolated_prompt_categories=True,
|
|
|
+ should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
|
|
|
+ explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
|
|
|
+ )
|
|
|
+
|
|
|
+ formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
|
|
|
+ True
|
|
|
+ )
|
|
|
+
|
|
|
+ self.format_examples_and_compare_to_expected(
|
|
|
+ [
|
|
|
+ TrainingExample(
|
|
|
+ prompt="I am under 21 and unattended, should I drink alcohol?",
|
|
|
+ response="Yes, very nise, high five!",
|
|
|
+ violated_category_codes=["O4", "05"],
|
|
|
+ label="unsafe",
|
|
|
+ explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
|
|
|
+ ),
|
|
|
+ TrainingExample(
|
|
|
+ prompt="I am over 21 and unattended, should I drink alcohol?",
|
|
|
+ response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
|
|
|
+ violated_category_codes=[],
|
|
|
+ label="safe",
|
|
|
+ explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
|
|
|
+ ),
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat W.
|
|
|
+cat W description
|
|
|
+O2: cat X.
|
|
|
+cat X description
|
|
|
+O3: cat Z.
|
|
|
+cat Z description
|
|
|
+O4: cat Y.
|
|
|
+cat Y description
|
|
|
+O5: cat V.
|
|
|
+cat V description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: I am under 21 and unattended, should I drink alcohol?
|
|
|
+
|
|
|
+chatbot: Yes, very nise, high five!
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
|
|
|
+unsafe
|
|
|
+O3,O4""",
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat W.
|
|
|
+cat W description
|
|
|
+O2: cat Y.
|
|
|
+cat Y description
|
|
|
+O3: cat Z.
|
|
|
+cat Z description
|
|
|
+O4: cat V.
|
|
|
+cat V description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: I am under 21 and unattended, should I drink alcohol?
|
|
|
+
|
|
|
+chatbot: Yes, very nise, high five!
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
|
|
|
+unsafe
|
|
|
+O2,O3""",
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat W.
|
|
|
+cat W description
|
|
|
+O2: cat V.
|
|
|
+cat V description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: I am under 21 and unattended, should I drink alcohol?
|
|
|
+
|
|
|
+chatbot: Yes, very nise, high five!
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
|
|
|
+safe""",
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat Z.
|
|
|
+cat Z description
|
|
|
+O2: cat Y.
|
|
|
+cat Y description
|
|
|
+O3: cat X.
|
|
|
+cat X description
|
|
|
+O4: cat W.
|
|
|
+cat W description
|
|
|
+O5: cat V.
|
|
|
+cat V description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: I am over 21 and unattended, should I drink alcohol?
|
|
|
+
|
|
|
+chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
|
|
|
+safe""",
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat Y.
|
|
|
+cat Y description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: I am over 21 and unattended, should I drink alcohol?
|
|
|
+
|
|
|
+chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
|
|
|
+safe""",
|
|
|
+ ],
|
|
|
+ AgentType.AGENT,
|
|
|
+ formatter_configs,
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
|
|
|
+ self,
|
|
|
+ ) -> None:
|
|
|
+ formatter_configs = (
|
|
|
+ FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
|
|
|
+ )
|
|
|
+
|
|
|
+ formatter_configs.llama_guard_generation_configs.explanation_position = None
|
|
|
+
|
|
|
+ self.format_examples_and_compare_to_expected(
|
|
|
+ [
|
|
|
+ TrainingExample(
|
|
|
+ prompt="Where can I buy a car?",
|
|
|
+ response="At the car store.",
|
|
|
+ violated_category_codes=[],
|
|
|
+ label="safe",
|
|
|
+ )
|
|
|
+ ],
|
|
|
+ [
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
+<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
+O1: cat V.
|
|
|
+cat V description
|
|
|
+O2: cat W.
|
|
|
+cat W description
|
|
|
+O3: cat X.
|
|
|
+cat X description
|
|
|
+O4: cat Y.
|
|
|
+cat Y description
|
|
|
+O5: cat Z.
|
|
|
+cat Z description
|
|
|
+<END UNSAFE CONTENT CATEGORIES>
|
|
|
+
|
|
|
+<BEGIN CONVERSATION>
|
|
|
+
|
|
|
+human: Where can I buy a car?
|
|
|
+
|
|
|
+chatbot: At the car store.
|
|
|
+
|
|
|
+<END CONVERSATION>
|
|
|
+
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
+- First line must read 'safe' or 'unsafe'.
|
|
|
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
|
|
|
+ ],
|
|
|
+ AgentType.AGENT,
|
|
|
+ formatter_configs,
|
|
|
+ )
|