test_finetuning.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from pytest import approx
  4. from unittest.mock import patch
  5. from torch.nn import Linear
  6. from torch.optim import AdamW
  7. from torch.utils.data.dataloader import DataLoader
  8. from llama_recipes.finetuning import main
  9. @patch('llama_recipes.finetuning.train')
  10. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  11. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  12. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  13. @patch('llama_recipes.finetuning.optim.AdamW')
  14. @patch('llama_recipes.finetuning.StepLR')
  15. def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  16. kwargs = {"run_validation": False}
  17. get_dataset.return_value = [1]
  18. main(**kwargs)
  19. assert train.call_count == 1
  20. args, kwargs = train.call_args
  21. train_dataloader = args[1]
  22. eval_dataloader = args[2]
  23. assert isinstance(train_dataloader, DataLoader)
  24. assert eval_dataloader is None
  25. assert get_model.return_value.to.call_args.args[0] == "cuda"
  26. @patch('llama_recipes.finetuning.train')
  27. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  28. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  29. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  30. @patch('llama_recipes.finetuning.optim.AdamW')
  31. @patch('llama_recipes.finetuning.StepLR')
  32. def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  33. kwargs = {"run_validation": True}
  34. get_dataset.return_value = [1]
  35. main(**kwargs)
  36. assert train.call_count == 1
  37. args, kwargs = train.call_args
  38. train_dataloader = args[1]
  39. eval_dataloader = args[2]
  40. assert isinstance(train_dataloader, DataLoader)
  41. assert isinstance(eval_dataloader, DataLoader)
  42. assert get_model.return_value.to.call_args.args[0] == "cuda"
  43. @patch('llama_recipes.finetuning.train')
  44. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  45. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  46. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  47. @patch('llama_recipes.finetuning.generate_peft_config')
  48. @patch('llama_recipes.finetuning.get_peft_model')
  49. @patch('llama_recipes.finetuning.optim.AdamW')
  50. @patch('llama_recipes.finetuning.StepLR')
  51. def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
  52. kwargs = {"use_peft": True}
  53. get_dataset.return_value = [1]
  54. main(**kwargs)
  55. assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
  56. assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
  57. @patch('llama_recipes.finetuning.train')
  58. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  59. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  60. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  61. @patch('llama_recipes.finetuning.get_peft_model')
  62. @patch('llama_recipes.finetuning.StepLR')
  63. def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
  64. kwargs = {"weight_decay": 0.01}
  65. get_dataset.return_value = [1]
  66. model = mocker.MagicMock(name="model")
  67. model.parameters.return_value = Linear(1,1).parameters()
  68. get_peft_model.return_value = model
  69. get_peft_model.return_value.print_trainable_parameters=lambda:None
  70. main(**kwargs)
  71. assert train.call_count == 1
  72. args, kwargs = train.call_args
  73. optimizer = args[4]
  74. print(optimizer.state_dict())
  75. assert isinstance(optimizer, AdamW)
  76. assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)