llama_finetuning.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List, Union
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. import os.path as osp
  11. from tqdm import tqdm
  12. # Unused imports removed
  13. from utils import fsdp_auto_wrap_policy
  14. from transformers import (
  15. LlamaForCausalLM,
  16. LlamaTokenizer,
  17. LlamaConfig,
  18. AutoModelForCausalLM,
  19. AutoModelForSeq2SeqLM,
  20. AutoTokenizer,
  21. default_data_collator,
  22. BitsAndBytesConfig
  23. )
  24. import torch.distributed as dist
  25. # Unused imports removed
  26. from utils.train_utils import (
  27. set_tokenizer_params,
  28. train,
  29. evaluation,
  30. freeze_transformer_layers,
  31. check_frozen_layers_peft_model,
  32. setup,
  33. setup_environ_flags,
  34. cleanup,
  35. clear_gpu_cache,
  36. get_parameter_dtypes,
  37. print_model_size,
  38. get_policies
  39. )
  40. from utils.dataset_utils import get_preprocessed_dataset
  41. from utils.config_utils import (
  42. update_config,
  43. generate_peft_config,
  44. generate_dataset_config,
  45. )
  46. from peft import get_peft_model, TaskType, prepare_model_for_int8_training
  47. import configs
  48. from torch.distributed.fsdp import (
  49. FullyShardedDataParallel as FSDP,
  50. MixedPrecision,
  51. )
  52. from torch.utils.data import DistributedSampler
  53. import policies
  54. from policies import AnyPrecisionAdamW
  55. from configs import fsdp_config, train_config
  56. import torch.optim as optim
  57. from torch.optim.lr_scheduler import StepLR
  58. from pkg_resources import packaging
  59. import torch
  60. import torch.nn as nn
  61. import torch.cuda.nccl as nccl
  62. import torch.distributed as dist
  63. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  64. def main(**kwargs):
  65. # Update the configuration for the training and sharding process
  66. update_config((train_config, fsdp_config), **kwargs)
  67. # Set the seeds for reproducibility
  68. torch.cuda.manual_seed(train_config.seed)
  69. torch.manual_seed(train_config.seed)
  70. if train_config.enable_fsdp:
  71. setup()
  72. # torchrun specific
  73. local_rank = int(os.environ["LOCAL_RANK"])
  74. rank = int(os.environ["RANK"])
  75. world_size = int(os.environ["WORLD_SIZE"])
  76. if torch.distributed.is_initialized():
  77. torch.cuda.set_device(local_rank)
  78. clear_gpu_cache(local_rank)
  79. setup_environ_flags(rank)
  80. # Calculate gradient accumulation steps
  81. gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
  82. # Load the pre-trained model and setup its configuration
  83. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  84. # for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  85. # this avoids cpu oom when loading large models like llama 70B, in which case
  86. # model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  87. # overhead and currently requires latest nightly.
  88. v = packaging.version.parse(torch.__version__)
  89. verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
  90. if not verify_latest_nightly:
  91. raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
  92. "please install latest nightly.")
  93. if rank == 0:
  94. model = LlamaForCausalLM.from_pretrained(
  95. train_config.model_name,
  96. load_in_8bit=True if train_config.quantization else None,
  97. device_map="auto" if train_config.quantization else None,
  98. )
  99. else:
  100. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  101. with torch.device("meta"):
  102. model = LlamaForCausalLM(llama_config)
  103. else:
  104. model = LlamaForCausalLM.from_pretrained(
  105. train_config.model_name,
  106. load_in_8bit=True if train_config.quantization else None,
  107. device_map="auto" if train_config.quantization else None,
  108. )
  109. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  110. # Prepare the model for int8 training if quantization is enabled
  111. if train_config.quantization:
  112. model = prepare_model_for_int8_training(model)
  113. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  114. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  115. model.to(torch.bfloat16)
  116. # Load the tokenizer and add special tokens
  117. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  118. tokenizer.add_special_tokens(
  119. {
  120. "pad_token": "<PAD>",
  121. }
  122. )
  123. if train_config.use_peft:
  124. peft_config = generate_peft_config(train_config, kwargs)
  125. model = get_peft_model(model, peft_config)
  126. model.print_trainable_parameters()
  127. #setting up FSDP if enable_fsdp is enabled
  128. if train_config.enable_fsdp:
  129. if not train_config.use_peft and train_config.freeze_layers:
  130. freeze_transformer_layers(train_config.num_freeze_layers)
  131. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  132. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  133. model = FSDP(
  134. model,
  135. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  136. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  137. sharding_strategy=fsdp_config.sharding_strategy,
  138. device_id=torch.cuda.current_device(),
  139. limit_all_gathers=True,
  140. sync_module_states=True if train_config.low_cpu_fsdp else False,
  141. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  142. if train_config.low_cpu_fsdp and rank != 0 else None,
  143. )
  144. if fsdp_config.fsdp_activation_checkpointing:
  145. policies.apply_fsdp_checkpointing(model)
  146. elif not train_config.quantization and not train_config.enable_fsdp:
  147. model.to("cuda")
  148. dataset_config = generate_dataset_config(train_config, kwargs)
  149. # Load and preprocess the dataset for training and validation
  150. dataset_train = get_preprocessed_dataset(
  151. tokenizer,
  152. dataset_config,
  153. split="train",
  154. )
  155. if not train_config.enable_fsdp or rank == 0:
  156. print(f"--> Training Set Length = {len(dataset_train)}")
  157. dataset_val = get_preprocessed_dataset(
  158. tokenizer,
  159. dataset_config,
  160. split="test",
  161. )
  162. if not train_config.enable_fsdp or rank == 0:
  163. print(f"--> Validation Set Length = {len(dataset_val)}")
  164. train_sampler = None
  165. val_sampler = None
  166. if train_config.enable_fsdp:
  167. train_sampler = DistributedSampler(
  168. dataset_train,
  169. rank=dist.get_rank(),
  170. num_replicas=dist.get_world_size(),
  171. shuffle=True,
  172. )
  173. if train_config.run_validation:
  174. val_sampler = DistributedSampler(
  175. dataset_val,
  176. rank=dist.get_rank(),
  177. num_replicas=dist.get_world_size(),
  178. )
  179. # Create DataLoaders for the training and validation dataset
  180. train_dataloader = torch.utils.data.DataLoader(
  181. dataset_train,
  182. batch_size=train_config.batch_size_training,
  183. num_workers=train_config.num_workers_dataloader,
  184. pin_memory=True,
  185. sampler=train_sampler if train_sampler else None,
  186. drop_last=True,
  187. collate_fn=default_data_collator,
  188. )
  189. if train_config.run_validation:
  190. eval_dataloader = torch.utils.data.DataLoader(
  191. dataset_val,
  192. batch_size=train_config.val_batch_size,
  193. num_workers=train_config.num_workers_dataloader,
  194. pin_memory=True,
  195. sampler=val_sampler if val_sampler else None,
  196. drop_last=True,
  197. collate_fn=default_data_collator,
  198. )
  199. # Initialize the optimizer and learning rate scheduler
  200. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  201. optimizer = AnyPrecisionAdamW(
  202. model.parameters(),
  203. lr=train_config.lr,
  204. momentum_dtype=torch.bfloat16,
  205. variance_dtype=torch.bfloat16,
  206. use_kahan_summation=False,
  207. )
  208. else:
  209. optimizer = optim.AdamW(
  210. model.parameters(),
  211. lr=train_config.lr,
  212. weight_decay=0.0,
  213. )
  214. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  215. # Start the training process
  216. results = train(
  217. model,
  218. train_dataloader,
  219. eval_dataloader,
  220. tokenizer,
  221. optimizer,
  222. scheduler,
  223. gradient_accumulation_steps,
  224. train_config,
  225. fsdp_config if train_config.enable_fsdp else None,
  226. local_rank if train_config.enable_fsdp else None,
  227. rank if train_config.enable_fsdp else None,
  228. )
  229. if not train_config.enable_fsdp or rank==0:
  230. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  231. if __name__ == "__main__":
  232. fire.Fire(main)