Browse Source

More PR comments

Michael Tontchev 1 year ago
parent
commit
2e0d3ddd09

+ 11 - 5
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -179,7 +179,9 @@ def _create_llama_guard_prompt(
     ) in enumerate(category_indices_to_include):
         category = formatter_configs.guidelines.categories[original_category_index]
 
-        newline_for_every_category_after_first = f"\n" if rewritten_category_index_for_current_prompt > 0 else ""
+        newline_for_every_category_after_first = (
+            f"\n" if rewritten_category_index_for_current_prompt > 0 else ""
+        )
 
         # Indices start at 0, but categories start at 1, so we add 1
         full_guidelines_text += f"{newline_for_every_category_after_first}{formatter_configs.guidelines.category_code_prefix}{rewritten_category_index_for_current_prompt + 1}: {category.name}. "
@@ -239,10 +241,14 @@ def _create_llama_guard_generation(
             )
         )
 
-        rewritten_violated_category_codes = [
-            map_of_original_category_indices_to_rewritten_category_codes[violated_index]
-            for violated_index in violated_category_indices
-        ]
+        rewritten_violated_category_codes = sorted(
+            [
+                map_of_original_category_indices_to_rewritten_category_codes[
+                    violated_index
+                ]
+                for violated_index in violated_category_indices
+            ]
+        )
 
         to_return += "\n"
         to_return += ",".join(rewritten_violated_category_codes)

+ 46 - 34
tests/test_finetuning_data_formatter.py

@@ -3,7 +3,8 @@
 
 import unittest
 from typing import Optional
-from examples.llama_guard.prompt_format import PROMPT_TEMPLATE, AgentType
+
+from examples.llama_guard.prompt_format import AgentType
 
 from llama_recipes.data.llama_guard.finetuning_data_formatter import (
     AugmentationConfigs,
@@ -23,7 +24,7 @@ class FinetuningDataFormatterTests(unittest.TestCase):
         super().setUp()
 
     @staticmethod
-    def create_most_conservative_formatter_configs() -> FormatterConfigs:
+    def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
         return FormatterConfigs(
             guidelines=Guidelines(
                 categories=[
@@ -36,18 +37,18 @@ class FinetuningDataFormatterTests(unittest.TestCase):
                 category_code_prefix="O",
             ),
             llama_guard_prompt_configs=LlamaGuardPromptConfigs(
-                instructions_format_string="""[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
-{guidelines}
+{{guidelines}}
 <END UNSAFE CONTENT CATEGORIES>
 
 <BEGIN CONVERSATION>
 
-{conversation}
+{{conversation}}
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for {agent_type_to_check} in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
                 should_include_category_descriptions=True,
@@ -69,12 +70,13 @@ Provide your safety assessment for $agent_type in the above conversation:
         self,
         training_examples: list[TrainingExample],
         expected_formatted_examples: list[str],
+        agent_type_to_check: AgentType,
         formatter_configs: Optional[FormatterConfigs] = None,
     ) -> None:
         formatter_configs = (
             formatter_configs
             if formatter_configs is not None
-            else FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
+            else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
         )
 
         formatted_examples = create_formatted_finetuning_examples(
@@ -104,7 +106,7 @@ Provide your safety assessment for $agent_type in the above conversation:
                 )
             ],
             [
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat V. 
 cat V description
@@ -124,11 +126,12 @@ human: Where can I buy a car?
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.USER in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
 Explanation: This is obviously safe."""
             ],
+            AgentType.USER,
         )
 
     def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
@@ -145,7 +148,7 @@ because it's illegal!""",
                 )
             ],
             [
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat V. 
 cat V description
@@ -166,13 +169,14 @@ Asking for a cagey friend
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.USER in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
 O5
 Explanation: Nice try, Nick. This is unsafe
 because it's illegal!"""
             ],
+            AgentType.USER,
         )
 
     def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
@@ -189,7 +193,7 @@ because it's illegal!""",
                 )
             ],
             [
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat V. 
 cat V description
@@ -210,13 +214,14 @@ Asking for a cagey friend
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.USER in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
 O1,O5
 Explanation: Nice try, Nick. This is unsafe
 because it's illegal!"""
             ],
+            AgentType.USER,
         )
 
     def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
@@ -231,7 +236,7 @@ because it's illegal!"""
                 )
             ],
             [
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat V. 
 cat V description
@@ -253,24 +258,27 @@ chatbot: At the car store.
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
 Explanation: This is obviously safe."""
             ],
+            AgentType.AGENT,
         )
 
     def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
         self,
     ) -> None:
         formatter_configs = (
-            FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
+            FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
         )
 
         # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
         formatter_configs.random_seed = 46
 
-        formatter_configs.llama_guard_generation_configs.explanation_position = ExplanationPosition.BEFORE_DECISION
+        formatter_configs.llama_guard_generation_configs.explanation_position = (
+            ExplanationPosition.BEFORE_DECISION
+        )
 
         formatter_configs.augmentation_configs = AugmentationConfigs(
             should_add_examples_with_dropped_nonviolated_prompt_categories=True,
@@ -300,7 +308,7 @@ Explanation: This is obviously safe."""
                 ),
             ],
             [
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat W. 
 cat W description
@@ -322,12 +330,12 @@ chatbot: Yes, very nise, high five!
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
 unsafe
-O4,O3""",
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+O3,O4""",
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat W. 
 cat W description
@@ -347,12 +355,12 @@ chatbot: Yes, very nise, high five!
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
 unsafe
 O2,O3""",
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat W. 
 cat W description
@@ -368,11 +376,11 @@ chatbot: Yes, very nise, high five!
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
 safe""",
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat Z. 
 cat Z description
@@ -394,11 +402,11 @@ chatbot: No, that would be very sad. Please become attended first. I could join
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
 safe""",
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat Y. 
 cat Y description
@@ -412,17 +420,20 @@ chatbot: No, that would be very sad. Please become attended first. I could join
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
 safe""",
             ],
+            AgentType.AGENT,
             formatter_configs,
         )
 
-    def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(self) -> None:
+    def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
+        self,
+    ) -> None:
         formatter_configs = (
-            FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
+            FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
         )
 
         formatter_configs.llama_guard_generation_configs.explanation_position = None
@@ -437,7 +448,7 @@ safe""",
                 )
             ],
             [
-                """[INST] Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
 <BEGIN UNSAFE CONTENT CATEGORIES>
 O1: cat V. 
 cat V description
@@ -459,9 +470,10 @@ chatbot: At the car store.
 
 <END CONVERSATION>
 
-Provide your safety assessment for $agent_type in the above conversation:
+Provide your safety assessment for AgentType.AGENT in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
             ],
-            formatter_configs
-        )
+            AgentType.AGENT,
+            formatter_configs,
+        )