test_train_utils.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import torch
  4. from llama_recipes.utils.train_utils import train
  5. def test_gradient_accumulation(mocker):
  6. # import sys
  7. # sys.path.append('/home/ubuntu/llama-recipes/')
  8. model = mocker.MagicMock(name="model")
  9. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  10. batch = {"input": torch.zeros(1)}
  11. train_dataloader = [batch, batch, batch, batch, batch]
  12. eval_dataloader = None
  13. tokenizer = mocker.MagicMock()
  14. optimizer = mocker.MagicMock()
  15. lr_scheduler = mocker.MagicMock()
  16. gradient_accumulation_steps = 1
  17. train_config = mocker.MagicMock()
  18. train_config.enable_fsdp = False
  19. train_config.use_fp16 = False
  20. train_config.run_validation = False
  21. train(
  22. model,
  23. train_dataloader,
  24. eval_dataloader,
  25. tokenizer,
  26. optimizer,
  27. lr_scheduler,
  28. gradient_accumulation_steps,
  29. train_config,
  30. )
  31. assert optimizer.zero_grad.call_count == 5
  32. optimizer.zero_grad.reset_mock()
  33. gradient_accumulation_steps = 2
  34. train(
  35. model,
  36. train_dataloader,
  37. eval_dataloader,
  38. tokenizer,
  39. optimizer,
  40. lr_scheduler,
  41. gradient_accumulation_steps,
  42. train_config,
  43. )
  44. assert optimizer.zero_grad.call_count == 3