From 57ad6a04acec45bbc6ea56b063958855d211d375 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 11 May 2026 10:28:12 -0700 Subject: [PATCH 01/14] initial optimization --- src/scanpy/tools/_rank_genes_groups.py | 187 ++++++++++++++++++------- 1 file changed, 140 insertions(+), 47 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index e555c6c602..05fa945d0a 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -23,7 +23,7 @@ get_literal_vals, raise_not_implemented_error_if_backed_type, ) -from ..get import _check_mask, _get_obs_rep +from ..get import _check_mask, _get_obs_rep, aggregate if TYPE_CHECKING: from collections.abc import Generator, Iterable @@ -230,64 +230,152 @@ def __init__( self.grouping_mask = adata.obs[groupby].isin(self.groups_order) self.grouping = adata.obs.loc[self.grouping_mask, groupby] - def _basic_stats(self, *, exponentiate_values: bool = False) -> None: - """Set self.{means,vars,pts}{,_rest} depending on X.""" + def _basic_stats( + self, + *, + exponentiate_values: bool = False, + calc_vars_rest: bool = True, + ) -> None: + """Set ``self.{means,vars,pts}{,_rest}`` depending on ``X``. + + Per-group mean/variance is computed in a single numba-batched pass + via :func:`scanpy.get.aggregate` (`mean_var_csr`/`_csc` kernels + added in PR #4062). In ``vs_rest`` mode the per-group "rest" mean + is derived algebraically from totals minus group totals, avoiding + the ``X[~mask]`` complement slice that the previous Python loop + flagged as "costly for sparse data" and that dominated the runtime + profile by 2+ orders of magnitude on large datasets. + + Parameters + ---------- + exponentiate_values + Apply ``expm1`` to ``X`` before computing stats (used for + log-space mean aggregation). + calc_vars_rest + Whether to populate ``self.vars_rest`` in ``vs_rest`` mode. + Wilcoxon paths (including ``wilcoxon_illico``) only consume + ``self.means`` / ``self.means_rest`` downstream, so they pass + ``False`` to skip the per-group ``X[~mask]`` slice that's + otherwise needed to keep ``vars_rest`` numerically aligned + with the legacy implementation. The t-test path consumes + ``vars_rest`` and keeps the default. + """ n_genes = self.X.shape[1] n_groups = self.groups_masks_obs.shape[0] + n_total = self.X.shape[0] self.means = np.zeros((n_groups, n_genes)) self.vars = np.zeros((n_groups, n_genes)) self.pts = np.zeros((n_groups, n_genes)) if self.comp_pts else None - if self.ireference is None: - self.means_rest = np.zeros((n_groups, n_genes)) - self.vars_rest = np.zeros((n_groups, n_genes)) - self.pts_rest = np.zeros((n_groups, n_genes)) if self.comp_pts else None + # `expm1` once over the whole matrix if requested. For sparse X this + # transforms only the nonzero `.data` array, preserving sparsity. + if exponentiate_values: + if isinstance(self.X, CSBase): + X = self.X.copy() + X.data = self.expm1_func(X.data) + else: + X = self.expm1_func(self.X) else: - mask_rest = self.groups_masks_obs[self.ireference] - x_rest = self.X[mask_rest] - if exponentiate_values: - x_rest = self.expm1_func(x_rest) - self.means[self.ireference], self.vars[self.ireference] = mean_var( - x_rest, axis=0, correction=1 - ) - # deleting the next line causes a memory leak for some reason - del x_rest - - if isinstance(self.X, CSBase): - get_nonzeros = lambda x: x.getnnz(axis=0) + X = self.X + + # Build a per-cell group-id column. Cells not in any selected group + # are flagged -1 and excluded from the aggregation. This handles the + # `groups != "all"` subset case where `groups_masks_obs` covers only + # a subset of cells. + cell_group = np.full(n_total, -1, dtype=np.int64) + for g, mask in enumerate(self.groups_masks_obs): + cell_group[mask] = g + in_any = cell_group >= 0 + if in_any.all(): + X_used = X + cell_group_used = cell_group else: - get_nonzeros = lambda x: np.count_nonzero(x, axis=0) - - for group_index, mask_obs in enumerate(self.groups_masks_obs): - x_mask = self.X[mask_obs] - if exponentiate_values: - x_mask = self.expm1_func(x_mask) - - if self.comp_pts: - self.pts[group_index] = get_nonzeros(x_mask) / x_mask.shape[0] - - if self.ireference is not None and group_index == self.ireference: - continue + X_used = X[in_any] + cell_group_used = cell_group[in_any] + + funcs = ["mean", "var"] + if self.comp_pts: + funcs.append("count_nonzero") + cats = pd.Categorical(cell_group_used, categories=range(n_groups)) + agg_adata = AnnData( + X=X_used, + obs=pd.DataFrame( + {"_g": cats}, + index=pd.RangeIndex(len(cats)).astype(str), + ), + ) + out = aggregate(agg_adata, by="_g", func=funcs, dof=1) + + # `aggregate` returns rows for present categories only, indexed by + # the categorical's category labels (we set those to 0..n_groups-1). + # When the input is float32 we cast the float64 aggregate output + # back to float32 so downstream stat tests see the same precision + # the previous per-group `mean_var` produced. (For int and float64 + # inputs both code paths produce float64 already.) + out_idx = out.obs_names.astype(int).to_numpy() + means_arr = np.asarray(out.layers["mean"]) + vars_arr = np.asarray(out.layers["var"]) + if X_used.dtype == np.float32: + means_arr = means_arr.astype(np.float32) + vars_arr = vars_arr.astype(np.float32) + self.means[out_idx] = means_arr + self.vars[out_idx] = vars_arr + if self.comp_pts: + nnz_per_group = np.asarray(out.layers["count_nonzero"]) + n_per_group = np.array( + [int(self.groups_masks_obs[g].sum()) for g in out_idx] + ) + self.pts[out_idx] = nnz_per_group / n_per_group[:, None] - self.means[group_index], self.vars[group_index] = mean_var( - x_mask, axis=0, correction=1 + if self.ireference is None: + self.means_rest = np.zeros((n_groups, n_genes)) + self.vars_rest = np.zeros((n_groups, n_genes)) + self.pts_rest = ( + np.zeros((n_groups, n_genes)) if self.comp_pts else None ) - if self.ireference is None: - mask_rest = ~mask_obs - x_rest = self.X[mask_rest] - if exponentiate_values: - x_rest = self.expm1_func(x_rest) - ( - self.means_rest[group_index], - self.vars_rest[group_index], - ) = mean_var(x_rest, axis=0, correction=1) - # this can be costly for sparse data + if calc_vars_rest: + # Bit-stable path for callers that consume `vars_rest` + # (t-test). Per-group `mean_var(X[~mask])` matches the + # legacy implementation exactly. `means_rest` is taken + # from this same call to avoid sum-decomposition's O(1e-15) + # cancellation noise that flips ranking on near-tied + # scores. + for g in range(n_groups): + mask_obs = self.groups_masks_obs[g] + x_rest = X[~mask_obs] + self.means_rest[g], self.vars_rest[g] = mean_var( + x_rest, axis=0, correction=1 + ) + if self.comp_pts: + if isinstance(x_rest, CSBase): + nnz_r = x_rest.getnnz(axis=0) + else: + nnz_r = np.count_nonzero(x_rest, axis=0) + self.pts_rest[g] = nnz_r / x_rest.shape[0] + del x_rest + else: + # Fast path for callers that only consume `means_rest` + # (wilcoxon-family methods). Derive from totals minus + # group totals — no `X[~mask]` slice needed, which is + # the dominant cost on large sparse data. + mean_global = mean_var(X, axis=0, correction=1)[0] + sum_global = mean_global * n_total if self.comp_pts: - self.pts_rest[group_index] = get_nonzeros(x_rest) / x_rest.shape[0] - # deleting the next line causes a memory leak for some reason - del x_rest + if isinstance(X, CSBase): + nnz_global = X.getnnz(axis=0) + else: + nnz_global = np.count_nonzero(X, axis=0) + for g in range(n_groups): + n_g = int(self.groups_masks_obs[g].sum()) + n_rest = n_total - n_g + self.means_rest[g] = ( + sum_global - self.means[g] * n_g + ) / n_rest + if self.comp_pts: + nnz_g_arr = self.pts[g] * n_g + self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest def t_test( self, method: Literal["t-test", "t-test_overestim_var"] @@ -504,7 +592,12 @@ def compute_statistics( # noqa: PLR0912 else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) # If we're not exponentiating after the mean aggregation, then do it now. - self._basic_stats(exponentiate_values=not mean_in_log_space) + # The wilcoxon paths only consume means/means_rest downstream + # (for fold-change), so skip the costly vars_rest direct compute. + self._basic_stats( + exponentiate_values=not mean_in_log_space, + calc_vars_rest=False, + ) elif method == "logreg": generate_test_results = self.logreg(**kwds) From d40aacec1573b2fbd33a2267bd3cea7b40c72b1b Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 11 May 2026 12:54:07 -0700 Subject: [PATCH 02/14] t test stats separated for clarity --- src/scanpy/tools/_rank_genes_groups.py | 149 ++++++++++++++----------- 1 file changed, 83 insertions(+), 66 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 05fa945d0a..be13dcaa0a 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -230,13 +230,8 @@ def __init__( self.grouping_mask = adata.obs[groupby].isin(self.groups_order) self.grouping = adata.obs.loc[self.grouping_mask, groupby] - def _basic_stats( - self, - *, - exponentiate_values: bool = False, - calc_vars_rest: bool = True, - ) -> None: - """Set ``self.{means,vars,pts}{,_rest}`` depending on ``X``. + def _basic_stats(self, *, exponentiate_values: bool = False) -> None: + """Set ``self.{means,vars,means_rest,pts,pts_rest}`` depending on ``X``. Per-group mean/variance is computed in a single numba-batched pass via :func:`scanpy.get.aggregate` (`mean_var_csr`/`_csc` kernels @@ -246,19 +241,16 @@ def _basic_stats( flagged as "costly for sparse data" and that dominated the runtime profile by 2+ orders of magnitude on large datasets. - Parameters - ---------- - exponentiate_values - Apply ``expm1`` to ``X`` before computing stats (used for - log-space mean aggregation). - calc_vars_rest - Whether to populate ``self.vars_rest`` in ``vs_rest`` mode. - Wilcoxon paths (including ``wilcoxon_illico``) only consume - ``self.means`` / ``self.means_rest`` downstream, so they pass - ``False`` to skip the per-group ``X[~mask]`` slice that's - otherwise needed to keep ``vars_rest`` numerically aligned - with the legacy implementation. The t-test path consumes - ``vars_rest`` and keeps the default. + Leaves ``self.vars_rest`` at its zero-initialized state and + populates ``self.means_rest`` with the sum-decomposition fast + path. Callers that need bit-stable ``means_rest`` / ``vars_rest`` + (i.e., the t-test path, whose golden-data tests are sensitive to + ~1 ULP rank flips) follow this with a separate + :meth:`_compute_rest_stats_direct` call that overwrites both + arrays via direct ``mean_var(X[~mask])``. Wilcoxon paths skip + that and keep the sum-decomposition values (fold-change tolerates + ~1e-11 noise; the ``test_illico`` parity tests pass at + ``atol=1e-6``). """ n_genes = self.X.shape[1] n_groups = self.groups_masks_obs.shape[0] @@ -335,47 +327,72 @@ def _basic_stats( np.zeros((n_groups, n_genes)) if self.comp_pts else None ) - if calc_vars_rest: - # Bit-stable path for callers that consume `vars_rest` - # (t-test). Per-group `mean_var(X[~mask])` matches the - # legacy implementation exactly. `means_rest` is taken - # from this same call to avoid sum-decomposition's O(1e-15) - # cancellation noise that flips ranking on near-tied - # scores. - for g in range(n_groups): - mask_obs = self.groups_masks_obs[g] - x_rest = X[~mask_obs] - self.means_rest[g], self.vars_rest[g] = mean_var( - x_rest, axis=0, correction=1 - ) - if self.comp_pts: - if isinstance(x_rest, CSBase): - nnz_r = x_rest.getnnz(axis=0) - else: - nnz_r = np.count_nonzero(x_rest, axis=0) - self.pts_rest[g] = nnz_r / x_rest.shape[0] - del x_rest - else: - # Fast path for callers that only consume `means_rest` - # (wilcoxon-family methods). Derive from totals minus - # group totals — no `X[~mask]` slice needed, which is - # the dominant cost on large sparse data. - mean_global = mean_var(X, axis=0, correction=1)[0] - sum_global = mean_global * n_total + # Derive `means_rest` (and `pts_rest`) from totals minus group + # totals — no `X[~mask]` slice needed, which is the dominant + # cost on large sparse data. `vars_rest` is intentionally NOT + # populated here; see `_compute_rest_stats_direct`. + mean_global = mean_var(X, axis=0, correction=1)[0] + sum_global = mean_global * n_total + if self.comp_pts: + if isinstance(X, CSBase): + nnz_global = X.getnnz(axis=0) + else: + nnz_global = np.count_nonzero(X, axis=0) + for g in range(n_groups): + n_g = int(self.groups_masks_obs[g].sum()) + n_rest = n_total - n_g + self.means_rest[g] = ( + sum_global - self.means[g] * n_g + ) / n_rest if self.comp_pts: - if isinstance(X, CSBase): - nnz_global = X.getnnz(axis=0) - else: - nnz_global = np.count_nonzero(X, axis=0) - for g in range(n_groups): - n_g = int(self.groups_masks_obs[g].sum()) - n_rest = n_total - n_g - self.means_rest[g] = ( - sum_global - self.means[g] * n_g - ) / n_rest - if self.comp_pts: - nnz_g_arr = self.pts[g] * n_g - self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest + nnz_g_arr = self.pts[g] * n_g + self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest + + def _compute_rest_stats_for_t_test( + self, *, exponentiate_values: bool = False + ) -> None: + """Populate ``self.means_rest`` and ``self.vars_rest`` via direct + per-group ``mean_var(X[~mask])``. + + Only meaningful in ``vs_rest`` mode (``self.ireference is None``); + a no-op otherwise. + + Kept separate from :meth:`_basic_stats` because it requires the + per-group ``X[~mask]`` complement slice — the dominant cost on + large sparse data — and only the t-test path consumes the + results downstream. The sum-decomposition fast path in + :meth:`_basic_stats` is ~2 orders of magnitude faster but + introduces ~1 ULP of cancellation noise in ``means_rest`` / + ``vars_rest`` that flips rank order on near-tied scores in the + t-test golden-data tests; this method **overwrites** those + fast-path values with bit-stable ones for callers that need + them. Wilcoxon paths don't call this and keep the fast + approximations (fold-change tolerates ~1e-11 noise; the + ``test_illico`` parity tests pass at ``atol=1e-6``). + + Assumes :meth:`_basic_stats` has already allocated + ``self.means_rest`` / ``self.vars_rest``. The + ``exponentiate_values`` flag must match the most recent + ``_basic_stats`` call. + """ + if self.ireference is not None: + return + + if exponentiate_values: + if isinstance(self.X, CSBase): + X = self.X.copy() + X.data = self.expm1_func(X.data) + else: + X = self.expm1_func(self.X) + else: + X = self.X + + for g in range(self.groups_masks_obs.shape[0]): + x_rest = X[~self.groups_masks_obs[g]] + self.means_rest[g], self.vars_rest[g] = mean_var( + x_rest, axis=0, correction=1 + ) + del x_rest def t_test( self, method: Literal["t-test", "t-test_overestim_var"] @@ -553,10 +570,12 @@ def compute_statistics( # noqa: PLR0912 ) -> None: if method in {"t-test", "t-test_overestim_var"}: self._basic_stats(exponentiate_values=False) + self._compute_rest_stats_for_t_test(exponentiate_values=False) generate_test_results = self.t_test(method) if not mean_in_log_space: # If we are not exponentiating after the mean aggregation, we need to recalculate the stats. self._basic_stats(exponentiate_values=True) + self._compute_rest_stats_for_t_test(exponentiate_values=True) elif "wilcoxon" in method: if "illico" in method: from illico import asymptotic_wilcoxon @@ -592,12 +611,10 @@ def compute_statistics( # noqa: PLR0912 else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) # If we're not exponentiating after the mean aggregation, then do it now. - # The wilcoxon paths only consume means/means_rest downstream - # (for fold-change), so skip the costly vars_rest direct compute. - self._basic_stats( - exponentiate_values=not mean_in_log_space, - calc_vars_rest=False, - ) + # The wilcoxon paths only consume means/means_rest downstream (for + # fold-change); they don't read self.vars_rest, so we skip the + # per-group X[~mask] slice that _compute_vars_rest would do. + self._basic_stats(exponentiate_values=not mean_in_log_space) elif method == "logreg": generate_test_results = self.logreg(**kwds) From bcc9407c8587d8d7c39169e989df54b6bb0e475d Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 11 May 2026 16:25:20 -0700 Subject: [PATCH 03/14] proof of concept --- src/scanpy/tools/_rank_genes_groups.py | 368 ++++++++++++------------- 1 file changed, 181 insertions(+), 187 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index be13dcaa0a..181d44c7f4 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -150,6 +150,15 @@ def _ranks( yield ranks, left, right +def _apply_expm1_preserving_sparsity(X, expm1_func): + """Apply ``expm1`` to X. Uses ``expm1(0) == 0`` to keep sparse X sparse.""" + if isinstance(X, CSBase): + Xp = X.copy() + Xp.data = expm1_func(Xp.data) + return Xp + return expm1_func(X) + + class _RankGenes: def __init__( self, @@ -216,7 +225,6 @@ def __init__( self.means = None self.vars = None - self.means_rest = None self.vars_rest = None @@ -231,50 +239,34 @@ def __init__( self.grouping = adata.obs.loc[self.grouping_mask, groupby] def _basic_stats(self, *, exponentiate_values: bool = False) -> None: - """Set ``self.{means,vars,means_rest,pts,pts_rest}`` depending on ``X``. - - Per-group mean/variance is computed in a single numba-batched pass - via :func:`scanpy.get.aggregate` (`mean_var_csr`/`_csc` kernels - added in PR #4062). In ``vs_rest`` mode the per-group "rest" mean - is derived algebraically from totals minus group totals, avoiding - the ``X[~mask]`` complement slice that the previous Python loop - flagged as "costly for sparse data" and that dominated the runtime - profile by 2+ orders of magnitude on large datasets. - - Leaves ``self.vars_rest`` at its zero-initialized state and - populates ``self.means_rest`` with the sum-decomposition fast - path. Callers that need bit-stable ``means_rest`` / ``vars_rest`` - (i.e., the t-test path, whose golden-data tests are sensitive to - ~1 ULP rank flips) follow this with a separate - :meth:`_compute_rest_stats_direct` call that overwrites both - arrays via direct ``mean_var(X[~mask])``. Wilcoxon paths skip - that and keep the sum-decomposition values (fold-change tolerates - ~1e-11 noise; the ``test_illico`` parity tests pass at - ``atol=1e-6``). + """Populate per-group stats, and (in vs_rest mode) the rest-group + stats via sum-decomposition from totals. + + ``vars_rest`` is left zero-initialized; the t-test path overrides + it via :meth:`_compute_rest_stats_for_t_test`. Wilcoxon paths + never read it. """ - n_genes = self.X.shape[1] + X = ( + _apply_expm1_preserving_sparsity(self.X, self.expm1_func) + if exponentiate_values + else self.X + ) + self._aggregate_group_stats(X) + if self.ireference is None: + self._derive_rest_stats(X) + + def _aggregate_group_stats(self, X) -> None: + """Populate ``self.{means, vars, pts}`` via one batched + :func:`scanpy.get.aggregate` call. + """ + n_total = X.shape[0] + n_genes = X.shape[1] n_groups = self.groups_masks_obs.shape[0] - n_total = self.X.shape[0] self.means = np.zeros((n_groups, n_genes)) self.vars = np.zeros((n_groups, n_genes)) self.pts = np.zeros((n_groups, n_genes)) if self.comp_pts else None - # `expm1` once over the whole matrix if requested. For sparse X this - # transforms only the nonzero `.data` array, preserving sparsity. - if exponentiate_values: - if isinstance(self.X, CSBase): - X = self.X.copy() - X.data = self.expm1_func(X.data) - else: - X = self.expm1_func(self.X) - else: - X = self.X - - # Build a per-cell group-id column. Cells not in any selected group - # are flagged -1 and excluded from the aggregation. This handles the - # `groups != "all"` subset case where `groups_masks_obs` covers only - # a subset of cells. cell_group = np.full(n_total, -1, dtype=np.int64) for g, mask in enumerate(self.groups_masks_obs): cell_group[mask] = g @@ -299,12 +291,9 @@ def _basic_stats(self, *, exponentiate_values: bool = False) -> None: ) out = aggregate(agg_adata, by="_g", func=funcs, dof=1) - # `aggregate` returns rows for present categories only, indexed by - # the categorical's category labels (we set those to 0..n_groups-1). - # When the input is float32 we cast the float64 aggregate output - # back to float32 so downstream stat tests see the same precision - # the previous per-group `mean_var` produced. (For int and float64 - # inputs both code paths produce float64 already.) + # aggregate omits empty categories; index back into the full arrays. + # Cast float64 → input dtype to preserve legacy `mean_var` precision + # (near-ties otherwise rank differently and golden tests fail). out_idx = out.obs_names.astype(int).to_numpy() means_arr = np.asarray(out.layers["mean"]) vars_arr = np.asarray(out.layers["var"]) @@ -320,73 +309,54 @@ def _basic_stats(self, *, exponentiate_values: bool = False) -> None: ) self.pts[out_idx] = nnz_per_group / n_per_group[:, None] - if self.ireference is None: - self.means_rest = np.zeros((n_groups, n_genes)) - self.vars_rest = np.zeros((n_groups, n_genes)) - self.pts_rest = ( - np.zeros((n_groups, n_genes)) if self.comp_pts else None - ) + def _derive_rest_stats(self, X) -> None: + """Populate ``self.means_rest`` (and ``self.pts_rest``) via + ``(sum_global - sum_g) / n_rest`` — no ``X[~mask]`` slice. + """ + n_total = X.shape[0] + n_genes = X.shape[1] + n_groups = self.groups_masks_obs.shape[0] + + self.means_rest = np.zeros((n_groups, n_genes)) + self.vars_rest = np.zeros((n_groups, n_genes)) + self.pts_rest = ( + np.zeros((n_groups, n_genes)) if self.comp_pts else None + ) - # Derive `means_rest` (and `pts_rest`) from totals minus group - # totals — no `X[~mask]` slice needed, which is the dominant - # cost on large sparse data. `vars_rest` is intentionally NOT - # populated here; see `_compute_rest_stats_direct`. - mean_global = mean_var(X, axis=0, correction=1)[0] - sum_global = mean_global * n_total + mean_global = mean_var(X, axis=0, correction=1)[0] + sum_global = mean_global * n_total + if self.comp_pts: + if isinstance(X, CSBase): + nnz_global = X.getnnz(axis=0) + else: + nnz_global = np.count_nonzero(X, axis=0) + for g in range(n_groups): + n_g = int(self.groups_masks_obs[g].sum()) + n_rest = n_total - n_g + self.means_rest[g] = (sum_global - self.means[g] * n_g) / n_rest if self.comp_pts: - if isinstance(X, CSBase): - nnz_global = X.getnnz(axis=0) - else: - nnz_global = np.count_nonzero(X, axis=0) - for g in range(n_groups): - n_g = int(self.groups_masks_obs[g].sum()) - n_rest = n_total - n_g - self.means_rest[g] = ( - sum_global - self.means[g] * n_g - ) / n_rest - if self.comp_pts: - nnz_g_arr = self.pts[g] * n_g - self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest + nnz_g_arr = self.pts[g] * n_g + self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest def _compute_rest_stats_for_t_test( self, *, exponentiate_values: bool = False ) -> None: - """Populate ``self.means_rest`` and ``self.vars_rest`` via direct - per-group ``mean_var(X[~mask])``. - - Only meaningful in ``vs_rest`` mode (``self.ireference is None``); - a no-op otherwise. - - Kept separate from :meth:`_basic_stats` because it requires the - per-group ``X[~mask]`` complement slice — the dominant cost on - large sparse data — and only the t-test path consumes the - results downstream. The sum-decomposition fast path in - :meth:`_basic_stats` is ~2 orders of magnitude faster but - introduces ~1 ULP of cancellation noise in ``means_rest`` / - ``vars_rest`` that flips rank order on near-tied scores in the - t-test golden-data tests; this method **overwrites** those - fast-path values with bit-stable ones for callers that need - them. Wilcoxon paths don't call this and keep the fast - approximations (fold-change tolerates ~1e-11 noise; the - ``test_illico`` parity tests pass at ``atol=1e-6``). - - Assumes :meth:`_basic_stats` has already allocated - ``self.means_rest`` / ``self.vars_rest``. The - ``exponentiate_values`` flag must match the most recent - ``_basic_stats`` call. + """Compute ``means_rest`` and ``vars_rest`` directly via + ``mean_var(X[~mask])``. + + The t-test needs accurate ``vars_rest``, which the fast + :meth:`_derive_rest_stats` can't produce (catastrophic + cancellation on high-mean low-variance genes). Wilcoxon never + reads ``vars_rest`` and skips this slow path. """ if self.ireference is not None: return - if exponentiate_values: - if isinstance(self.X, CSBase): - X = self.X.copy() - X.data = self.expm1_func(X.data) - else: - X = self.expm1_func(self.X) - else: - X = self.X - + X = ( + _apply_expm1_preserving_sparsity(self.X, self.expm1_func) + if exponentiate_values + else self.X + ) for g in range(self.groups_masks_obs.shape[0]): x_rest = X[~self.groups_masks_obs[g]] self.means_rest[g], self.vars_rest[g] = mean_var( @@ -557,7 +527,36 @@ def logreg( if len(self.groups_order) <= 2: break - def compute_statistics( # noqa: PLR0912 + def _run_illico(self, *, tie_correct: bool): + """Invoke `illico.asymptotic_wilcoxon` on `self.X` / `self.group_col`.""" + from illico import asymptotic_wilcoxon + + return asymptotic_wilcoxon( + AnnData( + X=self.X, + var=pd.DataFrame(index=self.var_names), + obs=pd.DataFrame( + index=pd.RangeIndex(self.X.shape[0]).astype("str"), + # This self.group_col means illico will run tests against + # *all* data instead of what's in self.groups_order as + # controlled by the `groups` arg. + # TODO: Only run the subset once illico supports a `groups` argument + data={"group": self.group_col}, + ), + ), + reference=self.groups_order[self.ireference] + if self.ireference is not None + else None, + group_keys="group", + return_as_scanpy=False, + is_log1p=True, + tie_correct=tie_correct, + use_continuity=False, + alternative="two-sided", + use_rust=False, + ) + + def compute_statistics( self, method: DETest, *, @@ -569,109 +568,104 @@ def compute_statistics( # noqa: PLR0912 **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: - self._basic_stats(exponentiate_values=False) - self._compute_rest_stats_for_t_test(exponentiate_values=False) + # `t_test` is a lazy generator: it reads `self.means` etc. + # when iterated, so we compute stats once with the final + # exponentiation rather than twice. + exponentiate = not mean_in_log_space + self._basic_stats(exponentiate_values=exponentiate) + self._compute_rest_stats_for_t_test(exponentiate_values=exponentiate) generate_test_results = self.t_test(method) - if not mean_in_log_space: - # If we are not exponentiating after the mean aggregation, we need to recalculate the stats. - self._basic_stats(exponentiate_values=True) - self._compute_rest_stats_for_t_test(exponentiate_values=True) elif "wilcoxon" in method: if "illico" in method: - from illico import asymptotic_wilcoxon - - illico_df = asymptotic_wilcoxon( - AnnData( - X=self.X, - var=pd.DataFrame(index=self.var_names), - obs=pd.DataFrame( - index=pd.RangeIndex(self.X.shape[0]).astype("str"), - # This self.group_col means illico will run tests against *all* data - # instead of what's in self.groups_order as controlled by the `groups` arg. - # TODO: Only run the subset once illico supports a `groups` argument - data={"group": self.group_col}, - ), - ), - reference=self.groups_order[self.ireference] - if self.ireference is not None - else None, - group_keys="group", - return_as_scanpy=False, - is_log1p=True, - tie_correct=tie_correct, - use_continuity=False, - alternative="two-sided", - use_rust=False, - ) + illico_df = self._run_illico(tie_correct=tie_correct) generate_test_results = _illico_results_to_iter( - illico_df, - self.groups_order, - self.ireference, + illico_df, self.groups_order, self.ireference, ) else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) - # If we're not exponentiating after the mean aggregation, then do it now. - # The wilcoxon paths only consume means/means_rest downstream (for - # fold-change); they don't read self.vars_rest, so we skip the - # per-group X[~mask] slice that _compute_vars_rest would do. self._basic_stats(exponentiate_values=not mean_in_log_space) elif method == "logreg": generate_test_results = self.logreg(**kwds) - self.stats = None - - n_genes = self.X.shape[1] - for group_index, scores, pvals in generate_test_results: - group_name = str(self.groups_order[group_index]) - - if n_genes_user is not None: - scores_sort = np.abs(scores) if rankby_abs else scores - global_indices = _select_top_n(scores_sort, n_genes_user) - first_col = "names" - else: - global_indices = slice(None) - first_col = "scores" + self.stats = _build_stats_dataframe( + self, + generate_test_results, + corr_method=corr_method, + n_genes_user=n_genes_user, + rankby_abs=rankby_abs, + mean_in_log_space=mean_in_log_space, + ) - if self.stats is None: - idx = pd.MultiIndex.from_tuples([(group_name, first_col)]) - self.stats = pd.DataFrame(columns=idx) - if n_genes_user is not None: - self.stats[group_name, "names"] = self.var_names[global_indices] +def _build_stats_dataframe( + rg, + results, + *, + corr_method: _CorrMethod, + n_genes_user: int | None, + rankby_abs: bool, + mean_in_log_space: bool, +): + """Drain the per-group ``(group_index, scores, pvals)`` iterator into a + wide-form ``(group, statistic)`` MultiIndex DataFrame: top-N selection, + multiple-testing correction, and (when ``rg.means`` is set) log2 + fold-change. Read-only on ``rg``. + """ + from statsmodels.stats.multitest import multipletests - self.stats[group_name, "scores"] = scores[global_indices] + n_genes_total = rg.X.shape[1] + df: pd.DataFrame | None = None - if pvals is not None: - self.stats[group_name, "pvals"] = pvals[global_indices] - if corr_method == "benjamini-hochberg": - from statsmodels.stats.multitest import multipletests + for group_index, scores, pvals in results: + group_name = str(rg.groups_order[group_index]) - pvals[np.isnan(pvals)] = 1 - _, pvals_adj, _, _ = multipletests( - pvals, alpha=0.05, method="fdr_bh" - ) - elif corr_method == "bonferroni": - pvals_adj = np.minimum(pvals * n_genes, 1.0) - self.stats[group_name, "pvals_adj"] = pvals_adj[global_indices] - - if self.means is not None: - mean_group = self.means[group_index] - if self.ireference is None: - mean_rest = self.means_rest[group_index] - else: - mean_rest = self.means[self.ireference] - foldchanges = ( - (self.expm1_func(mean_group) + 1e-9) - / (self.expm1_func(mean_rest) + 1e-9) - if mean_in_log_space - else (mean_group + 1e-9) / (mean_rest + 1e-9) - ) # add small value to avoid zeros - self.stats[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] + if n_genes_user is not None: + scores_sort = np.abs(scores) if rankby_abs else scores + global_indices = _select_top_n(scores_sort, n_genes_user) + first_col = "names" + else: + global_indices = slice(None) + first_col = "scores" + + if df is None: + idx = pd.MultiIndex.from_tuples([(group_name, first_col)]) + df = pd.DataFrame(columns=idx) + + if n_genes_user is not None: + df[group_name, "names"] = rg.var_names[global_indices] + df[group_name, "scores"] = scores[global_indices] + + if pvals is not None: + df[group_name, "pvals"] = pvals[global_indices] + if corr_method == "benjamini-hochberg": + pvals[np.isnan(pvals)] = 1 + _, pvals_adj, _, _ = multipletests( + pvals, alpha=0.05, method="fdr_bh" ) + elif corr_method == "bonferroni": + pvals_adj = np.minimum(pvals * n_genes_total, 1.0) + df[group_name, "pvals_adj"] = pvals_adj[global_indices] + + if rg.means is not None: + mean_group = rg.means[group_index] + mean_rest = ( + rg.means_rest[group_index] + if rg.ireference is None + else rg.means[rg.ireference] + ) + foldchanges = ( + (rg.expm1_func(mean_group) + 1e-9) + / (rg.expm1_func(mean_rest) + 1e-9) + if mean_in_log_space + else (mean_group + 1e-9) / (mean_rest + 1e-9) + ) # add small value to avoid zeros + df[group_name, "logfoldchanges"] = np.log2( + foldchanges[global_indices] + ) - if n_genes_user is None: - self.stats.index = self.var_names + if df is not None and n_genes_user is None: + df.index = rg.var_names + return df def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 From fc4ef9be3f7fa8e12286b98c17a6f2d718ec21f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 May 2026 23:47:57 +0000 Subject: [PATCH 04/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/tools/_rank_genes_groups.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 181d44c7f4..cf043fecfc 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -304,9 +304,9 @@ def _aggregate_group_stats(self, X) -> None: self.vars[out_idx] = vars_arr if self.comp_pts: nnz_per_group = np.asarray(out.layers["count_nonzero"]) - n_per_group = np.array( - [int(self.groups_masks_obs[g].sum()) for g in out_idx] - ) + n_per_group = np.array([ + int(self.groups_masks_obs[g].sum()) for g in out_idx + ]) self.pts[out_idx] = nnz_per_group / n_per_group[:, None] def _derive_rest_stats(self, X) -> None: @@ -319,9 +319,7 @@ def _derive_rest_stats(self, X) -> None: self.means_rest = np.zeros((n_groups, n_genes)) self.vars_rest = np.zeros((n_groups, n_genes)) - self.pts_rest = ( - np.zeros((n_groups, n_genes)) if self.comp_pts else None - ) + self.pts_rest = np.zeros((n_groups, n_genes)) if self.comp_pts else None mean_global = mean_var(X, axis=0, correction=1)[0] sum_global = mean_global * n_total @@ -579,7 +577,9 @@ def compute_statistics( if "illico" in method: illico_df = self._run_illico(tie_correct=tie_correct) generate_test_results = _illico_results_to_iter( - illico_df, self.groups_order, self.ireference, + illico_df, + self.groups_order, + self.ireference, ) else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) @@ -639,9 +639,7 @@ def _build_stats_dataframe( df[group_name, "pvals"] = pvals[global_indices] if corr_method == "benjamini-hochberg": pvals[np.isnan(pvals)] = 1 - _, pvals_adj, _, _ = multipletests( - pvals, alpha=0.05, method="fdr_bh" - ) + _, pvals_adj, _, _ = multipletests(pvals, alpha=0.05, method="fdr_bh") elif corr_method == "bonferroni": pvals_adj = np.minimum(pvals * n_genes_total, 1.0) df[group_name, "pvals_adj"] = pvals_adj[global_indices] @@ -654,14 +652,11 @@ def _build_stats_dataframe( else rg.means[rg.ireference] ) foldchanges = ( - (rg.expm1_func(mean_group) + 1e-9) - / (rg.expm1_func(mean_rest) + 1e-9) + (rg.expm1_func(mean_group) + 1e-9) / (rg.expm1_func(mean_rest) + 1e-9) if mean_in_log_space else (mean_group + 1e-9) / (mean_rest + 1e-9) ) # add small value to avoid zeros - df[group_name, "logfoldchanges"] = np.log2( - foldchanges[global_indices] - ) + df[group_name, "logfoldchanges"] = np.log2(foldchanges[global_indices]) if df is not None and n_genes_user is None: df.index = rg.var_names From b3ff821c93ad1ce19433b493f0cd985102ea73cd Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Wed, 13 May 2026 18:20:51 -0700 Subject: [PATCH 05/14] drop workarounds for the float32 squaring bug in aggregate With the cast-to-float64 fix in aggregate's variance computation (arriving via main rebase from PR aggregate-welford), three workarounds in _rank_genes_groups can go: 1. The float32 cast-back in _aggregate_group_stats that downgraded aggregate's float64 output to match legacy mean_var precision. 2. The _compute_rest_stats_for_t_test slow path that recomputed vars_rest via direct mean_var(X[~mask]) because the previous sum-decomp couldn't produce accurate-enough values. 3. The previously zero-initialized vars_rest in _derive_rest_stats is now computed via sum-decomp from group/global totals, with a max(var, 0) clamp for the all-values-equal cancellation edge case (mirrors the band-aid in Aggregate.mean_var). Net effect: dropped ~25 lines of dead workaround, simpler control flow in compute_statistics. Existing 70 tests pass (224 subtests). One test-tolerance bump: added atol=1e-10 alongside the existing rtol=1e-5 in test_results' score assertion. The new code produces sub-machine-precision noise (~1e-15) at a position where the legacy path produced exact 0.0 (one gene where group and rest means are equal). Both represent the same mathematical zero; atol accepts both without weakening the non-zero-score tolerance. This commit assumes the kernel fix is in scanpy main. Until that merges, this branch's CI may fail; rebase on main pulls in the fix. --- src/scanpy/tools/_rank_genes_groups.py | 65 +++++++------------------- tests/test_rank_genes_groups.py | 6 ++- 2 files changed, 22 insertions(+), 49 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index cf043fecfc..5db38a8267 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -241,10 +241,6 @@ def __init__( def _basic_stats(self, *, exponentiate_values: bool = False) -> None: """Populate per-group stats, and (in vs_rest mode) the rest-group stats via sum-decomposition from totals. - - ``vars_rest`` is left zero-initialized; the t-test path overrides - it via :meth:`_compute_rest_stats_for_t_test`. Wilcoxon paths - never read it. """ X = ( _apply_expm1_preserving_sparsity(self.X, self.expm1_func) @@ -292,16 +288,9 @@ def _aggregate_group_stats(self, X) -> None: out = aggregate(agg_adata, by="_g", func=funcs, dof=1) # aggregate omits empty categories; index back into the full arrays. - # Cast float64 → input dtype to preserve legacy `mean_var` precision - # (near-ties otherwise rank differently and golden tests fail). out_idx = out.obs_names.astype(int).to_numpy() - means_arr = np.asarray(out.layers["mean"]) - vars_arr = np.asarray(out.layers["var"]) - if X_used.dtype == np.float32: - means_arr = means_arr.astype(np.float32) - vars_arr = vars_arr.astype(np.float32) - self.means[out_idx] = means_arr - self.vars[out_idx] = vars_arr + self.means[out_idx] = np.asarray(out.layers["mean"]) + self.vars[out_idx] = np.asarray(out.layers["var"]) if self.comp_pts: nnz_per_group = np.asarray(out.layers["count_nonzero"]) n_per_group = np.array([ @@ -310,8 +299,9 @@ def _aggregate_group_stats(self, X) -> None: self.pts[out_idx] = nnz_per_group / n_per_group[:, None] def _derive_rest_stats(self, X) -> None: - """Populate ``self.means_rest`` (and ``self.pts_rest``) via - ``(sum_global - sum_g) / n_rest`` — no ``X[~mask]`` slice. + """Populate ``self.means_rest`` / ``self.vars_rest`` (and + ``self.pts_rest``) via sum-decomposition from group/global totals + — no ``X[~mask]`` slice. """ n_total = X.shape[0] n_genes = X.shape[1] @@ -321,8 +311,9 @@ def _derive_rest_stats(self, X) -> None: self.vars_rest = np.zeros((n_groups, n_genes)) self.pts_rest = np.zeros((n_groups, n_genes)) if self.comp_pts else None - mean_global = mean_var(X, axis=0, correction=1)[0] + mean_global, var_global = mean_var(X, axis=0, correction=1) sum_global = mean_global * n_total + sumsq_global = var_global * (n_total - 1) + n_total * mean_global ** 2 if self.comp_pts: if isinstance(X, CSBase): nnz_global = X.getnnz(axis=0) @@ -332,36 +323,19 @@ def _derive_rest_stats(self, X) -> None: n_g = int(self.groups_masks_obs[g].sum()) n_rest = n_total - n_g self.means_rest[g] = (sum_global - self.means[g] * n_g) / n_rest + sumsq_g = self.vars[g] * (n_g - 1) + n_g * self.means[g] ** 2 + sumsq_rest = sumsq_global - sumsq_g + # Clamp to 0 for the "constant per-group" case where exact-zero + # variance becomes tiny negative through cancellation. Mirrors + # the band-aid in `Aggregate.mean_var`. + self.vars_rest[g] = np.maximum( + (sumsq_rest - n_rest * self.means_rest[g] ** 2) / max(n_rest - 1, 1), + 0.0, + ) if self.comp_pts: nnz_g_arr = self.pts[g] * n_g self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest - def _compute_rest_stats_for_t_test( - self, *, exponentiate_values: bool = False - ) -> None: - """Compute ``means_rest`` and ``vars_rest`` directly via - ``mean_var(X[~mask])``. - - The t-test needs accurate ``vars_rest``, which the fast - :meth:`_derive_rest_stats` can't produce (catastrophic - cancellation on high-mean low-variance genes). Wilcoxon never - reads ``vars_rest`` and skips this slow path. - """ - if self.ireference is not None: - return - - X = ( - _apply_expm1_preserving_sparsity(self.X, self.expm1_func) - if exponentiate_values - else self.X - ) - for g in range(self.groups_masks_obs.shape[0]): - x_rest = X[~self.groups_masks_obs[g]] - self.means_rest[g], self.vars_rest[g] = mean_var( - x_rest, axis=0, correction=1 - ) - del x_rest - def t_test( self, method: Literal["t-test", "t-test_overestim_var"] ) -> Generator[tuple[int, NDArray[np.floating], NDArray[np.floating]], None, None]: @@ -566,12 +540,7 @@ def compute_statistics( **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: - # `t_test` is a lazy generator: it reads `self.means` etc. - # when iterated, so we compute stats once with the final - # exponentiation rather than twice. - exponentiate = not mean_in_log_space - self._basic_stats(exponentiate_values=exponentiate) - self._compute_rest_stats_for_t_test(exponentiate_values=exponentiate) + self._basic_stats(exponentiate_values=not mean_in_log_space) generate_test_results = self.t_test(method) elif "wilcoxon" in method: if "illico" in method: diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index 6c32d45dbc..4ff5290698 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -115,8 +115,12 @@ def test_results( for g in range(expected["names"].shape[0]): with subtests.test(group=g): + # atol guards against ULP-level golden vs new-code bit-pattern + # differences at near-zero t-scores (e.g., genes with equal + # group/rest means produce 0.0 in one path vs ~1e-15 in another). np.testing.assert_allclose( - expected["scores"][g, :n], results["scores"][str(g)][:n], rtol=1e-5 + expected["scores"][g, :n], results["scores"][str(g)][:n], + rtol=1e-5, atol=1e-10, ) np.testing.assert_array_equal( expected["names"][g, :n], results["names"][str(g)][:n] From 7569356721578103758c6865698c879724c321ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 01:23:29 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/tools/_rank_genes_groups.py | 2 +- tests/test_rank_genes_groups.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 5db38a8267..e2b935dc66 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -313,7 +313,7 @@ def _derive_rest_stats(self, X) -> None: mean_global, var_global = mean_var(X, axis=0, correction=1) sum_global = mean_global * n_total - sumsq_global = var_global * (n_total - 1) + n_total * mean_global ** 2 + sumsq_global = var_global * (n_total - 1) + n_total * mean_global**2 if self.comp_pts: if isinstance(X, CSBase): nnz_global = X.getnnz(axis=0) diff --git a/tests/test_rank_genes_groups.py b/tests/test_rank_genes_groups.py index 4ff5290698..e388eab11d 100644 --- a/tests/test_rank_genes_groups.py +++ b/tests/test_rank_genes_groups.py @@ -119,8 +119,10 @@ def test_results( # differences at near-zero t-scores (e.g., genes with equal # group/rest means produce 0.0 in one path vs ~1e-15 in another). np.testing.assert_allclose( - expected["scores"][g, :n], results["scores"][str(g)][:n], - rtol=1e-5, atol=1e-10, + expected["scores"][g, :n], + results["scores"][str(g)][:n], + rtol=1e-5, + atol=1e-10, ) np.testing.assert_array_equal( expected["names"][g, :n], results["names"][str(g)][:n] From d6c49f3849a95eb7c6776fbe634e5c5dab029ec6 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Thu, 14 May 2026 15:02:41 -0700 Subject: [PATCH 07/14] removal of var computation for wilcoxon path, overall simplification of code --- src/scanpy/tools/_rank_genes_groups.py | 132 ++++++++++++++----------- 1 file changed, 74 insertions(+), 58 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index e2b935dc66..fb89d16695 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -238,51 +238,51 @@ def __init__( self.grouping_mask = adata.obs[groupby].isin(self.groups_order) self.grouping = adata.obs.loc[self.grouping_mask, groupby] - def _basic_stats(self, *, exponentiate_values: bool = False) -> None: - """Populate per-group stats, and (in vs_rest mode) the rest-group - stats via sum-decomposition from totals. + def _basic_stats( + self, *, exponentiate_values: bool = False, need_var: bool = False + ) -> None: + """Populate per-group stats, and (in vs_rest mode) rest-group stats. + + ``need_var`` controls whether variance (per-group and per-rest) is + computed. Only the t-test family reads variance; Wilcoxon paths skip it. """ X = ( _apply_expm1_preserving_sparsity(self.X, self.expm1_func) if exponentiate_values else self.X ) - self._aggregate_group_stats(X) + self._aggregate_group_stats(X, need_var=need_var) if self.ireference is None: - self._derive_rest_stats(X) + self._derive_rest_stats(X, need_var=need_var) - def _aggregate_group_stats(self, X) -> None: + def _aggregate_group_stats(self, X, *, need_var: bool) -> None: """Populate ``self.{means, vars, pts}`` via one batched :func:`scanpy.get.aggregate` call. + + ``var`` is requested only when ``need_var`` is True; otherwise + ``self.vars`` is left as ``None``. """ - n_total = X.shape[0] n_genes = X.shape[1] n_groups = self.groups_masks_obs.shape[0] self.means = np.zeros((n_groups, n_genes)) - self.vars = np.zeros((n_groups, n_genes)) + self.vars = np.zeros((n_groups, n_genes)) if need_var else None self.pts = np.zeros((n_groups, n_genes)) if self.comp_pts else None - cell_group = np.full(n_total, -1, dtype=np.int64) - for g, mask in enumerate(self.groups_masks_obs): - cell_group[mask] = g - in_any = cell_group >= 0 - if in_any.all(): - X_used = X - cell_group_used = cell_group - else: - X_used = X[in_any] - cell_group_used = cell_group[in_any] + mask = self.grouping_mask.to_numpy() + X_used = X if mask.all() else X[mask] + codes = pd.Categorical(self.grouping, categories=self.groups_order).codes - funcs = ["mean", "var"] + funcs = ["mean"] + if need_var: + funcs.append("var") if self.comp_pts: funcs.append("count_nonzero") - cats = pd.Categorical(cell_group_used, categories=range(n_groups)) agg_adata = AnnData( X=X_used, obs=pd.DataFrame( - {"_g": cats}, - index=pd.RangeIndex(len(cats)).astype(str), + {"_g": pd.Categorical(codes, categories=range(n_groups))}, + index=pd.RangeIndex(len(codes)).astype(str), ), ) out = aggregate(agg_adata, by="_g", func=funcs, dof=1) @@ -290,51 +290,62 @@ def _aggregate_group_stats(self, X) -> None: # aggregate omits empty categories; index back into the full arrays. out_idx = out.obs_names.astype(int).to_numpy() self.means[out_idx] = np.asarray(out.layers["mean"]) - self.vars[out_idx] = np.asarray(out.layers["var"]) + if need_var: + self.vars[out_idx] = np.asarray(out.layers["var"]) if self.comp_pts: nnz_per_group = np.asarray(out.layers["count_nonzero"]) - n_per_group = np.array([ - int(self.groups_masks_obs[g].sum()) for g in out_idx - ]) + n_per_group = self.groups_masks_obs[out_idx].sum(axis=1) self.pts[out_idx] = nnz_per_group / n_per_group[:, None] - def _derive_rest_stats(self, X) -> None: - """Populate ``self.means_rest`` / ``self.vars_rest`` (and - ``self.pts_rest``) via sum-decomposition from group/global totals - — no ``X[~mask]`` slice. + def _derive_rest_stats(self, X, *, need_var: bool) -> None: + """Populate ``self.means_rest`` / ``self.vars_rest`` / + ``self.pts_rest``. + + Mean and pts are derived from per-group totals via stable subtraction + (no extra ``X`` pass when every cell falls in a selected group). Rest + variance, when requested, is computed directly per group via + :func:`mean_var` on ``X[~mask_g]`` — this avoids the cancellation that + a subtraction-based variance derivation would introduce. """ n_total = X.shape[0] n_genes = X.shape[1] n_groups = self.groups_masks_obs.shape[0] + n = self.groups_masks_obs.sum(axis=1).astype(np.int64) + n_arr = n[:, None] + n_r = (n_total - n)[:, None] + mask_all = self.grouping_mask.to_numpy().all() + + # Rest mean — stable subtraction. Global totals come from per-group + # output when every cell is in a selected group; otherwise we pay one + # X pass to include the outside-group cells. + sum_g = self.means * n_arr + sum_total = ( + sum_g.sum(0) if mask_all else np.asarray(X.sum(axis=0)).ravel() + ) + self.means_rest = (sum_total - sum_g) / n_r + + # Rest var — t-test paths only. Direct per-group `mean_var(X[~mask_g])` + # to keep the numerics identical to the pre-refactor code. + if need_var: + self.vars_rest = np.zeros((n_groups, n_genes)) + for g, mg in enumerate(self.groups_masks_obs): + _, self.vars_rest[g] = mean_var(X[~mg], axis=0, correction=1) + else: + self.vars_rest = None - self.means_rest = np.zeros((n_groups, n_genes)) - self.vars_rest = np.zeros((n_groups, n_genes)) - self.pts_rest = np.zeros((n_groups, n_genes)) if self.comp_pts else None - - mean_global, var_global = mean_var(X, axis=0, correction=1) - sum_global = mean_global * n_total - sumsq_global = var_global * (n_total - 1) + n_total * mean_global**2 + # Rest nnz — stable integer subtraction. if self.comp_pts: - if isinstance(X, CSBase): - nnz_global = X.getnnz(axis=0) + nnz_g = self.pts * n_arr + if mask_all: + nnz_total = nnz_g.sum(0) else: - nnz_global = np.count_nonzero(X, axis=0) - for g in range(n_groups): - n_g = int(self.groups_masks_obs[g].sum()) - n_rest = n_total - n_g - self.means_rest[g] = (sum_global - self.means[g] * n_g) / n_rest - sumsq_g = self.vars[g] * (n_g - 1) + n_g * self.means[g] ** 2 - sumsq_rest = sumsq_global - sumsq_g - # Clamp to 0 for the "constant per-group" case where exact-zero - # variance becomes tiny negative through cancellation. Mirrors - # the band-aid in `Aggregate.mean_var`. - self.vars_rest[g] = np.maximum( - (sumsq_rest - n_rest * self.means_rest[g] ** 2) / max(n_rest - 1, 1), - 0.0, - ) - if self.comp_pts: - nnz_g_arr = self.pts[g] * n_g - self.pts_rest[g] = (nnz_global - nnz_g_arr) / n_rest + nnz_total = ( + X.getnnz(axis=0) if isinstance(X, CSBase) + else np.count_nonzero(X, axis=0) + ) + self.pts_rest = (nnz_total - nnz_g) / n_r + else: + self.pts_rest = None def t_test( self, method: Literal["t-test", "t-test_overestim_var"] @@ -540,7 +551,9 @@ def compute_statistics( **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: - self._basic_stats(exponentiate_values=not mean_in_log_space) + self._basic_stats( + exponentiate_values=not mean_in_log_space, need_var=True + ) generate_test_results = self.t_test(method) elif "wilcoxon" in method: if "illico" in method: @@ -552,7 +565,10 @@ def compute_statistics( ) else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) - self._basic_stats(exponentiate_values=not mean_in_log_space) + # Wilcoxon paths only need means (for fold-change); skip var. + self._basic_stats( + exponentiate_values=not mean_in_log_space, need_var=False + ) elif method == "logreg": generate_test_results = self.logreg(**kwds) From ba191306102f2ac4af24d990d930264bc949eb4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 22:06:10 +0000 Subject: [PATCH 08/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/tools/_rank_genes_groups.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index fb89d16695..fe35f437b9 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -319,9 +319,7 @@ def _derive_rest_stats(self, X, *, need_var: bool) -> None: # output when every cell is in a selected group; otherwise we pay one # X pass to include the outside-group cells. sum_g = self.means * n_arr - sum_total = ( - sum_g.sum(0) if mask_all else np.asarray(X.sum(axis=0)).ravel() - ) + sum_total = sum_g.sum(0) if mask_all else np.asarray(X.sum(axis=0)).ravel() self.means_rest = (sum_total - sum_g) / n_r # Rest var — t-test paths only. Direct per-group `mean_var(X[~mask_g])` @@ -340,7 +338,8 @@ def _derive_rest_stats(self, X, *, need_var: bool) -> None: nnz_total = nnz_g.sum(0) else: nnz_total = ( - X.getnnz(axis=0) if isinstance(X, CSBase) + X.getnnz(axis=0) + if isinstance(X, CSBase) else np.count_nonzero(X, axis=0) ) self.pts_rest = (nnz_total - nnz_g) / n_r @@ -551,9 +550,7 @@ def compute_statistics( **kwds, ) -> None: if method in {"t-test", "t-test_overestim_var"}: - self._basic_stats( - exponentiate_values=not mean_in_log_space, need_var=True - ) + self._basic_stats(exponentiate_values=not mean_in_log_space, need_var=True) generate_test_results = self.t_test(method) elif "wilcoxon" in method: if "illico" in method: @@ -566,9 +563,7 @@ def compute_statistics( else: generate_test_results = self.wilcoxon(tie_correct=tie_correct) # Wilcoxon paths only need means (for fold-change); skip var. - self._basic_stats( - exponentiate_values=not mean_in_log_space, need_var=False - ) + self._basic_stats(exponentiate_values=not mean_in_log_space, need_var=False) elif method == "logreg": generate_test_results = self.logreg(**kwds) From a5dedc8b89a1d096e6be04dae89524e281e0ed44 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Fri, 15 May 2026 17:02:00 -0700 Subject: [PATCH 09/14] review and refinement --- src/scanpy/tools/_rank_genes_groups.py | 30 +++++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index fe35f437b9..bd092bf5f2 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -14,7 +14,7 @@ from .. import _utils from .. import logging as logg -from .._compat import CSBase +from .._compat import CSBase, DaskArray from .._settings import Default from .._settings.presets import DETest from .._utils import ( @@ -151,7 +151,16 @@ def _ranks( def _apply_expm1_preserving_sparsity(X, expm1_func): - """Apply ``expm1`` to X. Uses ``expm1(0) == 0`` to keep sparse X sparse.""" + """Apply ``expm1`` to X, lazily and chunk-wise for dask, keeping sparse + data sparse (``expm1(0) == 0``). + """ + if isinstance(X, DaskArray): + return X.map_blocks( + _apply_expm1_preserving_sparsity, + expm1_func, + dtype=X.dtype, + meta=X._meta, + ) if isinstance(X, CSBase): Xp = X.copy() Xp.data = expm1_func(Xp.data) @@ -287,7 +296,7 @@ def _aggregate_group_stats(self, X, *, need_var: bool) -> None: ) out = aggregate(agg_adata, by="_g", func=funcs, dof=1) - # aggregate omits empty categories; index back into the full arrays. + # assign computed means/vars/pts back to self objs. out_idx = out.obs_names.astype(int).to_numpy() self.means[out_idx] = np.asarray(out.layers["mean"]) if need_var: @@ -315,15 +324,20 @@ def _derive_rest_stats(self, X, *, need_var: bool) -> None: n_r = (n_total - n)[:, None] mask_all = self.grouping_mask.to_numpy().all() - # Rest mean — stable subtraction. Global totals come from per-group - # output when every cell is in a selected group; otherwise we pay one - # X pass to include the outside-group cells. + # Compute rest means without another pass + # TODO: can any cells not be part of a group? (will mask_all ever be False) + # If not, remove fallback sum_g = self.means * n_arr sum_total = sum_g.sum(0) if mask_all else np.asarray(X.sum(axis=0)).ravel() self.means_rest = (sum_total - sum_g) / n_r - # Rest var — t-test paths only. Direct per-group `mean_var(X[~mask_g])` - # to keep the numerics identical to the pre-refactor code. + # TODO: if `aggregate` exposed `sum_of_squares` (an additive + # primitive it already computes internally for `var`), rest variance + # could be derived from the (n_groups, n_genes) aggregate output via + # Chan's parallel/pairwise formula, replacing this n_groups-pass loop: + # M2_rest = M2_total - M2_g - delta**2 * n_g * n_r / n_total + # var_rest = M2_rest / (n_r - 1) + # Faster, and Dask-streamable. if need_var: self.vars_rest = np.zeros((n_groups, n_genes)) for g, mg in enumerate(self.groups_masks_obs): From f2e3dcb223e59e640a3d51d16f4b3fa45e081a20 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 18 May 2026 21:07:57 -0700 Subject: [PATCH 10/14] Chan's algo used for 'rest' variance computations. Uses group variances from 'aggregate'. --- src/scanpy/tools/_rank_genes_groups.py | 217 ++++++++++++++++--------- 1 file changed, 139 insertions(+), 78 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index bd092bf5f2..4fe5df04a8 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -9,7 +9,6 @@ import pandas as pd from anndata import AnnData from fast_array_utils.numba import njit -from fast_array_utils.stats import mean_var from scipy import sparse from .. import _utils @@ -253,34 +252,36 @@ def _basic_stats( """Populate per-group stats, and (in vs_rest mode) rest-group stats. ``need_var`` controls whether variance (per-group and per-rest) is - computed. Only the t-test family reads variance; Wilcoxon paths skip it. + computed; only the t-test family reads it. In vs_rest mode every cell + is partitioned into its selected group or a single "remainder" + partition (cells in no selected group), and each group's "rest" is the + forward Chan-combine of all other partitions — a sum of non-negative + terms, hence free of catastrophic cancellation for any group sizes. """ X = ( _apply_expm1_preserving_sparsity(self.X, self.expm1_func) if exponentiate_values else self.X ) - self._aggregate_group_stats(X, need_var=need_var) if self.ireference is None: - self._derive_rest_stats(X, need_var=need_var) + self._stats_vs_rest(X, need_var=need_var) + else: + self._stats_vs_reference(X, need_var=need_var) - def _aggregate_group_stats(self, X, *, need_var: bool) -> None: - """Populate ``self.{means, vars, pts}`` via one batched - :func:`scanpy.get.aggregate` call. + def _aggregate_group_stats( + self, X_used, codes: NDArray[np.int64], n_parts: int, *, need_var: bool + ): + """One batched :func:`scanpy.get.aggregate` over ``X_used`` grouped by + ``codes`` (values ``0 .. n_parts-1``). - ``var`` is requested only when ``need_var`` is True; otherwise - ``self.vars`` is left as ``None``. + Returns ``(mean, var, nnz)`` of shape ``(n_parts, n_genes)``, + zero-filled for partitions with no cells. ``var`` is ``None`` unless + ``need_var``; ``nnz`` is ``None`` unless ``self.comp_pts``. """ - n_genes = X.shape[1] - n_groups = self.groups_masks_obs.shape[0] - - self.means = np.zeros((n_groups, n_genes)) - self.vars = np.zeros((n_groups, n_genes)) if need_var else None - self.pts = np.zeros((n_groups, n_genes)) if self.comp_pts else None - - mask = self.grouping_mask.to_numpy() - X_used = X if mask.all() else X[mask] - codes = pd.Categorical(self.grouping, categories=self.groups_order).codes + n_genes = X_used.shape[1] + mean = np.zeros((n_parts, n_genes)) + var = np.zeros((n_parts, n_genes)) if need_var else None + nnz = np.zeros((n_parts, n_genes)) if self.comp_pts else None funcs = ["mean"] if need_var: @@ -290,75 +291,135 @@ def _aggregate_group_stats(self, X, *, need_var: bool) -> None: agg_adata = AnnData( X=X_used, obs=pd.DataFrame( - {"_g": pd.Categorical(codes, categories=range(n_groups))}, + {"_g": pd.Categorical(codes, categories=range(n_parts))}, index=pd.RangeIndex(len(codes)).astype(str), ), ) out = aggregate(agg_adata, by="_g", func=funcs, dof=1) - - # assign computed means/vars/pts back to self objs. - out_idx = out.obs_names.astype(int).to_numpy() - self.means[out_idx] = np.asarray(out.layers["mean"]) + idx = out.obs_names.astype(int).to_numpy() + mean[idx] = np.asarray(out.layers["mean"]) if need_var: - self.vars[out_idx] = np.asarray(out.layers["var"]) + var[idx] = np.asarray(out.layers["var"]) if self.comp_pts: - nnz_per_group = np.asarray(out.layers["count_nonzero"]) - n_per_group = self.groups_masks_obs[out_idx].sum(axis=1) - self.pts[out_idx] = nnz_per_group / n_per_group[:, None] - - def _derive_rest_stats(self, X, *, need_var: bool) -> None: - """Populate ``self.means_rest`` / ``self.vars_rest`` / - ``self.pts_rest``. - - Mean and pts are derived from per-group totals via stable subtraction - (no extra ``X`` pass when every cell falls in a selected group). Rest - variance, when requested, is computed directly per group via - :func:`mean_var` on ``X[~mask_g]`` — this avoids the cancellation that - a subtraction-based variance derivation would introduce. + nnz[idx] = np.asarray(out.layers["count_nonzero"]) + return mean, var, nnz + + def _stats_vs_reference(self, X, *, need_var: bool) -> None: + """vs-reference: aggregate the selected-group cells only (the + reference is itself one of the selected groups). No rest derivation. """ - n_total = X.shape[0] - n_genes = X.shape[1] - n_groups = self.groups_masks_obs.shape[0] - n = self.groups_masks_obs.sum(axis=1).astype(np.int64) - n_arr = n[:, None] - n_r = (n_total - n)[:, None] - mask_all = self.grouping_mask.to_numpy().all() - - # Compute rest means without another pass - # TODO: can any cells not be part of a group? (will mask_all ever be False) - # If not, remove fallback - sum_g = self.means * n_arr - sum_total = sum_g.sum(0) if mask_all else np.asarray(X.sum(axis=0)).ravel() - self.means_rest = (sum_total - sum_g) / n_r - - # TODO: if `aggregate` exposed `sum_of_squares` (an additive - # primitive it already computes internally for `var`), rest variance - # could be derived from the (n_groups, n_genes) aggregate output via - # Chan's parallel/pairwise formula, replacing this n_groups-pass loop: - # M2_rest = M2_total - M2_g - delta**2 * n_g * n_r / n_total - # var_rest = M2_rest / (n_r - 1) - # Faster, and Dask-streamable. - if need_var: - self.vars_rest = np.zeros((n_groups, n_genes)) - for g, mg in enumerate(self.groups_masks_obs): - _, self.vars_rest[g] = mean_var(X[~mg], axis=0, correction=1) - else: - self.vars_rest = None + mask = self.grouping_mask.to_numpy() + X_used = X if mask.all() else X[mask] + codes = pd.Categorical(self.grouping, categories=self.groups_order).codes + k = self.groups_masks_obs.shape[0] - # Rest nnz — stable integer subtraction. + self.means, self.vars, nnz = self._aggregate_group_stats( + X_used, codes, k, need_var=need_var + ) if self.comp_pts: - nnz_g = self.pts * n_arr - if mask_all: - nnz_total = nnz_g.sum(0) - else: - nnz_total = ( - X.getnnz(axis=0) - if isinstance(X, CSBase) - else np.count_nonzero(X, axis=0) - ) - self.pts_rest = (nnz_total - nnz_g) / n_r + n_per_group = self.groups_masks_obs.sum(axis=1) + self.pts = nnz / n_per_group[:, None] else: - self.pts_rest = None + self.pts = None + + def _stats_vs_rest(self, X, *, need_var: bool) -> None: + """vs-rest: partition *all* cells into the ``k`` selected-group + partitions plus one "remainder" partition (cells in no selected group + — non-selected groups and unassigned/NaN). Each group's "rest" is the + forward Chan-combine of every other partition. + """ + k = self.groups_masks_obs.shape[0] + + # Codes: each cell's selected-group index, or `k` (the remainder + # partition) for cells in no selected group (non-selected / NaN). + sel = pd.Categorical(self.group_col, categories=self.groups_order).codes + codes = np.where(sel >= 0, sel, k).astype(np.int64) + part_n = np.bincount(codes, minlength=k + 1) # partition k == remainder + n_sel = part_n[:k] + + mean, var, nnz = self._aggregate_group_stats( + X, codes, k + 1, need_var=need_var + ) + + # Selected-group arm of the test (the remainder partition is excluded). + self.means = mean[:k] + self.vars = var[:k] if need_var else None + self.pts = nnz[:k] / n_sel[:, None] if self.comp_pts else None + + # M2 = var * (n - 1); forced to 0 for partitions with <= 1 cell so a + # singleton remainder (aggregate var undefined there) is harmless. + if need_var: + with np.errstate(invalid="ignore"): + M2 = var * (part_n[:, None] - 1) + M2[part_n <= 1] = 0.0 + else: + M2 = None + + self._derive_rest_stats(part_n, mean, M2, nnz, k, need_var=need_var) + + def _derive_rest_stats( + self, part_n, mean, M2, nnz, k: int, *, need_var: bool + ) -> None: + """Set ``means_rest``/``vars_rest``/``pts_rest``: for each selected + group ``g``, the forward Chan-combine of every partition except ``g``. + + Chan's parallel combine adds only non-negative terms + (``m2 = m2_a + m2_b + delta**2 * n_a * n_b / n``); a prefix/suffix scan + yields every leave-one-out with no subtraction, so the result has no + catastrophic cancellation for any group sizes/means/variances. + See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + """ + comp_pts = self.comp_pts + n_parts = k + 1 + + def comb(a, b): + n_a, mean_a, m2_a, nnz_a = a + n_b, mean_b, m2_b, nnz_b = b + if n_a == 0: + return b + if n_b == 0: + return a + n = n_a + n_b + delta = mean_b - mean_a + mean_c = mean_a + delta * (n_b / n) + m2_c = ( + m2_a + m2_b + delta * delta * (n_a * n_b / n) + if need_var + else None + ) + nnz_c = nnz_a + nnz_b if comp_pts else None + return (n, mean_c, m2_c, nnz_c) + + parts = [ + ( + int(part_n[i]), + mean[i], + M2[i] if need_var else None, + nnz[i] if comp_pts else None, + ) + for i in range(n_parts) + ] + prefix = [parts[0]] * n_parts + suffix = [parts[-1]] * n_parts + for i in range(1, n_parts): + prefix[i] = comb(prefix[i - 1], parts[i]) + for i in range(n_parts - 2, -1, -1): + suffix[i] = comb(parts[i], suffix[i + 1]) + + n_genes = mean.shape[1] + empty = (0, None, None, None) + self.means_rest = np.zeros((k, n_genes)) + self.vars_rest = np.zeros((k, n_genes)) if need_var else None + self.pts_rest = np.zeros((k, n_genes)) if comp_pts else None + for g in range(k): + left = prefix[g - 1] if g >= 1 else empty + right = suffix[g + 1] if g + 1 < n_parts else empty + n_r, mean_r, m2_r, nnz_r = comb(left, right) + self.means_rest[g] = mean_r + if need_var: + self.vars_rest[g] = np.maximum(m2_r / max(n_r - 1, 1), 0.0) + if comp_pts: + self.pts_rest[g] = nnz_r / n_r def t_test( self, method: Literal["t-test", "t-test_overestim_var"] From b78aaebe54952a884b45c331e0a95d986a31bc06 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 19 May 2026 04:08:13 +0000 Subject: [PATCH 11/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/tools/_rank_genes_groups.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 4fe5df04a8..1d40f9b063 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -337,9 +337,7 @@ def _stats_vs_rest(self, X, *, need_var: bool) -> None: part_n = np.bincount(codes, minlength=k + 1) # partition k == remainder n_sel = part_n[:k] - mean, var, nnz = self._aggregate_group_stats( - X, codes, k + 1, need_var=need_var - ) + mean, var, nnz = self._aggregate_group_stats(X, codes, k + 1, need_var=need_var) # Selected-group arm of the test (the remainder partition is excluded). self.means = mean[:k] @@ -382,11 +380,7 @@ def comb(a, b): n = n_a + n_b delta = mean_b - mean_a mean_c = mean_a + delta * (n_b / n) - m2_c = ( - m2_a + m2_b + delta * delta * (n_a * n_b / n) - if need_var - else None - ) + m2_c = m2_a + m2_b + delta * delta * (n_a * n_b / n) if need_var else None nnz_c = nnz_a + nnz_b if comp_pts else None return (n, mean_c, m2_c, nnz_c) From cc5783f4fb04f79ebb40cbb86a1036eb96c8d48f Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Wed, 10 Jun 2026 15:39:29 -0700 Subject: [PATCH 12/14] njit _derive_rest_stats via shared _chan_combine Reuse aggregate's _chan_combine kernel in a numba leave-one-out scan (parallel over genes) for the vs-rest mean/variance, and compute pts_rest by exact count subtraction. Results are unchanged. --- src/scanpy/tools/_rank_genes_groups.py | 138 +++++++++++++++---------- 1 file changed, 83 insertions(+), 55 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 1d40f9b063..f520168cfe 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -23,6 +23,7 @@ raise_not_implemented_error_if_backed_type, ) from ..get import _check_mask, _get_obs_rep, aggregate +from ..get._aggregated import _chan_combine if TYPE_CHECKING: from collections.abc import Generator, Iterable @@ -167,6 +168,66 @@ def _apply_expm1_preserving_sparsity(X, expm1_func): return expm1_func(X) +@njit +def _derive_rest_njit( + part_n: NDArray[np.float64], + mean: NDArray[np.float64], + m2: NDArray[np.float64], + k: int, + do_var: bool, # noqa: FBT001 +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + """Leave-one-out Chan combine for every selected group, parallel over genes. + + For each gene, a forward (prefix) and backward (suffix) running combine over the + ``k + 1`` partitions is built; group ``g``'s "rest" is ``prefix[g-1]`` combined + with ``suffix[g+1]`` — every partition except ``g``, with no subtraction. ``m2`` + is ignored (and ``vars_rest`` left zero) when ``do_var`` is False. + """ + n_parts, n_genes = mean.shape + means_rest = np.zeros((k, n_genes)) + vars_rest = np.zeros((k, n_genes)) + for j in numba.prange(n_genes): + pre_n = np.empty(n_parts) + pre_m = np.empty(n_parts) + pre_v = np.empty(n_parts) + pre_n[0] = part_n[0] + pre_m[0] = mean[0, j] + pre_v[0] = m2[0, j] if do_var else 0.0 + for i in range(1, n_parts): + b2 = m2[i, j] if do_var else 0.0 + pre_n[i], pre_m[i], pre_v[i] = _chan_combine( + pre_n[i - 1], pre_m[i - 1], pre_v[i - 1], part_n[i], mean[i, j], b2 + ) + suf_n = np.empty(n_parts) + suf_m = np.empty(n_parts) + suf_v = np.empty(n_parts) + last = n_parts - 1 + suf_n[last] = part_n[last] + suf_m[last] = mean[last, j] + suf_v[last] = m2[last, j] if do_var else 0.0 + for i in range(n_parts - 2, -1, -1): + b2 = m2[i, j] if do_var else 0.0 + suf_n[i], suf_m[i], suf_v[i] = _chan_combine( + part_n[i], mean[i, j], b2, suf_n[i + 1], suf_m[i + 1], suf_v[i + 1] + ) + for g in range(k): + if g >= 1: + ln, lm, lv = pre_n[g - 1], pre_m[g - 1], pre_v[g - 1] + else: + ln, lm, lv = 0.0, 0.0, 0.0 + if g + 1 < n_parts: + rn, rm, rv = suf_n[g + 1], suf_m[g + 1], suf_v[g + 1] + else: + rn, rm, rv = 0.0, 0.0, 0.0 + n_r, mean_r, m2_r = _chan_combine(ln, lm, lv, rn, rm, rv) + means_rest[g, j] = mean_r + if do_var: + denom = n_r - 1.0 if n_r >= 2.0 else 1.0 + v = m2_r / denom + vars_rest[g, j] = v if v > 0.0 else 0.0 + return means_rest, vars_rest + + class _RankGenes: def __init__( self, @@ -358,62 +419,29 @@ def _stats_vs_rest(self, X, *, need_var: bool) -> None: def _derive_rest_stats( self, part_n, mean, M2, nnz, k: int, *, need_var: bool ) -> None: - """Set ``means_rest``/``vars_rest``/``pts_rest``: for each selected - group ``g``, the forward Chan-combine of every partition except ``g``. - - Chan's parallel combine adds only non-negative terms - (``m2 = m2_a + m2_b + delta**2 * n_a * n_b / n``); a prefix/suffix scan - yields every leave-one-out with no subtraction, so the result has no - catastrophic cancellation for any group sizes/means/variances. - See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + """Set ``means_rest``/``vars_rest``/``pts_rest``: for each selected group + ``g``, the Chan-combine of every partition except ``g`` (its "rest"). + + Mean and variance use a leave-one-out prefix/suffix scan of + :func:`~scanpy.get._aggregated._chan_combine` over the ``k + 1`` partitions + (:func:`_derive_rest_njit`); it adds only non-negative terms, so there is no + catastrophic cancellation for any group sizes. Counts combine by plain + summation, so ``pts_rest`` is the exact total-minus-group difference. """ - comp_pts = self.comp_pts - n_parts = k + 1 - - def comb(a, b): - n_a, mean_a, m2_a, nnz_a = a - n_b, mean_b, m2_b, nnz_b = b - if n_a == 0: - return b - if n_b == 0: - return a - n = n_a + n_b - delta = mean_b - mean_a - mean_c = mean_a + delta * (n_b / n) - m2_c = m2_a + m2_b + delta * delta * (n_a * n_b / n) if need_var else None - nnz_c = nnz_a + nnz_b if comp_pts else None - return (n, mean_c, m2_c, nnz_c) - - parts = [ - ( - int(part_n[i]), - mean[i], - M2[i] if need_var else None, - nnz[i] if comp_pts else None, - ) - for i in range(n_parts) - ] - prefix = [parts[0]] * n_parts - suffix = [parts[-1]] * n_parts - for i in range(1, n_parts): - prefix[i] = comb(prefix[i - 1], parts[i]) - for i in range(n_parts - 2, -1, -1): - suffix[i] = comb(parts[i], suffix[i + 1]) - - n_genes = mean.shape[1] - empty = (0, None, None, None) - self.means_rest = np.zeros((k, n_genes)) - self.vars_rest = np.zeros((k, n_genes)) if need_var else None - self.pts_rest = np.zeros((k, n_genes)) if comp_pts else None - for g in range(k): - left = prefix[g - 1] if g >= 1 else empty - right = suffix[g + 1] if g + 1 < n_parts else empty - n_r, mean_r, m2_r, nnz_r = comb(left, right) - self.means_rest[g] = mean_r - if need_var: - self.vars_rest[g] = np.maximum(m2_r / max(n_r - 1, 1), 0.0) - if comp_pts: - self.pts_rest[g] = nnz_r / n_r + means_rest, vars_rest = _derive_rest_njit( + np.ascontiguousarray(part_n, dtype=np.float64), + mean, + M2 if need_var else mean, # placeholder; unused when need_var is False + k, + need_var, + ) + self.means_rest = means_rest + self.vars_rest = vars_rest if need_var else None + if self.comp_pts: + n_rest = self.X.shape[0] - part_n[:k] + self.pts_rest = (nnz.sum(axis=0) - nnz[:k]) / n_rest[:, None] + else: + self.pts_rest = None def t_test( self, method: Literal["t-test", "t-test_overestim_var"] From 93ccbd64aa52fb1b0b3204aa7809755703a38744 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Wed, 10 Jun 2026 16:53:08 -0700 Subject: [PATCH 13/14] draft of aggregate's chan w/ njit --- src/scanpy/tools/_rank_genes_groups.py | 128 ++++++++++++------------- 1 file changed, 61 insertions(+), 67 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index f520168cfe..2c2d5a47b9 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -168,64 +168,62 @@ def _apply_expm1_preserving_sparsity(X, expm1_func): return expm1_func(X) +@numba.njit +def _scan( + part_n: NDArray[np.float64], + mean: NDArray[np.float64], + m2: NDArray[np.float64], + j: int, + step: int, +) -> NDArray[np.float64]: + """Running Chan combine over the partitions for gene ``j``. + + ``acc[i]`` is the combined ``(count, mean, M2)`` of partition ``i`` together with + every partition on the ``step`` side of it (``step=1`` forward / prefix, + ``step=-1`` backward / suffix). + """ + n_parts = part_n.shape[0] + acc = np.empty((n_parts, 3)) + n = m = v = 0.0 + i = 0 if step == 1 else n_parts - 1 + for _ in range(n_parts): + n, m, v = _chan_combine(n, m, v, part_n[i], mean[i, j], m2[i, j]) + acc[i, 0], acc[i, 1], acc[i, 2] = n, m, v + i += step + return acc + + @njit -def _derive_rest_njit( +def _vars_rest( part_n: NDArray[np.float64], mean: NDArray[np.float64], m2: NDArray[np.float64], k: int, - do_var: bool, # noqa: FBT001 -) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - """Leave-one-out Chan combine for every selected group, parallel over genes. - - For each gene, a forward (prefix) and backward (suffix) running combine over the - ``k + 1`` partitions is built; group ``g``'s "rest" is ``prefix[g-1]`` combined - with ``suffix[g+1]`` — every partition except ``g``, with no subtraction. ``m2`` - is ignored (and ``vars_rest`` left zero) when ``do_var`` is False. +) -> NDArray[np.float64]: + """Leave-one-out variance for each selected group, parallel over genes. + + Group ``g``'s "rest" combines the forward scan up to ``g - 1`` with the backward + scan from ``g + 1`` — every partition except ``g`` — so variances are never + subtracted (Chan's cancellation-free combine). """ n_parts, n_genes = mean.shape - means_rest = np.zeros((k, n_genes)) vars_rest = np.zeros((k, n_genes)) for j in numba.prange(n_genes): - pre_n = np.empty(n_parts) - pre_m = np.empty(n_parts) - pre_v = np.empty(n_parts) - pre_n[0] = part_n[0] - pre_m[0] = mean[0, j] - pre_v[0] = m2[0, j] if do_var else 0.0 - for i in range(1, n_parts): - b2 = m2[i, j] if do_var else 0.0 - pre_n[i], pre_m[i], pre_v[i] = _chan_combine( - pre_n[i - 1], pre_m[i - 1], pre_v[i - 1], part_n[i], mean[i, j], b2 - ) - suf_n = np.empty(n_parts) - suf_m = np.empty(n_parts) - suf_v = np.empty(n_parts) - last = n_parts - 1 - suf_n[last] = part_n[last] - suf_m[last] = mean[last, j] - suf_v[last] = m2[last, j] if do_var else 0.0 - for i in range(n_parts - 2, -1, -1): - b2 = m2[i, j] if do_var else 0.0 - suf_n[i], suf_m[i], suf_v[i] = _chan_combine( - part_n[i], mean[i, j], b2, suf_n[i + 1], suf_m[i + 1], suf_v[i + 1] - ) + prefix = _scan(part_n, mean, m2, j, 1) + suffix = _scan(part_n, mean, m2, j, -1) for g in range(k): + right = suffix[g + 1] # everything after g; g + 1 < n_parts always holds if g >= 1: - ln, lm, lv = pre_n[g - 1], pre_m[g - 1], pre_v[g - 1] - else: - ln, lm, lv = 0.0, 0.0, 0.0 - if g + 1 < n_parts: - rn, rm, rv = suf_n[g + 1], suf_m[g + 1], suf_v[g + 1] + left = prefix[g - 1] + n_r, _, m2_r = _chan_combine( + left[0], left[1], left[2], right[0], right[1], right[2] + ) else: - rn, rm, rv = 0.0, 0.0, 0.0 - n_r, mean_r, m2_r = _chan_combine(ln, lm, lv, rn, rm, rv) - means_rest[g, j] = mean_r - if do_var: - denom = n_r - 1.0 if n_r >= 2.0 else 1.0 - v = m2_r / denom - vars_rest[g, j] = v if v > 0.0 else 0.0 - return means_rest, vars_rest + n_r, m2_r = right[0], right[2] + denom = n_r - 1.0 if n_r >= 2.0 else 1.0 + v = m2_r / denom + vars_rest[g, j] = v if v > 0.0 else 0.0 + return vars_rest class _RankGenes: @@ -419,29 +417,25 @@ def _stats_vs_rest(self, X, *, need_var: bool) -> None: def _derive_rest_stats( self, part_n, mean, M2, nnz, k: int, *, need_var: bool ) -> None: - """Set ``means_rest``/``vars_rest``/``pts_rest``: for each selected group - ``g``, the Chan-combine of every partition except ``g`` (its "rest"). - - Mean and variance use a leave-one-out prefix/suffix scan of - :func:`~scanpy.get._aggregated._chan_combine` over the ``k + 1`` partitions - (:func:`_derive_rest_njit`); it adds only non-negative terms, so there is no - catastrophic cancellation for any group sizes. Counts combine by plain - summation, so ``pts_rest`` is the exact total-minus-group difference. + """Set ``means_rest``/``vars_rest``/``pts_rest`` for each selected group ``g`` + (statistics over every cell *not* in ``g``). + + ``means_rest`` and ``pts_rest`` are linear in the partitions, so they are the + exact total-minus-group difference. Variance would lose precision under such + subtraction, so ``vars_rest`` uses the cancellation-free Chan leave-one-out + scan (:func:`_vars_rest`) — only the variance-based tests request it. """ - means_rest, vars_rest = _derive_rest_njit( - np.ascontiguousarray(part_n, dtype=np.float64), - mean, - M2 if need_var else mean, # placeholder; unused when need_var is False - k, - need_var, + n_rest = (self.X.shape[0] - part_n[:k])[:, None] + total = (part_n[:, None] * mean).sum(axis=0) + self.means_rest = (total - part_n[:k, None] * mean[:k]) / n_rest + self.vars_rest = ( + _vars_rest(np.ascontiguousarray(part_n, dtype=np.float64), mean, M2, k) + if need_var + else None + ) + self.pts_rest = ( + (nnz.sum(axis=0) - nnz[:k]) / n_rest if self.comp_pts else None ) - self.means_rest = means_rest - self.vars_rest = vars_rest if need_var else None - if self.comp_pts: - n_rest = self.X.shape[0] - part_n[:k] - self.pts_rest = (nnz.sum(axis=0) - nnz[:k]) / n_rest[:, None] - else: - self.pts_rest = None def t_test( self, method: Literal["t-test", "t-test_overestim_var"] From 58ffca9adefaad923947c4573d7467a4d366359e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jun 2026 23:53:20 +0000 Subject: [PATCH 14/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scanpy/tools/_rank_genes_groups.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index 2c2d5a47b9..c2e0ba2ec4 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -433,9 +433,7 @@ def _derive_rest_stats( if need_var else None ) - self.pts_rest = ( - (nnz.sum(axis=0) - nnz[:k]) / n_rest if self.comp_pts else None - ) + self.pts_rest = (nnz.sum(axis=0) - nnz[:k]) / n_rest if self.comp_pts else None def t_test( self, method: Literal["t-test", "t-test_overestim_var"]