Explorar o código

Fix usage of dataclass for train_config and fsdp_config

Matthias Reso hai 1 ano
pai
achega
5da84b2913
Modificáronse 2 ficheiros con 6 adicións e 7 borrados
  1. 3 1
      src/llama_recipes/finetuning.py
  2. 3 6
      tests/test_finetuning.py

+ 3 - 1
src/llama_recipes/finetuning.py

@@ -21,7 +21,8 @@ from transformers import (
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
@@ -47,6 +48,7 @@ from llama_recipes.utils.train_utils import (
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility

+ 3 - 6
tests/test_finetuning.py

@@ -101,8 +101,8 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
     get_dataset.return_value = get_fake_dataset()
 
-    get_peft_model.return_value = Linear(1,1)
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
 
     assert train.call_count == 1
@@ -123,10 +123,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
-    kwargs = {
-        "batching_strategy": "packing",
-        "use_peft": False,
-        }
+    kwargs = {"batching_strategy": "packing"}
 
     get_dataset.return_value = get_fake_dataset()