train_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import yaml
  7. import fire
  8. import torch
  9. import transformers
  10. from datasets import load_dataset
  11. from tqdm import tqdm
  12. """
  13. Unused imports:
  14. import torch.nn as nn
  15. import bitsandbytes as bnb
  16. """
  17. from torch.nn import functional as F
  18. from peft import (
  19. LoraConfig,
  20. get_peft_model,
  21. get_peft_model_state_dict,
  22. prepare_model_for_int8_training,
  23. set_peft_model_state_dict,
  24. )
  25. from transformers import LlamaForCausalLM, LlamaTokenizer
  26. from torch.distributed.fsdp import StateDictType
  27. import torch.distributed as dist
  28. from pkg_resources import packaging
  29. from .memory_utils import MemoryTrace
  30. import model_checkpointing
  31. import torch.cuda.nccl as nccl
  32. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  33. from pathlib import Path
  34. sys.path.append(str(Path(__file__).resolve().parent.parent))
  35. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  36. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  37. tokenizer.pad_token_id = 0
  38. tokenizer.padding_side = "left"
  39. # Converting Bytes to Megabytes
  40. def byte2mb(x):
  41. return int(x / 2**20)
  42. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  43. """
  44. Trains the model on the given dataloader
  45. Args:
  46. model: The model to be trained
  47. train_dataloader: The dataloader containing the training data
  48. optimizer: The optimizer used for training
  49. lr_scheduler: The learning rate scheduler
  50. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  51. num_epochs: The number of epochs to train for
  52. local_rank: The rank of the current node in a distributed setting
  53. train_config: The training configuration
  54. eval_dataloader: The dataloader containing the eval data
  55. tokenizer: tokenizer used in the eval for decoding the predicitons
  56. Returns: results dictionary containing average training and validation perplexity and loss
  57. """
  58. # Create a gradient scaler for fp16
  59. if train_config.use_fp16 and train_config.enable_fsdp:
  60. scaler = ShardedGradScaler()
  61. elif train_config.use_fp16 and not train_config.enable_fsdp:
  62. scaler = torch.cuda.amp.GradScaler()
  63. if train_config.enable_fsdp:
  64. world_size = int(os.environ["WORLD_SIZE"])
  65. train_prep = []
  66. train_loss = []
  67. val_prep = []
  68. val_loss =[]
  69. results = {}
  70. best_val_loss = float("inf")
  71. for epoch in range(train_config.num_epochs):
  72. with MemoryTrace() as memtrace: # track the memory usage
  73. model.train()
  74. total_loss = 0.0
  75. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  76. for key in batch.keys():
  77. if train_config.enable_fsdp:
  78. batch[key] = batch[key].to(local_rank)
  79. else:
  80. batch[key] = batch[key].to('cuda:0')
  81. loss = model(**batch).loss
  82. loss = loss / gradient_accumulation_steps
  83. total_loss += loss.detach().float()
  84. if train_config.use_fp16:
  85. # if fp16 is enabled, use gradient scaler to handle gradient update
  86. scaler.scale(loss).backward()
  87. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  88. scaler.step(optimizer)
  89. scaler.update()
  90. optimizer.zero_grad()
  91. else:
  92. # regular backpropagation when fp16 is not used
  93. loss.backward()
  94. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  95. optimizer.step()
  96. optimizer.zero_grad()
  97. if train_config.enable_fsdp:
  98. if rank==0:
  99. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  100. else:
  101. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  102. # Reducing total_loss across all devices if there's more than one CUDA device
  103. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  104. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  105. train_epoch_loss = total_loss / len(train_dataloader)
  106. if train_config.enable_fsdp:
  107. train_epoch_loss = train_epoch_loss/world_size
  108. train_perplexity = torch.exp(train_epoch_loss)
  109. train_prep.append(train_perplexity)
  110. train_loss.append(train_epoch_loss)
  111. if train_config.enable_fsdp:
  112. if rank==0:
  113. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  114. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  115. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  116. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  117. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  118. else:
  119. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  120. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  121. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  122. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  123. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  124. # Update the learning rate as needed
  125. lr_scheduler.step()
  126. if train_config.run_validation:
  127. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  128. if train_config.save_model and eval_epoch_loss < best_val_loss:
  129. if train_config.enable_fsdp:
  130. dist.barrier()
  131. if train_config.use_peft:
  132. if train_config.enable_fsdp:
  133. if rank==0:
  134. print(f"we are about to save the PEFT modules")
  135. else:
  136. print(f"we are about to save the PEFT modules")
  137. model.save_pretrained(train_config.output_dir)
  138. if train_config.enable_fsdp:
  139. if rank==0:
  140. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  141. else:
  142. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  143. else:
  144. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  145. model_checkpointing.save_model_checkpoint(
  146. model, optimizer, rank, train_config, epoch=epoch
  147. )
  148. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  149. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  150. print("=====================================================")
  151. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  152. if train_config.save_optimizer:
  153. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  154. print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
  155. print("=====================================================")
  156. if not train_config.use_peft and train_config.save_optimizer:
  157. model_checkpointing.save_optimizer_checkpoint(
  158. model, optimizer, rank, train_config, epoch=epoch
  159. )
  160. print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
  161. print("=====================================================")
  162. if train_config.enable_fsdp:
  163. dist.barrier()
  164. if eval_epoch_loss < best_val_loss:
  165. best_val_loss = eval_epoch_loss
  166. if train_config.enable_fsdp:
  167. if rank==0:
  168. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  169. else:
  170. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  171. val_loss.append(best_val_loss)
  172. val_prep.append(eval_ppl)
  173. if train_config.enable_fsdp:
  174. if rank==0:
  175. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  176. else:
  177. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  178. avg_train_prep = sum(train_prep)/len(train_prep)
  179. avg_train_loss = sum(train_loss)/len(train_loss)
  180. if train_config.run_validation:
  181. avg_eval_prep = sum(val_prep)/len(val_prep)
  182. avg_eval_loss = sum(val_loss)/len(val_loss)
  183. results['avg_train_prep'] = avg_train_prep
  184. results['avg_train_loss'] = avg_train_loss
  185. if train_config.run_validation:
  186. results['avg_eval_prep'] = avg_eval_prep
  187. results['avg_eval_loss'] = avg_eval_loss
  188. #saving the training params including fsdp setting for reference.
  189. if train_config.enable_fsdp and fsdp_config:
  190. save_train_params(train_config, fsdp_config, rank)
  191. return results
  192. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  193. """
  194. Evaluates the model on the given dataloader
  195. Args:
  196. model: The model to evaluate
  197. eval_dataloader: The dataloader containing the evaluation data
  198. local_rank: The rank of the current node in a distributed setting
  199. tokenizer: The tokenizer used to decode predictions
  200. Returns: eval_ppl, eval_epoch_loss
  201. """
  202. if train_config.enable_fsdp:
  203. world_size = int(os.environ["WORLD_SIZE"])
  204. model.eval()
  205. eval_preds = []
  206. eval_loss = 0.0 # Initialize evaluation loss
  207. with MemoryTrace() as memtrace:
  208. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  209. for key in batch.keys():
  210. if train_config.enable_fsdp:
  211. batch[key] = batch[key].to(local_rank)
  212. else:
  213. batch[key] = batch[key].to('cuda:0')
  214. # Ensure no gradients are computed for this scope to save memory
  215. with torch.no_grad():
  216. # Forward pass and compute loss
  217. outputs = model(**batch)
  218. loss = outputs.loss
  219. eval_loss += loss.detach().float()
  220. # Decode predictions and add to evaluation predictions list
  221. preds = torch.argmax(outputs.logits, -1)
  222. eval_preds.extend(
  223. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  224. )
  225. # If there's more than one CUDA device, reduce evaluation loss across all devices
  226. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  227. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  228. # Compute average loss and perplexity
  229. eval_epoch_loss = eval_loss / len(eval_dataloader)
  230. if train_config.enable_fsdp:
  231. eval_epoch_loss = eval_epoch_loss/world_size
  232. eval_ppl = torch.exp(eval_epoch_loss)
  233. # Print evaluation metrics
  234. if train_config.enable_fsdp:
  235. if local_rank==0:
  236. print(f" {eval_ppl=} {eval_epoch_loss=}")
  237. else:
  238. print(f" {eval_ppl=} {eval_epoch_loss=}")
  239. return eval_ppl, eval_epoch_loss
  240. def freeze_transformer_layers(model, num_layer):
  241. for i, layer in enumerate(model.model.layers):
  242. if i < num_layer:
  243. for param in layer.parameters():
  244. param.requires_grad = False
  245. def check_frozen_layers_peft_model(model):
  246. for i, layer in enumerate(model.base_model.model.model.layers):
  247. for name, param in layer.named_parameters():
  248. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  249. def setup():
  250. """Initialize the process group for distributed training"""
  251. dist.init_process_group("nccl")
  252. def setup_environ_flags(rank):
  253. """Set environment flags for debugging purposes"""
  254. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  255. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  256. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  257. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  258. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  259. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  260. if rank == 0:
  261. print(f"--> Running with torch dist debug set to detail")
  262. def cleanup():
  263. """Clean up the process group after training"""
  264. dist.destroy_process_group()
  265. def clear_gpu_cache(rank=None):
  266. """Clear the GPU cache for all ranks"""
  267. if rank == 0:
  268. print(f"Clearing GPU cache for all ranks")
  269. torch.cuda.empty_cache()
  270. def get_parameter_dtypes(model):
  271. """Get the data types of model parameters"""
  272. parameter_dtypes = {}
  273. for name, parameter in model.named_parameters():
  274. parameter_dtypes[name] = parameter.dtype
  275. return parameter_dtypes
  276. def print_model_size(model, config, rank: int = 0) -> None:
  277. """
  278. Print model name, the number of trainable parameters and initialization time.
  279. Args:
  280. model: The PyTorch model.
  281. model_name (str): Name of the model.
  282. init_time_start (float): Initialization start time.
  283. init_time_end (float): Initialization end time.
  284. rank (int, optional): Current process's rank. Defaults to 0.
  285. """
  286. if rank == 0:
  287. print(f"--> Model {config.model_name}")
  288. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  289. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  290. def get_policies(cfg, rank):
  291. """Get the policies for mixed precision and fsdp wrapping"""
  292. verify_bfloat_support = (
  293. torch.version.cuda
  294. and torch.cuda.is_bf16_supported()
  295. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  296. and dist.is_nccl_available()
  297. and nccl.version() >= (2, 10)
  298. )
  299. mixed_precision_policy = None
  300. wrapping_policy = None
  301. # Mixed precision
  302. if cfg.mixed_precision:
  303. bf16_ready = verify_bfloat_support
  304. if bf16_ready and not cfg.use_fp16:
  305. mixed_precision_policy = bfSixteen_mixed
  306. if rank == 0:
  307. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  308. elif cfg.use_fp16:
  309. mixed_precision_policy = fpSixteen
  310. if rank == 0:
  311. print(f"FP16 enabled")
  312. else:
  313. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  314. wrapping_policy = get_llama_wrapper()
  315. return mixed_precision_policy, wrapping_policy
  316. def save_train_params(train_config, fsdp_config, rank):
  317. """
  318. This function saves the train_config and FSDP config into a train_params.yaml.
  319. This will be used by converter script in the inference folder to fetch the HF model name or path.
  320. It also would be hepful as a log for future references.
  321. """
  322. # Convert the train_config and fsdp_config objects to dictionaries,
  323. # converting all values to strings to ensure they can be serialized into a YAML file
  324. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  325. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  326. # Merge the two dictionaries into one
  327. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  328. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  329. folder_name = (
  330. train_config.dist_checkpoint_root_folder
  331. + "/"
  332. + train_config.dist_checkpoint_folder
  333. + "-"
  334. + train_config.model_name
  335. )
  336. save_dir = Path.cwd() / folder_name
  337. # If the directory does not exist, create it
  338. if not os.path.exists(save_dir):
  339. os.makedirs(save_dir)
  340. # Convert the dictionary to a YAML string
  341. config_yaml = yaml.dump(train_params_dict, indent=4)
  342. file_name = os.path.join(save_dir,'train_params.yaml')
  343. # Check if there's a directory with the same name as the file
  344. if os.path.isdir(file_name):
  345. print(f"Error: {file_name} is a directory, not a file.")
  346. else:
  347. # Write the YAML string to the file
  348. with open(file_name, 'w') as f:
  349. f.write(config_yaml)
  350. if rank==0:
  351. print(f"training params are saved in {file_name}")