|
@@ -5,6 +5,7 @@ import os
|
|
|
import sys
|
|
|
from typing import List
|
|
|
import yaml
|
|
|
+import time
|
|
|
|
|
|
import fire
|
|
|
import torch
|
|
@@ -73,9 +74,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
|
val_loss =[]
|
|
|
+ epoch_times = []
|
|
|
+ checkpoint_times = []
|
|
|
results = {}
|
|
|
best_val_loss = float("inf")
|
|
|
for epoch in range(train_config.num_epochs):
|
|
|
+ epoch_start_time = time.perf_counter()
|
|
|
with MemoryTrace() as memtrace: # track the memory usage
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
@@ -106,7 +110,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
else:
|
|
|
print(f"\n step {step} is completed and loss is {loss.detach().float()}")
|
|
|
-
|
|
|
+ epoch_end_time = time.perf_counter()-epoch_start_time
|
|
|
+ epoch_times.append(epoch_end_time)
|
|
|
# Reducing total_loss across all devices if there's more than one CUDA device
|
|
|
if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
|
|
|
dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
|
|
@@ -117,6 +122,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
train_prep.append(train_perplexity)
|
|
|
train_loss.append(train_epoch_loss)
|
|
|
+
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
|
|
@@ -136,6 +142,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
+ checkpoint_start_time = time.perf_counter()
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
if train_config.enable_fsdp:
|
|
|
dist.barrier()
|
|
@@ -165,18 +172,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
|
|
|
if train_config.save_optimizer:
|
|
|
model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
|
|
|
- print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
|
|
|
+ print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
|
|
|
if not train_config.use_peft and train_config.save_optimizer:
|
|
|
model_checkpointing.save_optimizer_checkpoint(
|
|
|
model, optimizer, rank, train_config, epoch=epoch
|
|
|
)
|
|
|
- print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
|
|
|
+ print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
|
|
|
print("=====================================================")
|
|
|
if train_config.enable_fsdp:
|
|
|
dist.barrier()
|
|
|
-
|
|
|
+ checkpoint_end_time = time.perf_counter() - checkpoint_start_time
|
|
|
+ checkpoint_times.append(checkpoint_end_time)
|
|
|
if eval_epoch_loss < best_val_loss:
|
|
|
best_val_loss = eval_epoch_loss
|
|
|
if train_config.enable_fsdp:
|
|
@@ -189,10 +197,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
else:
|
|
|
- print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
|
|
|
-
|
|
|
+ print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
|
|
|
+ avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
+ avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times)
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
|
avg_train_loss = sum(train_loss)/len(train_loss)
|
|
|
if train_config.run_validation:
|
|
@@ -204,7 +213,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.run_validation:
|
|
|
results['avg_eval_prep'] = avg_eval_prep
|
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
|
-
|
|
|
+ results["avg_epoch_time"] = avg_epoch_time
|
|
|
+ results["avg_checkpoint_time"] = avg_checkpoint_time
|
|
|
+
|
|
|
#saving the training params including fsdp setting for reference.
|
|
|
if train_config.enable_fsdp and not train_config.use_peft:
|
|
|
save_train_params(train_config, fsdp_config, rank)
|