|
@@ -8,7 +8,10 @@ import torch
|
|
|
from llama_recipes.utils.train_utils import train
|
|
|
|
|
|
@patch("llama_recipes.utils.train_utils.MemoryTrace")
|
|
|
-def test_gradient_accumulation(mem_trace, mocker):
|
|
|
+@patch("llama_recipes.utils.train_utils.nullcontext")
|
|
|
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
|
|
|
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
|
|
|
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
|
|
|
|
|
|
model = mocker.MagicMock(name="model")
|
|
|
model().loss.__truediv__().detach.return_value = torch.tensor(1)
|
|
@@ -39,7 +42,13 @@ def test_gradient_accumulation(mem_trace, mocker):
|
|
|
assert optimizer.zero_grad.call_count == 5
|
|
|
optimizer.zero_grad.reset_mock()
|
|
|
|
|
|
+ assert nullcontext.call_count == 5
|
|
|
+ nullcontext.reset_mock()
|
|
|
+
|
|
|
+ assert autocast.call_count == 0
|
|
|
+
|
|
|
gradient_accumulation_steps = 2
|
|
|
+ train_config.use_fp16 = True
|
|
|
train(
|
|
|
model,
|
|
|
train_dataloader,
|
|
@@ -50,4 +59,6 @@ def test_gradient_accumulation(mem_trace, mocker):
|
|
|
gradient_accumulation_steps,
|
|
|
train_config,
|
|
|
)
|
|
|
- assert optimizer.zero_grad.call_count == 3
|
|
|
+ assert optimizer.zero_grad.call_count == 3
|
|
|
+ assert nullcontext.call_count == 0
|
|
|
+ assert autocast.call_count == 5
|