|
@@ -83,12 +83,14 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
|
|
|
@patch('llama_recipes.finetuning.get_preprocessed_dataset')
|
|
|
@patch('llama_recipes.finetuning.get_peft_model')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
|
|
|
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
|
|
|
kwargs = {"weight_decay": 0.01}
|
|
|
|
|
|
get_dataset.return_value = [1]
|
|
|
|
|
|
- get_peft_model.return_value = Linear(1,1)
|
|
|
+ model = mocker.MagicMock(name="model")
|
|
|
+ model.parameters.return_value = Linear(1,1).parameters()
|
|
|
+ get_peft_model.return_value = model
|
|
|
get_peft_model.return_value.print_trainable_parameters=lambda:None
|
|
|
main(**kwargs)
|
|
|
|