|
@@ -3,7 +3,8 @@
|
|
|
|
|
|
import unittest
|
|
|
from typing import Optional
|
|
|
-from examples.llama_guard.prompt_format import PROMPT_TEMPLATE, AgentType
|
|
|
+
|
|
|
+from examples.llama_guard.prompt_format import AgentType
|
|
|
|
|
|
from llama_recipes.data.llama_guard.finetuning_data_formatter import (
|
|
|
AugmentationConfigs,
|
|
@@ -23,7 +24,7 @@ class FinetuningDataFormatterTests(unittest.TestCase):
|
|
|
super().setUp()
|
|
|
|
|
|
@staticmethod
|
|
|
- def create_most_conservative_formatter_configs() -> FormatterConfigs:
|
|
|
+ def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
|
|
|
return FormatterConfigs(
|
|
|
guidelines=Guidelines(
|
|
|
categories=[
|
|
@@ -36,18 +37,18 @@ class FinetuningDataFormatterTests(unittest.TestCase):
|
|
|
category_code_prefix="O",
|
|
|
),
|
|
|
llama_guard_prompt_configs=LlamaGuardPromptConfigs(
|
|
|
- instructions_format_string="""[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
-{guidelines}
|
|
|
+{{guidelines}}
|
|
|
<END UNSAFE CONTENT CATEGORIES>
|
|
|
|
|
|
<BEGIN CONVERSATION>
|
|
|
|
|
|
-{conversation}
|
|
|
+{{conversation}}
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for {agent_type_to_check} in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
|
|
|
should_include_category_descriptions=True,
|
|
@@ -69,12 +70,13 @@ Provide your safety assessment for $agent_type in the above conversation:
|
|
|
self,
|
|
|
training_examples: list[TrainingExample],
|
|
|
expected_formatted_examples: list[str],
|
|
|
+ agent_type_to_check: AgentType,
|
|
|
formatter_configs: Optional[FormatterConfigs] = None,
|
|
|
) -> None:
|
|
|
formatter_configs = (
|
|
|
formatter_configs
|
|
|
if formatter_configs is not None
|
|
|
- else FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
|
|
|
+ else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
|
|
|
)
|
|
|
|
|
|
formatted_examples = create_formatted_finetuning_examples(
|
|
@@ -104,7 +106,7 @@ Provide your safety assessment for $agent_type in the above conversation:
|
|
|
)
|
|
|
],
|
|
|
[
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat V.
|
|
|
cat V description
|
|
@@ -124,11 +126,12 @@ human: Where can I buy a car?
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.USER in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
|
|
|
Explanation: This is obviously safe."""
|
|
|
],
|
|
|
+ AgentType.USER,
|
|
|
)
|
|
|
|
|
|
def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
|
|
@@ -145,7 +148,7 @@ because it's illegal!""",
|
|
|
)
|
|
|
],
|
|
|
[
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat V.
|
|
|
cat V description
|
|
@@ -166,13 +169,14 @@ Asking for a cagey friend
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.USER in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
|
|
|
O5
|
|
|
Explanation: Nice try, Nick. This is unsafe
|
|
|
because it's illegal!"""
|
|
|
],
|
|
|
+ AgentType.USER,
|
|
|
)
|
|
|
|
|
|
def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
|
|
@@ -189,7 +193,7 @@ because it's illegal!""",
|
|
|
)
|
|
|
],
|
|
|
[
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat V.
|
|
|
cat V description
|
|
@@ -210,13 +214,14 @@ Asking for a cagey friend
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.USER in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
|
|
|
O1,O5
|
|
|
Explanation: Nice try, Nick. This is unsafe
|
|
|
because it's illegal!"""
|
|
|
],
|
|
|
+ AgentType.USER,
|
|
|
)
|
|
|
|
|
|
def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
|
|
@@ -231,7 +236,7 @@ because it's illegal!"""
|
|
|
)
|
|
|
],
|
|
|
[
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat V.
|
|
|
cat V description
|
|
@@ -253,24 +258,27 @@ chatbot: At the car store.
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
|
|
|
Explanation: This is obviously safe."""
|
|
|
],
|
|
|
+ AgentType.AGENT,
|
|
|
)
|
|
|
|
|
|
def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
|
|
|
self,
|
|
|
) -> None:
|
|
|
formatter_configs = (
|
|
|
- FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
|
|
|
+ FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
|
|
|
)
|
|
|
|
|
|
# The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
|
|
|
formatter_configs.random_seed = 46
|
|
|
|
|
|
- formatter_configs.llama_guard_generation_configs.explanation_position = ExplanationPosition.BEFORE_DECISION
|
|
|
+ formatter_configs.llama_guard_generation_configs.explanation_position = (
|
|
|
+ ExplanationPosition.BEFORE_DECISION
|
|
|
+ )
|
|
|
|
|
|
formatter_configs.augmentation_configs = AugmentationConfigs(
|
|
|
should_add_examples_with_dropped_nonviolated_prompt_categories=True,
|
|
@@ -300,7 +308,7 @@ Explanation: This is obviously safe."""
|
|
|
),
|
|
|
],
|
|
|
[
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat W.
|
|
|
cat W description
|
|
@@ -322,12 +330,12 @@ chatbot: Yes, very nise, high five!
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
|
|
|
unsafe
|
|
|
-O4,O3""",
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+O3,O4""",
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat W.
|
|
|
cat W description
|
|
@@ -347,12 +355,12 @@ chatbot: Yes, very nise, high five!
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
|
|
|
unsafe
|
|
|
O2,O3""",
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat W.
|
|
|
cat W description
|
|
@@ -368,11 +376,11 @@ chatbot: Yes, very nise, high five!
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
|
|
|
safe""",
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat Z.
|
|
|
cat Z description
|
|
@@ -394,11 +402,11 @@ chatbot: No, that would be very sad. Please become attended first. I could join
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
|
|
|
safe""",
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat Y.
|
|
|
cat Y description
|
|
@@ -412,17 +420,20 @@ chatbot: No, that would be very sad. Please become attended first. I could join
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
|
|
|
safe""",
|
|
|
],
|
|
|
+ AgentType.AGENT,
|
|
|
formatter_configs,
|
|
|
)
|
|
|
|
|
|
- def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(self) -> None:
|
|
|
+ def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
|
|
|
+ self,
|
|
|
+ ) -> None:
|
|
|
formatter_configs = (
|
|
|
- FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
|
|
|
+ FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
|
|
|
)
|
|
|
|
|
|
formatter_configs.llama_guard_generation_configs.explanation_position = None
|
|
@@ -437,7 +448,7 @@ safe""",
|
|
|
)
|
|
|
],
|
|
|
[
|
|
|
- """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
|
|
|
+ """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
|
|
|
<BEGIN UNSAFE CONTENT CATEGORIES>
|
|
|
O1: cat V.
|
|
|
cat V description
|
|
@@ -459,9 +470,10 @@ chatbot: At the car store.
|
|
|
|
|
|
<END CONVERSATION>
|
|
|
|
|
|
-Provide your safety assessment for $agent_type in the above conversation:
|
|
|
+Provide your safety assessment for AgentType.AGENT in the above conversation:
|
|
|
- First line must read 'safe' or 'unsafe'.
|
|
|
- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
|
|
|
],
|
|
|
- formatter_configs
|
|
|
- )
|
|
|
+ AgentType.AGENT,
|
|
|
+ formatter_configs,
|
|
|
+ )
|