diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index e555c6c602..c2e0ba2ec4 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -9,12 +9,11 @@ import pandas as pd from anndata import AnnData from fast_array_utils.numba import njit -from fast_array_utils.stats import mean_var from scipy import sparse from .. import _utils from .. import logging as logg -from .._compat import CSBase +from .._compat import CSBase, DaskArray from .._settings import Default from .._settings.presets import DETest from .._utils import ( @@ -23,7 +22,8 @@ get_literal_vals, raise_not_implemented_error_if_backed_type, ) -from ..get import _check_mask, _get_obs_rep +from ..get import _check_mask, _get_obs_rep, aggregate +from ..get._aggregated import _chan_combine if TYPE_CHECKING: from collections.abc import Generator, Iterable @@ -150,6 +150,82 @@ def _ranks( yield ranks, left, right +def _apply_expm1_preserving_sparsity(X, expm1_func): + """Apply ``expm1`` to X, lazily and chunk-wise for dask, keeping sparse + data sparse (``expm1(0) == 0``). + """ + if isinstance(X, DaskArray): + return X.map_blocks( + _apply_expm1_preserving_sparsity, + expm1_func, + dtype=X.dtype, + meta=X._meta, + ) + if isinstance(X, CSBase): + Xp = X.copy() + Xp.data = expm1_func(Xp.data) + return Xp + return expm1_func(X) + + +@numba.njit +def _scan( + part_n: NDArray[np.float64], + mean: NDArray[np.float64], + m2: NDArray[np.float64], + j: int, + step: int, +) -> NDArray[np.float64]: + """Running Chan combine over the partitions for gene ``j``. + + ``acc[i]`` is the combined ``(count, mean, M2)`` of partition ``i`` together with + every partition on the ``step`` side of it (``step=1`` forward / prefix, + ``step=-1`` backward / suffix). + """ + n_parts = part_n.shape[0] + acc = np.empty((n_parts, 3)) + n = m = v = 0.0 + i = 0 if step == 1 else n_parts - 1 + for _ in range(n_parts): + n, m, v = _chan_combine(n, m, v, part_n[i], mean[i, j], m2[i, j]) + acc[i, 0], acc[i, 1], acc[i, 2] = n, m, v + i += step + return acc + + +@njit +def _vars_rest( + part_n: NDArray[np.float64], + mean: NDArray[np.float64], + m2: NDArray[np.float64], + k: int, +) -> NDArray[np.float64]: + """Leave-one-out variance for each selected group, parallel over genes. + + Group ``g``'s "rest" combines the forward scan up to ``g - 1`` with the backward + scan from ``g + 1`` — every partition except ``g`` — so variances are never + subtracted (Chan's cancellation-free combine). + """ + n_parts, n_genes = mean.shape + vars_rest = np.zeros((k, n_genes)) + for j in numba.prange(n_genes): + prefix = _scan(part_n, mean, m2, j, 1) + suffix = _scan(part_n, mean, m2, j, -1) + for g in range(k): + right = suffix[g + 1] # everything after g; g + 1 < n_parts always holds + if g >= 1: + left = prefix[g - 1] + n_r, _, m2_r = _chan_combine( + left[0], left[1], left[2], right[0], right[1], right[2] + ) + else: + n_r, m2_r = right[0], right[2] + denom = n_r - 1.0 if n_r >= 2.0 else 1.0 + v = m2_r / denom + vars_rest[g, j] = v if v > 0.0 else 0.0 + return vars_rest + + class _RankGenes: def __init__( self, @@ -216,7 +292,6 @@ def __init__( self.means = None self.vars = None - self.means_rest = None self.vars_rest = None @@ -230,64 +305,135 @@ def __init__( self.grouping_mask = adata.obs[groupby].isin(self.groups_order) self.grouping = adata.obs.loc[self.grouping_mask, groupby] - def _basic_stats(self, *, exponentiate_values: bool = False) -> None: - """Set self.{means,vars,pts}{,_rest} depending on X.""" - n_genes = self.X.shape[1] - n_groups = self.groups_masks_obs.shape[0] - - self.means = np.zeros((n_groups, n_genes)) - self.vars = np.zeros((n_groups, n_genes)) - self.pts = np.zeros((n_groups, n_genes)) if self.comp_pts else None - + def _basic_stats( + self, *, exponentiate_values: bool = False, need_var: bool = False + ) -> None: + """Populate per-group stats, and (in vs_rest mode) rest-group stats. + + ``need_var`` controls whether variance (per-group and per-rest) is + computed; only the t-test family reads it. In vs_rest mode every cell + is partitioned into its selected group or a single "remainder" + partition (cells in no selected group), and each group's "rest" is the + forward Chan-combine of all other partitions — a sum of non-negative + terms, hence free of catastrophic cancellation for any group sizes. + """ + X = ( + _apply_expm1_preserving_sparsity(self.X, self.expm1_func) + if exponentiate_values + else self.X + ) if self.ireference is None: - self.means_rest = np.zeros((n_groups, n_genes)) - self.vars_rest = np.zeros((n_groups, n_genes)) - self.pts_rest = np.zeros((n_groups, n_genes)) if self.comp_pts else None + self._stats_vs_rest(X, need_var=need_var) else: - mask_rest = self.groups_masks_obs[self.ireference] - x_rest = self.X[mask_rest] - if exponentiate_values: - x_rest = self.expm1_func(x_rest) - self.means[self.ireference], self.vars[self.ireference] = mean_var( - x_rest, axis=0, correction=1 - ) - # deleting the next line causes a memory leak for some reason - del x_rest - - if isinstance(self.X, CSBase): - get_nonzeros = lambda x: x.getnnz(axis=0) + self._stats_vs_reference(X, need_var=need_var) + + def _aggregate_group_stats( + self, X_used, codes: NDArray[np.int64], n_parts: int, *, need_var: bool + ): + """One batched :func:`scanpy.get.aggregate` over ``X_used`` grouped by + ``codes`` (values ``0 .. n_parts-1``). + + Returns ``(mean, var, nnz)`` of shape ``(n_parts, n_genes)``, + zero-filled for partitions with no cells. ``var`` is ``None`` unless + ``need_var``; ``nnz`` is ``None`` unless ``self.comp_pts``. + """ + n_genes = X_used.shape[1] + mean = np.zeros((n_parts, n_genes)) + var = np.zeros((n_parts, n_genes)) if need_var else None + nnz = np.zeros((n_parts, n_genes)) if self.comp_pts else None + + funcs = ["mean"] + if need_var: + funcs.append("var") + if self.comp_pts: + funcs.append("count_nonzero") + agg_adata = AnnData( + X=X_used, + obs=pd.DataFrame( + {"_g": pd.Categorical(codes, categories=range(n_parts))}, + index=pd.RangeIndex(len(codes)).astype(str), + ), + ) + out = aggregate(agg_adata, by="_g", func=funcs, dof=1) + idx = out.obs_names.astype(int).to_numpy() + mean[idx] = np.asarray(out.layers["mean"]) + if need_var: + var[idx] = np.asarray(out.layers["var"]) + if self.comp_pts: + nnz[idx] = np.asarray(out.layers["count_nonzero"]) + return mean, var, nnz + + def _stats_vs_reference(self, X, *, need_var: bool) -> None: + """vs-reference: aggregate the selected-group cells only (the + reference is itself one of the selected groups). No rest derivation. + """ + mask = self.grouping_mask.to_numpy() + X_used = X if mask.all() else X[mask] + codes = pd.Categorical(self.grouping, categories=self.groups_order).codes + k = self.groups_masks_obs.shape[0] + + self.means, self.vars, nnz = self._aggregate_group_stats( + X_used, codes, k, need_var=need_var + ) + if self.comp_pts: + n_per_group = self.groups_masks_obs.sum(axis=1) + self.pts = nnz / n_per_group[:, None] else: - get_nonzeros = lambda x: np.count_nonzero(x, axis=0) - - for group_index, mask_obs in enumerate(self.groups_masks_obs): - x_mask = self.X[mask_obs] - if exponentiate_values: - x_mask = self.expm1_func(x_mask) - - if self.comp_pts: - self.pts[group_index] = get_nonzeros(x_mask) / x_mask.shape[0] - - if self.ireference is not None and group_index == self.ireference: - continue + self.pts = None + + def _stats_vs_rest(self, X, *, need_var: bool) -> None: + """vs-rest: partition *all* cells into the ``k`` selected-group + partitions plus one "remainder" partition (cells in no selected group + — non-selected groups and unassigned/NaN). Each group's "rest" is the + forward Chan-combine of every other partition. + """ + k = self.groups_masks_obs.shape[0] + + # Codes: each cell's selected-group index, or `k` (the remainder + # partition) for cells in no selected group (non-selected / NaN). + sel = pd.Categorical(self.group_col, categories=self.groups_order).codes + codes = np.where(sel >= 0, sel, k).astype(np.int64) + part_n = np.bincount(codes, minlength=k + 1) # partition k == remainder + n_sel = part_n[:k] + + mean, var, nnz = self._aggregate_group_stats(X, codes, k + 1, need_var=need_var) + + # Selected-group arm of the test (the remainder partition is excluded). + self.means = mean[:k] + self.vars = var[:k] if need_var else None + self.pts = nnz[:k] / n_sel[:, None] if self.comp_pts else None + + # M2 = var * (n - 1); forced to 0 for partitions with <= 1 cell so a + # singleton remainder (aggregate var undefined there) is harmless. + if need_var: + with np.errstate(invalid="ignore"): + M2 = var * (part_n[:, None] - 1) + M2[part_n <= 1] = 0.0 + else: + M2 = None - self.means[group_index], self.vars[group_index] = mean_var( - x_mask, axis=0, correction=1 - ) + self._derive_rest_stats(part_n, mean, M2, nnz, k, need_var=need_var) - if self.ireference is None: - mask_rest = ~mask_obs - x_rest = self.X[mask_rest] - if exponentiate_values: - x_rest = self.expm1_func(x_rest) - ( - self.means_rest[group_index], - self.vars_rest[group_index], - ) = mean_var(x_rest, axis=0, correction=1) - # this can be costly for sparse data - if self.comp_pts: - self.pts_rest[group_index] = get_nonzeros(x_rest) / x_rest.shape[0] - # deleting the next line causes a memory leak for some reason - del x_rest + def _derive_rest_stats( + self, part_n, mean, M2, nnz, k: int, *, need_var: bool + ) -> None: + """Set ``means_rest``/``vars_rest``/``pts_rest`` for each selected group ``g`` + (statistics over every cell *not* in ``g``). + + ``means_rest`` and ``pts_rest`` are linear in the partitions, so they are the + exact total-minus-group difference. Variance would lose precision under such + subtraction, so ``vars_rest`` uses the cancellation-free Chan leave-one-out + scan (:func:`_vars_rest`) — only the variance-based tests request it. + """ + n_rest = (self.X.shape[0] - part_n[:k])[:, None] + total = (part_n[:, None] * mean).sum(axis=0) + self.means_rest = (total - part_n[:k, None] * mean[:k]) / n_rest + self.vars_rest = ( + _vars_rest(np.ascontiguousarray(part_n, dtype=np.float64), mean, M2, k) + if need_var + else None + ) + self.pts_rest = (nnz.sum(axis=0) - nnz[:k]) / n_rest if self.comp_pts else None def t_test( self, method: Literal["t-test", "t-test_overestim_var"] @@ -452,7 +598,36 @@ def logreg( if len(self.groups_order) <= 2: break - def compute_statistics( # noqa: PLR0912 + def _run_illico(self, *, tie_correct: bool): + """Invoke `illico.asymptotic_wilcoxon` on `self.X` / `self.group_col`.""" + from illico import asymptotic_wilcoxon + + return asymptotic_wilcoxon( + AnnData( + X=self.X, + var=pd.DataFrame(index=self.var_names), + obs=pd.DataFrame( + index=pd.RangeIndex(self.X.shape[0]).astype("str"), + # This self.group_col means illico will run tests against + # *all* data instead of what's in self.groups_order as + # controlled by the `groups` arg. + # TODO: Only run the subset once illico supports a `groups` argument + data={"group": self.group_col}, + ), + ), + reference=self.groups_order[self.ireference] + if self.ireference is not None + else None, + group_keys="group", + return_as_scanpy=False, + is_log1p=True, + tie_correct=tie_correct, + use_continuity=False, + alternative="two-sided", + use_rust=False, + ) + + def compute_statistics( self, method: DETest, *, @@ -464,38 +639,11 @@ def compute_statistics( # noqa: PLR0912 **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: - self._basic_stats(exponentiate_values=False) + self._basic_stats(exponentiate_values=not mean_in_log_space, need_var=True) generate_test_results = self.t_test(method) - if not mean_in_log_space: - # If we are not exponentiating after the mean aggregation, we need to recalculate the stats. - self._basic_stats(exponentiate_values=True) elif "wilcoxon" in method: if "illico" in method: - from illico import asymptotic_wilcoxon - - illico_df = asymptotic_wilcoxon( - AnnData( - X=self.X, - var=pd.DataFrame(index=self.var_names), - obs=pd.DataFrame( - index=pd.RangeIndex(self.X.shape[0]).astype("str"), - # This self.group_col means illico will run tests against *all* data - # instead of what's in self.groups_order as controlled by the `groups` arg. - # TODO: Only run the subset once illico supports a `groups` argument - data={"group": self.group_col}, - ), - ), - reference=self.groups_order[self.ireference] - if self.ireference is not None - else None, - group_keys="group", - return_as_scanpy=False, - is_log1p=True, - tie_correct=tie_correct, - use_continuity=False, - alternative="two-sided", - use_rust=False, - ) + illico_df = self._run_illico(tie_correct=tie_correct) generate_test_results = _illico_results_to_iter( illico_df, self.groups_order, @@ -503,65 +651,85 @@ def compute_statistics( # noqa: PLR0912 ) else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) - # If we're not exponentiating after the mean aggregation, then do it now. - self._basic_stats(exponentiate_values=not mean_in_log_space) + # Wilcoxon paths only need means (for fold-change); skip var. + self._basic_stats(exponentiate_values=not mean_in_log_space, need_var=False) elif method == "logreg": generate_test_results = self.logreg(**kwds) - self.stats = None - - n_genes = self.X.shape[1] - for group_index, scores, pvals in generate_test_results: - group_name = str(self.groups_order[group_index]) - - if n_genes_user is not None: - scores_sort = np.abs(scores) if rankby_abs else scores - global_indices = _select_top_n(scores_sort, n_genes_user) - first_col = "names" - else: - global_indices = slice(None) - first_col = "scores" - - if self.stats is None: - idx = pd.MultiIndex.from_tuples([(group_name, first_col)]) - self.stats = pd.DataFrame(columns=idx) + self.stats = _build_stats_dataframe( + self, + generate_test_results, + corr_method=corr_method, + n_genes_user=n_genes_user, + rankby_abs=rankby_abs, + mean_in_log_space=mean_in_log_space, + ) - if n_genes_user is not None: - self.stats[group_name, "names"] = self.var_names[global_indices] - self.stats[group_name, "scores"] = scores[global_indices] +def _build_stats_dataframe( + rg, + results, + *, + corr_method: _CorrMethod, + n_genes_user: int | None, + rankby_abs: bool, + mean_in_log_space: bool, +): + """Drain the per-group ``(group_index, scores, pvals)`` iterator into a + wide-form ``(group, statistic)`` MultiIndex DataFrame: top-N selection, + multiple-testing correction, and (when ``rg.means`` is set) log2 + fold-change. Read-only on ``rg``. + """ + from statsmodels.stats.multitest import multipletests - if pvals is not None: - self.stats[group_name, "pvals"] = pvals[global_indices] - if corr_method == "benjamini-hochberg": - from statsmodels.stats.multitest import multipletests + n_genes_total = rg.X.shape[1] + df: pd.DataFrame | None = None - pvals[np.isnan(pvals)] = 1 - _, pvals_adj, _, _ = multipletests( - pvals, alpha=0.05, method="fdr_bh" - ) - elif corr_method == "bonferroni": - pvals_adj = np.minimum(pvals * n_genes, 1.0) - self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices] - - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = ( - (self.expm1_func(mean_group) + 1e-9) - / (self.expm1_func(mean_rest) + 1e-9) - if mean_in_log_space - else (mean_group + 1e-9) / (mean_rest + 1e-9) - ) # add small value to avoid zeros - self.stats[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] - ) + for group_index, scores, pvals in results: + group_name = str(rg.groups_order[group_index]) - if n_genes_user is None: - self.stats.index = self.var_names + if n_genes_user is not None: + scores_sort = np.abs(scores) if rankby_abs else scores + global_indices = _select_top_n(scores_sort, n_genes_user) + first_col = "names" + else: + global_indices = slice(None) + first_col = "scores" + + if df is None: + idx = pd.MultiIndex.from_tuples([(group_name, first_col)]) + df = pd.DataFrame(columns=idx) + + if n_genes_user is not None: + df[group_name, "names"] = rg.var_names[global_indices] + df[group_name, "scores"] = scores[global_indices] + + if pvals is not None: + df[group_name, "pvals"] = pvals[global_indices] + if corr_method == "benjamini-hochberg": + pvals[np.isnan(pvals)] = 1 + _, pvals_adj, _, _ = multipletests(pvals, alpha=0.05, method="fdr_bh") + elif corr_method == "bonferroni": + pvals_adj = np.minimum(pvals * n_genes_total, 1.0) + df[group_name, "pvals_adj"] = pvals_adj[global_indices] + + if rg.means is not None: + mean_group = rg.means[group_index] + mean_rest = ( + rg.means_rest[group_index] + if rg.ireference is None + else rg.means[rg.ireference] + ) + foldchanges = ( + (rg.expm1_func(mean_group) + 1e-9) / (rg.expm1_func(mean_rest) + 1e-9) + if mean_in_log_space + else (mean_group + 1e-9) / (mean_rest + 1e-9) + ) # add small value to avoid zeros + df[group_name, "logfoldchanges"] = np.log2(foldchanges[global_indices]) + + if df is not None and n_genes_user is None: + df.index = rg.var_names + return df def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index 6c32d45dbc..e388eab11d 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -115,8 +115,14 @@ def test_results( for g in range(expected["names"].shape[0]): with subtests.test(group=g): + # atol guards against ULP-level golden vs new-code bit-pattern + # differences at near-zero t-scores (e.g., genes with equal + # group/rest means produce 0.0 in one path vs ~1e-15 in another). np.testing.assert_allclose( - expected["scores"][g, :n], results["scores"][str(g)][:n], rtol=1e-5 + expected["scores"][g, :n], + results["scores"][str(g)][:n], + rtol=1e-5, + atol=1e-10, ) np.testing.assert_array_equal( expected["names"][g, :n], results["names"][str(g)][:n]