Bladeren bron

Modularizing the graphs further. Plotting train and validation separetly due to difference in number of steps.

Beto 1 jaar geleden
bovenliggende
commit
5a011246e1
1 gewijzigde bestanden met toevoegingen van 27 en 9 verwijderingen
  1. 27 9
      examples/plot_metrics.py

+ 27 - 9
examples/plot_metrics.py

@@ -4,24 +4,34 @@ import argparse
 import os
 
 def plot_metric(data, metric_name, x_label, y_label, title, colors):
-    plt.figure(figsize=(14, 6))
-    plt.subplot(1, 2, 1)
-    plt.plot(data[f'train_step_{metric_name}'], label=f'Train Step {metric_name.capitalize()}', color=colors[0])
-    plt.plot(data[f'val_step_{metric_name}'], label=f'Validation Step {metric_name.capitalize()}', color=colors[1])
+    plt.figure(figsize=(7, 6))
+    
+    plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0])
+    plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1])
     plt.xlabel(x_label)
     plt.ylabel(y_label)
-    plt.title(f'Train and Validation Step {title}')
+    plt.title(f'Train and Validation Epoch {title}')
     plt.legend()
+    plt.tight_layout()
 
-    plt.subplot(1, 2, 2)
-    plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0])
-    plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1])
+def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color):
+    plt.plot(data[f'{metric_name}'], label=f'{title}', color=color)
     plt.xlabel(x_label)
     plt.ylabel(y_label)
-    plt.title(f'Train and Validation Epoch {title}')
+    plt.title(title)
     plt.legend()
     plt.tight_layout()
 
+def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
+    plt.figure(figsize=(14, 6))
+
+    plt.subplot(1, 2, 1)
+    plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0])
+    plt.subplot(1, 2, 2)
+    plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1])
+    plt.tight_layout()
+
+    
 def plot_metrics(file_path):
     if not os.path.exists(file_path):
         print(f"File {file_path} does not exist.")
@@ -45,6 +55,14 @@ def plot_metrics(file_path):
     plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png"))
     plt.close()
 
+    plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png"))
+    plt.close()
+    
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description='Plot metrics from JSON file.')
     parser.add_argument('file_path', type=str, help='Path to the metrics JSON file.')