Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a625c55
perf: "two-pass" seurat hvg3 via `scanpy.get.aggregate`
ilan-gold Mar 26, 2026
d839e98
chore: hvg v3 benchmark
ilan-gold Mar 26, 2026
86db499
fix: use counts
ilan-gold Mar 26, 2026
d5a6a78
fix: use a batch key
ilan-gold Mar 26, 2026
fdc5653
fix: not again
ilan-gold Mar 26, 2026
8f0e426
fix: `compute` single pass!
ilan-gold Apr 8, 2026
8ad893d
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Apr 8, 2026
7e0390e
fix: unique
ilan-gold Apr 9, 2026
17be530
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Apr 10, 2026
cc0d67e
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Apr 16, 2026
96c16e9
chore: add new `dask` benchmark
ilan-gold May 4, 2026
db4bc2c
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 4, 2026
478af4a
fix: actually use dask lol
ilan-gold May 4, 2026
54db31b
chore: really do dask
ilan-gold May 4, 2026
4fe84c5
fix: layers support
ilan-gold May 4, 2026
35590a4
fix: no view check needed
ilan-gold May 4, 2026
db81d6e
fix: no layers eeded
ilan-gold May 4, 2026
b37444e
fix: reduce number of batches
ilan-gold May 5, 2026
cf65665
fix: a little bit more
ilan-gold May 5, 2026
8f4ef78
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 15, 2026
a7b067d
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 16, 2026
6f7ad6a
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 18, 2026
e624939
perf: chan's parallel mean-var algorithm for dask
ilan-gold Jun 5, 2026
61332fd
fix: params
ilan-gold Jun 5, 2026
1df5fda
fix: iteration
ilan-gold Jun 5, 2026
9a70581
fix: zarr link
ilan-gold Jun 5, 2026
5313ea2
fix: `median` calculation skipped
ilan-gold Jun 5, 2026
e19a7d8
fix: no-batch-key accel
ilan-gold Jun 5, 2026
8482561
fix: don't run all benchmarks with dask
ilan-gold Jun 5, 2026
44606f0
Merge branch 'ig/chan_mean_var_main' into ig/two_pass_hvg_v3
ilan-gold Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@
// "psutil": [""]
"pooch": [""],
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
// "scikit-misc": [""],
"scikit-misc": [""],
"dask": [""],
},

// Combinations of libraries/python versions can be excluded/included
Expand Down
32 changes: 25 additions & 7 deletions benchmarks/benchmarks/preprocessing_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING

import anndata as ad
import zarr

import scanpy as sc
from scanpy._utils import get_literal_vals
Expand Down Expand Up @@ -151,17 +152,34 @@ def peakmem_log1p(self, *_) -> None:


class Agg: # noqa: D101
params: tuple[AggType] = tuple(get_literal_vals(AggType))
param_names = ("agg_name",)
params: tuple[list[str], list[bool]] = (
list(get_literal_vals(AggType)),
[True, False],
)
param_names = ("agg_name", "use_dask")

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("lung93k")
adata.write_h5ad("lung93k.h5ad")

def setup(self, agg_name: AggType) -> None:
self.adata = ad.read_h5ad("lung93k.h5ad")
self.agg_name = agg_name
adata.write_zarr("lung93k.zarr")

def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001
if use_dask:
if agg_name == "median":
# Skip this one: https://asv.readthedocs.io/en/stable/writing_benchmarks.html#setup-and-teardown-functions
raise NotImplementedError()
z = zarr.open("lung93k.zarr")
self.adata = ad.AnnData(
obs=ad.io.read_elem(z["obs"]),
var=ad.io.read_elem(z["var"]),
layers={
"counts": ad.experimental.read_elem_lazy(z["layers"]["counts"])
},
X=ad.experimental.read_elem_lazy(z["X"]),
)
else:
self.adata = ad.read_zarr("lung93k.zarr")
self.agg_name: AggType = agg_name

