test_train_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. import os
  7. import shutil
  8. from llama_recipes.utils.train_utils import train
  9. TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
  10. @pytest.fixture(scope="session")
  11. def temp_output_dir():
  12. # Create the directory during the session-level setup
  13. temp_output_dir = "tmp"
  14. os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
  15. yield temp_output_dir
  16. # Delete the directory during the session-level teardown
  17. shutil.rmtree(temp_output_dir)
  18. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  19. @patch("llama_recipes.utils.train_utils.nullcontext")
  20. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
  21. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
  22. def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
  23. model = mocker.MagicMock(name="model")
  24. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  25. mock_tensor = mocker.MagicMock(name="tensor")
  26. batch = {"input": mock_tensor}
  27. train_dataloader = [batch, batch, batch, batch, batch]
  28. eval_dataloader = None
  29. tokenizer = mocker.MagicMock()
  30. optimizer = mocker.MagicMock()
  31. lr_scheduler = mocker.MagicMock()
  32. gradient_accumulation_steps = 1
  33. train_config = mocker.MagicMock()
  34. train_config.enable_fsdp = False
  35. train_config.use_fp16 = False
  36. train_config.run_validation = False
  37. train_config.gradient_clipping = False
  38. train_config.save_metrics = False
  39. train(
  40. model,
  41. train_dataloader,
  42. eval_dataloader,
  43. tokenizer,
  44. optimizer,
  45. lr_scheduler,
  46. gradient_accumulation_steps,
  47. train_config,
  48. )
  49. assert optimizer.zero_grad.call_count == 5
  50. optimizer.zero_grad.reset_mock()
  51. assert nullcontext.call_count == 5
  52. nullcontext.reset_mock()
  53. assert autocast.call_count == 0
  54. gradient_accumulation_steps = 2
  55. train_config.use_fp16 = True
  56. train(
  57. model,
  58. train_dataloader,
  59. eval_dataloader,
  60. tokenizer,
  61. optimizer,
  62. lr_scheduler,
  63. gradient_accumulation_steps,
  64. train_config,
  65. )
  66. assert optimizer.zero_grad.call_count == 3
  67. assert nullcontext.call_count == 0
  68. assert autocast.call_count == 5
  69. def test_save_to_json(temp_output_dir, mocker):
  70. model = mocker.MagicMock(name="model")
  71. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  72. mock_tensor = mocker.MagicMock(name="tensor")
  73. batch = {"input": mock_tensor}
  74. train_dataloader = [batch, batch, batch, batch, batch]
  75. eval_dataloader = None
  76. tokenizer = mocker.MagicMock()
  77. optimizer = mocker.MagicMock()
  78. lr_scheduler = mocker.MagicMock()
  79. gradient_accumulation_steps = 1
  80. train_config = mocker.MagicMock()
  81. train_config.enable_fsdp = False
  82. train_config.use_fp16 = False
  83. train_config.run_validation = False
  84. train_config.gradient_clipping = False
  85. train_config.save_metrics = True
  86. train_config.output_dir = temp_output_dir
  87. results = train(
  88. model,
  89. train_dataloader,
  90. eval_dataloader,
  91. tokenizer,
  92. optimizer,
  93. lr_scheduler,
  94. gradient_accumulation_steps,
  95. train_config,
  96. local_rank=0
  97. )
  98. assert results["metrics_filename"] not in ["", None]
  99. assert os.path.isfile(results["metrics_filename"])