training.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass
  4. @dataclass
  5. class train_config:
  6. model_name: str="PATH/to/LLAMA/7B"
  7. enable_fsdp: bool=False
  8. low_cpu_fsdp: bool=False
  9. run_validation: bool=True
  10. batch_size_training: int=4
  11. batching_strategy: str="packing" #alternative: padding
  12. context_length: int=4096
  13. gradient_accumulation_steps: int=1
  14. num_epochs: int=3
  15. num_workers_dataloader: int=1
  16. lr: float=1e-4
  17. weight_decay: float=0.0
  18. gamma: float= 0.85
  19. seed: int=42
  20. use_fp16: bool=False
  21. mixed_precision: bool=True
  22. val_batch_size: int=1
  23. dataset = "samsum_dataset"
  24. peft_method: str = "lora" # None , llama_adapter, prefix
  25. use_peft: bool=False
  26. output_dir: str = "PATH/to/save/PEFT/model"
  27. freeze_layers: bool = False
  28. num_freeze_layers: int = 1
  29. quantization: bool = False
  30. one_gpu: bool = False
  31. save_model: bool = True
  32. dist_checkpoint_root_folder: str="PATH/to/save/FSDP/model" # will be used if using FSDP
  33. dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
  34. save_optimizer: bool=False # will be used if using FSDP
  35. use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels