Browse Source

adding wandb_run ro eval

Hamid Shojanazeri 1 year ago
parent
commit
761b7e6e51
1 changed files with 2 additions and 2 deletions
  1. 2 2
      src/llama_recipes/utils/train_utils.py

+ 2 - 2
src/llama_recipes/utils/train_utils.py

@@ -169,7 +169,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         lr_scheduler.step()
         lr_scheduler.step()
 
 
         if train_config.run_validation:
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
             if train_config.save_metrics:
             if train_config.save_metrics:
                 val_step_loss.extend(temp_val_loss)
                 val_step_loss.extend(temp_val_loss)
                 val_step_perplexity.extend(temp_step_perplexity)
                 val_step_perplexity.extend(temp_step_perplexity)
@@ -492,4 +492,4 @@ def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_
         "val_epoch_perplexity": val_epoch_ppl
         "val_epoch_perplexity": val_epoch_ppl
     }
     }
     with open(output_filename, "w") as f:
     with open(output_filename, "w") as f:
-        json.dump(metrics_data, f)
+        json.dump(metrics_data, f)