finetuning.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import dataclasses
  5. import fire
  6. import random
  7. import torch
  8. import torch.optim as optim
  9. from peft import get_peft_model, prepare_model_for_kbit_training
  10. from torch.distributed.fsdp import (
  11. FullyShardedDataParallel as FSDP,
  12. ShardingStrategy
  13. )
  14. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  15. from torch.optim.lr_scheduler import StepLR
  16. from transformers import (
  17. AutoTokenizer,
  18. LlamaForCausalLM,
  19. LlamaConfig,
  20. )
  21. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  22. from llama_recipes.configs import fsdp_config as FSDP_CONFIG
  23. from llama_recipes.configs import train_config as TRAIN_CONFIG
  24. from llama_recipes.data.concatenator import ConcatDataset
  25. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  26. from llama_recipes.utils import fsdp_auto_wrap_policy
  27. from llama_recipes.utils.config_utils import (
  28. update_config,
  29. generate_peft_config,
  30. generate_dataset_config,
  31. get_dataloader_kwargs,
  32. )
  33. from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
  34. from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
  35. from llama_recipes.utils.train_utils import (
  36. train,
  37. freeze_transformer_layers,
  38. setup,
  39. setup_environ_flags,
  40. clear_gpu_cache,
  41. print_model_size,
  42. get_policies,
  43. )
  44. from accelerate.utils import is_xpu_available
  45. def setup_wandb(train_config, fsdp_config, **kwargs):
  46. try:
  47. import wandb
  48. except ImportError:
  49. raise ImportError(
  50. "You are trying to use wandb which is not currently installed. "
  51. "Please install it using pip install wandb"
  52. )
  53. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  54. wandb_config = WANDB_CONFIG()
  55. update_config(wandb_config, **kwargs)
  56. init_dict = dataclasses.asdict(wandb_config)
  57. run = wandb.init(**init_dict)
  58. run.config.update(train_config)
  59. run.config.update(fsdp_config, allow_val_change=True)
  60. return run
  61. def main(**kwargs):
  62. # Update the configuration for the training and sharding process
  63. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  64. update_config((train_config, fsdp_config), **kwargs)
  65. # Set the seeds for reproducibility
  66. if is_xpu_available():
  67. torch.xpu.manual_seed(train_config.seed)
  68. torch.manual_seed(train_config.seed)
  69. random.seed(train_config.seed)
  70. if train_config.enable_fsdp:
  71. setup()
  72. # torchrun specific
  73. local_rank = int(os.environ["LOCAL_RANK"])
  74. rank = int(os.environ["RANK"])
  75. world_size = int(os.environ["WORLD_SIZE"])
  76. if torch.distributed.is_initialized():
  77. if is_xpu_available():
  78. torch.xpu.set_device(local_rank)
  79. elif torch.cuda.is_available():
  80. torch.cuda.set_device(local_rank)
  81. clear_gpu_cache(local_rank)
  82. setup_environ_flags(rank)
  83. wandb_run = None
  84. if train_config.use_wandb:
  85. if not train_config.enable_fsdp or rank==0:
  86. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  87. # Load the pre-trained model and setup its configuration
  88. use_cache = False if train_config.enable_fsdp else None
  89. if train_config.enable_fsdp and train_config.low_cpu_fsdp:
  90. """
  91. for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
  92. this avoids cpu oom when loading large models like llama 70B, in which case
  93. model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
  94. overhead and currently requires latest nightly.
  95. """
  96. if rank == 0:
  97. model = LlamaForCausalLM.from_pretrained(
  98. train_config.model_name,
  99. load_in_8bit=True if train_config.quantization else None,
  100. device_map="auto" if train_config.quantization else None,
  101. use_cache=use_cache,
  102. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  103. )
  104. else:
  105. llama_config = LlamaConfig.from_pretrained(train_config.model_name)
  106. llama_config.use_cache = use_cache
  107. with torch.device("meta"):
  108. model = LlamaForCausalLM(llama_config)
  109. else:
  110. model = LlamaForCausalLM.from_pretrained(
  111. train_config.model_name,
  112. load_in_8bit=True if train_config.quantization else None,
  113. device_map="auto" if train_config.quantization else None,
  114. use_cache=use_cache,
  115. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  116. )
  117. # Load the tokenizer and add special tokens
  118. tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
  119. tokenizer.pad_token_id = tokenizer.eos_token_id
  120. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  121. # throw a warning and then expand the embedding matrix
  122. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  123. print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
  124. model.resize_token_embeddings(len(tokenizer))
  125. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  126. # Prepare the model for int8 training if quantization is enabled
  127. if train_config.quantization:
  128. model = prepare_model_for_kbit_training(model)
  129. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  130. if train_config.enable_fsdp and fsdp_config.pure_bf16:
  131. model.to(torch.bfloat16)
  132. if train_config.use_peft:
  133. peft_config = generate_peft_config(train_config, kwargs)
  134. model = get_peft_model(model, peft_config)
  135. model.print_trainable_parameters()
  136. if wandb_run:
  137. wandb_run.config.update(peft_config)
  138. hsdp_device_mesh = None
  139. if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
  140. hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
  141. print("HSDP device mesh is ready")
  142. #setting up FSDP if enable_fsdp is enabled
  143. if train_config.enable_fsdp:
  144. if not train_config.use_peft and train_config.freeze_layers:
  145. freeze_transformer_layers(train_config.num_freeze_layers)
  146. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  147. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
  148. device_id = 0
  149. if is_xpu_available():
  150. device_id = torch.xpu.current_device()
  151. elif torch.cuda.is_available():
  152. device_id = torch.cuda.current_device()
  153. model = FSDP(
  154. model,
  155. auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
  156. cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
  157. mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
  158. sharding_strategy=fsdp_config.sharding_strategy,
  159. device_mesh=hsdp_device_mesh,
  160. device_id=device_id,
  161. limit_all_gathers=True,
  162. sync_module_states=train_config.low_cpu_fsdp,
  163. param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
  164. if train_config.low_cpu_fsdp and rank != 0 else None,
  165. )
  166. if fsdp_config.fsdp_activation_checkpointing:
  167. apply_fsdp_checkpointing(model)
  168. elif not train_config.quantization and not train_config.enable_fsdp:
  169. if is_xpu_available():
  170. model.to("xpu:0")
  171. elif torch.cuda.is_available():
  172. model.to("cuda")
  173. dataset_config = generate_dataset_config(train_config, kwargs)
  174. # Load and preprocess the dataset for training and validation
  175. dataset_train = get_preprocessed_dataset(
  176. tokenizer,
  177. dataset_config,
  178. split="train",
  179. )
  180. if not train_config.enable_fsdp or rank == 0:
  181. print(f"--> Training Set Length = {len(dataset_train)}")
  182. dataset_val = get_preprocessed_dataset(
  183. tokenizer,
  184. dataset_config,
  185. split="test",
  186. )
  187. if not train_config.enable_fsdp or rank == 0:
  188. print(f"--> Validation Set Length = {len(dataset_val)}")
  189. if train_config.batching_strategy == "packing":
  190. dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
  191. train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
  192. # Create DataLoaders for the training and validation dataset
  193. train_dataloader = torch.utils.data.DataLoader(
  194. dataset_train,
  195. num_workers=train_config.num_workers_dataloader,
  196. pin_memory=True,
  197. **train_dl_kwargs,
  198. )
  199. eval_dataloader = None
  200. if train_config.run_validation:
  201. if train_config.batching_strategy == "packing":
  202. dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
  203. val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
  204. eval_dataloader = torch.utils.data.DataLoader(
  205. dataset_val,
  206. num_workers=train_config.num_workers_dataloader,
  207. pin_memory=True,
  208. **val_dl_kwargs,
  209. )
  210. # Initialize the optimizer and learning rate scheduler
  211. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  212. optimizer = AnyPrecisionAdamW(
  213. model.parameters(),
  214. lr=train_config.lr,
  215. momentum_dtype=torch.bfloat16,
  216. variance_dtype=torch.bfloat16,
  217. use_kahan_summation=False,
  218. weight_decay=train_config.weight_decay,
  219. )
  220. else:
  221. optimizer = optim.AdamW(
  222. model.parameters(),
  223. lr=train_config.lr,
  224. weight_decay=train_config.weight_decay,
  225. )
  226. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  227. # Start the training process
  228. results = train(
  229. model,
  230. train_dataloader,
  231. eval_dataloader,
  232. tokenizer,
  233. optimizer,
  234. scheduler,
  235. train_config.gradient_accumulation_steps,
  236. train_config,
  237. fsdp_config if train_config.enable_fsdp else None,
  238. local_rank if train_config.enable_fsdp else None,
  239. rank if train_config.enable_fsdp else None,
  240. wandb_run,
  241. )
  242. if not train_config.enable_fsdp or rank==0:
  243. [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
  244. if train_config.use_wandb:
  245. for k,v in results.items():
  246. wandb_run.summary[k] = v
  247. if __name__ == "__main__":
  248. fire.Fire(main)