|
@@ -6,6 +6,7 @@ import time
|
|
|
import yaml
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
+from datetime import datetime
|
|
|
|
|
|
|
|
|
import torch
|
|
@@ -15,6 +16,8 @@ from torch.distributed.fsdp import StateDictType
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaTokenizer
|
|
|
+import matplotlib.pyplot as plt
|
|
|
+import json
|
|
|
|
|
|
|
|
|
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
|
|
@@ -55,10 +58,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
if train_config.enable_fsdp:
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+
|
|
|
+ metrics_filename = f"{train_config.output_dir}/metrics_data_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
+ train_step_perplexity = []
|
|
|
+ train_step_loss = []
|
|
|
+
|
|
|
val_prep = []
|
|
|
val_loss =[]
|
|
|
+ val_step_loss = []
|
|
|
+ val_step_perplexity = []
|
|
|
+
|
|
|
epoch_times = []
|
|
|
checkpoint_times = []
|
|
|
results = {}
|
|
@@ -78,6 +89,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
batch[key] = batch[key].to('cuda:0')
|
|
|
loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
+ train_step_loss.append(loss.detach().float().item())
|
|
|
+ train_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
total_loss += loss.detach().float()
|
|
|
if train_config.use_fp16:
|
|
|
# if fp16 is enabled, use gradient scaler to handle gradient update
|
|
@@ -108,8 +121,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
train_epoch_loss = train_epoch_loss/world_size
|
|
|
train_perplexity = torch.exp(train_epoch_loss)
|
|
|
|
|
|
- train_prep.append(train_perplexity)
|
|
|
- train_loss.append(train_epoch_loss)
|
|
|
+ train_prep.append(float(train_perplexity))
|
|
|
+ train_loss.append(float(train_epoch_loss))
|
|
|
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
@@ -129,7 +142,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
if train_config.run_validation:
|
|
|
- eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
+ eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
+ val_step_loss.extend(temp_val_loss)
|
|
|
+ val_step_perplexity.extend(temp_step_perplexity)
|
|
|
checkpoint_start_time = time.perf_counter()
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
if train_config.enable_fsdp:
|
|
@@ -180,13 +195,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
else:
|
|
|
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
- val_loss.append(best_val_loss)
|
|
|
- val_prep.append(eval_ppl)
|
|
|
+ val_loss.append(float(best_val_loss))
|
|
|
+ val_prep.append(float(eval_ppl))
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
else:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
+
|
|
|
+ # Saving the results every epoch to plot later
|
|
|
+ save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
+
|
|
|
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
@@ -225,6 +244,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
model.eval()
|
|
|
eval_preds = []
|
|
|
+ val_step_loss = []
|
|
|
+ val_step_perplexity = []
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
with MemoryTrace() as memtrace:
|
|
|
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
|
|
@@ -238,6 +259,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
# Forward pass and compute loss
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
+ val_step_loss.append(loss.detach().float().item())
|
|
|
+ val_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
eval_loss += loss.detach().float()
|
|
|
# Decode predictions and add to evaluation predictions list
|
|
|
preds = torch.argmax(outputs.logits, -1)
|
|
@@ -262,7 +285,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
else:
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
|
|
|
- return eval_ppl, eval_epoch_loss
|
|
|
+ return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
|
|
|
|
|
|
def freeze_transformer_layers(model, num_layer):
|
|
|
for i, layer in enumerate(model.model.layers):
|
|
@@ -402,3 +425,17 @@ def save_train_params(train_config, fsdp_config, rank):
|
|
|
f.write(config_yaml)
|
|
|
if rank==0:
|
|
|
print(f"training params are saved in {file_name}")
|
|
|
+
|
|
|
+def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
|
|
|
+ metrics_data = {
|
|
|
+ "train_step_loss": train_step_loss,
|
|
|
+ "train_epoch_loss": train_epoch_loss,
|
|
|
+ "train_step_perplexity": train_step_ppl,
|
|
|
+ "train_epoch_perplexity": train_epoch_ppl,
|
|
|
+ "val_step_loss": val_step_loss,
|
|
|
+ "val_epoch_loss": val_epoch_loss,
|
|
|
+ "val_step_perplexity": val_step_ppl,
|
|
|
+ "val_epoch_perplexity": val_epoch_ppl
|
|
|
+ }
|
|
|
+ with open(output_filename, "w") as f:
|
|
|
+ json.dump(metrics_data, f)
|