train_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from pathlib import Path
  7. from pkg_resources import packaging
  8. import torch
  9. import torch.cuda.nccl as nccl
  10. import torch.distributed as dist
  11. from torch.distributed.fsdp import StateDictType
  12. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  13. from tqdm import tqdm
  14. from transformers import LlamaTokenizer
  15. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  16. from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
  17. from llama_recipes.utils.memory_utils import MemoryTrace
  18. from accelerate.utils import is_xpu_available, is_ccl_available
  19. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  20. tokenizer.pad_token_id = 0
  21. tokenizer.padding_side = "left"
  22. # Converting Bytes to Megabytes
  23. def byte2mb(x):
  24. return int(x / 2**20)
  25. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  26. """
  27. Trains the model on the given dataloader
  28. Args:
  29. model: The model to be trained
  30. train_dataloader: The dataloader containing the training data
  31. optimizer: The optimizer used for training
  32. lr_scheduler: The learning rate scheduler
  33. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  34. num_epochs: The number of epochs to train for
  35. local_rank: The rank of the current node in a distributed setting
  36. train_config: The training configuration
  37. eval_dataloader: The dataloader containing the eval data
  38. tokenizer: tokenizer used in the eval for decoding the predicitons
  39. Returns: results dictionary containing average training and validation perplexity and loss
  40. """
  41. # Create a gradient scaler for fp16
  42. if train_config.use_fp16 and train_config.enable_fsdp:
  43. scaler = ShardedGradScaler()
  44. elif train_config.use_fp16 and not train_config.enable_fsdp:
  45. scaler = torch.cuda.amp.GradScaler()
  46. if train_config.enable_fsdp:
  47. world_size = int(os.environ["WORLD_SIZE"])
  48. train_prep = []
  49. train_loss = []
  50. val_prep = []
  51. val_loss =[]
  52. epoch_times = []
  53. checkpoint_times = []
  54. results = {}
  55. best_val_loss = float("inf")
  56. for epoch in range(train_config.num_epochs):
  57. epoch_start_time = time.perf_counter()
  58. with MemoryTrace() as memtrace: # track the memory usage
  59. model.train()
  60. total_loss = 0.0
  61. total_length = len(train_dataloader)//gradient_accumulation_steps
  62. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
  63. for step, batch in enumerate(train_dataloader):
  64. for key in batch.keys():
  65. if train_config.enable_fsdp:
  66. batch[key] = batch[key].to(local_rank)
  67. else:
  68. batch[key] = batch[key].to('cuda:0')
  69. loss = model(**batch).loss
  70. loss = loss / gradient_accumulation_steps
  71. total_loss += loss.detach().float()
  72. loss = torch.autograd.Variable(loss, required_grad = True)
  73. if train_config.use_fp16:
  74. # if fp16 is enabled, use gradient scaler to handle gradient update
  75. scaler.scale(loss).backward()
  76. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  77. scaler.step(optimizer)
  78. scaler.update()
  79. optimizer.zero_grad()
  80. pbar.update(step//gradient_accumulation_steps)
  81. else:
  82. # regular backpropagation when fp16 is not used
  83. loss.backward()
  84. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  85. optimizer.step()
  86. optimizer.zero_grad()
  87. pbar.update(step//gradient_accumulation_steps)
  88. pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  89. epoch_end_time = time.perf_counter()-epoch_start_time
  90. epoch_times.append(epoch_end_time)
  91. # Reducing total_loss across all devices if there's more than one CUDA device
  92. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  93. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  94. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  95. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  96. train_epoch_loss = total_loss / len(train_dataloader)
  97. if train_config.enable_fsdp:
  98. train_epoch_loss = train_epoch_loss/world_size
  99. train_perplexity = torch.exp(train_epoch_loss)
  100. train_prep.append(train_perplexity)
  101. train_loss.append(train_epoch_loss)
  102. if train_config.enable_fsdp:
  103. if rank==0:
  104. if is_xpu_available():
  105. print(f"Max XPU memory allocated was {memtrace.peak} GB")
  106. print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
  107. print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
  108. print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
  109. else:
  110. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  111. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  112. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  113. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  114. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  115. else:
  116. if is_xpu_available():
  117. print(f"Max XPU memory allocated was {memtrace.peak} GB")
  118. print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
  119. print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
  120. print(f"Xpu Malloc retires : {memtrace.cuda_malloc_retires}")
  121. else:
  122. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  123. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  124. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  125. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  126. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  127. # Update the learning rate as needed
  128. lr_scheduler.step()
  129. if train_config.run_validation:
  130. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  131. checkpoint_start_time = time.perf_counter()
  132. if train_config.save_model and eval_epoch_loss < best_val_loss:
  133. if train_config.enable_fsdp:
  134. dist.barrier()
  135. if train_config.use_peft:
  136. if train_config.enable_fsdp:
  137. if rank==0:
  138. print(f"we are about to save the PEFT modules")
  139. else:
  140. print(f"we are about to save the PEFT modules")
  141. model.save_pretrained(train_config.output_dir)
  142. if train_config.enable_fsdp:
  143. if rank==0:
  144. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  145. else:
  146. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  147. else:
  148. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  149. save_model_checkpoint(
  150. model, optimizer, rank, train_config, epoch=epoch
  151. )
  152. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  153. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  154. print("=====================================================")
  155. save_model_and_optimizer_sharded(model, rank, train_config)
  156. if train_config.save_optimizer:
  157. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  158. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  159. print("=====================================================")
  160. if not train_config.use_peft and train_config.save_optimizer:
  161. save_optimizer_checkpoint(
  162. model, optimizer, rank, train_config, epoch=epoch
  163. )
  164. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  165. print("=====================================================")
  166. if train_config.enable_fsdp:
  167. dist.barrier()
  168. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  169. checkpoint_times.append(checkpoint_end_time)
  170. if eval_epoch_loss < best_val_loss:
  171. best_val_loss = eval_epoch_loss
  172. if train_config.enable_fsdp:
  173. if rank==0:
  174. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  175. else:
  176. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  177. val_loss.append(best_val_loss)
  178. val_prep.append(eval_ppl)
  179. if train_config.enable_fsdp:
  180. if rank==0:
  181. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  182. else:
  183. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
  184. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  185. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  186. avg_train_prep = sum(train_prep)/len(train_prep)
  187. avg_train_loss = sum(train_loss)/len(train_loss)
  188. if train_config.run_validation:
  189. avg_eval_prep = sum(val_prep)/len(val_prep)
  190. avg_eval_loss = sum(val_loss)/len(val_loss)
  191. results['avg_train_prep'] = avg_train_prep
  192. results['avg_train_loss'] = avg_train_loss
  193. if train_config.run_validation:
  194. results['avg_eval_prep'] = avg_eval_prep
  195. results['avg_eval_loss'] = avg_eval_loss
  196. results["avg_epoch_time"] = avg_epoch_time
  197. results["avg_checkpoint_time"] = avg_checkpoint_time
  198. #saving the training params including fsdp setting for reference.
  199. if train_config.enable_fsdp and not train_config.use_peft:
  200. save_train_params(train_config, fsdp_config, rank)
  201. return results
  202. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  203. """
  204. Evaluates the model on the given dataloader
  205. Args:
  206. model: The model to evaluate
  207. eval_dataloader: The dataloader containing the evaluation data
  208. local_rank: The rank of the current node in a distributed setting
  209. tokenizer: The tokenizer used to decode predictions
  210. Returns: eval_ppl, eval_epoch_loss
  211. """
  212. if train_config.enable_fsdp:
  213. world_size = int(os.environ["WORLD_SIZE"])
  214. model.eval()
  215. eval_preds = []
  216. eval_loss = 0.0 # Initialize evaluation loss
  217. with MemoryTrace() as memtrace:
  218. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  219. for key in batch.keys():
  220. if train_config.enable_fsdp:
  221. batch[key] = batch[key].to(local_rank)
  222. else:
  223. batch[key] = batch[key].to('cuda:0')
  224. # Ensure no gradients are computed for this scope to save memory
  225. with torch.no_grad():
  226. # Forward pass and compute loss
  227. outputs = model(**batch)
  228. loss = outputs.loss
  229. eval_loss += loss.detach().float()
  230. # Decode predictions and add to evaluation predictions list
  231. preds = torch.argmax(outputs.logits, -1)
  232. eval_preds.extend(
  233. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  234. )
  235. # If there's more than one CUDA device, reduce evaluation loss across all devices
  236. if is_xpu_available() and (torch.cuda.device_count() > 1 and train_config.enable_fsdp):
  237. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  238. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  239. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  240. # Compute average loss and perplexity
  241. eval_epoch_loss = eval_loss / len(eval_dataloader)
  242. if train_config.enable_fsdp:
  243. eval_epoch_loss = eval_epoch_loss/world_size
  244. eval_ppl = torch.exp(eval_epoch_loss)
  245. # Print evaluation metrics
  246. if train_config.enable_fsdp:
  247. if local_rank==0:
  248. print(f" {eval_ppl=} {eval_epoch_loss=}")
  249. else:
  250. print(f" {eval_ppl=} {eval_epoch_loss=}")
  251. return eval_ppl, eval_epoch_loss
  252. def freeze_transformer_layers(model, num_layer):
  253. for i, layer in enumerate(model.model.layers):
  254. if i < num_layer:
  255. for param in layer.parameters():
  256. param.requires_grad = False
  257. def check_frozen_layers_peft_model(model):
  258. for i, layer in enumerate(model.base_model.model.model.layers):
  259. for name, param in layer.named_parameters():
  260. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  261. def setup():
  262. """Initialize the process group for distributed training"""
  263. if is_ccl_available():
  264. # distributed training on xpus
  265. dist.init_process_group("ccl")
  266. else:
  267. dist.init_process_group("nccl")
  268. def setup_environ_flags(rank):
  269. """Set environment flags for debugging purposes"""
  270. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  271. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  272. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  273. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  274. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  275. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  276. if rank == 0:
  277. print(f"--> Running with torch dist debug set to detail")
  278. def cleanup():
  279. """Clean up the process group after training"""
  280. dist.destroy_process_group()
  281. def clear_gpu_cache(rank=None):
  282. """Clear the GPU cache for all ranks"""
  283. if rank == 0:
  284. print(f"Clearing GPU cache for all ranks")
  285. if is_xpu_available():
  286. torch.xpu_empty_cache()
  287. else:
  288. torch.cuda.empty_cache()
  289. def get_parameter_dtypes(model):
  290. """Get the data types of model parameters"""
  291. parameter_dtypes = {}
  292. for name, parameter in model.named_parameters():
  293. parameter_dtypes[name] = parameter.dtype
  294. return parameter_dtypes
  295. def print_model_size(model, config, rank: int = 0) -> None:
  296. """
  297. Print model name, the number of trainable parameters and initialization time.
  298. Args:
  299. model: The PyTorch model.
  300. model_name (str): Name of the model.
  301. init_time_start (float): Initialization start time.
  302. init_time_end (float): Initialization end time.
  303. rank (int, optional): Current process's rank. Defaults to 0.
  304. """
  305. if rank == 0:
  306. print(f"--> Model {config.model_name}")
  307. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  308. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  309. def get_policies(cfg, rank):
  310. """Get the policies for mixed precision and fsdp wrapping"""
  311. verify_bfloat_support = ((
  312. torch.version.cuda
  313. and torch.cuda.is_bf16_supported()
  314. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  315. and dist.is_nccl_available()
  316. and nccl.version() >= (2, 10)
  317. ) or
  318. (is_xpu_available()))
  319. mixed_precision_policy = None
  320. wrapping_policy = None
  321. # Mixed precision
  322. if cfg.mixed_precision:
  323. bf16_ready = verify_bfloat_support
  324. if bf16_ready and not cfg.use_fp16:
  325. mixed_precision_policy = bfSixteen_mixed
  326. if rank == 0:
  327. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  328. elif cfg.use_fp16:
  329. mixed_precision_policy = fpSixteen
  330. if rank == 0:
  331. print(f"FP16 enabled")
  332. else:
  333. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  334. wrapping_policy = get_llama_wrapper()
  335. return mixed_precision_policy, wrapping_policy
  336. def save_train_params(train_config, fsdp_config, rank):
  337. """
  338. This function saves the train_config and FSDP config into a train_params.yaml.
  339. This will be used by converter script in the inference folder to fetch the HF model name or path.
  340. It also would be hepful as a log for future references.
  341. """
  342. # Convert the train_config and fsdp_config objects to dictionaries,
  343. # converting all values to strings to ensure they can be serialized into a YAML file
  344. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  345. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  346. # Merge the two dictionaries into one
  347. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  348. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  349. folder_name = (
  350. train_config.dist_checkpoint_root_folder
  351. + "/"
  352. + train_config.dist_checkpoint_folder
  353. + "-"
  354. + train_config.model_name
  355. )
  356. save_dir = Path.cwd() / folder_name
  357. # If the directory does not exist, create it
  358. if not os.path.exists(save_dir):
  359. os.makedirs(save_dir)
  360. # Convert the dictionary to a YAML string
  361. config_yaml = yaml.dump(train_params_dict, indent=4)
  362. file_name = os.path.join(save_dir,'train_params.yaml')
  363. # Check if there's a directory with the same name as the file
  364. if os.path.isdir(file_name):
  365. print(f"Error: {file_name} is a directory, not a file.")
  366. else:
  367. # Write the YAML string to the file
  368. with open(file_name, 'w') as f:
  369. f.write(config_yaml)
  370. if rank==0:
  371. print(f"training params are saved in {file_name}")