test_train_utils.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import torch
  5. from llama_recipes.utils.train_utils import train
  6. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  7. @patch("llama_recipes.utils.train_utils.nullcontext")
  8. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
  9. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
  10. def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
  11. model = mocker.MagicMock(name="model")
  12. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  13. mock_tensor = mocker.MagicMock(name="tensor")
  14. batch = {"input": mock_tensor}
  15. train_dataloader = [batch, batch, batch, batch, batch]
  16. eval_dataloader = None
  17. tokenizer = mocker.MagicMock()
  18. optimizer = mocker.MagicMock()
  19. lr_scheduler = mocker.MagicMock()
  20. gradient_accumulation_steps = 1
  21. train_config = mocker.MagicMock()
  22. train_config.enable_fsdp = False
  23. train_config.use_fp16 = False
  24. train_config.run_validation = False
  25. train_config.gradient_clipping = False
  26. train(
  27. model,
  28. train_dataloader,
  29. eval_dataloader,
  30. tokenizer,
  31. optimizer,
  32. lr_scheduler,
  33. gradient_accumulation_steps,
  34. train_config,
  35. )
  36. assert optimizer.zero_grad.call_count == 5
  37. optimizer.zero_grad.reset_mock()
  38. assert nullcontext.call_count == 5
  39. nullcontext.reset_mock()
  40. assert autocast.call_count == 0
  41. gradient_accumulation_steps = 2
  42. train_config.use_fp16 = True
  43. train(
  44. model,
  45. train_dataloader,
  46. eval_dataloader,
  47. tokenizer,
  48. optimizer,
  49. lr_scheduler,
  50. gradient_accumulation_steps,
  51. train_config,
  52. )
  53. assert optimizer.zero_grad.call_count == 3
  54. assert nullcontext.call_count == 0
  55. assert autocast.call_count == 5