Procházet zdrojové kódy

Fix unit test to reflect batch packing

Matthias Reso před 1 rokem
rodič
revize
aa5dee241a
1 změnil soubory, kde provedl 19 přidání a 6 odebrání
  1. 19 6
      tests/test_finetuning.py

+ 19 - 6
tests/test_finetuning.py

@@ -13,6 +13,15 @@ from torch.utils.data.sampler import BatchSampler
 from llama_recipes.finetuning import main
 from llama_recipes.data.sampler import LengthBasedBatchSampler
 
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -22,7 +31,7 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
@@ -46,7 +55,8 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [[1]]
+
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
@@ -72,7 +82,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)
 
@@ -89,7 +99,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
     kwargs = {"weight_decay": 0.01}
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     get_peft_model.return_value = Linear(1,1)
     get_peft_model.return_value.print_trainable_parameters=lambda:None
@@ -113,9 +123,12 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
-    kwargs = {"batching_strategy": "packing"}
+    kwargs = {
+        "batching_strategy": "packing",
+        "use_peft": False,
+        }
 
-    get_dataset.return_value = [[1]]
+    get_dataset.return_value = get_fake_dataset()
 
     main(**kwargs)