|
@@ -9,10 +9,7 @@ import random
|
|
|
import torch
|
|
|
import torch.optim as optim
|
|
|
from peft import get_peft_model, prepare_model_for_kbit_training
|
|
|
-from torch.distributed.fsdp import (
|
|
|
- FullyShardedDataParallel as FSDP,
|
|
|
- ShardingStrategy
|
|
|
-)
|
|
|
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
|
|
|
|
|
|
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
|
|
|
from torch.optim.lr_scheduler import StepLR
|
|
@@ -49,6 +46,7 @@ from llama_recipes.utils.train_utils import (
|
|
|
)
|
|
|
from accelerate.utils import is_xpu_available
|
|
|
|
|
|
+
|
|
|
def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
try:
|
|
|
import wandb
|
|
@@ -58,6 +56,7 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
"Please install it using pip install wandb"
|
|
|
)
|
|
|
from llama_recipes.configs import wandb_config as WANDB_CONFIG
|
|
|
+
|
|
|
wandb_config = WANDB_CONFIG()
|
|
|
update_config(wandb_config, **kwargs)
|
|
|
init_dict = dataclasses.asdict(wandb_config)
|
|
@@ -67,10 +66,34 @@ def setup_wandb(train_config, fsdp_config, **kwargs):
|
|
|
return run
|
|
|
|
|
|
|
|
|
+def display(config):
|
|
|
+ from io import StringIO
|
|
|
+ from re import compile
|
|
|
+ from termcolor import colored
|
|
|
+
|
|
|
+ buffer = StringIO()
|
|
|
+ print(config, file=buffer)
|
|
|
+ text = buffer.getvalue()
|
|
|
+
|
|
|
+ pat = compile("^(\w+)\(([\w\d]+=[^,]+(, [\w\d]+=[^,]+)*)\)$")
|
|
|
+
|
|
|
+ result = pat.match(text)
|
|
|
+ assert result is not None
|
|
|
+ name = result.group(1)
|
|
|
+ print()
|
|
|
+ print(colored(name.replace("_", " ").upper(), "blue"))
|
|
|
+ for key, value in map(lambda s: s.split("="), result.group(2).split(", ")):
|
|
|
+ print(colored(key, "green"), "=", colored(value, "red"))
|
|
|
+
|
|
|
+
|
|
|
def main(**kwargs):
|
|
|
# Update the configuration for the training and sharding process
|
|
|
train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
|
|
|
update_config((train_config, fsdp_config), **kwargs)
|
|
|
+
|
|
|
+ display(train_config)
|
|
|
+ display(fsdp_config)
|
|
|
+
|
|
|
# Set the seeds for reproducibility
|
|
|
if is_xpu_available():
|
|
|
torch.xpu.manual_seed(train_config.seed)
|
|
@@ -95,7 +118,7 @@ def main(**kwargs):
|
|
|
wandb_run = None
|
|
|
|
|
|
if train_config.use_wandb:
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
|
|
|
|
|
|
# Load the pre-trained model and setup its configuration
|
|
@@ -131,13 +154,19 @@ def main(**kwargs):
|
|
|
)
|
|
|
|
|
|
# Load the tokenizer and add special tokens
|
|
|
- tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
|
|
|
+ tokenizer = AutoTokenizer.from_pretrained(
|
|
|
+ train_config.model_name
|
|
|
+ if train_config.tokenizer_name is None
|
|
|
+ else train_config.tokenizer_name
|
|
|
+ )
|
|
|
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
|
|
|
|
- # If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
+ # If there is a mismatch between tokenizer vocab size and embedding matrix,
|
|
|
# throw a warning and then expand the embedding matrix
|
|
|
if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
|
|
|
- print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
|
|
|
+ print(
|
|
|
+ "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
|
|
|
+ )
|
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
@@ -157,13 +186,18 @@ def main(**kwargs):
|
|
|
if wandb_run:
|
|
|
wandb_run.config.update(peft_config)
|
|
|
|
|
|
-
|
|
|
hsdp_device_mesh = None
|
|
|
- if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
|
|
|
- hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
|
|
|
+ if (
|
|
|
+ fsdp_config.hsdp
|
|
|
+ and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
|
|
|
+ ):
|
|
|
+ hsdp_device_mesh = hsdp_device_mesh(
|
|
|
+ replica_group_size=fsdp_config.replica_group_size,
|
|
|
+ sharding_group_size=fsdp_config.sharding_group_size,
|
|
|
+ )
|
|
|
print("HSDP device mesh is ready")
|
|
|
|
|
|
- #setting up FSDP if enable_fsdp is enabled
|
|
|
+ # setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
|
|
|
@@ -180,16 +214,27 @@ def main(**kwargs):
|
|
|
|
|
|
model = FSDP(
|
|
|
model,
|
|
|
- auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
|
- cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
|
|
|
- mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
|
|
|
+ auto_wrap_policy=(
|
|
|
+ my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
|
|
|
+ ),
|
|
|
+ cpu_offload=(
|
|
|
+ CPUOffload(offload_params=True)
|
|
|
+ if fsdp_config.fsdp_cpu_offload
|
|
|
+ else None
|
|
|
+ ),
|
|
|
+ mixed_precision=(
|
|
|
+ mixed_precision_policy if not fsdp_config.pure_bf16 else None
|
|
|
+ ),
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
|
device_mesh=hsdp_device_mesh,
|
|
|
device_id=device_id,
|
|
|
limit_all_gathers=True,
|
|
|
sync_module_states=train_config.low_cpu_fsdp,
|
|
|
- param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
- if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
+ param_init_fn=lambda module: (
|
|
|
+ module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
+ if train_config.low_cpu_fsdp and rank != 0
|
|
|
+ else None
|
|
|
+ ),
|
|
|
)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
apply_fsdp_checkpointing(model)
|
|
@@ -201,7 +246,6 @@ def main(**kwargs):
|
|
|
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
|
|
|
|
- # Load and preprocess the dataset for training and validation
|
|
|
dataset_train = get_preprocessed_dataset(
|
|
|
tokenizer,
|
|
|
dataset_config,
|
|
@@ -217,12 +261,16 @@ def main(**kwargs):
|
|
|
split="test",
|
|
|
)
|
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
|
- print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
+ print(f"--> Validation Set Length = {len(dataset_val)}")
|
|
|
|
|
|
if train_config.batching_strategy == "packing":
|
|
|
- dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
|
|
|
+ dataset_train = ConcatDataset(
|
|
|
+ dataset_train, chunk_size=train_config.context_length
|
|
|
+ )
|
|
|
|
|
|
- train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
|
|
|
+ train_dl_kwargs = get_dataloader_kwargs(
|
|
|
+ train_config, dataset_train, tokenizer, "train"
|
|
|
+ )
|
|
|
|
|
|
# Create DataLoaders for the training and validation dataset
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
@@ -235,9 +283,13 @@ def main(**kwargs):
|
|
|
eval_dataloader = None
|
|
|
if train_config.run_validation:
|
|
|
if train_config.batching_strategy == "packing":
|
|
|
- dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
|
|
|
+ dataset_val = ConcatDataset(
|
|
|
+ dataset_val, chunk_size=train_config.context_length
|
|
|
+ )
|
|
|
|
|
|
- val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
|
|
|
+ val_dl_kwargs = get_dataloader_kwargs(
|
|
|
+ train_config, dataset_val, tokenizer, "val"
|
|
|
+ )
|
|
|
|
|
|
eval_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_val,
|
|
@@ -279,11 +331,12 @@ def main(**kwargs):
|
|
|
rank if train_config.enable_fsdp else None,
|
|
|
wandb_run,
|
|
|
)
|
|
|
- if not train_config.enable_fsdp or rank==0:
|
|
|
- [print(f'Key: {k}, Value: {v}') for k, v in results.items()]
|
|
|
+ if not train_config.enable_fsdp or rank == 0:
|
|
|
+ [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
|
|
|
if train_config.use_wandb:
|
|
|
- for k,v in results.items():
|
|
|
+ for k, v in results.items():
|
|
|
wandb_run.summary[k] = v
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
fire.Fire(main)
|