|
@@ -7,6 +7,7 @@ import yaml
|
|
|
from contextlib import nullcontext
|
|
|
from pathlib import Path
|
|
|
from pkg_resources import packaging
|
|
|
+from datetime import datetime
|
|
|
|
|
|
|
|
|
import torch
|
|
@@ -16,6 +17,7 @@ from torch.distributed.fsdp import StateDictType
|
|
|
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
|
|
from tqdm import tqdm
|
|
|
from transformers import LlamaTokenizer
|
|
|
+import json
|
|
|
|
|
|
|
|
|
from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
|
|
@@ -55,13 +57,24 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
elif train_config.use_fp16 and not train_config.enable_fsdp:
|
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
|
if train_config.enable_fsdp:
|
|
|
- world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+ world_size = int(os.environ["WORLD_SIZE"])
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
val_prep = []
|
|
|
val_loss =[]
|
|
|
+
|
|
|
+ if train_config.save_metrics:
|
|
|
+ metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
|
|
|
+ train_step_perplexity = []
|
|
|
+ train_step_loss = []
|
|
|
+ val_step_loss = []
|
|
|
+ val_step_perplexity = []
|
|
|
+
|
|
|
epoch_times = []
|
|
|
checkpoint_times = []
|
|
|
results = {}
|
|
@@ -82,6 +95,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
with autocast():
|
|
|
loss = model(**batch).loss
|
|
|
loss = loss / gradient_accumulation_steps
|
|
|
+ if train_config.save_metrics:
|
|
|
+ train_step_loss.append(loss.detach().float().item())
|
|
|
+ train_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
total_loss += loss.detach().float()
|
|
|
if train_config.use_fp16:
|
|
|
# if fp16 is enabled, use gradient scaler to handle gradient update
|
|
@@ -111,6 +127,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
pbar.update(1)
|
|
|
|
|
|
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+
|
|
|
+ if train_config.save_metrics:
|
|
|
+ save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
pbar.close()
|
|
|
|
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
@@ -122,10 +141,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.enable_fsdp:
|
|
|
train_epoch_loss = train_epoch_loss/world_size
|
|
|
train_perplexity = torch.exp(train_epoch_loss)
|
|
|
-
|
|
|
- train_prep.append(train_perplexity)
|
|
|
- train_loss.append(train_epoch_loss)
|
|
|
-
|
|
|
+
|
|
|
+ train_prep.append(float(train_perplexity))
|
|
|
+ train_loss.append(float(train_epoch_loss))
|
|
|
+
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
print(f"Max CUDA memory allocated was {memtrace.peak} GB")
|
|
@@ -144,7 +163,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
if train_config.run_validation:
|
|
|
- eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
+ eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
+ if train_config.save_metrics:
|
|
|
+ val_step_loss.extend(temp_val_loss)
|
|
|
+ val_step_perplexity.extend(temp_step_perplexity)
|
|
|
+
|
|
|
checkpoint_start_time = time.perf_counter()
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
if train_config.enable_fsdp:
|
|
@@ -195,13 +218,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
else:
|
|
|
print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
|
|
|
- val_loss.append(best_val_loss)
|
|
|
- val_prep.append(eval_ppl)
|
|
|
+ val_loss.append(float(best_val_loss))
|
|
|
+ val_prep.append(float(eval_ppl))
|
|
|
if train_config.enable_fsdp:
|
|
|
if rank==0:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
else:
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
+
|
|
|
+ # Saving the results every epoch to plot later
|
|
|
+ if train_config.save_metrics:
|
|
|
+ save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
+
|
|
|
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
|
avg_train_prep = sum(train_prep)/len(train_prep)
|
|
@@ -217,6 +245,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
results['avg_eval_loss'] = avg_eval_loss
|
|
|
results["avg_epoch_time"] = avg_epoch_time
|
|
|
results["avg_checkpoint_time"] = avg_checkpoint_time
|
|
|
+ if train_config.save_metrics:
|
|
|
+ results["metrics_filename"] = metrics_filename
|
|
|
|
|
|
#saving the training params including fsdp setting for reference.
|
|
|
if train_config.enable_fsdp and not train_config.use_peft:
|
|
@@ -240,6 +270,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
model.eval()
|
|
|
eval_preds = []
|
|
|
+ val_step_loss = []
|
|
|
+ val_step_perplexity = []
|
|
|
eval_loss = 0.0 # Initialize evaluation loss
|
|
|
with MemoryTrace() as memtrace:
|
|
|
for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
|
|
@@ -253,6 +285,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
# Forward pass and compute loss
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
+ if train_config.save_metrics:
|
|
|
+ val_step_loss.append(loss.detach().float().item())
|
|
|
+ val_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
+
|
|
|
eval_loss += loss.detach().float()
|
|
|
# Decode predictions and add to evaluation predictions list
|
|
|
preds = torch.argmax(outputs.logits, -1)
|
|
@@ -276,8 +312,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
else:
|
|
|
print(f" {eval_ppl=} {eval_epoch_loss=}")
|
|
|
-
|
|
|
- return eval_ppl, eval_epoch_loss
|
|
|
+
|
|
|
+ return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
|
|
|
|
|
|
def freeze_transformer_layers(model, num_layer):
|
|
|
for i, layer in enumerate(model.model.layers):
|
|
@@ -417,3 +453,17 @@ def save_train_params(train_config, fsdp_config, rank):
|
|
|
f.write(config_yaml)
|
|
|
if rank==0:
|
|
|
print(f"training params are saved in {file_name}")
|
|
|
+
|
|
|
+def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
|
|
|
+ metrics_data = {
|
|
|
+ "train_step_loss": train_step_loss,
|
|
|
+ "train_epoch_loss": train_epoch_loss,
|
|
|
+ "train_step_perplexity": train_step_ppl,
|
|
|
+ "train_epoch_perplexity": train_epoch_ppl,
|
|
|
+ "val_step_loss": val_step_loss,
|
|
|
+ "val_epoch_loss": val_epoch_loss,
|
|
|
+ "val_step_perplexity": val_step_ppl,
|
|
|
+ "val_epoch_perplexity": val_epoch_ppl
|
|
|
+ }
|
|
|
+ with open(output_filename, "w") as f:
|
|
|
+ json.dump(metrics_data, f)
|