Browse Source

comment fixes, type fixes

Michael Tontchev 1 year ago
parent
commit
fc19074085

+ 11 - 11
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -91,7 +91,7 @@ def create_formatted_finetuning_examples(
     Some of these augmented transformations are made by duplicating training
     examples and safely removing some violation categories from the llama
     guard prompts. Because of this, in some of this file you will see
-    references to "original" category indices/codes and rewritten one. The originals
+    references to "original" category indices/codes and rewritten ones. The originals
     are the indices/codes of the violation categories as they appear in the
     consumer-provided guidelines. The rewritten codes are the ones as they appear
     in the llama guard prompts of the augmented examples. We occasionally need to
@@ -143,7 +143,7 @@ def _verify_formatter_configs(
 
     if (
         formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
-        > 0
+        == True
         and formatter_configs.llama_guard_generation_configs.explanation_position
         is not None
         and formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
@@ -185,10 +185,6 @@ def _create_formatted_finetuning_example(
     return f"{llama_guard_prompt} {llama_guard_generation}"
 
 
-def _is_a_prompt_only_example(training_example: TrainingExample) -> bool:
-    return training_example.response == "N/A"
-
-
 def _create_llama_guard_prompt(
     training_example: TrainingExample,
     category_indices_to_include: List[int],
@@ -223,6 +219,10 @@ def _create_llama_guard_prompt(
     )
 
 
+def _is_a_prompt_only_example(training_example: TrainingExample) -> bool:
+    return training_example.response == "N/A"
+
+
 def _serialize_conversation(conversation: Dict[str, str]) -> str:
     conversation_as_list = []
 
@@ -296,7 +296,7 @@ def _get_map_of_original_category_indices_to_rewritten_category_codes(
 
 def _maybe_add_data_augmentations_for_example(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[dict[str, str]],
+    formatted_examples_being_built: list[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -317,7 +317,7 @@ def _maybe_add_data_augmentations_for_example(
 
 def _maybe_add_safe_example_with_empty_response(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[dict[str, str]],
+    formatted_examples_being_built: list[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -353,7 +353,7 @@ def _maybe_add_safe_example_with_empty_response(
 
 def _maybe_add_examples_with_dropped_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[dict[str, str]],
+    formatted_examples_being_built: list[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -396,7 +396,7 @@ def _convert_category_codes_to_indices(
 
 def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[dict[str, str]],
+    formatted_examples_being_built: list[str],
     indices_of_all_categories: range,
     nonviolated_category_indices: list[int],
     formatter_configs: FormatterConfigs,
@@ -436,7 +436,7 @@ def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
 
 def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[dict[str, str]],
+    formatted_examples_being_built: list[str],
     indices_of_all_categories: range,
     violated_category_indices: list[int],
     nonviolated_category_indices: list[int],

+ 1 - 0
tests/test_finetuning_data_formatter.py

@@ -241,6 +241,7 @@ Explanation: This is obviously safe."""
             FinetuningDataFormatterTests.create_most_conservative_formatter_configs()
         )
 
+        # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
         formatter_configs.random_seed = 46
 
         formatter_configs.llama_guard_generation_configs.explanation_position = ExplanationPosition.BEFORE_DECISION