test_train_utils.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import torch
  2. from llama_recipes.utils.train_utils import train
  3. def test_gradient_accumulation(mocker):
  4. # import sys
  5. # sys.path.append('/home/ubuntu/llama-recipes/')
  6. model = mocker.MagicMock(name="model")
  7. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  8. batch = {"input": torch.zeros(1)}
  9. train_dataloader = [batch, batch, batch, batch, batch]
  10. eval_dataloader = None
  11. tokenizer = mocker.MagicMock()
  12. optimizer = mocker.MagicMock()
  13. lr_scheduler = mocker.MagicMock()
  14. gradient_accumulation_steps = 1
  15. train_config = mocker.MagicMock()
  16. train_config.enable_fsdp = False
  17. train_config.use_fp16 = False
  18. train_config.run_validation = False
  19. train(
  20. model,
  21. train_dataloader,
  22. eval_dataloader,
  23. tokenizer,
  24. optimizer,
  25. lr_scheduler,
  26. gradient_accumulation_steps,
  27. train_config,
  28. )
  29. assert optimizer.zero_grad.call_count == 5
  30. optimizer.zero_grad.reset_mock()
  31. gradient_accumulation_steps = 2
  32. train(
  33. model,
  34. train_dataloader,
  35. eval_dataloader,
  36. tokenizer,
  37. optimizer,
  38. lr_scheduler,
  39. gradient_accumulation_steps,
  40. train_config,
  41. )
  42. assert optimizer.zero_grad.call_count == 3