|
@@ -13,6 +13,15 @@ from torch.utils.data.sampler import BatchSampler
|
|
|
from llama_recipes.finetuning import main
|
|
|
from llama_recipes.data.sampler import LengthBasedBatchSampler
|
|
|
|
|
|
+
|
|
|
+def get_fake_dataset():
|
|
|
+ return [{
|
|
|
+ "input_ids":[1],
|
|
|
+ "attention_mask":[1],
|
|
|
+ "labels":[1],
|
|
|
+ }]
|
|
|
+
|
|
|
+
|
|
|
@patch('llama_recipes.finetuning.train')
|
|
|
@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
|
|
@@ -22,7 +31,7 @@ from llama_recipes.data.sampler import LengthBasedBatchSampler
|
|
|
def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
|
|
|
kwargs = {"run_validation": False}
|
|
|
|
|
|
- get_dataset.return_value = [[1]]
|
|
|
+ get_dataset.return_value = get_fake_dataset()
|
|
|
|
|
|
main(**kwargs)
|
|
|
|
|
@@ -46,7 +55,8 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
|
|
|
kwargs = {"run_validation": True}
|
|
|
- get_dataset.return_value = [[1]]
|
|
|
+
|
|
|
+ get_dataset.return_value = get_fake_dataset()
|
|
|
|
|
|
main(**kwargs)
|
|
|
|
|
@@ -72,7 +82,7 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
|
|
|
def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
|
|
|
kwargs = {"use_peft": True}
|
|
|
|
|
|
- get_dataset.return_value = [[1]]
|
|
|
+ get_dataset.return_value = get_fake_dataset()
|
|
|
|
|
|
main(**kwargs)
|
|
|
|
|
@@ -89,7 +99,7 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
|
|
|
def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
|
|
|
kwargs = {"weight_decay": 0.01}
|
|
|
|
|
|
- get_dataset.return_value = [[1]]
|
|
|
+ get_dataset.return_value = get_fake_dataset()
|
|
|
|
|
|
get_peft_model.return_value = Linear(1,1)
|
|
|
get_peft_model.return_value.print_trainable_parameters=lambda:None
|
|
@@ -113,9 +123,12 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
|
|
|
- kwargs = {"batching_strategy": "packing"}
|
|
|
+ kwargs = {
|
|
|
+ "batching_strategy": "packing",
|
|
|
+ "use_peft": False,
|
|
|
+ }
|
|
|
|
|
|
- get_dataset.return_value = [[1]]
|
|
|
+ get_dataset.return_value = get_fake_dataset()
|
|
|
|
|
|
main(**kwargs)
|
|
|
|