train_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from pkg_resources import packaging
  9. from datetime import datetime
  10. import torch
  11. import torch.cuda.nccl as nccl
  12. import torch.distributed as dist
  13. from torch.distributed.fsdp import StateDictType
  14. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  15. from tqdm import tqdm
  16. from transformers import LlamaTokenizer
  17. import json
  18. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  19. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  20. from llama_recipes.utils.memory_utils import MemoryTrace
  21. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  22. tokenizer.pad_token_id = 0
  23. tokenizer.padding_side = "left"
  24. # Converting Bytes to Megabytes
  25. def byte2mb(x):
  26. return int(x / 2**20)
  27. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  28. """
  29. Trains the model on the given dataloader
  30. Args:
  31. model: The model to be trained
  32. train_dataloader: The dataloader containing the training data
  33. optimizer: The optimizer used for training
  34. lr_scheduler: The learning rate scheduler
  35. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  36. num_epochs: The number of epochs to train for
  37. local_rank: The rank of the current node in a distributed setting
  38. train_config: The training configuration
  39. eval_dataloader: The dataloader containing the eval data
  40. tokenizer: tokenizer used in the eval for decoding the predicitons
  41. Returns: results dictionary containing average training and validation perplexity and loss
  42. """
  43. # Create a gradient scaler for fp16
  44. if train_config.use_fp16 and train_config.enable_fsdp:
  45. scaler = ShardedGradScaler()
  46. elif train_config.use_fp16 and not train_config.enable_fsdp:
  47. scaler = torch.cuda.amp.GradScaler()
  48. if train_config.enable_fsdp:
  49. world_size = int(os.environ["WORLD_SIZE"])
  50. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  51. train_prep = []
  52. train_loss = []
  53. val_prep = []
  54. val_loss =[]
  55. if train_config.save_metrics:
  56. metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
  57. train_step_perplexity = []
  58. train_step_loss = []
  59. val_step_loss = []
  60. val_step_perplexity = []
  61. epoch_times = []
  62. checkpoint_times = []
  63. results = {}
  64. best_val_loss = float("inf")
  65. for epoch in range(train_config.num_epochs):
  66. epoch_start_time = time.perf_counter()
  67. with MemoryTrace() as memtrace: # track the memory usage
  68. model.train()
  69. total_loss = 0.0
  70. total_length = len(train_dataloader)//gradient_accumulation_steps
  71. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  72. for step, batch in enumerate(train_dataloader):
  73. for key in batch.keys():
  74. if train_config.enable_fsdp:
  75. batch[key] = batch[key].to(local_rank)
  76. else:
  77. batch[key] = batch[key].to('cuda:0')
  78. with autocast():
  79. loss = model(**batch).loss
  80. loss = loss / gradient_accumulation_steps
  81. if train_config.save_metrics:
  82. train_step_loss.append(loss.detach().float().item())
  83. train_step_perplexity.append(float(torch.exp(loss.detach().float())))
  84. total_loss += loss.detach().float()
  85. if train_config.use_fp16:
  86. # if fp16 is enabled, use gradient scaler to handle gradient update
  87. scaler.scale(loss).backward()
  88. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  89. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  90. scaler.unscale_(optimizer)
  91. if train_config.enable_fsdp:
  92. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  93. else:
  94. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  95. scaler.step(optimizer)
  96. scaler.update()
  97. optimizer.zero_grad()
  98. pbar.update(1)
  99. else:
  100. # regular backpropagation when fp16 is not used
  101. loss.backward()
  102. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  103. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  104. if train_config.enable_fsdp:
  105. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  106. else:
  107. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  108. optimizer.step()
  109. optimizer.zero_grad()
  110. pbar.update(1)
  111. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  112. if train_config.save_metrics:
  113. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  114. pbar.close()
  115. epoch_end_time = time.perf_counter()-epoch_start_time
  116. epoch_times.append(epoch_end_time)
  117. # Reducing total_loss across all devices if there's more than one CUDA device
  118. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  119. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  120. train_epoch_loss = total_loss / len(train_dataloader)
  121. if train_config.enable_fsdp:
  122. train_epoch_loss = train_epoch_loss/world_size
  123. train_perplexity = torch.exp(train_epoch_loss)
  124. train_prep.append(float(train_perplexity))
  125. train_loss.append(float(train_epoch_loss))
  126. if train_config.enable_fsdp:
  127. if rank==0:
  128. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  129. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  130. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  131. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  132. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  133. else:
  134. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  135. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  136. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  137. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  138. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  139. # Update the learning rate as needed
  140. lr_scheduler.step()
  141. if train_config.run_validation:
  142. eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  143. if train_config.save_metrics:
  144. val_step_loss.extend(temp_val_loss)
  145. val_step_perplexity.extend(temp_step_perplexity)
  146. checkpoint_start_time = time.perf_counter()
  147. if train_config.save_model and eval_epoch_loss < best_val_loss:
  148. if train_config.enable_fsdp:
  149. dist.barrier()
  150. if train_config.use_peft:
  151. if train_config.enable_fsdp:
  152. if rank==0:
  153. print(f"we are about to save the PEFT modules")
  154. else:
  155. print(f"we are about to save the PEFT modules")
  156. model.save_pretrained(train_config.output_dir)
  157. if train_config.enable_fsdp:
  158. if rank==0:
  159. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  160. else:
  161. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  162. else:
  163. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  164. save_model_checkpoint(
  165. model, optimizer, rank, train_config, epoch=epoch
  166. )
  167. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  168. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  169. print("=====================================================")
  170. save_model_and_optimizer_sharded(model, rank, train_config)
  171. if train_config.save_optimizer:
  172. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  173. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  174. print("=====================================================")
  175. if not train_config.use_peft and train_config.save_optimizer:
  176. save_optimizer_checkpoint(
  177. model, optimizer, rank, train_config, epoch=epoch
  178. )
  179. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  180. print("=====================================================")
  181. if train_config.enable_fsdp:
  182. dist.barrier()
  183. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  184. checkpoint_times.append(checkpoint_end_time)
  185. if eval_epoch_loss < best_val_loss:
  186. best_val_loss = eval_epoch_loss
  187. if train_config.enable_fsdp:
  188. if rank==0:
  189. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  190. else:
  191. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  192. val_loss.append(float(best_val_loss))
  193. val_prep.append(float(eval_ppl))
  194. if train_config.enable_fsdp:
  195. if rank==0:
  196. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  197. else:
  198. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  199. # Saving the results every epoch to plot later
  200. if train_config.save_metrics:
  201. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  202. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  203. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  204. avg_train_prep = sum(train_prep)/len(train_prep)
  205. avg_train_loss = sum(train_loss)/len(train_loss)
  206. if train_config.run_validation:
  207. avg_eval_prep = sum(val_prep)/len(val_prep)
  208. avg_eval_loss = sum(val_loss)/len(val_loss)
  209. results['avg_train_prep'] = avg_train_prep
  210. results['avg_train_loss'] = avg_train_loss
  211. if train_config.run_validation:
  212. results['avg_eval_prep'] = avg_eval_prep
  213. results['avg_eval_loss'] = avg_eval_loss
  214. results["avg_epoch_time"] = avg_epoch_time
  215. results["avg_checkpoint_time"] = avg_checkpoint_time
  216. if train_config.save_metrics:
  217. results["metrics_filename"] = metrics_filename
  218. #saving the training params including fsdp setting for reference.
  219. if train_config.enable_fsdp and not train_config.use_peft:
  220. save_train_params(train_config, fsdp_config, rank)
  221. return results
  222. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  223. """
  224. Evaluates the model on the given dataloader
  225. Args:
  226. model: The model to evaluate
  227. eval_dataloader: The dataloader containing the evaluation data
  228. local_rank: The rank of the current node in a distributed setting
  229. tokenizer: The tokenizer used to decode predictions
  230. Returns: eval_ppl, eval_epoch_loss
  231. """
  232. if train_config.enable_fsdp:
  233. world_size = int(os.environ["WORLD_SIZE"])
  234. model.eval()
  235. eval_preds = []
  236. val_step_loss = []
  237. val_step_perplexity = []
  238. eval_loss = 0.0 # Initialize evaluation loss
  239. with MemoryTrace() as memtrace:
  240. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  241. for key in batch.keys():
  242. if train_config.enable_fsdp:
  243. batch[key] = batch[key].to(local_rank)
  244. else:
  245. batch[key] = batch[key].to('cuda:0')
  246. # Ensure no gradients are computed for this scope to save memory
  247. with torch.no_grad():
  248. # Forward pass and compute loss
  249. outputs = model(**batch)
  250. loss = outputs.loss
  251. if train_config.save_metrics:
  252. val_step_loss.append(loss.detach().float().item())
  253. val_step_perplexity.append(float(torch.exp(loss.detach().float())))
  254. eval_loss += loss.detach().float()
  255. # Decode predictions and add to evaluation predictions list
  256. preds = torch.argmax(outputs.logits, -1)
  257. eval_preds.extend(
  258. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  259. )
  260. # If there's more than one CUDA device, reduce evaluation loss across all devices
  261. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  262. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  263. # Compute average loss and perplexity
  264. eval_epoch_loss = eval_loss / len(eval_dataloader)
  265. if train_config.enable_fsdp:
  266. eval_epoch_loss = eval_epoch_loss/world_size
  267. eval_ppl = torch.exp(eval_epoch_loss)
  268. # Print evaluation metrics
  269. if train_config.enable_fsdp:
  270. if local_rank==0:
  271. print(f" {eval_ppl=} {eval_epoch_loss=}")
  272. else:
  273. print(f" {eval_ppl=} {eval_epoch_loss=}")
  274. return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
  275. def freeze_transformer_layers(model, num_layer):
  276. for i, layer in enumerate(model.model.layers):
  277. if i < num_layer:
  278. for param in layer.parameters():
  279. param.requires_grad = False
  280. def check_frozen_layers_peft_model(model):
  281. for i, layer in enumerate(model.base_model.model.model.layers):
  282. for name, param in layer.named_parameters():
  283. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  284. def setup():
  285. """Initialize the process group for distributed training"""
  286. dist.init_process_group("nccl")
  287. def setup_environ_flags(rank):
  288. """Set environment flags for debugging purposes"""
  289. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  290. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  291. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  292. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  293. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  294. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  295. if rank == 0:
  296. print(f"--> Running with torch dist debug set to detail")
  297. def cleanup():
  298. """Clean up the process group after training"""
  299. dist.destroy_process_group()
  300. def clear_gpu_cache(rank=None):
  301. """Clear the GPU cache for all ranks"""
  302. if rank == 0:
  303. print(f"Clearing GPU cache for all ranks")
  304. torch.cuda.empty_cache()
  305. def get_parameter_dtypes(model):
  306. """Get the data types of model parameters"""
  307. parameter_dtypes = {}
  308. for name, parameter in model.named_parameters():
  309. parameter_dtypes[name] = parameter.dtype
  310. return parameter_dtypes
  311. def print_model_size(model, config, rank: int = 0) -> None:
  312. """
  313. Print model name, the number of trainable parameters and initialization time.
  314. Args:
  315. model: The PyTorch model.
  316. model_name (str): Name of the model.
  317. init_time_start (float): Initialization start time.
  318. init_time_end (float): Initialization end time.
  319. rank (int, optional): Current process's rank. Defaults to 0.
  320. """
  321. if rank == 0:
  322. print(f"--> Model {config.model_name}")
  323. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  324. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  325. def get_policies(cfg, rank):
  326. """Get the policies for mixed precision and fsdp wrapping"""
  327. verify_bfloat_support = (
  328. torch.version.cuda
  329. and torch.cuda.is_bf16_supported()
  330. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  331. and dist.is_nccl_available()
  332. and nccl.version() >= (2, 10)
  333. )
  334. mixed_precision_policy = None
  335. wrapping_policy = None
  336. # Mixed precision
  337. if cfg.mixed_precision:
  338. bf16_ready = verify_bfloat_support
  339. if bf16_ready and not cfg.use_fp16:
  340. mixed_precision_policy = bfSixteen
  341. if rank == 0:
  342. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  343. elif cfg.use_fp16:
  344. mixed_precision_policy = fpSixteen
  345. if rank == 0:
  346. print(f"FP16 enabled")
  347. else:
  348. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  349. wrapping_policy = get_llama_wrapper()
  350. return mixed_precision_policy, wrapping_policy
  351. def save_train_params(train_config, fsdp_config, rank):
  352. """
  353. This function saves the train_config and FSDP config into a train_params.yaml.
  354. This will be used by converter script in the inference folder to fetch the HF model name or path.
  355. It also would be hepful as a log for future references.
  356. """
  357. # Convert the train_config and fsdp_config objects to dictionaries,
  358. # converting all values to strings to ensure they can be serialized into a YAML file
  359. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  360. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  361. # Merge the two dictionaries into one
  362. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  363. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  364. folder_name = (
  365. train_config.dist_checkpoint_root_folder
  366. + "/"
  367. + train_config.dist_checkpoint_folder
  368. + "-"
  369. + train_config.model_name
  370. )
  371. save_dir = Path.cwd() / folder_name
  372. # If the directory does not exist, create it
  373. if not os.path.exists(save_dir):
  374. os.makedirs(save_dir)
  375. # Convert the dictionary to a YAML string
  376. config_yaml = yaml.dump(train_params_dict, indent=4)
  377. file_name = os.path.join(save_dir,'train_params.yaml')
  378. # Check if there's a directory with the same name as the file
  379. if os.path.isdir(file_name):
  380. print(f"Error: {file_name} is a directory, not a file.")
  381. else:
  382. # Write the YAML string to the file
  383. with open(file_name, 'w') as f:
  384. f.write(config_yaml)
  385. if rank==0:
  386. print(f"training params are saved in {file_name}")
  387. def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
  388. metrics_data = {
  389. "train_step_loss": train_step_loss,
  390. "train_epoch_loss": train_epoch_loss,
  391. "train_step_perplexity": train_step_ppl,
  392. "train_epoch_perplexity": train_epoch_ppl,
  393. "val_step_loss": val_step_loss,
  394. "val_epoch_loss": val_epoch_loss,
  395. "val_step_perplexity": val_step_ppl,
  396. "val_epoch_perplexity": val_epoch_ppl
  397. }
  398. with open(output_filename, "w") as f:
  399. json.dump(metrics_data, f)