test_finetuning.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from pytest import approx
  5. from unittest.mock import patch
  6. import torch
  7. from torch.optim import AdamW
  8. from torch.utils.data.dataloader import DataLoader
  9. from torch.utils.data.sampler import BatchSampler
  10. from llama_recipes.finetuning import main
  11. from llama_recipes.data.sampler import LengthBasedBatchSampler
  12. def get_fake_dataset():
  13. return [{
  14. "input_ids":[1],
  15. "attention_mask":[1],
  16. "labels":[1],
  17. }]
  18. @patch('llama_recipes.finetuning.torch.cuda.is_available')
  19. @patch('llama_recipes.finetuning.train')
  20. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  21. @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
  22. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  23. @patch('llama_recipes.finetuning.optim.AdamW')
  24. @patch('llama_recipes.finetuning.StepLR')
  25. @pytest.mark.parametrize("cuda_is_available", [True, False])
  26. def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
  27. kwargs = {"run_validation": False}
  28. get_dataset.return_value = get_fake_dataset()
  29. cuda.return_value = cuda_is_available
  30. main(**kwargs)
  31. assert train.call_count == 1
  32. args, kwargs = train.call_args
  33. train_dataloader = args[1]
  34. eval_dataloader = args[2]
  35. assert isinstance(train_dataloader, DataLoader)
  36. assert eval_dataloader is None
  37. if cuda_is_available:
  38. assert get_model.return_value.to.call_count == 1
  39. assert get_model.return_value.to.call_args.args[0] == "cuda"
  40. else:
  41. assert get_model.return_value.to.call_count == 0
  42. @patch('llama_recipes.finetuning.torch.cuda.is_available')
  43. @patch('llama_recipes.finetuning.train')
  44. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  45. @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
  46. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  47. @patch('llama_recipes.finetuning.optim.AdamW')
  48. @patch('llama_recipes.finetuning.StepLR')
  49. @pytest.mark.parametrize("cuda_is_available", [True, False])
  50. def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
  51. kwargs = {"run_validation": True}
  52. get_dataset.return_value = get_fake_dataset()
  53. cuda.return_value = cuda_is_available
  54. main(**kwargs)
  55. assert train.call_count == 1
  56. args, kwargs = train.call_args
  57. train_dataloader = args[1]
  58. eval_dataloader = args[2]
  59. assert isinstance(train_dataloader, DataLoader)
  60. assert isinstance(eval_dataloader, DataLoader)
  61. if cuda_is_available:
  62. assert get_model.return_value.to.call_count == 1
  63. assert get_model.return_value.to.call_args.args[0] == "cuda"
  64. else:
  65. assert get_model.return_value.to.call_count == 0
  66. @patch('llama_recipes.finetuning.torch.cuda.is_available')
  67. @patch('llama_recipes.finetuning.train')
  68. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  69. @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
  70. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  71. @patch('llama_recipes.finetuning.generate_peft_config')
  72. @patch('llama_recipes.finetuning.get_peft_model')
  73. @patch('llama_recipes.finetuning.optim.AdamW')
  74. @patch('llama_recipes.finetuning.StepLR')
  75. @pytest.mark.parametrize("cuda_is_available", [True, False])
  76. def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
  77. kwargs = {"use_peft": True}
  78. get_dataset.return_value = get_fake_dataset()
  79. cuda.return_value = cuda_is_available
  80. main(**kwargs)
  81. if cuda_is_available:
  82. assert get_peft_model.return_value.to.call_count == 1
  83. assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
  84. else:
  85. assert get_peft_model.return_value.to.call_count == 0
  86. assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
  87. @patch('llama_recipes.finetuning.train')
  88. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  89. @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
  90. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  91. @patch('llama_recipes.finetuning.get_peft_model')
  92. @patch('llama_recipes.finetuning.StepLR')
  93. def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
  94. kwargs = {"weight_decay": 0.01}
  95. get_dataset.return_value = get_fake_dataset()
  96. model = mocker.MagicMock(name="Model")
  97. model.parameters.return_value = [torch.ones(1,1)]
  98. get_model.return_value = model
  99. main(**kwargs)
  100. assert train.call_count == 1
  101. args, kwargs = train.call_args
  102. optimizer = args[4]
  103. print(optimizer.state_dict())
  104. assert isinstance(optimizer, AdamW)
  105. assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
  106. @patch('llama_recipes.finetuning.train')
  107. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  108. @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
  109. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  110. @patch('llama_recipes.finetuning.optim.AdamW')
  111. @patch('llama_recipes.finetuning.StepLR')
  112. def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  113. kwargs = {"batching_strategy": "packing"}
  114. get_dataset.return_value = get_fake_dataset()
  115. main(**kwargs)
  116. assert train.call_count == 1
  117. args, kwargs = train.call_args
  118. train_dataloader, eval_dataloader = args[1:3]
  119. assert isinstance(train_dataloader.batch_sampler, BatchSampler)
  120. assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
  121. kwargs["batching_strategy"] = "padding"
  122. train.reset_mock()
  123. main(**kwargs)
  124. assert train.call_count == 1
  125. args, kwargs = train.call_args
  126. train_dataloader, eval_dataloader = args[1:3]
  127. assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
  128. assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
  129. kwargs["batching_strategy"] = "none"
  130. with pytest.raises(ValueError):
  131. main(**kwargs)