config_utils.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import inspect
  4. from dataclasses import fields
  5. from peft import (
  6. LoraConfig,
  7. AdaptionPromptConfig,
  8. PrefixTuningConfig,
  9. )
  10. from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
  11. from llama_recipes.utils.dataset_utils import DATASET_PREPROC
  12. def update_config(config, **kwargs):
  13. if isinstance(config, (tuple, list)):
  14. for c in config:
  15. update_config(c, **kwargs)
  16. else:
  17. for k, v in kwargs.items():
  18. if hasattr(config, k):
  19. setattr(config, k, v)
  20. elif "." in k:
  21. # allow --some_config.some_param=True
  22. config_name, param_name = k.split(".")
  23. if type(config).__name__ == config_name:
  24. if hasattr(config, param_name):
  25. setattr(config, param_name, v)
  26. else:
  27. # In case of specialized config we can warm user
  28. print(f"Warning: {config_name} does not accept parameter: {k}")
  29. elif isinstance(config, train_config):
  30. print(f"Warning: unknown parameter {k}")
  31. def generate_peft_config(train_config, kwargs):
  32. configs = (lora_config, llama_adapter_config, prefix_config)
  33. peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
  34. names = tuple(c.__name__.rstrip("_config") for c in configs)
  35. assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
  36. config = configs[names.index(train_config.peft_method)]
  37. update_config(config, **kwargs)
  38. params = {k.name: getattr(config, k.name) for k in fields(config)}
  39. peft_config = peft_configs[names.index(train_config.peft_method)](**params)
  40. return peft_config
  41. def generate_dataset_config(train_config, kwargs):
  42. names = tuple(DATASET_PREPROC.keys())
  43. assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
  44. dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
  45. update_config(dataset_config, **kwargs)
  46. return dataset_config