|
@@ -1,4 +1,5 @@
|
|
|
from unittest.mock import patch
|
|
|
+import importlib
|
|
|
|
|
|
from torch.utils.data.dataloader import DataLoader
|
|
|
|
|
@@ -11,7 +12,7 @@ from llama_recipes.finetuning import main
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
|
|
|
- kwargs = {"run_validation": True}
|
|
|
+ kwargs = {"run_validation": False}
|
|
|
|
|
|
get_dataset.return_value = [1]
|
|
|
|
|
@@ -36,8 +37,7 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
|
|
|
- kwargs = {"run_validation": False}
|
|
|
-
|
|
|
+ kwargs = {"run_validation": True}
|
|
|
get_dataset.return_value = [1]
|
|
|
|
|
|
main(**kwargs)
|
|
@@ -47,7 +47,6 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
|
|
|
args, kwargs = train.call_args
|
|
|
train_dataloader = args[1]
|
|
|
eval_dataloader = args[2]
|
|
|
-
|
|
|
assert isinstance(train_dataloader, DataLoader)
|
|
|
assert isinstance(eval_dataloader, DataLoader)
|
|
|
|