train_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. from tqdm import tqdm
  11. """
  12. Unused imports:
  13. import torch.nn as nn
  14. import bitsandbytes as bnb
  15. """
  16. from torch.nn import functional as F
  17. from peft import (
  18. LoraConfig,
  19. get_peft_model,
  20. get_peft_model_state_dict,
  21. prepare_model_for_int8_training,
  22. set_peft_model_state_dict,
  23. )
  24. from transformers import LlamaForCausalLM, LlamaTokenizer
  25. from torch.distributed.fsdp import StateDictType
  26. import torch.distributed as dist
  27. from pkg_resources import packaging
  28. from .memory_utils import MemoryTrace
  29. import model_checkpointing
  30. import torch.cuda.nccl as nccl
  31. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  32. from pathlib import Path
  33. sys.path.append(str(Path(__file__).resolve().parent.parent))
  34. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  35. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  36. tokenizer.pad_token_id = 0
  37. tokenizer.padding_side = "left"
  38. # Converting Bytes to Megabytes
  39. def byte2mb(x):
  40. return int(x / 2**20)
  41. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  42. """
  43. Trains the model on the given dataloader
  44. Args:
  45. model: The model to be trained
  46. train_dataloader: The dataloader containing the training data
  47. optimizer: The optimizer used for training
  48. lr_scheduler: The learning rate scheduler
  49. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  50. num_epochs: The number of epochs to train for
  51. local_rank: The rank of the current node in a distributed setting
  52. train_config: The training configuration
  53. eval_dataloader: The dataloader containing the eval data
  54. tokenizer: tokenizer used in the eval for decoding the predicitons
  55. Returns: results dictionary containing average training and validation perplexity and loss
  56. """
  57. # Create a gradient scaler for fp16
  58. if train_config.use_fp16 and train_config.enable_fsdp:
  59. scaler = ShardedGradScaler()
  60. elif train_config.use_fp16 and not train_config.enable_fsdp:
  61. scaler = torch.cuda.amp.GradScaler()
  62. if train_config.enable_fsdp:
  63. world_size = int(os.environ["WORLD_SIZE"])
  64. train_prep = []
  65. train_loss = []
  66. val_prep = []
  67. val_loss =[]
  68. results = {}
  69. best_val_loss = float("inf")
  70. for epoch in range(train_config.num_epochs):
  71. with MemoryTrace() as memtrace: # track the memory usage
  72. model.train()
  73. total_loss = 0.0
  74. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  75. for key in batch.keys():
  76. if train_config.enable_fsdp:
  77. batch[key] = batch[key].to(local_rank)
  78. else:
  79. batch[key] = batch[key].to('cuda:0')
  80. loss = model(**batch).loss
  81. loss = loss / gradient_accumulation_steps
  82. total_loss += loss.detach().float()
  83. if train_config.use_fp16:
  84. # if fp16 is enabled, use gradient scaler to handle gradient update
  85. scaler.scale(loss).backward()
  86. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  87. scaler.step(optimizer)
  88. scaler.update()
  89. optimizer.zero_grad()
  90. else:
  91. # regular backpropagation when fp16 is not used
  92. loss.backward()
  93. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  94. optimizer.step()
  95. optimizer.zero_grad()
  96. if train_config.enable_fsdp:
  97. if rank==0:
  98. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  99. else:
  100. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  101. # Reducing total_loss across all devices if there's more than one CUDA device
  102. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  103. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  104. train_epoch_loss = total_loss / len(train_dataloader)
  105. if train_config.enable_fsdp:
  106. train_epoch_loss = train_epoch_loss/world_size
  107. train_perplexity = torch.exp(train_epoch_loss)
  108. train_prep.append(train_perplexity)
  109. train_loss.append(train_epoch_loss)
  110. if train_config.enable_fsdp:
  111. if rank==0:
  112. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  113. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  114. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  115. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  116. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  117. else:
  118. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  119. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  120. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  121. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  122. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  123. # Update the learning rate as needed
  124. lr_scheduler.step()
  125. if train_config.run_validation:
  126. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  127. if train_config.save_model and eval_epoch_loss < best_val_loss:
  128. if train_config.enable_fsdp:
  129. dist.barrier()
  130. if train_config.use_peft:
  131. if train_config.enable_fsdp:
  132. if rank==0:
  133. print(f"we are about to save the PEFT modules")
  134. else:
  135. print(f"we are about to save the PEFT modules")
  136. model.save_pretrained(train_config.output_dir)
  137. if train_config.enable_fsdp:
  138. if rank==0:
  139. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  140. else:
  141. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  142. else:
  143. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  144. model_checkpointing.save_model_checkpoint(
  145. model, optimizer, rank, train_config, epoch=epoch
  146. )
  147. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  148. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  149. print("=====================================================")
  150. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  151. if train_config.save_optimizer:
  152. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  153. print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
  154. print("=====================================================")
  155. if not train_config.use_peft and train_config.save_optimizer:
  156. model_checkpointing.save_optimizer_checkpoint(
  157. model, optimizer, rank, train_config, epoch=epoch
  158. )
  159. print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
  160. print("=====================================================")
  161. if train_config.enable_fsdp:
  162. dist.barrier()
  163. if eval_epoch_loss < best_val_loss:
  164. best_val_loss = eval_epoch_loss
  165. if train_config.enable_fsdp:
  166. if rank==0:
  167. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  168. else:
  169. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  170. val_loss.append(best_val_loss)
  171. val_prep.append(eval_ppl)
  172. if train_config.enable_fsdp:
  173. if rank==0:
  174. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  175. else:
  176. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  177. lr_scheduler.step()
  178. avg_train_prep = sum(train_prep)/len(train_prep)
  179. avg_train_loss = sum(train_loss)/len(train_loss)
  180. if train_config.run_validation:
  181. avg_eval_prep = sum(val_prep)/len(val_prep)
  182. avg_eval_loss = sum(val_loss)/len(val_loss)
  183. results['avg_train_prep'] = avg_train_prep
  184. results['avg_train_loss'] = avg_train_loss
  185. if train_config.run_validation:
  186. results['avg_eval_prep'] = avg_eval_prep
  187. results['avg_eval_loss'] = avg_eval_loss
  188. return results
  189. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  190. """
  191. Evaluates the model on the given dataloader
  192. Args:
  193. model: The model to evaluate
  194. eval_dataloader: The dataloader containing the evaluation data
  195. local_rank: The rank of the current node in a distributed setting
  196. tokenizer: The tokenizer used to decode predictions
  197. Returns: eval_ppl, eval_epoch_loss
  198. """
  199. if train_config.enable_fsdp:
  200. world_size = int(os.environ["WORLD_SIZE"])
  201. model.eval()
  202. eval_preds = []
  203. eval_loss = 0.0 # Initialize evaluation loss
  204. with MemoryTrace() as memtrace:
  205. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  206. for key in batch.keys():
  207. if train_config.enable_fsdp:
  208. batch[key] = batch[key].to(local_rank)
  209. else:
  210. batch[key] = batch[key].to('cuda:0')
  211. # Ensure no gradients are computed for this scope to save memory
  212. with torch.no_grad():
  213. # Forward pass and compute loss
  214. outputs = model(**batch)
  215. loss = outputs.loss
  216. eval_loss += loss.detach().float()
  217. # Decode predictions and add to evaluation predictions list
  218. preds = torch.argmax(outputs.logits, -1)
  219. eval_preds.extend(
  220. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  221. )
  222. # If there's more than one CUDA device, reduce evaluation loss across all devices
  223. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  224. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  225. # Compute average loss and perplexity
  226. eval_epoch_loss = eval_loss / len(eval_dataloader)
  227. if train_config.enable_fsdp:
  228. eval_epoch_loss = eval_epoch_loss/world_size
  229. eval_ppl = torch.exp(eval_epoch_loss)
  230. # Print evaluation metrics
  231. if train_config.enable_fsdp:
  232. if local_rank==0:
  233. print(f" {eval_ppl=} {eval_epoch_loss=}")
  234. else:
  235. print(f" {eval_ppl=} {eval_epoch_loss=}")
  236. return eval_ppl, eval_epoch_loss
  237. def freeze_transformer_layers(model, num_layer):
  238. for i, layer in enumerate(model.model.layers):
  239. if i < num_layer:
  240. for param in layer.parameters():
  241. param.requires_grad = False
  242. def check_frozen_layers_peft_model(model):
  243. for i, layer in enumerate(model.base_model.model.model.layers):
  244. for name, param in layer.named_parameters():
  245. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  246. def setup():
  247. """Initialize the process group for distributed training"""
  248. dist.init_process_group("nccl")
  249. def setup_environ_flags(rank):
  250. """Set environment flags for debugging purposes"""
  251. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  252. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  253. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  254. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  255. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  256. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  257. if rank == 0:
  258. print(f"--> Running with torch dist debug set to detail")
  259. def cleanup():
  260. """Clean up the process group after training"""
  261. dist.destroy_process_group()
  262. def clear_gpu_cache(rank=None):
  263. """Clear the GPU cache for all ranks"""
  264. if rank == 0:
  265. print(f"Clearing GPU cache for all ranks")
  266. torch.cuda.empty_cache()
  267. def get_parameter_dtypes(model):
  268. """Get the data types of model parameters"""
  269. parameter_dtypes = {}
  270. for name, parameter in model.named_parameters():
  271. parameter_dtypes[name] = parameter.dtype
  272. return parameter_dtypes
  273. def print_model_size(model, config, rank: int = 0) -> None:
  274. """
  275. Print model name, the number of trainable parameters and initialization time.
  276. Args:
  277. model: The PyTorch model.
  278. model_name (str): Name of the model.
  279. init_time_start (float): Initialization start time.
  280. init_time_end (float): Initialization end time.
  281. rank (int, optional): Current process's rank. Defaults to 0.
  282. """
  283. if rank == 0:
  284. print(f"--> Model {config.model_name}")
  285. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  286. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  287. def get_policies(cfg, rank):
  288. """Get the policies for mixed precision and fsdp wrapping"""
  289. verify_bfloat_support = (
  290. torch.version.cuda
  291. and torch.cuda.is_bf16_supported()
  292. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  293. and dist.is_nccl_available()
  294. and nccl.version() >= (2, 10)
  295. )
  296. mixed_precision_policy = None
  297. wrapping_policy = None
  298. # Mixed precision
  299. if cfg.mixed_precision:
  300. bf16_ready = verify_bfloat_support
  301. if bf16_ready and not cfg.use_fp16:
  302. mixed_precision_policy = bfSixteen_mixed
  303. if rank == 0:
  304. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  305. elif cfg.use_fp16:
  306. mixed_precision_policy = fpSixteen
  307. if rank == 0:
  308. print(f"FP16 enabled")
  309. else:
  310. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  311. wrapping_policy = get_llama_wrapper()
  312. return mixed_precision_policy, wrapping_policy