From 4c351be5e6a602d569ef3676964194e45d8bb763 Mon Sep 17 00:00:00 2001 From: wdwzyyg Date: Tue, 3 Feb 2026 21:33:21 -0500 Subject: [PATCH 1/6] reload h5 dtype fix --- phaser/execute.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/phaser/execute.py b/phaser/execute.py index 593a71c..cf9c2f3 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -307,6 +307,8 @@ def initialize_reconstruction( if init_state.scan is not None and plan.init.scan is None: logging.info("Re-using scan from initial state...") scan = init_state.scan + scan = scan.astype(dtype) + else: logging.info("Initializing scan...") scan = pane.from_data(scan_hook, ScanHook)( # type: ignore @@ -316,6 +318,8 @@ def initialize_reconstruction( if init_state.tilt is not None and plan.init.tilt is None: logging.info("Re-using tilt from initial state...") tilt = init_state.tilt + tilt = tilt.astype(dtype) + elif tilt_hook is not None: logging.info("Initializing tilt...") tilt = pane.from_data(tilt_hook, TiltHook)( # type: ignore From 49e53bd5fe2c15ee878763ee74b59852ffe8d732 Mon Sep 17 00:00:00 2001 From: wdwzyyg Date: Sun, 29 Mar 2026 18:13:17 -0400 Subject: [PATCH 2/6] adding ssim to progress --- phaser/engines/conventional/run.py | 23 ++++- phaser/engines/gradient/run.py | 19 ++++ phaser/execute.py | 9 +- phaser/observer.py | 85 ++++++++++++------ phaser/plan.py | 13 ++- phaser/utils/analysis.py | 134 ++++++++++++++++++++++++++++- pyproject.toml | 1 + 7 files changed, 250 insertions(+), 34 deletions(-) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 2264805..ddf51dd 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -1,7 +1,10 @@ import logging +import numpy +from copy import deepcopy from phaser.utils.misc import mask_fraction_of_groups from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype +from phaser.utils.analysis import structural_similarity from phaser.observer import Observer from phaser.hooks import EngineArgs from phaser.plan import ConventionalEnginePlan @@ -56,7 +59,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: else: other_keys = () - for k in ('detector_loss', 'total_loss', *other_keys): + ssim_keys = ( + *(('obj_ssim',) if props.track_ssim else ()), + *(('probe_ssim',) if props.track_ssim else ()), + ) + + for k in ('detector_loss', 'total_loss', *other_keys, *ssim_keys): if k not in sim.state.progress: sim.state.progress[k] = ProgressState() @@ -88,6 +96,8 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: iter_update_positions = update_positions({'state': sim.state, 'niter': props.niter}) iter_shuffle_groups = shuffle_groups({'state': sim.state, 'niter': props.niter}) + state_store = deepcopy(sim.state) if props.track_ssim else None + sim, pos_update, group_errors = solver.run_iteration( sim, groups.iter(sim.state.scan, i, iter_shuffle_groups), patterns=patterns, pattern_mask=pattern_mask, propagators=propagators, @@ -101,6 +111,17 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) + if state_store is not None: + obj_before = numpy.angle(to_numpy(state_store.object.data[0])) + obj_after = numpy.angle(to_numpy(sim.state.object.data[0])) + progress['obj_ssim'].iters.append(i + start_i) + progress['obj_ssim'].values.append(structural_similarity(obj_after, obj_before)) + + probe_before = numpy.abs(to_numpy(state_store.probe.data[0])) + probe_after = numpy.abs(to_numpy(sim.state.probe.data[0])) + progress['probe_ssim'].iters.append(i + start_i) + progress['probe_ssim'].values.append(structural_similarity(probe_after, probe_before)) + sim = sim.apply_iter_constraints() if iter_update_positions: diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index 489ac37..ca20c28 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -1,4 +1,5 @@ import logging +from copy import deepcopy from functools import partial import typing as t @@ -22,6 +23,7 @@ from phaser.plan import GradientEnginePlan from phaser.types import process_flag, ReconsVar from ..common.simulation import GroupManager, make_propagators, tilt_propagators, slice_forwards, stream_patterns +from phaser.utils.analysis import structural_similarity logger = logging.getLogger(__name__) @@ -235,6 +237,9 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple other_keys = ( *(('pos_update_rms',) if 'positions' in all_vars else ()), *(('tilt_update_rms',) if 'tilt' in all_vars else ()), + *(('obj_ssim',) if props.track_ssim else ()), + *(('probe_ssim',) if props.track_ssim else ()), + ) # populate missing keys in progress dictionary @@ -314,8 +319,22 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple (update, iter_solver_states[sol_i]) = solver.update( state, iter_solver_states[sol_i], filter_vars(iter_grads, solver.params), losses['total_loss'] ) + + state_store = deepcopy(state) if props.track_ssim else None state = apply_update(state, update) + if 'obj_ssim' in other_keys and state_store is not None: + obj_before = numpy.angle(to_numpy(state_store.object.data[0])) + obj_after = numpy.angle(to_numpy(state.object.data[0])) + progress['obj_ssim'].iters.append(i + start_i) + progress['obj_ssim'].values.append(structural_similarity(obj_after, obj_before)) + + if 'probe_ssim' in other_keys and state_store is not None: + probe_before = numpy.abs(to_numpy(state_store.probe.data[0])) + probe_after = numpy.abs(to_numpy(state.probe.data[0])) + progress['probe_ssim'].iters.append(i + start_i) + progress['probe_ssim'].values.append(structural_similarity(probe_after, probe_before)) + if 'positions' in update: pos_update_rms = float(xp.mean(xp.linalg.norm(update['positions'], axis=-1))) progress['pos_update_rms'].iters.append(i + start_i) diff --git a/phaser/execute.py b/phaser/execute.py index cf9c2f3..37332e4 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -57,9 +57,14 @@ def execute_engine( engine_i = recons.state.iter.engine_num - if plan.early_termination: + if any(v is not None for v in ( + plan.early_termination_loss, plan.early_termination_obj_ssim, plan.early_termination_probe_ssim + )): engine_observer = ObserverSet((recons.observer, PatienceObserver( - plan.early_termination, plan.early_termination_smoothing + patience_loss=plan.early_termination_loss, + patience_obj_ssim=plan.early_termination_obj_ssim, + patience_probe_ssim=plan.early_termination_probe_ssim, + smoothing=plan.early_termination_smoothing, ))) else: engine_observer = recons.observer diff --git a/phaser/observer.py b/phaser/observer.py index dbb9f53..87c1cac 100644 --- a/phaser/observer.py +++ b/phaser/observer.py @@ -169,44 +169,75 @@ def finish_recons(self, state: ReconsState): class PatienceObserver(Observer): - def __init__(self, patience: int, smoothing: float = 0.1, continue_next_engine: bool = True): - self.patience: int = patience - self.no_improvement_iter: int = 0 - self.best_error: t.Optional[float] = None - self.smoothed_error: t.Optional[float] = None + # metrics where higher values indicate improvement + _HIGHER_IS_BETTER: t.FrozenSet[str] = frozenset({'obj_ssim', 'probe_ssim'}) + + def __init__( + self, + patience_loss: t.Optional[int] = None, + patience_obj_ssim: t.Optional[int] = None, + patience_probe_ssim: t.Optional[int] = None, + smoothing: float = 0.1, + continue_next_engine: bool = True, + ): self.smoothing: float = smoothing self.continue_next_engine: bool = continue_next_engine + # build active metric table: key -> patience + self._patience: t.Dict[str, int] = {} + if patience_loss is not None: + self._patience['total_loss'] = patience_loss + if patience_obj_ssim is not None: + self._patience['obj_ssim'] = patience_obj_ssim + if patience_probe_ssim is not None: + self._patience['probe_ssim'] = patience_probe_ssim + + self._best: t.Dict[str, float] = {} + self._no_improvement: t.Dict[str, int] = {} + self._smoothed: t.Dict[str, float] = {} + def init_engine( self, init_state: ReconsState, *, recons_name: str, plan: EnginePlan, **kwargs: t.Any ): - self.no_improvement_iter = 0 - - def _error_from_state(self, state: t.Union[ReconsState, PartialReconsState]) -> t.Optional[float]: - if state.progress is None or (progress := state.progress['total_loss']) is None or not len(progress.values): - return None - return progress.values[-1] + self._best = {} + self._no_improvement = {k: 0 for k in self._patience} + self._smoothed = {} def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): - if (error := errors.get('total_loss')) is None: - return - - if self.best_error is None or error < self.best_error: - self.best_error = error - self.no_improvement_iter = 0 - else: - self.no_improvement_iter += 1 + for key, patience in self._patience.items(): + # read value: loss comes from errors dict, ssim metrics from state.progress + if key == 'total_loss': + value: t.Optional[float] = errors.get('total_loss') + else: + prog = state.progress.get(key) if state.progress else None + value = prog.values[-1] if prog is not None and len(prog.values) else None + + if value is None: + continue + + # exponential moving average + if key not in self._smoothed: + self._smoothed[key] = value + else: + self._smoothed[key] = (1 - self.smoothing) * self._smoothed[key] + self.smoothing * value + + higher_is_better = key in self._HIGHER_IS_BETTER + improved = ( + key not in self._best + or (higher_is_better and value > self._best[key]) + or (not higher_is_better and value < self._best[key]) + ) - # Exponential moving average - if self.smoothed_error is None: - self.smoothed_error = error - else: - self.smoothed_error = (1 - self.smoothing) * self.smoothed_error + self.smoothing * error + if improved: + self._best[key] = value + self._no_improvement[key] = 0 + else: + self._no_improvement[key] += 1 - if self.no_improvement_iter >= self.patience: - logging.info(f"Early termination: no improvement for {self.patience} iterations") - raise EarlyTermination(state, self.continue_next_engine) + if self._no_improvement[key] >= patience: + logging.info(f"Early termination: {key} no improvement for {patience} iterations") + raise EarlyTermination(state, self.continue_next_engine) class SaveObserver(Observer): diff --git a/phaser/plan.py b/phaser/plan.py index 0e68ad6..f5574cb 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -85,8 +85,12 @@ class EnginePlan(Dataclass, kw_only=True): save_images: FlagLike = False save_options: SaveOptions = SaveOptions() - early_termination: t.Optional[int] = None - """Terminate after n iterations without improvement""" + early_termination_loss: t.Optional[int] = None + """Terminate after n iterations without improvement in total_loss""" + early_termination_obj_ssim: t.Optional[int] = None + """Terminate after n iterations without improvement in obj_ssim (requires track_ssim=True)""" + early_termination_probe_ssim: t.Optional[int] = None + """Terminate after n iterations without improvement in probe_ssim (requires track_ssim=True)""" early_termination_smoothing: float = 0.9 """ Smoothing factor to apply to error measurement for early termination. @@ -94,6 +98,9 @@ class EnginePlan(Dataclass, kw_only=True): (smooths over ~1/smoothing iterations) """ + track_ssim: bool = False + """Track SSIM between consecutive iterations as a convergence metric.""" + check_every_group: bool = False send_every_group: bool = False @@ -154,7 +161,7 @@ class GradientEnginePlan(EnginePlan): regularizers: t.List[CostRegularizerHook] group_constraints: t.List[GroupConstraintHook] iter_constraints: t.List[IterConstraintHook] - + class SGDSolverPlan(Dataclass, kw_only=True): learning_rate: ScheduleLike diff --git a/phaser/utils/analysis.py b/phaser/utils/analysis.py index 024c66c..e11de9e 100644 --- a/phaser/utils/analysis.py +++ b/phaser/utils/analysis.py @@ -294,4 +294,136 @@ def align_and_correlate(mat: NDArray[numpy.floating]) -> NDArray[numpy.floating] if return_crop: return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)], crop - return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)] \ No newline at end of file + return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)] + + +def _resample_to_shape(im, target_shape: t.Tuple[int, ...], xp: t.Any): + """Resample im to target_shape, staying on the input device.""" + xp_name = getattr(xp, '__name__', '') + + if xp_name == 'numpy': + from scipy.ndimage import zoom + zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape)) + return zoom(im, zoom_factors, order=1) + + if 'cupy' in xp_name: + from cupyx.scipy.ndimage import zoom + zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape)) + return zoom(im, zoom_factors, order=1) + + # JAX + import jax.image + return jax.image.resize(im, target_shape, method='linear') + + +def _uniform_filter_2d(im, size: int, xp: t.Any): + """ + Separable 2D box filter, staying on the input device. + + Dispatches to: + - scipy.ndimage.uniform_filter for numpy + - cupyx.scipy.ndimage.uniform_filter for cupy (GPU-native) + - cumsum-based separable filter for JAX / other backends (XLA-friendly) + """ + xp_name = getattr(xp, '__name__', '') + + if xp_name == 'numpy': + from scipy.ndimage import uniform_filter + return uniform_filter(im, size) + + if 'cupy' in xp_name: + from cupyx.scipy.ndimage import uniform_filter + return uniform_filter(im, size) + + # JAX or other: separable cumsum box filter (no D2H transfer, XLA-friendly) + def _along_axis(arr, axis: int): + pad = size // 2 + pad_config = [(0, 0)] * arr.ndim + pad_config[axis] = (pad, pad) + padded = xp.pad(arr, pad_config, mode='reflect') + + # prepend a zero slice so that cs[i+size] - cs[i] = sum(padded[i : i+size]) + zero_shape = list(padded.shape) + zero_shape[axis] = 1 + cs = xp.concatenate( + [xp.zeros(zero_shape, dtype=padded.dtype), xp.cumsum(padded, axis=axis)], + axis=axis, + ) + n = arr.shape[axis] + sl_end = [slice(None)] * arr.ndim + sl_end[axis] = slice(size, size + n) + sl_beg = [slice(None)] * arr.ndim + sl_beg[axis] = slice(0, n) + return (cs[tuple(sl_end)] - cs[tuple(sl_beg)]) / size + + return _along_axis(_along_axis(im, -2), -1) + + +## simplified version from skimage/metrics/_structural_similarity.py +def structural_similarity( + im1, + im2, + data_range=None, + win_size=3, + **kwargs, +) -> float: + """ + Compute the mean structural similarity index between two images. + Please pay attention to the `data_range` parameter with floating-point images. + + Parameters + ---------- + im1, im2 : ndarray + Arrays from any supported backend (numpy, JAX, cupy). All computation + stays on the input device; only the final scalar crosses the boundary. + data_range : float, optional + The data range of the input image (difference between maximum and + minimum possible values). Computed from im2 if not provided. + win_size : odd int in px, 3 as the smallest. + + Returns + ------- + mssim : float + The mean structural similarity index over the image. + """ + xp = get_array_module(im1, im2) + + im1 = im1.astype(numpy.float64) + im2 = im2.astype(numpy.float64) + + if im1.shape != im2.shape: + im2 = _resample_to_shape(im2, im1.shape, xp) + + if data_range is None: + data_range = float(im2.max() - im2.min()) + + K1 = 0.01 + K2 = 0.03 + + ux = _uniform_filter_2d(im1, win_size, xp) + uy = _uniform_filter_2d(im2, win_size, xp) + uxx = _uniform_filter_2d(im1 * im1, win_size, xp) + uyy = _uniform_filter_2d(im2 * im2, win_size, xp) + uxy = _uniform_filter_2d(im1 * im2, win_size, xp) + + vx = uxx - ux * ux + vy = uyy - uy * uy + vxy = uxy - ux * uy + + R = data_range + C1 = (K1 * R) ** 2 + C2 = (K2 * R) ** 2 + + A1 = 2 * ux * uy + C1 + A2 = 2 * vxy + C2 + B1 = ux ** 2 + uy ** 2 + C1 + B2 = vx + vy + C2 + + S = (A1 * A2) / (B1 * B2) + + # crop edges to avoid filter boundary effects + pad = (win_size - 1) // 2 + slices = tuple(slice(pad, s - pad) for s in im1.shape) + + # single scalar D2H transfer + return float(xp.mean(S[slices])) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 3d08640..47e6bbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ requires-python = ">=3.10" dependencies = [ "numpy>=2.0,<2.6", # tested on 2.3 "scipy>=1.7.0,<1.19", # tested on 1.11, 1.16 + "scikit-image>=0.19.0", "matplotlib~=3.8", "h5py~=3.8", "pyyaml>=5.3.1", From 23ff7dfcf6bd01e214badcb8141c16bcdd61174c Mon Sep 17 00:00:00 2001 From: wdwzyyg Date: Tue, 31 Mar 2026 00:48:49 -0400 Subject: [PATCH 3/6] multiscale ssim and interval between iters --- phaser/engines/conventional/run.py | 36 ++++++----- phaser/engines/gradient/run.py | 40 ++++++------ phaser/plan.py | 8 +-- phaser/utils/analysis.py | 97 +++++++++++++++++------------- 4 files changed, 103 insertions(+), 78 deletions(-) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index ddf51dd..5fb5796 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -1,4 +1,5 @@ import logging +import typing as t import numpy from copy import deepcopy @@ -59,9 +60,12 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: else: other_keys = () + calc_ssim_flag = process_flag(props.calc_ssim) + ssim_enabled = flag_any_true(props.calc_ssim, props.niter) + ssim_keys = ( - *(('obj_ssim',) if props.track_ssim else ()), - *(('probe_ssim',) if props.track_ssim else ()), + *(('obj_ssim',) if ssim_enabled else ()), + *(('probe_ssim',) if ssim_enabled else ()), ) for k in ('detector_loss', 'total_loss', *other_keys, *ssim_keys): @@ -88,6 +92,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: sim.state.progress = progress observer.start_engine(sim.state) + prev_ssim_state: t.Optional[ReconsState] = None for i in range(1, props.niter+1): sim.state.iter.engine_iter = i @@ -96,8 +101,6 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: iter_update_positions = update_positions({'state': sim.state, 'niter': props.niter}) iter_shuffle_groups = shuffle_groups({'state': sim.state, 'niter': props.niter}) - state_store = deepcopy(sim.state) if props.track_ssim else None - sim, pos_update, group_errors = solver.run_iteration( sim, groups.iter(sim.state.scan, i, iter_shuffle_groups), patterns=patterns, pattern_mask=pattern_mask, propagators=propagators, @@ -111,17 +114,6 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: assert_dtype(sim.state.object.data, cdtype) assert_dtype(sim.state.probe.data, cdtype) - if state_store is not None: - obj_before = numpy.angle(to_numpy(state_store.object.data[0])) - obj_after = numpy.angle(to_numpy(sim.state.object.data[0])) - progress['obj_ssim'].iters.append(i + start_i) - progress['obj_ssim'].values.append(structural_similarity(obj_after, obj_before)) - - probe_before = numpy.abs(to_numpy(state_store.probe.data[0])) - probe_after = numpy.abs(to_numpy(sim.state.probe.data[0])) - progress['probe_ssim'].iters.append(i + start_i) - progress['probe_ssim'].values.append(structural_similarity(probe_after, probe_before)) - sim = sim.apply_iter_constraints() if iter_update_positions: @@ -152,6 +144,20 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: progress[k].iters.append(i + start_i) progress[k].values.append(error) + # ssim: compare end-of-iteration state against the reference saved at the + # previous flag firing, then update the reference for the next interval + if ssim_enabled and calc_ssim_flag({'state': sim.state, 'niter': props.niter}): + if prev_ssim_state is not None: + ssim_o = structural_similarity(xp.angle(sim.state.object.data[0]), xp.angle(prev_ssim_state.object.data[0])) + progress['obj_ssim'].iters.append(i + start_i) + progress['obj_ssim'].values.append(ssim_o) + + ssim_p = structural_similarity(xp.abs(sim.state.probe.data[0]), xp.abs(prev_ssim_state.probe.data[0])) + progress['probe_ssim'].iters.append(i + start_i) + progress['probe_ssim'].values.append(ssim_p) + + prev_ssim_state = deepcopy(sim.state) + sim.state.progress = progress observer.update_iteration(sim.state, i, props.niter, {'total_loss': error}) diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index ca20c28..b3830d6 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -21,7 +21,7 @@ from phaser.hooks.solver import GradientSolver from phaser.hooks.regularization import CostRegularizer, GroupConstraint from phaser.plan import GradientEnginePlan -from phaser.types import process_flag, ReconsVar +from phaser.types import process_flag, flag_any_true, ReconsVar from ..common.simulation import GroupManager, make_propagators, tilt_propagators, slice_forwards, stream_patterns from phaser.utils.analysis import structural_similarity @@ -234,12 +234,14 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple loss_keys = ( 'detector_loss', 'total_loss', *(reg.name() for reg in regularizers), ) + calc_ssim_flag = process_flag(props.calc_ssim) + ssim_enabled = flag_any_true(props.calc_ssim, props.niter) + other_keys = ( *(('pos_update_rms',) if 'positions' in all_vars else ()), *(('tilt_update_rms',) if 'tilt' in all_vars else ()), - *(('obj_ssim',) if props.track_ssim else ()), - *(('probe_ssim',) if props.track_ssim else ()), - + *(('obj_ssim',) if ssim_enabled else ()), + *(('probe_ssim',) if ssim_enabled else ()), ) # populate missing keys in progress dictionary @@ -249,6 +251,7 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple # progress gets clobbered by the jits, so we keep track of it manually progress = state.progress + prev_ssim_state: t.Optional[ReconsState] = None for i in range(1, props.niter+1): state.iter.engine_iter = i @@ -320,21 +323,8 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple state, iter_solver_states[sol_i], filter_vars(iter_grads, solver.params), losses['total_loss'] ) - state_store = deepcopy(state) if props.track_ssim else None state = apply_update(state, update) - if 'obj_ssim' in other_keys and state_store is not None: - obj_before = numpy.angle(to_numpy(state_store.object.data[0])) - obj_after = numpy.angle(to_numpy(state.object.data[0])) - progress['obj_ssim'].iters.append(i + start_i) - progress['obj_ssim'].values.append(structural_similarity(obj_after, obj_before)) - - if 'probe_ssim' in other_keys and state_store is not None: - probe_before = numpy.abs(to_numpy(state_store.probe.data[0])) - probe_after = numpy.abs(to_numpy(state.probe.data[0])) - progress['probe_ssim'].iters.append(i + start_i) - progress['probe_ssim'].values.append(structural_similarity(probe_after, probe_before)) - if 'positions' in update: pos_update_rms = float(xp.mean(xp.linalg.norm(update['positions'], axis=-1))) progress['pos_update_rms'].iters.append(i + start_i) @@ -361,6 +351,22 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple state.object.sampling.check_scan(state.scan, state.probe.sampling.extent / 2.) assert_dtype(state.scan, dtype) + # ssim: compare end-of-iteration state against the reference saved at the + # previous flag firing, then update the reference for the next interval + if ssim_enabled and calc_ssim_flag({'state': state, 'niter': props.niter}): + if prev_ssim_state is not None: + ssim_o = structural_similarity(xp.angle(state.object.data), xp.angle(prev_ssim_state.object.data)) + progress['obj_ssim'].iters.append(i + start_i) + progress['obj_ssim'].values.append(ssim_o) + logger.info(f" Object phase SSIM: {ssim_o}") + + ssim_p = structural_similarity(xp.abs(state.probe.data), xp.abs(prev_ssim_state.probe.data)) + progress['probe_ssim'].iters.append(i + start_i) + progress['probe_ssim'].values.append(ssim_p) + logger.info(f" Probe intensity SSIM: {ssim_p}") + + prev_ssim_state = deepcopy(state) + state.progress = progress observer.update_iteration(state, i, props.niter, losses) diff --git a/phaser/plan.py b/phaser/plan.py index f5574cb..c2958be 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -88,9 +88,9 @@ class EnginePlan(Dataclass, kw_only=True): early_termination_loss: t.Optional[int] = None """Terminate after n iterations without improvement in total_loss""" early_termination_obj_ssim: t.Optional[int] = None - """Terminate after n iterations without improvement in obj_ssim (requires track_ssim=True)""" + """Terminate after n iterations without improvement in obj_ssim (requires calc_ssim to be enabled)""" early_termination_probe_ssim: t.Optional[int] = None - """Terminate after n iterations without improvement in probe_ssim (requires track_ssim=True)""" + """Terminate after n iterations without improvement in probe_ssim (requires calc_ssim to be enabled)""" early_termination_smoothing: float = 0.9 """ Smoothing factor to apply to error measurement for early termination. @@ -98,8 +98,8 @@ class EnginePlan(Dataclass, kw_only=True): (smooths over ~1/smoothing iterations) """ - track_ssim: bool = False - """Track SSIM between consecutive iterations as a convergence metric.""" + calc_ssim: FlagLike = False + """Compute SSIM between consecutive iterations as a convergence metric. Use SimpleFlag(every=N) to compute every N iterations.""" check_every_group: bool = False send_every_group: bool = False diff --git a/phaser/utils/analysis.py b/phaser/utils/analysis.py index e11de9e..2ea96c0 100644 --- a/phaser/utils/analysis.py +++ b/phaser/utils/analysis.py @@ -298,7 +298,7 @@ def align_and_correlate(mat: NDArray[numpy.floating]) -> NDArray[numpy.floating] def _resample_to_shape(im, target_shape: t.Tuple[int, ...], xp: t.Any): - """Resample im to target_shape, staying on the input device.""" + """Resample im to target_shape, staying on the input device. first dimension untouched""" xp_name = getattr(xp, '__name__', '') if xp_name == 'numpy': @@ -316,9 +316,11 @@ def _resample_to_shape(im, target_shape: t.Tuple[int, ...], xp: t.Any): return jax.image.resize(im, target_shape, method='linear') -def _uniform_filter_2d(im, size: int, xp: t.Any): +def _uniform_filter_spatial(im, size: int, xp: t.Any): """ - Separable 2D box filter, staying on the input device. + Separable box filter over the last two spatial dims only (any ndim >= 2). + Accepts stacked inputs e.g. (N, H, W), filtering H and W only — enabling + fused multi-statistic computation in one call. Dispatches to: - scipy.ndimage.uniform_filter for numpy @@ -326,23 +328,22 @@ def _uniform_filter_2d(im, size: int, xp: t.Any): - cumsum-based separable filter for JAX / other backends (XLA-friendly) """ xp_name = getattr(xp, '__name__', '') + sizes = [1] * (im.ndim - 2) + [size, size] if xp_name == 'numpy': from scipy.ndimage import uniform_filter - return uniform_filter(im, size) + return uniform_filter(im, sizes) if 'cupy' in xp_name: from cupyx.scipy.ndimage import uniform_filter - return uniform_filter(im, size) + return uniform_filter(im, sizes) - # JAX or other: separable cumsum box filter (no D2H transfer, XLA-friendly) + # JAX or other: cumsum box filter along axes -2 and -1 only (XLA-friendly) def _along_axis(arr, axis: int): pad = size // 2 pad_config = [(0, 0)] * arr.ndim pad_config[axis] = (pad, pad) padded = xp.pad(arr, pad_config, mode='reflect') - - # prepend a zero slice so that cs[i+size] - cs[i] = sum(padded[i : i+size]) zero_shape = list(padded.shape) zero_shape[axis] = 1 cs = xp.concatenate( @@ -359,32 +360,43 @@ def _along_axis(arr, axis: int): return _along_axis(_along_axis(im, -2), -1) -## simplified version from skimage/metrics/_structural_similarity.py def structural_similarity( im1, im2, data_range=None, - win_size=3, + win_size: int = 3, + num_scales: int = 3, **kwargs, ) -> float: """ - Compute the mean structural similarity index between two images. - Please pay attention to the `data_range` parameter with floating-point images. + Multi-scale contrast-structure similarity (geometric mean across scales). + + Computes the contrast-structure (CS) component of SSIM at each scale of a + bilinear downsampling pyramid, then combines as a geometric mean: + result = (cs_1 * cs_2 * ... * cs_N)^(1/N) + + Luminance is omitted. Equal scale weights are used. + + Efficient implementation: + - fused filter pass: all statistics filtered in one call per scale + - bilinear downsampling pyramid via _resample_to_shape + - fully on-device: only the final scalar crosses the device boundary Parameters ---------- im1, im2 : ndarray - Arrays from any supported backend (numpy, JAX, cupy). All computation - stays on the input device; only the final scalar crosses the boundary. + Arrays from any supported backend (numpy, JAX, cupy). data_range : float, optional - The data range of the input image (difference between maximum and - minimum possible values). Computed from im2 if not provided. - win_size : odd int in px, 3 as the smallest. + Computed from im2 if not provided. + win_size : int + Box filter size in pixels (default 3). + num_scales : int + Number of pyramid levels (default 3). Returns ------- mssim : float - The mean structural similarity index over the image. + MS-SSIM value in [0, 1]. """ xp = get_array_module(im1, im2) @@ -393,37 +405,38 @@ def structural_similarity( if im1.shape != im2.shape: im2 = _resample_to_shape(im2, im1.shape, xp) - if data_range is None: data_range = float(im2.max() - im2.min()) - K1 = 0.01 - K2 = 0.03 + C2 = (0.03 * data_range) ** 2 + + pad = (win_size - 1) // 2 + weight = 1.0 / num_scales - ux = _uniform_filter_2d(im1, win_size, xp) - uy = _uniform_filter_2d(im2, win_size, xp) - uxx = _uniform_filter_2d(im1 * im1, win_size, xp) - uyy = _uniform_filter_2d(im2 * im2, win_size, xp) - uxy = _uniform_filter_2d(im1 * im2, win_size, xp) + mssim = 1.0 + for scale in range(num_scales): + if min(im1.shape[-2:]) < win_size: + break - vx = uxx - ux * ux - vy = uyy - uy * uy - vxy = uxy - ux * uy + # fused: stack [im1, im2, im1², im2², im1·im2] and filter in one pass + stacked = xp.stack([im1, im2, im1 * im1, im2 * im2, im1 * im2]) + f = _uniform_filter_spatial(stacked, win_size, xp) + ux, uy, uxx, uyy, uxy = f[0], f[1], f[2], f[3], f[4] - R = data_range - C1 = (K1 * R) ** 2 - C2 = (K2 * R) ** 2 + vx = uxx - ux * ux + vy = uyy - uy * uy + vxy = uxy - ux * uy - A1 = 2 * ux * uy + C1 - A2 = 2 * vxy + C2 - B1 = ux ** 2 + uy ** 2 + C1 - B2 = vx + vy + C2 + # crop boundary artifacts + s = (slice(pad, -pad), slice(pad, -pad)) + vx, vy, vxy = vx[s], vy[s], vxy[s] - S = (A1 * A2) / (B1 * B2) + cs = float(xp.mean((2 * vxy + C2) / (vx + vy + C2))) + mssim *= cs ** weight - # crop edges to avoid filter boundary effects - pad = (win_size - 1) // 2 - slices = tuple(slice(pad, s - pad) for s in im1.shape) + if scale < num_scales - 1: + new_shape = (im1.shape[0], im1.shape[-2] // 2, im1.shape[-1] // 2) + im1 = _resample_to_shape(im1, new_shape, xp) + im2 = _resample_to_shape(im2, new_shape, xp) - # single scalar D2H transfer - return float(xp.mean(S[slices])) \ No newline at end of file + return mssim \ No newline at end of file From 6ee250fd0592ca0ddf74b59c137358ebae5b7679 Mon Sep 17 00:00:00 2001 From: wdwzyyg Date: Wed, 1 Apr 2026 13:48:29 -0400 Subject: [PATCH 4/6] minor change about iters --- phaser/engines/conventional/run.py | 4 ++-- phaser/engines/gradient/run.py | 4 ++-- phaser/observer.py | 27 ++++++++++++++++++--------- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 5fb5796..0b878ac 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -149,11 +149,11 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: if ssim_enabled and calc_ssim_flag({'state': sim.state, 'niter': props.niter}): if prev_ssim_state is not None: ssim_o = structural_similarity(xp.angle(sim.state.object.data[0]), xp.angle(prev_ssim_state.object.data[0])) - progress['obj_ssim'].iters.append(i + start_i) + progress['obj_ssim'].iters.append(int(sim.state.iter.total_iter)) progress['obj_ssim'].values.append(ssim_o) ssim_p = structural_similarity(xp.abs(sim.state.probe.data[0]), xp.abs(prev_ssim_state.probe.data[0])) - progress['probe_ssim'].iters.append(i + start_i) + progress['probe_ssim'].iters.append(int(sim.state.iter.total_iter)) progress['probe_ssim'].values.append(ssim_p) prev_ssim_state = deepcopy(sim.state) diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index b3830d6..f928c1d 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -356,12 +356,12 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple if ssim_enabled and calc_ssim_flag({'state': state, 'niter': props.niter}): if prev_ssim_state is not None: ssim_o = structural_similarity(xp.angle(state.object.data), xp.angle(prev_ssim_state.object.data)) - progress['obj_ssim'].iters.append(i + start_i) + progress['obj_ssim'].iters.append(int(state.iter.total_iter)) progress['obj_ssim'].values.append(ssim_o) logger.info(f" Object phase SSIM: {ssim_o}") ssim_p = structural_similarity(xp.abs(state.probe.data), xp.abs(prev_ssim_state.probe.data)) - progress['probe_ssim'].iters.append(i + start_i) + progress['probe_ssim'].iters.append(int(state.iter.total_iter)) progress['probe_ssim'].values.append(ssim_p) logger.info(f" Probe intensity SSIM: {ssim_p}") diff --git a/phaser/observer.py b/phaser/observer.py index 87c1cac..743ad7d 100644 --- a/phaser/observer.py +++ b/phaser/observer.py @@ -193,7 +193,7 @@ def __init__( self._patience['probe_ssim'] = patience_probe_ssim self._best: t.Dict[str, float] = {} - self._no_improvement: t.Dict[str, int] = {} + self._last_improvement_iter: t.Dict[str, int] = {} self._smoothed: t.Dict[str, float] = {} def init_engine( @@ -201,17 +201,25 @@ def init_engine( plan: EnginePlan, **kwargs: t.Any ): self._best = {} - self._no_improvement = {k: 0 for k in self._patience} + self._last_improvement_iter = {} self._smoothed = {} def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): + current_iter = int(state.iter.total_iter) + for key, patience in self._patience.items(): - # read value: loss comes from errors dict, ssim metrics from state.progress + # read value: loss from errors dict every iteration; + # ssim metrics only when a new value was computed this iteration if key == 'total_loss': value: t.Optional[float] = errors.get('total_loss') else: prog = state.progress.get(key) if state.progress else None - value = prog.values[-1] if prog is not None and len(prog.values) else None + if prog is None or not len(prog.values): + continue + # skip if no new ssim value was produced this iteration + if not len(prog.iters) or prog.iters[-1] != current_iter: + continue + value = prog.values[-1] if value is None: continue @@ -231,12 +239,13 @@ def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[st if improved: self._best[key] = value - self._no_improvement[key] = 0 - else: - self._no_improvement[key] += 1 + self._last_improvement_iter[key] = current_iter - if self._no_improvement[key] >= patience: - logging.info(f"Early termination: {key} no improvement for {patience} iterations") + iters_without_improvement = current_iter - self._last_improvement_iter.get(key, current_iter) + if iters_without_improvement >= patience: + logging.info( + f"Early termination: {key} no improvement for {iters_without_improvement} iterations" + ) raise EarlyTermination(state, self.continue_next_engine) From 6c3f51f9e3fdcb356338a21356fc81257f140fcf Mon Sep 17 00:00:00 2001 From: wdwzyyg Date: Fri, 3 Apr 2026 15:54:10 -0400 Subject: [PATCH 5/6] transfer change from run to obsever --- phaser/engines/conventional/run.py | 27 +------------ phaser/engines/gradient/run.py | 24 ------------ phaser/execute.py | 14 +++++-- phaser/observer.py | 62 +++++++++++++++++++++++++++++- 4 files changed, 73 insertions(+), 54 deletions(-) diff --git a/phaser/engines/conventional/run.py b/phaser/engines/conventional/run.py index 0b878ac..04dee87 100644 --- a/phaser/engines/conventional/run.py +++ b/phaser/engines/conventional/run.py @@ -1,11 +1,9 @@ import logging import typing as t import numpy -from copy import deepcopy from phaser.utils.misc import mask_fraction_of_groups from phaser.utils.num import assert_dtype, cast_array_module, to_numpy, to_complex_dtype -from phaser.utils.analysis import structural_similarity from phaser.observer import Observer from phaser.hooks import EngineArgs from phaser.plan import ConventionalEnginePlan @@ -60,15 +58,7 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: else: other_keys = () - calc_ssim_flag = process_flag(props.calc_ssim) - ssim_enabled = flag_any_true(props.calc_ssim, props.niter) - - ssim_keys = ( - *(('obj_ssim',) if ssim_enabled else ()), - *(('probe_ssim',) if ssim_enabled else ()), - ) - - for k in ('detector_loss', 'total_loss', *other_keys, *ssim_keys): + for k in ('detector_loss', 'total_loss', *other_keys): if k not in sim.state.progress: sim.state.progress[k] = ProgressState() @@ -92,7 +82,6 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: sim.state.progress = progress observer.start_engine(sim.state) - prev_ssim_state: t.Optional[ReconsState] = None for i in range(1, props.niter+1): sim.state.iter.engine_iter = i @@ -144,20 +133,6 @@ def run_engine(args: EngineArgs, props: ConventionalEnginePlan) -> ReconsState: progress[k].iters.append(i + start_i) progress[k].values.append(error) - # ssim: compare end-of-iteration state against the reference saved at the - # previous flag firing, then update the reference for the next interval - if ssim_enabled and calc_ssim_flag({'state': sim.state, 'niter': props.niter}): - if prev_ssim_state is not None: - ssim_o = structural_similarity(xp.angle(sim.state.object.data[0]), xp.angle(prev_ssim_state.object.data[0])) - progress['obj_ssim'].iters.append(int(sim.state.iter.total_iter)) - progress['obj_ssim'].values.append(ssim_o) - - ssim_p = structural_similarity(xp.abs(sim.state.probe.data[0]), xp.abs(prev_ssim_state.probe.data[0])) - progress['probe_ssim'].iters.append(int(sim.state.iter.total_iter)) - progress['probe_ssim'].values.append(ssim_p) - - prev_ssim_state = deepcopy(sim.state) - sim.state.progress = progress observer.update_iteration(sim.state, i, props.niter, {'total_loss': error}) diff --git a/phaser/engines/gradient/run.py b/phaser/engines/gradient/run.py index f928c1d..524ea2e 100644 --- a/phaser/engines/gradient/run.py +++ b/phaser/engines/gradient/run.py @@ -1,5 +1,4 @@ import logging -from copy import deepcopy from functools import partial import typing as t @@ -23,7 +22,6 @@ from phaser.plan import GradientEnginePlan from phaser.types import process_flag, flag_any_true, ReconsVar from ..common.simulation import GroupManager, make_propagators, tilt_propagators, slice_forwards, stream_patterns -from phaser.utils.analysis import structural_similarity logger = logging.getLogger(__name__) @@ -234,14 +232,9 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple loss_keys = ( 'detector_loss', 'total_loss', *(reg.name() for reg in regularizers), ) - calc_ssim_flag = process_flag(props.calc_ssim) - ssim_enabled = flag_any_true(props.calc_ssim, props.niter) - other_keys = ( *(('pos_update_rms',) if 'positions' in all_vars else ()), *(('tilt_update_rms',) if 'tilt' in all_vars else ()), - *(('obj_ssim',) if ssim_enabled else ()), - *(('probe_ssim',) if ssim_enabled else ()), ) # populate missing keys in progress dictionary @@ -251,7 +244,6 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple # progress gets clobbered by the jits, so we keep track of it manually progress = state.progress - prev_ssim_state: t.Optional[ReconsState] = None for i in range(1, props.niter+1): state.iter.engine_iter = i @@ -351,22 +343,6 @@ def iter_patterns(groups: t.Iterable[NDArray[numpy.int_]]) -> t.Iterable[t.Tuple state.object.sampling.check_scan(state.scan, state.probe.sampling.extent / 2.) assert_dtype(state.scan, dtype) - # ssim: compare end-of-iteration state against the reference saved at the - # previous flag firing, then update the reference for the next interval - if ssim_enabled and calc_ssim_flag({'state': state, 'niter': props.niter}): - if prev_ssim_state is not None: - ssim_o = structural_similarity(xp.angle(state.object.data), xp.angle(prev_ssim_state.object.data)) - progress['obj_ssim'].iters.append(int(state.iter.total_iter)) - progress['obj_ssim'].values.append(ssim_o) - logger.info(f" Object phase SSIM: {ssim_o}") - - ssim_p = structural_similarity(xp.abs(state.probe.data), xp.abs(prev_ssim_state.probe.data)) - progress['probe_ssim'].iters.append(int(state.iter.total_iter)) - progress['probe_ssim'].values.append(ssim_p) - logger.info(f" Probe intensity SSIM: {ssim_p}") - - prev_ssim_state = deepcopy(state) - state.progress = progress observer.update_iteration(state, i, props.niter, losses) diff --git a/phaser/execute.py b/phaser/execute.py index 37332e4..d8ed699 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -15,7 +15,7 @@ from .hooks import EngineHook, Hook, ObjectHook, RawData from .plan import GradientEnginePlan, ReconsPlan, EnginePlan, ScanHook, ProbeHook, TiltHook from .state import Patterns, ReconsState, PartialReconsState, IterState, PreparedRecons -from .observer import Observer, LoggingObserver, PatienceObserver, SaveObserver, ObserverSet +from .observer import Observer, LoggingObserver, PatienceObserver, SSIMObserver, SaveObserver, ObserverSet def execute_plan( @@ -57,15 +57,23 @@ def execute_engine( engine_i = recons.state.iter.engine_num + extra_observers: t.List[Observer] = [] + + if plan.calc_ssim is not False: + extra_observers.append(SSIMObserver(plan.calc_ssim)) + if any(v is not None for v in ( plan.early_termination_loss, plan.early_termination_obj_ssim, plan.early_termination_probe_ssim )): - engine_observer = ObserverSet((recons.observer, PatienceObserver( + extra_observers.append(PatienceObserver( patience_loss=plan.early_termination_loss, patience_obj_ssim=plan.early_termination_obj_ssim, patience_probe_ssim=plan.early_termination_probe_ssim, smoothing=plan.early_termination_smoothing, - ))) + )) + + if extra_observers: + engine_observer = ObserverSet((recons.observer, *extra_observers)) else: engine_observer = recons.observer diff --git a/phaser/observer.py b/phaser/observer.py index 743ad7d..4fb5217 100644 --- a/phaser/observer.py +++ b/phaser/observer.py @@ -10,7 +10,7 @@ from phaser.types import EarlyTermination, flag_any_true, process_flag if t.TYPE_CHECKING: - from phaser.hooks.schedule import FlagArgs + from phaser.hooks.schedule import FlagArgs, FlagLike from typing_extensions import Self P = t.ParamSpec('P') @@ -249,6 +249,66 @@ def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[st raise EarlyTermination(state, self.continue_next_engine) +class SSIMObserver(Observer): + """Computes obj_ssim and probe_ssim at each calc_ssim flag firing.""" + + def __init__(self, calc_ssim: 'FlagLike'): + from phaser.types import process_flag, flag_any_true + self._calc_ssim_raw = calc_ssim + self._calc_ssim_flag = process_flag(calc_ssim) + self._ssim_enabled: bool = False + self._prev_state: t.Optional[ReconsState] = None + + def init_engine( + self, init_state: ReconsState, *, recons_name: str, + plan: EnginePlan, **kwargs: t.Any + ): + from phaser.types import flag_any_true + self._ssim_enabled = flag_any_true(self._calc_ssim_raw, plan.niter) + self._prev_state = None + + if self._ssim_enabled: + for k in ('obj_ssim', 'probe_ssim'): + if k not in init_state.progress: + init_state.progress[k] = ProgressState() + + def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): + if not self._ssim_enabled: + return + if not self._calc_ssim_flag({'state': state, 'niter': n}): + return + + from copy import deepcopy + from phaser.utils.num import get_array_module + from phaser.utils.analysis import structural_similarity + + if self._prev_state is not None: + xp = get_array_module(state.object.data) + total_iter = int(state.iter.total_iter) + prev_iter = int(self._prev_state.iter.total_iter) + + ssim_o = structural_similarity( + xp.angle(state.object.data), + xp.angle(self._prev_state.object.data), + ) + state.progress['obj_ssim'].iters.append(total_iter) + state.progress['obj_ssim'].values.append(ssim_o) + + ssim_p = structural_similarity( + xp.abs(state.probe.data), + xp.abs(self._prev_state.probe.data), + ) + state.progress['probe_ssim'].iters.append(total_iter) + state.progress['probe_ssim'].values.append(ssim_p) + + logging.info( + f"SSIM (iters {prev_iter}→{total_iter}): " + f"obj={ssim_o:.4f} probe={ssim_p:.4f}" + ) + + self._prev_state = deepcopy(state) + + class SaveObserver(Observer): def __init__(self): self.out_dir: t.Optional[Path] = None From 8c48b0940ab062d0351e973f16f2e971fae7cdb5 Mon Sep 17 00:00:00 2001 From: wdwzyyg Date: Fri, 10 Apr 2026 17:41:37 -0400 Subject: [PATCH 6/6] made changes accordingly, naming, affine_transform and using cpu --- phaser/execute.py | 12 +++---- phaser/observer.py | 73 ++++++++++++++++++++-------------------- phaser/plan.py | 10 +++--- phaser/utils/analysis.py | 36 +++++++------------- pyproject.toml | 1 - 5 files changed, 60 insertions(+), 72 deletions(-) diff --git a/phaser/execute.py b/phaser/execute.py index d8ed699..2fb3a58 100644 --- a/phaser/execute.py +++ b/phaser/execute.py @@ -15,7 +15,7 @@ from .hooks import EngineHook, Hook, ObjectHook, RawData from .plan import GradientEnginePlan, ReconsPlan, EnginePlan, ScanHook, ProbeHook, TiltHook from .state import Patterns, ReconsState, PartialReconsState, IterState, PreparedRecons -from .observer import Observer, LoggingObserver, PatienceObserver, SSIMObserver, SaveObserver, ObserverSet +from .observer import Observer, LoggingObserver, PatienceObserver, RelMsSSIMObserver, SaveObserver, ObserverSet def execute_plan( @@ -59,16 +59,16 @@ def execute_engine( extra_observers: t.List[Observer] = [] - if plan.calc_ssim is not False: - extra_observers.append(SSIMObserver(plan.calc_ssim)) + if plan.calc_rel_msssim is not False: + extra_observers.append(RelMsSSIMObserver(plan.calc_rel_msssim)) if any(v is not None for v in ( - plan.early_termination_loss, plan.early_termination_obj_ssim, plan.early_termination_probe_ssim + plan.early_termination_loss, plan.early_termination_obj_rel_msssim, plan.early_termination_probe_rel_msssim )): extra_observers.append(PatienceObserver( patience_loss=plan.early_termination_loss, - patience_obj_ssim=plan.early_termination_obj_ssim, - patience_probe_ssim=plan.early_termination_probe_ssim, + patience_obj_rel_msssim=plan.early_termination_obj_rel_msssim, + patience_probe_rel_msssim=plan.early_termination_probe_rel_msssim, smoothing=plan.early_termination_smoothing, )) diff --git a/phaser/observer.py b/phaser/observer.py index 4fb5217..6693736 100644 --- a/phaser/observer.py +++ b/phaser/observer.py @@ -170,13 +170,13 @@ def finish_recons(self, state: ReconsState): class PatienceObserver(Observer): # metrics where higher values indicate improvement - _HIGHER_IS_BETTER: t.FrozenSet[str] = frozenset({'obj_ssim', 'probe_ssim'}) + _HIGHER_IS_BETTER: t.FrozenSet[str] = frozenset({'obj_rel_msssim', 'probe_rel_msssim'}) def __init__( self, patience_loss: t.Optional[int] = None, - patience_obj_ssim: t.Optional[int] = None, - patience_probe_ssim: t.Optional[int] = None, + patience_obj_rel_msssim: t.Optional[int] = None, + patience_probe_rel_msssim: t.Optional[int] = None, smoothing: float = 0.1, continue_next_engine: bool = True, ): @@ -187,10 +187,10 @@ def __init__( self._patience: t.Dict[str, int] = {} if patience_loss is not None: self._patience['total_loss'] = patience_loss - if patience_obj_ssim is not None: - self._patience['obj_ssim'] = patience_obj_ssim - if patience_probe_ssim is not None: - self._patience['probe_ssim'] = patience_probe_ssim + if patience_obj_rel_msssim is not None: + self._patience['obj_rel_msssim'] = patience_obj_rel_msssim + if patience_probe_rel_msssim is not None: + self._patience['probe_rel_msssim'] = patience_probe_rel_msssim self._best: t.Dict[str, float] = {} self._last_improvement_iter: t.Dict[str, int] = {} @@ -249,64 +249,63 @@ def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[st raise EarlyTermination(state, self.continue_next_engine) -class SSIMObserver(Observer): - """Computes obj_ssim and probe_ssim at each calc_ssim flag firing.""" +class RelMsSSIMObserver(Observer): + """Computes obj_rel_msssim and probe_rel_msssim at each calc_rel_msssim flag firing.""" - def __init__(self, calc_ssim: 'FlagLike'): + def __init__(self, calc_rel_msssim: 'FlagLike'): from phaser.types import process_flag, flag_any_true - self._calc_ssim_raw = calc_ssim - self._calc_ssim_flag = process_flag(calc_ssim) + self._calc_rel_msssim_raw = calc_rel_msssim + self._calc_rel_msssim_flag = process_flag(calc_rel_msssim) self._ssim_enabled: bool = False - self._prev_state: t.Optional[ReconsState] = None + # CPU-side snapshot: (total_iter, obj_phase, probe_abs) as numpy arrays + self._prev_snapshot: t.Optional[t.Tuple[int, 'numpy.ndarray', 'numpy.ndarray']] = None def init_engine( self, init_state: ReconsState, *, recons_name: str, plan: EnginePlan, **kwargs: t.Any ): from phaser.types import flag_any_true - self._ssim_enabled = flag_any_true(self._calc_ssim_raw, plan.niter) - self._prev_state = None + self._ssim_enabled = flag_any_true(self._calc_rel_msssim_raw, plan.niter) + self._prev_snapshot = None if self._ssim_enabled: - for k in ('obj_ssim', 'probe_ssim'): + for k in ('obj_rel_msssim', 'probe_rel_msssim'): if k not in init_state.progress: init_state.progress[k] = ProgressState() def update_iteration(self, state: ReconsState, i: int, n: int, errors: t.Dict[str, float]): if not self._ssim_enabled: return - if not self._calc_ssim_flag({'state': state, 'niter': n}): + if not self._calc_rel_msssim_flag({'state': state, 'niter': n}): return - from copy import deepcopy - from phaser.utils.num import get_array_module + from phaser.utils.num import get_array_module, to_numpy from phaser.utils.analysis import structural_similarity - if self._prev_state is not None: - xp = get_array_module(state.object.data) - total_iter = int(state.iter.total_iter) - prev_iter = int(self._prev_state.iter.total_iter) + xp = get_array_module(state.object.data) + total_iter = int(state.iter.total_iter) - ssim_o = structural_similarity( - xp.angle(state.object.data), - xp.angle(self._prev_state.object.data), - ) - state.progress['obj_ssim'].iters.append(total_iter) - state.progress['obj_ssim'].values.append(ssim_o) + # transfer only the two arrays needed; forces GPU→CPU sync here + obj_now = to_numpy(xp.angle(state.object.data)) + probe_now = to_numpy(xp.abs(state.probe.data)) - ssim_p = structural_similarity( - xp.abs(state.probe.data), - xp.abs(self._prev_state.probe.data), - ) - state.progress['probe_ssim'].iters.append(total_iter) - state.progress['probe_ssim'].values.append(ssim_p) + if self._prev_snapshot is not None: + prev_iter, obj_prev, probe_prev = self._prev_snapshot + + ssim_o = structural_similarity(obj_now, obj_prev) + state.progress['obj_rel_msssim'].iters.append(total_iter) + state.progress['obj_rel_msssim'].values.append(ssim_o) + + ssim_p = structural_similarity(probe_now, probe_prev) + state.progress['probe_rel_msssim'].iters.append(total_iter) + state.progress['probe_rel_msssim'].values.append(ssim_p) logging.info( - f"SSIM (iters {prev_iter}→{total_iter}): " + f"Relative multiscale SSIM (iters {prev_iter}→{total_iter}): " f"obj={ssim_o:.4f} probe={ssim_p:.4f}" ) - self._prev_state = deepcopy(state) + self._prev_snapshot = (total_iter, obj_now, probe_now) class SaveObserver(Observer): diff --git a/phaser/plan.py b/phaser/plan.py index c2958be..5e1247a 100644 --- a/phaser/plan.py +++ b/phaser/plan.py @@ -87,10 +87,10 @@ class EnginePlan(Dataclass, kw_only=True): early_termination_loss: t.Optional[int] = None """Terminate after n iterations without improvement in total_loss""" - early_termination_obj_ssim: t.Optional[int] = None - """Terminate after n iterations without improvement in obj_ssim (requires calc_ssim to be enabled)""" - early_termination_probe_ssim: t.Optional[int] = None - """Terminate after n iterations without improvement in probe_ssim (requires calc_ssim to be enabled)""" + early_termination_obj_rel_msssim: t.Optional[int] = None + """Terminate after n iterations without improvement in obj_rel_msssim (requires calc_rel_msssim to be enabled)""" + early_termination_probe_rel_msssim: t.Optional[int] = None + """Terminate after n iterations without improvement in probe_rel_msssim (requires calc_rel_msssim to be enabled)""" early_termination_smoothing: float = 0.9 """ Smoothing factor to apply to error measurement for early termination. @@ -98,7 +98,7 @@ class EnginePlan(Dataclass, kw_only=True): (smooths over ~1/smoothing iterations) """ - calc_ssim: FlagLike = False + calc_rel_msssim: FlagLike = False """Compute SSIM between consecutive iterations as a convergence metric. Use SimpleFlag(every=N) to compute every N iterations.""" check_every_group: bool = False diff --git a/phaser/utils/analysis.py b/phaser/utils/analysis.py index 2ea96c0..9c3c9c5 100644 --- a/phaser/utils/analysis.py +++ b/phaser/utils/analysis.py @@ -297,25 +297,6 @@ def align_and_correlate(mat: NDArray[numpy.floating]) -> NDArray[numpy.floating] return upsamp_obj[(slice(None), *crop)], ground_truth[tuple(crop)] -def _resample_to_shape(im, target_shape: t.Tuple[int, ...], xp: t.Any): - """Resample im to target_shape, staying on the input device. first dimension untouched""" - xp_name = getattr(xp, '__name__', '') - - if xp_name == 'numpy': - from scipy.ndimage import zoom - zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape)) - return zoom(im, zoom_factors, order=1) - - if 'cupy' in xp_name: - from cupyx.scipy.ndimage import zoom - zoom_factors = tuple(s1 / s2 for s1, s2 in zip(target_shape, im.shape)) - return zoom(im, zoom_factors, order=1) - - # JAX - import jax.image - return jax.image.resize(im, target_shape, method='linear') - - def _uniform_filter_spatial(im, size: int, xp: t.Any): """ Separable box filter over the last two spatial dims only (any ndim >= 2). @@ -379,7 +360,7 @@ def structural_similarity( Efficient implementation: - fused filter pass: all statistics filtered in one call per scale - - bilinear downsampling pyramid via _resample_to_shape + - bilinear downsampling pyramid via affine_transform - fully on-device: only the final scalar crosses the device boundary Parameters @@ -398,13 +379,22 @@ def structural_similarity( mssim : float MS-SSIM value in [0, 1]. """ + from phaser.utils.image import affine_transform as _affine_transform + + def _resample(im, target_shape): + scale_y = im.shape[-2] / target_shape[-2] + scale_x = im.shape[-1] / target_shape[-1] + matrix = numpy.array([[scale_y, 0.0], [0.0, scale_x]]) + offset = numpy.array([0.5 * (scale_y - 1.0), 0.5 * (scale_x - 1.0)]) + return _affine_transform(im, matrix, offset=offset, output_shape=target_shape[-2:], order=1) + xp = get_array_module(im1, im2) im1 = im1.astype(numpy.float64) im2 = im2.astype(numpy.float64) if im1.shape != im2.shape: - im2 = _resample_to_shape(im2, im1.shape, xp) + im2 = _resample(im2, im1.shape) if data_range is None: data_range = float(im2.max() - im2.min()) @@ -436,7 +426,7 @@ def structural_similarity( if scale < num_scales - 1: new_shape = (im1.shape[0], im1.shape[-2] // 2, im1.shape[-1] // 2) - im1 = _resample_to_shape(im1, new_shape, xp) - im2 = _resample_to_shape(im2, new_shape, xp) + im1 = _resample(im1, new_shape) + im2 = _resample(im2, new_shape) return mssim \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 47e6bbf..3d08640 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ requires-python = ">=3.10" dependencies = [ "numpy>=2.0,<2.6", # tested on 2.3 "scipy>=1.7.0,<1.19", # tested on 1.11, 1.16 - "scikit-image>=0.19.0", "matplotlib~=3.8", "h5py~=3.8", "pyyaml>=5.3.1",