test_train_utils.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import pytest
  5. import torch
  6. import os
  7. import shutil
  8. from llama_recipes.utils.train_utils import train
  9. TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
  10. @pytest.fixture(scope="session")
  11. def temp_output_dir():
  12. # Create the directory during the session-level setup
  13. temp_output_dir = "tmp"
  14. os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
  15. yield temp_output_dir
  16. # Delete the directory during the session-level teardown
  17. shutil.rmtree(temp_output_dir)
  18. @patch("llama_recipes.utils.train_utils.MemoryTrace")
  19. @patch("llama_recipes.utils.train_utils.nullcontext")
  20. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
  21. @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
  22. def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
  23. model = mocker.MagicMock(name="model")
  24. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  25. mock_tensor = mocker.MagicMock(name="tensor")
  26. batch = {"input": mock_tensor}
  27. train_dataloader = [batch, batch, batch, batch, batch]
  28. eval_dataloader = None
  29. tokenizer = mocker.MagicMock()
  30. optimizer = mocker.MagicMock()
  31. lr_scheduler = mocker.MagicMock()
  32. gradient_accumulation_steps = 1
  33. train_config = mocker.MagicMock()
  34. train_config.enable_fsdp = False
  35. train_config.use_fp16 = False
  36. train_config.run_validation = False
  37. train_config.gradient_clipping = False
  38. train_config.max_train_step = 0
  39. train_config.max_eval_step = 0
  40. train_config.save_metrics = False
  41. train(
  42. model,
  43. train_dataloader,
  44. eval_dataloader,
  45. tokenizer,
  46. optimizer,
  47. lr_scheduler,
  48. gradient_accumulation_steps,
  49. train_config,
  50. )
  51. assert optimizer.zero_grad.call_count == 5
  52. optimizer.zero_grad.reset_mock()
  53. assert nullcontext.call_count == 5
  54. nullcontext.reset_mock()
  55. assert autocast.call_count == 0
  56. gradient_accumulation_steps = 2
  57. train_config.use_fp16 = True
  58. train(
  59. model,
  60. train_dataloader,
  61. eval_dataloader,
  62. tokenizer,
  63. optimizer,
  64. lr_scheduler,
  65. gradient_accumulation_steps,
  66. train_config,
  67. )
  68. assert optimizer.zero_grad.call_count == 3
  69. assert nullcontext.call_count == 0
  70. assert autocast.call_count == 5
  71. def test_save_to_json(temp_output_dir, mocker):
  72. model = mocker.MagicMock(name="model")
  73. model().loss.__truediv__().detach.return_value = torch.tensor(1)
  74. mock_tensor = mocker.MagicMock(name="tensor")
  75. batch = {"input": mock_tensor}
  76. train_dataloader = [batch, batch, batch, batch, batch]
  77. eval_dataloader = None
  78. tokenizer = mocker.MagicMock()
  79. optimizer = mocker.MagicMock()
  80. lr_scheduler = mocker.MagicMock()
  81. gradient_accumulation_steps = 1
  82. train_config = mocker.MagicMock()
  83. train_config.enable_fsdp = False
  84. train_config.use_fp16 = False
  85. train_config.run_validation = False
  86. train_config.gradient_clipping = False
  87. train_config.save_metrics = True
  88. train_config.max_train_step = 0
  89. train_config.max_eval_step = 0
  90. train_config.output_dir = temp_output_dir
  91. results = train(
  92. model,
  93. train_dataloader,
  94. eval_dataloader,
  95. tokenizer,
  96. optimizer,
  97. lr_scheduler,
  98. gradient_accumulation_steps,
  99. train_config,
  100. local_rank=0
  101. )
  102. assert results["metrics_filename"] not in ["", None]
  103. assert os.path.isfile(results["metrics_filename"])