From e62493948004da5214c5b35ab8fa995c52d838bb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 14:02:32 +0200 Subject: [PATCH 1/7] perf: chan's parallel mean-var algorithm for dask --- benchmarks/asv.conf.json | 1 + benchmarks/benchmarks/preprocessing_counts.py | 34 ++++- src/scanpy/get/_aggregated.py | 134 ++++++++++++++++-- 3 files changed, 151 insertions(+), 18 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..a1a8d31a42 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,6 +84,7 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 9a20e7eda3..672a7df5fc 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import anndata as ad +import zarr import scanpy as sc from scanpy._utils import get_literal_vals @@ -18,6 +19,7 @@ from ._utils import get_count_dataset, get_dataset if TYPE_CHECKING: + from collections.abc import KeysView from typing import Any from ._utils import Dataset, KeyCount @@ -151,17 +153,35 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[AggType] = tuple(get_literal_vals(AggType)) - param_names = ("agg_name",) + params: tuple[KeysView[AggType], tuple[bool]] = ( + get_literal_vals(AggType), + (True, False), + ) + param_names = ("agg_name", "use_dask") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") - adata.write_h5ad("lung93k.h5ad") - - def setup(self, agg_name: AggType) -> None: - self.adata = ad.read_h5ad("lung93k.h5ad") - self.agg_name = agg_name + adata.write_zarr("lung93k.zarr") + + def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 + if use_dask: + z = zarr.open("lung93k_shuffled.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) + # Times out on the benchmark machine with full dataset + self.adata = self.adata[ + self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) + ].copy() + else: + self.adata = ad.read_zarr("lung93k.zarr") + self.agg_name: AggType = agg_name def time_agg(self, *_) -> None: sc.get.aggregate( diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..66f251c316 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from anndata import AnnData -from fast_array_utils.stats._power import power as fau_power # TODO: upstream from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -371,16 +370,129 @@ def aggregate_dask_mean_var( mask: NDArray[np.bool] | None = None, dof: int = 1, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] - # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. - if isinstance(data._meta, CSRBase): - sq_mean = sq_mean.compute() - var = sq_mean - fau_power(mean, 2) - if dof != 0: - group_counts = np.bincount(by.codes) - var *= (group_counts / (group_counts - dof))[:, np.newaxis] - return MeanVarDict(mean=mean, var=var) + """Compute group-wise mean and variance for a dask array. + + Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``), + then combine across chunks with the pairwise parallel algorithm from + Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm) + so the across-chunk reduction avoids the catastrophic cancellation of + ``E[X**2] - E[X]**2``. + """ + import dask.array as da + + n_categories = len(by.categories) + n_features = data.shape[1] + chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1 + + if chunked_axis == 1: + # Each block already sees every observation, so mean/var per chunk is final. + def per_block_col(chunk: Array) -> NDArray[np.float64]: + mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof) + return np.concatenate([mean_, var_], axis=0) + + combined = data.map_blocks( + per_block_col, + chunks=((2 * n_categories,), data.chunks[1]), + meta=np.array([], dtype=np.float64), + ) + return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:]) + + n_blocks = data.numblocks[0] + + def per_block_row( + chunk: Array, block_info: dict | None = None + ) -> NDArray[np.float64]: + row_subset = slice(*block_info[0]["array-location"][0]) + by_sub = by[row_subset] + mask_sub = mask[row_subset] if mask is not None else None + return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[ + None + ] + + per_block_stats = data.map_blocks( + per_block_row, + chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)), + new_axis=(1, 2), + meta=np.array([], dtype=np.float64), + ) + + combined = da.reduction( + per_block_stats, + chunk=lambda x, axis=None, keepdims=False: x, + aggregate=_chan_reduce_axis_0, + axis=0, + keepdims=False, + concatenate=True, + dtype=np.float64, + meta=np.array([], dtype=np.float64), + ) + counts = combined[0] + mean_ = combined[1] + m2 = combined[2] + denom = counts - dof if dof > 0 else counts + return MeanVarDict(mean=mean_, var=m2 / denom) + + +def _block_moments( + data: np.ndarray | CSBase, + by: pd.Categorical, + *, + mask: NDArray[np.bool] | None, + n_categories: int, +) -> NDArray[np.float64]: + """Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``. + + Groups with no observations in the chunk get zeros for mean and M2 so + they combine cleanly under ``_chan_combine``. + """ + codes = by.codes + valid = codes >= 0 + if mask is not None: + valid = valid & mask + counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64) + + out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64) + out[0] = counts[:, None] + nonempty = counts > 0 + if not nonempty.any(): + return out + + agg = Aggregate(groupby=by, data=data, mask=mask) + sum_ = agg.sum() + sum_sq = agg._sum(_power(data, 2)) + safe_counts = np.where(nonempty, counts, 1)[:, None] + mean_ = sum_ / safe_counts + # M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0. + m2 = np.maximum(sum_sq - sum_ * mean_, 0) + out[1, nonempty] = mean_[nonempty] + out[2, nonempty] = m2[nonempty] + return out + + +def _chan_combine( + a: NDArray[np.float64], b: NDArray[np.float64] +) -> NDArray[np.float64]: + """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" + n_a, mean_a, m2_a = a[0], a[1], a[2] + n_b, mean_b, m2_b = b[0], b[1], b[2] + n = n_a + n_b + safe_n = np.where(n > 0, n, 1) + delta = mean_b - mean_a + new_mean = mean_a + delta * n_b / safe_n + new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n + return np.stack([n, new_mean, new_m2]) + + +def _chan_reduce_axis_0( + stats: NDArray[np.float64], + axis: int | None, + keepdims: bool, # noqa: FBT001 +) -> NDArray[np.float64]: + """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" + result = stats[0] + for i in range(1, stats.shape[0]): + result = _chan_combine(result, stats[i]) + return result[None] if keepdims else result @_aggregate.register(DaskArray) From 61332fd103d7d904136e701b9f831eca6bbe041f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 15:08:26 +0200 Subject: [PATCH 2/7] fix: params --- benchmarks/benchmarks/preprocessing_counts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 672a7df5fc..db6c23f0f0 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -153,9 +153,9 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[KeysView[AggType], tuple[bool]] = ( + params: tuple[KeysView[AggType], list[bool]] = ( get_literal_vals(AggType), - (True, False), + [True, False], ) param_names = ("agg_name", "use_dask") From 1df5fdaed65ce603511c7be00cacc35126f61d3b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 15:30:21 +0200 Subject: [PATCH 3/7] fix: iteration --- benchmarks/benchmarks/preprocessing_counts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index db6c23f0f0..1809c6d3bb 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -19,7 +19,6 @@ from ._utils import get_count_dataset, get_dataset if TYPE_CHECKING: - from collections.abc import KeysView from typing import Any from ._utils import Dataset, KeyCount @@ -153,8 +152,8 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[KeysView[AggType], list[bool]] = ( - get_literal_vals(AggType), + params: tuple[list[str], list[bool]] = ( + list(get_literal_vals(AggType)), [True, False], ) param_names = ("agg_name", "use_dask") From 9a705812bce5d556678a7717004380a89454dc54 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 15:50:18 +0200 Subject: [PATCH 4/7] fix: zarr link --- benchmarks/benchmarks/preprocessing_counts.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 1809c6d3bb..a17d96e9c9 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -165,7 +165,7 @@ def setup_cache(self) -> None: def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 if use_dask: - z = zarr.open("lung93k_shuffled.zarr") + z = zarr.open("lung93k.zarr") self.adata = ad.AnnData( obs=ad.io.read_elem(z["obs"]), var=ad.io.read_elem(z["var"]), @@ -174,10 +174,6 @@ def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 }, X=ad.experimental.read_elem_lazy(z["X"]), ) - # Times out on the benchmark machine with full dataset - self.adata = self.adata[ - self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) - ].copy() else: self.adata = ad.read_zarr("lung93k.zarr") self.agg_name: AggType = agg_name From 5313ea24ab162fb994cf2a16b5dfe34151c0b411 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 16:30:59 +0200 Subject: [PATCH 5/7] fix: `median` calculation skipped --- benchmarks/benchmarks/preprocessing_counts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index a17d96e9c9..bc3bedce63 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -165,6 +165,9 @@ def setup_cache(self) -> None: def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 if use_dask: + if agg_name == "median": + # Skip this one: https://asv.readthedocs.io/en/stable/writing_benchmarks.html#setup-and-teardown-functions + raise NotImplementedError() z = zarr.open("lung93k.zarr") self.adata = ad.AnnData( obs=ad.io.read_elem(z["obs"]), From a2b390b88e83112cb981570afa08200b72687e48 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 8 Jun 2026 13:13:17 +0200 Subject: [PATCH 6/7] chore: relnote --- docs/release-notes/4143.perf.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/release-notes/4143.perf.md diff --git a/docs/release-notes/4143.perf.md b/docs/release-notes/4143.perf.md new file mode 100644 index 0000000000..2ffd9c00d6 --- /dev/null +++ b/docs/release-notes/4143.perf.md @@ -0,0 +1,3 @@ +Use Chan's mean-var algorithm for acceleration of dask-backed {func}`scanpy.get.aggregate` {smaller}`I Gold` + +[Chan's mean-var]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm From b71eb686174bd998ddeefd4e0034bcc186c0a154 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Thu, 11 Jun 2026 12:00:14 -0700 Subject: [PATCH 7/7] njit support for chan algorithm (#4153) Co-authored-by: Ilan Gold Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/scanpy/get/_aggregated.py | 41 ++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 66f251c316..b05b80e435 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -3,9 +3,11 @@ from functools import partial, singledispatch from typing import TYPE_CHECKING, Literal, TypedDict, get_args +import numba import numpy as np import pandas as pd from anndata import AnnData +from fast_array_utils.numba import njit from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -469,18 +471,37 @@ def _block_moments( return out -def _chan_combine( +@numba.njit(inline="always") # noqa: TID251 +def _chan_combine( # noqa: PLR0917 + n_a: float, mean_a: float, m2_a: float, n_b: float, mean_b: float, m2_b: float +) -> tuple[float, float, float]: + """Combine two ``(count, mean, M2)`` groups pairwise.""" + if n_a == 0.0: + return n_b, mean_b, m2_b + if n_b == 0.0: + return n_a, mean_a, m2_a + n = n_a + n_b + delta = mean_b - mean_a + return n, mean_a + delta * n_b / n, m2_a + m2_b + delta * delta * n_a * n_b / n + + +@njit +def _chan_combine_blocks( a: NDArray[np.float64], b: NDArray[np.float64] ) -> NDArray[np.float64]: """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" - n_a, mean_a, m2_a = a[0], a[1], a[2] - n_b, mean_b, m2_b = b[0], b[1], b[2] - n = n_a + n_b - safe_n = np.where(n > 0, n, 1) - delta = mean_b - mean_a - new_mean = mean_a + delta * n_b / safe_n - new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n - return np.stack([n, new_mean, new_m2]) + out = np.empty_like(a) + for i in numba.prange(a.shape[1]): + for j in range(a.shape[2]): + out[0, i, j], out[1, i, j], out[2, i, j] = _chan_combine( + a[0, i, j], + a[1, i, j], + a[2, i, j], + b[0, i, j], + b[1, i, j], + b[2, i, j], + ) + return out def _chan_reduce_axis_0( @@ -491,7 +512,7 @@ def _chan_reduce_axis_0( """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" result = stats[0] for i in range(1, stats.shape[0]): - result = _chan_combine(result, stats[i]) + result = _chan_combine_blocks(result, stats[i]) return result[None] if keepdims else result