Browse Source

Replaced ClassVar config param with field

tim-a-davis 1 year ago
parent
commit
3038020aa4
2 changed files with 5 additions and 6 deletions
  1. 3 3
      src/llama_recipes/configs/peft.py
  2. 2 3
      src/llama_recipes/utils/config_utils.py

+ 3 - 3
src/llama_recipes/configs/peft.py

@@ -1,14 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-from dataclasses import dataclass
-from typing import ClassVar, List
+from dataclasses import dataclass, field
+from typing import List
 
 @dataclass
 class lora_config:
      r: int=8
      lora_alpha: int=32
-     target_modules: ClassVar[List[str]]= ["q_proj", "v_proj"]
+     target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
      bias= "none"
      task_type: str= "CAUSAL_LM"
      lora_dropout: float=0.05

+ 2 - 3
src/llama_recipes/utils/config_utils.py

@@ -2,8 +2,7 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import inspect
-from dataclasses import fields
-
+from dataclasses import asdict
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
@@ -45,7 +44,7 @@ def generate_peft_config(train_config, kwargs):
     config = configs[names.index(train_config.peft_method)]()
     
     update_config(config, **kwargs)
-    params = {k.name: getattr(config, k.name) for k in fields(config)}
+    params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     
     return peft_config