|
@@ -2,71 +2,49 @@
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
import os
|
|
import os
|
|
-import sys
|
|
|
|
-from typing import List, Union
|
|
|
|
|
|
|
|
import fire
|
|
import fire
|
|
import torch
|
|
import torch
|
|
-import transformers
|
|
|
|
-from datasets import load_dataset
|
|
|
|
-import os.path as osp
|
|
|
|
-from tqdm import tqdm
|
|
|
|
-
|
|
|
|
-# Unused imports removed
|
|
|
|
-from utils import fsdp_auto_wrap_policy
|
|
|
|
|
|
+import torch.distributed as dist
|
|
|
|
+import torch.optim as optim
|
|
|
|
+from peft import get_peft_model, prepare_model_for_int8_training
|
|
|
|
+from pkg_resources import packaging
|
|
|
|
+from torch.distributed.fsdp import (
|
|
|
|
+ FullyShardedDataParallel as FSDP,
|
|
|
|
+)
|
|
|
|
+from torch.optim.lr_scheduler import StepLR
|
|
|
|
+from torch.utils.data import DistributedSampler
|
|
from transformers import (
|
|
from transformers import (
|
|
LlamaForCausalLM,
|
|
LlamaForCausalLM,
|
|
LlamaTokenizer,
|
|
LlamaTokenizer,
|
|
- AutoModelForCausalLM,
|
|
|
|
- AutoModelForSeq2SeqLM,
|
|
|
|
- AutoTokenizer,
|
|
|
|
|
|
+ LlamaConfig,
|
|
default_data_collator,
|
|
default_data_collator,
|
|
- BitsAndBytesConfig
|
|
|
|
-)
|
|
|
|
-import torch.distributed as dist
|
|
|
|
-# Unused imports removed
|
|
|
|
-from utils.train_utils import (
|
|
|
|
- set_tokenizer_params,
|
|
|
|
- train,
|
|
|
|
- evaluation,
|
|
|
|
- freeze_transformer_layers,
|
|
|
|
- check_frozen_layers_peft_model,
|
|
|
|
- setup,
|
|
|
|
- setup_environ_flags,
|
|
|
|
- cleanup,
|
|
|
|
- clear_gpu_cache,
|
|
|
|
- get_parameter_dtypes,
|
|
|
|
- print_model_size,
|
|
|
|
- get_policies
|
|
|
|
)
|
|
)
|
|
|
|
+from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
|
|
-from utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
|
|
+import policies
|
|
|
|
+from configs import fsdp_config, train_config
|
|
|
|
+from policies import AnyPrecisionAdamW
|
|
|
|
|
|
|
|
+from utils import fsdp_auto_wrap_policy
|
|
from utils.config_utils import (
|
|
from utils.config_utils import (
|
|
update_config,
|
|
update_config,
|
|
generate_peft_config,
|
|
generate_peft_config,
|
|
generate_dataset_config,
|
|
generate_dataset_config,
|
|
)
|
|
)
|
|
-from peft import get_peft_model, TaskType, prepare_model_for_int8_training
|
|
|
|
-import configs
|
|
|
|
-from torch.distributed.fsdp import (
|
|
|
|
- FullyShardedDataParallel as FSDP,
|
|
|
|
- MixedPrecision,
|
|
|
|
|
|
+from utils.dataset_utils import get_preprocessed_dataset
|
|
|
|
+
|
|
|
|
+from utils.train_utils import (
|
|
|
|
+ train,
|
|
|
|
+ freeze_transformer_layers,
|
|
|
|
+ setup,
|
|
|
|
+ setup_environ_flags,
|
|
|
|
+ clear_gpu_cache,
|
|
|
|
+ print_model_size,
|
|
|
|
+ get_policies
|
|
)
|
|
)
|
|
-from torch.utils.data import DistributedSampler
|
|
|
|
-import policies
|
|
|
|
-from policies import AnyPrecisionAdamW
|
|
|
|
-from configs import fsdp_config, train_config
|
|
|
|
-import torch.optim as optim
|
|
|
|
-from torch.optim.lr_scheduler import StepLR
|
|
|
|
-from pkg_resources import packaging
|
|
|
|
-import torch
|
|
|
|
-import torch.cuda.nccl as nccl
|
|
|
|
-import torch.distributed as dist
|
|
|
|
-from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
|
|
from accelerate.utils import is_xpu_available
|
|
from accelerate.utils import is_xpu_available
|
|
|
|
|
|
-
|
|
|
|
def main(**kwargs):
|
|
def main(**kwargs):
|
|
# Update the configuration for the training and sharding process
|
|
# Update the configuration for the training and sharding process
|
|
update_config((train_config, fsdp_config), **kwargs)
|
|
update_config((train_config, fsdp_config), **kwargs)
|
|
@@ -90,17 +68,42 @@ def main(**kwargs):
|
|
torch.xpu.set_device(rank)
|
|
torch.xpu.set_device(rank)
|
|
else:
|
|
else:
|
|
torch.cuda.set_device(rank)
|
|
torch.cuda.set_device(rank)
|
|
|
|
+ clear_gpu_cache(rank)
|
|
setup_environ_flags(rank)
|
|
setup_environ_flags(rank)
|
|
-
|
|
|
|
|
|
+
|
|
# Calculate gradient accumulation steps
|
|
# Calculate gradient accumulation steps
|
|
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
|
|
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
|
|
-
|
|
|
|
|
|
+
|
|
# Load the pre-trained model and setup its configuration
|
|
# Load the pre-trained model and setup its configuration
|
|
- model = LlamaForCausalLM.from_pretrained(
|
|
|
|
- train_config.model_name,
|
|
|
|
- load_in_8bit=True if train_config.quantization else None,
|
|
|
|
- device_map="auto" if train_config.quantization else None,
|
|
|
|
- )
|
|
|
|
|
|
+ if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
|
|
+ """
|
|
|
|
+ for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
|
|
|
|
+ this avoids cpu oom when loading large models like llama 70B, in which case
|
|
|
|
+ model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
|
|
|
|
+ overhead and currently requires latest nightly.
|
|
|
|
+ """
|
|
|
|
+ v = packaging.version.parse(torch.__version__)
|
|
|
|
+ verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
|
|
|
|
+ if not verify_latest_nightly:
|
|
|
|
+ raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
|
|
|
|
+ "please install latest nightly.")
|
|
|
|
+ if rank == 0:
|
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
|
+ train_config.model_name,
|
|
|
|
+ load_in_8bit=True if train_config.quantization else None,
|
|
|
|
+ device_map="auto" if train_config.quantization else None,
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
|
+ with torch.device("meta"):
|
|
|
|
+ model = LlamaForCausalLM(llama_config)
|
|
|
|
+
|
|
|
|
+ else:
|
|
|
|
+ model = LlamaForCausalLM.from_pretrained(
|
|
|
|
+ train_config.model_name,
|
|
|
|
+ load_in_8bit=True if train_config.quantization else None,
|
|
|
|
+ device_map="auto" if train_config.quantization else None,
|
|
|
|
+ )
|
|
if train_config.enable_fsdp and train_config.use_fast_kernels:
|
|
if train_config.enable_fsdp and train_config.use_fast_kernels:
|
|
"""
|
|
"""
|
|
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
|
|
For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
|
|
@@ -113,11 +116,11 @@ def main(**kwargs):
|
|
except ImportError:
|
|
except ImportError:
|
|
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
|
print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
-
|
|
|
|
|
|
+
|
|
# Prepare the model for int8 training if quantization is enabled
|
|
# Prepare the model for int8 training if quantization is enabled
|
|
if train_config.quantization:
|
|
if train_config.quantization:
|
|
model = prepare_model_for_int8_training(model)
|
|
model = prepare_model_for_int8_training(model)
|
|
-
|
|
|
|
|
|
+
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
if train_config.enable_fsdp and fsdp_config.pure_bf16:
|
|
if train_config.enable_fsdp and fsdp_config.pure_bf16:
|
|
model.to(torch.bfloat16)
|
|
model.to(torch.bfloat16)
|
|
@@ -126,7 +129,7 @@ def main(**kwargs):
|
|
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
|
|
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
|
|
tokenizer.add_special_tokens(
|
|
tokenizer.add_special_tokens(
|
|
{
|
|
{
|
|
-
|
|
|
|
|
|
+
|
|
"pad_token": "<PAD>",
|
|
"pad_token": "<PAD>",
|
|
}
|
|
}
|
|
)
|
|
)
|
|
@@ -134,16 +137,16 @@ def main(**kwargs):
|
|
peft_config = generate_peft_config(train_config, kwargs)
|
|
peft_config = generate_peft_config(train_config, kwargs)
|
|
model = get_peft_model(model, peft_config)
|
|
model = get_peft_model(model, peft_config)
|
|
model.print_trainable_parameters()
|
|
model.print_trainable_parameters()
|
|
-
|
|
|
|
|
|
+
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
if train_config.enable_fsdp:
|
|
if train_config.enable_fsdp:
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
-
|
|
|
|
|
|
+
|
|
freeze_transformer_layers(train_config.num_freeze_layers)
|
|
freeze_transformer_layers(train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
|
|
-
|
|
|
|
|
|
+
|
|
model = FSDP(
|
|
model = FSDP(
|
|
model,
|
|
model,
|
|
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
|
|
@@ -151,6 +154,9 @@ def main(**kwargs):
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
|
|
device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
|
|
limit_all_gathers=True,
|
|
limit_all_gathers=True,
|
|
|
|
+ sync_module_states=train_config.low_cpu_fsdp,
|
|
|
|
+ param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
|
+ if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
)
|
|
)
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
policies.apply_fsdp_checkpointing(model)
|
|
policies.apply_fsdp_checkpointing(model)
|
|
@@ -161,14 +167,14 @@ def main(**kwargs):
|
|
model.to("cuda")
|
|
model.to("cuda")
|
|
|
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
-
|
|
|
|
|
|
+
|
|
# Load and preprocess the dataset for training and validation
|
|
# Load and preprocess the dataset for training and validation
|
|
dataset_train = get_preprocessed_dataset(
|
|
dataset_train = get_preprocessed_dataset(
|
|
tokenizer,
|
|
tokenizer,
|
|
dataset_config,
|
|
dataset_config,
|
|
split="train",
|
|
split="train",
|
|
)
|
|
)
|
|
-
|
|
|
|
|
|
+
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
|
|
|
|
@@ -195,7 +201,7 @@ def main(**kwargs):
|
|
rank=dist.get_rank(),
|
|
rank=dist.get_rank(),
|
|
num_replicas=dist.get_world_size(),
|
|
num_replicas=dist.get_world_size(),
|
|
)
|
|
)
|
|
-
|
|
|
|
|
|
+
|
|
# Create DataLoaders for the training and validation dataset
|
|
# Create DataLoaders for the training and validation dataset
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
dataset_train,
|
|
dataset_train,
|
|
@@ -217,7 +223,7 @@ def main(**kwargs):
|
|
drop_last=True,
|
|
drop_last=True,
|
|
collate_fn=default_data_collator,
|
|
collate_fn=default_data_collator,
|
|
)
|
|
)
|
|
-
|
|
|
|
|
|
+
|
|
# Initialize the optimizer and learning rate scheduler
|
|
# Initialize the optimizer and learning rate scheduler
|
|
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
optimizer = AnyPrecisionAdamW(
|
|
optimizer = AnyPrecisionAdamW(
|
|
@@ -239,7 +245,7 @@ def main(**kwargs):
|
|
results = train(
|
|
results = train(
|
|
model,
|
|
model,
|
|
train_dataloader,
|
|
train_dataloader,
|
|
- eval_dataloader,
|
|
|
|
|
|
+ eval_dataloader,
|
|
tokenizer,
|
|
tokenizer,
|
|
optimizer,
|
|
optimizer,
|
|
scheduler,
|
|
scheduler,
|