train_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from pathlib import Path
  7. from pkg_resources import packaging
  8. import torch
  9. import torch.cuda.nccl as nccl
  10. import torch.distributed as dist
  11. from torch.distributed.fsdp import StateDictType
  12. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  13. from tqdm import tqdm
  14. from transformers import LlamaTokenizer
  15. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  16. from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
  17. from llama_recipes.utils.memory_utils import MemoryTrace
  18. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  19. tokenizer.pad_token_id = 0
  20. tokenizer.padding_side = "left"
  21. # Converting Bytes to Megabytes
  22. def byte2mb(x):
  23. return int(x / 2**20)
  24. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  25. """
  26. Trains the model on the given dataloader
  27. Args:
  28. model: The model to be trained
  29. train_dataloader: The dataloader containing the training data
  30. optimizer: The optimizer used for training
  31. lr_scheduler: The learning rate scheduler
  32. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  33. num_epochs: The number of epochs to train for
  34. local_rank: The rank of the current node in a distributed setting
  35. train_config: The training configuration
  36. eval_dataloader: The dataloader containing the eval data
  37. tokenizer: tokenizer used in the eval for decoding the predicitons
  38. Returns: results dictionary containing average training and validation perplexity and loss
  39. """
  40. # Create a gradient scaler for fp16
  41. if train_config.use_fp16 and train_config.enable_fsdp:
  42. scaler = ShardedGradScaler()
  43. elif train_config.use_fp16 and not train_config.enable_fsdp:
  44. scaler = torch.cuda.amp.GradScaler()
  45. if train_config.enable_fsdp:
  46. world_size = int(os.environ["WORLD_SIZE"])
  47. train_prep = []
  48. train_loss = []
  49. val_prep = []
  50. val_loss =[]
  51. epoch_times = []
  52. checkpoint_times = []
  53. results = {}
  54. best_val_loss = float("inf")
  55. for epoch in range(train_config.num_epochs):
  56. epoch_start_time = time.perf_counter()
  57. with MemoryTrace() as memtrace: # track the memory usage
  58. model.train()
  59. total_loss = 0.0
  60. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  61. for key in batch.keys():
  62. if train_config.enable_fsdp:
  63. batch[key] = batch[key].to(local_rank)
  64. else:
  65. batch[key] = batch[key].to('cuda:0')
  66. loss = model(**batch).loss
  67. loss = loss / gradient_accumulation_steps
  68. total_loss += loss.detach().float()
  69. if train_config.use_fp16:
  70. # if fp16 is enabled, use gradient scaler to handle gradient update
  71. scaler.scale(loss).backward()
  72. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  73. scaler.step(optimizer)
  74. scaler.update()
  75. optimizer.zero_grad()
  76. else:
  77. # regular backpropagation when fp16 is not used
  78. loss.backward()
  79. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  80. optimizer.step()
  81. optimizer.zero_grad()
  82. if train_config.enable_fsdp:
  83. if rank==0:
  84. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  85. else:
  86. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  87. epoch_end_time = time.perf_counter()-epoch_start_time
  88. epoch_times.append(epoch_end_time)
  89. # Reducing total_loss across all devices if there's more than one CUDA device
  90. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  91. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  92. train_epoch_loss = total_loss / len(train_dataloader)
  93. if train_config.enable_fsdp:
  94. train_epoch_loss = train_epoch_loss/world_size
  95. train_perplexity = torch.exp(train_epoch_loss)
  96. train_prep.append(train_perplexity)
  97. train_loss.append(train_epoch_loss)
  98. if train_config.enable_fsdp:
  99. if rank==0:
  100. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  101. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  102. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  103. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  104. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  105. else:
  106. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  107. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  108. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  109. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  110. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  111. # Update the learning rate as needed
  112. lr_scheduler.step()
  113. if train_config.run_validation:
  114. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  115. checkpoint_start_time = time.perf_counter()
  116. if train_config.save_model and eval_epoch_loss < best_val_loss:
  117. if train_config.enable_fsdp:
  118. dist.barrier()
  119. if train_config.use_peft:
  120. if train_config.enable_fsdp:
  121. if rank==0:
  122. print(f"we are about to save the PEFT modules")
  123. else:
  124. print(f"we are about to save the PEFT modules")
  125. model.save_pretrained(train_config.output_dir)
  126. if train_config.enable_fsdp:
  127. if rank==0:
  128. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  129. else:
  130. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  131. else:
  132. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  133. save_model_checkpoint(
  134. model, optimizer, rank, train_config, epoch=epoch
  135. )
  136. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  137. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  138. print("=====================================================")
  139. save_model_and_optimizer_sharded(model, rank, train_config)
  140. if train_config.save_optimizer:
  141. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  142. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  143. print("=====================================================")
  144. if not train_config.use_peft and train_config.save_optimizer:
  145. save_optimizer_checkpoint(
  146. model, optimizer, rank, train_config, epoch=epoch
  147. )
  148. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  149. print("=====================================================")
  150. if train_config.enable_fsdp:
  151. dist.barrier()
  152. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  153. checkpoint_times.append(checkpoint_end_time)
  154. if eval_epoch_loss < best_val_loss:
  155. best_val_loss = eval_epoch_loss
  156. if train_config.enable_fsdp:
  157. if rank==0:
  158. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  159. else:
  160. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  161. val_loss.append(best_val_loss)
  162. val_prep.append(eval_ppl)
  163. if train_config.enable_fsdp:
  164. if rank==0:
  165. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  166. else:
  167. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  168. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  169. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  170. avg_train_prep = sum(train_prep)/len(train_prep)
  171. avg_train_loss = sum(train_loss)/len(train_loss)
  172. if train_config.run_validation:
  173. avg_eval_prep = sum(val_prep)/len(val_prep)
  174. avg_eval_loss = sum(val_loss)/len(val_loss)
  175. results['avg_train_prep'] = avg_train_prep
  176. results['avg_train_loss'] = avg_train_loss
  177. if train_config.run_validation:
  178. results['avg_eval_prep'] = avg_eval_prep
  179. results['avg_eval_loss'] = avg_eval_loss
  180. results["avg_epoch_time"] = avg_epoch_time
  181. results["avg_checkpoint_time"] = avg_checkpoint_time
  182. #saving the training params including fsdp setting for reference.
  183. if train_config.enable_fsdp and not train_config.use_peft:
  184. save_train_params(train_config, fsdp_config, rank)
  185. return results
  186. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  187. """
  188. Evaluates the model on the given dataloader
  189. Args:
  190. model: The model to evaluate
  191. eval_dataloader: The dataloader containing the evaluation data
  192. local_rank: The rank of the current node in a distributed setting
  193. tokenizer: The tokenizer used to decode predictions
  194. Returns: eval_ppl, eval_epoch_loss
  195. """
  196. if train_config.enable_fsdp:
  197. world_size = int(os.environ["WORLD_SIZE"])
  198. model.eval()
  199. eval_preds = []
  200. eval_loss = 0.0 # Initialize evaluation loss
  201. with MemoryTrace() as memtrace:
  202. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  203. for key in batch.keys():
  204. if train_config.enable_fsdp:
  205. batch[key] = batch[key].to(local_rank)
  206. else:
  207. batch[key] = batch[key].to('cuda:0')
  208. # Ensure no gradients are computed for this scope to save memory
  209. with torch.no_grad():
  210. # Forward pass and compute loss
  211. outputs = model(**batch)
  212. loss = outputs.loss
  213. eval_loss += loss.detach().float()
  214. # Decode predictions and add to evaluation predictions list
  215. preds = torch.argmax(outputs.logits, -1)
  216. eval_preds.extend(
  217. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  218. )
  219. # If there's more than one CUDA device, reduce evaluation loss across all devices
  220. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  221. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  222. # Compute average loss and perplexity
  223. eval_epoch_loss = eval_loss / len(eval_dataloader)
  224. if train_config.enable_fsdp:
  225. eval_epoch_loss = eval_epoch_loss/world_size
  226. eval_ppl = torch.exp(eval_epoch_loss)
  227. # Print evaluation metrics
  228. if train_config.enable_fsdp:
  229. if local_rank==0:
  230. print(f" {eval_ppl=} {eval_epoch_loss=}")
  231. else:
  232. print(f" {eval_ppl=} {eval_epoch_loss=}")
  233. return eval_ppl, eval_epoch_loss
  234. def freeze_transformer_layers(model, num_layer):
  235. for i, layer in enumerate(model.model.layers):
  236. if i < num_layer:
  237. for param in layer.parameters():
  238. param.requires_grad = False
  239. def check_frozen_layers_peft_model(model):
  240. for i, layer in enumerate(model.base_model.model.model.layers):
  241. for name, param in layer.named_parameters():
  242. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  243. def setup():
  244. """Initialize the process group for distributed training"""
  245. dist.init_process_group("nccl")
  246. def setup_environ_flags(rank):
  247. """Set environment flags for debugging purposes"""
  248. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  249. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  250. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  251. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  252. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  253. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  254. if rank == 0:
  255. print(f"--> Running with torch dist debug set to detail")
  256. def cleanup():
  257. """Clean up the process group after training"""
  258. dist.destroy_process_group()
  259. def clear_gpu_cache(rank=None):
  260. """Clear the GPU cache for all ranks"""
  261. if rank == 0:
  262. print(f"Clearing GPU cache for all ranks")
  263. torch.cuda.empty_cache()
  264. def get_parameter_dtypes(model):
  265. """Get the data types of model parameters"""
  266. parameter_dtypes = {}
  267. for name, parameter in model.named_parameters():
  268. parameter_dtypes[name] = parameter.dtype
  269. return parameter_dtypes
  270. def print_model_size(model, config, rank: int = 0) -> None:
  271. """
  272. Print model name, the number of trainable parameters and initialization time.
  273. Args:
  274. model: The PyTorch model.
  275. model_name (str): Name of the model.
  276. init_time_start (float): Initialization start time.
  277. init_time_end (float): Initialization end time.
  278. rank (int, optional): Current process's rank. Defaults to 0.
  279. """
  280. if rank == 0:
  281. print(f"--> Model {config.model_name}")
  282. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  283. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  284. def get_policies(cfg, rank):
  285. """Get the policies for mixed precision and fsdp wrapping"""
  286. verify_bfloat_support = (
  287. torch.version.cuda
  288. and torch.cuda.is_bf16_supported()
  289. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  290. and dist.is_nccl_available()
  291. and nccl.version() >= (2, 10)
  292. )
  293. mixed_precision_policy = None
  294. wrapping_policy = None
  295. # Mixed precision
  296. if cfg.mixed_precision:
  297. bf16_ready = verify_bfloat_support
  298. if bf16_ready and not cfg.use_fp16:
  299. mixed_precision_policy = bfSixteen_mixed
  300. if rank == 0:
  301. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  302. elif cfg.use_fp16:
  303. mixed_precision_policy = fpSixteen
  304. if rank == 0:
  305. print(f"FP16 enabled")
  306. else:
  307. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  308. wrapping_policy = get_llama_wrapper()
  309. return mixed_precision_policy, wrapping_policy
  310. def save_train_params(train_config, fsdp_config, rank):
  311. """
  312. This function saves the train_config and FSDP config into a train_params.yaml.
  313. This will be used by converter script in the inference folder to fetch the HF model name or path.
  314. It also would be hepful as a log for future references.
  315. """
  316. # Convert the train_config and fsdp_config objects to dictionaries,
  317. # converting all values to strings to ensure they can be serialized into a YAML file
  318. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  319. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  320. # Merge the two dictionaries into one
  321. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  322. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  323. folder_name = (
  324. train_config.dist_checkpoint_root_folder
  325. + "/"
  326. + train_config.dist_checkpoint_folder
  327. + "-"
  328. + train_config.model_name
  329. )
  330. save_dir = Path.cwd() / folder_name
  331. # If the directory does not exist, create it
  332. if not os.path.exists(save_dir):
  333. os.makedirs(save_dir)
  334. # Convert the dictionary to a YAML string
  335. config_yaml = yaml.dump(train_params_dict, indent=4)
  336. file_name = os.path.join(save_dir,'train_params.yaml')
  337. # Check if there's a directory with the same name as the file
  338. if os.path.isdir(file_name):
  339. print(f"Error: {file_name} is a directory, not a file.")
  340. else:
  341. # Write the YAML string to the file
  342. with open(file_name, 'w') as f:
  343. f.write(config_yaml)
  344. if rank==0:
  345. print(f"training params are saved in {file_name}")