Browse Source

Adapt test_finetuning to new model

Matthias Reso 10 months ago
parent
commit
5efea160a2
1 changed files with 26 additions and 19 deletions
  1. 26 19
      tests/test_finetuning.py

+ 26 - 19
tests/test_finetuning.py

@@ -21,17 +21,19 @@ def get_fake_dataset():
         "labels":[1],
         }]
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": False}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -44,23 +46,26 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
 
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -72,40 +77,42 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.generate_peft_config')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"use_peft": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
-    if torch.cuda.is_available():
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if cuda_is_available:
+        assert get_peft_model.return_value.to.call_count == 1
+        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     else:
-        assert get_model.return_value.to.call_count == 0
-    
+        assert get_peft_model.return_value.to.call_count == 0
+
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
@@ -113,11 +120,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()
-    
+
     model = mocker.MagicMock(name="Model")
     model.parameters.return_value = [torch.ones(1,1)]
 
-    get_model.return_value = model 
+    get_model.return_value = model
 
     main(**kwargs)
 
@@ -134,7 +141,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')