llama_finetuning.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List, Union
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. import os.path as osp
  11. from tqdm import tqdm
  12. # Unused imports removed
  13. from utils import fsdp_auto_wrap_policy
  14. from transformers import (
  15. LlamaForCausalLM,
  16. LlamaTokenizer,
  17. LlamaConfig,
  18. AutoModelForCausalLM,
  19. AutoModelForSeq2SeqLM,
  20. AutoTokenizer,
  21. default_data_collator,
  22. BitsAndBytesConfig
  23. )
  24. import torch.distributed as dist
  25. # Unused imports removed
  26. from utils.train_utils import (
  27. set_tokenizer_params,
  28. train,
  29. evaluation,
  30. freeze_transformer_layers,
  31. check_frozen_layers_peft_model,
  32. setup,
  33. setup_environ_flags,
  34. cleanup,
  35. clear_gpu_cache,
  36. get_parameter_dtypes,
  37. print_model_size,
  38. get_policies
  39. )
  40. from accelerate import init_empty_weights
  41. from utils.dataset_utils import get_preprocessed_dataset
  42. from utils.config_utils import (
  43. update_config,
  44. generate_peft_config,
  45. generate_dataset_config,
  46. )
  47. from peft import get_peft_model, TaskType, prepare_model_for_int8_training
  48. import configs
  49. from torch.distributed.fsdp import (
  50. FullyShardedDataParallel as FSDP,
  51. MixedPrecision,
  52. )
  53. from torch.utils.data import DistributedSampler
  54. import policies
  55. from policies import AnyPrecisionAdamW
  56. from configs import fsdp_config, train_config
  57. import torch.optim as optim
  58. from torch.optim.lr_scheduler import StepLR
  59. from pkg_resources import packaging
  60. import torch
  61. import torch.nn as nn
  62. import torch.cuda.nccl as nccl
  63. import torch.distributed as dist
  64. from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
  65. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  66. def main(**kwargs):
  67. # Update the configuration for the training and sharding process
  68. update_config((train_config, fsdp_config), **kwargs)
  69. # Set the seeds for reproducibility
  70. torch.cuda.manual_seed(train_config.seed)
  71. torch.manual_seed(train_config.seed)
  72. if train_config.enable_fsdp:
  73. setup()
  74. # torchrun specific
  75. local_rank = int(os.environ["LOCAL_RANK"])
  76. rank = int(os.environ["RANK"])
  77. world_size = int(os.environ["WORLD_SIZE"])
  78. if torch.distributed.is_initialized():
  79. torch.cuda.set_device(rank)
  80. setup_environ_flags(rank)
  81. # Calculate gradient accumulation steps
  82. gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
  83. # Load the pre-trained model and setup its configuration
  84. if train_config.enable_fsdp:
  85. # for FSDP, we save cpu memory by loading pretrained model on rank0 only.
  86. # this avoids cpu oom when loading large models like llama 70B, in which case
  87. # model alone would consume 2+TB cpu mem (70 * 4 * 8)
  88. if rank == 0:
  89. model = LlamaForCausalLM.from_pretrained(
  90. train_config.model_name,
  91. load_in_8bit=True if train_config.quantization else None,
  92. device_map="auto" if train_config.quantization else None,
  93. )
  94. else:
  95. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  96. with init_empty_weights():
  97. model = LlamaForCausalLM(llama_config)
  98. else:
  99. model = LlamaForCausalLM.from_pretrained(
  100. train_config.model_name,
  101. load_in_8bit=True if train_config.quantization else None,
  102. device_map="auto" if train_config.quantization else None,
  103. )
  104. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  105. # Prepare the model for int8 training if quantization is enabled
  106. if train_config.quantization:
  107. model = prepare_model_for_int8_training(model)
  108. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  109. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  110. model.to(torch.bfloat16)
  111. # Load the tokenizer and add special tokens
  112. tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
  113. tokenizer.add_special_tokens(
  114. {
  115. "pad_token": "<PAD>",
  116. }
  117. )
  118. if train_config.use_peft:
  119. peft_config = generate_peft_config(train_config, kwargs)
  120. model = get_peft_model(model, peft_config)
  121. model.print_trainable_parameters()
  122. #setting up FSDP if enable_fsdp is enabled
  123. if train_config.enable_fsdp:
  124. if not train_config.use_peft and train_config.freeze_layers:
  125. freeze_transformer_layers(train_config.num_freeze_layers)
  126. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  127. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  128. # given the fast evolving PRs around meta device init, I am not sure
  129. # what is the best param_init_fn atm, maybe we can switch to simple to_emtpy.
  130. def _param_init_fn(module: nn.Module):
  131. torch.manual_seed(0)
  132. for submodule in module.modules():
  133. for param_name, param in submodule.named_parameters(recurse=False):
  134. if not _is_fsdp_flattened(param) and param.is_meta:
  135. materialized_param = nn.Parameter(
  136. torch.empty_like(param, device=torch.device("cuda"))
  137. )
  138. nn.init.uniform_(materialized_param)
  139. setattr(submodule, param_name, materialized_param)
  140. model = FSDP(
  141. model,
  142. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  143. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  144. sharding_strategy=fsdp_config.sharding_strategy,
  145. device_id=torch.cuda.current_device(),
  146. limit_all_gathers=True,
  147. sync_module_states=True,
  148. param_init_fn=None if rank == 0 else _param_init_fn,
  149. )
  150. if fsdp_config.fsdp_activation_checkpointing:
  151. policies.apply_fsdp_checkpointing(model)
  152. elif not train_config.quantization and not train_config.enable_fsdp:
  153. model.to("cuda")
  154. dataset_config = generate_dataset_config(train_config, kwargs)
  155. # Load and preprocess the dataset for training and validation
  156. dataset_train = get_preprocessed_dataset(
  157. tokenizer,
  158. dataset_config,
  159. split="train",
  160. )
  161. if not train_config.enable_fsdp or rank == 0:
  162. print(f"--> Training Set Length = {len(dataset_train)}")
  163. dataset_val = get_preprocessed_dataset(
  164. tokenizer,
  165. dataset_config,
  166. split="test",
  167. )
  168. if not train_config.enable_fsdp or rank == 0:
  169. print(f"--> Validation Set Length = {len(dataset_val)}")
  170. train_sampler = None
  171. val_sampler = None
  172. if train_config.enable_fsdp:
  173. train_sampler = DistributedSampler(
  174. dataset_train,
  175. rank=dist.get_rank(),
  176. num_replicas=dist.get_world_size(),
  177. shuffle=True,
  178. )
  179. if train_config.run_validation:
  180. val_sampler = DistributedSampler(
  181. dataset_val,
  182. rank=dist.get_rank(),
  183. num_replicas=dist.get_world_size(),
  184. )
  185. # Create DataLoaders for the training and validation dataset
  186. train_dataloader = torch.utils.data.DataLoader(
  187. dataset_train,
  188. batch_size=train_config.batch_size_training,
  189. num_workers=train_config.num_workers_dataloader,
  190. pin_memory=True,
  191. sampler=train_sampler if train_sampler else None,
  192. drop_last=True,
  193. collate_fn=default_data_collator,
  194. )
  195. if train_config.run_validation:
  196. eval_dataloader = torch.utils.data.DataLoader(
  197. dataset_val,
  198. batch_size=train_config.val_batch_size,
  199. num_workers=train_config.num_workers_dataloader,
  200. pin_memory=True,
  201. sampler=val_sampler if val_sampler else None,
  202. drop_last=True,
  203. collate_fn=default_data_collator,
  204. )
  205. # Initialize the optimizer and learning rate scheduler
  206. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  207. optimizer = AnyPrecisionAdamW(
  208. model.parameters(),
  209. lr=train_config.lr,
  210. momentum_dtype=torch.bfloat16,
  211. variance_dtype=torch.bfloat16,
  212. use_kahan_summation=False,
  213. )
  214. else:
  215. optimizer = optim.AdamW(
  216. model.parameters(),
  217. lr=train_config.lr,
  218. weight_decay=0.0,
  219. )
  220. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  221. # Start the training process
  222. results = train(
  223. model,
  224. train_dataloader,
  225. eval_dataloader,
  226. tokenizer,
  227. optimizer,
  228. scheduler,
  229. gradient_accumulation_steps,
  230. train_config,
  231. fsdp_config if train_config.enable_fsdp else None,
  232. local_rank if train_config.enable_fsdp else None,
  233. rank if train_config.enable_fsdp else None,
  234. )
  235. if not train_config.enable_fsdp or rank==0:
  236. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  237. if __name__ == "__main__":
  238. fire.Fire(main)