diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..a1a8d31a42 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,6 +84,7 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 9a20e7eda3..bc3bedce63 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import anndata as ad +import zarr import scanpy as sc from scanpy._utils import get_literal_vals @@ -151,17 +152,34 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[AggType] = tuple(get_literal_vals(AggType)) - param_names = ("agg_name",) + params: tuple[list[str], list[bool]] = ( + list(get_literal_vals(AggType)), + [True, False], + ) + param_names = ("agg_name", "use_dask") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") - adata.write_h5ad("lung93k.h5ad") - - def setup(self, agg_name: AggType) -> None: - self.adata = ad.read_h5ad("lung93k.h5ad") - self.agg_name = agg_name + adata.write_zarr("lung93k.zarr") + + def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 + if use_dask: + if agg_name == "median": + # Skip this one: https://asv.readthedocs.io/en/stable/writing_benchmarks.html#setup-and-teardown-functions + raise NotImplementedError() + z = zarr.open("lung93k.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) + else: + self.adata = ad.read_zarr("lung93k.zarr") + self.agg_name: AggType = agg_name def time_agg(self, *_) -> None: sc.get.aggregate( diff --git a/docs/release-notes/4143.perf.md b/docs/release-notes/4143.perf.md new file mode 100644 index 0000000000..2ffd9c00d6 --- /dev/null +++ b/docs/release-notes/4143.perf.md @@ -0,0 +1,3 @@ +Use Chan's mean-var algorithm for acceleration of dask-backed {func}`scanpy.get.aggregate` {smaller}`I Gold` + +[Chan's mean-var]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..b05b80e435 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -3,10 +3,11 @@ from functools import partial, singledispatch from typing import TYPE_CHECKING, Literal, TypedDict, get_args +import numba import numpy as np import pandas as pd from anndata import AnnData -from fast_array_utils.stats._power import power as fau_power # TODO: upstream +from fast_array_utils.numba import njit from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -371,16 +372,148 @@ def aggregate_dask_mean_var( mask: NDArray[np.bool] | None = None, dof: int = 1, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] - # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. - if isinstance(data._meta, CSRBase): - sq_mean = sq_mean.compute() - var = sq_mean - fau_power(mean, 2) - if dof != 0: - group_counts = np.bincount(by.codes) - var *= (group_counts / (group_counts - dof))[:, np.newaxis] - return MeanVarDict(mean=mean, var=var) + """Compute group-wise mean and variance for a dask array. + + Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``), + then combine across chunks with the pairwise parallel algorithm from + Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm) + so the across-chunk reduction avoids the catastrophic cancellation of + ``E[X**2] - E[X]**2``. + """ + import dask.array as da + + n_categories = len(by.categories) + n_features = data.shape[1] + chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1 + + if chunked_axis == 1: + # Each block already sees every observation, so mean/var per chunk is final. + def per_block_col(chunk: Array) -> NDArray[np.float64]: + mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof) + return np.concatenate([mean_, var_], axis=0) + + combined = data.map_blocks( + per_block_col, + chunks=((2 * n_categories,), data.chunks[1]), + meta=np.array([], dtype=np.float64), + ) + return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:]) + + n_blocks = data.numblocks[0] + + def per_block_row( + chunk: Array, block_info: dict | None = None + ) -> NDArray[np.float64]: + row_subset = slice(*block_info[0]["array-location"][0]) + by_sub = by[row_subset] + mask_sub = mask[row_subset] if mask is not None else None + return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[ + None + ] + + per_block_stats = data.map_blocks( + per_block_row, + chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)), + new_axis=(1, 2), + meta=np.array([], dtype=np.float64), + ) + + combined = da.reduction( + per_block_stats, + chunk=lambda x, axis=None, keepdims=False: x, + aggregate=_chan_reduce_axis_0, + axis=0, + keepdims=False, + concatenate=True, + dtype=np.float64, + meta=np.array([], dtype=np.float64), + ) + counts = combined[0] + mean_ = combined[1] + m2 = combined[2] + denom = counts - dof if dof > 0 else counts + return MeanVarDict(mean=mean_, var=m2 / denom) + + +def _block_moments( + data: np.ndarray | CSBase, + by: pd.Categorical, + *, + mask: NDArray[np.bool] | None, + n_categories: int, +) -> NDArray[np.float64]: + """Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``. + + Groups with no observations in the chunk get zeros for mean and M2 so + they combine cleanly under ``_chan_combine``. + """ + codes = by.codes + valid = codes >= 0 + if mask is not None: + valid = valid & mask + counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64) + + out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64) + out[0] = counts[:, None] + nonempty = counts > 0 + if not nonempty.any(): + return out + + agg = Aggregate(groupby=by, data=data, mask=mask) + sum_ = agg.sum() + sum_sq = agg._sum(_power(data, 2)) + safe_counts = np.where(nonempty, counts, 1)[:, None] + mean_ = sum_ / safe_counts + # M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0. + m2 = np.maximum(sum_sq - sum_ * mean_, 0) + out[1, nonempty] = mean_[nonempty] + out[2, nonempty] = m2[nonempty] + return out + + +@numba.njit(inline="always") # noqa: TID251 +def _chan_combine( # noqa: PLR0917 + n_a: float, mean_a: float, m2_a: float, n_b: float, mean_b: float, m2_b: float +) -> tuple[float, float, float]: + """Combine two ``(count, mean, M2)`` groups pairwise.""" + if n_a == 0.0: + return n_b, mean_b, m2_b + if n_b == 0.0: + return n_a, mean_a, m2_a + n = n_a + n_b + delta = mean_b - mean_a + return n, mean_a + delta * n_b / n, m2_a + m2_b + delta * delta * n_a * n_b / n + + +@njit +def _chan_combine_blocks( + a: NDArray[np.float64], b: NDArray[np.float64] +) -> NDArray[np.float64]: + """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" + out = np.empty_like(a) + for i in numba.prange(a.shape[1]): + for j in range(a.shape[2]): + out[0, i, j], out[1, i, j], out[2, i, j] = _chan_combine( + a[0, i, j], + a[1, i, j], + a[2, i, j], + b[0, i, j], + b[1, i, j], + b[2, i, j], + ) + return out + + +def _chan_reduce_axis_0( + stats: NDArray[np.float64], + axis: int | None, + keepdims: bool, # noqa: FBT001 +) -> NDArray[np.float64]: + """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" + result = stats[0] + for i in range(1, stats.shape[0]): + result = _chan_combine_blocks(result, stats[i]) + return result[None] if keepdims else result @_aggregate.register(DaskArray)