train_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from pkg_resources import packaging
  9. import torch
  10. import torch.cuda.nccl as nccl
  11. import torch.distributed as dist
  12. from torch.distributed.fsdp import StateDictType
  13. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  14. from tqdm import tqdm
  15. from transformers import LlamaTokenizer
  16. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  17. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  18. from llama_recipes.utils.memory_utils import MemoryTrace
  19. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  20. tokenizer.pad_token_id = 0
  21. tokenizer.padding_side = "left"
  22. # Converting Bytes to Megabytes
  23. def byte2mb(x):
  24. return int(x / 2**20)
  25. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  26. """
  27. Trains the model on the given dataloader
  28. Args:
  29. model: The model to be trained
  30. train_dataloader: The dataloader containing the training data
  31. optimizer: The optimizer used for training
  32. lr_scheduler: The learning rate scheduler
  33. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  34. num_epochs: The number of epochs to train for
  35. local_rank: The rank of the current node in a distributed setting
  36. train_config: The training configuration
  37. eval_dataloader: The dataloader containing the eval data
  38. tokenizer: tokenizer used in the eval for decoding the predicitons
  39. Returns: results dictionary containing average training and validation perplexity and loss
  40. """
  41. # Create a gradient scaler for fp16
  42. if train_config.use_fp16 and train_config.enable_fsdp:
  43. scaler = ShardedGradScaler()
  44. elif train_config.use_fp16 and not train_config.enable_fsdp:
  45. scaler = torch.cuda.amp.GradScaler()
  46. if train_config.enable_fsdp:
  47. world_size = int(os.environ["WORLD_SIZE"])
  48. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  49. train_prep = []
  50. train_loss = []
  51. val_prep = []
  52. val_loss =[]
  53. epoch_times = []
  54. checkpoint_times = []
  55. results = {}
  56. best_val_loss = float("inf")
  57. for epoch in range(train_config.num_epochs):
  58. epoch_start_time = time.perf_counter()
  59. with MemoryTrace() as memtrace: # track the memory usage
  60. model.train()
  61. total_loss = 0.0
  62. total_length = len(train_dataloader)//gradient_accumulation_steps
  63. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  64. for step, batch in enumerate(train_dataloader):
  65. for key in batch.keys():
  66. if train_config.enable_fsdp:
  67. batch[key] = batch[key].to(local_rank)
  68. else:
  69. batch[key] = batch[key].to('cuda:0')
  70. with autocast():
  71. loss = model(**batch).loss
  72. loss = loss / gradient_accumulation_steps
  73. total_loss += loss.detach().float()
  74. if train_config.use_fp16:
  75. # if fp16 is enabled, use gradient scaler to handle gradient update
  76. scaler.scale(loss).backward()
  77. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  78. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  79. scaler.unscale_(optimizer)
  80. if train_config.enable_fsdp:
  81. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  82. else:
  83. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  84. scaler.step(optimizer)
  85. scaler.update()
  86. optimizer.zero_grad()
  87. pbar.update(1)
  88. else:
  89. # regular backpropagation when fp16 is not used
  90. loss.backward()
  91. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  92. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  93. if train_config.enable_fsdp:
  94. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  95. else:
  96. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  97. optimizer.step()
  98. optimizer.zero_grad()
  99. pbar.update(1)
  100. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  101. pbar.close()
  102. epoch_end_time = time.perf_counter()-epoch_start_time
  103. epoch_times.append(epoch_end_time)
  104. # Reducing total_loss across all devices if there's more than one CUDA device
  105. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  106. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  107. train_epoch_loss = total_loss / len(train_dataloader)
  108. if train_config.enable_fsdp:
  109. train_epoch_loss = train_epoch_loss/world_size
  110. train_perplexity = torch.exp(train_epoch_loss)
  111. train_prep.append(train_perplexity)
  112. train_loss.append(train_epoch_loss)
  113. if train_config.enable_fsdp:
  114. if rank==0:
  115. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  116. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  117. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  118. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  119. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  120. else:
  121. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  122. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  123. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  124. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  125. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  126. # Update the learning rate as needed
  127. lr_scheduler.step()
  128. if train_config.run_validation:
  129. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  130. checkpoint_start_time = time.perf_counter()
  131. if train_config.save_model and eval_epoch_loss < best_val_loss:
  132. if train_config.enable_fsdp:
  133. dist.barrier()
  134. if train_config.use_peft:
  135. if train_config.enable_fsdp:
  136. if rank==0:
  137. print(f"we are about to save the PEFT modules")
  138. else:
  139. print(f"we are about to save the PEFT modules")
  140. model.save_pretrained(train_config.output_dir)
  141. if train_config.enable_fsdp:
  142. if rank==0:
  143. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  144. else:
  145. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  146. else:
  147. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  148. save_model_checkpoint(
  149. model, optimizer, rank, train_config, epoch=epoch
  150. )
  151. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  152. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  153. print("=====================================================")
  154. save_model_and_optimizer_sharded(model, rank, train_config)
  155. if train_config.save_optimizer:
  156. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  157. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  158. print("=====================================================")
  159. if not train_config.use_peft and train_config.save_optimizer:
  160. save_optimizer_checkpoint(
  161. model, optimizer, rank, train_config, epoch=epoch
  162. )
  163. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  164. print("=====================================================")
  165. if train_config.enable_fsdp:
  166. dist.barrier()
  167. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  168. checkpoint_times.append(checkpoint_end_time)
  169. if eval_epoch_loss < best_val_loss:
  170. best_val_loss = eval_epoch_loss
  171. if train_config.enable_fsdp:
  172. if rank==0:
  173. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  174. else:
  175. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  176. val_loss.append(best_val_loss)
  177. val_prep.append(eval_ppl)
  178. if train_config.enable_fsdp:
  179. if rank==0:
  180. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  181. else:
  182. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  183. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  184. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  185. avg_train_prep = sum(train_prep)/len(train_prep)
  186. avg_train_loss = sum(train_loss)/len(train_loss)
  187. if train_config.run_validation:
  188. avg_eval_prep = sum(val_prep)/len(val_prep)
  189. avg_eval_loss = sum(val_loss)/len(val_loss)
  190. results['avg_train_prep'] = avg_train_prep
  191. results['avg_train_loss'] = avg_train_loss
  192. if train_config.run_validation:
  193. results['avg_eval_prep'] = avg_eval_prep
  194. results['avg_eval_loss'] = avg_eval_loss
  195. results["avg_epoch_time"] = avg_epoch_time
  196. results["avg_checkpoint_time"] = avg_checkpoint_time
  197. #saving the training params including fsdp setting for reference.
  198. if train_config.enable_fsdp and not train_config.use_peft:
  199. save_train_params(train_config, fsdp_config, rank)
  200. return results
  201. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  202. """
  203. Evaluates the model on the given dataloader
  204. Args:
  205. model: The model to evaluate
  206. eval_dataloader: The dataloader containing the evaluation data
  207. local_rank: The rank of the current node in a distributed setting
  208. tokenizer: The tokenizer used to decode predictions
  209. Returns: eval_ppl, eval_epoch_loss
  210. """
  211. if train_config.enable_fsdp:
  212. world_size = int(os.environ["WORLD_SIZE"])
  213. model.eval()
  214. eval_preds = []
  215. eval_loss = 0.0 # Initialize evaluation loss
  216. with MemoryTrace() as memtrace:
  217. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  218. for key in batch.keys():
  219. if train_config.enable_fsdp:
  220. batch[key] = batch[key].to(local_rank)
  221. else:
  222. batch[key] = batch[key].to('cuda:0')
  223. # Ensure no gradients are computed for this scope to save memory
  224. with torch.no_grad():
  225. # Forward pass and compute loss
  226. outputs = model(**batch)
  227. loss = outputs.loss
  228. eval_loss += loss.detach().float()
  229. # Decode predictions and add to evaluation predictions list
  230. preds = torch.argmax(outputs.logits, -1)
  231. eval_preds.extend(
  232. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  233. )
  234. # If there's more than one CUDA device, reduce evaluation loss across all devices
  235. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  236. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  237. # Compute average loss and perplexity
  238. eval_epoch_loss = eval_loss / len(eval_dataloader)
  239. if train_config.enable_fsdp:
  240. eval_epoch_loss = eval_epoch_loss/world_size
  241. eval_ppl = torch.exp(eval_epoch_loss)
  242. # Print evaluation metrics
  243. if train_config.enable_fsdp:
  244. if local_rank==0:
  245. print(f" {eval_ppl=} {eval_epoch_loss=}")
  246. else:
  247. print(f" {eval_ppl=} {eval_epoch_loss=}")
  248. return eval_ppl, eval_epoch_loss
  249. def freeze_transformer_layers(model, num_layer):
  250. for i, layer in enumerate(model.model.layers):
  251. if i < num_layer:
  252. for param in layer.parameters():
  253. param.requires_grad = False
  254. def check_frozen_layers_peft_model(model):
  255. for i, layer in enumerate(model.base_model.model.model.layers):
  256. for name, param in layer.named_parameters():
  257. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  258. def setup():
  259. """Initialize the process group for distributed training"""
  260. dist.init_process_group("nccl")
  261. def setup_environ_flags(rank):
  262. """Set environment flags for debugging purposes"""
  263. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  264. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  265. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  266. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  267. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  268. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  269. if rank == 0:
  270. print(f"--> Running with torch dist debug set to detail")
  271. def cleanup():
  272. """Clean up the process group after training"""
  273. dist.destroy_process_group()
  274. def clear_gpu_cache(rank=None):
  275. """Clear the GPU cache for all ranks"""
  276. if rank == 0:
  277. print(f"Clearing GPU cache for all ranks")
  278. torch.cuda.empty_cache()
  279. def get_parameter_dtypes(model):
  280. """Get the data types of model parameters"""
  281. parameter_dtypes = {}
  282. for name, parameter in model.named_parameters():
  283. parameter_dtypes[name] = parameter.dtype
  284. return parameter_dtypes
  285. def print_model_size(model, config, rank: int = 0) -> None:
  286. """
  287. Print model name, the number of trainable parameters and initialization time.
  288. Args:
  289. model: The PyTorch model.
  290. model_name (str): Name of the model.
  291. init_time_start (float): Initialization start time.
  292. init_time_end (float): Initialization end time.
  293. rank (int, optional): Current process's rank. Defaults to 0.
  294. """
  295. if rank == 0:
  296. print(f"--> Model {config.model_name}")
  297. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  298. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  299. def get_policies(cfg, rank):
  300. """Get the policies for mixed precision and fsdp wrapping"""
  301. verify_bfloat_support = (
  302. torch.version.cuda
  303. and torch.cuda.is_bf16_supported()
  304. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  305. and dist.is_nccl_available()
  306. and nccl.version() >= (2, 10)
  307. )
  308. mixed_precision_policy = None
  309. wrapping_policy = None
  310. # Mixed precision
  311. if cfg.mixed_precision:
  312. bf16_ready = verify_bfloat_support
  313. if bf16_ready and not cfg.use_fp16:
  314. mixed_precision_policy = bfSixteen
  315. if rank == 0:
  316. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  317. elif cfg.use_fp16:
  318. mixed_precision_policy = fpSixteen
  319. if rank == 0:
  320. print(f"FP16 enabled")
  321. else:
  322. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  323. wrapping_policy = get_llama_wrapper()
  324. return mixed_precision_policy, wrapping_policy
  325. def save_train_params(train_config, fsdp_config, rank):
  326. """
  327. This function saves the train_config and FSDP config into a train_params.yaml.
  328. This will be used by converter script in the inference folder to fetch the HF model name or path.
  329. It also would be hepful as a log for future references.
  330. """
  331. # Convert the train_config and fsdp_config objects to dictionaries,
  332. # converting all values to strings to ensure they can be serialized into a YAML file
  333. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  334. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  335. # Merge the two dictionaries into one
  336. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  337. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  338. folder_name = (
  339. train_config.dist_checkpoint_root_folder
  340. + "/"
  341. + train_config.dist_checkpoint_folder
  342. + "-"
  343. + train_config.model_name
  344. )
  345. save_dir = Path.cwd() / folder_name
  346. # If the directory does not exist, create it
  347. if not os.path.exists(save_dir):
  348. os.makedirs(save_dir)
  349. # Convert the dictionary to a YAML string
  350. config_yaml = yaml.dump(train_params_dict, indent=4)
  351. file_name = os.path.join(save_dir,'train_params.yaml')
  352. # Check if there's a directory with the same name as the file
  353. if os.path.isdir(file_name):
  354. print(f"Error: {file_name} is a directory, not a file.")
  355. else:
  356. # Write the YAML string to the file
  357. with open(file_name, 'w') as f:
  358. f.write(config_yaml)
  359. if rank==0:
  360. print(f"training params are saved in {file_name}")