Skip to content

Commit 07503bf

Browse files
committed
fix for jax implementation of sfadamw
1 parent db4bcd1 commit 07503bf

File tree

13 files changed

+350
-181
lines changed

13 files changed

+350
-181
lines changed

algoperf/workloads/fastmri/input_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212
from algoperf import data_utils
1313

14-
_TRAIN_DIR = 'fastmri/knee_singlecoil_train'
15-
_VAL_DIR = 'fastmri/knee_singlecoil_val'
14+
_TRAIN_DIR = 'knee_singlecoil_train'
15+
_VAL_DIR = 'knee_singlecoil_val'
1616
_EVAL_SEED = 0
1717

1818

algoperf/workloads/imagenet_resnet/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,4 @@ def _build_input_queue(
145145
@property
146146
def step_hint(self) -> int:
147147
"""Approx. steps the baseline can do in the allowed runtime budget."""
148-
return 195_999
148+
return 186_666

algoperf/workloads/imagenet_vit/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,4 @@ def _build_dataset(
121121
@property
122122
def step_hint(self) -> int:
123123
"""Approx. steps the baseline can do in the allowed runtime budget."""
124-
return 167_999
124+
return 186_666

algoperf/workloads/librispeech_conformer/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ def eval_period_time_sec(self) -> int:
8989
@property
9090
def step_hint(self) -> int:
9191
"""Approx. steps the baseline can do in the allowed runtime budget."""
92-
return 76_000
92+
return 80_000

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_target_value(self) -> float:
9696
@property
9797
def step_hint(self) -> int:
9898
"""Approx. steps the baseline can do in the allowed runtime budget."""
99-
return 38_400
99+
return 48_000
100100

101101
@property
102102
def max_allowed_runtime_sec(self) -> int:

algoperf/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def test_target_value(self) -> float:
9292
@property
9393
def step_hint(self) -> int:
9494
"""Approx. steps the baseline can do in the allowed runtime budget."""
95-
return 38_400
95+
return 48_000
9696

9797
@property
9898
def max_allowed_runtime_sec(self) -> int:

algoperf/workloads/ogbg/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def loss_fn(
144144
@property
145145
def step_hint(self) -> int:
146146
"""Approx. steps the baseline can do in the allowed runtime budget."""
147-
return 52_000
147+
return 80_000
148148

149149
@abc.abstractmethod
150150
def _normalize_eval_metrics(

algoperf/workloads/wmt/workload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def eval_period_time_sec(self) -> int:
9898
@property
9999
def step_hint(self) -> int:
100100
"""Approx. steps the baseline can do in the allowed runtime budget."""
101-
return 120_000
101+
return 133_333
102102

103103
@property
104104
def pre_ln(self) -> bool:

docker/scripts/plot.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# plot_results.py
2+
3+
import pandas as pd
4+
import matplotlib.pyplot as plt
5+
from pathlib import Path
6+
import collections
7+
8+
# --- Configuration ---
9+
# The base directory to start searching from.
10+
# The script assumes it's run from the 'experiment_runs' directory.
11+
SEARCH_DIR = Path("~/experiment_runs/tests/regression_tests/adamw").expanduser()
12+
13+
# The directory where the output plots will be saved.
14+
OUTPUT_DIR = Path("ssim_plots")
15+
16+
# The columns to use for the x and y axes.
17+
X_AXIS_COL = "global_step"
18+
Y_AXIS_CANDIDATES = ["validation/loss", "validation/ctc_loss"]
19+
# ---------------------
20+
21+
def generate_plots():
22+
"""
23+
Finds all 'measurements.csv' files, groups them by workflow,
24+
and generates a JAX vs. PyTorch plot for each.
25+
"""
26+
# Create the output directory if it doesn't already exist
27+
OUTPUT_DIR.mkdir(exist_ok=True)
28+
print(f"📊 Plots will be saved to the '{OUTPUT_DIR}' directory.")
29+
30+
# Use a dictionary to group file paths by their workflow name
31+
# e.g., {'fastmri': [...], 'wmt': [...]}
32+
workflow_files = collections.defaultdict(list)
33+
34+
# Recursively find all 'measurements.csv' files in the search directory
35+
for csv_path in SEARCH_DIR.rglob("measurements.csv"):
36+
try:
37+
# The directory name looks like 'fastmri_jax' or 'wmt_pytorch'.
38+
# We get this from the parent of the parent of the csv file.
39+
# e.g., .../fastmri_jax/trial_1/measurements.csv
40+
workflow_framework_name = csv_path.parent.parent.name
41+
42+
# Split the name to get the framework (last part) and workflow (everything else)
43+
parts = workflow_framework_name.split('_')
44+
framework = parts[-1]
45+
workflow = '_'.join(parts[:-1])
46+
47+
# Store the path and framework for this workflow
48+
if framework in ['jax', 'pytorch']:
49+
workflow_files[workflow].append({'path': csv_path, 'framework': framework})
50+
51+
except IndexError:
52+
# This handles cases where the directory name might not match the expected pattern
53+
print(f"⚠️ Could not parse workflow/framework from path: {csv_path}")
54+
continue
55+
56+
if not workflow_files:
57+
print("❌ No 'measurements.csv' files found. Check the SEARCH_DIR variable and your folder structure.")
58+
return
59+
60+
print(f"\nFound {len(workflow_files)} workflows. Generating plots...")
61+
62+
# Iterate through each workflow and its associated files to create a plot
63+
for workflow, files in workflow_files.items():
64+
plt.style.use('seaborn-v0_8-whitegrid')
65+
fig, ax = plt.subplots(figsize=(12, 7))
66+
67+
print(f" -> Processing workflow: '{workflow}'")
68+
69+
y_axis_col_used = None # To store the name of the y-axis column for the plot labels
70+
71+
# Plot data for each framework (JAX and PyTorch) on the same figure
72+
for item in files:
73+
try:
74+
df = pd.read_csv(item['path'])
75+
76+
y_axis_col = None
77+
for candidate in Y_AXIS_CANDIDATES:
78+
if candidate in df.columns:
79+
y_axis_col = candidate
80+
if not y_axis_col_used:
81+
y_axis_col_used = y_axis_col # Set the label from the first file
82+
break # Found a valid column, no need to check further
83+
84+
# if item['framework'] == 'jax':
85+
# y_axis_col = None
86+
87+
# Check if the required columns exist in the CSV
88+
if X_AXIS_COL in df.columns and y_axis_col:
89+
90+
# 1. Forward-fill 'global_step' to propagate the last valid step downwards.
91+
df[X_AXIS_COL] = df[X_AXIS_COL].ffill()
92+
93+
# 2. Drop any rows where 'validation/ssim' is empty (NaN).
94+
df_cleaned = df.dropna(subset=[y_axis_col])
95+
96+
# Plot the cleaned data
97+
ax.plot(
98+
df_cleaned[X_AXIS_COL],
99+
df_cleaned[y_axis_col],
100+
label=item['framework'].capitalize(), # e.g., 'Jax'
101+
marker='.',
102+
linestyle='-',
103+
alpha=0.8
104+
)
105+
else:
106+
print(f" - Skipping {item['path']} (missing required columns).")
107+
108+
except Exception as e:
109+
print(f" - ❗️ Error reading {item['path']}: {e}")
110+
111+
# Customize and save the plot
112+
ax.set_title(f'Validation loss vs. Global Step for {workflow.replace("_", " ").title()}', fontsize=16)
113+
ax.set_xlabel("Global Step", fontsize=12)
114+
ax.set_ylabel("Validation loss", fontsize=12)
115+
ax.legend(title="Framework", fontsize=10)
116+
plt.tight_layout()
117+
plt.yscale('log')
118+
119+
# Define the output filename and save the figure
120+
output_filename = OUTPUT_DIR / f"{workflow}_comparison.png"
121+
plt.savefig(output_filename, dpi=150)
122+
plt.close(fig) # Close the figure to free up memory
123+
124+
print("\n✅ All plots generated successfully!")
125+
126+
127+
if __name__ == "__main__":
128+
generate_plots()

submission_runner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import importlib
2020
import itertools
2121
import json
22-
import jax
2322
import os
2423
import struct
2524
import time
25+
import optax
2626
from inspect import signature
2727
from types import MappingProxyType
2828
from typing import Any, Dict, Optional, Tuple
@@ -861,9 +861,11 @@ def main(_):
861861

862862

863863
if __name__ == '__main__':
864+
print(optax.__version__)
865+
print("!!!!")
864866
flags.mark_flag_as_required('workload')
865867
flags.mark_flag_as_required('framework')
866868
flags.mark_flag_as_required('submission_path')
867869
flags.mark_flag_as_required('experiment_dir')
868870
flags.mark_flag_as_required('experiment_name')
869-
app.run(main)
871+
app.run(main)

0 commit comments

Comments
 (0)