finetuning.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. from pkg_resources import packaging
  5. import fire
  6. import torch
  7. import torch.distributed as dist
  8. import torch.optim as optim
  9. from peft import get_peft_model, prepare_model_for_int8_training
  10. from torch.distributed.fsdp import (
  11. FullyShardedDataParallel as FSDP,
  12. )
  13. from torch.optim.lr_scheduler import StepLR
  14. from torch.utils.data import DistributedSampler
  15. from transformers import (
  16. LlamaForCausalLM,
  17. LlamaTokenizer,
  18. LlamaConfig,
  19. default_data_collator,
  20. )
  21. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  22. from llama_recipes.configs import fsdp_config, train_config
  23. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  24. from llama_recipes.utils import fsdp_auto_wrap_policy
  25. from llama_recipes.utils.config_utils import (
  26. update_config,
  27. generate_peft_config,
  28. generate_dataset_config,
  29. )
  30. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
  31. from llama_recipes.utils.train_utils import (
  32. train,
  33. freeze_transformer_layers,
  34. setup,
  35. setup_environ_flags,
  36. clear_gpu_cache,
  37. print_model_size,
  38. get_policies
  39. )
  40. from accelerate.utils import is_xpu_available
  41. def main(**kwargs):
  42. # Update the configuration for the training and sharding process
  43. update_config((train_config, fsdp_config), **kwargs)
  44. # Set the seeds for reproducibility
  45. if is_xpu_available():
  46. torch.xpu.manual_seed(train_config.seed)
  47. else:
  48. torch.cuda.manual_seed(train_config.seed)
  49. torch.manual_seed(train_config.seed)
  50. if train_config.enable_fsdp:
  51. setup()
  52. # torchrun specific
  53. local_rank = int(os.environ["LOCAL_RANK"])
  54. rank = int(os.environ["RANK"])
  55. world_size = int(os.environ["WORLD_SIZE"])
  56. if torch.distributed.is_initialized():
  57. if is_xpu_available():
  58. torch.xpu.set_device(local_rank)
  59. else:
  60. torch.cuda.set_device(local_rank)
  61. clear_gpu_cache(local_rank)
  62. setup_environ_flags(rank)
  63. # Load the pre-trained model and setup its configuration
  64. use_cache = False if train_config.enable_fsdp else None
  65. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  66. """
  67. for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  68. this avoids cpu oom when loading large models like llama 70B, in which case
  69. model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  70. overhead and currently requires latest nightly.
  71. """
  72. v = packaging.version.parse(torch.__version__)
  73. verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
  74. if not verify_latest_nightly:
  75. raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
  76. "please install latest nightly.")
  77. if rank == 0:
  78. model = LlamaForCausalLM.from_pretrained(
  79. train_config.model_name,
  80. load_in_8bit=True if train_config.quantization else None,
  81. device_map="auto" if train_config.quantization else None,
  82. use_cache=use_cache,
  83. )
  84. else:
  85. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  86. llama_config.use_cache = use_cache
  87. with torch.device("meta"):
  88. model = LlamaForCausalLM(llama_config)
  89. else:
  90. model = LlamaForCausalLM.from_pretrained(
  91. train_config.model_name,
  92. load_in_8bit=True if train_config.quantization else None,
  93. device_map="auto" if train_config.quantization else None,
  94. use_cache=use_cache,
  95. )
  96. if train_config.enable_fsdp and train_config.use_fast_kernels:
  97. """
  98. For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
  99. using of Flash Attention or Xformer memory-efficient kernels
  100. based on the hardware being used. This would speed up fine-tuning.
  101. """
  102. try:
  103. from optimum.bettertransformer import BetterTransformer
  104. model = BetterTransformer.transform(model)
  105. except ImportError:
  106. print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
  107. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  108. # Prepare the model for int8 training if quantization is enabled
  109. if train_config.quantization:
  110. model = prepare_model_for_int8_training(model)
  111. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  112. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  113. model.to(torch.bfloat16)
  114. # Load the tokenizer and add special tokens
  115. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  116. tokenizer.add_special_tokens(
  117. {
  118. "pad_token": "<PAD>",
  119. }
  120. )
  121. if train_config.use_peft:
  122. peft_config = generate_peft_config(train_config, kwargs)
  123. model = get_peft_model(model, peft_config)
  124. model.print_trainable_parameters()
  125. #setting up FSDP if enable_fsdp is enabled
  126. if train_config.enable_fsdp:
  127. if not train_config.use_peft and train_config.freeze_layers:
  128. freeze_transformer_layers(train_config.num_freeze_layers)
  129. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  130. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  131. model = FSDP(
  132. model,
  133. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  134. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  135. sharding_strategy=fsdp_config.sharding_strategy,
  136. device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
  137. limit_all_gathers=True,
  138. sync_module_states=train_config.low_cpu_fsdp,
  139. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  140. if train_config.low_cpu_fsdp and rank != 0 else None,
  141. )
  142. if fsdp_config.fsdp_activation_checkpointing:
  143. apply_fsdp_checkpointing(model)
  144. elif not train_config.quantization and not train_config.enable_fsdp:
  145. if is_xpu_available():
  146. model.to("xpu:0")
  147. else:
  148. model.to("cuda")
  149. dataset_config = generate_dataset_config(train_config, kwargs)
  150. # Load and preprocess the dataset for training and validation
  151. dataset_train = get_preprocessed_dataset(
  152. tokenizer,
  153. dataset_config,
  154. split="train",
  155. )
  156. if not train_config.enable_fsdp or rank == 0:
  157. print(f"--> Training Set Length = {len(dataset_train)}")
  158. dataset_val = get_preprocessed_dataset(
  159. tokenizer,
  160. dataset_config,
  161. split="test",
  162. )
  163. if not train_config.enable_fsdp or rank == 0:
  164. print(f"--> Validation Set Length = {len(dataset_val)}")
  165. train_sampler = None
  166. val_sampler = None
  167. if train_config.enable_fsdp:
  168. train_sampler = DistributedSampler(
  169. dataset_train,
  170. rank=dist.get_rank(),
  171. num_replicas=dist.get_world_size(),
  172. shuffle=True,
  173. )
  174. if train_config.run_validation:
  175. val_sampler = DistributedSampler(
  176. dataset_val,
  177. rank=dist.get_rank(),
  178. num_replicas=dist.get_world_size(),
  179. )
  180. # Create DataLoaders for the training and validation dataset
  181. train_dataloader = torch.utils.data.DataLoader(
  182. dataset_train,
  183. batch_size=train_config.batch_size_training,
  184. num_workers=train_config.num_workers_dataloader,
  185. pin_memory=True,
  186. sampler=train_sampler if train_sampler else None,
  187. drop_last=True,
  188. collate_fn=default_data_collator,
  189. )
  190. eval_dataloader = None
  191. if train_config.run_validation:
  192. eval_dataloader = torch.utils.data.DataLoader(
  193. dataset_val,
  194. batch_size=train_config.val_batch_size,
  195. num_workers=train_config.num_workers_dataloader,
  196. pin_memory=True,
  197. sampler=val_sampler if val_sampler else None,
  198. drop_last=True,
  199. collate_fn=default_data_collator,
  200. )
  201. # Initialize the optimizer and learning rate scheduler
  202. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  203. optimizer = AnyPrecisionAdamW(
  204. model.parameters(),
  205. lr=train_config.lr,
  206. momentum_dtype=torch.bfloat16,
  207. variance_dtype=torch.bfloat16,
  208. use_kahan_summation=False,
  209. )
  210. else:
  211. optimizer = optim.AdamW(
  212. model.parameters(),
  213. lr=train_config.lr,
  214. weight_decay=0.0,
  215. )
  216. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  217. # Start the training process
  218. results = train(
  219. model,
  220. train_dataloader,
  221. eval_dataloader,
  222. tokenizer,
  223. optimizer,
  224. scheduler,
  225. train_config.gradient_accumulation_steps,
  226. train_config,
  227. fsdp_config if train_config.enable_fsdp else None,
  228. local_rank if train_config.enable_fsdp else None,
  229. rank if train_config.enable_fsdp else None,
  230. )
  231. if not train_config.enable_fsdp or rank==0:
  232. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  233. if __name__ == "__main__":
  234. fire.Fire(main)