training.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. from typing import ClassVar
  5. @dataclass
  6. class train_config:
  7. model_name: str="PATH/to/LLAMA/7B"
  8. enable_fsdp: bool=False
  9. low_cpu_fsdp: bool=False
  10. run_validation: bool=True
  11. batch_size_training: int=4
  12. gradient_accumulation_steps: int=1
  13. num_epochs: int=3
  14. num_workers_dataloader: int=1
  15. lr: float=1e-4
  16. weight_decay: float=0.0
  17. gamma: float= 0.85
  18. seed: int=42
  19. use_fp16: bool=False
  20. mixed_precision: bool=True
  21. val_batch_size: int=1
  22. dataset = "samsum_dataset"
  23. peft_method: str = "lora" # None , llama_adapter, prefix
  24. use_peft: bool=False
  25. output_dir: str = "PATH/to/save/PEFT/model"
  26. freeze_layers: bool = False
  27. num_freeze_layers: int = 1
  28. quantization: bool = False
  29. one_gpu: bool = False
  30. save_model: bool = True
  31. dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
  32. dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
  33. save_optimizer: bool=False # will be used if using FSDP
  34. use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels