training.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. @dataclass
  5. class train_config:
  6. model_name: str="PATH/to/LLAMA/7B"
  7. tokenizer_name: str=None
  8. enable_fsdp: bool=False
  9. low_cpu_fsdp: bool=False
  10. run_validation: bool=True
  11. batch_size_training: int=4
  12. batching_strategy: str="packing" #alternative: padding
  13. context_length: int=4096
  14. gradient_accumulation_steps: int=1
  15. gradient_clipping: bool = False
  16. gradient_clipping_threshold: float = 1.0
  17. num_epochs: int=3
  18. num_workers_dataloader: int=1
  19. lr: float=1e-4
  20. weight_decay: float=0.0
  21. gamma: float= 0.85
  22. seed: int=42
  23. use_fp16: bool=False
  24. mixed_precision: bool=True
  25. val_batch_size: int=1
  26. dataset = "samsum_dataset"
  27. peft_method: str = "lora" # None , llama_adapter, prefix
  28. use_peft: bool=False
  29. output_dir: str = "PATH/to/save/PEFT/model"
  30. freeze_layers: bool = False
  31. num_freeze_layers: int = 1
  32. quantization: bool = False
  33. one_gpu: bool = False
  34. save_model: bool = True
  35. dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
  36. dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
  37. save_optimizer: bool=False # will be used if using FSDP
  38. use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
  39. use_wandb: bool = False # Enable wandb for experient tracking
  40. save_metrics: bool = False # saves training metrics to a json file for later plotting