train_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from pathlib import Path
  7. from pkg_resources import packaging
  8. import torch
  9. import torch.cuda.nccl as nccl
  10. import torch.distributed as dist
  11. from torch.distributed.fsdp import StateDictType
  12. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  13. from tqdm import tqdm
  14. from transformers import LlamaTokenizer
  15. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  16. from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
  17. from llama_recipes.utils.memory_utils import MemoryTrace
  18. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  19. tokenizer.pad_token_id = 0
  20. tokenizer.padding_side = "left"
  21. # Converting Bytes to Megabytes
  22. def byte2mb(x):
  23. return int(x / 2**20)
  24. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  25. """
  26. Trains the model on the given dataloader
  27. Args:
  28. model: The model to be trained
  29. train_dataloader: The dataloader containing the training data
  30. optimizer: The optimizer used for training
  31. lr_scheduler: The learning rate scheduler
  32. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  33. num_epochs: The number of epochs to train for
  34. local_rank: The rank of the current node in a distributed setting
  35. train_config: The training configuration
  36. eval_dataloader: The dataloader containing the eval data
  37. tokenizer: tokenizer used in the eval for decoding the predicitons
  38. Returns: results dictionary containing average training and validation perplexity and loss
  39. """
  40. # Create a gradient scaler for fp16
  41. if train_config.use_fp16 and train_config.enable_fsdp:
  42. scaler = ShardedGradScaler()
  43. elif train_config.use_fp16 and not train_config.enable_fsdp:
  44. scaler = torch.cuda.amp.GradScaler()
  45. if train_config.enable_fsdp:
  46. world_size = int(os.environ["WORLD_SIZE"])
  47. train_prep = []
  48. train_loss = []
  49. val_prep = []
  50. val_loss =[]
  51. epoch_times = []
  52. checkpoint_times = []
  53. results = {}
  54. best_val_loss = float("inf")
  55. for epoch in range(train_config.num_epochs):
  56. epoch_start_time = time.perf_counter()
  57. with MemoryTrace() as memtrace: # track the memory usage
  58. model.train()
  59. total_loss = 0.0
  60. total_length = len(train_dataloader)//gradient_accumulation_steps
  61. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  62. for step, batch in enumerate(train_dataloader):
  63. for key in batch.keys():
  64. if train_config.enable_fsdp:
  65. batch[key] = batch[key].to(local_rank)
  66. else:
  67. batch[key] = batch[key].to('cuda:0')
  68. loss = model(**batch).loss
  69. loss = loss / gradient_accumulation_steps
  70. total_loss += loss.detach().float()
  71. if train_config.use_fp16:
  72. # if fp16 is enabled, use gradient scaler to handle gradient update
  73. scaler.scale(loss).backward()
  74. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  75. scaler.step(optimizer)
  76. scaler.update()
  77. optimizer.zero_grad()
  78. pbar.update(1)
  79. else:
  80. # regular backpropagation when fp16 is not used
  81. loss.backward()
  82. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  83. optimizer.step()
  84. optimizer.zero_grad()
  85. pbar.update(1)
  86. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  87. pbar.close()
  88. epoch_end_time = time.perf_counter()-epoch_start_time
  89. epoch_times.append(epoch_end_time)
  90. # Reducing total_loss across all devices if there's more than one CUDA device
  91. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  92. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  93. train_epoch_loss = total_loss / len(train_dataloader)
  94. if train_config.enable_fsdp:
  95. train_epoch_loss = train_epoch_loss/world_size
  96. train_perplexity = torch.exp(train_epoch_loss)
  97. train_prep.append(train_perplexity)
  98. train_loss.append(train_epoch_loss)
  99. if train_config.enable_fsdp:
  100. if rank==0:
  101. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  102. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  103. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  104. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  105. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  106. else:
  107. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  108. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  109. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  110. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  111. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  112. # Update the learning rate as needed
  113. lr_scheduler.step()
  114. if train_config.run_validation:
  115. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  116. checkpoint_start_time = time.perf_counter()
  117. if train_config.save_model and eval_epoch_loss < best_val_loss:
  118. if train_config.enable_fsdp:
  119. dist.barrier()
  120. if train_config.use_peft:
  121. if train_config.enable_fsdp:
  122. if rank==0:
  123. print(f"we are about to save the PEFT modules")
  124. else:
  125. print(f"we are about to save the PEFT modules")
  126. model.save_pretrained(train_config.output_dir)
  127. if train_config.enable_fsdp:
  128. if rank==0:
  129. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  130. else:
  131. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  132. else:
  133. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  134. save_model_checkpoint(
  135. model, optimizer, rank, train_config, epoch=epoch
  136. )
  137. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  138. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  139. print("=====================================================")
  140. save_model_and_optimizer_sharded(model, rank, train_config)
  141. if train_config.save_optimizer:
  142. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  143. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  144. print("=====================================================")
  145. if not train_config.use_peft and train_config.save_optimizer:
  146. save_optimizer_checkpoint(
  147. model, optimizer, rank, train_config, epoch=epoch
  148. )
  149. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  150. print("=====================================================")
  151. if train_config.enable_fsdp:
  152. dist.barrier()
  153. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  154. checkpoint_times.append(checkpoint_end_time)
  155. if eval_epoch_loss < best_val_loss:
  156. best_val_loss = eval_epoch_loss
  157. if train_config.enable_fsdp:
  158. if rank==0:
  159. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  160. else:
  161. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  162. val_loss.append(best_val_loss)
  163. val_prep.append(eval_ppl)
  164. if train_config.enable_fsdp:
  165. if rank==0:
  166. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  167. else:
  168. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  169. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  170. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  171. avg_train_prep = sum(train_prep)/len(train_prep)
  172. avg_train_loss = sum(train_loss)/len(train_loss)
  173. if train_config.run_validation:
  174. avg_eval_prep = sum(val_prep)/len(val_prep)
  175. avg_eval_loss = sum(val_loss)/len(val_loss)
  176. results['avg_train_prep'] = avg_train_prep
  177. results['avg_train_loss'] = avg_train_loss
  178. if train_config.run_validation:
  179. results['avg_eval_prep'] = avg_eval_prep
  180. results['avg_eval_loss'] = avg_eval_loss
  181. results["avg_epoch_time"] = avg_epoch_time
  182. results["avg_checkpoint_time"] = avg_checkpoint_time
  183. #saving the training params including fsdp setting for reference.
  184. if train_config.enable_fsdp and not train_config.use_peft:
  185. save_train_params(train_config, fsdp_config, rank)
  186. return results
  187. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  188. """
  189. Evaluates the model on the given dataloader
  190. Args:
  191. model: The model to evaluate
  192. eval_dataloader: The dataloader containing the evaluation data
  193. local_rank: The rank of the current node in a distributed setting
  194. tokenizer: The tokenizer used to decode predictions
  195. Returns: eval_ppl, eval_epoch_loss
  196. """
  197. if train_config.enable_fsdp:
  198. world_size = int(os.environ["WORLD_SIZE"])
  199. model.eval()
  200. eval_preds = []
  201. eval_loss = 0.0 # Initialize evaluation loss
  202. with MemoryTrace() as memtrace:
  203. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  204. for key in batch.keys():
  205. if train_config.enable_fsdp:
  206. batch[key] = batch[key].to(local_rank)
  207. else:
  208. batch[key] = batch[key].to('cuda:0')
  209. # Ensure no gradients are computed for this scope to save memory
  210. with torch.no_grad():
  211. # Forward pass and compute loss
  212. outputs = model(**batch)
  213. loss = outputs.loss
  214. eval_loss += loss.detach().float()
  215. # Decode predictions and add to evaluation predictions list
  216. preds = torch.argmax(outputs.logits, -1)
  217. eval_preds.extend(
  218. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  219. )
  220. # If there's more than one CUDA device, reduce evaluation loss across all devices
  221. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  222. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  223. # Compute average loss and perplexity
  224. eval_epoch_loss = eval_loss / len(eval_dataloader)
  225. if train_config.enable_fsdp:
  226. eval_epoch_loss = eval_epoch_loss/world_size
  227. eval_ppl = torch.exp(eval_epoch_loss)
  228. # Print evaluation metrics
  229. if train_config.enable_fsdp:
  230. if local_rank==0:
  231. print(f" {eval_ppl=} {eval_epoch_loss=}")
  232. else:
  233. print(f" {eval_ppl=} {eval_epoch_loss=}")
  234. return eval_ppl, eval_epoch_loss
  235. def freeze_transformer_layers(model, num_layer):
  236. for i, layer in enumerate(model.model.layers):
  237. if i < num_layer:
  238. for param in layer.parameters():
  239. param.requires_grad = False
  240. def check_frozen_layers_peft_model(model):
  241. for i, layer in enumerate(model.base_model.model.model.layers):
  242. for name, param in layer.named_parameters():
  243. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  244. def setup():
  245. """Initialize the process group for distributed training"""
  246. dist.init_process_group("nccl")
  247. def setup_environ_flags(rank):
  248. """Set environment flags for debugging purposes"""
  249. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  250. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  251. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  252. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  253. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  254. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  255. if rank == 0:
  256. print(f"--> Running with torch dist debug set to detail")
  257. def cleanup():
  258. """Clean up the process group after training"""
  259. dist.destroy_process_group()
  260. def clear_gpu_cache(rank=None):
  261. """Clear the GPU cache for all ranks"""
  262. if rank == 0:
  263. print(f"Clearing GPU cache for all ranks")
  264. torch.cuda.empty_cache()
  265. def get_parameter_dtypes(model):
  266. """Get the data types of model parameters"""
  267. parameter_dtypes = {}
  268. for name, parameter in model.named_parameters():
  269. parameter_dtypes[name] = parameter.dtype
  270. return parameter_dtypes
  271. def print_model_size(model, config, rank: int = 0) -> None:
  272. """
  273. Print model name, the number of trainable parameters and initialization time.
  274. Args:
  275. model: The PyTorch model.
  276. model_name (str): Name of the model.
  277. init_time_start (float): Initialization start time.
  278. init_time_end (float): Initialization end time.
  279. rank (int, optional): Current process's rank. Defaults to 0.
  280. """
  281. if rank == 0:
  282. print(f"--> Model {config.model_name}")
  283. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  284. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  285. def get_policies(cfg, rank):
  286. """Get the policies for mixed precision and fsdp wrapping"""
  287. verify_bfloat_support = (
  288. torch.version.cuda
  289. and torch.cuda.is_bf16_supported()
  290. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  291. and dist.is_nccl_available()
  292. and nccl.version() >= (2, 10)
  293. )
  294. mixed_precision_policy = None
  295. wrapping_policy = None
  296. # Mixed precision
  297. if cfg.mixed_precision:
  298. bf16_ready = verify_bfloat_support
  299. if bf16_ready and not cfg.use_fp16:
  300. mixed_precision_policy = bfSixteen_mixed
  301. if rank == 0:
  302. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  303. elif cfg.use_fp16:
  304. mixed_precision_policy = fpSixteen
  305. if rank == 0:
  306. print(f"FP16 enabled")
  307. else:
  308. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  309. wrapping_policy = get_llama_wrapper()
  310. return mixed_precision_policy, wrapping_policy
  311. def save_train_params(train_config, fsdp_config, rank):
  312. """
  313. This function saves the train_config and FSDP config into a train_params.yaml.
  314. This will be used by converter script in the inference folder to fetch the HF model name or path.
  315. It also would be hepful as a log for future references.
  316. """
  317. # Convert the train_config and fsdp_config objects to dictionaries,
  318. # converting all values to strings to ensure they can be serialized into a YAML file
  319. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  320. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  321. # Merge the two dictionaries into one
  322. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  323. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  324. folder_name = (
  325. train_config.dist_checkpoint_root_folder
  326. + "/"
  327. + train_config.dist_checkpoint_folder
  328. + "-"
  329. + train_config.model_name
  330. )
  331. save_dir = Path.cwd() / folder_name
  332. # If the directory does not exist, create it
  333. if not os.path.exists(save_dir):
  334. os.makedirs(save_dir)
  335. # Convert the dictionary to a YAML string
  336. config_yaml = yaml.dump(train_params_dict, indent=4)
  337. file_name = os.path.join(save_dir,'train_params.yaml')
  338. # Check if there's a directory with the same name as the file
  339. if os.path.isdir(file_name):
  340. print(f"Error: {file_name} is a directory, not a file.")
  341. else:
  342. # Write the YAML string to the file
  343. with open(file_name, 'w') as f:
  344. f.write(config_yaml)
  345. if rank==0:
  346. print(f"training params are saved in {file_name}")