train_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import yaml
  7. import time
  8. import fire
  9. import torch
  10. import transformers
  11. from datasets import load_dataset
  12. from tqdm import tqdm
  13. """
  14. Unused imports:
  15. import torch.nn as nn
  16. import bitsandbytes as bnb
  17. """
  18. from torch.nn import functional as F
  19. from peft import (
  20. LoraConfig,
  21. get_peft_model,
  22. get_peft_model_state_dict,
  23. prepare_model_for_int8_training,
  24. set_peft_model_state_dict,
  25. )
  26. from transformers import LlamaForCausalLM, LlamaTokenizer
  27. from torch.distributed.fsdp import StateDictType
  28. import torch.distributed as dist
  29. from pkg_resources import packaging
  30. from .memory_utils import MemoryTrace
  31. import model_checkpointing
  32. import torch.cuda.nccl as nccl
  33. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  34. from pathlib import Path
  35. sys.path.append(str(Path(__file__).resolve().parent.parent))
  36. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  37. from accelerate.utils import is_xpu_available, is_ccl_available
  38. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  39. tokenizer.pad_token_id = 0
  40. tokenizer.padding_side = "left"
  41. # Converting Bytes to Megabytes
  42. def byte2mb(x):
  43. return int(x / 2**20)
  44. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  45. """
  46. Trains the model on the given dataloader
  47. Args:
  48. model: The model to be trained
  49. train_dataloader: The dataloader containing the training data
  50. optimizer: The optimizer used for training
  51. lr_scheduler: The learning rate scheduler
  52. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  53. num_epochs: The number of epochs to train for
  54. local_rank: The rank of the current node in a distributed setting
  55. train_config: The training configuration
  56. eval_dataloader: The dataloader containing the eval data
  57. tokenizer: tokenizer used in the eval for decoding the predicitons
  58. Returns: results dictionary containing average training and validation perplexity and loss
  59. """
  60. # Create a gradient scaler for fp16
  61. if train_config.use_fp16 and train_config.enable_fsdp:
  62. scaler = ShardedGradScaler()
  63. elif train_config.use_fp16 and not train_config.enable_fsdp:
  64. scaler = torch.cuda.amp.GradScaler()
  65. if train_config.enable_fsdp:
  66. world_size = int(os.environ["WORLD_SIZE"])
  67. train_prep = []
  68. train_loss = []
  69. val_prep = []
  70. val_loss =[]
  71. epoch_times = []
  72. checkpoint_times = []
  73. results = {}
  74. best_val_loss = float("inf")
  75. for epoch in range(train_config.num_epochs):
  76. epoch_start_time = time.perf_counter()
  77. with MemoryTrace() as memtrace: # track the memory usage
  78. model.train()
  79. total_loss = 0.0
  80. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  81. for key in batch.keys():
  82. if train_config.enable_fsdp:
  83. batch[key] = batch[key].to(local_rank)
  84. else:
  85. if is_xpu_available():
  86. batch[key] = batch[key].to('xpu:0')
  87. else:
  88. batch[key] = batch[key].to('cuda:0')
  89. loss = model(**batch).loss
  90. loss = loss / gradient_accumulation_steps
  91. total_loss += loss.detach().float()
  92. if train_config.use_fp16:
  93. # if fp16 is enabled, use gradient scaler to handle gradient update
  94. scaler.scale(loss).backward()
  95. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  96. scaler.step(optimizer)
  97. scaler.update()
  98. optimizer.zero_grad()
  99. else:
  100. # regular backpropagation when fp16 is not used
  101. loss.backward()
  102. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  103. optimizer.step()
  104. optimizer.zero_grad()
  105. if train_config.enable_fsdp:
  106. if rank==0:
  107. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  108. else:
  109. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  110. epoch_end_time = time.perf_counter()-epoch_start_time
  111. epoch_times.append(epoch_end_time)
  112. # Reducing total_loss across all devices if there's more than one CUDA device
  113. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  114. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  115. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  116. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  117. train_epoch_loss = total_loss / len(train_dataloader)
  118. if train_config.enable_fsdp:
  119. train_epoch_loss = train_epoch_loss/world_size
  120. train_perplexity = torch.exp(train_epoch_loss)
  121. train_prep.append(train_perplexity)
  122. train_loss.append(train_epoch_loss)
  123. if train_config.enable_fsdp:
  124. if rank==0:
  125. if is_xpu_available():
  126. print(f"Max XPU memory allocated was {memtrace.peak} GB")
  127. print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
  128. print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
  129. print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
  130. else:
  131. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  132. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  133. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  134. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  135. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  136. else:
  137. if is_xpu_available():
  138. print(f"Max XPU memory allocated was {memtrace.peak} GB")
  139. print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
  140. print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
  141. print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
  142. else:
  143. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  144. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  145. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  146. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  147. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  148. # Update the learning rate as needed
  149. lr_scheduler.step()
  150. if train_config.run_validation:
  151. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  152. checkpoint_start_time = time.perf_counter()
  153. if train_config.save_model and eval_epoch_loss < best_val_loss:
  154. if train_config.enable_fsdp:
  155. dist.barrier()
  156. if train_config.use_peft:
  157. if train_config.enable_fsdp:
  158. if rank==0:
  159. print(f"we are about to save the PEFT modules")
  160. else:
  161. print(f"we are about to save the PEFT modules")
  162. model.save_pretrained(train_config.output_dir)
  163. if train_config.enable_fsdp:
  164. if rank==0:
  165. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  166. else:
  167. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  168. else:
  169. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  170. model_checkpointing.save_model_checkpoint(
  171. model, optimizer, rank, train_config, epoch=epoch
  172. )
  173. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  174. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  175. print("=====================================================")
  176. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  177. if train_config.save_optimizer:
  178. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  179. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  180. print("=====================================================")
  181. if not train_config.use_peft and train_config.save_optimizer:
  182. model_checkpointing.save_optimizer_checkpoint(
  183. model, optimizer, rank, train_config, epoch=epoch
  184. )
  185. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  186. print("=====================================================")
  187. if train_config.enable_fsdp:
  188. dist.barrier()
  189. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  190. checkpoint_times.append(checkpoint_end_time)
  191. if eval_epoch_loss < best_val_loss:
  192. best_val_loss = eval_epoch_loss
  193. if train_config.enable_fsdp:
  194. if rank==0:
  195. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  196. else:
  197. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  198. val_loss.append(best_val_loss)
  199. val_prep.append(eval_ppl)
  200. if train_config.enable_fsdp:
  201. if rank==0:
  202. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  203. else:
  204. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  205. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  206. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)
  207. avg_train_prep = sum(train_prep)/len(train_prep)
  208. avg_train_loss = sum(train_loss)/len(train_loss)
  209. if train_config.run_validation:
  210. avg_eval_prep = sum(val_prep)/len(val_prep)
  211. avg_eval_loss = sum(val_loss)/len(val_loss)
  212. results['avg_train_prep'] = avg_train_prep
  213. results['avg_train_loss'] = avg_train_loss
  214. if train_config.run_validation:
  215. results['avg_eval_prep'] = avg_eval_prep
  216. results['avg_eval_loss'] = avg_eval_loss
  217. results["avg_epoch_time"] = avg_epoch_time
  218. results["avg_checkpoint_time"] = avg_checkpoint_time
  219. #saving the training params including fsdp setting for reference.
  220. if train_config.enable_fsdp and not train_config.use_peft:
  221. save_train_params(train_config, fsdp_config, rank)
  222. return results
  223. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  224. """
  225. Evaluates the model on the given dataloader
  226. Args:
  227. model: The model to evaluate
  228. eval_dataloader: The dataloader containing the evaluation data
  229. local_rank: The rank of the current node in a distributed setting
  230. tokenizer: The tokenizer used to decode predictions
  231. Returns: eval_ppl, eval_epoch_loss
  232. """
  233. if train_config.enable_fsdp:
  234. world_size = int(os.environ["WORLD_SIZE"])
  235. model.eval()
  236. eval_preds = []
  237. eval_loss = 0.0 # Initialize evaluation loss
  238. with MemoryTrace() as memtrace:
  239. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  240. for key in batch.keys():
  241. if train_config.enable_fsdp:
  242. batch[key] = batch[key].to(local_rank)
  243. else:
  244. if is_xpu_available():
  245. batch[key] = batch[key].to('xpu:0')
  246. else:
  247. batch[key] = batch[key].to('cuda:0')
  248. # Ensure no gradients are computed for this scope to save memory
  249. with torch.no_grad():
  250. # Forward pass and compute loss
  251. outputs = model(**batch)
  252. loss = outputs.loss
  253. eval_loss += loss.detach().float()
  254. # Decode predictions and add to evaluation predictions list
  255. preds = torch.argmax(outputs.logits, -1)
  256. eval_preds.extend(
  257. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  258. )
  259. # If there's more than one CUDA device, reduce evaluation loss across all devices
  260. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  261. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  262. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  263. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  264. # Compute average loss and perplexity
  265. eval_epoch_loss = eval_loss / len(eval_dataloader)
  266. if train_config.enable_fsdp:
  267. eval_epoch_loss = eval_epoch_loss/world_size
  268. eval_ppl = torch.exp(eval_epoch_loss)
  269. # Print evaluation metrics
  270. if train_config.enable_fsdp:
  271. if local_rank==0:
  272. print(f" {eval_ppl=} {eval_epoch_loss=}")
  273. else:
  274. print(f" {eval_ppl=} {eval_epoch_loss=}")
  275. return eval_ppl, eval_epoch_loss
  276. def freeze_transformer_layers(model, num_layer):
  277. for i, layer in enumerate(model.model.layers):
  278. if i < num_layer:
  279. for param in layer.parameters():
  280. param.requires_grad = False
  281. def check_frozen_layers_peft_model(model):
  282. for i, layer in enumerate(model.base_model.model.model.layers):
  283. for name, param in layer.named_parameters():
  284. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  285. def setup():
  286. """Initialize the process group for distributed training"""
  287. if is_ccl_available():
  288. # distributed training on xpus
  289. dist.init_process_group("ccl")
  290. else:
  291. dist.init_process_group("nccl")
  292. def setup_environ_flags(rank):
  293. """Set environment flags for debugging purposes"""
  294. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  295. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  296. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  297. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  298. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  299. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  300. if rank == 0:
  301. print(f"--> Running with torch dist debug set to detail")
  302. def cleanup():
  303. """Clean up the process group after training"""
  304. dist.destroy_process_group()
  305. def clear_gpu_cache(rank=None):
  306. """Clear the GPU cache for all ranks"""
  307. if rank == 0:
  308. print(f"Clearing GPU cache for all ranks")
  309. if is_xpu_available():
  310. torch.xpu_empty_cache()
  311. else:
  312. torch.cuda.empty_cache()
  313. def get_parameter_dtypes(model):
  314. """Get the data types of model parameters"""
  315. parameter_dtypes = {}
  316. for name, parameter in model.named_parameters():
  317. parameter_dtypes[name] = parameter.dtype
  318. return parameter_dtypes
  319. def print_model_size(model, config, rank: int = 0) -> None:
  320. """
  321. Print model name, the number of trainable parameters and initialization time.
  322. Args:
  323. model: The PyTorch model.
  324. model_name (str): Name of the model.
  325. init_time_start (float): Initialization start time.
  326. init_time_end (float): Initialization end time.
  327. rank (int, optional): Current process's rank. Defaults to 0.
  328. """
  329. if rank == 0:
  330. print(f"--> Model {config.model_name}")
  331. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  332. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  333. def get_policies(cfg, rank):
  334. """Get the policies for mixed precision and fsdp wrapping"""
  335. verify_bfloat_support = ((
  336. torch.version.cuda
  337. and torch.cuda.is_bf16_supported()
  338. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  339. and dist.is_nccl_available()
  340. and nccl.version() >= (2, 10)
  341. ) or
  342. (is_xpu_available()))
  343. mixed_precision_policy = None
  344. wrapping_policy = None
  345. # Mixed precision
  346. if cfg.mixed_precision:
  347. bf16_ready = verify_bfloat_support
  348. if bf16_ready and not cfg.use_fp16:
  349. mixed_precision_policy = bfSixteen_mixed
  350. if rank == 0:
  351. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  352. elif cfg.use_fp16:
  353. mixed_precision_policy = fpSixteen
  354. if rank == 0:
  355. print(f"FP16 enabled")
  356. else:
  357. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  358. wrapping_policy = get_llama_wrapper()
  359. return mixed_precision_policy, wrapping_policy
  360. def save_train_params(train_config, fsdp_config, rank):
  361. """
  362. This function saves the train_config and FSDP config into a train_params.yaml.
  363. This will be used by converter script in the inference folder to fetch the HF model name or path.
  364. It also would be hepful as a log for future references.
  365. """
  366. # Convert the train_config and fsdp_config objects to dictionaries,
  367. # converting all values to strings to ensure they can be serialized into a YAML file
  368. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  369. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  370. # Merge the two dictionaries into one
  371. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  372. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  373. folder_name = (
  374. train_config.dist_checkpoint_root_folder
  375. + "/"
  376. + train_config.dist_checkpoint_folder
  377. + "-"
  378. + train_config.model_name
  379. )
  380. save_dir = Path.cwd() / folder_name
  381. # If the directory does not exist, create it
  382. if not os.path.exists(save_dir):
  383. os.makedirs(save_dir)
  384. # Convert the dictionary to a YAML string
  385. config_yaml = yaml.dump(train_params_dict, indent=4)
  386. file_name = os.path.join(save_dir,'train_params.yaml')
  387. # Check if there's a directory with the same name as the file
  388. if os.path.isdir(file_name):
  389. print(f"Error: {file_name} is a directory, not a file.")
  390. else:
  391. # Write the YAML string to the file
  392. with open(file_name, 'w') as f:
  393. f.write(config_yaml)
  394. if rank==0:
  395. print(f"training params are saved in {file_name}")