test_finetuning.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from unittest.mock import patch
  2. from torch.utils.data.dataloader import DataLoader
  3. from llama_recipes.finetuning import main
  4. @patch('llama_recipes.finetuning.train')
  5. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  6. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  7. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  8. @patch('llama_recipes.finetuning.optim.AdamW')
  9. @patch('llama_recipes.finetuning.StepLR')
  10. def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  11. kwargs = {"run_validation": True}
  12. get_dataset.return_value = [1]
  13. main(**kwargs)
  14. assert train.call_count == 1
  15. args, kwargs = train.call_args
  16. train_dataloader = args[1]
  17. eval_dataloader = args[2]
  18. assert isinstance(train_dataloader, DataLoader)
  19. assert eval_dataloader is None
  20. assert get_model.return_value.to.call_args.args[0] == "cuda"
  21. @patch('llama_recipes.finetuning.train')
  22. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  23. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  24. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  25. @patch('llama_recipes.finetuning.optim.AdamW')
  26. @patch('llama_recipes.finetuning.StepLR')
  27. def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
  28. kwargs = {"run_validation": False}
  29. get_dataset.return_value = [1]
  30. main(**kwargs)
  31. assert train.call_count == 1
  32. args, kwargs = train.call_args
  33. train_dataloader = args[1]
  34. eval_dataloader = args[2]
  35. assert isinstance(train_dataloader, DataLoader)
  36. assert isinstance(eval_dataloader, DataLoader)
  37. assert get_model.return_value.to.call_args.args[0] == "cuda"
  38. @patch('llama_recipes.finetuning.train')
  39. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  40. @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
  41. @patch('llama_recipes.finetuning.get_preprocessed_dataset')
  42. @patch('llama_recipes.finetuning.generate_peft_config')
  43. @patch('llama_recipes.finetuning.get_peft_model')
  44. @patch('llama_recipes.finetuning.optim.AdamW')
  45. @patch('llama_recipes.finetuning.StepLR')
  46. def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
  47. kwargs = {"use_peft": True}
  48. get_dataset.return_value = [1]
  49. main(**kwargs)
  50. assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
  51. assert get_peft_model.return_value.print_trainable_parameters.call_count == 1