test_train_utils.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import torch
  5. from llama_recipes.utils.train_utils import train
  6. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  7. def test_gradient_accumulation(mem_trace, mocker):
  8. model = mocker.MagicMock(name="model")
  9. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  10. mock_tensor = mocker.MagicMock(name="tensor")
  11. batch = {"input": mock_tensor}
  12. train_dataloader = [batch, batch, batch, batch, batch]
  13. eval_dataloader = None
  14. tokenizer = mocker.MagicMock()
  15. optimizer = mocker.MagicMock()
  16. lr_scheduler = mocker.MagicMock()
  17. gradient_accumulation_steps = 1
  18. train_config = mocker.MagicMock()
  19. train_config.enable_fsdp = False
  20. train_config.use_fp16 = False
  21. train_config.run_validation = False
  22. train(
  23. model,
  24. train_dataloader,
  25. eval_dataloader,
  26. tokenizer,
  27. optimizer,
  28. lr_scheduler,
  29. gradient_accumulation_steps,
  30. train_config,
  31. )
  32. assert optimizer.zero_grad.call_count == 5
  33. optimizer.zero_grad.reset_mock()
  34. gradient_accumulation_steps = 2
  35. train(
  36. model,
  37. train_dataloader,
  38. eval_dataloader,
  39. tokenizer,
  40. optimizer,
  41. lr_scheduler,
  42. gradient_accumulation_steps,
  43. train_config,
  44. )
  45. assert optimizer.zero_grad.call_count == 3