train_utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. from tqdm import tqdm
  11. """
  12. Unused imports:
  13. import torch.nn as nn
  14. import bitsandbytes as bnb
  15. """
  16. from torch.nn import functional as F
  17. from peft import (
  18. LoraConfig,
  19. get_peft_model,
  20. get_peft_model_state_dict,
  21. prepare_model_for_int8_training,
  22. set_peft_model_state_dict,
  23. )
  24. from transformers import LlamaForCausalLM, LlamaTokenizer
  25. from torch.distributed.fsdp import StateDictType
  26. import torch.distributed as dist
  27. from pkg_resources import packaging
  28. from .memory_utils import MemoryTrace
  29. import model_checkpointing
  30. import torch.cuda.nccl as nccl
  31. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  32. from pathlib import Path
  33. sys.path.append(str(Path(__file__).resolve().parent.parent))
  34. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  35. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  36. tokenizer.pad_token_id = 0
  37. tokenizer.padding_side = "left"
  38. # Converting Bytes to Megabytes
  39. def byte2mb(x):
  40. return int(x / 2**20)
  41. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  42. """
  43. Trains the model on the given dataloader
  44. Args:
  45. model: The model to be trained
  46. train_dataloader: The dataloader containing the training data
  47. optimizer: The optimizer used for training
  48. lr_scheduler: The learning rate scheduler
  49. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  50. num_epochs: The number of epochs to train for
  51. local_rank: The rank of the current node in a distributed setting
  52. train_config: The training configuration
  53. eval_dataloader: The dataloader containing the eval data
  54. tokenizer: tokenizer used in the eval for decoding the predicitons
  55. Returns: results dictionary containing average training and validation perplexity and loss
  56. """
  57. # Create a gradient scaler for fp16
  58. if train_config.use_fp16 and train_config.enable_fsdp:
  59. scaler = ShardedGradScaler()
  60. elif train_config.use_fp16 and not train_config.enable_fsdp:
  61. scaler = torch.cuda.amp.GradScaler()
  62. train_prep = []
  63. train_loss = []
  64. val_prep = []
  65. val_loss =[]
  66. results = {}
  67. best_val_loss = float("inf")
  68. for epoch in range(train_config.num_epochs):
  69. with MemoryTrace() as memtrace: # track the memory usage
  70. model.train()
  71. total_loss = 0.0
  72. data_set_len = 0
  73. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  74. for key in batch.keys():
  75. if train_config.enable_fsdp:
  76. batch[key] = batch[key].to(local_rank)
  77. else:
  78. batch[key] = batch[key].to('cuda')
  79. outputs = model(**batch)
  80. loss = outputs.loss
  81. loss = loss / gradient_accumulation_steps
  82. total_loss += loss.detach().float()
  83. first_key = next(iter(batch))
  84. data_set_len += len(batch[first_key])
  85. if train_config.use_fp16:
  86. # if fp16 is enabled, use gradient scaler to handle gradient update
  87. scaler.scale(loss).backward()
  88. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  89. scaler.step(optimizer)
  90. scaler.update()
  91. optimizer.zero_grad()
  92. else:
  93. # regular backpropagation when fp16 is not used
  94. loss.backward()
  95. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  96. optimizer.step()
  97. optimizer.zero_grad()
  98. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  99. # Reducing total_loss across all devices if there's more than one CUDA device
  100. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  101. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  102. train_epoch_loss = total_loss / data_set_len
  103. train_perplexity = torch.exp(train_epoch_loss)
  104. train_prep.append(train_perplexity)
  105. train_loss.append(train_epoch_loss)
  106. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  107. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  108. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  109. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  110. # Update the learning rate as needed
  111. lr_scheduler.step()
  112. if train_config.run_validation:
  113. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  114. if train_config.save_model and eval_epoch_loss < best_val_loss:
  115. if train_config.use_peft:
  116. print(f"we are in the saving the PEFT modules")
  117. model.save_pretrained(train_config.output_dir)
  118. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  119. else:
  120. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  121. model_checkpointing.save_model_checkpoint(
  122. model, optimizer, rank, train_config, epoch=1
  123. )
  124. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  125. print(" we are about to save the models *******")
  126. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  127. if train_config.save_optimizer:
  128. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  129. if not train_config.use_peft and train_config.save_optimizer:
  130. model_checkpointing.save_optimizer_checkpoint(
  131. model, optimizer, rank, train_config, epoch=1
  132. )
  133. if local_rank == 0 and eval_epoch_loss < best_val_loss:
  134. best_val_loss = eval_epoch_loss
  135. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  136. val_loss.append(best_val_loss)
  137. val_prep.append(eval_ppl)
  138. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  139. lr_scheduler.step()
  140. avg_train_prep = sum(train_prep)/len(train_prep)
  141. avg_train_loss = sum(train_loss)/len(train_loss)
  142. if train_config.run_validation:
  143. avg_eval_prep = sum(val_prep)/len(val_prep)
  144. avg_eval_loss = sum(val_loss)/len(val_loss)
  145. results['avg_train_prep'] = avg_train_prep
  146. results['avg_train_loss'] = avg_train_loss
  147. if train_config.run_validation:
  148. results['avg_eval_prep'] = avg_eval_prep
  149. results['avg_eval_loss'] = avg_eval_loss
  150. return results
  151. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  152. """
  153. Evaluates the model on the given dataloader
  154. Args:
  155. model: The model to evaluate
  156. eval_dataloader: The dataloader containing the evaluation data
  157. local_rank: The rank of the current node in a distributed setting
  158. tokenizer: The tokenizer used to decode predictions
  159. Returns: eval_ppl, eval_epoch_loss
  160. """
  161. model.eval()
  162. eval_preds = []
  163. eval_loss = 0.0 # Initialize evaluation loss
  164. eval_dataset_len = 0
  165. with MemoryTrace() as memtrace:
  166. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  167. for key in batch.keys():
  168. if train_config.enable_fsdp:
  169. batch[key] = batch[key].to(local_rank)
  170. else:
  171. batch[key] = batch[key].to('cuda')
  172. # Ensure no gradients are computed for this scope to save memory
  173. with torch.no_grad():
  174. # Forward pass and compute loss
  175. outputs = model(**batch)
  176. loss = outputs.loss
  177. eval_loss += loss.detach().float()
  178. first_key = next(iter(batch))
  179. eval_dataset_len+= len(batch[first_key])
  180. # Decode predictions and add to evaluation predictions list
  181. preds = torch.argmax(outputs.logits, -1)
  182. eval_preds.extend(
  183. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  184. )
  185. # If there's more than one CUDA device, reduce evaluation loss across all devices
  186. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  187. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  188. # Compute average loss and perplexity
  189. eval_epoch_loss = eval_loss / eval_dataset_len
  190. eval_ppl = torch.exp(eval_epoch_loss)
  191. # Print evaluation metrics
  192. print(f" {eval_ppl=} {eval_epoch_loss=}")
  193. return eval_ppl, eval_epoch_loss
  194. def freeze_transformer_layers(model, num_layer):
  195. for i, layer in enumerate(model.model.layers):
  196. if i < num_layer:
  197. for param in layer.parameters():
  198. param.requires_grad = False
  199. def check_frozen_layers_peft_model(model):
  200. for i, layer in enumerate(model.base_model.model.model.layers):
  201. for name, param in layer.named_parameters():
  202. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  203. def setup():
  204. """Initialize the process group for distributed training"""
  205. dist.init_process_group("nccl")
  206. def setup_environ_flags(rank):
  207. """Set environment flags for debugging purposes"""
  208. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  209. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  210. os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  211. if rank == 0:
  212. print(f"--> Running with torch dist debug set to detail")
  213. def cleanup():
  214. """Clean up the process group after training"""
  215. dist.destroy_process_group()
  216. def clear_gpu_cache(rank=None):
  217. """Clear the GPU cache for all ranks"""
  218. if rank == 0:
  219. print(f"Clearing GPU cache for all ranks")
  220. torch.cuda.empty_cache()
  221. def get_parameter_dtypes(model):
  222. """Get the data types of model parameters"""
  223. parameter_dtypes = {}
  224. for name, parameter in model.named_parameters():
  225. parameter_dtypes[name] = parameter.dtype
  226. return parameter_dtypes
  227. def print_model_size(model, config, rank: int = 0) -> None:
  228. """
  229. Print model name, the number of trainable parameters and initialization time.
  230. Args:
  231. model: The PyTorch model.
  232. model_name (str): Name of the model.
  233. init_time_start (float): Initialization start time.
  234. init_time_end (float): Initialization end time.
  235. rank (int, optional): Current process's rank. Defaults to 0.
  236. """
  237. if rank == 0:
  238. print(f"--> Model {config.model_name}")
  239. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  240. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  241. def get_policies(cfg, rank):
  242. """Get the policies for mixed precision and fsdp wrapping"""
  243. verify_bfloat_support = (
  244. torch.version.cuda
  245. and torch.cuda.is_bf16_supported()
  246. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  247. and dist.is_nccl_available()
  248. and nccl.version() >= (2, 10)
  249. )
  250. mixed_precision_policy = None
  251. wrapping_policy = None
  252. # Mixed precision
  253. if cfg.mixed_precision:
  254. bf16_ready = verify_bfloat_support
  255. if bf16_ready and not cfg.use_fp16:
  256. mixed_precision_policy = bfSixteen_mixed
  257. if rank == 0:
  258. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  259. elif cfg.use_fp16:
  260. mixed_precision_policy = fpSixteen
  261. if rank == 0:
  262. print(f"FP16 enabled")
  263. else:
  264. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  265. wrapping_policy = get_llama_wrapper()
  266. return mixed_precision_policy, wrapping_policy