Skip to content
134 changes: 123 additions & 11 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

Expand Down Expand Up @@ -371,16 +370,129 @@ def aggregate_dask_mean_var(
mask: NDArray[np.bool] | None = None,
dof: int = 1,
) -> MeanVarDict:
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
return MeanVarDict(mean=mean, var=var)
"""Compute group-wise mean and variance for a dask array.

Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``),
then combine across chunks with the pairwise parallel algorithm from
Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm)
so the across-chunk reduction avoids the catastrophic cancellation of
``E[X**2] - E[X]**2``.
"""
import dask.array as da

n_categories = len(by.categories)
n_features = data.shape[1]
chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1

if chunked_axis == 1:
# Each block already sees every observation, so mean/var per chunk is final.
def per_block_col(chunk: Array) -> NDArray[np.float64]:
mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof)
return np.concatenate([mean_, var_], axis=0)

combined = data.map_blocks(
per_block_col,
chunks=((2 * n_categories,), data.chunks[1]),
meta=np.array([], dtype=np.float64),
)
return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:])

n_blocks = data.numblocks[0]

def per_block_row(
chunk: Array, block_info: dict | None = None
) -> NDArray[np.float64]:
row_subset = slice(*block_info[0]["array-location"][0])
by_sub = by[row_subset]
mask_sub = mask[row_subset] if mask is not None else None
return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[
None
]

per_block_stats = data.map_blocks(
per_block_row,
chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)),
new_axis=(1, 2),
meta=np.array([], dtype=np.float64),
)

combined = da.reduction(
per_block_stats,
chunk=lambda x, axis=None, keepdims=False: x,
aggregate=_chan_reduce_axis_0,
axis=0,
keepdims=False,
concatenate=True,
dtype=np.float64,
meta=np.array([], dtype=np.float64),
)
counts = combined[0]
mean_ = combined[1]
m2 = combined[2]
denom = counts - dof if dof > 0 else counts
return MeanVarDict(mean=mean_, var=m2 / denom)


def _block_moments(
data: np.ndarray | CSBase,
by: pd.Categorical,
*,
mask: NDArray[np.bool] | None,
n_categories: int,
) -> NDArray[np.float64]:
"""Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``.

Groups with no observations in the chunk get zeros for mean and M2 so
they combine cleanly under ``_chan_combine``.
"""
codes = by.codes
valid = codes >= 0
if mask is not None:
valid = valid & mask
counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64)

out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64)
out[0] = counts[:, None]
nonempty = counts > 0
if not nonempty.any():
return out

agg = Aggregate(groupby=by, data=data, mask=mask)
sum_ = agg.sum()
sum_sq = agg._sum(_power(data, 2))
safe_counts = np.where(nonempty, counts, 1)[:, None]
mean_ = sum_ / safe_counts
# M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0.
m2 = np.maximum(sum_sq - sum_ * mean_, 0)
out[1, nonempty] = mean_[nonempty]
out[2, nonempty] = m2[nonempty]
return out


def _chan_combine(
a: NDArray[np.float64], b: NDArray[np.float64]
) -> NDArray[np.float64]:
"""Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise."""
n_a, mean_a, m2_a = a[0], a[1], a[2]
n_b, mean_b, m2_b = b[0], b[1], b[2]
n = n_a + n_b
safe_n = np.where(n > 0, n, 1)
delta = mean_b - mean_a
new_mean = mean_a + delta * n_b / safe_n
new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n
return np.stack([n, new_mean, new_m2])


def _chan_reduce_axis_0(
stats: NDArray[np.float64],
axis: int | None,
keepdims: bool, # noqa: FBT001
) -> NDArray[np.float64]:
"""Aggregate per-block stats along axis 0 with the parallel variance algorithm."""
result = stats[0]
for i in range(1, stats.shape[0]):
result = _chan_combine(result, stats[i])
return result[None] if keepdims else result


@_aggregate.register(DaskArray)
Expand Down
33 changes: 23 additions & 10 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from inspect import signature
from typing import TYPE_CHECKING, TypedDict, cast

import anndata as ad
import numba
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -178,16 +179,29 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info}
),
)
aggregated_mean_var = aggregate(
adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"]
)
mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"])
if isinstance(mean_global, DaskArray):
import dask.array as da
if batch_key is not None:
aggregated_mean_var = aggregate(
adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"]
)
mean_global, var_global = (
aggregated_mean_var.layers[l] for l in ["mean", "var"]
)
if isinstance(mean_global, DaskArray):
import dask.array as da

mean_global, var_global = da.compute(mean_global, var_global)
aggregated_mean_var.layers["mean"] = mean_global
aggregated_mean_var.layers["var"] = var_global
mean_global, var_global = da.compute(mean_global, var_global)
aggregated_mean_var.layers["mean"] = mean_global
aggregated_mean_var.layers["var"] = var_global
else:
aggregated_mean_var = ad.AnnData(
obs=pd.DataFrame(
index=np.array(["one"]), data={"__hvg_v3_batch_info__": np.array([0])}
),
layers={
"mean": df["means"].to_numpy().reshape((1, -1)),
"var": df["variances"].to_numpy().reshape((1, -1)),
},
)
batch_info = batch_info.to_numpy()
for b in np.unique(batch_info):
data_batch = data[batch_info == b]
Expand Down Expand Up @@ -752,7 +766,6 @@ def highly_variable_genes( # noqa: PLR0913
from .. import settings

flavor = settings.preset.highly_variable_genes.flavor

start = logg.info("extracting highly variable genes")

if not isinstance(adata, AnnData):
Expand Down
Loading