training.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. @dataclass
  5. class train_config:
  6. model_name: str="PATH/to/LLAMA/7B"
  7. tokenizer_name: str=None
  8. enable_fsdp: bool=False
  9. low_cpu_fsdp: bool=False
  10. run_validation: bool=True
  11. batch_size_training: int=4
  12. batching_strategy: str="packing" #alternative: padding
  13. context_length: int=4096
  14. gradient_accumulation_steps: int=1
  15. gradient_clipping: bool = False
  16. gradient_clipping_threshold: float = 1.0
  17. num_epochs: int=3
  18. max_train_step: int=0
  19. max_eval_step: int=0
  20. num_workers_dataloader: int=1
  21. lr: float=1e-4
  22. weight_decay: float=0.0
  23. gamma: float= 0.85
  24. seed: int=42
  25. use_fp16: bool=False
  26. mixed_precision: bool=True
  27. val_batch_size: int=1
  28. dataset = "school_math_dataset"
  29. peft_method: str = "lora" # None , llama_adapter, prefix
  30. use_peft: bool=False
  31. output_dir: str = "PATH/to/save/PEFT/model"
  32. freeze_layers: bool = False
  33. num_freeze_layers: int = 1
  34. quantization: bool = False
  35. one_gpu: bool = False
  36. save_model: bool = True
  37. dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
  38. dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
  39. save_optimizer: bool=False # will be used if using FSDP
  40. use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
  41. use_wandb: bool = False # Enable wandb for experient tracking
  42. save_metrics: bool = False # saves training metrics to a json file for later plotting