diff --git a/paper/experiments/spherical_mnist_reconstruction/.gitignore b/paper/experiments/spherical_mnist_reconstruction/.gitignore index 45f01f7..3c7d08c 100644 --- a/paper/experiments/spherical_mnist_reconstruction/.gitignore +++ b/paper/experiments/spherical_mnist_reconstruction/.gitignore @@ -1,6 +1,5 @@ -figures/ +figures/* +!figures/paper_orbits.pdf __pycache__/ -*.png -*.pdf *.json *.log diff --git a/paper/experiments/spherical_mnist_reconstruction/README.md b/paper/experiments/spherical_mnist_reconstruction/README.md index abc68bc..e4617cc 100644 --- a/paper/experiments/spherical_mnist_reconstruction/README.md +++ b/paper/experiments/spherical_mnist_reconstruction/README.md @@ -44,22 +44,30 @@ Default settings (`lmax=12`, `n_steps=8000`, `n_recon_restarts=4`, For a faster sanity check use `--n_digits 3 --n_recon_restarts 1 --n_steps 3000 --align_n_restarts 4`. -To regenerate **only the compact 2x2 paper figure** (~5 min), skip the -comprehensive sweep: +To regenerate **only the compact paper figure** (`2 x len(--paper_digits)`, +full-width on a NeurIPS column), skip the comprehensive sweep: ```bash -python reconstruct.py --paper_only --paper_digits 0 1 +python reconstruct.py --paper_only --paper_digits 0 1 2 3 4 5 +``` + +This caches all the recon tensors to `figures/state.pt`. Iterate on the +figure layout *without* re-running the optimization: + +```bash +python reconstruct.py --paper_only --paper_digits 0 1 2 3 4 5 --replot ``` Outputs land in `figures/`: -| File | Description | -|---------------------------------|----------------------------------------------------------------------| -| `orbits.{pdf,png}` | `(n_digits x 3(1+K))` grid: target / raw recon / aligned recon spheres | -| `paper_orbits.{pdf,png}` | Compact `(len(--paper_digits) x 2)` figure for the NeurIPS paper | -| `convergence.{pdf,png}` | Median + IQR of relative bispectrum residual vs. step | -| `invariance_vs_recon.{pdf,png}` | Per-pair scatter of invariance vs. recon residual | -| `results.json` | All scalar metrics + per-step traces | +| File | Description | +|---------------------------------|--------------------------------------------------------------------------------------| +| `orbits.{pdf,png}` | `(n_digits x 3(1+K))` grid: target / raw recon / aligned recon spheres | +| `paper_orbits.{pdf,png}` | Compact `(2 rows x len(--paper_digits) cols)` figure for the NeurIPS paper | +| `state.pt` | Cached `ReconResult` tensors for fast `--replot` iteration | +| `convergence.{pdf,png}` | Median + IQR of relative bispectrum residual vs. step | +| `invariance_vs_recon.{pdf,png}` | Per-pair scatter of invariance vs. recon residual | +| `results.json` | All scalar metrics + per-step traces | ## Key CLI arguments @@ -79,9 +87,10 @@ Outputs land in `figures/`: | `--view_size` | `128` | Orthographic view resolution | | `--elev_deg / --azim_deg` | `25/30` | Fallback / fixed-view camera direction (degrees) | | `--fixed_view` | off | Disable per-panel auto-centering on the signal centroid (use a single shared camera direction instead) | -| `--paper_digits` | `0 1` | `digit_idx` values to use for the compact `paper_orbits.pdf` figure | +| `--paper_digits` | `0 1 2 3 4 5` | `digit_idx` values to use as columns of the compact `paper_orbits.pdf` figure | | `--paper_figure_path` | auto | Override output path; defaults to `/paper_orbits.pdf` | | `--paper_only` | off | Run only the digits in `--paper_digits` and emit only the paper figure (fast regeneration path) | +| `--replot` | off | Skip recon + alignment, rebuild figures from `/state.pt` (~3 s) | | `--full_bispectrum` | off | `O(L^3)` full bispectrum instead of selective `O(L^2)` | | `--no_bandlimit_project` | off | Disable the per-step `IRealSHT(RealSHT(.))` projection | | `--seed` | `0` | Controls digit selection, rotations, Gaussian init, and alignment seeds | diff --git a/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.pdf b/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.pdf index 229ff93..e1a8839 100644 Binary files a/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.pdf and b/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.pdf differ diff --git a/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.png b/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.png index eca1c44..48e8277 100644 Binary files a/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.png and b/paper/experiments/spherical_mnist_reconstruction/figures/paper_orbits.png differ diff --git a/paper/experiments/spherical_mnist_reconstruction/reconstruct.py b/paper/experiments/spherical_mnist_reconstruction/reconstruct.py index 174d43e..217ec23 100644 --- a/paper/experiments/spherical_mnist_reconstruction/reconstruct.py +++ b/paper/experiments/spherical_mnist_reconstruction/reconstruct.py @@ -1063,17 +1063,19 @@ def make_paper_figure( digit_indices: list[int], path: Path, view_size: int = 192, - panel_size: float = 1.9, + panel_size: float = 1.15, auto_center: bool = True, elev_deg: float = 25.0, azim_deg: float = 30.0, ) -> None: - """Compact 2-row x 2-col figure for the NeurIPS paper. - - For each requested ``digit_idx`` we pick the rotation with the lowest - aligned image residual and render ``[target | aligned recon]`` as a row. - The result is a clean ``len(digit_indices) x 2`` grid suitable for a - ``\\includegraphics[width=0.6\\linewidth]`` inclusion. + """Compact ``2 x len(digit_indices)`` figure for the NeurIPS paper. + + Rows: ``Target f`` / ``Recon \\hat R \\cdot \\hat f``. Columns: one per + requested ``digit_idx`` (we keep the rotation with the lowest aligned + image residual). Both rows of a column share the *target's* camera so + any visible mismatch is a true reconstruction error rather than a + viewpoint difference. Sized for ``\\includegraphics[width=\\linewidth]`` + inclusion in a NeurIPS column. """ plt.rcParams.update(NEURIPS_RCPARAMS) @@ -1090,12 +1092,12 @@ def make_paper_figure( best = min(by_digit[d], key=lambda r: r.aligned_image_space_rel) picks.append(best) - n_rows = len(picks) - n_cols = 2 + n_cols = len(picks) + n_rows = 2 fig, axes = plt.subplots( n_rows, n_cols, - figsize=(panel_size * n_cols, panel_size * n_rows), + figsize=(panel_size * n_cols + 0.6, panel_size * n_rows + 0.5), squeeze=False, ) @@ -1107,7 +1109,7 @@ def _view_for(arr: np.ndarray) -> tuple[float, float, list[tuple[np.ndarray, np. e, a = _signal_view_angle(arr, fallback_elev=elev_deg, fallback_azim=azim_deg) return e, a, _sphere_grid_lines(e, a) - for i, r in enumerate(picks): + for j, r in enumerate(picks): target = r.target.numpy() aligned = r.aligned.numpy() vmax = max(abs(target).max(), abs(aligned).max(), 1e-8) @@ -1116,33 +1118,38 @@ def _view_for(arr: np.ndarray) -> tuple[float, float, list[tuple[np.ndarray, np. e_t, a_t, gl_t = _view_for(target) _draw_sphere( - axes[i][0], target, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t + axes[0][j], target, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t ) _draw_sphere( - axes[i][1], aligned, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t + axes[1][j], aligned, vmin, vmax, view_size, e_t, a_t, grid_lines=gl_t ) - if i == 0: - axes[i][0].set_title('Target $f$', fontsize=10, pad=4) - axes[i][1].set_title( - r'Recon $\hat{R}\!\cdot\!\hat{f}$', fontsize=10, pad=4 - ) - axes[i][0].set_ylabel( - f'class {r.label}', - rotation=0, - ha='right', - va='center', - fontsize=9, - labelpad=10, - ) - axes[i][1].set_xlabel( - rf'$\|\hat{{R}}\!\cdot\!\hat{{f}}-f\|/\|f\|={r.aligned_image_space_rel:.2f}$', - fontsize=7, + axes[0][j].set_title(f'class {r.label}', fontsize=9, pad=3) + axes[1][j].set_xlabel( + rf'${r.aligned_image_space_rel:.2f}$', + fontsize=8, labelpad=2, ) + axes[0][0].set_ylabel( + 'Target $f$', + rotation=90, + ha='center', + va='center', + fontsize=10, + labelpad=8, + ) + axes[1][0].set_ylabel( + r'Recon $\hat{R}\!\cdot\!\hat{f}$', + rotation=90, + ha='center', + va='center', + fontsize=10, + labelpad=8, + ) + fig.tight_layout(pad=0.3) - fig.subplots_adjust(wspace=0.02, hspace=0.08) + fig.subplots_adjust(wspace=0.04, hspace=0.06) fig.savefig(path) fig.savefig(str(path).replace('.pdf', '.png')) plt.close(fig) @@ -1230,8 +1237,10 @@ def parse_args() -> argparse.Namespace: parser.add_argument('--fixed_view', action='store_true', help='Use --elev_deg / --azim_deg for every panel instead of ' 'auto-centering on each signal\'s positive-mass centroid.') - parser.add_argument('--paper_digits', type=int, nargs='+', default=[0, 1], - help='Digit indices used for the compact paper figure (2x2).') + parser.add_argument('--paper_digits', type=int, nargs='+', + default=[0, 1, 2, 3, 4, 5], + help='Digit indices used for the compact paper figure ' + '(rows = Target/Recon, columns = digits).') parser.add_argument('--paper_figure_path', type=Path, default=None, help='Override output path for the paper figure ' '(default: /paper_orbits.pdf).') @@ -1239,6 +1248,10 @@ def parse_args() -> argparse.Namespace: help='Run only the digits required for the paper figure ' '(--paper_digits) and skip the comprehensive figures. ' 'Fast regeneration path for the paper.') + parser.add_argument('--replot', action='store_true', + help='Skip the reconstruction and rebuild figures from the ' + '/state.pt cache produced by a previous run. ' + 'Use this to iterate on figure aesthetics in seconds.') parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--log_every', type=int, default=300) @@ -1279,28 +1292,42 @@ def main() -> int: n_digits_run = args.n_digits n_rotations_run = args.n_rotations - results, meta = run_demo( - data_dir=args.data_dir, - output_dir=args.output_dir, - n_digits=n_digits_run, - n_rotations=n_rotations_run, - nlat=args.nlat, - nlon=args.nlon, - lmax=args.lmax, - selective=not args.full_bispectrum, - n_steps=args.n_steps, - lr=args.lr, - bandlimit_project=not args.no_bandlimit_project, - device=device, - seed=args.seed, - log_every=args.log_every, - align_n_restarts=args.align_n_restarts, - align_n_steps=args.align_n_steps, - align_lr=args.align_lr, - n_recon_restarts=args.n_recon_restarts, - ) - args.output_dir.mkdir(parents=True, exist_ok=True) + state_path = args.output_dir / 'state.pt' + + if args.replot: + if not state_path.exists(): + raise FileNotFoundError( + f'--replot requires {state_path} from a previous run; rerun once without --replot.' + ) + logger.info('Loading cached state from %s (skipping reconstruction)', state_path) + cached = torch.load(state_path, weights_only=False) + results = cached['results'] + meta = cached['meta'] + else: + results, meta = run_demo( + data_dir=args.data_dir, + output_dir=args.output_dir, + n_digits=n_digits_run, + n_rotations=n_rotations_run, + nlat=args.nlat, + nlon=args.nlon, + lmax=args.lmax, + selective=not args.full_bispectrum, + n_steps=args.n_steps, + lr=args.lr, + bandlimit_project=not args.no_bandlimit_project, + device=device, + seed=args.seed, + log_every=args.log_every, + align_n_restarts=args.align_n_restarts, + align_n_steps=args.align_n_steps, + align_lr=args.align_lr, + n_recon_restarts=args.n_recon_restarts, + ) + torch.save({'results': results, 'meta': meta}, state_path) + logger.info('Cached state to %s for fast --replot iterations', state_path) + paper_path = args.paper_figure_path or (args.output_dir / 'paper_orbits.pdf') if args.paper_only: diff --git a/paper/paper.pdf b/paper/paper.pdf index 4c7c517..c83f365 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 7e480b7..944483b 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -959,47 +959,58 @@ \subsection{Empirical proof of concept: bispectrum reconstruction} \label{sec:so3-recon} To test whether $\beta_{\mathrm{sel}}$ retains enough information to -recover signals up to rotation, we run a direct inversion experiment. -For each of $8$ source signals $f:S^2\to\R$ band-limited at -$L=15$, we form three rotated copies $f_k = R_k\cdot f$ -($k\in\{0,1,2\}$) under random $R_k\in\SO(3)$, compute -$\beta_{\mathrm{sel}}(f_k)$, and reconstruct each $f_k$ by gradient -descent on the spherical-harmonic coefficients -$\hat a^m_\ell$, minimising -$\lVert \beta_{\mathrm{sel}}(\hat f_k) - -\beta_{\mathrm{sel}}(f_k)\rVert^2$ from a random initialisation. -After convergence, $\hat f_k$ is aligned to $f_k$ by the optimal -$R\in\SO(3)$ (degree-1 alignment + Wigner-$D$ refinement). +recover signals up to rotation, we run a direct inversion experiment +on Spherical MNIST digits band-limited at $L = 12$.\footnote{We use +$L{=}12$ here rather than the classifier's $L{=}15$ to keep the +reconstruction residual safely below the SHT discretisation floor at +the $64{\times}128$ grid; protocol details are in +Appendix~\ref{app:so3-recon}.} For each source signal +$f:S^2\to\R$ we apply a uniform random rotation $R\in\SO(3)$ to form +the target $R\cdot f$, compute $\beta_{\mathrm{sel}}(R\cdot f)$, and +recover $\hat f$ by gradient descent from a random Gaussian +initialisation, minimising +$\lVert \beta_{\mathrm{sel}}(\hat f) - +\beta_{\mathrm{sel}}(R\cdot f)\rVert^2 / \lVert +\beta_{\mathrm{sel}}(R\cdot f)\rVert^2$ (Adam, +multi-restart, cosine-annealed LR). The orbit ambiguity is then +resolved by Procrustes-style optimisation on a quaternion +parameterisation: we fit $\hat R\in\SO(3)$ minimising $\lVert +\hat R\cdot\hat f - R\cdot f\rVert^2$. \begin{figure}[t] \centering - \includegraphics[width=\linewidth]{figures/so3_reconstruction.png} - \caption{$\SO(3)$ bispectrum reconstruction (proof of concept). - For each of $8$ source signals (rows), three rotations - $f_k = R_k\cdot f$ ($k\in\{0,1,2\}$) of the same target are - inverted by gradient descent from $\beta_{\mathrm{sel}}(f_k)$ at - band-limit $L=15$. Each triplet shows: target $f_k$ (left), - reconstruction $\hat f_k$ (middle, with $\beta$-residual), and - $\hat f_k$ aligned to $f_k$ by the optimal $R\in\SO(3)$ (right, - with image-residual). All spheres are auto-centred on the - signal's positive-mass centroid; aligned and target panels share - the same view because they are the same signal.} + \includegraphics[width=\linewidth]{experiments/spherical_mnist_reconstruction/figures/paper_orbits.pdf} + \caption{$\SO(3)$ bispectrum reconstruction on Spherical MNIST + (proof of concept). Top row: target signal $f$ for digit classes + 0--5. Bottom row: $\hat R\!\cdot\!\hat f$, where $\hat f$ is + recovered from $\beta_{\mathrm{sel}}(f)$ alone by gradient + descent in image space and $\hat R\in\SO(3)$ is fit by Procrustes + alignment to factor out the orbit ambiguity. Both panels of a + column share the target's camera, so any visible difference is a + true reconstruction error rather than a viewpoint mismatch. The + number under each column is the relative image-space residual + $\lVert\hat R\!\cdot\!\hat f - f\rVert/\lVert f\rVert$ (median + $\approx 0.20$ across the six classes; raw + $\lVert\hat f - f\rVert/\lVert f\rVert$ before alignment is + $\mathcal{O}(1)$, $\approx 1.3$).} \label{fig:so3-reconstruction} \end{figure} \paragraph{Takeaways.} Bispectrum residuals reach -$\lVert\Delta\beta\rVert/\lVert\beta\rVert\sim 10^{-4}$--$10^{-3}$ -(machine-precision relative error), and the aligned image residuals -$\lVert R\cdot\hat f_k - f_k\rVert/\lVert f_k\rVert\sim -10^{-1}$--$5\cdot 10^{-1}$ are dominated by residual local optima of -the gradient-descent solver, not by missing invariant information -- -the recovered $\hat f_k$ visually matches $f_k$ up to rotation across -all $8\times 3 = 24$ test cases (Figure~\ref{fig:so3-reconstruction}). -We read this as proof-of-concept evidence that the selective -coefficients carry enough information to recover band-limited signals -up to $\SO(3)$; a formal proof of completeness for $\beta_{\mathrm{sel}}$ -remains an open question. +$\lVert\Delta\beta\rVert/\lVert\beta\rVert \sim 10^{-4}$--$10^{-3}$ +(below the SHT-discretisation invariance noise floor of +${\sim}8\times 10^{-3}$ at this grid), and the aligned image residuals +$\lVert\hat R\cdot\hat f - f\rVert/\lVert f\rVert \sim +0.1$--$0.3$ are dominated by residual local optima of the +gradient-descent solvers, not by missing invariant information---the +recovered $\hat f$ visually matches $f$ up to rotation across all six +digit classes shown (Figure~\ref{fig:so3-reconstruction}; the same +behaviour holds across our 8-digit sweep, see +Appendix~\ref{app:so3-recon}). We read this as proof-of-concept +evidence that the selective coefficients carry enough information to +recover band-limited signals up to $\SO(3)$; a formal proof of +completeness for $\beta_{\mathrm{sel}}$ remains an open question. \section{Discussion and Conclusion} @@ -2055,6 +2066,56 @@ \subsection{Aggregation procedures} of accuracy, yielding 30 measurements per model (10 rotations $\times$ 3 seeds). +\subsection{Spherical MNIST: bispectrum reconstruction protocol} +\label{app:so3-recon} + +This section gives the protocol behind +Figure~\ref{fig:so3-reconstruction} (Section~\ref{sec:so3-recon}). +The implementation lives in +\texttt{paper/experiments/spherical\_mnist\_reconstruction/reconstruct.py}. + +\paragraph{Setup.} +Spherical MNIST is rendered onto an equiangular $64{\times}128$ +lat--lon grid using the same stereographic projection as the +classifier, then band-limited at $L = 12$ via a forward/inverse SHT +roundtrip (\texttt{torch\_harmonics.RealSHT}). For each digit we +draw a target rotation $R\in\SO(3)$ uniformly at random and form +$f := R\cdot f_0$. We compute the SO(3)-on-$S^2$ selective bispectrum +$\beta := \beta_{\mathrm{sel}}(f)$. + +\paragraph{Reconstruction.} +Starting from a Gaussian initialization $\hat f_0\sim\mathcal{N}(0, I)$ +on the $64{\times}128$ grid, we optimize $\hat f$ with Adam +(initial LR $5\times 10^{-2}$, cosine annealing to $5\times 10^{-4}$, +$8000$ steps) to minimize the relative complex L2 loss +$\|\beta(\hat f) - \beta\|^2 / \|\beta\|^2$. After every step we +project $\hat f$ back to the band-limited subspace via $\mathrm{SHT}^{-1}\!\circ\mathrm{SHT}$ truncated at $L$ to suppress the +unconstrained null space. We run $4$ random restarts and keep the +$\hat f$ with the lowest bispectrum residual. Reconstruction +typically reaches $\|\beta(\hat f)-\beta\|/\|\beta\|\sim 10^{-3}$, +near the SHT-discretization invariance floor for this grid. + +\paragraph{$\SO(3)$ alignment.} +The bispectrum is invariant on $\SO(3)$-orbits, so $\hat f$ lives +somewhere in the orbit of $f$ rather than at $f$ itself. We resolve +the orbit ambiguity by Procrustes-style fitting: we parameterize +$\hat R \in \SO(3)$ as a unit quaternion (auto-normalized) to avoid +gimbal lock, and minimize $\|\hat R\cdot\hat f - f\|^2$ with Adam +($12$ random restarts, $200$ steps each, cosine-annealed LR). +Rotations of the spherical image use bilinear interpolation in +lat--lon coordinates (\texttt{bispectrum.rotate\_spherical\_function}). + +\paragraph{Numerical floor.} +On a $64{\times}128$ grid, two SHT roundtrips of two $\SO(3)$-rotated +copies of the same signal already differ by +$\|f' - f\|/\|f\|\approx 2\times 10^{-2}$ in image space due to +bilinear-interpolation discretization; this is the noise floor below +which alignment cannot improve. The aligned residuals reported in +Figure~\ref{fig:so3-reconstruction} ($0.1$--$0.3$) sit one to two +orders of magnitude above this floor and are dominated by gradient +descent local optima on the bispectrum reconstruction step, not by +missing information. + \subsection{Dataset licenses} \begin{itemize}[nosep]