plot_metrics.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import json
  2. import matplotlib.pyplot as plt
  3. import sys
  4. import os
  5. def plot_metrics(file_path):
  6. # Read the JSON file
  7. with open(file_path, 'r') as f:
  8. data = json.load(f)
  9. # Get directory and filename information
  10. directory = os.path.dirname(file_path)
  11. filename_prefix = os.path.basename(file_path).split('.')[0]
  12. # Plotting metrics for training and validation step loss
  13. plt.figure(figsize=(14, 6))
  14. plt.subplot(1, 2, 1)
  15. plt.plot(data['train_step_loss'], label='Train Step Loss', color='b')
  16. plt.plot(data['val_step_loss'], label='Validation Step Loss', color='r')
  17. plt.xlabel('Step')
  18. plt.ylabel('Loss')
  19. plt.title('Train and Validation Step Loss')
  20. plt.legend()
  21. # Plotting metrics for training and validation epoch loss
  22. plt.subplot(1, 2, 2)
  23. plt.plot(data['train_epoch_loss'], label='Train Epoch Loss', color='b')
  24. plt.plot(data['val_epoch_loss'], label='Validation Epoch Loss', color='r')
  25. plt.xlabel('Epoch')
  26. plt.ylabel('Loss')
  27. plt.title('Train and Validation Epoch Loss')
  28. plt.legend()
  29. plt.tight_layout()
  30. plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png"))
  31. plt.close()
  32. # Plotting perplexity
  33. plt.figure(figsize=(14, 6))
  34. plt.subplot(1, 2, 1)
  35. plt.plot(data['train_step_perplexity'],
  36. label='Train Step Perplexity', color='g')
  37. plt.plot(data['val_step_perplexity'],
  38. label='Validation Step Perplexity', color='m')
  39. plt.xlabel('Step')
  40. plt.ylabel('Perplexity')
  41. plt.title('Train and Validation Step Perplexity')
  42. plt.legend()
  43. plt.subplot(1, 2, 2)
  44. plt.plot(data['train_epoch_perplexity'],
  45. label='Train Epoch Perplexity', color='g')
  46. plt.plot(data['val_epoch_perplexity'],
  47. label='Validation Epoch Perplexity', color='m')
  48. plt.xlabel('Epoch')
  49. plt.ylabel('Perplexity')
  50. plt.title('Train and Validation Epoch Perplexity')
  51. plt.legend()
  52. plt.tight_layout()
  53. plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png"))
  54. plt.close()
  55. if __name__ == "__main__":
  56. if len(sys.argv) < 2:
  57. print("Usage: python script.py <path_to_metrics_json>")
  58. sys.exit(1)
  59. file_path = sys.argv[1]
  60. plot_metrics(file_path)