|
@@ -5,6 +5,7 @@ import pytest
|
|
|
from pytest import approx
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
+import torch
|
|
|
from torch.nn import Linear
|
|
|
from torch.optim import AdamW
|
|
|
from torch.utils.data.dataloader import DataLoader
|
|
@@ -100,8 +101,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
|
|
|
kwargs = {"weight_decay": 0.01}
|
|
|
|
|
|
get_dataset.return_value = get_fake_dataset()
|
|
|
+
|
|
|
+ model = mocker.MagicMock(name="Model")
|
|
|
+ model.parameters.return_value = [torch.ones(1,1)]
|
|
|
|
|
|
- get_model.return_value = Linear(1,1)
|
|
|
+ get_model.return_value = model
|
|
|
|
|
|
main(**kwargs)
|
|
|
|