train_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from pathlib import Path
  7. from pkg_resources import packaging
  8. import contextlib
  9. import gc
  10. import torch
  11. import torch.cuda.nccl as nccl
  12. import torch.distributed as dist
  13. from torch.distributed.fsdp import StateDictType
  14. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  15. # from torch.utils.flop_counter import FlopCounterMode
  16. from tqdm import tqdm
  17. from transformers import LlamaTokenizer
  18. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  19. from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
  20. from llama_recipes.utils.memory_utils import MemoryTrace
  21. from llama_recipes.utils.tflop_counter import FlopCounterMode
  22. @contextlib.contextmanager
  23. def maybe_run_profiler(cfg, *args, **kwargs):
  24. use_profiler: bool = cfg.profiler
  25. if use_profiler:
  26. with torch.profiler.profile(
  27. activities=[
  28. torch.profiler.ProfilerActivity.CPU,
  29. torch.profiler.ProfilerActivity.CUDA,
  30. ],
  31. schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
  32. on_trace_ready=torch.profiler.tensorboard_trace_handler(
  33. cfg.profile_output_dir
  34. ),
  35. profile_memory=True,
  36. with_stack=False,
  37. record_shapes=True,
  38. ) as torch_profiler:
  39. yield torch_profiler
  40. else:
  41. torch_profiler = contextlib.nullcontext()
  42. yield None
  43. def get_total_flops(model):
  44. return (sum([v for _, v in model.flop_counts["Global"].items()]))
  45. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  46. tokenizer.pad_token_id = 0
  47. tokenizer.padding_side = "left"
  48. # Converting Bytes to Megabytes
  49. def byte2mb(x):
  50. return int(x / 2**20)
  51. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  52. """
  53. Trains the model on the given dataloader
  54. Args:
  55. model: The model to be trained
  56. train_dataloader: The dataloader containing the training data
  57. optimizer: The optimizer used for training
  58. lr_scheduler: The learning rate scheduler
  59. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  60. num_epochs: The number of epochs to train for
  61. local_rank: The rank of the current node in a distributed setting
  62. train_config: The training configuration
  63. eval_dataloader: The dataloader containing the eval data
  64. tokenizer: tokenizer used in the eval for decoding the predicitons
  65. Returns: results dictionary containing average training and validation perplexity and loss
  66. """
  67. # Create a gradient scaler for fp16
  68. if train_config.use_fp16 and train_config.enable_fsdp:
  69. scaler = ShardedGradScaler()
  70. elif train_config.use_fp16 and not train_config.enable_fsdp:
  71. scaler = torch.cuda.amp.GradScaler()
  72. if train_config.enable_fsdp:
  73. world_size = int(os.environ["WORLD_SIZE"])
  74. train_prep = []
  75. train_loss = []
  76. val_prep = []
  77. val_loss =[]
  78. epoch_times = []
  79. checkpoint_times = []
  80. results = {}
  81. best_val_loss = float("inf")
  82. for epoch in range(train_config.num_epochs):
  83. epoch_start_time = time.perf_counter()
  84. with MemoryTrace() as memtrace: # track the memory usage
  85. model.train()
  86. total_loss = 0.0
  87. total_length = len(train_dataloader)//gradient_accumulation_steps
  88. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  89. with maybe_run_profiler(train_config) as torch_profiler:
  90. for step, batch in enumerate(train_dataloader):
  91. if step > 5:
  92. break
  93. gc.collect(1)
  94. for key in batch.keys():
  95. if train_config.enable_fsdp:
  96. batch[key] = batch[key].to(local_rank)
  97. else:
  98. batch[key] = batch[key].to('cuda:0')
  99. flop_check_done = False
  100. if train_config.flop_counter and step == 3 and not flop_check_done:
  101. flop_counter = FlopCounterMode(rank=local_rank)
  102. with flop_counter:
  103. loss = model(**batch).loss
  104. loss = loss / gradient_accumulation_steps
  105. total_loss += loss.detach().float()
  106. if train_config.use_fp16:
  107. # if fp16 is enabled, use gradient scaler to handle gradient update
  108. scaler.scale(loss).backward()
  109. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  110. scaler.step(optimizer)
  111. scaler.update()
  112. optimizer.zero_grad()
  113. pbar.update(1)
  114. else:
  115. # regular backpropagation when fp16 is not used
  116. loss.backward()
  117. TFlops = get_total_flops(flop_counter) / 1e12
  118. flop_check_done = True
  119. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  120. optimizer.step()
  121. optimizer.zero_grad()
  122. pbar.update(1)
  123. else:
  124. loss = model(**batch).loss
  125. loss = loss / gradient_accumulation_steps
  126. total_loss += loss.detach().float()
  127. if train_config.use_fp16:
  128. # if fp16 is enabled, use gradient scaler to handle gradient update
  129. scaler.scale(loss).backward()
  130. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  131. scaler.step(optimizer)
  132. scaler.update()
  133. optimizer.zero_grad()
  134. pbar.update(1)
  135. else:
  136. # regular backpropagation when fp16 is not used
  137. loss.backward()
  138. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  139. optimizer.step()
  140. optimizer.zero_grad()
  141. pbar.update(1)
  142. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  143. pbar.close()
  144. epoch_end_time = time.perf_counter()-epoch_start_time
  145. epoch_times.append(epoch_end_time)
  146. # Reducing total_loss across all devices if there's more than one CUDA device
  147. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  148. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  149. train_epoch_loss = total_loss / len(train_dataloader)
  150. if train_config.enable_fsdp:
  151. train_epoch_loss = train_epoch_loss/world_size
  152. train_perplexity = torch.exp(train_epoch_loss)
  153. train_prep.append(train_perplexity)
  154. train_loss.append(train_epoch_loss)
  155. if train_config.enable_fsdp:
  156. if rank==0:
  157. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  158. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  159. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  160. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  161. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  162. else:
  163. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  164. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  165. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  166. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  167. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  168. # Update the learning rate as needed
  169. lr_scheduler.step()
  170. if train_config.run_validation:
  171. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  172. checkpoint_start_time = time.perf_counter()
  173. if train_config.save_model and eval_epoch_loss < best_val_loss:
  174. if train_config.enable_fsdp:
  175. dist.barrier()
  176. if train_config.use_peft:
  177. if train_config.enable_fsdp:
  178. if rank==0:
  179. print(f"we are about to save the PEFT modules")
  180. else:
  181. print(f"we are about to save the PEFT modules")
  182. model.save_pretrained(train_config.output_dir)
  183. if train_config.enable_fsdp:
  184. if rank==0:
  185. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  186. else:
  187. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  188. else:
  189. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  190. save_model_checkpoint(
  191. model, optimizer, rank, train_config, epoch=epoch
  192. )
  193. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  194. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  195. print("=====================================================")
  196. save_model_and_optimizer_sharded(model, rank, train_config)
  197. if train_config.save_optimizer:
  198. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  199. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  200. print("=====================================================")
  201. if not train_config.use_peft and train_config.save_optimizer:
  202. save_optimizer_checkpoint(
  203. model, optimizer, rank, train_config, epoch=epoch
  204. )
  205. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  206. print("=====================================================")
  207. if train_config.enable_fsdp:
  208. dist.barrier()
  209. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  210. checkpoint_times.append(checkpoint_end_time)
  211. if eval_epoch_loss < best_val_loss:
  212. best_val_loss = eval_epoch_loss
  213. if train_config.enable_fsdp:
  214. if rank==0:
  215. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  216. else:
  217. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  218. val_loss.append(best_val_loss)
  219. val_prep.append(eval_ppl)
  220. if train_config.enable_fsdp:
  221. if rank==0:
  222. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  223. else:
  224. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  225. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  226. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  227. avg_train_prep = sum(train_prep)/len(train_prep)
  228. avg_train_loss = sum(train_loss)/len(train_loss)
  229. if train_config.run_validation:
  230. avg_eval_prep = sum(val_prep)/len(val_prep)
  231. avg_eval_loss = sum(val_loss)/len(val_loss)
  232. results['avg_train_prep'] = avg_train_prep
  233. results['avg_train_loss'] = avg_train_loss
  234. if train_config.run_validation:
  235. results['avg_eval_prep'] = avg_eval_prep
  236. results['avg_eval_loss'] = avg_eval_loss
  237. results["avg_epoch_time"] = avg_epoch_time
  238. results["avg_checkpoint_time"] = avg_checkpoint_time
  239. if train_config.flop_counter:
  240. results["model_flops"]= TFlops
  241. #saving the training params including fsdp setting for reference.
  242. if train_config.enable_fsdp and not train_config.use_peft:
  243. save_train_params(train_config, fsdp_config, rank)
  244. return results
  245. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  246. """
  247. Evaluates the model on the given dataloader
  248. Args:
  249. model: The model to evaluate
  250. eval_dataloader: The dataloader containing the evaluation data
  251. local_rank: The rank of the current node in a distributed setting
  252. tokenizer: The tokenizer used to decode predictions
  253. Returns: eval_ppl, eval_epoch_loss
  254. """
  255. if train_config.enable_fsdp:
  256. world_size = int(os.environ["WORLD_SIZE"])
  257. model.eval()
  258. eval_preds = []
  259. eval_loss = 0.0 # Initialize evaluation loss
  260. with MemoryTrace() as memtrace:
  261. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  262. if step > 5:
  263. break
  264. gc.collect(1)
  265. for key in batch.keys():
  266. if train_config.enable_fsdp:
  267. batch[key] = batch[key].to(local_rank)
  268. else:
  269. batch[key] = batch[key].to('cuda:0')
  270. # Ensure no gradients are computed for this scope to save memory
  271. with torch.no_grad():
  272. # Forward pass and compute loss
  273. outputs = model(**batch)
  274. loss = outputs.loss
  275. eval_loss += loss.detach().float()
  276. # Decode predictions and add to evaluation predictions list
  277. preds = torch.argmax(outputs.logits, -1)
  278. eval_preds.extend(
  279. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  280. )
  281. # If there's more than one CUDA device, reduce evaluation loss across all devices
  282. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  283. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  284. # Compute average loss and perplexity
  285. eval_epoch_loss = eval_loss / len(eval_dataloader)
  286. if train_config.enable_fsdp:
  287. eval_epoch_loss = eval_epoch_loss/world_size
  288. eval_ppl = torch.exp(eval_epoch_loss)
  289. # Print evaluation metrics
  290. if train_config.enable_fsdp:
  291. if local_rank==0:
  292. print(f" {eval_ppl=} {eval_epoch_loss=}")
  293. else:
  294. print(f" {eval_ppl=} {eval_epoch_loss=}")
  295. return eval_ppl, eval_epoch_loss
  296. def freeze_transformer_layers(model, num_layer):
  297. for i, layer in enumerate(model.model.layers):
  298. if i < num_layer:
  299. for param in layer.parameters():
  300. param.requires_grad = False
  301. def check_frozen_layers_peft_model(model):
  302. for i, layer in enumerate(model.base_model.model.model.layers):
  303. for name, param in layer.named_parameters():
  304. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  305. def setup():
  306. """Initialize the process group for distributed training"""
  307. dist.init_process_group("nccl")
  308. def setup_environ_flags(rank):
  309. """Set environment flags for debugging purposes"""
  310. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  311. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  312. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  313. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  314. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  315. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  316. if rank == 0:
  317. print(f"--> Running with torch dist debug set to detail")
  318. def cleanup():
  319. """Clean up the process group after training"""
  320. dist.destroy_process_group()
  321. def clear_gpu_cache(rank=None):
  322. """Clear the GPU cache for all ranks"""
  323. if rank == 0:
  324. print(f"Clearing GPU cache for all ranks")
  325. torch.cuda.empty_cache()
  326. def get_parameter_dtypes(model):
  327. """Get the data types of model parameters"""
  328. parameter_dtypes = {}
  329. for name, parameter in model.named_parameters():
  330. parameter_dtypes[name] = parameter.dtype
  331. return parameter_dtypes
  332. def print_model_size(model, config, rank: int = 0) -> None:
  333. """
  334. Print model name, the number of trainable parameters and initialization time.
  335. Args:
  336. model: The PyTorch model.
  337. model_name (str): Name of the model.
  338. init_time_start (float): Initialization start time.
  339. init_time_end (float): Initialization end time.
  340. rank (int, optional): Current process's rank. Defaults to 0.
  341. """
  342. if rank == 0:
  343. print(f"--> Model {config.model_name}")
  344. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  345. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  346. def get_policies(cfg, rank):
  347. """Get the policies for mixed precision and fsdp wrapping"""
  348. verify_bfloat_support = (
  349. torch.version.cuda
  350. and torch.cuda.is_bf16_supported()
  351. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  352. and dist.is_nccl_available()
  353. and nccl.version() >= (2, 10)
  354. )
  355. mixed_precision_policy = None
  356. wrapping_policy = None
  357. # Mixed precision
  358. if cfg.mixed_precision:
  359. bf16_ready = verify_bfloat_support
  360. if bf16_ready and not cfg.use_fp16:
  361. mixed_precision_policy = bfSixteen_mixed
  362. if rank == 0:
  363. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  364. elif cfg.use_fp16:
  365. mixed_precision_policy = fpSixteen
  366. if rank == 0:
  367. print(f"FP16 enabled")
  368. else:
  369. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  370. wrapping_policy = get_llama_wrapper()
  371. return mixed_precision_policy, wrapping_policy
  372. def save_train_params(train_config, fsdp_config, rank):
  373. """
  374. This function saves the train_config and FSDP config into a train_params.yaml.
  375. This will be used by converter script in the inference folder to fetch the HF model name or path.
  376. It also would be hepful as a log for future references.
  377. """
  378. # Convert the train_config and fsdp_config objects to dictionaries,
  379. # converting all values to strings to ensure they can be serialized into a YAML file
  380. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  381. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  382. # Merge the two dictionaries into one
  383. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  384. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  385. folder_name = (
  386. train_config.dist_checkpoint_root_folder
  387. + "/"
  388. + train_config.dist_checkpoint_folder
  389. + "-"
  390. + train_config.model_name
  391. )
  392. save_dir = Path.cwd() / folder_name
  393. # If the directory does not exist, create it
  394. if not os.path.exists(save_dir):
  395. os.makedirs(save_dir)
  396. # Convert the dictionary to a YAML string
  397. config_yaml = yaml.dump(train_params_dict, indent=4)
  398. file_name = os.path.join(save_dir,'train_params.yaml')
  399. # Check if there's a directory with the same name as the file
  400. if os.path.isdir(file_name):
  401. print(f"Error: {file_name} is a directory, not a file.")
  402. else:
  403. # Write the YAML string to the file
  404. with open(file_name, 'w') as f:
  405. f.write(config_yaml)
  406. if rank==0:
  407. print(f"training params are saved in {file_name}")