diff --git a/docs/api.md b/docs/api.md index 1bf4fdf92..d9da2508d 100644 --- a/docs/api.md +++ b/docs/api.md @@ -149,6 +149,9 @@ See the {doc}`extensibility guide ` for how to implement a custo experimental.tl.calculate_tiling_qc experimental.tl.TilingQCParams + experimental.tl.align + experimental.tl.align_by_landmarks + experimental.tl.AlignResult experimental.pl.tiling_qc experimental.im.fit_stain_reference experimental.im.apply_stain_normalization diff --git a/hatch.toml b/hatch.toml index 065fe2dfb..f2672cd35 100644 --- a/hatch.toml +++ b/hatch.toml @@ -17,10 +17,14 @@ extra-dependencies = ["diff-cover"] matrix = [ { deps = ["stable"], python = ["3.11", "3.12", "3.13"] }, { deps = ["pre"], python = ["3.13"] }, + { deps = ["stable"], python = ["3.13"], extras = ["jax"] }, ] overrides.matrix.deps.env-vars = [ { key = "UV_PRERELEASE", value = "allow", if = ["pre"] }, ] +overrides.matrix.extras.features = [ + { value = "jax", if = ["jax"] }, +] # default commands (only `cov-report` is overridden) scripts.run = "pytest{env:HATCH_TEST_ARGS:} -p no:cov {args}" scripts.run-cov = "coverage run -m pytest{env:HATCH_TEST_ARGS:} -p no:cov {args}" diff --git a/pyproject.toml b/pyproject.toml index a3a9308cb..e24deb7cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,9 @@ dependencies = [ "xarray>=2024.10", "zarr>=3", ] +optional-dependencies.jax = [ + "jax", +] optional-dependencies.leiden = [ "leidenalg", "spatialleiden>=0.4", diff --git a/src/squidpy/experimental/_methods/__init__.py b/src/squidpy/experimental/_methods/__init__.py new file mode 100644 index 000000000..1fa74d939 --- /dev/null +++ b/src/squidpy/experimental/_methods/__init__.py @@ -0,0 +1,8 @@ +"""In-memory model-fitting core for experimental methods.""" + +from __future__ import annotations + +from squidpy.experimental._methods._protocols import AlignLandmarksFn, AlignResult, AlignSamplesFn +from squidpy.experimental._methods._registry import Registry + +__all__ = ["Registry", "AlignResult", "AlignSamplesFn", "AlignLandmarksFn"] diff --git a/src/squidpy/experimental/_methods/_families.py b/src/squidpy/experimental/_methods/_families.py new file mode 100644 index 000000000..6f3610f8d --- /dev/null +++ b/src/squidpy/experimental/_methods/_families.py @@ -0,0 +1,14 @@ +"""Estimator family registries.""" + +from __future__ import annotations + +from squidpy.experimental._methods._protocols import AlignLandmarksFn, AlignSamplesFn +from squidpy.experimental._methods._registry import Registry + +#: Sample-to-sample alignment estimators -- ref/query point clouds in, transform out. +#: Consumed by ``squidpy.experimental.tl.align``. +ALIGN_SAMPLES: Registry[AlignSamplesFn] = Registry("align_samples") + +#: Closed-form landmark alignment estimators -- paired landmarks in, affine out. +#: Consumed by ``squidpy.experimental.tl.align_by_landmarks``. +ALIGN_LANDMARKS: Registry[AlignLandmarksFn] = Registry("align_landmarks") diff --git a/src/squidpy/experimental/_methods/_protocols.py b/src/squidpy/experimental/_methods/_protocols.py new file mode 100644 index 000000000..6f2366a65 --- /dev/null +++ b/src/squidpy/experimental/_methods/_protocols.py @@ -0,0 +1,59 @@ +"""Structural contracts shared across the alignment estimator families. + +These :class:`~typing.Protocol` types are what the public API and the registries +are typed against, so the orchestration layer never names a concrete estimator +result (e.g. ``StalignResult``). A new estimator only has to satisfy +:class:`AlignResult` -- a ``transform`` that maps points into the reference +frame -- to plug into :func:`squidpy.experimental.tl.align`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable + +import numpy.typing as npt + +from squidpy._utils import NDArrayA + +if TYPE_CHECKING: + from squidpy.experimental._methods.align_landmarks._landmark import AffineFitResult + +__all__ = ["AlignResult", "AlignSamplesFn", "AlignLandmarksFn"] + + +@runtime_checkable +class AlignResult(Protocol): + """A fitted alignment that maps ``(N, 2)`` ``(x, y)`` points into the reference frame. + + This is the only thing the public ``align*`` functions require of an + estimator's result, so ``output_mode="object"`` is agnostic to the method + that produced it. + """ + + def transform(self, points: npt.ArrayLike, /) -> NDArrayA: + """Map an ``(N, 2)`` ``(x, y)`` array into the reference frame.""" + ... + + +class AlignSamplesFn(Protocol): + """Calling convention for ``align_samples`` estimators. + + Two point clouds in (passed by keyword as ``ref`` / ``query`` so the + direction can never be silently swapped), one :class:`AlignResult` out. + Solver-specific options arrive through ``**kwargs``. + """ + + def __call__(self, ref: npt.ArrayLike, query: npt.ArrayLike, **kwargs: Any) -> AlignResult: ... + + +class AlignLandmarksFn(Protocol): + """Calling convention for ``align_landmarks`` estimators: paired landmarks in, affine out.""" + + def __call__( + self, + landmarks_ref: npt.ArrayLike, + landmarks_query: npt.ArrayLike, + *, + source_cs: str | None = ..., + target_cs: str | None = ..., + ) -> AffineFitResult: ... diff --git a/src/squidpy/experimental/_methods/_registry.py b/src/squidpy/experimental/_methods/_registry.py new file mode 100644 index 000000000..75d7c1ab0 --- /dev/null +++ b/src/squidpy/experimental/_methods/_registry.py @@ -0,0 +1,73 @@ +"""A flat registry mapping method names to fitting functions.""" + +from __future__ import annotations + +import functools +import importlib.util +from collections.abc import Callable +from typing import Any, Generic, TypeVar + +#: The calling convention a family's registry advertises (returned by :meth:`Registry.get`). +F = TypeVar("F", bound=Callable[..., Any]) +#: The concrete function being registered. Kept separate from ``F`` so an estimator may +#: declare specific keyword parameters (e.g. ``config=``) without having to structurally +#: match the family's open-ended ``**kwargs`` calling convention. +RegisteredT = TypeVar("RegisteredT", bound=Callable[..., Any]) + + +class Registry(Generic[F]): + """A flat ``name -> function`` registry for one *family* of methods. + + One :class:`Registry` is created per family (e.g. ``align``, ``impute``), + so keys are plain method names -- there is no ``(method, mode)`` compound + key, because the family already pins the rest. + + The type parameter ``F`` is the family's calling convention (a callable + :class:`~typing.Protocol`); :meth:`get` returns it, so dispatch sites are + typed against the family contract rather than ``Callable[..., Any]``. + """ + + def __init__(self, name: str) -> None: + self.name = name + self._registry: dict[str, F] = {} + + def register(self, key: str, *, requires: tuple[str, ...] = ()) -> Callable[[RegisteredT], RegisteredT]: + """Return a decorator registering a method/function under ``key``.""" + + def decorator(func: RegisteredT) -> RegisteredT: + if key in self._registry: + raise ValueError(f"Method {key!r} is already registered in the {self.name!r} registry.") + + if requires: + + @functools.wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + missing = [pkg for pkg in requires if importlib.util.find_spec(pkg) is None] + if missing: + verb = "is" if len(missing) == 1 else "are" + names = ", ".join(repr(p) for p in missing) + extras = ",".join(missing) + raise ImportError( + f"Method {key!r} requires {names}, which {verb} not installed. " + f'Install with `pip install "squidpy[{extras}]"`.' + ) + return func(*args, **kwargs) + + self._registry[key] = wrapped # type: ignore[assignment] + return wrapped # type: ignore[return-value] + else: + self._registry[key] = func # type: ignore[assignment] + return func + + return decorator + + def get(self, key: str) -> F: + """Return the function registered under ``key``.""" + try: + return self._registry[key] + except KeyError: + raise ValueError(f"Unknown {self.name} method {key!r}. Available: {sorted(self._registry)}.") from None + + def keys(self) -> tuple[str, ...]: + """Return the registered method names.""" + return tuple(self._registry) diff --git a/src/squidpy/experimental/_methods/align_landmarks/__init__.py b/src/squidpy/experimental/_methods/align_landmarks/__init__.py new file mode 100644 index 000000000..37ee80aa0 --- /dev/null +++ b/src/squidpy/experimental/_methods/align_landmarks/__init__.py @@ -0,0 +1,17 @@ +"""``align_landmarks`` family: closed-form alignment from paired landmarks.""" + +from __future__ import annotations + +from squidpy.experimental._methods._families import ALIGN_LANDMARKS +from squidpy.experimental._methods.align_landmarks._landmark import ( + AffineFitResult, + fit_affine, + fit_similarity, +) + +__all__ = [ + "ALIGN_LANDMARKS", + "AffineFitResult", + "fit_affine", + "fit_similarity", +] diff --git a/src/squidpy/experimental/_methods/align_landmarks/_landmark.py b/src/squidpy/experimental/_methods/align_landmarks/_landmark.py new file mode 100644 index 000000000..714309f1b --- /dev/null +++ b/src/squidpy/experimental/_methods/align_landmarks/_landmark.py @@ -0,0 +1,157 @@ +"""Closed-form landmark alignment estimators.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import numpy.typing as npt + +from squidpy._utils import NDArrayA +from squidpy.experimental._methods._families import ALIGN_LANDMARKS + + +@dataclass +class AffineFitResult: + """A fitted ``(3, 3)`` homogeneous affine mapping query onto ref, in ``(x, y)``.""" + + matrix: np.ndarray + source_cs: str | None = None + target_cs: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.matrix.shape != (3, 3): + raise ValueError(f"Expected a (3, 3) homogeneous matrix, found shape {self.matrix.shape}.") + + def transform(self, x: npt.ArrayLike) -> NDArrayA: + """Apply the affine to an ``(N, 2)`` ``(x, y)`` coordinate array.""" + coords = np.asarray(x, dtype=float) + if coords.ndim != 2 or coords.shape[1] != 2: + raise ValueError(f"Expected an (N, 2) coordinate array, found shape {coords.shape}.") + return coords @ self.matrix[:2, :2].T + self.matrix[:2, 2] + + +def _fit_landmark_relation( + landmarks_ref: np.ndarray, + landmarks_query: np.ndarray, + *, + method: str, + solve_fn: Callable[[np.ndarray, np.ndarray], np.ndarray], + source_cs: str | None = None, + target_cs: str | None = None, +) -> AffineFitResult: + ref = _validate_landmarks(landmarks_ref, name="landmarks_ref") + query = _validate_landmarks(landmarks_query, name="landmarks_query") + if ref.shape != query.shape: + raise ValueError( + f"`landmarks_ref` and `landmarks_query` must have the same shape; got {ref.shape} and {query.shape}." + ) + if ref.shape[0] < 3: + raise ValueError(f"`{method}` needs at least 3 landmark pairs, got {ref.shape[0]}.") + + matrix = solve_fn(ref, query) + return AffineFitResult( + matrix=matrix, + source_cs=source_cs, + target_cs=target_cs, + metadata={"method": method}, + ) + + +@ALIGN_LANDMARKS.register("similarity") +def fit_similarity( + landmarks_ref: np.ndarray, + landmarks_query: np.ndarray, + *, + source_cs: str | None = None, + target_cs: str | None = None, +) -> AffineFitResult: + """4-DOF similarity fit (rotation + uniform scale + translation), via spatialdata. + + Parameters + ---------- + landmarks_ref, landmarks_query + Pre-paired ``(N, 2)`` ``(x, y)`` landmark arrays (``N >= 3``). + source_cs, target_cs + Optional coordinate-system labels stamped onto the result for + traceability; they do not affect the fit. + """ + return _fit_landmark_relation( + landmarks_ref, + landmarks_query, + method="similarity", + solve_fn=_fit_similarity, + source_cs=source_cs, + target_cs=target_cs, + ) + + +@ALIGN_LANDMARKS.register("affine") +def fit_affine( + landmarks_ref: np.ndarray, + landmarks_query: np.ndarray, + *, + source_cs: str | None = None, + target_cs: str | None = None, +) -> AffineFitResult: + """6-DOF affine fit (rotation + non-uniform scale + shear + translation), via skimage. + + Parameters + ---------- + landmarks_ref, landmarks_query + Pre-paired ``(N, 2)`` ``(x, y)`` landmark arrays (``N >= 3``). + source_cs, target_cs + Optional coordinate-system labels stamped onto the result for + traceability; they do not affect the fit. + """ + return _fit_landmark_relation( + landmarks_ref, + landmarks_query, + method="affine", + solve_fn=_fit_affine, + source_cs=source_cs, + target_cs=target_cs, + ) + + +def _validate_landmarks(points: np.ndarray, *, name: str) -> np.ndarray: + arr = np.asarray(points, dtype=float) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"`{name}` must be a sequence of (x, y) pairs, got shape {arr.shape}.") + if not np.all(np.isfinite(arr)): + raise ValueError(f"`{name}` must contain only finite values.") + return arr + + +def _fit_similarity(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray: + """4-DOF similarity fit, delegated to spatialdata.""" + from spatialdata.models import PointsModel + from spatialdata.transformations import get_transformation_between_landmarks + + refs_pts = PointsModel.parse(ref_xy) + moving_pts = PointsModel.parse(query_xy) + sd_transform = get_transformation_between_landmarks(refs_pts, moving_pts) + return _extract_affine_matrix(sd_transform) + + +def _fit_affine(ref_xy: np.ndarray, query_xy: np.ndarray) -> np.ndarray: + """Full 6-DOF affine fit, delegated to skimage's least-squares estimator.""" + from skimage.transform import estimate_transform + + model_obj = estimate_transform("affine", src=query_xy, dst=ref_xy) + return np.asarray(model_obj.params) + + +def _extract_affine_matrix(sd_transform: object) -> np.ndarray: + """Pull a ``(3, 3)`` homogeneous matrix out of a spatialdata transformation.""" + from spatialdata.transformations import Affine as SDAffine + from spatialdata.transformations import Sequence as SDSequence + + if isinstance(sd_transform, SDAffine): + return np.asarray(sd_transform.matrix) + if isinstance(sd_transform, SDSequence): + return np.asarray(sd_transform.to_affine_matrix(input_axes=("x", "y"), output_axes=("x", "y"))) + raise TypeError(f"Unexpected transformation type from spatialdata: {type(sd_transform).__name__}.") diff --git a/src/squidpy/experimental/_methods/align_samples/__init__.py b/src/squidpy/experimental/_methods/align_samples/__init__.py new file mode 100644 index 000000000..933ea7d06 --- /dev/null +++ b/src/squidpy/experimental/_methods/align_samples/__init__.py @@ -0,0 +1,14 @@ +"""``align_samples`` family: align two samples' point clouds (STalign). + +Importing this package registers the family's estimators into +:data:`~squidpy.experimental._methods._families.ALIGN_SAMPLES`. It stays cheap -- +JAX is pulled in lazily, only when an estimator's ``fit`` runs. +""" + +from __future__ import annotations + +from squidpy.experimental._methods._families import ALIGN_SAMPLES +from squidpy.experimental._methods.align_samples._stalign import fit_stalign +from squidpy.experimental._methods.align_samples._stalign_impl._tools import StalignResult + +__all__ = ["ALIGN_SAMPLES", "fit_stalign", "StalignResult"] diff --git a/src/squidpy/experimental/_methods/align_samples/_stalign.py b/src/squidpy/experimental/_methods/align_samples/_stalign.py new file mode 100644 index 000000000..59490b9c8 --- /dev/null +++ b/src/squidpy/experimental/_methods/align_samples/_stalign.py @@ -0,0 +1,67 @@ +"""STalign estimator: JAX LDDMM point-cloud registration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy.typing as npt + +from squidpy.experimental._methods._families import ALIGN_SAMPLES + +if TYPE_CHECKING: + from ._stalign_impl._tools import STalignConfig, StalignResult + + +@ALIGN_SAMPLES.register("stalign", requires=("jax",)) +def fit_stalign( + ref: npt.ArrayLike, + query: npt.ArrayLike, + *, + config: STalignConfig | None = None, + landmarks_source: npt.ArrayLike | None = None, + landmarks_target: npt.ArrayLike | None = None, +) -> StalignResult: + """Fit a deformation mapping ``query`` onto ``ref``. + + Parameters + ---------- + ref + ``(N, 2)`` reference point cloud in ``(x, y)`` order. + query + ``(M, 2)`` query point cloud in ``(x, y)`` order, to be aligned to + ``ref``. Both are plain in-memory arrays; extracting them from an + ``AnnData`` / ``SpatialData`` is the caller's responsibility. + config + Optional :class:`STalignConfig` of solver hyperparameters. + landmarks_source, landmarks_target + Optional corresponding ``(x, y)`` landmark arrays used to + initialise the affine. Must be provided together. + + Returns + ------- + A :class:`StalignResult` whose :meth:`~StalignResult.transform` maps + ``(x, y)`` points into the reference frame; ``aligned_points`` is the fitted + ``query`` already mapped. + """ + # Import the JAX-backed solver only after requirements pass, so callers + # without JAX get the clean ImportError from check_requirements rather + # than a confusing failure from a module-level `import jax`. + import jax.numpy as jnp + + from ._stalign_impl._helpers import validate_points + from ._stalign_impl._tools import stalign_points + + ref_xy = validate_points(ref, name="ref") + query_xy = validate_points(query, name="query") + + # The solver runs internally in row-col (y, x); inputs are (x, y) -- swap at the boundary. + lm_src = None if landmarks_source is None else jnp.asarray(landmarks_source)[:, ::-1] + lm_tgt = None if landmarks_target is None else jnp.asarray(landmarks_target)[:, ::-1] + + return stalign_points( + source_points=query_xy[:, ::-1], + target_points=ref_xy[:, ::-1], + config=config, + landmarks_source=lm_src, + landmarks_target=lm_tgt, + ) diff --git a/src/squidpy/experimental/_methods/align_samples/_stalign_impl/__init__.py b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/__init__.py new file mode 100644 index 000000000..26c72f7e2 --- /dev/null +++ b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/__init__.py @@ -0,0 +1,6 @@ +"""Ported STalign JAX LDDMM solver + +Pure numerics only and these functions are gated by the jax requirement. +""" + +from __future__ import annotations diff --git a/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_core.py b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_core.py new file mode 100644 index 000000000..39996a099 --- /dev/null +++ b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_core.py @@ -0,0 +1,369 @@ +"""Core JAX implementation for experimental STalign point registration.""" + +from __future__ import annotations + +from typing import Any, Literal + +import jax +import jax.numpy as jnp +import jax.scipy as jsp +import numpy as np + +__all__ = ["jax_dtype", "lddmm", "transform_points_row_col"] + + +def jax_dtype() -> jnp.dtype: + """Resolve the active JAX float dtype at call time, not import time.""" + return jnp.float64 if jax.config.x64_enabled else jnp.float32 + + +def _to_affine(linear: jax.Array, translation: jax.Array) -> jax.Array: + return jnp.array( + [ + [linear[0, 0], linear[0, 1], translation[0]], + [linear[1, 0], linear[1, 1], translation[1]], + [0.0, 0.0, 1.0], + ], + dtype=linear.dtype, + ) + + +def _grid_points(x: tuple[jax.Array, jax.Array]) -> jax.Array: + yy, xx = jnp.meshgrid(x[0], x[1], indexing="ij") + return jnp.stack((yy, xx)) + + +def _interp( + x: tuple[jax.Array, jax.Array], + image: jax.Array, + phii: jax.Array, + *, + mode: str = "nearest", +) -> jax.Array: + """Interpolate a channels-first image on physical row-column coordinates.""" + arr = jnp.asarray(image) + coords = jnp.asarray(phii) + if coords.shape[0] != 2: + raise ValueError(f"Expected interpolation coordinates to have leading axis of size 2, found `{coords.shape}`.") + + if arr.ndim == 2: + arr = arr[None, ...] + + row_step = x[0][1] - x[0][0] + col_step = x[1][1] - x[1][0] + row_idx = (coords[0] - x[0][0]) / row_step + col_idx = (coords[1] - x[1][0]) / col_step + idx = jnp.stack((row_idx.reshape(-1), col_idx.reshape(-1))) + + def _sample(channel: jax.Array) -> jax.Array: + values = jsp.ndimage.map_coordinates(channel, idx, order=1, mode=mode) + return values.reshape(coords.shape[1:]) + + return jax.vmap(_sample)(arr) + + +def transform_points_row_col( + xv: tuple[jax.Array, jax.Array], + velocity: jax.Array, + affine: jax.Array, + points: np.ndarray | jax.Array, + *, + direction: Literal["forward", "backward"] = "forward", +) -> jax.Array: + pts = jnp.asarray(points) + n_steps = velocity.shape[0] + time_steps = range(n_steps) + flow_sign = 1.0 + if direction == "backward": + affine = jnp.linalg.inv(affine) + pts = pts @ affine[:2, :2].T + affine[:2, -1] + flow_sign = -1.0 + time_steps = reversed(time_steps) + + for t in time_steps: + disp = _interp( + xv, + jnp.moveaxis(flow_sign * velocity[t], -1, 0), + pts.T[:, :, None], + mode="nearest", + )[:, :, 0].T + pts = pts + disp / n_steps + + if direction == "forward": + pts = pts @ affine[:2, :2].T + affine[:2, -1] + + return pts + + +def _transform_grid_backward( + x_target: tuple[jax.Array, jax.Array], + xv: tuple[jax.Array, jax.Array], + velocity: jax.Array, + affine: jax.Array, +) -> jax.Array: + target_grid = _grid_points(x_target) + affine_inv = jnp.linalg.inv(affine) + source_grid = jnp.einsum("ij,jhw->ihw", affine_inv[:2, :2], target_grid) + affine_inv[:2, -1][:, None, None] + + for t in range(velocity.shape[0] - 1, -1, -1): + disp = _interp(xv, jnp.moveaxis(-velocity[t], -1, 0), source_grid, mode="nearest") + source_grid = source_grid + disp / velocity.shape[0] + + return source_grid + + +def _contrast_transform(source_image: jax.Array, target_image: jax.Array, weights: jax.Array) -> jax.Array: + flat_source = source_image.reshape(source_image.shape[0], -1) + flat_target = target_image.reshape(target_image.shape[0], -1) + flat_weights = weights.reshape(-1) + + design = jnp.concatenate((jnp.ones((1, flat_source.shape[1]), dtype=source_image.dtype), flat_source), axis=0) + weighted_design = design * flat_weights[None, :] + design_cov = weighted_design @ design.T + target_cov = weighted_design @ flat_target.T + regularized = design_cov + 0.1 * jnp.eye(design_cov.shape[0], dtype=design_cov.dtype) + coefficients = jnp.linalg.solve(regularized, target_cov) + return (coefficients.T @ design).reshape(target_image.shape) + + +def _build_velocity_grid(x_source: tuple[jax.Array, jax.Array], *, a: float, expand: float) -> tuple[jax.Array, jax.Array]: + minimum = jnp.array([x_source[0][0], x_source[1][0]]) + maximum = jnp.array([x_source[0][-1], x_source[1][-1]]) + center = (minimum + maximum) / 2.0 + half_width = (maximum - minimum) * expand / 2.0 + step = a * 0.5 + return ( + jnp.arange(center[0] - half_width[0], center[0] + half_width[0] + step, step), + jnp.arange(center[1] - half_width[1], center[1] + half_width[1] + step, step), + ) + + +def _build_regularizer( + xv: tuple[jax.Array, jax.Array], + *, + a: float, + p: float, +) -> tuple[jax.Array, jax.Array, float | jax.Array]: + dv = jnp.array([xv[0][1] - xv[0][0], xv[1][1] - xv[1][0]]) + shape = (xv[0].shape[0], xv[1].shape[0]) + fy = jnp.arange(shape[0], dtype=xv[0].dtype) / (shape[0] * dv[0]) + fx = jnp.arange(shape[1], dtype=xv[1].dtype) / (shape[1] * dv[1]) + frequency_grid = jnp.stack(jnp.meshgrid(fy, fx, indexing="ij"), axis=-1) + ll = (1.0 + 2.0 * a**2 * jnp.sum((1.0 - jnp.cos(2.0 * np.pi * frequency_grid * dv)) / (dv**2), axis=-1)) ** ( + 2.0 * p + ) + kernel = 1.0 / ll + dv_prod = jnp.prod(dv) + return kernel, ll, dv_prod + + +def _update_mixture_weights( + transformed_source: jax.Array, + target_image: jax.Array, + match_weights: jax.Array, + artifact_weights: jax.Array, + background_weights: jax.Array, + *, + sigmaM: float, + sigmaA: float, + sigmaB: float, + estimate_muA: bool, + estimate_muB: bool, + muA: jax.Array, + muB: jax.Array, + iteration: int, +) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]: + if estimate_muA: + muA = jnp.sum(artifact_weights * target_image, axis=(-1, -2)) / jnp.maximum(jnp.sum(artifact_weights), 1e-12) + if estimate_muB: + muB = jnp.sum(background_weights * target_image, axis=(-1, -2)) / jnp.maximum( + jnp.sum(background_weights), 1e-12 + ) + + if iteration < 50: + return match_weights, artifact_weights, background_weights, muA, muB + + weights = jnp.stack((match_weights, artifact_weights, background_weights)) + mixing = jnp.sum(weights, axis=(1, 2)) + mixing = mixing + jnp.max(mixing) * 1e-6 + mixing = mixing / jnp.sum(mixing) + + n_channels = target_image.shape[0] + norm_match = (2.0 * np.pi * sigmaM**2) ** (n_channels / 2.0) + norm_artifact = (2.0 * np.pi * sigmaA**2) ** (n_channels / 2.0) + norm_background = (2.0 * np.pi * sigmaB**2) ** (n_channels / 2.0) + + match_weights = mixing[0] * jnp.exp(-jnp.sum((transformed_source - target_image) ** 2, axis=0) / (2.0 * sigmaM**2)) + match_weights = match_weights / norm_match + artifact_weights = mixing[1] * jnp.exp( + -jnp.sum((muA[:, None, None] - target_image) ** 2, axis=0) / (2.0 * sigmaA**2) + ) + artifact_weights = artifact_weights / norm_artifact + background_weights = mixing[2] * jnp.exp( + -jnp.sum((muB[:, None, None] - target_image) ** 2, axis=0) / (2.0 * sigmaB**2) + ) + background_weights = background_weights / norm_background + + total = match_weights + artifact_weights + background_weights + total = total + jnp.max(total) * 1e-6 + return match_weights / total, artifact_weights / total, background_weights / total, muA, muB + + +def _lddmm_loss( + linear: jax.Array, + translation: jax.Array, + velocity: jax.Array, + *, + x_source: tuple[jax.Array, jax.Array], + source_image: jax.Array, + x_target: tuple[jax.Array, jax.Array], + target_image: jax.Array, + xv: tuple[jax.Array, jax.Array], + match_weights: jax.Array, + ll: jax.Array, + dv_prod: float | jax.Array, + points_source: jax.Array, + points_target: jax.Array, + sigmaM: float, + sigmaR: float, + sigmaP: float, +) -> tuple[jax.Array, tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]]: + affine = _to_affine(linear, translation) + source_grid = _transform_grid_backward(x_target, xv, velocity, affine) + warped_source = _interp(x_source, source_image, source_grid, mode="nearest") + contrast_source = _contrast_transform(warped_source, target_image, match_weights) + + match_energy = jnp.sum((contrast_source - target_image) ** 2 * match_weights) / (2.0 * sigmaM**2) + fft_velocity = jnp.fft.fftn(velocity, axes=(1, 2)) + reg_energy = ( + jnp.sum(jnp.sum(jnp.abs(fft_velocity) ** 2, axis=(0, 3)) * ll) + * dv_prod + / 2.0 + / velocity.shape[1] + / velocity.shape[2] + / sigmaR**2 + ) + + transformed_points = transform_points_row_col(xv, velocity, affine, points_source, direction="forward") + if points_source.shape[0] == 0: + point_energy = jnp.array(0.0, dtype=source_image.dtype) + else: + point_energy = jnp.sum((transformed_points - points_target) ** 2) / (2.0 * sigmaP**2) + + total = match_energy + reg_energy + point_energy + return total, (contrast_source, transformed_points, match_energy, reg_energy, point_energy) + + +def lddmm( + xI: tuple[np.ndarray | jax.Array, np.ndarray | jax.Array], + I: np.ndarray | jax.Array, + xJ: tuple[np.ndarray | jax.Array, np.ndarray | jax.Array], + J: np.ndarray | jax.Array, + *, + L: np.ndarray | jax.Array, + T: np.ndarray | jax.Array, + points_source: np.ndarray | jax.Array | None = None, + points_target: np.ndarray | jax.Array | None = None, + a: float = 500.0, + p: float = 2.0, + expand: float = 2.0, + nt: int = 3, + niter: int = 5000, + diffeo_start: int = 0, + epL: float = 2e-8, + epT: float = 2e-1, + epV: float = 2e3, + sigmaM: float = 1.0, + sigmaB: float = 2.0, + sigmaA: float = 5.0, + sigmaR: float = 5e5, + sigmaP: float = 2e1, +) -> dict[str, Any]: + x_source = (jnp.asarray(xI[0]), jnp.asarray(xI[1])) + x_target = (jnp.asarray(xJ[0]), jnp.asarray(xJ[1])) + source_image = jnp.asarray(I, dtype=jax_dtype()) + target_image = jnp.asarray(J, dtype=jax_dtype()) + linear = jnp.asarray(L, dtype=jax_dtype()) + translation = jnp.asarray(T, dtype=jax_dtype()) + + if points_source is None: + source_landmarks = jnp.zeros((0, 2), dtype=jax_dtype()) + target_landmarks = jnp.zeros((0, 2), dtype=jax_dtype()) + else: + source_landmarks = jnp.asarray(points_source, dtype=jax_dtype()) + target_landmarks = jnp.asarray(points_target, dtype=jax_dtype()) + + xv = _build_velocity_grid(x_source, a=a, expand=expand) + velocity = jnp.zeros((nt, xv[0].shape[0], xv[1].shape[0], 2), dtype=jax_dtype()) + kernel, ll, dv_prod = _build_regularizer(xv, a=a, p=p) + + match_weights = jnp.full(target_image.shape[1:], 0.5, dtype=target_image.dtype) + background_weights = jnp.full(target_image.shape[1:], 0.4, dtype=target_image.dtype) + artifact_weights = jnp.full(target_image.shape[1:], 0.1, dtype=target_image.dtype) + muA = jnp.mean(target_image, axis=(1, 2)) + muB = jnp.zeros_like(muA) + estimate_muA = True + estimate_muB = True + + loss_and_grad = jax.jit(jax.value_and_grad(_lddmm_loss, argnums=(0, 1, 2), has_aux=True)) + + for iteration in range(niter): + (energy, aux), (grad_linear, grad_translation, grad_velocity) = loss_and_grad( + linear, + translation, + velocity, + x_source=x_source, + source_image=source_image, + x_target=x_target, + target_image=target_image, + xv=xv, + match_weights=match_weights, + ll=ll, + dv_prod=dv_prod, + points_source=source_landmarks, + points_target=target_landmarks, + sigmaM=sigmaM, + sigmaR=sigmaR, + sigmaP=sigmaP, + ) + contrast_source, transformed_points, _, _, _ = aux + + affine_scale = 1.0 + 9.0 * float(iteration >= diffeo_start) + linear = linear - (epL / affine_scale) * grad_linear + translation = translation - (epT / affine_scale) * grad_translation + + grad_velocity = jnp.fft.ifftn( + jnp.fft.fftn(grad_velocity, axes=(1, 2)) * kernel[None, ..., None], + axes=(1, 2), + ).real + if iteration >= diffeo_start: + velocity = velocity - epV * grad_velocity + + if iteration % 5 == 0: + match_weights, artifact_weights, background_weights, muA, muB = _update_mixture_weights( + contrast_source, + target_image, + match_weights, + artifact_weights, + background_weights, + sigmaM=sigmaM, + sigmaA=sigmaA, + sigmaB=sigmaB, + estimate_muA=estimate_muA, + estimate_muB=estimate_muB, + muA=muA, + muB=muB, + iteration=iteration, + ) + + affine = _to_affine(linear, translation) + return { + "A": affine, + "v": velocity, + "xv": xv, + "WM": match_weights, + "WB": background_weights, + "WA": artifact_weights, + "E": energy, + "points": transformed_points, + } diff --git a/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_helpers.py b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_helpers.py new file mode 100644 index 000000000..9da0a2067 --- /dev/null +++ b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_helpers.py @@ -0,0 +1,126 @@ +"""Numeric helpers for STalign point-cloud registration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import jax.numpy as jnp +import numpy as np + +from ._core import jax_dtype + +if TYPE_CHECKING: + import jax + + JaxArray = jax.Array +else: # pragma: no cover - typing only + JaxArray = Any + +__all__ = [ + "affine_from_points", + "rasterize", + "validate_points", +] + + +def validate_points(points: Any, *, name: str) -> JaxArray: + """Coerce ``points`` to a finite ``(n, 2)`` JAX array.""" + arr = jnp.asarray(points, dtype=jax_dtype()) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"Expected `{name}` to have shape `(n, 2)`, found `{arr.shape}`.") + if not bool(jnp.all(jnp.isfinite(arr))): + raise ValueError(f"Expected `{name}` to contain only finite values.") + return arr + + +def rasterize( + x: np.ndarray, + y: np.ndarray, + *, + dx: float = 30.0, + blur: float | list[float] = 1.0, + expand: float = 1.1, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Rasterize a point cloud into a multi-scale Gaussian density image. + + Each point splats a normalized Gaussian over a fixed ``(2r + 1)`` patch and + the patches are accumulated onto the grid. + """ + x = np.asarray(x, dtype=float).reshape(-1) + y = np.asarray(y, dtype=float).reshape(-1) + if x.shape != y.shape: + raise ValueError("Expected `x` and `y` to be 1D arrays with the same length.") + if x.size == 0: + raise ValueError("Expected at least one point to rasterize.") + if dx <= 0: + raise ValueError("Expected `dx` to be positive.") + if expand <= 0: + raise ValueError("Expected `expand` to be positive.") + + blur_values = np.atleast_1d(np.asarray(blur, dtype=float)) + if blur_values.ndim != 1 or np.any(blur_values <= 0): + raise ValueError("Expected `blur` to be a positive scalar or a 1D sequence of positive values.") + + min_x = float(np.min(x)) + max_x = float(np.max(x)) + min_y = float(np.min(y)) + max_y = float(np.max(y)) + + center_x = (min_x + max_x) / 2.0 + center_y = (min_y + max_y) / 2.0 + half_x = (max_x - min_x) * expand / 2.0 + half_y = (max_y - min_y) * expand / 2.0 + + grid_x = np.arange(center_x - half_x, center_x + half_x + dx, dx, dtype=float) + grid_y = np.arange(center_y - half_y, center_y + half_y + dx, dx, dtype=float) + if grid_x.size < 2 or grid_y.size < 2: + raise ValueError("Rasterized grid is too small. Increase the point spread or lower `dx`.") + + mesh_x, mesh_y = np.meshgrid(grid_x, grid_y) + out = np.zeros((len(blur_values), grid_y.size, grid_x.size), dtype=float) + radius = int(np.ceil(float(np.max(blur_values)) * 4.0)) + denom = 2.0 * (dx * blur_values * 2.0) ** 2 + + for x_i, y_i in zip(x, y, strict=False): + col = int(np.rint((x_i - grid_x[0]) / dx)) + row = int(np.rint((y_i - grid_y[0]) / dx)) + + row0 = max(row - radius, 0) + row1 = min(row + radius, out.shape[1] - 1) + col0 = max(col - radius, 0) + col1 = min(col + radius, out.shape[2] - 1) + + patch_x = mesh_x[row0 : row1 + 1, col0 : col1 + 1] + patch_y = mesh_y[row0 : row1 + 1, col0 : col1 + 1] + + kernels = np.exp(-((patch_x[..., None] - x_i) ** 2 + (patch_y[..., None] - y_i) ** 2) / denom) + kernels_sum = kernels.sum(axis=(0, 1), keepdims=True) + kernels /= np.where(kernels_sum == 0.0, 1.0, kernels_sum) + out[:, row0 : row1 + 1, col0 : col1 + 1] += np.moveaxis(kernels, -1, 0) + + return grid_x, grid_y, out + + +def affine_from_points( + points_source: JaxArray, + points_target: JaxArray, +) -> tuple[np.ndarray, np.ndarray]: + """Compute an affine initialization from corresponding landmarks.""" + source = np.asarray(points_source, dtype=float) + target = np.asarray(points_target, dtype=float) + if source.shape != target.shape: + raise ValueError( + f"Expected `points_source` and `points_target` to have the same shape, found " + f"`{source.shape}` and `{target.shape}`." + ) + + if source.shape[0] < 3: + linear = np.eye(2, dtype=float) + translation = np.mean(target, axis=0) - np.mean(source, axis=0) + return linear, translation + + from skimage.transform import estimate_transform + + model_obj = estimate_transform("affine", src=source, dst=target) + affine = np.asarray(model_obj.params) + return affine[:2, :2], affine[:2, -1] diff --git a/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_tools.py b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_tools.py new file mode 100644 index 000000000..b5a6848bc --- /dev/null +++ b/src/squidpy/experimental/_methods/align_samples/_stalign_impl/_tools.py @@ -0,0 +1,183 @@ +"""Point-cloud tools for experimental STalign. + +The fitted map is returned as a single :class:`StalignResult`; container +write-back lives in the caller, not here. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, Literal, TypeAlias + +if TYPE_CHECKING: + import jax + + JaxArray = jax.Array +else: # pragma: no cover - typing only + JaxArray = Any + +BlurScales: TypeAlias = float | tuple[float, ...] | list[float] + +__all__ = [ + "STalignConfig", + "STalignPreprocessConfig", + "STalignRegistrationConfig", + "StalignResult", + "stalign_points", +] + + +@dataclass(slots=True) +class STalignPreprocessConfig: + dx: float = 30.0 + blur: BlurScales = (2.0, 1.0, 0.5) + expand: float = 1.1 + + +@dataclass(slots=True) +class STalignRegistrationConfig: + """LDDMM registration hyperparameters. + + Field names (``sigmaM``, ``epL``, etc.) preserve the conventions from + the STalign paper and reference implementation to keep them + recognisable when cross-referencing the literature. + """ + + a: float = 500.0 + p: float = 2.0 + expand: float = 2.0 + nt: int = 3 + niter: int = 5000 + diffeo_start: int = 0 + epL: float = 2e-8 + epT: float = 2e-1 + epV: float = 2e3 + sigmaM: float = 1.0 + sigmaB: float = 2.0 + sigmaA: float = 5.0 + sigmaR: float = 5e5 + sigmaP: float = 2e1 + + +@dataclass(slots=True) +class STalignConfig: + preprocess: STalignPreprocessConfig = field(default_factory=STalignPreprocessConfig) + registration: STalignRegistrationConfig = field(default_factory=STalignRegistrationConfig) + + +@dataclass(slots=True) +class StalignResult: + """A fitted STalign diffeomorphism, ready to transform arbitrary points. + + :meth:`transform` works in ``(x, y)``; ``aligned_points`` is the fitted query + cloud already mapped into the reference frame. + """ + + affine: JaxArray + velocity: JaxArray + velocity_grid: tuple[JaxArray, JaxArray] + aligned_points: JaxArray + + def transform( + self, + points: JaxArray, + *, + direction: Literal["forward", "backward"] = "forward", + ) -> JaxArray: + """Map ``(N, 2)`` ``(x, y)`` points with the fitted diffeomorphism.""" + import jax.numpy as jnp + + from ._core import jax_dtype, transform_points_row_col + + pts = jnp.asarray(points, dtype=jax_dtype()) + if pts.ndim != 2 or pts.shape[1] != 2: + raise ValueError(f"Expected an (N, 2) `(x, y)` array, found shape {pts.shape}.") + transformed_rc = transform_points_row_col( + self.velocity_grid, + self.velocity, + self.affine, + pts[:, ::-1], + direction=direction, + ) + return transformed_rc[:, ::-1] + + +def _rasterize_cloud(points_rc: JaxArray, config: STalignPreprocessConfig) -> tuple[tuple[JaxArray, JaxArray], JaxArray]: + """Rasterize a row-col cloud into a ``((grid_y, grid_x), image)`` density.""" + from ._helpers import rasterize + + grid_x, grid_y, image = rasterize( + points_rc[:, 1], + points_rc[:, 0], + dx=config.dx, + blur=config.blur, + expand=config.expand, + ) + return (grid_y, grid_x), image + + +def stalign_points( + source_points: JaxArray, + target_points: JaxArray, + *, + config: STalignConfig | None = None, + landmarks_source: JaxArray | None = None, + landmarks_target: JaxArray | None = None, +) -> StalignResult: + """Align a source point cloud onto a target with a JAX LDDMM solver. + + All point arrays are in the solver's row-column frame; the returned + :class:`StalignResult` speaks ``(x, y)``. + """ + import jax.numpy as jnp + + from ._core import jax_dtype, lddmm, transform_points_row_col + from ._helpers import affine_from_points, validate_points + + config = STalignConfig() if config is None else config + registration = config.registration + source_points = validate_points(source_points, name="source_points") + target_points = validate_points(target_points, name="target_points") + source_grid, source_image = _rasterize_cloud(source_points, config.preprocess) + target_grid, target_image = _rasterize_cloud(target_points, config.preprocess) + + if (landmarks_source is None) != (landmarks_target is None): + raise ValueError("Expected both landmark arrays to be provided together.") + + dtype = jax_dtype() + if landmarks_source is None: + linear = jnp.eye(2, dtype=dtype) + translation = jnp.zeros(2, dtype=dtype) + source_landmarks = None + target_landmarks = None + else: + source_landmarks = validate_points(landmarks_source, name="landmarks_source") + target_landmarks = validate_points(landmarks_target, name="landmarks_target") + linear_np, translation_np = affine_from_points(source_landmarks, target_landmarks) + linear = jnp.asarray(linear_np, dtype=dtype) + translation = jnp.asarray(translation_np, dtype=dtype) + + result = lddmm( + source_grid, + source_image, + target_grid, + target_image, + L=linear, + T=translation, + points_source=source_landmarks, + points_target=target_landmarks, + **asdict(registration), + ) + transformed_rc = transform_points_row_col( + result["xv"], + result["v"], + result["A"], + source_points, + direction="forward", + ) + return StalignResult( + affine=result["A"], + velocity=result["v"], + velocity_grid=result["xv"], + aligned_points=transformed_rc[:, ::-1], + ) diff --git a/src/squidpy/experimental/tl/__init__.py b/src/squidpy/experimental/tl/__init__.py index 1c2f97ece..803c04039 100644 --- a/src/squidpy/experimental/tl/__init__.py +++ b/src/squidpy/experimental/tl/__init__.py @@ -1,5 +1,19 @@ from __future__ import annotations +# `AlignResult` is the only result type on the public surface: it is the estimator +# contract (a `transform` mapping points into the reference frame) and the declared +# return of `align` / `align_by_landmarks`. The concrete results (`StalignResult`, +# `AffineFitResult`) stay in their home modules under `squidpy.experimental._methods` +# for callers that need raw fields -- the public API stays method-agnostic. +from squidpy.experimental._methods import AlignResult + +from ._align import align, align_by_landmarks from ._tiling_qc import TilingQCParams, calculate_tiling_qc -__all__ = ["TilingQCParams", "calculate_tiling_qc"] +__all__ = [ + "align", + "align_by_landmarks", + "calculate_tiling_qc", + "TilingQCParams", + "AlignResult", +] diff --git a/src/squidpy/experimental/tl/_align/__init__.py b/src/squidpy/experimental/tl/_align/__init__.py new file mode 100644 index 000000000..d386fa112 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/__init__.py @@ -0,0 +1,7 @@ +"""Public alignment API for :mod:`squidpy.experimental.tl`.""" + +from __future__ import annotations + +from squidpy.experimental.tl._align._api import align, align_by_landmarks + +__all__ = ["align", "align_by_landmarks"] diff --git a/src/squidpy/experimental/tl/_align/_api.py b/src/squidpy/experimental/tl/_align/_api.py new file mode 100644 index 000000000..53bf306c2 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_api.py @@ -0,0 +1,165 @@ +"""Public alignment functions built on the :mod:`squidpy.experimental._methods` core. + +These are thin orchestrators: resolve inputs to in-memory arrays, dispatch to a +fit-core estimator, write the result back. All container I/O and write-back live +in :mod:`._io`; the estimators themselves never see a container. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any, Literal + +import numpy as np +from anndata import AnnData +from spatialdata import SpatialData + +from squidpy._validators import assert_one_of +from squidpy.experimental._methods.align_landmarks import ALIGN_LANDMARKS +from squidpy.experimental._methods.align_samples import ALIGN_SAMPLES +from squidpy.experimental.tl._align._io import ( + get_coords, + resolve_obs_pair, + writeback_affine_sdata, + writeback_obs, +) + +OUTPUT_MODES = ("object", "copy", "inplace") +ON_VALUES = ("obs", "image") +if TYPE_CHECKING: + from squidpy.experimental._methods import AlignResult + +__all__ = ["align", "align_by_landmarks"] + + +def align( + data_ref: AnnData | SpatialData, + data_query: AnnData | SpatialData | None = None, + *, + method: str = "stalign", + on: Literal["obs", "image"] = "obs", + ref_key: str | None = None, + query_key: str | None = None, + spatial_key: str = "spatial", + output_mode: Literal["object", "copy", "inplace"] = "object", + key_added: str | None = None, + **method_kwargs: Any, +) -> AlignResult | AnnData | SpatialData | None: + """Align a query sample onto a reference sample. + + Parameters + ---------- + data_ref, data_query + Both :class:`~anndata.AnnData`, or both :class:`~spatialdata.SpatialData`, + or ``data_ref`` a SpatialData with ``data_query=None`` to align two of its + own tables (selected by ``ref_key`` / ``query_key``). + method + Estimator in the ``align_samples`` family. ``"stalign"`` (JAX LDDMM). + on + ``"obs"`` aligns the ``obsm`` point clouds. ``"image"`` is reserved and + currently raises :class:`NotImplementedError`. + ref_key, query_key + Table keys, required (and only valid) for SpatialData inputs. + spatial_key + ``obsm`` key holding the ``(x, y)`` coordinates. Defaults to ``"spatial"``. + output_mode + - ``"object"`` (default) -- return the fitted :class:`~squidpy.experimental._methods.AlignResult`; nothing is written. + - ``"inplace"`` -- write the aligned coordinates into the query and return ``None``. + - ``"copy"`` -- write into a copy of the query and return the copy. + key_added + Destination ``obsm`` key for the aligned coordinates. If ``None`` it + defaults to ``f"aligned_{spatial_key}"``; if that key already exists and + ``key_added`` was not given explicitly, a :class:`ValueError` is raised + (pass ``key_added`` to overwrite intentionally). + method_kwargs + Forwarded to the estimator's ``fit`` (e.g. ``config=STalignConfig(...)``). + """ + assert_one_of(output_mode, OUTPUT_MODES, name="output_mode") + assert_one_of(on, ON_VALUES, name="on") + if on == "image": + raise NotImplementedError("`align(on='image')` is not implemented yet; use `on='obs'`.") + + ref_adata, query_adata, container, element_key = resolve_obs_pair(data_ref, data_query, ref_key, query_key) + ref_xy = get_coords(ref_adata, spatial_key) + query_xy = get_coords(query_adata, spatial_key) + + result = ALIGN_SAMPLES.get(method)(ref=ref_xy, query=query_xy, **method_kwargs) + + return writeback_obs( + result, + output_mode=output_mode, + query_adata=query_adata, + container=container, + element_key=element_key, + spatial_key=spatial_key, + key_added=key_added, + ) + + +def align_by_landmarks( + landmarks_ref: np.ndarray | Sequence[tuple[float, float]], + landmarks_query: np.ndarray | Sequence[tuple[float, float]], + *, + method: Literal["similarity", "affine"] = "similarity", + data: AnnData | SpatialData | None = None, + cs_name_ref: str | None = None, + cs_name_query: str | None = None, + spatial_key: str = "spatial", + output_mode: Literal["object", "copy", "inplace"] = "object", + key_added: str | None = None, +) -> AlignResult | AnnData | SpatialData | None: + """Align by a closed-form fit on pre-paired landmarks. + + Parameters + ---------- + landmarks_ref, landmarks_query + Equal-length ``(N, 2)`` ``(x, y)`` landmark arrays (``N >= 3``), paired by + row order. No automatic correspondence matching is performed. + method + Estimator in the ``align_landmarks`` family: ``"similarity"`` (4 DOF) or + ``"affine"`` (6 DOF). + data + Target to write the alignment into. Required for ``output_mode`` other + than ``"object"``. + cs_name_ref, cs_name_query + Coordinate-system names. For a SpatialData ``data`` the fitted affine is + registered on every element in ``cs_name_query``, mapping into + ``cs_name_ref``. + spatial_key + ``obsm`` key when ``data`` is an :class:`~anndata.AnnData`. + output_mode + See :func:`align`. ``"object"`` (default) returns the fitted + :class:`~squidpy.experimental._methods.AlignResult` (an + :class:`~squidpy.experimental._methods.align_landmarks.AffineFitResult`). + key_added + Destination ``obsm`` key when ``data`` is an AnnData (see :func:`align`). + """ + assert_one_of(output_mode, OUTPUT_MODES, name="output_mode") + + result = ALIGN_LANDMARKS.get(method)( + landmarks_ref=landmarks_ref, + landmarks_query=landmarks_query, + source_cs=cs_name_query, + target_cs=cs_name_ref, + ) + + if output_mode == "object": + return result + if data is None: + raise ValueError("`data` is required when `output_mode` is 'copy' or 'inplace'.") + + if isinstance(data, SpatialData): + return writeback_affine_sdata( + result, data, output_mode=output_mode, moving_cs=cs_name_query, target_cs=cs_name_ref + ) + if isinstance(data, AnnData): + return writeback_obs( + result, + output_mode=output_mode, + query_adata=data, + container=None, + element_key=None, + spatial_key=spatial_key, + key_added=key_added, + ) + raise TypeError(f"`data` must be AnnData or SpatialData, got {type(data).__name__}.") diff --git a/src/squidpy/experimental/tl/_align/_io.py b/src/squidpy/experimental/tl/_align/_io.py new file mode 100644 index 000000000..84f0b5b08 --- /dev/null +++ b/src/squidpy/experimental/tl/_align/_io.py @@ -0,0 +1,170 @@ +"""Input resolvers and output write-back for the public align functions. + +This is the *only* layer that knows about the ``AnnData | SpatialData`` argument +shapes and the ``output_mode`` write-back strategies. The fit estimators in +:mod:`squidpy.experimental._methods` operate on plain arrays and never see a +container. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np +from anndata import AnnData +from spatialdata import SpatialData + +from squidpy._validators import assert_key_in_sdata + +if TYPE_CHECKING: + from squidpy.experimental._methods import AlignResult + from squidpy.experimental._methods.align_landmarks import AffineFitResult + + +# --------------------------------------------------------------------------- +# Read side +# --------------------------------------------------------------------------- + + +def resolve_obs_pair( + data_ref: AnnData | SpatialData, + data_query: AnnData | SpatialData | None, + ref_key: str | None, + query_key: str | None, +) -> tuple[AnnData, AnnData, SpatialData | None, str | None]: + """Normalise ``align(on="obs")`` inputs. + + Returns ``(ref_adata, query_adata, query_container, query_key)`` where + ``query_container`` is the SpatialData to write back into (``None`` for plain + AnnData inputs). + """ + if isinstance(data_ref, AnnData): + if data_query is None: + raise ValueError("`data_query` is required when `data_ref` is an AnnData.") + if not isinstance(data_query, AnnData): + raise TypeError( + f"Mixed AnnData/SpatialData inputs are not supported; `data_query` is {type(data_query).__name__}." + ) + if ref_key is not None or query_key is not None: + raise ValueError("`ref_key`/`query_key` are only valid for SpatialData inputs.") + return data_ref, data_query, None, None + + if not isinstance(data_ref, SpatialData): + raise TypeError(f"`data_ref` must be AnnData or SpatialData, got {type(data_ref).__name__}.") + + sdata_query = data_ref if data_query is None else data_query + if not isinstance(sdata_query, SpatialData): + raise TypeError( + f"Mixed AnnData/SpatialData inputs are not supported; `data_query` is {type(data_query).__name__}." + ) + if ref_key is None or query_key is None: + raise ValueError("`ref_key` and `query_key` are required for SpatialData inputs.") + assert_key_in_sdata(data_ref, ref_key, attr="tables") + assert_key_in_sdata(sdata_query, query_key, attr="tables") + return data_ref.tables[ref_key], sdata_query.tables[query_key], sdata_query, query_key + + +def get_coords(adata: AnnData, spatial_key: str) -> np.ndarray: + """Return a validated ``(N, 2)`` ``(x, y)`` coordinate array from ``obsm``.""" + if spatial_key not in adata.obsm: + raise KeyError(f"`obsm[{spatial_key!r}]` not found; pass `spatial_key=` to select the coordinate key.") + arr = np.asarray(adata.obsm[spatial_key], dtype=float) + if arr.ndim != 2 or arr.shape[1] != 2: + raise ValueError(f"`obsm[{spatial_key!r}]` must be an (N, 2) array, found shape {arr.shape}.") + return arr + + +# --------------------------------------------------------------------------- +# Write side +# --------------------------------------------------------------------------- + + +def writeback_obs( + result: AlignResult, + *, + output_mode: str, + query_adata: AnnData, + container: SpatialData | None, + element_key: str | None, + spatial_key: str, + key_added: str | None, +) -> AlignResult | AnnData | SpatialData | None: + """Bake ``result.transform(coords)`` into the query ``obsm`` per ``output_mode``.""" + if output_mode == "object": + return result + + dest = _resolve_dest(query_adata, spatial_key=spatial_key, key_added=key_added) + new_coords = np.asarray(result.transform(get_coords(query_adata, spatial_key))) + + if container is None: + target = query_adata if output_mode == "inplace" else query_adata.copy() + target.obsm[dest] = new_coords + return None if output_mode == "inplace" else target + + sdata = container if output_mode == "inplace" else shallow_copy_sdata(container) + target_table = sdata.tables[element_key] + if output_mode == "copy": + # `shallow_copy_sdata` shares table objects with the original; copy the single + # table we mutate so `output_mode="copy"` truly leaves the input untouched. + target_table = target_table.copy() + sdata.tables[element_key] = target_table + target_table.obsm[dest] = new_coords + return None if output_mode == "inplace" else sdata + + +def writeback_affine_sdata( + result: AffineFitResult, + sdata: SpatialData, + *, + output_mode: str, + moving_cs: str | None, + target_cs: str | None, +) -> SpatialData | None: + """Register the fitted affine on every element living in ``moving_cs``. + + Non-destructive: it adds a transformation into ``target_cs`` so the whole + coordinate system inherits the alignment. Nothing is materialised. + """ + from spatialdata import deepcopy as sd_deepcopy + from spatialdata.transformations import Affine, get_transformation, set_transformation + + if moving_cs is None or target_cs is None: + raise ValueError("`cs_name_query` and `cs_name_ref` are required to register a transform on a SpatialData.") + + out = sdata if output_mode == "inplace" else shallow_copy_sdata(sdata) + sd_affine = Affine(np.asarray(result.matrix), input_axes=("x", "y"), output_axes=("x", "y")) + touched = False + for etype, name, element in list(out.gen_elements()): + if isinstance(element, AnnData): + continue + if moving_cs not in get_transformation(element, get_all=True): + continue + if output_mode == "copy": + # `shallow_copy_sdata` shares element objects with the original; deep-copy each + # element we register a transform on so `output_mode="copy"` leaves the input untouched. + element = sd_deepcopy(element) + getattr(out, etype)[name] = element + set_transformation(element, sd_affine, to_coordinate_system=target_cs) + touched = True + if not touched: + raise KeyError(f"No elements in the SpatialData are registered to coordinate system {moving_cs!r}.") + return None if output_mode == "inplace" else out + + +def _resolve_dest(adata: AnnData, *, spatial_key: str, key_added: str | None) -> str: + """Resolve the destination obsm key, guarding against silent overwrite.""" + if key_added is not None: + return key_added + dest = f"aligned_{spatial_key}" + if dest in adata.obsm: + raise ValueError( + f"`obsm[{dest!r}]` already exists. Pass `key_added` explicitly to choose the destination " + f"(or to overwrite it intentionally)." + ) + return dest + + +def shallow_copy_sdata(sdata: SpatialData) -> SpatialData: + """Shallow copy of a SpatialData for ``output_mode='copy'`` (via ``subset``).""" + names = [name for _, name, _ in sdata.gen_elements()] + return sdata.subset(names, filter_tables=False, include_orphan_tables=True) diff --git a/tests/experimental/_methods/__init__.py b/tests/experimental/_methods/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/_methods/test_core.py b/tests/experimental/_methods/test_core.py new file mode 100644 index 000000000..a657e1bd1 --- /dev/null +++ b/tests/experimental/_methods/test_core.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import pytest + +from squidpy.experimental._methods import Registry + + +@dataclass +class _MeanShiftResult: + """Toy result: a constant per-axis offset baked into ``transform``.""" + + delta: np.ndarray + metadata: dict[str, Any] = field(default_factory=dict) + + def transform(self, x: np.ndarray) -> np.ndarray: + return np.asarray(x, dtype=float) + self.delta + + +def fit_mean_shift(ref: np.ndarray, query: np.ndarray) -> _MeanShiftResult: + """Toy estimator function: fit the offset that maps the query centroid onto the ref centroid.""" + delta = np.asarray(ref, dtype=float).mean(0) - np.asarray(query, dtype=float).mean(0) + return _MeanShiftResult(delta=delta, metadata={"method": "mean_shift"}) + + +def test_fit_then_transform_round_trip() -> None: + ref = np.array([[1.0, 1.0], [3.0, 3.0]]) # centroid (2, 2) + query = np.array([[0.0, 0.0], [2.0, 2.0]]) # centroid (1, 1) + + result = fit_mean_shift(ref, query) + + np.testing.assert_allclose(result.delta, [1.0, 1.0]) + np.testing.assert_allclose(result.transform(query), query + 1.0) + assert result.metadata == {"method": "mean_shift"} + + +def test_registry_register_get_keys() -> None: + reg = Registry("demo") + + @reg.register("mean_shift") + def _registered(ref: np.ndarray, query: np.ndarray) -> _MeanShiftResult: + return fit_mean_shift(ref, query) + + assert reg.keys() == ("mean_shift",) + assert reg.get("mean_shift") is _registered + assert isinstance(reg.get("mean_shift")(np.ones((2, 2)), np.zeros((2, 2))), _MeanShiftResult) + + +def test_registry_unknown_key_lists_available() -> None: + reg = Registry("demo") + reg.register("a")(fit_mean_shift) + + with pytest.raises(ValueError, match=r"Unknown demo method 'b'. Available: \['a'\]"): + reg.get("b") + + +def test_registry_rejects_duplicate_key() -> None: + reg = Registry("demo") + reg.register("dup")(fit_mean_shift) + + with pytest.raises(ValueError, match="already registered"): + reg.register("dup")(fit_mean_shift) + + +def test_check_requirements_passes_when_none() -> None: + reg = Registry("demo") + # By default, registering without requires parameter does not wrap/check. + reg.register("mean_shift")(fit_mean_shift) + result = reg.get("mean_shift")(np.ones((2, 2)), np.zeros((2, 2))) + assert isinstance(result, _MeanShiftResult) + + +def test_check_requirements_raises_for_missing_dependency() -> None: + reg = Registry("demo") + + @reg.register("needs_ghost", requires=("squidpy_nonexistent_pkg_xyz",)) + def _needs_ghost(ref: np.ndarray, query: np.ndarray) -> _MeanShiftResult: + return fit_mean_shift(ref, query) + + with pytest.raises( + ImportError, match=r"Method 'needs_ghost' requires 'squidpy_nonexistent_pkg_xyz'.*squidpy\[squidpy_nonexistent_pkg_xyz\]" + ): + reg.get("needs_ghost")(np.ones((2, 2)), np.zeros((2, 2))) diff --git a/tests/experimental/_methods/test_stalign.py b/tests/experimental/_methods/test_stalign.py new file mode 100644 index 000000000..960328733 --- /dev/null +++ b/tests/experimental/_methods/test_stalign.py @@ -0,0 +1,110 @@ +"""Integration tests for the ported STalign estimator. + +Tiny synthetic fixtures with ``niter=1`` keep these fast; they verify wiring +and shapes (dispatch -> JAX LDDMM -> StalignResult), not solver quality. +""" + +from __future__ import annotations + +import numpy as np +import pytest + +pytest.importorskip("jax") + +from squidpy.experimental._methods.align_samples import ALIGN_SAMPLES, fit_stalign +from squidpy.experimental._methods.align_samples._stalign_impl._tools import ( + STalignConfig, + STalignPreprocessConfig, + STalignRegistrationConfig, + StalignResult, +) + + +def _points_xy() -> np.ndarray: + return np.array( + [ + [10.0, 1.0], + [12.0, 1.0], + [11.0, 2.0], + [10.0, 3.0], + [12.0, 3.0], + ] + ) + + +def _tiny_config() -> STalignConfig: + """Single-iteration LDDMM hyperparameters - the smallest possible solve.""" + return STalignConfig( + preprocess=STalignPreprocessConfig(dx=0.5, blur=1.0), + registration=STalignRegistrationConfig(a=1.0, expand=1.0, nt=1, niter=1, epV=1.0), + ) + + +def test_stalign_registered_in_align_family() -> None: + assert "stalign" in ALIGN_SAMPLES.keys() + assert ALIGN_SAMPLES.get("stalign") is fit_stalign + + +def test_stalign_fit_returns_diffeomorphism() -> None: + ref, query = _points_xy(), _points_xy() + + result = fit_stalign(ref, query, config=_tiny_config()) + + assert isinstance(result, StalignResult) + assert result.aligned_points.shape == query.shape + assert np.all(np.isfinite(np.asarray(result.aligned_points))) + assert result.affine.shape == (3, 3) + assert result.velocity.ndim == 4 + + +def test_stalign_transform_matches_aligned_points() -> None: + ref, query = _points_xy(), _points_xy() + + result = fit_stalign(ref, query, config=_tiny_config()) + + np.testing.assert_allclose(np.asarray(result.transform(query)), np.asarray(result.aligned_points), atol=1e-5) + + +def test_stalign_transform_accepts_arbitrary_points() -> None: + ref, query = _points_xy(), _points_xy() + result = fit_stalign(ref, query, config=_tiny_config()) + + out = result.transform(np.zeros((1, 2))) + assert np.asarray(out).shape == (1, 2) + + +def test_stalign_transform_backward_inverts_forward() -> None: + ref, query = _points_xy(), _points_xy() + result = fit_stalign(ref, query, config=_tiny_config()) + + forward = result.transform(query, direction="forward") + roundtrip = result.transform(forward, direction="backward") + np.testing.assert_allclose(np.asarray(roundtrip), query, atol=1e-3) + + +def test_stalign_transform_rejects_non_2d() -> None: + ref, query = _points_xy(), _points_xy() + result = fit_stalign(ref, query, config=_tiny_config()) + + with pytest.raises(ValueError, match=r"Expected an \(N, 2\)"): + result.transform(np.zeros((5, 3))) + + +def test_stalign_fit_with_landmarks() -> None: + ref, query = _points_xy(), _points_xy() + landmarks = ref[:3] + + result = fit_stalign( + ref, + query, + config=_tiny_config(), + landmarks_source=landmarks, + landmarks_target=landmarks, + ) + + assert result.aligned_points.shape == query.shape + + +def test_stalign_fit_rejects_non_2d_input() -> None: + with pytest.raises(ValueError, match=r"Expected `query` to have shape `\(n, 2\)`"): + fit_stalign(_points_xy(), np.zeros((5, 3)), config=_tiny_config()) diff --git a/tests/experimental/tl/__init__.py b/tests/experimental/tl/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/experimental/tl/test_align.py b/tests/experimental/tl/test_align.py new file mode 100644 index 000000000..7136be25b --- /dev/null +++ b/tests/experimental/tl/test_align.py @@ -0,0 +1,141 @@ +"""Integration tests for the public ``align`` (STalign) function. + +Tiny synthetic fixtures with ``niter=1`` keep these fast; they verify wiring, +write-back modes, and the key guard -- not solver quality. +""" + +from __future__ import annotations + +import numpy as np +import pytest +from anndata import AnnData + +pytest.importorskip("jax") + +from squidpy.experimental._methods.align_samples._stalign_impl._tools import ( + STalignConfig, + STalignPreprocessConfig, + STalignRegistrationConfig, + StalignResult, +) +from squidpy.experimental.tl import align + + +def _adata(*, key: str = "spatial") -> AnnData: + pts = np.array([[10.0, 1.0], [12.0, 1.0], [11.0, 2.0], [10.0, 3.0], [12.0, 3.0]]) + adata = AnnData(np.zeros((pts.shape[0], 1))) + adata.obsm[key] = pts + return adata + + +def _tiny() -> STalignConfig: + return STalignConfig( + preprocess=STalignPreprocessConfig(dx=0.5, blur=1.0), + registration=STalignRegistrationConfig(a=1.0, expand=1.0, nt=1, niter=1, epV=1.0), + ) + + +def test_object_mode_returns_result_and_touches_nothing() -> None: + ref, query = _adata(), _adata() + result = align(ref, query, method="stalign", output_mode="object", config=_tiny()) + assert isinstance(result, StalignResult) + assert result.aligned_points.shape == query.obsm["spatial"].shape + assert "aligned_spatial" not in query.obsm + + +def test_inplace_writes_explicit_key() -> None: + ref, query = _adata(), _adata() + out = align(ref, query, output_mode="inplace", key_added="spatial_aligned", config=_tiny()) + assert out is None + assert query.obsm["spatial_aligned"].shape == query.obsm["spatial"].shape + + +def test_inplace_default_key() -> None: + ref, query = _adata(), _adata() + align(ref, query, output_mode="inplace", config=_tiny()) + assert "aligned_spatial" in query.obsm + + +def test_copy_leaves_original_untouched() -> None: + ref, query = _adata(), _adata() + out = align(ref, query, output_mode="copy", key_added="aligned", config=_tiny()) + assert isinstance(out, AnnData) and out is not query + assert "aligned" in out.obsm + assert "aligned" not in query.obsm + + +def test_existing_default_key_requires_explicit_key_added() -> None: + ref, query = _adata(), _adata() + query.obsm["aligned_spatial"] = np.zeros_like(query.obsm["spatial"]) + with pytest.raises(ValueError, match="aligned_spatial.*already exists"): + align(ref, query, output_mode="inplace", config=_tiny()) + + +def test_image_not_implemented() -> None: + ref, query = _adata(), _adata() + with pytest.raises(NotImplementedError, match="on='image'"): + align(ref, query, on="image", config=_tiny()) + + +def test_missing_spatial_key() -> None: + ref, query = _adata(), _adata() + with pytest.raises(KeyError, match="missing.*not found"): + align(ref, query, spatial_key="missing", config=_tiny()) + + +def test_public_surface_is_align_result_only() -> None: + import squidpy.experimental.tl as tl + + # `AlignResult` is the only result type exposed; concretes stay in their home modules. + assert "AlignResult" in tl.__all__ + assert not hasattr(tl, "StalignResult") + assert not hasattr(tl, "AffineFitResult") + + +def test_object_result_satisfies_align_result_protocol() -> None: + from squidpy.experimental.tl import AlignResult + + result = align(_adata(), _adata(), method="stalign", output_mode="object", config=_tiny()) + assert isinstance(result, AlignResult) + + +def _sdata_tables(**tables: AnnData): + sd = pytest.importorskip("spatialdata") + from spatialdata.models import TableModel + + return sd.SpatialData(tables={name: TableModel.parse(adata) for name, adata in tables.items()}) + + +def test_sdata_object_mode() -> None: + sdata = _sdata_tables(ref=_adata(), query=_adata()) + result = align(sdata, method="stalign", ref_key="ref", query_key="query", output_mode="object", config=_tiny()) + assert isinstance(result, StalignResult) + assert "aligned_spatial" not in sdata.tables["query"].obsm + + +def test_sdata_copy_leaves_original_untouched() -> None: + sd = pytest.importorskip("spatialdata") + + sdata = _sdata_tables(ref=_adata(), query=_adata()) + out = align(sdata, ref_key="ref", query_key="query", output_mode="copy", key_added="aligned", config=_tiny()) + assert isinstance(out, sd.SpatialData) and out is not sdata + assert "aligned" in out.tables["query"].obsm + assert "aligned" not in sdata.tables["query"].obsm + + +def test_align_with_landmarks() -> None: + ref, query = _adata(), _adata() + landmarks = ref.obsm["spatial"][:3] + + result = align( + ref, + query, + method="stalign", + output_mode="object", + landmarks_source=landmarks, + landmarks_target=landmarks, + config=_tiny(), + ) + + assert isinstance(result, StalignResult) + assert result.aligned_points.shape == query.obsm["spatial"].shape diff --git a/tests/experimental/tl/test_align_by_landmarks.py b/tests/experimental/tl/test_align_by_landmarks.py new file mode 100644 index 000000000..e54d93bf7 --- /dev/null +++ b/tests/experimental/tl/test_align_by_landmarks.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import numpy as np +import pytest +from anndata import AnnData + +from squidpy.experimental._methods.align_landmarks import AffineFitResult +from squidpy.experimental.tl import align_by_landmarks + +# square corners; query = ref shifted by (5, 7) -> a pure translation both models recover +_REF = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]) +_SHIFT = np.array([5.0, 7.0]) +_QUERY = _REF + _SHIFT + + +def _adata(coords: np.ndarray = _QUERY, *, key: str = "spatial") -> AnnData: + adata = AnnData(np.zeros((coords.shape[0], 1))) + adata.obsm[key] = coords.copy() + return adata + + +@pytest.mark.parametrize("method", ["similarity", "affine"]) +def test_object_mode_returns_affine_result(method: str) -> None: + result = align_by_landmarks(_REF, _QUERY, method=method, output_mode="object") + assert isinstance(result, AffineFitResult) + assert result.matrix.shape == (3, 3) + # affine maps query -> ref + np.testing.assert_allclose(result.transform(_QUERY), _REF, atol=1e-6) + assert result.metadata["method"] == method + + +def test_object_is_default() -> None: + assert isinstance(align_by_landmarks(_REF, _QUERY), AffineFitResult) + + +def test_anndata_inplace_writes_default_key() -> None: + adata = _adata() + out = align_by_landmarks(_REF, _QUERY, method="affine", data=adata, output_mode="inplace") + assert out is None + assert "aligned_spatial" in adata.obsm + np.testing.assert_allclose(adata.obsm["aligned_spatial"], _REF, atol=1e-6) + # source coords untouched + np.testing.assert_allclose(adata.obsm["spatial"], _QUERY) + + +def test_anndata_copy_leaves_original_untouched() -> None: + adata = _adata() + out = align_by_landmarks(_REF, _QUERY, method="affine", data=adata, output_mode="copy", key_added="xy_aligned") + assert isinstance(out, AnnData) and out is not adata + assert "xy_aligned" in out.obsm + assert "xy_aligned" not in adata.obsm + + +def test_custom_spatial_key() -> None: + adata = _adata(key="loc") + align_by_landmarks(_REF, _QUERY, method="affine", data=adata, spatial_key="loc", output_mode="inplace") + assert "aligned_loc" in adata.obsm + + +def test_existing_default_key_requires_explicit_key_added() -> None: + adata = _adata() + adata.obsm["aligned_spatial"] = np.zeros_like(_QUERY) + with pytest.raises(ValueError, match="aligned_spatial.*already exists"): + align_by_landmarks(_REF, _QUERY, method="affine", data=adata, output_mode="inplace") + # explicit key_added overwrites without error + align_by_landmarks(_REF, _QUERY, method="affine", data=adata, output_mode="inplace", key_added="aligned_spatial") + np.testing.assert_allclose(adata.obsm["aligned_spatial"], _REF, atol=1e-6) + + +def test_write_mode_requires_data() -> None: + with pytest.raises(ValueError, match="`data` is required"): + align_by_landmarks(_REF, _QUERY, method="affine", output_mode="inplace") + + +def test_too_few_landmarks() -> None: + with pytest.raises(ValueError, match="at least 3 landmark pairs"): + align_by_landmarks(_REF[:2], _QUERY[:2], method="affine", output_mode="object") + + +def test_length_mismatch() -> None: + with pytest.raises(ValueError, match="same shape"): + align_by_landmarks(_REF, _QUERY[:3], method="affine", output_mode="object") + + +def test_unknown_method_lists_available() -> None: + with pytest.raises(ValueError, match=r"Unknown align_landmarks method 'nope'"): + align_by_landmarks(_REF, _QUERY, method="nope", output_mode="object") + + +def test_non_finite_landmarks_rejected() -> None: + bad = _QUERY.copy() + bad[0, 0] = np.nan + with pytest.raises(ValueError, match="must contain only finite values"): + align_by_landmarks(_REF, bad, method="affine", output_mode="object") + + +def test_bad_data_type_raises() -> None: + with pytest.raises(TypeError, match="must be AnnData or SpatialData"): + align_by_landmarks(_REF, _QUERY, method="affine", data=object(), output_mode="inplace") # type: ignore[arg-type] + + +def test_spatialdata_copy_leaves_original_untouched() -> None: + sd = pytest.importorskip("spatialdata") + from spatialdata.models import PointsModel + from spatialdata.transformations import Identity, get_transformation + + pts = PointsModel.parse(_QUERY, transformations={"query_cs": Identity()}) + sdata = sd.SpatialData(points={"pts": pts}) + + out = align_by_landmarks( + _REF, + _QUERY, + method="affine", + data=sdata, + cs_name_query="query_cs", + cs_name_ref="ref_cs", + output_mode="copy", + ) + assert out is not sdata + assert "ref_cs" in get_transformation(out.points["pts"], get_all=True) + assert "ref_cs" not in get_transformation(sdata.points["pts"], get_all=True) + + +def test_spatialdata_registers_transformation() -> None: + sd = pytest.importorskip("spatialdata") + from spatialdata.models import PointsModel + from spatialdata.transformations import Identity, get_transformation + + pts = PointsModel.parse(_QUERY, transformations={"query_cs": Identity()}) + sdata = sd.SpatialData(points={"pts": pts}) + + out = align_by_landmarks( + _REF, + _QUERY, + method="affine", + data=sdata, + cs_name_query="query_cs", + cs_name_ref="ref_cs", + output_mode="inplace", + ) + assert out is None + assert "ref_cs" in get_transformation(sdata.points["pts"], get_all=True) diff --git a/tests/experimental/tl/test_align_io.py b/tests/experimental/tl/test_align_io.py new file mode 100644 index 000000000..409364471 --- /dev/null +++ b/tests/experimental/tl/test_align_io.py @@ -0,0 +1,297 @@ +"""Unit tests for the I/O layer of the public ``align`` functions. + +These exercise input resolution and write-back directly -- no estimator, no JAX -- +so they cover the ``AnnData`` *and* ``SpatialData`` branches and every error guard +in :mod:`squidpy.experimental.tl._align._io` cheaply. +""" + +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +import pytest +from anndata import AnnData + +from squidpy.experimental._methods.align_landmarks import AffineFitResult +from squidpy.experimental.tl._align._io import ( + get_coords, + resolve_obs_pair, + shallow_copy_sdata, + writeback_affine_sdata, + writeback_obs, +) + +_PTS = np.array([[10.0, 1.0], [12.0, 1.0], [11.0, 2.0], [10.0, 3.0], [12.0, 3.0]]) + + +@dataclass +class _ShiftResult: + """Minimal :class:`AlignResult`: a constant offset baked into ``transform``.""" + + delta: float = 100.0 + + def transform(self, points: np.ndarray) -> np.ndarray: + return np.asarray(points, dtype=float) + self.delta + + +def _adata(coords: np.ndarray = _PTS, *, key: str = "spatial") -> AnnData: + adata = AnnData(np.zeros((coords.shape[0], 1))) + adata.obsm[key] = coords.copy() + return adata + + +def _sdata_tables(**tables: AnnData): + sd = pytest.importorskip("spatialdata") + from spatialdata.models import TableModel + + return sd.SpatialData(tables={name: TableModel.parse(adata) for name, adata in tables.items()}) + + +def _sdata_points(cs: str = "qcs"): + sd = pytest.importorskip("spatialdata") + from spatialdata.models import PointsModel + from spatialdata.transformations import Identity + + pts = PointsModel.parse(_PTS, transformations={cs: Identity()}) + return sd.SpatialData(points={"pts": pts}) + + +# --------------------------------------------------------------------------- +# resolve_obs_pair +# --------------------------------------------------------------------------- + + +def test_resolve_anndata_pair() -> None: + ref, query = _adata(), _adata() + r_adata, q_adata, container, element_key = resolve_obs_pair(ref, query, None, None) + assert r_adata is ref and q_adata is query + assert container is None and element_key is None + + +def test_resolve_anndata_requires_query() -> None: + with pytest.raises(ValueError, match="`data_query` is required when `data_ref` is an AnnData"): + resolve_obs_pair(_adata(), None, None, None) + + +def test_resolve_anndata_rejects_mixed_query() -> None: + with pytest.raises(TypeError, match="Mixed AnnData/SpatialData"): + resolve_obs_pair(_adata(), _sdata_points(), None, None) + + +def test_resolve_anndata_rejects_keys() -> None: + with pytest.raises(ValueError, match="only valid for SpatialData"): + resolve_obs_pair(_adata(), _adata(), "ref", None) + + +def test_resolve_bad_ref_type() -> None: + with pytest.raises(TypeError, match="must be AnnData or SpatialData"): + resolve_obs_pair(object(), _adata(), None, None) # type: ignore[arg-type] + + +def test_resolve_sdata_pair() -> None: + pytest.importorskip("spatialdata") + ref_sd = _sdata_tables(ref=_adata()) + query_sd = _sdata_tables(query=_adata(_PTS + 5)) + r_adata, q_adata, container, element_key = resolve_obs_pair(ref_sd, query_sd, "ref", "query") + assert r_adata is ref_sd.tables["ref"] + assert q_adata is query_sd.tables["query"] + assert container is query_sd + assert element_key == "query" + + +def test_resolve_sdata_single_two_tables() -> None: + pytest.importorskip("spatialdata") + both = _sdata_tables(ref=_adata(), query=_adata(_PTS + 5)) + r_adata, q_adata, container, element_key = resolve_obs_pair(both, None, "ref", "query") + assert r_adata is both.tables["ref"] + assert q_adata is both.tables["query"] + assert container is both + assert element_key == "query" + + +def test_resolve_sdata_requires_keys() -> None: + pytest.importorskip("spatialdata") + both = _sdata_tables(ref=_adata(), query=_adata()) + with pytest.raises(ValueError, match="`ref_key` and `query_key` are required"): + resolve_obs_pair(both, None, None, None) + + +def test_resolve_sdata_rejects_mixed_query() -> None: + pytest.importorskip("spatialdata") + with pytest.raises(TypeError, match="Mixed AnnData/SpatialData"): + resolve_obs_pair(_sdata_tables(ref=_adata()), _adata(), "ref", "query") + + +def test_resolve_sdata_missing_key() -> None: + pytest.importorskip("spatialdata") + both = _sdata_tables(ref=_adata()) + with pytest.raises(KeyError, match="nope"): + resolve_obs_pair(both, None, "nope", "ref") + + +# --------------------------------------------------------------------------- +# get_coords +# --------------------------------------------------------------------------- + + +def test_get_coords_missing_key() -> None: + with pytest.raises(KeyError, match="missing.*not found"): + get_coords(_adata(), "missing") + + +def test_get_coords_rejects_non_2d() -> None: + adata = _adata() + adata.obsm["bad"] = np.zeros((_PTS.shape[0], 3)) + with pytest.raises(ValueError, match=r"must be an \(N, 2\) array"): + get_coords(adata, "bad") + + +# --------------------------------------------------------------------------- +# writeback_obs -- AnnData +# --------------------------------------------------------------------------- + + +def test_writeback_obs_object_mode_returns_result() -> None: + result = _ShiftResult() + out = writeback_obs( + result, + output_mode="object", + query_adata=_adata(), + container=None, + element_key=None, + spatial_key="spatial", + key_added=None, + ) + assert out is result + + +def test_writeback_obs_anndata_inplace() -> None: + query = _adata() + out = writeback_obs( + _ShiftResult(), + output_mode="inplace", + query_adata=query, + container=None, + element_key=None, + spatial_key="spatial", + key_added="aligned", + ) + assert out is None + np.testing.assert_allclose(query.obsm["aligned"], _PTS + 100.0) + + +def test_writeback_obs_anndata_copy_leaves_original_untouched() -> None: + query = _adata() + out = writeback_obs( + _ShiftResult(), + output_mode="copy", + query_adata=query, + container=None, + element_key=None, + spatial_key="spatial", + key_added="aligned", + ) + assert isinstance(out, AnnData) and out is not query + assert "aligned" in out.obsm + assert "aligned" not in query.obsm + + +# --------------------------------------------------------------------------- +# writeback_obs -- SpatialData +# --------------------------------------------------------------------------- + + +def test_writeback_obs_sdata_inplace() -> None: + pytest.importorskip("spatialdata") + sdata = _sdata_tables(query=_adata()) + out = writeback_obs( + _ShiftResult(), + output_mode="inplace", + query_adata=sdata.tables["query"], + container=sdata, + element_key="query", + spatial_key="spatial", + key_added="aligned", + ) + assert out is None + np.testing.assert_allclose(sdata.tables["query"].obsm["aligned"], _PTS + 100.0) + + +def test_writeback_obs_sdata_copy_leaves_original_untouched() -> None: + pytest.importorskip("spatialdata") + sdata = _sdata_tables(query=_adata()) + out = writeback_obs( + _ShiftResult(), + output_mode="copy", + query_adata=sdata.tables["query"], + container=sdata, + element_key="query", + spatial_key="spatial", + key_added="aligned", + ) + assert out is not sdata + assert "aligned" in out.tables["query"].obsm + # regression: copy must not leak the new key back into the input container + assert "aligned" not in sdata.tables["query"].obsm + + +# --------------------------------------------------------------------------- +# writeback_affine_sdata +# --------------------------------------------------------------------------- + + +def test_writeback_affine_inplace_registers_transform() -> None: + pytest.importorskip("spatialdata") + from spatialdata.transformations import get_transformation + + sdata = _sdata_points() + out = writeback_affine_sdata( + AffineFitResult(matrix=np.eye(3)), sdata, output_mode="inplace", moving_cs="qcs", target_cs="tcs" + ) + assert out is None + assert "tcs" in get_transformation(sdata.points["pts"], get_all=True) + + +def test_writeback_affine_copy_leaves_original_untouched() -> None: + pytest.importorskip("spatialdata") + from spatialdata.transformations import get_transformation + + sdata = _sdata_points() + out = writeback_affine_sdata( + AffineFitResult(matrix=np.eye(3)), sdata, output_mode="copy", moving_cs="qcs", target_cs="tcs" + ) + assert out is not sdata + assert "tcs" in get_transformation(out.points["pts"], get_all=True) + # regression: copy must not register the transform on the input container + assert "tcs" not in get_transformation(sdata.points["pts"], get_all=True) + + +def test_writeback_affine_requires_cs_names() -> None: + pytest.importorskip("spatialdata") + with pytest.raises(ValueError, match="`cs_name_query` and `cs_name_ref` are required"): + writeback_affine_sdata( + AffineFitResult(matrix=np.eye(3)), _sdata_points(), output_mode="inplace", moving_cs=None, target_cs="tcs" + ) + + +def test_writeback_affine_no_matching_cs() -> None: + pytest.importorskip("spatialdata") + sdata = _sdata_points(cs="qcs") + with pytest.raises(KeyError, match="No elements .* registered to coordinate system 'other'"): + writeback_affine_sdata( + AffineFitResult(matrix=np.eye(3)), sdata, output_mode="inplace", moving_cs="other", target_cs="tcs" + ) + + +# --------------------------------------------------------------------------- +# shallow_copy_sdata +# --------------------------------------------------------------------------- + + +def test_shallow_copy_sdata_preserves_elements() -> None: + pytest.importorskip("spatialdata") + sdata = _sdata_tables(ref=_adata(), query=_adata(_PTS + 5)) + copy = shallow_copy_sdata(sdata) + assert copy is not sdata + assert set(copy.tables) == {"ref", "query"}