1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
- import inspect
- from dataclasses import asdict
- from peft import (
- LoraConfig,
- AdaptionPromptConfig,
- PrefixTuningConfig,
- )
- from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
- from llama_recipes.utils.dataset_utils import DATASET_PREPROC
- def update_config(config, **kwargs):
- if isinstance(config, (tuple, list)):
- for c in config:
- update_config(c, **kwargs)
- else:
- for k, v in kwargs.items():
- if hasattr(config, k):
- setattr(config, k, v)
- elif "." in k:
- # allow --some_config.some_param=True
- config_name, param_name = k.split(".")
- if type(config).__name__ == config_name:
- if hasattr(config, param_name):
- setattr(config, param_name, v)
- else:
- # In case of specialized config we can warm user
- print(f"Warning: {config_name} does not accept parameter: {k}")
- elif isinstance(config, train_config):
- print(f"Warning: unknown parameter {k}")
-
-
- def generate_peft_config(train_config, kwargs):
- configs = (lora_config, llama_adapter_config, prefix_config)
- peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
- names = tuple(c.__name__.rstrip("_config") for c in configs)
-
- assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-
- config = configs[names.index(train_config.peft_method)]()
-
- update_config(config, **kwargs)
- params = asdict(config)
- peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-
- return peft_config
- def generate_dataset_config(train_config, kwargs):
- names = tuple(DATASET_PREPROC.keys())
-
- assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-
- dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-
- update_config(dataset_config, **kwargs)
-
- return dataset_config
|