|
@@ -169,7 +169,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
|
|
lr_scheduler.step()
|
|
lr_scheduler.step()
|
|
|
|
|
|
if train_config.run_validation:
|
|
if train_config.run_validation:
|
|
- eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
|
|
|
|
|
|
+ eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
|
|
if train_config.save_metrics:
|
|
if train_config.save_metrics:
|
|
val_step_loss.extend(temp_val_loss)
|
|
val_step_loss.extend(temp_val_loss)
|
|
val_step_perplexity.extend(temp_step_perplexity)
|
|
val_step_perplexity.extend(temp_step_perplexity)
|
|
@@ -492,4 +492,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
|
|
"val_epoch_perplexity": val_epoch_ppl
|
|
"val_epoch_perplexity": val_epoch_ppl
|
|
}
|
|
}
|
|
with open(output_filename, "w") as f:
|
|
with open(output_filename, "w") as f:
|
|
- json.dump(metrics_data, f)
|
|
|
|
|
|
+ json.dump(metrics_data, f)
|