|
@@ -39,7 +39,7 @@ from utils.train_utils import (
|
|
|
clear_gpu_cache,
|
|
|
get_parameter_dtypes,
|
|
|
print_model_size,
|
|
|
- get_policies
|
|
|
+ get_policies
|
|
|
)
|
|
|
|
|
|
from utils.dataset_utils import get_preprocessed_dataset
|
|
@@ -88,10 +88,10 @@ def main(**kwargs):
|
|
|
if torch.distributed.is_initialized():
|
|
|
torch.cuda.set_device(rank)
|
|
|
setup_environ_flags(rank)
|
|
|
-
|
|
|
+
|
|
|
# Calculate gradient accumulation steps
|
|
|
gradient_accumulation_steps = train_config.batch_size_training // train_config.micro_batch_size
|
|
|
-
|
|
|
+
|
|
|
# Load the pre-trained model and setup its configuration
|
|
|
if train_config.enable_fsdp and train_config.low_cpu_fsdp:
|
|
|
# for FSDP, we can save cpu memory by loading pretrained model on rank0 only.
|
|
@@ -113,19 +113,20 @@ def main(**kwargs):
|
|
|
llama_config = LlamaConfig.from_pretrained(train_config.model_name)
|
|
|
with torch.device("meta"):
|
|
|
model = LlamaForCausalLM(llama_config)
|
|
|
+
|
|
|
else:
|
|
|
model = LlamaForCausalLM.from_pretrained(
|
|
|
train_config.model_name,
|
|
|
load_in_8bit=True if train_config.quantization else None,
|
|
|
device_map="auto" if train_config.quantization else None,
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
|
|
|
-
|
|
|
+
|
|
|
# Prepare the model for int8 training if quantization is enabled
|
|
|
if train_config.quantization:
|
|
|
model = prepare_model_for_int8_training(model)
|
|
|
-
|
|
|
+
|
|
|
# Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
|
|
|
if train_config.enable_fsdp and fsdp_config.pure_bf16:
|
|
|
model.to(torch.bfloat16)
|
|
@@ -134,7 +135,7 @@ def main(**kwargs):
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
|
|
|
tokenizer.add_special_tokens(
|
|
|
{
|
|
|
-
|
|
|
+
|
|
|
"pad_token": "<PAD>",
|
|
|
}
|
|
|
)
|
|
@@ -142,11 +143,11 @@ def main(**kwargs):
|
|
|
peft_config = generate_peft_config(train_config, kwargs)
|
|
|
model = get_peft_model(model, peft_config)
|
|
|
model.print_trainable_parameters()
|
|
|
-
|
|
|
+
|
|
|
#setting up FSDP if enable_fsdp is enabled
|
|
|
if train_config.enable_fsdp:
|
|
|
if not train_config.use_peft and train_config.freeze_layers:
|
|
|
-
|
|
|
+
|
|
|
freeze_transformer_layers(train_config.num_freeze_layers)
|
|
|
|
|
|
mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
|
|
@@ -159,8 +160,9 @@ def main(**kwargs):
|
|
|
sharding_strategy=fsdp_config.sharding_strategy,
|
|
|
device_id=torch.cuda.current_device(),
|
|
|
limit_all_gathers=True,
|
|
|
- sync_module_states=True,
|
|
|
- param_init_fn=None if rank == 0 else lambda module: module.to_empty(device=torch.device("cuda"), recurse=False),
|
|
|
+ sync_module_states=True if train_config.low_cpu_fsdp else False,
|
|
|
+ param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
|
|
|
+ if train_config.low_cpu_fsdp and rank != 0 else None,
|
|
|
)
|
|
|
if fsdp_config.fsdp_activation_checkpointing:
|
|
|
policies.apply_fsdp_checkpointing(model)
|
|
@@ -168,14 +170,14 @@ def main(**kwargs):
|
|
|
model.to("cuda")
|
|
|
|
|
|
dataset_config = generate_dataset_config(train_config, kwargs)
|
|
|
-
|
|
|
+
|
|
|
# Load and preprocess the dataset for training and validation
|
|
|
dataset_train = get_preprocessed_dataset(
|
|
|
tokenizer,
|
|
|
dataset_config,
|
|
|
split="train",
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
if not train_config.enable_fsdp or rank == 0:
|
|
|
print(f"--> Training Set Length = {len(dataset_train)}")
|
|
|
|
|
@@ -202,7 +204,7 @@ def main(**kwargs):
|
|
|
rank=dist.get_rank(),
|
|
|
num_replicas=dist.get_world_size(),
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
# Create DataLoaders for the training and validation dataset
|
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
|
dataset_train,
|
|
@@ -224,7 +226,7 @@ def main(**kwargs):
|
|
|
drop_last=True,
|
|
|
collate_fn=default_data_collator,
|
|
|
)
|
|
|
-
|
|
|
+
|
|
|
# Initialize the optimizer and learning rate scheduler
|
|
|
if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
|
|
|
optimizer = AnyPrecisionAdamW(
|
|
@@ -246,7 +248,7 @@ def main(**kwargs):
|
|
|
results = train(
|
|
|
model,
|
|
|
train_dataloader,
|
|
|
- eval_dataloader,
|
|
|
+ eval_dataloader,
|
|
|
tokenizer,
|
|
|
optimizer,
|
|
|
scheduler,
|