test_finetuning.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from unittest.mock import patch
  4. import importlib
  5. from torch.utils.data.dataloader import DataLoader
  6. from llama_recipes.finetuning import main
  7. @patch('llama_recipes.finetuning.train')
  8. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  9. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  10. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  11. @patch('llama_recipes.finetuning.optim.AdamW')
  12. @patch('llama_recipes.finetuning.StepLR')
  13. def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  14. kwargs = {"run_validation": False}
  15. get_dataset.return_value = [1]
  16. main(**kwargs)
  17. assert train.call_count == 1
  18. args, kwargs = train.call_args
  19. train_dataloader = args[1]
  20. eval_dataloader = args[2]
  21. assert isinstance(train_dataloader, DataLoader)
  22. assert eval_dataloader is None
  23. assert get_model.return_value.to.call_args.args[0] == "cuda"
  24. @patch('llama_recipes.finetuning.train')
  25. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  26. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  27. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  28. @patch('llama_recipes.finetuning.optim.AdamW')
  29. @patch('llama_recipes.finetuning.StepLR')
  30. def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  31. kwargs = {"run_validation": True}
  32. get_dataset.return_value = [1]
  33. main(**kwargs)
  34. assert train.call_count == 1
  35. args, kwargs = train.call_args
  36. train_dataloader = args[1]
  37. eval_dataloader = args[2]
  38. assert isinstance(train_dataloader, DataLoader)
  39. assert isinstance(eval_dataloader, DataLoader)
  40. assert get_model.return_value.to.call_args.args[0] == "cuda"
  41. @patch('llama_recipes.finetuning.train')
  42. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  43. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  44. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  45. @patch('llama_recipes.finetuning.generate_peft_config')
  46. @patch('llama_recipes.finetuning.get_peft_model')
  47. @patch('llama_recipes.finetuning.optim.AdamW')
  48. @patch('llama_recipes.finetuning.StepLR')
  49. def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
  50. kwargs = {"use_peft": True}
  51. get_dataset.return_value = [1]
  52. main(**kwargs)
  53. assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
  54. assert get_peft_model.return_value.print_trainable_parameters.call_count == 1