llama_finetuning.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import fire
  5. import torch
  6. import torch.distributed as dist
  7. import torch.distributed as dist
  8. import torch.optim as optim
  9. from peft import get_peft_model, prepare_model_for_int8_training
  10. from pkg_resources import packaging
  11. from torch.distributed.fsdp import (
  12. FullyShardedDataParallel as FSDP,
  13. )
  14. from torch.optim.lr_scheduler import StepLR
  15. from torch.utils.data import DistributedSampler
  16. from transformers import (
  17. LlamaForCausalLM,
  18. LlamaTokenizer,
  19. LlamaConfig,
  20. default_data_collator,
  21. )
  22. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  23. import policies
  24. from configs import fsdp_config, train_config
  25. from policies import AnyPrecisionAdamW
  26. from utils import fsdp_auto_wrap_policy
  27. from utils.config_utils import (
  28. update_config,
  29. generate_peft_config,
  30. generate_dataset_config,
  31. )
  32. from utils.dataset_utils import get_preprocessed_dataset
  33. from utils.train_utils import (
  34. train,
  35. freeze_transformer_layers,
  36. setup,
  37. setup_environ_flags,
  38. clear_gpu_cache,
  39. print_model_size,
  40. get_policies
  41. )
  42. def main(**kwargs):
  43. # Update the configuration for the training and sharding process
  44. update_config((train_config, fsdp_config), **kwargs)
  45. # Set the seeds for reproducibility
  46. torch.cuda.manual_seed(train_config.seed)
  47. torch.manual_seed(train_config.seed)
  48. if train_config.enable_fsdp:
  49. setup()
  50. # torchrun specific
  51. local_rank = int(os.environ["LOCAL_RANK"])
  52. rank = int(os.environ["RANK"])
  53. world_size = int(os.environ["WORLD_SIZE"])
  54. if torch.distributed.is_initialized():
  55. torch.cuda.set_device(local_rank)
  56. clear_gpu_cache(local_rank)
  57. setup_environ_flags(rank)
  58. # Calculate gradient accumulation steps
  59. gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
  60. # Load the pre-trained model and setup its configuration
  61. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  62. # for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  63. # this avoids cpu oom when loading large models like llama 70B, in which case
  64. # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  65. # overhead and currently requires latest nightly.
  66. v = packaging.version.parse(torch.__version__)
  67. verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
  68. if not verify_latest_nightly:
  69. raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
  70. "please install latest nightly.")
  71. if rank == 0:
  72. model = LlamaForCausalLM.from_pretrained(
  73. train_config.model_name,
  74. load_in_8bit=True if train_config.quantization else None,
  75. device_map="auto" if train_config.quantization else None,
  76. )
  77. else:
  78. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  79. with torch.device("meta"):
  80. model = LlamaForCausalLM(llama_config)
  81. else:
  82. model = LlamaForCausalLM.from_pretrained(
  83. train_config.model_name,
  84. load_in_8bit=True if train_config.quantization else None,
  85. device_map="auto" if train_config.quantization else None,
  86. )
  87. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  88. # Prepare the model for int8 training if quantization is enabled
  89. if train_config.quantization:
  90. model = prepare_model_for_int8_training(model)
  91. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  92. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  93. model.to(torch.bfloat16)
  94. # Load the tokenizer and add special tokens
  95. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  96. tokenizer.add_special_tokens(
  97. {
  98. "pad_token": "<PAD>",
  99. }
  100. )
  101. if train_config.use_peft:
  102. peft_config = generate_peft_config(train_config, kwargs)
  103. model = get_peft_model(model, peft_config)
  104. model.print_trainable_parameters()
  105. #setting up FSDP if enable_fsdp is enabled
  106. if train_config.enable_fsdp:
  107. if not train_config.use_peft and train_config.freeze_layers:
  108. freeze_transformer_layers(train_config.num_freeze_layers)
  109. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  110. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  111. model = FSDP(
  112. model,
  113. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  114. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  115. sharding_strategy=fsdp_config.sharding_strategy,
  116. device_id=torch.cuda.current_device(),
  117. limit_all_gathers=True,
  118. sync_module_states=train_config.low_cpu_fsdp,
  119. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  120. if train_config.low_cpu_fsdp and rank != 0 else None,
  121. )
  122. if fsdp_config.fsdp_activation_checkpointing:
  123. policies.apply_fsdp_checkpointing(model)
  124. elif not train_config.quantization and not train_config.enable_fsdp:
  125. model.to("cuda")
  126. dataset_config = generate_dataset_config(train_config, kwargs)
  127. # Load and preprocess the dataset for training and validation
  128. dataset_train = get_preprocessed_dataset(
  129. tokenizer,
  130. dataset_config,
  131. split="train",
  132. )
  133. if not train_config.enable_fsdp or rank == 0:
  134. print(f"--> Training Set Length = {len(dataset_train)}")
  135. dataset_val = get_preprocessed_dataset(
  136. tokenizer,
  137. dataset_config,
  138. split="test",
  139. )
  140. if not train_config.enable_fsdp or rank == 0:
  141. print(f"--> Validation Set Length = {len(dataset_val)}")
  142. train_sampler = None
  143. val_sampler = None
  144. if train_config.enable_fsdp:
  145. train_sampler = DistributedSampler(
  146. dataset_train,
  147. rank=dist.get_rank(),
  148. num_replicas=dist.get_world_size(),
  149. shuffle=True,
  150. )
  151. if train_config.run_validation:
  152. val_sampler = DistributedSampler(
  153. dataset_val,
  154. rank=dist.get_rank(),
  155. num_replicas=dist.get_world_size(),
  156. )
  157. # Create DataLoaders for the training and validation dataset
  158. train_dataloader = torch.utils.data.DataLoader(
  159. dataset_train,
  160. batch_size=train_config.batch_size_training,
  161. num_workers=train_config.num_workers_dataloader,
  162. pin_memory=True,
  163. sampler=train_sampler if train_sampler else None,
  164. drop_last=True,
  165. collate_fn=default_data_collator,
  166. )
  167. if train_config.run_validation:
  168. eval_dataloader = torch.utils.data.DataLoader(
  169. dataset_val,
  170. batch_size=train_config.val_batch_size,
  171. num_workers=train_config.num_workers_dataloader,
  172. pin_memory=True,
  173. sampler=val_sampler if val_sampler else None,
  174. drop_last=True,
  175. collate_fn=default_data_collator,
  176. )
  177. # Initialize the optimizer and learning rate scheduler
  178. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  179. optimizer = AnyPrecisionAdamW(
  180. model.parameters(),
  181. lr=train_config.lr,
  182. momentum_dtype=torch.bfloat16,
  183. variance_dtype=torch.bfloat16,
  184. use_kahan_summation=False,
  185. )
  186. else:
  187. optimizer = optim.AdamW(
  188. model.parameters(),
  189. lr=train_config.lr,
  190. weight_decay=0.0,
  191. )
  192. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  193. # Start the training process
  194. results = train(
  195. model,
  196. train_dataloader,
  197. eval_dataloader,
  198. tokenizer,
  199. optimizer,
  200. scheduler,
  201. gradient_accumulation_steps,
  202. train_config,
  203. fsdp_config if train_config.enable_fsdp else None,
  204. local_rank if train_config.enable_fsdp else None,
  205. rank if train_config.enable_fsdp else None,
  206. )
  207. if not train_config.enable_fsdp or rank==0:
  208. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  209. if __name__ == "__main__":
  210. fire.Fire(main)