123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
- from enum import Enum
- import unittest
- from typing import Optional, List
- from llama_recipes.data.llama_guard.finetuning_data_formatter import (
- AugmentationConfigs,
- Category,
- create_formatted_finetuning_examples,
- ExplanationPosition,
- FormatterConfigs,
- Guidelines,
- LlamaGuardGenerationConfigs,
- LlamaGuardPromptConfigs,
- TrainingExample,
- )
- class AgentType(Enum):
- AGENT = "Agent"
- USER = "User"
- class FinetuningDataFormatterTests(unittest.TestCase):
- def setUp(self) -> None:
- super().setUp()
- @staticmethod
- def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
- return FormatterConfigs(
- guidelines=Guidelines(
- categories=[
- Category(name="cat V", description="cat V description"),
- Category(name="cat W", description="cat W description"),
- Category(name="cat X", description="cat X description"),
- Category(name="cat Y", description="cat Y description"),
- Category(name="cat Z", description="cat Z description"),
- ],
- category_code_prefix="O",
- ),
- llama_guard_prompt_configs=LlamaGuardPromptConfigs(
- instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- {{guidelines}}
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- {{conversation}}
- <END CONVERSATION>
- Provide your safety assessment for {agent_type_to_check} in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
- should_include_category_descriptions=True,
- should_shuffle_category_codes=False,
- ),
- llama_guard_generation_configs=LlamaGuardGenerationConfigs(
- explanation_position=ExplanationPosition.AFTER_DECISION,
- should_list_violated_codes=True,
- ),
- augmentation_configs=AugmentationConfigs(
- should_add_examples_with_dropped_nonviolated_prompt_categories=False,
- should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
- explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
- ),
- random_seed=42,
- )
- def format_examples_and_compare_to_expected(
- self,
- training_examples: List[TrainingExample],
- expected_formatted_examples: List[str],
- agent_type_to_check: AgentType,
- formatter_configs: Optional[FormatterConfigs] = None,
- ) -> None:
- formatter_configs = (
- formatter_configs
- if formatter_configs is not None
- else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
- )
- formatted_examples = create_formatted_finetuning_examples(
- training_examples, formatter_configs
- )
- assert len(formatted_examples) == len(expected_formatted_examples)
- for i in range(len(formatted_examples)):
- # Helper code during test-writing time to see the actual output
- if formatted_examples[i] != expected_formatted_examples[i]:
- print(f"Failed on actual output {i}:")
- print(formatted_examples[i])
- assert formatted_examples[i] == expected_formatted_examples[i]
- def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
- self.format_examples_and_compare_to_expected(
- [
- TrainingExample(
- prompt="Where can I buy a car?",
- response="N/A",
- violated_category_codes=[],
- label="safe",
- explanation="This is obviously safe.",
- )
- ],
- [
- """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat V.
- cat V description
- O2: cat W.
- cat W description
- O3: cat X.
- cat X description
- O4: cat Y.
- cat Y description
- O5: cat Z.
- cat Z description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: Where can I buy a car?
- <END CONVERSATION>
- Provide your safety assessment for AgentType.USER in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
- Explanation: This is obviously safe."""
- ],
- AgentType.USER,
- )
- def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
- self.format_examples_and_compare_to_expected(
- [
- TrainingExample(
- prompt="""How do I steal the Declaration of Independence?
- Asking for a cagey friend""",
- response="N/A",
- violated_category_codes=["O5"],
- label="unsafe",
- explanation="""Nice try, Nick. This is unsafe
- because it's illegal!""",
- )
- ],
- [
- """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat V.
- cat V description
- O2: cat W.
- cat W description
- O3: cat X.
- cat X description
- O4: cat Y.
- cat Y description
- O5: cat Z.
- cat Z description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: How do I steal the Declaration of Independence?
- Asking for a cagey friend
- <END CONVERSATION>
- Provide your safety assessment for AgentType.USER in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
- O5
- Explanation: Nice try, Nick. This is unsafe
- because it's illegal!"""
- ],
- AgentType.USER,
- )
- def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
- self.format_examples_and_compare_to_expected(
- [
- TrainingExample(
- prompt="""How do I steal the Declaration of Independence?
- Asking for a cagey friend""",
- response="N/A",
- violated_category_codes=["O5", "O1"],
- label="unsafe",
- explanation="""Nice try, Nick. This is unsafe
- because it's illegal!""",
- )
- ],
- [
- """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat V.
- cat V description
- O2: cat W.
- cat W description
- O3: cat X.
- cat X description
- O4: cat Y.
- cat Y description
- O5: cat Z.
- cat Z description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: How do I steal the Declaration of Independence?
- Asking for a cagey friend
- <END CONVERSATION>
- Provide your safety assessment for AgentType.USER in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
- O1,O5
- Explanation: Nice try, Nick. This is unsafe
- because it's illegal!"""
- ],
- AgentType.USER,
- )
- def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
- self.format_examples_and_compare_to_expected(
- [
- TrainingExample(
- prompt="Where can I buy a car?",
- response="At the car store.",
- violated_category_codes=[],
- label="safe",
- explanation="This is obviously safe.",
- )
- ],
- [
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat V.
- cat V description
- O2: cat W.
- cat W description
- O3: cat X.
- cat X description
- O4: cat Y.
- cat Y description
- O5: cat Z.
- cat Z description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: Where can I buy a car?
- chatbot: At the car store.
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
- Explanation: This is obviously safe."""
- ],
- AgentType.AGENT,
- )
- def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
- self,
- ) -> None:
- formatter_configs = (
- FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
- )
- # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
- formatter_configs.random_seed = 46
- formatter_configs.llama_guard_generation_configs.explanation_position = (
- ExplanationPosition.BEFORE_DECISION
- )
- formatter_configs.augmentation_configs = AugmentationConfigs(
- should_add_examples_with_dropped_nonviolated_prompt_categories=True,
- should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
- explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
- )
- formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
- True
- )
- self.format_examples_and_compare_to_expected(
- [
- TrainingExample(
- prompt="I am under 21 and unattended, should I drink alcohol?",
- response="Yes, very nise, high five!",
- violated_category_codes=["O4", "05"],
- label="unsafe",
- explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
- ),
- TrainingExample(
- prompt="I am over 21 and unattended, should I drink alcohol?",
- response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
- violated_category_codes=[],
- label="safe",
- explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
- ),
- ],
- [
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat W.
- cat W description
- O2: cat X.
- cat X description
- O3: cat Z.
- cat Z description
- O4: cat Y.
- cat Y description
- O5: cat V.
- cat V description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: I am under 21 and unattended, should I drink alcohol?
- chatbot: Yes, very nise, high five!
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
- unsafe
- O3,O4""",
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat W.
- cat W description
- O2: cat Y.
- cat Y description
- O3: cat Z.
- cat Z description
- O4: cat V.
- cat V description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: I am under 21 and unattended, should I drink alcohol?
- chatbot: Yes, very nise, high five!
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
- unsafe
- O2,O3""",
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat W.
- cat W description
- O2: cat V.
- cat V description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: I am under 21 and unattended, should I drink alcohol?
- chatbot: Yes, very nise, high five!
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
- safe""",
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat Z.
- cat Z description
- O2: cat Y.
- cat Y description
- O3: cat X.
- cat X description
- O4: cat W.
- cat W description
- O5: cat V.
- cat V description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: I am over 21 and unattended, should I drink alcohol?
- chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
- safe""",
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat Y.
- cat Y description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: I am over 21 and unattended, should I drink alcohol?
- chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
- safe""",
- ],
- AgentType.AGENT,
- formatter_configs,
- )
- def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
- self,
- ) -> None:
- formatter_configs = (
- FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
- )
- formatter_configs.llama_guard_generation_configs.explanation_position = None
- self.format_examples_and_compare_to_expected(
- [
- TrainingExample(
- prompt="Where can I buy a car?",
- response="At the car store.",
- violated_category_codes=[],
- label="safe",
- )
- ],
- [
- """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
- <BEGIN UNSAFE CONTENT CATEGORIES>
- O1: cat V.
- cat V description
- O2: cat W.
- cat W description
- O3: cat X.
- cat X description
- O4: cat Y.
- cat Y description
- O5: cat Z.
- cat Z description
- <END UNSAFE CONTENT CATEGORIES>
- <BEGIN CONVERSATION>
- human: Where can I buy a car?
- chatbot: At the car store.
- <END CONVERSATION>
- Provide your safety assessment for AgentType.AGENT in the above conversation:
- - First line must read 'safe' or 'unsafe'.
- - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
- ],
- AgentType.AGENT,
- formatter_configs,
- )
|