Bladeren bron

Make tests run on cpu instance (#212)

Geeta Chauhan 1 jaar geleden
bovenliggende
commit
18e0198626
2 gewijzigde bestanden met toevoegingen van 10 en 6 verwijderingen
  1. 4 2
      tests/test_finetuning.py
  2. 6 4
      tests/test_train_utils.py

+ 4 - 2
tests/test_finetuning.py

@@ -83,12 +83,14 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
     
     get_dataset.return_value = [1]
     
-    get_peft_model.return_value = Linear(1,1)
+    model = mocker.MagicMock(name="model")
+    model.parameters.return_value = Linear(1,1).parameters()
+    get_peft_model.return_value = model 
     get_peft_model.return_value.print_trainable_parameters=lambda:None
     main(**kwargs)
     

+ 6 - 4
tests/test_train_utils.py

@@ -1,17 +1,19 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from unittest.mock import patch
+
 import torch
 
 from llama_recipes.utils.train_utils import train
 
-def test_gradient_accumulation(mocker):
-    # import sys
-    # sys.path.append('/home/ubuntu/llama-recipes/')
+@patch("llama_recipes.utils.train_utils.MemoryTrace")
+def test_gradient_accumulation(mem_trace, mocker):
     
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
-    batch = {"input": torch.zeros(1)}
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     eval_dataloader = None
     tokenizer = mocker.MagicMock()