Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/4147.perf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use [Welford's algorithm][] for mean-var calculation in {func}`scanpy.get.aggregate` for in-memory (i.e., non-dask) arrays {smaller}`I Gold`
27 changes: 10 additions & 17 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr
from ._kernels import (
agg_sum_csc,
agg_sum_csr,
mean_var_csc,
mean_var_csr,
mean_var_dense,
)
from .get import _check_mask

if TYPE_CHECKING:
Expand Down Expand Up @@ -117,11 +123,8 @@ def mean(self) -> Array:
def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
"""Compute the count, as well as mean and variance per feature, per group of observations.

The formula `Var(X) = E(X^2) - E(X)^2` suffers loss of precision when the variance is a
very small fraction of the squared mean. In particular, when X is constant, the formula may
nonetheless be non-zero. By default, our implementation resets the variance to exactly zero
when the computed variance, relative to the squared mean, nears limit of precision of the
floating-point significand.
Mean and variance are computed with Welford's online algorithm, which is
numerically stable for constant or near-constant inputs.

Params
------
Expand All @@ -137,21 +140,11 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:

group_counts = np.bincount(self.groupby.codes)
if isinstance(self.data, np.ndarray):
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None]
sq_mean = mean_**2
var_ = mean_sq - sq_mean
mean_, var_ = mean_var_dense(self.indicator_matrix.tocsr(), self.data)
else:
mean_, var_ = (
mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc
)(self.indicator_matrix, self.data)
sq_mean = mean_**2
# TODO: Why these values exactly? Because they are high relative to the datatype?
# (unchanged from original code: https://github.com/scverse/anndata/pull/564)
precision = 2 << (42 if self.data.dtype == np.float64 else 20)
# detects loss of precision in mean_sq - sq_mean, which suggests variance is 0
var_[precision * var_ < sq_mean] = 0
if dof != 0:
var_ *= (group_counts / (group_counts - dof))[:, np.newaxis]
return mean_, var_
Expand Down
131 changes: 98 additions & 33 deletions src/scanpy/get/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,69 +48,134 @@ def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None:
out[cat, col] += data.data[j]


@njit

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make these nogil or provide an option for fau to provide nogil njit

def mean_var_dense(
indicator: CSRBase, data: NDArray
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
# Welford's online algorithm, parallelized over categories. The indicator
# CSR lists which observations belong to each category, allowing mask
# handling to be folded in naturally.
n_cats = indicator.shape[0]
n_features = data.shape[1]
mean = np.zeros((n_cats, n_features), dtype="float64")
var = np.zeros((n_cats, n_features), dtype="float64")

for cat in numba.prange(n_cats):
start = indicator.indptr[cat]
stop = indicator.indptr[cat + 1]
n = 0
for row_num in range(start, stop):
obs = indicator.indices[row_num]
n += 1
for col in range(n_features):
value = np.float64(data[obs, col])
delta = value - mean[cat, col]
mean[cat, col] += delta / n
delta2 = value - mean[cat, col]
var[cat, col] += delta * delta2
if n > 0:
for col in range(n_features):
var[cat, col] /= n
return mean, var


@njit
def mean_var_csr(
indicator: CSRBase,
data: CSCBase,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

for cat_num in numba.prange(indicator.shape[0]):
# Welford's online algorithm over nonzeros, then merge with the block of
# implicit zeros per (category, feature). Merging a Welford accumulator
# (n_A, mean_A, M2_A) with k zeros gives:
# mean = mean_A * n_A / (n_A + k)
# M2_new = M2_A + mean_A^2 * n_A * k / (n_A + k)
n_cats = indicator.shape[0]
n_features = data.shape[1]
mean = np.zeros((n_cats, n_features), dtype="float64")
var = np.zeros((n_cats, n_features), dtype="float64")

for cat_num in numba.prange(n_cats):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
n_obs = stop_cat_idx - start_cat_idx
if n_obs == 0:
continue

n_nonzero = np.zeros(n_features, dtype=np.int64)

for row_num in range(start_cat_idx, stop_cat_idx):
obs_per_cat = indicator.indices[row_num]

start_obs = data.indptr[obs_per_cat]
end_obs = data.indptr[obs_per_cat + 1]

for j in range(start_obs, end_obs):
col = data.indices[j]
value = np.float64(data.data[j])
value = data.data[j]
mean[cat_num, col] += value
var[cat_num, col] += value * value

n_obs = stop_cat_idx - start_cat_idx
mean_cat = mean[cat_num, :] / n_obs
mean[cat_num, :] = mean_cat
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
n_nonzero[col] += 1
n = n_nonzero[col]
delta = value - mean[cat_num, col]
mean[cat_num, col] += delta / n
delta2 = value - mean[cat_num, col]
var[cat_num, col] += delta * delta2

for col in range(n_features):
n_nz = n_nonzero[col]
k = n_obs - n_nz
if k > 0 and n_nz > 0:
mean_a = mean[cat_num, col]
mean[cat_num, col] = mean_a * n_nz / n_obs
var[cat_num, col] += mean_a * mean_a * n_nz * k / n_obs
var[cat_num, col] /= n_obs
return mean, var


@njit
def mean_var_csc(
indicator: CSRBase, data: CSCBase
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
# Welford's online algorithm, parallelized over columns. For each column
# we accumulate per-category over the explicit nonzeros, then merge each
# category's accumulator with its block of implicit zeros (see merge
# formula in `mean_var_csr`).
n_cats = indicator.shape[0]
n_features = data.shape[1]
obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

for cat in range(indicator.shape[0]):
n_obs_per_cat = np.zeros(n_cats, dtype=np.int64)
for cat in range(n_cats):
n_obs_per_cat[cat] = indicator.indptr[cat + 1] - indicator.indptr[cat]
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
obs_to_cat[indicator.indices[k]] = cat

for col in numba.prange(data.shape[1]):
mean = np.zeros((n_cats, n_features), dtype="float64")
var = np.zeros((n_cats, n_features), dtype="float64")

for col in numba.prange(n_features):
n_nonzero = np.zeros(n_cats, dtype=np.int64)
start = data.indptr[col]
end = data.indptr[col + 1]

for j in range(start, end):
obs = data.indices[j]
cat = obs_to_cat[obs]

if cat != -1:
value = np.float64(data.data[j])
value = data.data[j]
mean[cat, col] += value
var[cat, col] += value * value

for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
n_obs = stop_cat_idx - start_cat_idx
mean_cat = mean[cat_num, :] / n_obs
mean[cat_num, :] = mean_cat
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
if cat == -1:
continue
value = np.float64(data.data[j])
n_nonzero[cat] += 1
n = n_nonzero[cat]
delta = value - mean[cat, col]
mean[cat, col] += delta / n
delta2 = value - mean[cat, col]
var[cat, col] += delta * delta2

for cat in range(n_cats):
n_obs = n_obs_per_cat[cat]
if n_obs == 0:
continue
n_nz = n_nonzero[cat]
k = n_obs - n_nz
if k > 0 and n_nz > 0:
mean_a = mean[cat, col]
mean[cat, col] = mean_a * n_nz / n_obs
var[cat, col] += mean_a * mean_a * n_nz * k / n_obs
var[cat, col] /= n_obs
return mean, var
Loading