test_train_utils.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import torch
  5. from llama_recipes.utils.train_utils import train
  6. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  7. @patch("llama_recipes.utils.train_utils.nullcontext")
  8. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
  9. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
  10. def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
  11. model = mocker.MagicMock(name="model")
  12. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  13. mock_tensor = mocker.MagicMock(name="tensor")
  14. batch = {"input": mock_tensor}
  15. train_dataloader = [batch, batch, batch, batch, batch]
  16. eval_dataloader = None
  17. tokenizer = mocker.MagicMock()
  18. optimizer = mocker.MagicMock()
  19. lr_scheduler = mocker.MagicMock()
  20. gradient_accumulation_steps = 1
  21. train_config = mocker.MagicMock()
  22. train_config.enable_fsdp = False
  23. train_config.use_fp16 = False
  24. train_config.run_validation = False
  25. train(
  26. model,
  27. train_dataloader,
  28. eval_dataloader,
  29. tokenizer,
  30. optimizer,
  31. lr_scheduler,
  32. gradient_accumulation_steps,
  33. train_config,
  34. )
  35. assert optimizer.zero_grad.call_count == 5
  36. optimizer.zero_grad.reset_mock()
  37. assert nullcontext.call_count == 5
  38. nullcontext.reset_mock()
  39. assert autocast.call_count == 0
  40. gradient_accumulation_steps = 2
  41. train_config.use_fp16 = True
  42. train(
  43. model,
  44. train_dataloader,
  45. eval_dataloader,
  46. tokenizer,
  47. optimizer,
  48. lr_scheduler,
  49. gradient_accumulation_steps,
  50. train_config,
  51. )
  52. assert optimizer.zero_grad.call_count == 3
  53. assert nullcontext.call_count == 0
  54. assert autocast.call_count == 5