Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions paper/experiments/spherical_mnist_reconstruction/.gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
figures/
figures/*
!figures/paper_orbits.pdf
__pycache__/
*.png
*.pdf
*.json
*.log
31 changes: 20 additions & 11 deletions paper/experiments/spherical_mnist_reconstruction/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,30 @@ Default settings (`lmax=12`, `n_steps=8000`, `n_recon_restarts=4`,
For a faster sanity check use `--n_digits 3 --n_recon_restarts 1
--n_steps 3000 --align_n_restarts 4`.

To regenerate **only the compact 2x2 paper figure** (~5 min), skip the
comprehensive sweep:
To regenerate **only the compact paper figure** (`2 x len(--paper_digits)`,
full-width on a NeurIPS column), skip the comprehensive sweep:

```bash
python reconstruct.py --paper_only --paper_digits 0 1
python reconstruct.py --paper_only --paper_digits 0 1 2 3 4 5
```

This caches all the recon tensors to `figures/state.pt`. Iterate on the
figure layout *without* re-running the optimization:

```bash
python reconstruct.py --paper_only --paper_digits 0 1 2 3 4 5 --replot
```

Outputs land in `figures/`:

| File | Description |
|---------------------------------|----------------------------------------------------------------------|
| `orbits.{pdf,png}` | `(n_digits x 3(1+K))` grid: target / raw recon / aligned recon spheres |
| `paper_orbits.{pdf,png}` | Compact `(len(--paper_digits) x 2)` figure for the NeurIPS paper |
| `convergence.{pdf,png}` | Median + IQR of relative bispectrum residual vs. step |
| `invariance_vs_recon.{pdf,png}` | Per-pair scatter of invariance vs. recon residual |
| `results.json` | All scalar metrics + per-step traces |
| File | Description |
|---------------------------------|--------------------------------------------------------------------------------------|
| `orbits.{pdf,png}` | `(n_digits x 3(1+K))` grid: target / raw recon / aligned recon spheres |
| `paper_orbits.{pdf,png}` | Compact `(2 rows x len(--paper_digits) cols)` figure for the NeurIPS paper |
| `state.pt` | Cached `ReconResult` tensors for fast `--replot` iteration |
| `convergence.{pdf,png}` | Median + IQR of relative bispectrum residual vs. step |
| `invariance_vs_recon.{pdf,png}` | Per-pair scatter of invariance vs. recon residual |
| `results.json` | All scalar metrics + per-step traces |

## Key CLI arguments

Expand All @@ -79,9 +87,10 @@ Outputs land in `figures/`:
| `--view_size` | `128` | Orthographic view resolution |
| `--elev_deg / --azim_deg` | `25/30` | Fallback / fixed-view camera direction (degrees) |
| `--fixed_view` | off | Disable per-panel auto-centering on the signal centroid (use a single shared camera direction instead) |
| `--paper_digits` | `0 1` | `digit_idx` values to use for the compact `paper_orbits.pdf` figure |
| `--paper_digits` | `0 1 2 3 4 5` | `digit_idx` values to use as columns of the compact `paper_orbits.pdf` figure |
| `--paper_figure_path` | auto | Override output path; defaults to `<output_dir>/paper_orbits.pdf` |
| `--paper_only` | off | Run only the digits in `--paper_digits` and emit only the paper figure (fast regeneration path) |
| `--replot` | off | Skip recon + alignment, rebuild figures from `<output_dir>/state.pt` (~3 s) |
| `--full_bispectrum` | off | `O(L^3)` full bispectrum instead of selective `O(L^2)` |
| `--no_bandlimit_project` | off | Disable the per-step `IRealSHT(RealSHT(.))` projection |
| `--seed` | `0` | Controls digit selection, rotations, Gaussian init, and alignment seeds |
Expand Down
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
133 changes: 80 additions & 53 deletions paper/experiments/spherical_mnist_reconstruction/reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1063,17 +1063,19 @@ def make_paper_figure(
digit_indices: list[int],
path: Path,
view_size: int = 192,
panel_size: float = 1.9,
panel_size: float = 1.15,
auto_center: bool = True,
elev_deg: float = 25.0,
azim_deg: float = 30.0,
) -> None:
"""Compact 2-row x 2-col figure for the NeurIPS paper.

For each requested ``digit_idx`` we pick the rotation with the lowest
aligned image residual and render ``[target | aligned recon]`` as a row.
The result is a clean ``len(digit_indices) x 2`` grid suitable for a
``\\includegraphics[width=0.6\\linewidth]`` inclusion.
"""Compact ``2 x len(digit_indices)`` figure for the NeurIPS paper.

Rows: ``Target f`` / ``Recon \\hat R \\cdot \\hat f``. Columns: one per
requested ``digit_idx`` (we keep the rotation with the lowest aligned
image residual). Both rows of a column share the *target's* camera so
any visible mismatch is a true reconstruction error rather than a
viewpoint difference. Sized for ``\\includegraphics[width=\\linewidth]``
inclusion in a NeurIPS column.
"""
plt.rcParams.update(NEURIPS_RCPARAMS)

Expand All @@ -1090,12 +1092,12 @@ def make_paper_figure(
best = min(by_digit[d], key=lambda r: r.aligned_image_space_rel)
picks.append(best)

n_rows = len(picks)
n_cols = 2
n_cols = len(picks)
n_rows = 2
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=(panel_size * n_cols, panel_size * n_rows),
figsize=(panel_size * n_cols + 0.6, panel_size * n_rows + 0.5),
squeeze=False,
)

Expand All @@ -1107,7 +1109,7 @@ def _view_for(arr: np.ndarray) -> tuple[float, float, list[tuple[np.ndarray, np.
e, a = _signal_view_angle(arr, fallback_elev=elev_deg, fallback_azim=azim_deg)
return e, a, _sphere_grid_lines(e, a)

for i, r in enumerate(picks):
for j, r in enumerate(picks):
target = r.target.numpy()
aligned = r.aligned.numpy()
vmax = max(abs(target).max(), abs(aligned).max(), 1e-8)
Expand All @@ -1116,33 +1118,38 @@ def _view_for(arr: np.ndarray) -> tuple[float, float, list[tuple[np.ndarray, np.
e_t, a_t, gl_t = _view_for(target)

_draw_sphere(
axes[i][0], target, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t
axes[0][j], target, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t
)
_draw_sphere(
axes[i][1], aligned, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t
axes[1][j], aligned, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t
)

if i == 0:
axes[i][0].set_title('Target $f$', fontsize=10, pad=4)
axes[i][1].set_title(
r'Recon $\hat{R}\!\cdot\!\hat{f}$', fontsize=10, pad=4
)
axes[i][0].set_ylabel(
f'class {r.label}',
rotation=0,
ha='right',
va='center',
fontsize=9,
labelpad=10,
)
axes[i][1].set_xlabel(
rf'$\|\hat{{R}}\!\cdot\!\hat{{f}}-f\|/\|f\|={r.aligned_image_space_rel:.2f}$',
fontsize=7,
axes[0][j].set_title(f'class {r.label}', fontsize=9, pad=3)
axes[1][j].set_xlabel(
rf'${r.aligned_image_space_rel:.2f}$',
fontsize=8,
labelpad=2,
)

axes[0][0].set_ylabel(
'Target $f$',
rotation=90,
ha='center',
va='center',
fontsize=10,
labelpad=8,
)
axes[1][0].set_ylabel(
r'Recon $\hat{R}\!\cdot\!\hat{f}$',
rotation=90,
ha='center',
va='center',
fontsize=10,
labelpad=8,
)

fig.tight_layout(pad=0.3)
fig.subplots_adjust(wspace=0.02, hspace=0.08)
fig.subplots_adjust(wspace=0.04, hspace=0.06)
fig.savefig(path)
fig.savefig(str(path).replace('.pdf', '.png'))
plt.close(fig)
Expand Down Expand Up @@ -1230,15 +1237,21 @@ def parse_args() -> argparse.Namespace:
parser.add_argument('--fixed_view', action='store_true',
help='Use --elev_deg / --azim_deg for every panel instead of '
'auto-centering on each signal\'s positive-mass centroid.')
parser.add_argument('--paper_digits', type=int, nargs='+', default=[0, 1],
help='Digit indices used for the compact paper figure (2x2).')
parser.add_argument('--paper_digits', type=int, nargs='+',
default=[0, 1, 2, 3, 4, 5],
help='Digit indices used for the compact paper figure '
'(rows = Target/Recon, columns = digits).')
parser.add_argument('--paper_figure_path', type=Path, default=None,
help='Override output path for the paper figure '
'(default: <output_dir>/paper_orbits.pdf).')
parser.add_argument('--paper_only', action='store_true',
help='Run only the digits required for the paper figure '
'(--paper_digits) and skip the comprehensive figures. '
'Fast regeneration path for the paper.')
parser.add_argument('--replot', action='store_true',
help='Skip the reconstruction and rebuild figures from the '
'<output_dir>/state.pt cache produced by a previous run. '
'Use this to iterate on figure aesthetics in seconds.')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_every', type=int, default=300)
Expand Down Expand Up @@ -1279,28 +1292,42 @@ def main() -> int:
n_digits_run = args.n_digits
n_rotations_run = args.n_rotations

results, meta = run_demo(
data_dir=args.data_dir,
output_dir=args.output_dir,
n_digits=n_digits_run,
n_rotations=n_rotations_run,
nlat=args.nlat,
nlon=args.nlon,
lmax=args.lmax,
selective=not args.full_bispectrum,
n_steps=args.n_steps,
lr=args.lr,
bandlimit_project=not args.no_bandlimit_project,
device=device,
seed=args.seed,
log_every=args.log_every,
align_n_restarts=args.align_n_restarts,
align_n_steps=args.align_n_steps,
align_lr=args.align_lr,
n_recon_restarts=args.n_recon_restarts,
)

args.output_dir.mkdir(parents=True, exist_ok=True)
state_path = args.output_dir / 'state.pt'

if args.replot:
if not state_path.exists():
raise FileNotFoundError(
f'--replot requires {state_path} from a previous run; rerun once without --replot.'
)
logger.info('Loading cached state from %s (skipping reconstruction)', state_path)
cached = torch.load(state_path, weights_only=False)
results = cached['results']
meta = cached['meta']
else:
results, meta = run_demo(
data_dir=args.data_dir,
output_dir=args.output_dir,
n_digits=n_digits_run,
n_rotations=n_rotations_run,
nlat=args.nlat,
nlon=args.nlon,
lmax=args.lmax,
selective=not args.full_bispectrum,
n_steps=args.n_steps,
lr=args.lr,
bandlimit_project=not args.no_bandlimit_project,
device=device,
seed=args.seed,
log_every=args.log_every,
align_n_restarts=args.align_n_restarts,
align_n_steps=args.align_n_steps,
align_lr=args.align_lr,
n_recon_restarts=args.n_recon_restarts,
)
torch.save({'results': results, 'meta': meta}, state_path)
logger.info('Cached state to %s for fast --replot iterations', state_path)

paper_path = args.paper_figure_path or (args.output_dir / 'paper_orbits.pdf')

if args.paper_only:
Expand Down
Binary file modified paper/paper.pdf
Binary file not shown.
Loading
Loading