|
@@ -0,0 +1,324 @@
|
|
|
+# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
+
|
|
|
+import os
|
|
|
+import sys
|
|
|
+from typing import List
|
|
|
+
|
|
|
+import fire
|
|
|
+import torch
|
|
|
+import transformers
|
|
|
+from datasets import load_dataset
|
|
|
+from tqdm import tqdm
|
|
|
+"""
|
|
|
+Unused imports:
|
|
|
+import torch.nn as nn
|
|
|
+import bitsandbytes as bnb
|
|
|
+"""
|
|
|
+from torch.nn import functional as F
|
|
|
+from peft import (
|
|
|
+ LoraConfig,
|
|
|
+ get_peft_model,
|
|
|
+ get_peft_model_state_dict,
|
|
|
+ prepare_model_for_int8_training,
|
|
|
+ set_peft_model_state_dict,
|
|
|
+)
|
|
|
+from transformers import LlamaForCausalLM, LlamaTokenizer
|
|
|
+from torch.distributed.fsdp import StateDictType
|
|
|
+import torch.distributed as dist
|
|
|
+from pkg_resources import packaging
|
|
|
+from .memory_utils import MemoryTrace
|
|
|
+import model_checkpointing
|
|
|
+import torch.cuda.nccl as nccl
|
|
|
+from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
+from pathlib import Path
|
|
|
+sys.path.append(str(Path(__file__).resolve().parent.parent))
|
|
|
+from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
|
|
|
+
|
|
|
+scaler = ShardedGradScaler()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def set_tokenizer_params(tokenizer: LlamaTokenizer):
|
|
|
+ tokenizer.pad_token_id = 0
|
|
|
+ tokenizer.padding_side = "left"
|
|
|
+
|
|
|
+# Converting Bytes to Megabytes
|
|
|
+def byte2mb(x):
|
|
|
+ return int(x / 2**20)
|
|
|
+
|
|
|
+def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
|
|
|
+ """
|
|
|
+ Trains the model on the given dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model: The model to be trained
|
|
|
+ train_dataloader: The dataloader containing the training data
|
|
|
+ optimizer: The optimizer used for training
|
|
|
+ lr_scheduler: The learning rate scheduler
|
|
|
+ gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
|
|
|
+ num_epochs: The number of epochs to train for
|
|
|
+ local_rank: The rank of the current node in a distributed setting
|
|
|
+ train_config: The training configuration
|
|
|
+ eval_dataloader: The dataloader containing the eval data
|
|
|
+ tokenizer: tokenizer used in the eval for decoding the predicitons
|
|
|
+
|
|
|
+ Returns: results dictionary containing average training and validation perplexity and loss
|
|
|
+ """
|
|
|
+ # Create a gradient scaler for fp16
|
|
|
+ scaler = torch.cuda.amp.GradScaler() if train_config.use_fp16 else None
|
|
|
+
|
|
|
+ train_prep = []
|
|
|
+ train_loss = []
|
|
|
+ val_prep = []
|
|
|
+ val_loss =[]
|
|
|
+ results = {}
|
|
|
+ best_val_loss = float("inf")
|
|
|
+ for epoch in range(train_config.num_epochs):
|
|
|
+ with MemoryTrace() as memtrace: # track the memory usage
|
|
|
+ model.train()
|
|
|
+ total_loss = 0.0
|
|
|
+ data_set_len = 0
|
|
|
+
|
|
|
+ for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
|
|
|
+ for key in batch.keys():
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ batch[key] = batch[key].to(local_rank)
|
|
|
+ elif not train_config.quantization:
|
|
|
+ batch[key] = batch[key].to('cuda')
|
|
|
+ outputs = model(**batch)
|
|
|
+ loss = outputs.loss
|
|
|
+ loss = loss / gradient_accumulation_steps
|
|
|
+ total_loss += loss.detach().float()
|
|
|
+ first_key = next(iter(batch))
|
|
|
+ data_set_len += len(batch[first_key])
|
|
|
+ if train_config.use_fp16:
|
|
|
+ # if fp16 is enabled, use gradient scaler to handle gradient update
|
|
|
+ scaler.scale(loss).backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
+ optimizer.zero_grad()
|
|
|
+ else:
|
|
|
+ # regular backpropagation when fp16 is not used
|
|
|
+ loss.backward()
|
|
|
+ if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
|
|
|
+ optimizer.step()
|
|
|
+ lr_scheduler.step()
|
|
|
+ optimizer.zero_grad()
|
|
|
+
|
|
|
+ print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
+
|
|
|
+ # Reducing total_loss across all devices if there's more than one CUDA device
|
|
|
+ if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
+ dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
|
+ train_epoch_loss = total_loss / data_set_len
|
|
|
+ train_perplexity = torch.exp(train_epoch_loss)
|
|
|
+
|
|
|
+ train_prep.append(train_perplexity)
|
|
|
+ train_loss.append(train_epoch_loss)
|
|
|
+
|
|
|
+ print(f"Max CUDA memory allocated was {memtrace.peak} GB")
|
|
|
+ print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
|
|
|
+ print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
|
|
|
+ print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
|
|
|
+
|
|
|
+ if train_config.run_validation:
|
|
|
+ eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
|
|
|
+ if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
+
|
|
|
+ if train_config.use_peft:
|
|
|
+
|
|
|
+ print(f"we are in the saving the PEFT modules")
|
|
|
+ model.save_pretrained(train_config.output_dir)
|
|
|
+ print(f"PEFT modules are saved in {train_config.output_dir} directory")
|
|
|
+
|
|
|
+ else:
|
|
|
+ if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
|
|
|
+
|
|
|
+ model_checkpointing.save_model_checkpoint(
|
|
|
+ model, optimizer, rank, train_config, epoch=1
|
|
|
+ )
|
|
|
+ elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
|
|
|
+ print(" we are about to save the models *******")
|
|
|
+
|
|
|
+ model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
+ if train_config.save_optimizer:
|
|
|
+ model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
+
|
|
|
+ if not train_config.use_peft and train_config.save_optimizer:
|
|
|
+ model_checkpointing.save_optimizer_checkpoint(
|
|
|
+ model, optimizer, rank, train_config, epoch=1
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ if local_rank == 0 and eval_epoch_loss < best_val_loss:
|
|
|
+ best_val_loss = eval_epoch_loss
|
|
|
+ print(f"best eval loss on epoch {epoch} is {best_val_loss}")
|
|
|
+ val_loss.append(best_val_loss)
|
|
|
+ val_prep.append(eval_ppl)
|
|
|
+
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
|
|
|
+
|
|
|
+ avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
+ avg_train_loss = sum(train_loss)/len(train_loss)
|
|
|
+ if train_config.run_validation:
|
|
|
+ avg_eval_prep = sum(val_prep)/len(val_prep)
|
|
|
+ avg_eval_loss = sum(val_loss)/len(val_loss)
|
|
|
+
|
|
|
+ results['avg_train_prep'] = avg_train_prep
|
|
|
+ results['avg_train_loss'] = avg_train_loss
|
|
|
+ if train_config.run_validation:
|
|
|
+ results['avg_eval_prep'] = avg_eval_prep
|
|
|
+ results['avg_eval_loss'] = avg_eval_loss
|
|
|
+
|
|
|
+
|
|
|
+ return results
|
|
|
+
|
|
|
+def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
+ """
|
|
|
+ Evaluates the model on the given dataloader
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model: The model to evaluate
|
|
|
+ eval_dataloader: The dataloader containing the evaluation data
|
|
|
+ local_rank: The rank of the current node in a distributed setting
|
|
|
+ tokenizer: The tokenizer used to decode predictions
|
|
|
+
|
|
|
+ Returns: eval_ppl, eval_epoch_loss
|
|
|
+ """
|
|
|
+ model.eval()
|
|
|
+ eval_preds = []
|
|
|
+ eval_loss = 0.0 # Initialize evaluation loss
|
|
|
+ eval_dataset_len = 0
|
|
|
+ with MemoryTrace() as memtrace:
|
|
|
+ for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
|
|
|
+ for key in batch.keys():
|
|
|
+ if train_config.enable_fsdp:
|
|
|
+ batch[key] = batch[key].to(local_rank)
|
|
|
+ else:
|
|
|
+ batch[key] = batch[key].to('cuda')
|
|
|
+ # Ensure no gradients are computed for this scope to save memory
|
|
|
+ with torch.no_grad():
|
|
|
+ # Forward pass and compute loss
|
|
|
+ outputs = model(**batch)
|
|
|
+ loss = outputs.loss
|
|
|
+ eval_loss += loss.detach().float()
|
|
|
+ first_key = next(iter(batch))
|
|
|
+ eval_dataset_len+= len(batch[first_key])
|
|
|
+
|
|
|
+ # Decode predictions and add to evaluation predictions list
|
|
|
+ preds = torch.argmax(outputs.logits, -1)
|
|
|
+ eval_preds.extend(
|
|
|
+ tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
|
|
|
+ )
|
|
|
+
|
|
|
+ # If there's more than one CUDA device, reduce evaluation loss across all devices
|
|
|
+ if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
+ dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
|
|
+
|
|
|
+ # Compute average loss and perplexity
|
|
|
+ eval_epoch_loss = eval_loss / eval_dataset_len
|
|
|
+ eval_ppl = torch.exp(eval_epoch_loss)
|
|
|
+
|
|
|
+ # Print evaluation metrics
|
|
|
+ print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
+ return eval_ppl, eval_epoch_loss
|
|
|
+
|
|
|
+def freeze_transformer_layers(model, num_layer):
|
|
|
+ for i, layer in enumerate(model.model.layers):
|
|
|
+ if i < num_layer:
|
|
|
+ for param in layer.parameters():
|
|
|
+ param.requires_grad = False
|
|
|
+
|
|
|
+
|
|
|
+def check_frozen_layers_peft_model(model):
|
|
|
+ for i, layer in enumerate(model.base_model.model.model.layers):
|
|
|
+ for name, param in layer.named_parameters():
|
|
|
+ print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
|
|
|
+
|
|
|
+
|
|
|
+def setup():
|
|
|
+ """Initialize the process group for distributed training"""
|
|
|
+ dist.init_process_group("nccl")
|
|
|
+
|
|
|
+
|
|
|
+def setup_environ_flags(rank):
|
|
|
+ """Set environment flags for debugging purposes"""
|
|
|
+ os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
|
|
|
+ os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
|
|
|
+ os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
|
|
+ if rank == 0:
|
|
|
+ print(f"--> Running with torch dist debug set to detail")
|
|
|
+
|
|
|
+
|
|
|
+def cleanup():
|
|
|
+ """Clean up the process group after training"""
|
|
|
+ dist.destroy_process_group()
|
|
|
+
|
|
|
+
|
|
|
+def clear_gpu_cache(rank=None):
|
|
|
+ """Clear the GPU cache for all ranks"""
|
|
|
+ if rank == 0:
|
|
|
+ print(f"Clearing GPU cache for all ranks")
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+
|
|
|
+def get_parameter_dtypes(model):
|
|
|
+ """Get the data types of model parameters"""
|
|
|
+ parameter_dtypes = {}
|
|
|
+ for name, parameter in model.named_parameters():
|
|
|
+ parameter_dtypes[name] = parameter.dtype
|
|
|
+ return parameter_dtypes
|
|
|
+
|
|
|
+def print_model_size(model, config, rank: int = 0) -> None:
|
|
|
+ """
|
|
|
+ Print model name, the number of trainable parameters and initialization time.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ model: The PyTorch model.
|
|
|
+ model_name (str): Name of the model.
|
|
|
+ init_time_start (float): Initialization start time.
|
|
|
+ init_time_end (float): Initialization end time.
|
|
|
+ rank (int, optional): Current process's rank. Defaults to 0.
|
|
|
+ """
|
|
|
+ if rank == 0:
|
|
|
+ print(f"--> Model {config.model_name}")
|
|
|
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
+ print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def get_policies(cfg, rank):
|
|
|
+ """Get the policies for mixed precision and fsdp wrapping"""
|
|
|
+
|
|
|
+ verify_bfloat_support = (
|
|
|
+ torch.version.cuda
|
|
|
+ and torch.cuda.is_bf16_supported()
|
|
|
+ and packaging.version.parse(torch.version.cuda).release >= (11, 0)
|
|
|
+ and dist.is_nccl_available()
|
|
|
+ and nccl.version() >= (2, 10)
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+ mixed_precision_policy = None
|
|
|
+ wrapping_policy = None
|
|
|
+
|
|
|
+ # Mixed precision
|
|
|
+ if cfg.mixed_precision:
|
|
|
+ bf16_ready = verify_bfloat_support
|
|
|
+
|
|
|
+ if bf16_ready and not cfg.use_fp16:
|
|
|
+ mixed_precision_policy = bfSixteen_mixed
|
|
|
+ if rank == 0:
|
|
|
+ print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
|
|
|
+ elif cfg.use_fp16:
|
|
|
+ mixed_precision_policy = fpSixteen
|
|
|
+ if rank == 0:
|
|
|
+ print(f"FP16 enabled")
|
|
|
+ else:
|
|
|
+ print(f"bFloat16 support not present. Using FP32, and not mixed precision")
|
|
|
+ wrapping_policy = get_llama_wrapper()
|
|
|
+ return mixed_precision_policy, wrapping_policy
|