浏览代码

Added unit test coverage for amp

Matthias Reso 1 年之前
父节点
当前提交
2e768b1d1d
共有 1 个文件被更改,包括 13 次插入2 次删除
  1. 13 2
      tests/test_train_utils.py

+ 13 - 2
tests/test_train_utils.py

@@ -8,7 +8,10 @@ import torch
 from llama_recipes.utils.train_utils import train
 
 @patch("llama_recipes.utils.train_utils.MemoryTrace")
-def test_gradient_accumulation(mem_trace, mocker):
+@patch("llama_recipes.utils.train_utils.nullcontext")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
     
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
@@ -39,7 +42,13 @@ def test_gradient_accumulation(mem_trace, mocker):
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
     
+    assert nullcontext.call_count == 5
+    nullcontext.reset_mock()
+    
+    assert autocast.call_count == 0
+    
     gradient_accumulation_steps = 2
+    train_config.use_fp16 = True
     train(
         model,
         train_dataloader,
@@ -50,4 +59,6 @@ def test_gradient_accumulation(mem_trace, mocker):
         gradient_accumulation_steps,
         train_config,
     )
-    assert optimizer.zero_grad.call_count == 3
+    assert optimizer.zero_grad.call_count == 3
+    assert nullcontext.call_count == 0
+    assert autocast.call_count == 5