diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..66f251c316 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from anndata import AnnData -from fast_array_utils.stats._power import power as fau_power # TODO: upstream from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -371,16 +370,129 @@ def aggregate_dask_mean_var( mask: NDArray[np.bool] | None = None, dof: int = 1, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] - # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. - if isinstance(data._meta, CSRBase): - sq_mean = sq_mean.compute() - var = sq_mean - fau_power(mean, 2) - if dof != 0: - group_counts = np.bincount(by.codes) - var *= (group_counts / (group_counts - dof))[:, np.newaxis] - return MeanVarDict(mean=mean, var=var) + """Compute group-wise mean and variance for a dask array. + + Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``), + then combine across chunks with the pairwise parallel algorithm from + Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm) + so the across-chunk reduction avoids the catastrophic cancellation of + ``E[X**2] - E[X]**2``. + """ + import dask.array as da + + n_categories = len(by.categories) + n_features = data.shape[1] + chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1 + + if chunked_axis == 1: + # Each block already sees every observation, so mean/var per chunk is final. + def per_block_col(chunk: Array) -> NDArray[np.float64]: + mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof) + return np.concatenate([mean_, var_], axis=0) + + combined = data.map_blocks( + per_block_col, + chunks=((2 * n_categories,), data.chunks[1]), + meta=np.array([], dtype=np.float64), + ) + return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:]) + + n_blocks = data.numblocks[0] + + def per_block_row( + chunk: Array, block_info: dict | None = None + ) -> NDArray[np.float64]: + row_subset = slice(*block_info[0]["array-location"][0]) + by_sub = by[row_subset] + mask_sub = mask[row_subset] if mask is not None else None + return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[ + None + ] + + per_block_stats = data.map_blocks( + per_block_row, + chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)), + new_axis=(1, 2), + meta=np.array([], dtype=np.float64), + ) + + combined = da.reduction( + per_block_stats, + chunk=lambda x, axis=None, keepdims=False: x, + aggregate=_chan_reduce_axis_0, + axis=0, + keepdims=False, + concatenate=True, + dtype=np.float64, + meta=np.array([], dtype=np.float64), + ) + counts = combined[0] + mean_ = combined[1] + m2 = combined[2] + denom = counts - dof if dof > 0 else counts + return MeanVarDict(mean=mean_, var=m2 / denom) + + +def _block_moments( + data: np.ndarray | CSBase, + by: pd.Categorical, + *, + mask: NDArray[np.bool] | None, + n_categories: int, +) -> NDArray[np.float64]: + """Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``. + + Groups with no observations in the chunk get zeros for mean and M2 so + they combine cleanly under ``_chan_combine``. + """ + codes = by.codes + valid = codes >= 0 + if mask is not None: + valid = valid & mask + counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64) + + out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64) + out[0] = counts[:, None] + nonempty = counts > 0 + if not nonempty.any(): + return out + + agg = Aggregate(groupby=by, data=data, mask=mask) + sum_ = agg.sum() + sum_sq = agg._sum(_power(data, 2)) + safe_counts = np.where(nonempty, counts, 1)[:, None] + mean_ = sum_ / safe_counts + # M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0. + m2 = np.maximum(sum_sq - sum_ * mean_, 0) + out[1, nonempty] = mean_[nonempty] + out[2, nonempty] = m2[nonempty] + return out + + +def _chan_combine( + a: NDArray[np.float64], b: NDArray[np.float64] +) -> NDArray[np.float64]: + """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" + n_a, mean_a, m2_a = a[0], a[1], a[2] + n_b, mean_b, m2_b = b[0], b[1], b[2] + n = n_a + n_b + safe_n = np.where(n > 0, n, 1) + delta = mean_b - mean_a + new_mean = mean_a + delta * n_b / safe_n + new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n + return np.stack([n, new_mean, new_m2]) + + +def _chan_reduce_axis_0( + stats: NDArray[np.float64], + axis: int | None, + keepdims: bool, # noqa: FBT001 +) -> NDArray[np.float64]: + """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" + result = stats[0] + for i in range(1, stats.shape[0]): + result = _chan_combine(result, stats[i]) + return result[None] if keepdims else result @_aggregate.register(DaskArray) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index f588aa019b..5cc5b98b2c 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -6,6 +6,7 @@ from inspect import signature from typing import TYPE_CHECKING, TypedDict, cast +import anndata as ad import numba import numpy as np import pandas as pd @@ -178,16 +179,29 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info} ), ) - aggregated_mean_var = aggregate( - adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] - ) - mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) - if isinstance(mean_global, DaskArray): - import dask.array as da + if batch_key is not None: + aggregated_mean_var = aggregate( + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] + ) + mean_global, var_global = ( + aggregated_mean_var.layers[l] for l in ["mean", "var"] + ) + if isinstance(mean_global, DaskArray): + import dask.array as da - mean_global, var_global = da.compute(mean_global, var_global) - aggregated_mean_var.layers["mean"] = mean_global - aggregated_mean_var.layers["var"] = var_global + mean_global, var_global = da.compute(mean_global, var_global) + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global + else: + aggregated_mean_var = ad.AnnData( + obs=pd.DataFrame( + index=np.array(["one"]), data={"__hvg_v3_batch_info__": np.array([0])} + ), + layers={ + "mean": df["means"].to_numpy().reshape((1, -1)), + "var": df["variances"].to_numpy().reshape((1, -1)), + }, + ) batch_info = batch_info.to_numpy() for b in np.unique(batch_info): data_batch = data[batch_info == b] @@ -752,7 +766,6 @@ def highly_variable_genes( # noqa: PLR0913 from .. import settings flavor = settings.preset.highly_variable_genes.flavor - start = logg.info("extracting highly variable genes") if not isinstance(adata, AnnData):