|
@@ -59,19 +59,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
if train_config.enable_fsdp:
|
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
|
|
|
|
- metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
|
|
|
+
|
|
|
+
|
|
|
autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
|
|
|
|
|
|
train_prep = []
|
|
|
train_loss = []
|
|
|
- train_step_perplexity = []
|
|
|
- train_step_loss = []
|
|
|
-
|
|
|
val_prep = []
|
|
|
val_loss =[]
|
|
|
- val_step_loss = []
|
|
|
- val_step_perplexity = []
|
|
|
|
|
|
+ if train_config.save_metrics:
|
|
|
+ print(f"Save metrics TRUE {train_config.save_metrics}")
|
|
|
+ metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
|
|
|
+ train_step_perplexity = []
|
|
|
+ train_step_loss = []
|
|
|
+ val_step_loss = []
|
|
|
+ val_step_perplexity = []
|
|
|
+
|
|
|
epoch_times = []
|
|
|
checkpoint_times = []
|
|
|
results = {}
|
|
@@ -112,6 +116,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
pbar.update(1)
|
|
|
|
|
|
pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
|
|
|
+
|
|
|
+ if train_config.save_metrics:
|
|
|
+ save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
pbar.close()
|
|
|
|
|
|
epoch_end_time = time.perf_counter()-epoch_start_time
|
|
@@ -146,8 +153,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
|
|
|
if train_config.run_validation:
|
|
|
eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
- val_step_loss.extend(temp_val_loss)
|
|
|
- val_step_perplexity.extend(temp_step_perplexity)
|
|
|
+ if train_config.save_metrics:
|
|
|
+ val_step_loss.extend(temp_val_loss)
|
|
|
+ val_step_perplexity.extend(temp_step_perplexity)
|
|
|
+
|
|
|
checkpoint_start_time = time.perf_counter()
|
|
|
if train_config.save_model and eval_epoch_loss < best_val_loss:
|
|
|
if train_config.enable_fsdp:
|
|
@@ -207,7 +216,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
|
print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
|
|
|
|
|
|
# Saving the results every epoch to plot later
|
|
|
- save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
+ if train_config.save_metrics:
|
|
|
+ save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
|
|
|
|
|
|
avg_epoch_time = sum(epoch_times)/ len(epoch_times)
|
|
|
avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
|
|
@@ -262,8 +272,10 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
|
|
|
# Forward pass and compute loss
|
|
|
outputs = model(**batch)
|
|
|
loss = outputs.loss
|
|
|
- val_step_loss.append(loss.detach().float().item())
|
|
|
- val_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
+ if train_config.save_metrics:
|
|
|
+ val_step_loss.append(loss.detach().float().item())
|
|
|
+ val_step_perplexity.append(float(torch.exp(loss.detach().float())))
|
|
|
+
|
|
|
eval_loss += loss.detach().float()
|
|
|
# Decode predictions and add to evaluation predictions list
|
|
|
preds = torch.argmax(outputs.logits, -1)
|