llama_finetuning.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List, Union
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. import os.path as osp
  11. from tqdm import tqdm
  12. # Unused imports removed
  13. from utils import fsdp_auto_wrap_policy
  14. from transformers import (
  15. LlamaForCausalLM,
  16. LlamaTokenizer,
  17. AutoModelForCausalLM,
  18. AutoModelForSeq2SeqLM,
  19. AutoTokenizer,
  20. default_data_collator,
  21. BitsAndBytesConfig
  22. )
  23. import torch.distributed as dist
  24. # Unused imports removed
  25. from utils.train_utils import (
  26. set_tokenizer_params,
  27. train,
  28. evaluation,
  29. freeze_transformer_layers,
  30. check_frozen_layers_peft_model,
  31. setup,
  32. setup_environ_flags,
  33. cleanup,
  34. clear_gpu_cache,
  35. get_parameter_dtypes,
  36. print_model_size,
  37. get_policies
  38. )
  39. from utils.dataset_utils import get_preprocessed_dataset
  40. from utils.config_utils import (
  41. update_config,
  42. generate_peft_config,
  43. generate_dataset_config,
  44. )
  45. from peft import get_peft_model, TaskType, prepare_model_for_int8_training
  46. import configs
  47. from torch.distributed.fsdp import (
  48. FullyShardedDataParallel as FSDP,
  49. MixedPrecision,
  50. )
  51. from torch.utils.data import DistributedSampler
  52. import policies
  53. from policies import AnyPrecisionAdamW
  54. from configs import fsdp_config, train_config
  55. import torch.optim as optim
  56. from torch.optim.lr_scheduler import StepLR
  57. from pkg_resources import packaging
  58. import torch
  59. import torch.cuda.nccl as nccl
  60. import torch.distributed as dist
  61. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  62. from accelerate.utils import is_xpu_available
  63. def main(**kwargs):
  64. # Update the configuration for the training and sharding process
  65. update_config((train_config, fsdp_config), **kwargs)
  66. # Set the seeds for reproducibility
  67. if is_xpu_available():
  68. torch.xpu.manual_seed(train_config.seed)
  69. else:
  70. torch.cuda.manual_seed(train_config.seed)
  71. torch.manual_seed(train_config.seed)
  72. if train_config.enable_fsdp:
  73. setup()
  74. # torchrun specific
  75. local_rank = int(os.environ["LOCAL_RANK"])
  76. rank = int(os.environ["RANK"])
  77. world_size = int(os.environ["WORLD_SIZE"])
  78. if torch.distributed.is_initialized():
  79. if is_xpu_available():
  80. torch.xpu.set_device(rank)
  81. else:
  82. torch.cuda.set_device(rank)
  83. setup_environ_flags(rank)
  84. # Calculate gradient accumulation steps
  85. gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
  86. # Load the pre-trained model and setup its configuration
  87. model = LlamaForCausalLM.from_pretrained(
  88. train_config.model_name,
  89. load_in_8bit=True if train_config.quantization else None,
  90. device_map="auto" if train_config.quantization else None,
  91. )
  92. if train_config.enable_fsdp and train_config.use_fast_kernels:
  93. """
  94. For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
  95. using of Flash Attention or Xformer memory-efficient kernels
  96. based on the hardware being used. This would speed up fine-tuning.
  97. """
  98. try:
  99. from optimum.bettertransformer import BetterTransformer
  100. model = BetterTransformer.transform(model)
  101. except ImportError:
  102. print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
  103. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  104. # Prepare the model for int8 training if quantization is enabled
  105. if train_config.quantization:
  106. model = prepare_model_for_int8_training(model)
  107. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  108. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  109. model.to(torch.bfloat16)
  110. # Load the tokenizer and add special tokens
  111. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  112. tokenizer.add_special_tokens(
  113. {
  114. "pad_token": "<PAD>",
  115. }
  116. )
  117. if train_config.use_peft:
  118. peft_config = generate_peft_config(train_config, kwargs)
  119. model = get_peft_model(model, peft_config)
  120. model.print_trainable_parameters()
  121. #setting up FSDP if enable_fsdp is enabled
  122. if train_config.enable_fsdp:
  123. if not train_config.use_peft and train_config.freeze_layers:
  124. freeze_transformer_layers(train_config.num_freeze_layers)
  125. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  126. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  127. model = FSDP(
  128. model,
  129. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  130. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  131. sharding_strategy=fsdp_config.sharding_strategy,
  132. device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
  133. limit_all_gathers=True,
  134. )
  135. if fsdp_config.fsdp_activation_checkpointing:
  136. policies.apply_fsdp_checkpointing(model)
  137. elif not train_config.quantization and not train_config.enable_fsdp:
  138. if is_xpu_available():
  139. model.to("xpu:0")
  140. else:
  141. model.to("cuda")
  142. dataset_config = generate_dataset_config(train_config, kwargs)
  143. # Load and preprocess the dataset for training and validation
  144. dataset_train = get_preprocessed_dataset(
  145. tokenizer,
  146. dataset_config,
  147. split="train",
  148. )
  149. if not train_config.enable_fsdp or rank == 0:
  150. print(f"--> Training Set Length = {len(dataset_train)}")
  151. dataset_val = get_preprocessed_dataset(
  152. tokenizer,
  153. dataset_config,
  154. split="test",
  155. )
  156. if not train_config.enable_fsdp or rank == 0:
  157. print(f"--> Validation Set Length = {len(dataset_val)}")
  158. train_sampler = None
  159. val_sampler = None
  160. if train_config.enable_fsdp:
  161. train_sampler = DistributedSampler(
  162. dataset_train,
  163. rank=dist.get_rank(),
  164. num_replicas=dist.get_world_size(),
  165. shuffle=True,
  166. )
  167. if train_config.run_validation:
  168. val_sampler = DistributedSampler(
  169. dataset_val,
  170. rank=dist.get_rank(),
  171. num_replicas=dist.get_world_size(),
  172. )
  173. # Create DataLoaders for the training and validation dataset
  174. train_dataloader = torch.utils.data.DataLoader(
  175. dataset_train,
  176. batch_size=train_config.batch_size_training,
  177. num_workers=train_config.num_workers_dataloader,
  178. pin_memory=True,
  179. sampler=train_sampler if train_sampler else None,
  180. drop_last=True,
  181. collate_fn=default_data_collator,
  182. )
  183. if train_config.run_validation:
  184. eval_dataloader = torch.utils.data.DataLoader(
  185. dataset_val,
  186. batch_size=train_config.val_batch_size,
  187. num_workers=train_config.num_workers_dataloader,
  188. pin_memory=True,
  189. sampler=val_sampler if val_sampler else None,
  190. drop_last=True,
  191. collate_fn=default_data_collator,
  192. )
  193. # Initialize the optimizer and learning rate scheduler
  194. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  195. optimizer = AnyPrecisionAdamW(
  196. model.parameters(),
  197. lr=train_config.lr,
  198. momentum_dtype=torch.bfloat16,
  199. variance_dtype=torch.bfloat16,
  200. use_kahan_summation=False,
  201. )
  202. else:
  203. optimizer = optim.AdamW(
  204. model.parameters(),
  205. lr=train_config.lr,
  206. weight_decay=0.0,
  207. )
  208. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  209. # Start the training process
  210. results = train(
  211. model,
  212. train_dataloader,
  213. eval_dataloader,
  214. tokenizer,
  215. optimizer,
  216. scheduler,
  217. gradient_accumulation_steps,
  218. train_config,
  219. fsdp_config if train_config.enable_fsdp else None,
  220. local_rank if train_config.enable_fsdp else None,
  221. rank if train_config.enable_fsdp else None,
  222. )
  223. if not train_config.enable_fsdp or rank==0:
  224. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  225. if __name__ == "__main__":
  226. fire.Fire(main)