training.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. @dataclass
  5. class train_config:
  6. model_name: str="PATH/to/LLAMA/7B"
  7. enable_fsdp: bool=False
  8. low_cpu_fsdp: bool=False
  9. run_validation: bool=True
  10. batch_size_training: int=4
  11. batching_strategy: str="packing" #alternative: padding
  12. context_length: int=4096
  13. gradient_accumulation_steps: int=1
  14. gradient_clipping: bool = False
  15. gradient_clipping_threshold: float = 1.0
  16. num_epochs: int=3
  17. num_workers_dataloader: int=1
  18. lr: float=1e-4
  19. weight_decay: float=0.0
  20. gamma: float= 0.85
  21. seed: int=42
  22. use_fp16: bool=False
  23. mixed_precision: bool=True
  24. val_batch_size: int=1
  25. dataset = "samsum_dataset"
  26. peft_method: str = "lora" # None , llama_adapter, prefix
  27. use_peft: bool=False
  28. output_dir: str = "PATH/to/save/PEFT/model"
  29. freeze_layers: bool = False
  30. num_freeze_layers: int = 1
  31. quantization: bool = False
  32. one_gpu: bool = False
  33. save_model: bool = True
  34. dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
  35. dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
  36. save_optimizer: bool=False # will be used if using FSDP
  37. use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
  38. save_metrics: bool = False # saves training metrics to a json file for later plotting