diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..655b33d9a8 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -83,7 +83,8 @@ // "psutil": [""] "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], + "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 9a20e7eda3..bc3bedce63 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import anndata as ad +import zarr import scanpy as sc from scanpy._utils import get_literal_vals @@ -151,17 +152,34 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[AggType] = tuple(get_literal_vals(AggType)) - param_names = ("agg_name",) + params: tuple[list[str], list[bool]] = ( + list(get_literal_vals(AggType)), + [True, False], + ) + param_names = ("agg_name", "use_dask") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") - adata.write_h5ad("lung93k.h5ad") - - def setup(self, agg_name: AggType) -> None: - self.adata = ad.read_h5ad("lung93k.h5ad") - self.agg_name = agg_name + adata.write_zarr("lung93k.zarr") + + def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 + if use_dask: + if agg_name == "median": + # Skip this one: https://asv.readthedocs.io/en/stable/writing_benchmarks.html#setup-and-teardown-functions + raise NotImplementedError() + z = zarr.open("lung93k.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) + else: + self.adata = ad.read_zarr("lung93k.zarr") + self.agg_name: AggType = agg_name def time_agg(self, *_) -> None: sc.get.aggregate( diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..350bb66883 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -9,12 +9,16 @@ from typing import TYPE_CHECKING import anndata as ad +import numpy as np +import zarr import scanpy as sc from ._utils import get_dataset, param_skipper if TYPE_CHECKING: + from typing import Literal + from ._utils import Dataset, KeyX @@ -47,17 +51,6 @@ def time_pca(self, *_) -> None: def peakmem_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") - def time_highly_variable_genes(self, *_) -> None: - # the default flavor runs on log-transformed data - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - - def peakmem_highly_variable_genes(self, *_) -> None: - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - # regress_out is very slow for this dataset @skip_when(dataset={"pbmc3k"}) def time_regress_out(self, *_) -> None: @@ -72,3 +65,66 @@ def time_scale(self, *_) -> None: def peakmem_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) + + +class HVGSuite: # noqa: D101 + params = (["seurat_v3", "cell_ranger", "seurat"], [True, False]) + param_names = ("flavor", "use_dask") + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("lung93k") + adata.write_zarr("lung93k.zarr") + obs = np.arange(adata.shape[0]) + np.random.default_rng().shuffle(obs) + adata[obs].write_zarr("lung93k_shuffled.zarr") + + def setup( + self, + flavor: Literal["seurat_v3", "cell_ranger", "seurat"], + use_dask: bool, # noqa: FBT001 + ) -> None: + if use_dask: + if flavor != "seurat_v3": + # This benchmark only really makes sense for seurat v3 as that has been optimized. + raise NotImplementedError() + z = zarr.open("lung93k_shuffled.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) + # Times out on the benchmark machine with full dataset + self.adata = self.adata[ + self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) + ].copy() + else: + self.adata = ad.read_zarr("lung93k.zarr") + sc.pp.filter_genes(self.adata, min_cells=3) + self.flavor = flavor + + def time_highly_variable_genes(self, *_) -> None: + # the default flavor runs on log-transformed data + sc.pp.highly_variable_genes( + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + batch_key="PatientNumber", + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), + ) + + def peakmem_highly_variable_genes(self, *_) -> None: + sc.pp.highly_variable_genes( + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + batch_key="PatientNumber", + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), + ) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..66f251c316 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from anndata import AnnData -from fast_array_utils.stats._power import power as fau_power # TODO: upstream from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -371,16 +370,129 @@ def aggregate_dask_mean_var( mask: NDArray[np.bool] | None = None, dof: int = 1, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] - # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. - if isinstance(data._meta, CSRBase): - sq_mean = sq_mean.compute() - var = sq_mean - fau_power(mean, 2) - if dof != 0: - group_counts = np.bincount(by.codes) - var *= (group_counts / (group_counts - dof))[:, np.newaxis] - return MeanVarDict(mean=mean, var=var) + """Compute group-wise mean and variance for a dask array. + + Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``), + then combine across chunks with the pairwise parallel algorithm from + Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm) + so the across-chunk reduction avoids the catastrophic cancellation of + ``E[X**2] - E[X]**2``. + """ + import dask.array as da + + n_categories = len(by.categories) + n_features = data.shape[1] + chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1 + + if chunked_axis == 1: + # Each block already sees every observation, so mean/var per chunk is final. + def per_block_col(chunk: Array) -> NDArray[np.float64]: + mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof) + return np.concatenate([mean_, var_], axis=0) + + combined = data.map_blocks( + per_block_col, + chunks=((2 * n_categories,), data.chunks[1]), + meta=np.array([], dtype=np.float64), + ) + return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:]) + + n_blocks = data.numblocks[0] + + def per_block_row( + chunk: Array, block_info: dict | None = None + ) -> NDArray[np.float64]: + row_subset = slice(*block_info[0]["array-location"][0]) + by_sub = by[row_subset] + mask_sub = mask[row_subset] if mask is not None else None + return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[ + None + ] + + per_block_stats = data.map_blocks( + per_block_row, + chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)), + new_axis=(1, 2), + meta=np.array([], dtype=np.float64), + ) + + combined = da.reduction( + per_block_stats, + chunk=lambda x, axis=None, keepdims=False: x, + aggregate=_chan_reduce_axis_0, + axis=0, + keepdims=False, + concatenate=True, + dtype=np.float64, + meta=np.array([], dtype=np.float64), + ) + counts = combined[0] + mean_ = combined[1] + m2 = combined[2] + denom = counts - dof if dof > 0 else counts + return MeanVarDict(mean=mean_, var=m2 / denom) + + +def _block_moments( + data: np.ndarray | CSBase, + by: pd.Categorical, + *, + mask: NDArray[np.bool] | None, + n_categories: int, +) -> NDArray[np.float64]: + """Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``. + + Groups with no observations in the chunk get zeros for mean and M2 so + they combine cleanly under ``_chan_combine``. + """ + codes = by.codes + valid = codes >= 0 + if mask is not None: + valid = valid & mask + counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64) + + out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64) + out[0] = counts[:, None] + nonempty = counts > 0 + if not nonempty.any(): + return out + + agg = Aggregate(groupby=by, data=data, mask=mask) + sum_ = agg.sum() + sum_sq = agg._sum(_power(data, 2)) + safe_counts = np.where(nonempty, counts, 1)[:, None] + mean_ = sum_ / safe_counts + # M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0. + m2 = np.maximum(sum_sq - sum_ * mean_, 0) + out[1, nonempty] = mean_[nonempty] + out[2, nonempty] = m2[nonempty] + return out + + +def _chan_combine( + a: NDArray[np.float64], b: NDArray[np.float64] +) -> NDArray[np.float64]: + """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" + n_a, mean_a, m2_a = a[0], a[1], a[2] + n_b, mean_b, m2_b = b[0], b[1], b[2] + n = n_a + n_b + safe_n = np.where(n > 0, n, 1) + delta = mean_b - mean_a + new_mean = mean_a + delta * n_b / safe_n + new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n + return np.stack([n, new_mean, new_m2]) + + +def _chan_reduce_axis_0( + stats: NDArray[np.float64], + axis: int | None, + keepdims: bool, # noqa: FBT001 +) -> NDArray[np.float64]: + """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" + result = stats[0] + for i in range(1, stats.shape[0]): + result = _chan_combine(result, stats[i]) + return result[None] if keepdims else result @_aggregate.register(DaskArray) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index d5f3d2cc79..cad3ff470d 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -20,7 +20,7 @@ raise_if_dask_feature_axis_chunked, sanitize_anndata, ) -from ..get import _get_obs_rep +from ..get import _get_obs_rep, aggregate from ._distributed import materialize_as_ndarray from ._simple import filter_genes @@ -36,7 +36,7 @@ @singledispatch def clip_square_sum( data_batch: np.ndarray, clip_val: np.ndarray -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray] | tuple[DaskArray, DaskArray]: """Clip data_batch by clip_val. Parameters @@ -64,24 +64,19 @@ def clip_square_sum( @clip_square_sum.register(DaskArray) -def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[DaskArray, DaskArray]: n_blocks = data_batch.blocks.size def sum_and_sum_squares_clipped_from_block(block): return np.vstack(clip_square_sum(block, clip_val))[None, ...] - squared_batch_counts_sum, batch_counts_sum = ( - data_batch - .map_blocks( - sum_and_sum_squares_clipped_from_block, - new_axis=(1,), - chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), - meta=np.array([]), - dtype=np.float64, - ) - .sum(axis=0) - .compute() - ) + squared_batch_counts_sum, batch_counts_sum = data_batch.map_blocks( + sum_and_sum_squares_clipped_from_block, + new_axis=(1,), + chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), + meta=np.array([]), + dtype=np.float64, + ).sum(axis=0) return squared_batch_counts_sum, batch_counts_sum @@ -172,17 +167,55 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 batch_info = ( pd.Categorical(np.zeros(adata.shape[0], dtype=int)) if batch_key is None - else adata.obs[batch_key].to_numpy() + else adata.obs[batch_key] ) - norm_gene_vars = [] + + adata_agg = AnnData( + X=data, + var=pd.DataFrame(index=adata.var_names), + obs=pd.DataFrame( + index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info} + ), + ) + if batch_key is not None: + aggregated_mean_var = aggregate( + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] + ) + mean_global, var_global = ( + aggregated_mean_var.layers[l] for l in ["mean", "var"] + ) + if isinstance(mean_global, DaskArray): + import dask.array as da + + mean_global, var_global = da.compute(mean_global, var_global) + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global + else: + aggregated_mean_var = AnnData( + obs=pd.DataFrame( + index=np.array(["one"]), data={"__hvg_v3_batch_info__": np.array([0])} + ), + layers={ + "mean": df["means"].to_numpy().reshape((1, -1)), + "var": df["variances"].to_numpy().reshape((1, -1)), + }, + ) + batch_info = batch_info.to_numpy() for b in np.unique(batch_info): data_batch = data[batch_info == b] - - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + mean, var = ( + aggregated_mean_var[ + aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b + ].layers[l] + for l in ["mean", "var"] + ) + if isinstance(mean, CSBase): + mean = mean.toarray() + mean = mean.ravel() + if isinstance(var, CSBase): + var = var.toarray() + var = var.ravel() estimat_var = np.zeros(data.shape[1], dtype=np.float64) if (not_const := var > 0).any(): y = np.log10(var[not_const]) @@ -204,8 +237,13 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 + squared_batch_counts_sum - 2 * batch_counts_sum * mean ) - norm_gene_vars.append(norm_gene_var.reshape(1, -1)) + norm_gene_vars.append(norm_gene_var) + if any(isinstance(e, DaskArray) for e in norm_gene_vars): + import dask.array as da + + norm_gene_vars = da.compute(*norm_gene_vars) + norm_gene_vars = [ngv.reshape(1, -1) for ngv in norm_gene_vars] norm_gene_vars = np.concatenate(norm_gene_vars, axis=0) # argsort twice gives ranks, small rank means most variable ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1)