train_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import yaml
  7. import time
  8. import fire
  9. import torch
  10. import transformers
  11. from datasets import load_dataset
  12. from tqdm import tqdm
  13. """
  14. Unused imports:
  15. import torch.nn as nn
  16. import bitsandbytes as bnb
  17. """
  18. from torch.nn import functional as F
  19. from peft import (
  20. LoraConfig,
  21. get_peft_model,
  22. get_peft_model_state_dict,
  23. prepare_model_for_int8_training,
  24. set_peft_model_state_dict,
  25. )
  26. from transformers import LlamaForCausalLM, LlamaTokenizer
  27. from torch.distributed.fsdp import StateDictType
  28. import torch.distributed as dist
  29. from pkg_resources import packaging
  30. from .memory_utils import MemoryTrace
  31. import model_checkpointing
  32. import torch.cuda.nccl as nccl
  33. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  34. from pathlib import Path
  35. sys.path.append(str(Path(__file__).resolve().parent.parent))
  36. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  37. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  38. tokenizer.pad_token_id = 0
  39. tokenizer.padding_side = "left"
  40. # Converting Bytes to Megabytes
  41. def byte2mb(x):
  42. return int(x / 2**20)
  43. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  44. """
  45. Trains the model on the given dataloader
  46. Args:
  47. model: The model to be trained
  48. train_dataloader: The dataloader containing the training data
  49. optimizer: The optimizer used for training
  50. lr_scheduler: The learning rate scheduler
  51. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  52. num_epochs: The number of epochs to train for
  53. local_rank: The rank of the current node in a distributed setting
  54. train_config: The training configuration
  55. eval_dataloader: The dataloader containing the eval data
  56. tokenizer: tokenizer used in the eval for decoding the predicitons
  57. Returns: results dictionary containing average training and validation perplexity and loss
  58. """
  59. # Create a gradient scaler for fp16
  60. if train_config.use_fp16 and train_config.enable_fsdp:
  61. scaler = ShardedGradScaler()
  62. elif train_config.use_fp16 and not train_config.enable_fsdp:
  63. scaler = torch.cuda.amp.GradScaler()
  64. if train_config.enable_fsdp:
  65. world_size = int(os.environ["WORLD_SIZE"])
  66. train_prep = []
  67. train_loss = []
  68. val_prep = []
  69. val_loss =[]
  70. epoch_times = []
  71. checkpoint_times = []
  72. results = {}
  73. best_val_loss = float("inf")
  74. for epoch in range(train_config.num_epochs):
  75. epoch_start_time = time.perf_counter()
  76. with MemoryTrace() as memtrace: # track the memory usage
  77. model.train()
  78. total_loss = 0.0
  79. total_length = len(train_dataloader)//gradient_accumulation_steps
  80. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
  81. for step, batch in enumerate(train_dataloader):
  82. for key in batch.keys():
  83. if train_config.enable_fsdp:
  84. batch[key] = batch[key].to(local_rank)
  85. else:
  86. batch[key] = batch[key].to('cuda:0')
  87. loss = model(**batch).loss
  88. loss = loss / gradient_accumulation_steps
  89. total_loss += loss.detach().float()
  90. if train_config.use_fp16:
  91. # if fp16 is enabled, use gradient scaler to handle gradient update
  92. scaler.scale(loss).backward()
  93. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  94. scaler.step(optimizer)
  95. scaler.update()
  96. optimizer.zero_grad()
  97. pbar.update(step//gradient_accumulation_steps)
  98. else:
  99. # regular backpropagation when fp16 is not used
  100. loss.backward()
  101. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  102. optimizer.step()
  103. optimizer.zero_grad()
  104. pbar.update(step//gradient_accumulation_steps)
  105. pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  106. epoch_end_time = time.perf_counter()-epoch_start_time
  107. epoch_times.append(epoch_end_time)
  108. # Reducing total_loss across all devices if there's more than one CUDA device
  109. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  110. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  111. train_epoch_loss = total_loss / len(train_dataloader)
  112. if train_config.enable_fsdp:
  113. train_epoch_loss = train_epoch_loss/world_size
  114. train_perplexity = torch.exp(train_epoch_loss)
  115. train_prep.append(train_perplexity)
  116. train_loss.append(train_epoch_loss)
  117. if train_config.enable_fsdp:
  118. if rank==0:
  119. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  120. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  121. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  122. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  123. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  124. else:
  125. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  126. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  127. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  128. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  129. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  130. # Update the learning rate as needed
  131. lr_scheduler.step()
  132. if train_config.run_validation:
  133. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  134. checkpoint_start_time = time.perf_counter()
  135. if train_config.save_model and eval_epoch_loss < best_val_loss:
  136. if train_config.enable_fsdp:
  137. dist.barrier()
  138. if train_config.use_peft:
  139. if train_config.enable_fsdp:
  140. if rank==0:
  141. print(f"we are about to save the PEFT modules")
  142. else:
  143. print(f"we are about to save the PEFT modules")
  144. model.save_pretrained(train_config.output_dir)
  145. if train_config.enable_fsdp:
  146. if rank==0:
  147. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  148. else:
  149. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  150. else:
  151. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  152. model_checkpointing.save_model_checkpoint(
  153. model, optimizer, rank, train_config, epoch=epoch
  154. )
  155. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  156. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  157. print("=====================================================")
  158. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  159. if train_config.save_optimizer:
  160. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  161. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  162. print("=====================================================")
  163. if not train_config.use_peft and train_config.save_optimizer:
  164. model_checkpointing.save_optimizer_checkpoint(
  165. model, optimizer, rank, train_config, epoch=epoch
  166. )
  167. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  168. print("=====================================================")
  169. if train_config.enable_fsdp:
  170. dist.barrier()
  171. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  172. checkpoint_times.append(checkpoint_end_time)
  173. if eval_epoch_loss < best_val_loss:
  174. best_val_loss = eval_epoch_loss
  175. if train_config.enable_fsdp:
  176. if rank==0:
  177. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  178. else:
  179. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  180. val_loss.append(best_val_loss)
  181. val_prep.append(eval_ppl)
  182. if train_config.enable_fsdp:
  183. if rank==0:
  184. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  185. else:
  186. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  187. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  188. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)
  189. avg_train_prep = sum(train_prep)/len(train_prep)
  190. avg_train_loss = sum(train_loss)/len(train_loss)
  191. if train_config.run_validation:
  192. avg_eval_prep = sum(val_prep)/len(val_prep)
  193. avg_eval_loss = sum(val_loss)/len(val_loss)
  194. results['avg_train_prep'] = avg_train_prep
  195. results['avg_train_loss'] = avg_train_loss
  196. if train_config.run_validation:
  197. results['avg_eval_prep'] = avg_eval_prep
  198. results['avg_eval_loss'] = avg_eval_loss
  199. results["avg_epoch_time"] = avg_epoch_time
  200. results["avg_checkpoint_time"] = avg_checkpoint_time
  201. #saving the training params including fsdp setting for reference.
  202. if train_config.enable_fsdp and not train_config.use_peft:
  203. save_train_params(train_config, fsdp_config, rank)
  204. return results
  205. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  206. """
  207. Evaluates the model on the given dataloader
  208. Args:
  209. model: The model to evaluate
  210. eval_dataloader: The dataloader containing the evaluation data
  211. local_rank: The rank of the current node in a distributed setting
  212. tokenizer: The tokenizer used to decode predictions
  213. Returns: eval_ppl, eval_epoch_loss
  214. """
  215. if train_config.enable_fsdp:
  216. world_size = int(os.environ["WORLD_SIZE"])
  217. model.eval()
  218. eval_preds = []
  219. eval_loss = 0.0 # Initialize evaluation loss
  220. with MemoryTrace() as memtrace:
  221. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  222. for key in batch.keys():
  223. if train_config.enable_fsdp:
  224. batch[key] = batch[key].to(local_rank)
  225. else:
  226. batch[key] = batch[key].to('cuda:0')
  227. # Ensure no gradients are computed for this scope to save memory
  228. with torch.no_grad():
  229. # Forward pass and compute loss
  230. outputs = model(**batch)
  231. loss = outputs.loss
  232. eval_loss += loss.detach().float()
  233. # Decode predictions and add to evaluation predictions list
  234. preds = torch.argmax(outputs.logits, -1)
  235. eval_preds.extend(
  236. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  237. )
  238. # If there's more than one CUDA device, reduce evaluation loss across all devices
  239. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  240. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  241. # Compute average loss and perplexity
  242. eval_epoch_loss = eval_loss / len(eval_dataloader)
  243. if train_config.enable_fsdp:
  244. eval_epoch_loss = eval_epoch_loss/world_size
  245. eval_ppl = torch.exp(eval_epoch_loss)
  246. # Print evaluation metrics
  247. if train_config.enable_fsdp:
  248. if local_rank==0:
  249. print(f" {eval_ppl=} {eval_epoch_loss=}")
  250. else:
  251. print(f" {eval_ppl=} {eval_epoch_loss=}")
  252. return eval_ppl, eval_epoch_loss
  253. def freeze_transformer_layers(model, num_layer):
  254. for i, layer in enumerate(model.model.layers):
  255. if i < num_layer:
  256. for param in layer.parameters():
  257. param.requires_grad = False
  258. def check_frozen_layers_peft_model(model):
  259. for i, layer in enumerate(model.base_model.model.model.layers):
  260. for name, param in layer.named_parameters():
  261. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  262. def setup():
  263. """Initialize the process group for distributed training"""
  264. dist.init_process_group("nccl")
  265. def setup_environ_flags(rank):
  266. """Set environment flags for debugging purposes"""
  267. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  268. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  269. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  270. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  271. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  272. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  273. if rank == 0:
  274. print(f"--> Running with torch dist debug set to detail")
  275. def cleanup():
  276. """Clean up the process group after training"""
  277. dist.destroy_process_group()
  278. def clear_gpu_cache(rank=None):
  279. """Clear the GPU cache for all ranks"""
  280. if rank == 0:
  281. print(f"Clearing GPU cache for all ranks")
  282. torch.cuda.empty_cache()
  283. def get_parameter_dtypes(model):
  284. """Get the data types of model parameters"""
  285. parameter_dtypes = {}
  286. for name, parameter in model.named_parameters():
  287. parameter_dtypes[name] = parameter.dtype
  288. return parameter_dtypes
  289. def print_model_size(model, config, rank: int = 0) -> None:
  290. """
  291. Print model name, the number of trainable parameters and initialization time.
  292. Args:
  293. model: The PyTorch model.
  294. model_name (str): Name of the model.
  295. init_time_start (float): Initialization start time.
  296. init_time_end (float): Initialization end time.
  297. rank (int, optional): Current process's rank. Defaults to 0.
  298. """
  299. if rank == 0:
  300. print(f"--> Model {config.model_name}")
  301. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  302. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  303. def get_policies(cfg, rank):
  304. """Get the policies for mixed precision and fsdp wrapping"""
  305. verify_bfloat_support = (
  306. torch.version.cuda
  307. and torch.cuda.is_bf16_supported()
  308. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  309. and dist.is_nccl_available()
  310. and nccl.version() >= (2, 10)
  311. )
  312. mixed_precision_policy = None
  313. wrapping_policy = None
  314. # Mixed precision
  315. if cfg.mixed_precision:
  316. bf16_ready = verify_bfloat_support
  317. if bf16_ready and not cfg.use_fp16:
  318. mixed_precision_policy = bfSixteen_mixed
  319. if rank == 0:
  320. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  321. elif cfg.use_fp16:
  322. mixed_precision_policy = fpSixteen
  323. if rank == 0:
  324. print(f"FP16 enabled")
  325. else:
  326. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  327. wrapping_policy = get_llama_wrapper()
  328. return mixed_precision_policy, wrapping_policy
  329. def save_train_params(train_config, fsdp_config, rank):
  330. """
  331. This function saves the train_config and FSDP config into a train_params.yaml.
  332. This will be used by converter script in the inference folder to fetch the HF model name or path.
  333. It also would be hepful as a log for future references.
  334. """
  335. # Convert the train_config and fsdp_config objects to dictionaries,
  336. # converting all values to strings to ensure they can be serialized into a YAML file
  337. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  338. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  339. # Merge the two dictionaries into one
  340. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  341. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  342. folder_name = (
  343. train_config.dist_checkpoint_root_folder
  344. + "/"
  345. + train_config.dist_checkpoint_folder
  346. + "-"
  347. + train_config.model_name
  348. )
  349. save_dir = Path.cwd() / folder_name
  350. # If the directory does not exist, create it
  351. if not os.path.exists(save_dir):
  352. os.makedirs(save_dir)
  353. # Convert the dictionary to a YAML string
  354. config_yaml = yaml.dump(train_params_dict, indent=4)
  355. file_name = os.path.join(save_dir,'train_params.yaml')
  356. # Check if there's a directory with the same name as the file
  357. if os.path.isdir(file_name):
  358. print(f"Error: {file_name} is a directory, not a file.")
  359. else:
  360. # Write the YAML string to the file
  361. with open(file_name, 'w') as f:
  362. f.write(config_yaml)
  363. if rank==0:
  364. print(f"training params are saved in {file_name}")