Browse Source

changing list to List in the script and test, to attempt to fix issue reported during review

Beto 1 year ago
parent
commit
6266d31110

+ 10 - 10
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -63,7 +63,7 @@ class FormatterConfigs:
 class TrainingExample:
     prompt: str
     response: str
-    violated_category_codes: list[str]
+    violated_category_codes: List[str]
     label: Literal["safe", "unsafe"]
     explanation: Optional[str] = None
 
@@ -71,7 +71,7 @@ class TrainingExample:
 def create_formatted_finetuning_examples(
     training_examples: Sequence[TrainingExample],
     formatter_configs: FormatterConfigs,
-) -> list[str]:
+) -> List[str]:
     """
     This formatter takes consumer-provided training examples and converts them to
     the right format for finetuning llama-guard.
@@ -285,7 +285,7 @@ def _get_map_of_original_category_indices_to_rewritten_category_codes(
 
 def _maybe_add_data_augmentations_for_example(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -317,8 +317,8 @@ def _maybe_add_data_augmentations_for_example(
 
 
 def _convert_category_codes_to_indices(
-    codes: list[str], formatter_configs: FormatterConfigs
-) -> list[int]:
+    codes: List[str], formatter_configs: FormatterConfigs
+) -> List[int]:
     # Category codes start at 1, but indices start at 0, so we subtract 1
     return [
         int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
@@ -328,9 +328,9 @@ def _convert_category_codes_to_indices(
 
 def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    nonviolated_category_indices: list[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """
@@ -368,10 +368,10 @@ def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
 
 def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    violated_category_indices: list[int],
-    nonviolated_category_indices: list[int],
+    violated_category_indices: List[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """

+ 3 - 3
tests/test_finetuning_data_formatter.py

@@ -3,7 +3,7 @@
 
 from enum import Enum
 import unittest
-from typing import Optional
+from typing import Optional, List
 
 from llama_recipes.data.llama_guard.finetuning_data_formatter import (
     AugmentationConfigs,
@@ -72,8 +72,8 @@ Provide your safety assessment for {agent_type_to_check} in the above conversati
 
     def format_examples_and_compare_to_expected(
         self,
-        training_examples: list[TrainingExample],
-        expected_formatted_examples: list[str],
+        training_examples: List[TrainingExample],
+        expected_formatted_examples: List[str],
         agent_type_to_check: AgentType,
         formatter_configs: Optional[FormatterConfigs] = None,
     ) -> None: