train_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. from tqdm import tqdm
  11. """
  12. Unused imports:
  13. import torch.nn as nn
  14. import bitsandbytes as bnb
  15. """
  16. from torch.nn import functional as F
  17. from peft import (
  18. LoraConfig,
  19. get_peft_model,
  20. get_peft_model_state_dict,
  21. prepare_model_for_int8_training,
  22. set_peft_model_state_dict,
  23. )
  24. from transformers import LlamaForCausalLM, LlamaTokenizer
  25. from torch.distributed.fsdp import StateDictType
  26. import torch.distributed as dist
  27. from pkg_resources import packaging
  28. from .memory_utils import MemoryTrace
  29. import model_checkpointing
  30. import torch.cuda.nccl as nccl
  31. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  32. from pathlib import Path
  33. sys.path.append(str(Path(__file__).resolve().parent.parent))
  34. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  35. scaler = ShardedGradScaler()
  36. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  37. tokenizer.pad_token_id = 0
  38. tokenizer.padding_side = "left"
  39. # Converting Bytes to Megabytes
  40. def byte2mb(x):
  41. return int(x / 2**20)
  42. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  43. """
  44. Trains the model on the given dataloader
  45. Args:
  46. model: The model to be trained
  47. train_dataloader: The dataloader containing the training data
  48. optimizer: The optimizer used for training
  49. lr_scheduler: The learning rate scheduler
  50. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  51. num_epochs: The number of epochs to train for
  52. local_rank: The rank of the current node in a distributed setting
  53. train_config: The training configuration
  54. eval_dataloader: The dataloader containing the eval data
  55. tokenizer: tokenizer used in the eval for decoding the predicitons
  56. Returns: results dictionary containing average training and validation perplexity and loss
  57. """
  58. # Create a gradient scaler for fp16
  59. scaler = torch.cuda.amp.GradScaler() if train_config.use_fp16 else None
  60. train_prep = []
  61. train_loss = []
  62. val_prep = []
  63. val_loss =[]
  64. results = {}
  65. best_val_loss = float("inf")
  66. for epoch in range(train_config.num_epochs):
  67. with MemoryTrace() as memtrace: # track the memory usage
  68. model.train()
  69. total_loss = 0.0
  70. data_set_len = 0
  71. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  72. for key in batch.keys():
  73. if train_config.enable_fsdp:
  74. batch[key] = batch[key].to(local_rank)
  75. elif not train_config.quantization:
  76. batch[key] = batch[key].to('cuda')
  77. outputs = model(**batch)
  78. loss = outputs.loss
  79. loss = loss / gradient_accumulation_steps
  80. total_loss += loss.detach().float()
  81. first_key = next(iter(batch))
  82. data_set_len += len(batch[first_key])
  83. if train_config.use_fp16:
  84. # if fp16 is enabled, use gradient scaler to handle gradient update
  85. scaler.scale(loss).backward()
  86. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  87. scaler.step(optimizer)
  88. scaler.update()
  89. optimizer.zero_grad()
  90. else:
  91. # regular backpropagation when fp16 is not used
  92. loss.backward()
  93. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  94. optimizer.step()
  95. lr_scheduler.step()
  96. optimizer.zero_grad()
  97. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  98. # Reducing total_loss across all devices if there's more than one CUDA device
  99. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  100. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  101. train_epoch_loss = total_loss / data_set_len
  102. train_perplexity = torch.exp(train_epoch_loss)
  103. train_prep.append(train_perplexity)
  104. train_loss.append(train_epoch_loss)
  105. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  106. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  107. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  108. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  109. if train_config.run_validation:
  110. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  111. if train_config.save_model and eval_epoch_loss < best_val_loss:
  112. if train_config.use_peft:
  113. print(f"we are in the saving the PEFT modules")
  114. model.save_pretrained(train_config.output_dir)
  115. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  116. else:
  117. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  118. model_checkpointing.save_model_checkpoint(
  119. model, optimizer, rank, train_config, epoch=1
  120. )
  121. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  122. print(" we are about to save the models *******")
  123. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  124. if train_config.save_optimizer:
  125. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  126. if not train_config.use_peft and train_config.save_optimizer:
  127. model_checkpointing.save_optimizer_checkpoint(
  128. model, optimizer, rank, train_config, epoch=1
  129. )
  130. if local_rank == 0 and eval_epoch_loss < best_val_loss:
  131. best_val_loss = eval_epoch_loss
  132. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  133. val_loss.append(best_val_loss)
  134. val_prep.append(eval_ppl)
  135. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  136. avg_train_prep = sum(train_prep)/len(train_prep)
  137. avg_train_loss = sum(train_loss)/len(train_loss)
  138. if train_config.run_validation:
  139. avg_eval_prep = sum(val_prep)/len(val_prep)
  140. avg_eval_loss = sum(val_loss)/len(val_loss)
  141. results['avg_train_prep'] = avg_train_prep
  142. results['avg_train_loss'] = avg_train_loss
  143. if train_config.run_validation:
  144. results['avg_eval_prep'] = avg_eval_prep
  145. results['avg_eval_loss'] = avg_eval_loss
  146. return results
  147. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  148. """
  149. Evaluates the model on the given dataloader
  150. Args:
  151. model: The model to evaluate
  152. eval_dataloader: The dataloader containing the evaluation data
  153. local_rank: The rank of the current node in a distributed setting
  154. tokenizer: The tokenizer used to decode predictions
  155. Returns: eval_ppl, eval_epoch_loss
  156. """
  157. model.eval()
  158. eval_preds = []
  159. eval_loss = 0.0 # Initialize evaluation loss
  160. eval_dataset_len = 0
  161. with MemoryTrace() as memtrace:
  162. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  163. for key in batch.keys():
  164. if train_config.enable_fsdp:
  165. batch[key] = batch[key].to(local_rank)
  166. else:
  167. batch[key] = batch[key].to('cuda')
  168. # Ensure no gradients are computed for this scope to save memory
  169. with torch.no_grad():
  170. # Forward pass and compute loss
  171. outputs = model(**batch)
  172. loss = outputs.loss
  173. eval_loss += loss.detach().float()
  174. first_key = next(iter(batch))
  175. eval_dataset_len+= len(batch[first_key])
  176. # Decode predictions and add to evaluation predictions list
  177. preds = torch.argmax(outputs.logits, -1)
  178. eval_preds.extend(
  179. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  180. )
  181. # If there's more than one CUDA device, reduce evaluation loss across all devices
  182. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  183. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  184. # Compute average loss and perplexity
  185. eval_epoch_loss = eval_loss / eval_dataset_len
  186. eval_ppl = torch.exp(eval_epoch_loss)
  187. # Print evaluation metrics
  188. print(f" {eval_ppl=} {eval_epoch_loss=}")
  189. return eval_ppl, eval_epoch_loss
  190. def freeze_transformer_layers(model, num_layer):
  191. for i, layer in enumerate(model.model.layers):
  192. if i < num_layer:
  193. for param in layer.parameters():
  194. param.requires_grad = False
  195. def check_frozen_layers_peft_model(model):
  196. for i, layer in enumerate(model.base_model.model.model.layers):
  197. for name, param in layer.named_parameters():
  198. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  199. def setup():
  200. """Initialize the process group for distributed training"""
  201. dist.init_process_group("nccl")
  202. def setup_environ_flags(rank):
  203. """Set environment flags for debugging purposes"""
  204. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  205. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  206. os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  207. if rank == 0:
  208. print(f"--> Running with torch dist debug set to detail")
  209. def cleanup():
  210. """Clean up the process group after training"""
  211. dist.destroy_process_group()
  212. def clear_gpu_cache(rank=None):
  213. """Clear the GPU cache for all ranks"""
  214. if rank == 0:
  215. print(f"Clearing GPU cache for all ranks")
  216. torch.cuda.empty_cache()
  217. def get_parameter_dtypes(model):
  218. """Get the data types of model parameters"""
  219. parameter_dtypes = {}
  220. for name, parameter in model.named_parameters():
  221. parameter_dtypes[name] = parameter.dtype
  222. return parameter_dtypes
  223. def print_model_size(model, config, rank: int = 0) -> None:
  224. """
  225. Print model name, the number of trainable parameters and initialization time.
  226. Args:
  227. model: The PyTorch model.
  228. model_name (str): Name of the model.
  229. init_time_start (float): Initialization start time.
  230. init_time_end (float): Initialization end time.
  231. rank (int, optional): Current process's rank. Defaults to 0.
  232. """
  233. if rank == 0:
  234. print(f"--> Model {config.model_name}")
  235. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  236. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  237. def get_policies(cfg, rank):
  238. """Get the policies for mixed precision and fsdp wrapping"""
  239. verify_bfloat_support = (
  240. torch.version.cuda
  241. and torch.cuda.is_bf16_supported()
  242. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  243. and dist.is_nccl_available()
  244. and nccl.version() >= (2, 10)
  245. )
  246. mixed_precision_policy = None
  247. wrapping_policy = None
  248. # Mixed precision
  249. if cfg.mixed_precision:
  250. bf16_ready = verify_bfloat_support
  251. if bf16_ready and not cfg.use_fp16:
  252. mixed_precision_policy = bfSixteen_mixed
  253. if rank == 0:
  254. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  255. elif cfg.use_fp16:
  256. mixed_precision_policy = fpSixteen
  257. if rank == 0:
  258. print(f"FP16 enabled")
  259. else:
  260. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  261. wrapping_policy = get_llama_wrapper()
  262. return mixed_precision_policy, wrapping_policy