finetuning.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. from pkg_resources import packaging
  5. import fire
  6. import random
  7. import torch
  8. import torch.optim as optim
  9. from peft import get_peft_model, prepare_model_for_int8_training
  10. from torch.distributed.fsdp import (
  11. FullyShardedDataParallel as FSDP,
  12. )
  13. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  14. from torch.optim.lr_scheduler import StepLR
  15. from transformers import (
  16. LlamaForCausalLM,
  17. LlamaTokenizer,
  18. LlamaConfig,
  19. )
  20. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  21. from llama_recipes.configs import fsdp_config as FSDP_CONFIG
  22. from llama_recipes.configs import train_config as TRAIN_CONFIG
  23. from llama_recipes.data.concatenator import ConcatDataset
  24. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  25. from llama_recipes.utils import fsdp_auto_wrap_policy
  26. from llama_recipes.utils.config_utils import (
  27. update_config,
  28. generate_peft_config,
  29. generate_dataset_config,
  30. get_dataloader_kwargs,
  31. )
  32. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
  33. from llama_recipes.utils.train_utils import (
  34. train,
  35. freeze_transformer_layers,
  36. setup,
  37. setup_environ_flags,
  38. clear_gpu_cache,
  39. print_model_size,
  40. get_policies
  41. )
  42. def setup_wandb(train_config, fsdp_config, **kwargs):
  43. try:
  44. import wandb
  45. except ImportError:
  46. raise ImportError(
  47. "You are trying to use wandb which is not currently installed"
  48. "Please install it using pip install wandb"
  49. )
  50. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  51. wandb_config = WANDB_CONFIG()
  52. wandb_entity = None if wandb_config.wandb_entity == 'none' else wandb_config.wandb_entity
  53. update_config(wandb_config, **kwargs)
  54. run = wandb.init(project=wandb_config.wandb_project, entity=wandb_entity)
  55. run.config.update(train_config)
  56. run.config.update(fsdp_config)
  57. return run
  58. def main(**kwargs):
  59. # Update the configuration for the training and sharding process
  60. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  61. update_config((train_config, fsdp_config), **kwargs)
  62. # Set the seeds for reproducibility
  63. torch.cuda.manual_seed(train_config.seed)
  64. torch.manual_seed(train_config.seed)
  65. random.seed(train_config.seed)
  66. if train_config.enable_fsdp:
  67. setup()
  68. # torchrun specific
  69. local_rank = int(os.environ["LOCAL_RANK"])
  70. rank = int(os.environ["RANK"])
  71. world_size = int(os.environ["WORLD_SIZE"])
  72. if torch.distributed.is_initialized():
  73. torch.cuda.set_device(local_rank)
  74. clear_gpu_cache(local_rank)
  75. setup_environ_flags(rank)
  76. if train_config.enable_wandb:
  77. if not train_config.enable_fsdp or rank==0:
  78. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  79. # Load the pre-trained model and setup its configuration
  80. use_cache = False if train_config.enable_fsdp else None
  81. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  82. """
  83. for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  84. this avoids cpu oom when loading large models like llama 70B, in which case
  85. model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  86. overhead and currently requires latest nightly.
  87. """
  88. v = packaging.version.parse(torch.__version__)
  89. verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
  90. if not verify_latest_nightly:
  91. raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
  92. "please install latest nightly.")
  93. if rank == 0:
  94. model = LlamaForCausalLM.from_pretrained(
  95. train_config.model_name,
  96. load_in_8bit=True if train_config.quantization else None,
  97. device_map="auto" if train_config.quantization else None,
  98. use_cache=use_cache,
  99. )
  100. else:
  101. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  102. llama_config.use_cache = use_cache
  103. with torch.device("meta"):
  104. model = LlamaForCausalLM(llama_config)
  105. else:
  106. model = LlamaForCausalLM.from_pretrained(
  107. train_config.model_name,
  108. load_in_8bit=True if train_config.quantization else None,
  109. device_map="auto" if train_config.quantization else None,
  110. use_cache=use_cache,
  111. )
  112. if train_config.enable_fsdp and train_config.use_fast_kernels:
  113. """
  114. For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
  115. using of Flash Attention or Xformer memory-efficient kernels
  116. based on the hardware being used. This would speed up fine-tuning.
  117. """
  118. try:
  119. from optimum.bettertransformer import BetterTransformer
  120. model = BetterTransformer.transform(model)
  121. except ImportError:
  122. print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
  123. # Load the tokenizer and add special tokens
  124. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  125. tokenizer.pad_token_id = tokenizer.eos_token_id
  126. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  127. # Prepare the model for int8 training if quantization is enabled
  128. if train_config.quantization:
  129. model = prepare_model_for_int8_training(model)
  130. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  131. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  132. model.to(torch.bfloat16)
  133. if train_config.use_peft:
  134. peft_config = generate_peft_config(train_config, kwargs)
  135. model = get_peft_model(model, peft_config)
  136. model.print_trainable_parameters()
  137. if train_config.enable_wandb:
  138. if not train_config.enable_fsdp or rank==0:
  139. wandb_run.config.update(peft_config)
  140. #setting up FSDP if enable_fsdp is enabled
  141. if train_config.enable_fsdp:
  142. if not train_config.use_peft and train_config.freeze_layers:
  143. freeze_transformer_layers(train_config.num_freeze_layers)
  144. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  145. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  146. model = FSDP(
  147. model,
  148. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  149. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
  150. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  151. sharding_strategy=fsdp_config.sharding_strategy,
  152. device_id=torch.cuda.current_device(),
  153. limit_all_gathers=True,
  154. sync_module_states=train_config.low_cpu_fsdp,
  155. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  156. if train_config.low_cpu_fsdp and rank != 0 else None,
  157. )
  158. if fsdp_config.fsdp_activation_checkpointing:
  159. apply_fsdp_checkpointing(model)
  160. elif not train_config.quantization and not train_config.enable_fsdp:
  161. model.to("cuda")
  162. dataset_config = generate_dataset_config(train_config, kwargs)
  163. # Load and preprocess the dataset for training and validation
  164. dataset_train = get_preprocessed_dataset(
  165. tokenizer,
  166. dataset_config,
  167. split="train",
  168. )
  169. if not train_config.enable_fsdp or rank == 0:
  170. print(f"--> Training Set Length = {len(dataset_train)}")
  171. dataset_val = get_preprocessed_dataset(
  172. tokenizer,
  173. dataset_config,
  174. split="test",
  175. )
  176. if not train_config.enable_fsdp or rank == 0:
  177. print(f"--> Validation Set Length = {len(dataset_val)}")
  178. if train_config.batching_strategy == "packing":
  179. dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
  180. train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
  181. # Create DataLoaders for the training and validation dataset
  182. train_dataloader = torch.utils.data.DataLoader(
  183. dataset_train,
  184. num_workers=train_config.num_workers_dataloader,
  185. pin_memory=True,
  186. **train_dl_kwargs,
  187. )
  188. eval_dataloader = None
  189. if train_config.run_validation:
  190. if train_config.batching_strategy == "packing":
  191. dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
  192. val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
  193. eval_dataloader = torch.utils.data.DataLoader(
  194. dataset_val,
  195. num_workers=train_config.num_workers_dataloader,
  196. pin_memory=True,
  197. **val_dl_kwargs,
  198. )
  199. # Initialize the optimizer and learning rate scheduler
  200. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  201. optimizer = AnyPrecisionAdamW(
  202. model.parameters(),
  203. lr=train_config.lr,
  204. momentum_dtype=torch.bfloat16,
  205. variance_dtype=torch.bfloat16,
  206. use_kahan_summation=False,
  207. weight_decay=train_config.weight_decay,
  208. )
  209. else:
  210. optimizer = optim.AdamW(
  211. model.parameters(),
  212. lr=train_config.lr,
  213. weight_decay=train_config.weight_decay,
  214. )
  215. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  216. # Start the training process
  217. results = train(
  218. model,
  219. train_dataloader,
  220. eval_dataloader,
  221. tokenizer,
  222. optimizer,
  223. scheduler,
  224. train_config.gradient_accumulation_steps,
  225. train_config,
  226. fsdp_config if train_config.enable_fsdp else None,
  227. local_rank if train_config.enable_fsdp else None,
  228. rank if train_config.enable_fsdp else None,
  229. wandb_run if train_config.enable_wandb else None,
  230. )
  231. if not train_config.enable_fsdp or rank==0:
  232. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  233. if train_config.enable_wandb:
  234. for k,v in results.items():
  235. wandb_run.summary[k] = v
  236. if __name__ == "__main__":
  237. fire.Fire(main)