Browse Source

Add unit test for weight decay

Matthias Reso 1 year ago
parent
commit
0b2fa40dba
1 changed files with 31 additions and 2 deletions
  1. 31 2
      tests/test_finetuning.py

+ 31 - 2
tests/test_finetuning.py

@@ -1,9 +1,11 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from pytest import approx
 from unittest.mock import patch
-import importlib
 
+from torch.nn import Linear
+from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 
 from llama_recipes.finetuning import main
@@ -72,4 +74,31 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
     main(**kwargs)
     
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
-    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+    
+    
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.get_peft_model')
+@patch('llama_recipes.finetuning.StepLR')
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
+    kwargs = {"weight_decay": 0.01}
+    
+    get_dataset.return_value = [1]
+    
+    get_peft_model.return_value = Linear(1,1)
+    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    optimizer = args[4]
+    
+    print(optimizer.state_dict())
+    
+    assert isinstance(optimizer, AdamW)
+    assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
+