1+ # plot_results.py
2+
3+ import pandas as pd
4+ import matplotlib .pyplot as plt
5+ from pathlib import Path
6+ import collections
7+
8+ # --- Configuration ---
9+ # The base directory to start searching from.
10+ # The script assumes it's run from the 'experiment_runs' directory.
11+ SEARCH_DIR = Path ("~/experiment_runs/tests/regression_tests/adamw" ).expanduser ()
12+
13+ # The directory where the output plots will be saved.
14+ OUTPUT_DIR = Path ("ssim_plots" )
15+
16+ # The columns to use for the x and y axes.
17+ X_AXIS_COL = "global_step"
18+ Y_AXIS_CANDIDATES = ["validation/loss" , "validation/ctc_loss" ]
19+ # ---------------------
20+
21+ def generate_plots ():
22+ """
23+ Finds all 'measurements.csv' files, groups them by workflow,
24+ and generates a JAX vs. PyTorch plot for each.
25+ """
26+ # Create the output directory if it doesn't already exist
27+ OUTPUT_DIR .mkdir (exist_ok = True )
28+ print (f"📊 Plots will be saved to the '{ OUTPUT_DIR } ' directory." )
29+
30+ # Use a dictionary to group file paths by their workflow name
31+ # e.g., {'fastmri': [...], 'wmt': [...]}
32+ workflow_files = collections .defaultdict (list )
33+
34+ # Recursively find all 'measurements.csv' files in the search directory
35+ for csv_path in SEARCH_DIR .rglob ("measurements.csv" ):
36+ try :
37+ # The directory name looks like 'fastmri_jax' or 'wmt_pytorch'.
38+ # We get this from the parent of the parent of the csv file.
39+ # e.g., .../fastmri_jax/trial_1/measurements.csv
40+ workflow_framework_name = csv_path .parent .parent .name
41+
42+ # Split the name to get the framework (last part) and workflow (everything else)
43+ parts = workflow_framework_name .split ('_' )
44+ framework = parts [- 1 ]
45+ workflow = '_' .join (parts [:- 1 ])
46+
47+ # Store the path and framework for this workflow
48+ if framework in ['jax' , 'pytorch' ]:
49+ workflow_files [workflow ].append ({'path' : csv_path , 'framework' : framework })
50+
51+ except IndexError :
52+ # This handles cases where the directory name might not match the expected pattern
53+ print (f"⚠️ Could not parse workflow/framework from path: { csv_path } " )
54+ continue
55+
56+ if not workflow_files :
57+ print ("❌ No 'measurements.csv' files found. Check the SEARCH_DIR variable and your folder structure." )
58+ return
59+
60+ print (f"\n Found { len (workflow_files )} workflows. Generating plots..." )
61+
62+ # Iterate through each workflow and its associated files to create a plot
63+ for workflow , files in workflow_files .items ():
64+ plt .style .use ('seaborn-v0_8-whitegrid' )
65+ fig , ax = plt .subplots (figsize = (12 , 7 ))
66+
67+ print (f" -> Processing workflow: '{ workflow } '" )
68+
69+ y_axis_col_used = None # To store the name of the y-axis column for the plot labels
70+
71+ # Plot data for each framework (JAX and PyTorch) on the same figure
72+ for item in files :
73+ try :
74+ df = pd .read_csv (item ['path' ])
75+
76+ y_axis_col = None
77+ for candidate in Y_AXIS_CANDIDATES :
78+ if candidate in df .columns :
79+ y_axis_col = candidate
80+ if not y_axis_col_used :
81+ y_axis_col_used = y_axis_col # Set the label from the first file
82+ break # Found a valid column, no need to check further
83+
84+ # if item['framework'] == 'jax':
85+ # y_axis_col = None
86+
87+ # Check if the required columns exist in the CSV
88+ if X_AXIS_COL in df .columns and y_axis_col :
89+
90+ # 1. Forward-fill 'global_step' to propagate the last valid step downwards.
91+ df [X_AXIS_COL ] = df [X_AXIS_COL ].ffill ()
92+
93+ # 2. Drop any rows where 'validation/ssim' is empty (NaN).
94+ df_cleaned = df .dropna (subset = [y_axis_col ])
95+
96+ # Plot the cleaned data
97+ ax .plot (
98+ df_cleaned [X_AXIS_COL ],
99+ df_cleaned [y_axis_col ],
100+ label = item ['framework' ].capitalize (), # e.g., 'Jax'
101+ marker = '.' ,
102+ linestyle = '-' ,
103+ alpha = 0.8
104+ )
105+ else :
106+ print (f" - Skipping { item ['path' ]} (missing required columns)." )
107+
108+ except Exception as e :
109+ print (f" - ❗️ Error reading { item ['path' ]} : { e } " )
110+
111+ # Customize and save the plot
112+ ax .set_title (f'Validation loss vs. Global Step for { workflow .replace ("_" , " " ).title ()} ' , fontsize = 16 )
113+ ax .set_xlabel ("Global Step" , fontsize = 12 )
114+ ax .set_ylabel ("Validation loss" , fontsize = 12 )
115+ ax .legend (title = "Framework" , fontsize = 10 )
116+ plt .tight_layout ()
117+ plt .yscale ('log' )
118+
119+ # Define the output filename and save the figure
120+ output_filename = OUTPUT_DIR / f"{ workflow } _comparison.png"
121+ plt .savefig (output_filename , dpi = 150 )
122+ plt .close (fig ) # Close the figure to free up memory
123+
124+ print ("\n ✅ All plots generated successfully!" )
125+
126+
127+ if __name__ == "__main__" :
128+ generate_plots ()
0 commit comments