diff --git a/docs/release-notes/4147.perf.md b/docs/release-notes/4147.perf.md new file mode 100644 index 0000000000..f6044655c8 --- /dev/null +++ b/docs/release-notes/4147.perf.md @@ -0,0 +1 @@ +Use [Welford's algorithm][] for mean-var calculation in {func}`scanpy.get.aggregate` for in-memory (i.e., non-dask) arrays {smaller}`I Gold` diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..3fad260f7d 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -13,7 +13,13 @@ from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals -from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr +from ._kernels import ( + agg_sum_csc, + agg_sum_csr, + mean_var_csc, + mean_var_csr, + mean_var_dense, +) from .get import _check_mask if TYPE_CHECKING: @@ -117,11 +123,8 @@ def mean(self) -> Array: def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. - The formula `Var(X) = E(X^2) - E(X)^2` suffers loss of precision when the variance is a - very small fraction of the squared mean. In particular, when X is constant, the formula may - nonetheless be non-zero. By default, our implementation resets the variance to exactly zero - when the computed variance, relative to the squared mean, nears limit of precision of the - floating-point significand. + Mean and variance are computed with Welford's online algorithm, which is + numerically stable for constant or near-constant inputs. Params ------ @@ -137,21 +140,11 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: group_counts = np.bincount(self.groupby.codes) if isinstance(self.data, np.ndarray): - mean_ = self.mean() - # sparse matrices do not support ** for elementwise power. - mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None] - sq_mean = mean_**2 - var_ = mean_sq - sq_mean + mean_, var_ = mean_var_dense(self.indicator_matrix.tocsr(), self.data) else: mean_, var_ = ( mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc )(self.indicator_matrix, self.data) - sq_mean = mean_**2 - # TODO: Why these values exactly? Because they are high relative to the datatype? - # (unchanged from original code: https://github.com/scverse/anndata/pull/564) - precision = 2 << (42 if self.data.dtype == np.float64 else 20) - # detects loss of precision in mean_sq - sq_mean, which suggests variance is 0 - var_[precision * var_ < sq_mean] = 0 if dof != 0: var_ *= (group_counts / (group_counts - dof))[:, np.newaxis] return mean_, var_ diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index 4d25bd06be..a0ecd7f35f 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -48,34 +48,84 @@ def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None: out[cat, col] += data.data[j] +@njit +def mean_var_dense( + indicator: CSRBase, data: NDArray +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + # Welford's online algorithm, parallelized over categories. The indicator + # CSR lists which observations belong to each category, allowing mask + # handling to be folded in naturally. + n_cats = indicator.shape[0] + n_features = data.shape[1] + mean = np.zeros((n_cats, n_features), dtype="float64") + var = np.zeros((n_cats, n_features), dtype="float64") + + for cat in numba.prange(n_cats): + start = indicator.indptr[cat] + stop = indicator.indptr[cat + 1] + n = 0 + for row_num in range(start, stop): + obs = indicator.indices[row_num] + n += 1 + for col in range(n_features): + value = np.float64(data[obs, col]) + delta = value - mean[cat, col] + mean[cat, col] += delta / n + delta2 = value - mean[cat, col] + var[cat, col] += delta * delta2 + if n > 0: + for col in range(n_features): + var[cat, col] /= n + return mean, var + + @njit def mean_var_csr( indicator: CSRBase, data: CSCBase, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - - for cat_num in numba.prange(indicator.shape[0]): + # Welford's online algorithm over nonzeros, then merge with the block of + # implicit zeros per (category, feature). Merging a Welford accumulator + # (n_A, mean_A, M2_A) with k zeros gives: + # mean = mean_A * n_A / (n_A + k) + # M2_new = M2_A + mean_A^2 * n_A * k / (n_A + k) + n_cats = indicator.shape[0] + n_features = data.shape[1] + mean = np.zeros((n_cats, n_features), dtype="float64") + var = np.zeros((n_cats, n_features), dtype="float64") + + for cat_num in numba.prange(n_cats): start_cat_idx = indicator.indptr[cat_num] stop_cat_idx = indicator.indptr[cat_num + 1] + n_obs = stop_cat_idx - start_cat_idx + if n_obs == 0: + continue + + n_nonzero = np.zeros(n_features, dtype=np.int64) + for row_num in range(start_cat_idx, stop_cat_idx): obs_per_cat = indicator.indices[row_num] - start_obs = data.indptr[obs_per_cat] end_obs = data.indptr[obs_per_cat + 1] for j in range(start_obs, end_obs): col = data.indices[j] value = np.float64(data.data[j]) - value = data.data[j] - mean[cat_num, col] += value - var[cat_num, col] += value * value - - n_obs = stop_cat_idx - start_cat_idx - mean_cat = mean[cat_num, :] / n_obs - mean[cat_num, :] = mean_cat - var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + n_nonzero[col] += 1 + n = n_nonzero[col] + delta = value - mean[cat_num, col] + mean[cat_num, col] += delta / n + delta2 = value - mean[cat_num, col] + var[cat_num, col] += delta * delta2 + + for col in range(n_features): + n_nz = n_nonzero[col] + k = n_obs - n_nz + if k > 0 and n_nz > 0: + mean_a = mean[cat_num, col] + mean[cat_num, col] = mean_a * n_nz / n_obs + var[cat_num, col] += mean_a * mean_a * n_nz * k / n_obs + var[cat_num, col] /= n_obs return mean, var @@ -83,34 +133,49 @@ def mean_var_csr( def mean_var_csc( indicator: CSRBase, data: CSCBase ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + # Welford's online algorithm, parallelized over columns. For each column + # we accumulate per-category over the explicit nonzeros, then merge each + # category's accumulator with its block of implicit zeros (see merge + # formula in `mean_var_csr`). + n_cats = indicator.shape[0] + n_features = data.shape[1] obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) - - mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - - for cat in range(indicator.shape[0]): + n_obs_per_cat = np.zeros(n_cats, dtype=np.int64) + for cat in range(n_cats): + n_obs_per_cat[cat] = indicator.indptr[cat + 1] - indicator.indptr[cat] for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): obs_to_cat[indicator.indices[k]] = cat - for col in numba.prange(data.shape[1]): + mean = np.zeros((n_cats, n_features), dtype="float64") + var = np.zeros((n_cats, n_features), dtype="float64") + + for col in numba.prange(n_features): + n_nonzero = np.zeros(n_cats, dtype=np.int64) start = data.indptr[col] end = data.indptr[col + 1] for j in range(start, end): obs = data.indices[j] cat = obs_to_cat[obs] - - if cat != -1: - value = np.float64(data.data[j]) - value = data.data[j] - mean[cat, col] += value - var[cat, col] += value * value - - for cat_num in numba.prange(indicator.shape[0]): - start_cat_idx = indicator.indptr[cat_num] - stop_cat_idx = indicator.indptr[cat_num + 1] - n_obs = stop_cat_idx - start_cat_idx - mean_cat = mean[cat_num, :] / n_obs - mean[cat_num, :] = mean_cat - var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + if cat == -1: + continue + value = np.float64(data.data[j]) + n_nonzero[cat] += 1 + n = n_nonzero[cat] + delta = value - mean[cat, col] + mean[cat, col] += delta / n + delta2 = value - mean[cat, col] + var[cat, col] += delta * delta2 + + for cat in range(n_cats): + n_obs = n_obs_per_cat[cat] + if n_obs == 0: + continue + n_nz = n_nonzero[cat] + k = n_obs - n_nz + if k > 0 and n_nz > 0: + mean_a = mean[cat, col] + mean[cat, col] = mean_a * n_nz / n_obs + var[cat, col] += mean_a * mean_a * n_nz * k / n_obs + var[cat, col] /= n_obs return mean, var