def time_agg(self, *_) -> None:
sc.get.aggregate(
Expand Down
78 changes: 67 additions & 11 deletions benchmarks/benchmarks/preprocessing_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
from typing import TYPE_CHECKING

import anndata as ad
import numpy as np
import zarr

import scanpy as sc

from ._utils import get_dataset, param_skipper

if TYPE_CHECKING:
from typing import Literal

from ._utils import Dataset, KeyX


Expand Down Expand Up @@ -47,17 +51,6 @@ def time_pca(self, *_) -> None:
def peakmem_pca(self, *_) -> None:
sc.pp.pca(self.adata, svd_solver="arpack")

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

# regress_out is very slow for this dataset
@skip_when(dataset={"pbmc3k"})
def time_regress_out(self, *_) -> None:
Expand All @@ -72,3 +65,66 @@ def time_scale(self, *_) -> None:

def peakmem_scale(self, *_) -> None:
sc.pp.scale(self.adata, max_value=10)


class HVGSuite: # noqa: D101
params = (["seurat_v3", "cell_ranger", "seurat"], [True, False])
param_names = ("flavor", "use_dask")

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("lung93k")
adata.write_zarr("lung93k.zarr")
obs = np.arange(adata.shape[0])
np.random.default_rng().shuffle(obs)
adata[obs].write_zarr("lung93k_shuffled.zarr")

def setup(
self,
flavor: Literal["seurat_v3", "cell_ranger", "seurat"],
use_dask: bool, # noqa: FBT001
) -> None:
if use_dask:
if flavor != "seurat_v3":
# This benchmark only really makes sense for seurat v3 as that has been optimized.
raise NotImplementedError()
z = zarr.open("lung93k_shuffled.zarr")
self.adata = ad.AnnData(
obs=ad.io.read_elem(z["obs"]),
var=ad.io.read_elem(z["var"]),
layers={
"counts": ad.experimental.read_elem_lazy(z["layers"]["counts"])
},
X=ad.experimental.read_elem_lazy(z["X"]),
)
# Times out on the benchmark machine with full dataset
self.adata = self.adata[
self.adata.obs["PatientNumber"].isin(["1", "2", "3"])
].copy()
else:
self.adata = ad.read_zarr("lung93k.zarr")
sc.pp.filter_genes(self.adata, min_cells=3)
self.flavor = flavor

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata,
min_mean=0.0125,
max_mean=3,
min_disp=0.5,
flavor=self.flavor,
batch_key="PatientNumber",
**({"layer": "counts"} if self.flavor == "seurat_v3" else {}),
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata,
min_mean=0.0125,
max_mean=3,
min_disp=0.5,
flavor=self.flavor,
batch_key="PatientNumber",
**({"layer": "counts"} if self.flavor == "seurat_v3" else {}),
)
134 changes: 123 additions & 11 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

Expand Down Expand Up @@ -371,16 +370,129 @@ def aggregate_dask_mean_var(
mask: NDArray[np.bool] | None = None,
dof: int = 1,
) -> MeanVarDict:
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
return MeanVarDict(mean=mean, var=var)
"""Compute group-wise mean and variance for a dask array.

Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``),
then combine across chunks with the pairwise parallel algorithm from
Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm)
so the across-chunk reduction avoids the catastrophic cancellation of
``E[X**2] - E[X]**2``.
"""
import dask.array as da

n_categories = len(by.categories)
n_features = data.shape[1]
chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1

if chunked_axis == 1:
# Each block already sees every observation, so mean/var per chunk is final.
def per_block_col(chunk: Array) -> NDArray[np.float64]:
mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof)
return np.concatenate([mean_, var_], axis=0)

combined = data.map_blocks(
per_block_col,
chunks=((2 * n_categories,), data.chunks[1]),
meta=np.array([], dtype=np.float64),
)
return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:])

n_blocks = data.numblocks[0]

def per_block_row(
chunk: Array, block_info: dict | None = None
) -> NDArray[np.float64]:
row_subset = slice(*block_info[0]["array-location"][0])
by_sub = by[row_subset]
mask_sub = mask[row_subset] if mask is not None else None
return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[
None
]

per_block_stats = data.map_blocks(
per_block_row,
chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)),
new_axis=(1, 2),
meta=np.array([], dtype=np.float64),
)

combined = da.reduction(
per_block_stats,
chunk=lambda x, axis=None, keepdims=False: x,
aggregate=_chan_reduce_axis_0,
axis=0,
keepdims=False,
concatenate=True,
dtype=np.float64,
meta=np.array([], dtype=np.float64),
)
counts = combined[0]
mean_ = combined[1]
m2 = combined[2]
denom = counts - dof if dof > 0 else counts
return MeanVarDict(mean=mean_, var=m2 / denom)


def _block_moments(
data: np.ndarray | CSBase,
by: pd.Categorical,
*,
mask: NDArray[np.bool] | None,
n_categories: int,
) -> NDArray[np.float64]:
"""Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``.

Groups with no observations in the chunk get zeros for mean and M2 so
they combine cleanly under ``_chan_combine``.
"""
codes = by.codes
valid = codes >= 0
if mask is not None:
valid = valid & mask
counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64)

out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64)
out[0] = counts[:, None]
nonempty = counts > 0
if not nonempty.any():
return out

agg = Aggregate(groupby=by, data=data, mask=mask)
sum_ = agg.sum()
sum_sq = agg._sum(_power(data, 2))
safe_counts = np.where(nonempty, counts, 1)[:, None]
mean_ = sum_ / safe_counts
# M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0.
m2 = np.maximum(sum_sq - sum_ * mean_, 0)
out[1, nonempty] = mean_[nonempty]
out[2, nonempty] = m2[nonempty]
return out


def _chan_combine(
a: NDArray[np.float64], b: NDArray[np.float64]
) -> NDArray[np.float64]:
"""Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise."""
n_a, mean_a, m2_a = a[0], a[1], a[2]
n_b, mean_b, m2_b = b[0], b[1], b[2]
n = n_a + n_b
safe_n = np.where(n > 0, n, 1)
delta = mean_b - mean_a
new_mean = mean_a + delta * n_b / safe_n
new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n
return np.stack([n, new_mean, new_m2])


def _chan_reduce_axis_0(
stats: NDArray[np.float64],
axis: int | None,
keepdims: bool, # noqa: FBT001
) -> NDArray[np.float64]:
"""Aggregate per-block stats along axis 0 with the parallel variance algorithm."""
result = stats[0]
for i in range(1, stats.shape[0]):
result = _chan_combine(result, stats[i])
return result[None] if keepdims else result


@_aggregate.register(DaskArray)
Expand Down
Loading
Loading