From de9c481a14dde3208830d00e7a0b67c2adceb9ce Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 23 Feb 2026 16:26:42 +0100 Subject: [PATCH 01/20] feat: support np.random.Generator --- src/scanpy/_utils/random.py | 29 ++++++---- src/scanpy/datasets/_datasets.py | 9 ++-- src/scanpy/experimental/_docs.py | 4 +- src/scanpy/experimental/pp/_normalization.py | 5 +- src/scanpy/experimental/pp/_recipes.py | 6 ++- src/scanpy/neighbors/__init__.py | 29 +++++----- src/scanpy/preprocessing/_pca/__init__.py | 33 ++++++------ src/scanpy/preprocessing/_pca/_compat.py | 16 +++--- src/scanpy/preprocessing/_recipes.py | 6 +-- .../preprocessing/_scrublet/__init__.py | 53 +++++++++---------- src/scanpy/preprocessing/_scrublet/core.py | 22 ++++---- .../preprocessing/_scrublet/pipeline.py | 22 ++++---- .../preprocessing/_scrublet/sparse_utils.py | 7 +-- src/scanpy/preprocessing/_simple.py | 46 ++++++---------- src/scanpy/preprocessing/_utils.py | 7 +-- src/scanpy/tools/_diffmap.py | 15 +++--- src/scanpy/tools/_dpt.py | 13 +++-- src/scanpy/tools/_draw_graph.py | 36 ++++++------- src/scanpy/tools/_leiden.py | 16 +++--- src/scanpy/tools/_score_genes.py | 19 +++---- src/scanpy/tools/_tsne.py | 9 ++-- src/scanpy/tools/_umap.py | 29 +++++----- src/scanpy/tools/_utils.py | 9 ++-- 23 files changed, 223 insertions(+), 217 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 98d6ce8a1b..aa405c0ce8 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.utils import check_random_state from . import ensure_igraph @@ -23,6 +22,7 @@ "_LegacyRandom", "ith_k_tuple", "legacy_numpy_gen", + "legacy_random_state", "random_k_tuples", "random_str", ] @@ -43,29 +43,29 @@ class _RNGIgraph: See :func:`igraph.set_random_number_generator` for the requirements. """ - def __init__(self, random_state: int | np.random.RandomState = 0) -> None: - self._rng = check_random_state(random_state) + def __init__(self, rng: SeedLike | RNGLike | None) -> None: + self._rng = np.random.default_rng(rng) def getrandbits(self, k: int) -> int: - return self._rng.tomaxint() & ((1 << k) - 1) + lims = np.iinfo(np.uint64) + i = int(self._rng.integers(0, lims.max, dtype=np.uint64)) + return i & ((1 << k) - 1) - def randint(self, a: int, b: int) -> int: - return self._rng.randint(a, b + 1) + def randint(self, a: int, b: int) -> np.int64: + return self._rng.integers(a, b + 1) def __getattr__(self, attr: str): return getattr(self._rng, "normal" if attr == "gauss" else attr) @contextmanager -def set_igraph_random_state( - random_state: int | np.random.RandomState, -) -> Generator[None, None, None]: +def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ensure_igraph() import igraph - rng = _RNGIgraph(random_state) + ig_rng = _RNGIgraph(rng) try: - igraph.set_random_number_generator(rng) + igraph.set_random_number_generator(ig_rng) yield None finally: igraph.set_random_number_generator(random) @@ -114,6 +114,13 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): _FakeRandomGen._delegate() +def legacy_random_state(rng: SeedLike | RNGLike | None) -> np.random.RandomState: + rng = np.random.default_rng(rng) + if isinstance(rng, _FakeRandomGen): + return rng._state + return np.random.RandomState(rng.bit_generator.spawn(1)[0]) + + ################### # Random k-tuples # ################### diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index d6221fe8d2..116c8a572d 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,13 +12,14 @@ from .._compat import deprecated, old_positionals from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs +from .._utils.random import legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists if TYPE_CHECKING: from typing import Literal - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike type VisiumSampleID = Literal[ "V1_Breast_Cancer_Block_A_Section_1", @@ -63,7 +64,7 @@ def blobs( n_centers: int = 5, cluster_std: float = 1.0, n_observations: int = 640, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData: """Gaussian Blobs. @@ -78,7 +79,7 @@ def blobs( n_observations Number of observations. By default, this is the same observation number as in :func:`scanpy.datasets.krumsiek11`. - random_state + rng Determines random number generation for dataset creation. Returns @@ -101,7 +102,7 @@ def blobs( n_features=n_variables, centers=n_centers, cluster_std=cluster_std, - random_state=random_state, + random_state=legacy_random_state(rng), ) return AnnData(x, obs=dict(blobs=y.astype(str))) diff --git a/src/scanpy/experimental/_docs.py b/src/scanpy/experimental/_docs.py index c6f1bf2f8b..a449c317f1 100644 --- a/src/scanpy/experimental/_docs.py +++ b/src/scanpy/experimental/_docs.py @@ -60,8 +60,8 @@ doc_pca_chunk = """\ n_comps Number of principal components to compute in the PCA step. -random_state - Random seed for setting the initial states for the optimization in the PCA step. +rng + Random number generator for setting the initial states for the optimization in the PCA step. kwargs_pca Dictionary of further keyword arguments passed on to `scanpy.pp.pca()`. """ diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index cb34b9902b..eddb899bb5 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -27,6 +27,7 @@ from typing import Any from ..._utils import Empty + from ..._utils.random import RNGLike, SeedLike def _pearson_residuals( @@ -166,7 +167,7 @@ def normalize_pearson_residuals_pca( theta: float = 100, clip: float | None = None, n_comps: int | None = 50, - random_state: float = 0, + rng: SeedLike | RNGLike | None = None, kwargs_pca: Mapping[str, Any] = MappingProxyType({}), mask_var: np.ndarray | str | None | Empty = _empty, use_highly_variable: bool | None = None, @@ -233,7 +234,7 @@ def normalize_pearson_residuals_pca( normalize_pearson_residuals( adata_pca, theta=theta, clip=clip, check_values=check_values ) - pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca) + pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None if inplace: diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 27d272fc4d..5b0357dc4d 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -24,6 +24,8 @@ import pandas as pd from anndata import AnnData + from ..._utils.random import RNGLike, SeedLike + @_doc_params( adata=doc_adata, @@ -42,7 +44,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 batch_key: str | None = None, chunksize: int = 1000, n_comps: int | None = 50, - random_state: float | None = 0, + rng: SeedLike | RNGLike | None = None, kwargs_pca: Mapping[str, Any] = MappingProxyType({}), check_values: bool = True, inplace: bool = True, @@ -133,7 +135,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 experimental.pp.normalize_pearson_residuals( adata_pca, theta=theta, clip=clip, check_values=check_values ) - pca(adata_pca, n_comps=n_comps, random_state=random_state, **kwargs_pca) + pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) if inplace: normalization_param = adata_pca.uns["pearson_residuals_normalization"] diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 3efb392a1d..4fa2b5d597 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -11,13 +11,13 @@ import numpy as np import scipy from scipy import sparse -from sklearn.utils import check_random_state from .. import _utils from .. import logging as logg from .._compat import CSBase, CSRBase, SpBase, old_positionals, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals +from .._utils.random import legacy_random_state from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -36,7 +36,7 @@ from igraph import Graph from numpy.typing import NDArray - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike, _LegacyRandom from ._types import KnnTransformerLike, _Metric, _MetricFn # TODO: make `type` when https://github.com/sphinx-doc/sphinx/pull/13508 is released @@ -90,7 +90,7 @@ def neighbors( # noqa: PLR0913 transformer: KnnTransformerLike | _KnownTransformer | None = None, metric: _Metric | _MetricFn | None = None, metric_kwds: Mapping[str, Any] = MappingProxyType({}), - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, key_added: str | None = None, copy: bool = False, ) -> AnnData | None: @@ -158,8 +158,8 @@ def neighbors( # noqa: PLR0913 Options for the metric. *ignored if ``transformer`` is an instance.* - random_state - A numpy random seed. + rng + A numpy random number generator. *ignored if ``transformer`` is an instance.* key_added @@ -220,14 +220,14 @@ def neighbors( # noqa: PLR0913 transformer=transformer, metric=metric, metric_kwds=metric_kwds, - random_state=random_state, + rng=rng, ) else: params = locals() if ignored := { p.name for p in signature(neighbors).parameters.values() - if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "random_state"} + if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "rng"} if params[p.name] != p.default }: warn( @@ -262,7 +262,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=random_state, + random_state=rng, metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), @@ -583,7 +583,7 @@ def compute_neighbors( transformer: KnnTransformerLike | _KnownTransformer | None = None, metric: _Metric | _MetricFn = "euclidean", metric_kwds: Mapping[str, Any] = MappingProxyType({}), - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> None: """Compute distances and connectivities of neighbors. @@ -619,7 +619,7 @@ def compute_neighbors( n_neighbors=n_neighbors, metric=metric, metric_params=metric_kwds, # most use _params, not _kwds - random_state=random_state, + random_state=legacy_random_state(rng), ) method, transformer, shortcut = self._handle_transformer( method, transformer, knn=knn, kwds=transformer_kwds_default @@ -848,7 +848,7 @@ def compute_eigen( n_comps: int = 15, sym: bool | None = None, sort: Literal["decrease", "increase"] = "decrease", - random_state: _LegacyRandom = 0, + rng: np.random.Generator, ): """Compute eigen decomposition of transition matrix. @@ -861,8 +861,8 @@ def compute_eigen( Instead of computing the eigendecomposition of the assymetric transition matrix, computed the eigendecomposition of the symmetric Ktilde matrix. - random_state - A numpy random seed + rng + A numpy random number generator Returns ------- @@ -895,8 +895,7 @@ def compute_eigen( matrix = matrix.astype(np.float64) # Setting the random initial vector - random_state = check_random_state(random_state) - v0 = random_state.standard_normal(matrix.shape[0]) + v0 = rng.standard_normal(matrix.shape[0]) evals, evecs = sparse.linalg.eigsh( matrix, k=n_comps, which=which, ncv=ncv, v0=v0 ) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index b370c09d7b..87ad1752ee 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -5,12 +5,12 @@ import numpy as np from anndata import AnnData from packaging.version import Version -from sklearn.utils import check_random_state from ... import logging as logg from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type +from ..._utils.random import legacy_random_state from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -25,7 +25,7 @@ from numpy.typing import DTypeLike, NDArray from ..._utils import Empty - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike type MethodDaskML = type[dmld.PCA | dmld.IncrementalPCA | dmld.TruncatedSVD] @@ -64,7 +64,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 svd_solver: SvdSolver | None = None, chunked: bool = False, chunk_size: int | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, return_info: bool = False, mask_var: NDArray[np.bool_] | str | None | Empty = _empty, use_highly_variable: bool | None = None, @@ -157,7 +157,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 chunk_size Number of observations to include in each chunk. Required if `chunked=True` was passed. - random_state + rng Change to use different initial states for the optimization. return_info Only relevant when not passing an :class:`~anndata.AnnData`: @@ -241,18 +241,17 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 msg = f"PCA is not implemented for matrices of type {type(x)} from layers/obsm" raise NotImplementedError(msg) - # check_random_state returns a numpy RandomState when passed an int but # dask needs an int for random state if not isinstance(x, DaskArray): - random_state = check_random_state(random_state) - elif not isinstance(random_state, int): - msg = f"random_state needs to be an int, not a {type(random_state).__name__} when passing a dask array" + rng = np.random.default_rng(rng) + elif not isinstance(rng, int): + msg = f"rng needs to be an int, not a {type(rng).__name__} when passing a dask array" raise TypeError(msg) if chunked: if ( not zero_center - or random_state + or rng is not None or (svd_solver is not None and svd_solver != "arpack") ): logg.debug("Ignoring zero_center, random_state, svd_solver") @@ -287,9 +286,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 "Also the lobpcg solver has been observed to be inaccurate. Please use 'arpack' instead." ) warn(msg, FutureWarning) - x_pca, pca_ = _pca_compat_sparse( - x, n_comps, solver=svd_solver, random_state=random_state - ) + x_pca, pca_ = _pca_compat_sparse(x, n_comps, solver=svd_solver, rng=rng) else: if not isinstance(x, DaskArray): from sklearn.decomposition import PCA @@ -300,13 +297,13 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=random_state, + random_state=legacy_random_state(rng), ) elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask - if random_state != 0: - msg = f"Ignoring {random_state=} when using a sparse dask array" + if rng is not None: + msg = f"Ignoring {rng=} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" @@ -319,7 +316,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=random_state, + random_state=legacy_random_state(rng), ) x_pca = pca_.fit_transform(x) else: @@ -345,7 +342,9 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 " the following components often resemble the exact PCA very closely" ) pca_ = TruncatedSVD( - n_components=n_comps, random_state=random_state, algorithm=svd_solver + n_components=n_comps, + random_state=legacy_random_state(rng), + algorithm=svd_solver, ) x_pca = pca_.fit_transform(x) diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index b1c64e0735..133be7b0f0 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -6,10 +6,11 @@ from fast_array_utils.stats import mean_var from packaging.version import Version from scipy.sparse.linalg import LinearOperator, svds -from sklearn.utils import check_array, check_random_state +from sklearn.utils import check_array from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version +from ..._utils.random import legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -18,7 +19,7 @@ from sklearn.decomposition import PCA from ..._compat import CSBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike def _pca_compat_sparse( @@ -27,12 +28,11 @@ def _pca_compat_sparse( *, solver: Literal["arpack", "lobpcg"], mu: NDArray[np.floating] | None = None, - random_state: _LegacyRandom = None, + rng: SeedLike | RNGLike | None = None, ) -> tuple[NDArray[np.floating], PCA]: """Sparse PCA for scikit-learn <1.4.""" - random_state = check_random_state(random_state) - np.random.set_state(random_state.get_state()) - random_init = np.random.rand(np.min(x.shape)) + rng = np.random.default_rng(rng) + random_init = rng.uniform(size=np.min(x.shape)) x = check_array(x, accept_sparse=["csr", "csc"]) if mu is None: @@ -70,7 +70,9 @@ def rmat_op(v: NDArray[np.floating]): from sklearn.decomposition import PCA - pca = PCA(n_components=n_pcs, svd_solver=solver, random_state=random_state) + pca = PCA( + n_components=n_pcs, svd_solver=solver, random_state=legacy_random_state(rng) + ) pca.explained_variance_ = ev pca.explained_variance_ratio_ = ev_ratio pca.components_ = v diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index d223036873..34cf986d57 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals( @@ -36,7 +36,7 @@ def recipe_weinreb17( cv_threshold: int = 2, n_pcs: int = 50, svd_solver="randomized", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: """Normalize and filter as of :cite:p:`Weinreb2017`. @@ -72,7 +72,7 @@ def recipe_weinreb17( zscore_deprecated(adata.X), n_comps=n_pcs, svd_solver=svd_solver, - random_state=random_state, + rng=rng, ) # update adata adata.obsm["X_pca"] = x_pca diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index b8ea1a267c..b5c7d60a55 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -16,7 +16,7 @@ from .core import Scrublet if TYPE_CHECKING: - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...neighbors import _Metric, _MetricFn @@ -59,7 +59,7 @@ def scrublet( # noqa: PLR0913 threshold: float | None = None, verbose: bool = True, copy: bool = False, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData | None: """Predict doublets using Scrublet :cite:p:`Wolock2019`. @@ -147,7 +147,7 @@ def scrublet( # noqa: PLR0913 copy If :data:`True`, return a copy of the input ``adata`` with Scrublet results added. Otherwise, Scrublet results are added in place. - random_state + rng Initial state for doublet simulation and nearest neighbors. Returns @@ -178,6 +178,7 @@ def scrublet( # noqa: PLR0913 scores for observed transcriptomes and simulated doublets. """ + rng = np.random.default_rng(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -224,7 +225,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): layer="raw", sim_doublet_ratio=sim_doublet_ratio, synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling, - random_seed=random_state, + rng=rng, ) del ad_obs.layers["raw"] if log_transform: @@ -249,7 +250,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): knn_dist_metric=knn_dist_metric, get_doublet_neighbor_parents=get_doublet_neighbor_parents, threshold=threshold, - random_state=random_state, + rng=rng, verbose=verbose, ) @@ -307,18 +308,18 @@ def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, *, - n_neighbors: int | None = None, - expected_doublet_rate: float = 0.05, - stdev_doublet_rate: float = 0.02, - mean_center: bool = True, - normalize_variance: bool = True, - n_prin_comps: int = 30, - use_approx_neighbors: bool | None = None, - knn_dist_metric: _Metric | _MetricFn = "euclidean", - get_doublet_neighbor_parents: bool = False, - threshold: float | None = None, - random_state: _LegacyRandom = 0, - verbose: bool = True, + n_neighbors: int | None, + expected_doublet_rate: float, + stdev_doublet_rate: float, + mean_center: bool, + normalize_variance: bool, + n_prin_comps: int, + use_approx_neighbors: bool | None, + knn_dist_metric: _Metric | _MetricFn, + get_doublet_neighbor_parents: bool, + threshold: float | None, + rng: np.random.Generator, + verbose: bool, ) -> AnnData: """Core function for predicting doublets using Scrublet :cite:p:`Wolock2019`. @@ -376,8 +377,8 @@ def _scrublet_call_doublets( # noqa: PLR0913 practice to check the threshold visually using the `doublet_scores_sim_` histogram and/or based on co-localization of predicted doublets in a 2-D embedding. - random_state - Initial state for doublet simulation and nearest neighbors. + rng + Random number generator for doublet simulation and nearest neighbors. verbose If :data:`True`, log progress updates. @@ -414,7 +415,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 n_neighbors=n_neighbors, expected_doublet_rate=expected_doublet_rate, stdev_doublet_rate=stdev_doublet_rate, - random_state=random_state, + rng=rng, ) # Ensure normalised matrix sparseness as Scrublet does @@ -440,12 +441,10 @@ def _scrublet_call_doublets( # noqa: PLR0913 if mean_center: logg.info("Embedding transcriptomes using PCA...") - pipeline.pca(scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state) + pipeline.pca(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) else: logg.info("Embedding transcriptomes using Truncated SVD...") - pipeline.truncated_svd( - scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state - ) + pipeline.truncated_svd(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) # Score the doublets @@ -477,7 +476,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "random_state": random_state, + "rng": rng, }, } @@ -511,7 +510,7 @@ def scrublet_simulate_doublets( layer: str | None = None, sim_doublet_ratio: float = 2.0, synthetic_doublet_umi_subsampling: float = 1.0, - random_seed: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, ) -> AnnData: """Simulate doublets by adding the counts of random observed transcriptome pairs. @@ -556,7 +555,7 @@ def scrublet_simulate_doublets( """ x = _get_obs_rep(adata, layer=layer) - scrub = Scrublet(x, random_state=random_seed) + scrub = Scrublet(x, rng=rng) scrub.simulate_doublets( sim_doublet_ratio=sim_doublet_ratio, diff --git a/src/scanpy/preprocessing/_scrublet/core.py b/src/scanpy/preprocessing/_scrublet/core.py index a5cb80fd43..511feac334 100644 --- a/src/scanpy/preprocessing/_scrublet/core.py +++ b/src/scanpy/preprocessing/_scrublet/core.py @@ -7,7 +7,6 @@ import pandas as pd from anndata import AnnData, concat from scipy import sparse -from sklearn.utils import check_random_state from ... import logging as logg from ...neighbors import ( @@ -18,11 +17,10 @@ from .sparse_utils import subsample_counts if TYPE_CHECKING: - from numpy.random import RandomState from numpy.typing import NDArray from ..._compat import CSBase, CSCBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...neighbors import _Metric, _MetricFn __all__ = ["Scrublet"] @@ -58,8 +56,8 @@ class Scrublet: stdev_doublet_rate Uncertainty in the expected doublet rate. - random_state - Random state for doublet simulation, approximate + rng + Random number generator for doublet simulation, approximate nearest neighbor search, and PCA/TruncatedSVD. """ @@ -72,12 +70,12 @@ class Scrublet: n_neighbors: InitVar[int | None] = None expected_doublet_rate: float = 0.1 stdev_doublet_rate: float = 0.02 - random_state: InitVar[_LegacyRandom] = 0 + rng: InitVar[SeedLike | RNGLike | None] = None # private fields _n_neighbors: int = field(init=False, repr=False) - _random_state: RandomState = field(init=False, repr=False) + _rng: np.random.Generator = field(init=False, repr=False) _counts_obs: CSCBase = field(init=False, repr=False) _total_counts_obs: NDArray[np.integer] = field(init=False, repr=False) @@ -169,7 +167,7 @@ def __post_init__( counts_obs: CSBase | NDArray[np.integer], total_counts_obs: NDArray[np.integer] | None, n_neighbors: int | None, - random_state: _LegacyRandom, + rng: SeedLike | RNGLike | None, ) -> None: self._counts_obs = sparse.csc_matrix(counts_obs) # noqa: TID251 self._total_counts_obs = ( @@ -182,7 +180,7 @@ def __post_init__( if n_neighbors is None else n_neighbors ) - self._random_state = check_random_state(random_state) + self._rng = np.random.default_rng(rng) def simulate_doublets( self, @@ -218,7 +216,7 @@ def simulate_doublets( n_obs = self._counts_obs.shape[0] n_sim = int(n_obs * sim_doublet_ratio) - pair_ix = sample_comb((n_obs, n_obs), n_sim, random_state=self._random_state) + pair_ix = sample_comb((n_obs, n_obs), n_sim, rng=self._rng) e1 = cast("CSCBase", self._counts_obs[pair_ix[:, 0], :]) e2 = cast("CSCBase", self._counts_obs[pair_ix[:, 1], :]) @@ -229,7 +227,7 @@ def simulate_doublets( e1 + e2, rate=synthetic_doublet_umi_subsampling, original_totals=tots1 + tots2, - random_seed=self._random_state, + rng=self._rng, ) else: self._counts_sim = e1 + e2 @@ -348,7 +346,7 @@ def _nearest_neighbor_classifier( knn=True, transformer=transformer, method=None, - random_state=self._random_state, + rng=self._rng, ) neighbors, _ = _get_indices_distances_from_sparse_matrix(knn.distances, k_adj) if use_approx_neighbors: diff --git a/src/scanpy/preprocessing/_scrublet/pipeline.py b/src/scanpy/preprocessing/_scrublet/pipeline.py index edc3417cd9..53d98ff8ed 100644 --- a/src/scanpy/preprocessing/_scrublet/pipeline.py +++ b/src/scanpy/preprocessing/_scrublet/pipeline.py @@ -6,12 +6,12 @@ from fast_array_utils.stats import mean_var from scipy import sparse +from ..._utils.random import legacy_random_state from .sparse_utils import sparse_multiply, sparse_zscore if TYPE_CHECKING: from typing import Literal - from ..._utils.random import _LegacyRandom from .core import Scrublet @@ -46,10 +46,10 @@ def zscore(self: Scrublet) -> None: def truncated_svd( self: Scrublet, - n_prin_comps: int = 30, + n_prin_comps: int, *, - random_state: _LegacyRandom = 0, - algorithm: Literal["arpack", "randomized"] = "arpack", + rng: np.random.Generator, + algorithm: Literal["arpack", "randomized"], ) -> None: if self._counts_sim_norm is None: msg = "_counts_sim_norm is not set" @@ -57,7 +57,9 @@ def truncated_svd( from sklearn.decomposition import TruncatedSVD svd = TruncatedSVD( - n_components=n_prin_comps, random_state=random_state, algorithm=algorithm + n_components=n_prin_comps, + random_state=legacy_random_state(rng), + algorithm=algorithm, ).fit(self._counts_obs_norm) self.set_manifold( svd.transform(self._counts_obs_norm), svd.transform(self._counts_sim_norm) @@ -66,10 +68,10 @@ def truncated_svd( def pca( self: Scrublet, - n_prin_comps: int = 50, + n_prin_comps: int, *, - random_state: _LegacyRandom = 0, - svd_solver: Literal["auto", "full", "arpack", "randomized"] = "arpack", + rng: np.random.Generator, + svd_solver: Literal["auto", "full", "arpack", "randomized"], ) -> None: if self._counts_sim_norm is None: msg = "_counts_sim_norm is not set" @@ -80,6 +82,8 @@ def pca( x_sim = self._counts_sim_norm.toarray() pca = PCA( - n_components=n_prin_comps, random_state=random_state, svd_solver=svd_solver + n_components=n_prin_comps, + random_state=legacy_random_state(rng), + svd_solver=svd_solver, ).fit(x_obs) self.set_manifold(pca.transform(x_obs), pca.transform(x_sim)) diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index 5b7e7aaaaf..611754f91b 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -5,13 +5,11 @@ import numpy as np from fast_array_utils.stats import mean_var from scipy import sparse -from sklearn.utils import check_random_state if TYPE_CHECKING: from numpy.typing import NDArray from ..._compat import CSBase - from ..._utils.random import _LegacyRandom def sparse_multiply( @@ -48,11 +46,10 @@ def subsample_counts( *, rate: float, original_totals, - random_seed: _LegacyRandom = 0, + rng: np.random.Generator, ) -> tuple[CSBase, NDArray[np.int64]]: if rate < 1: - random_seed = check_random_state(random_seed) - e.data = random_seed.binomial(np.round(e.data).astype(int), rate) + e.data = rng.binomial(np.round(e.data).astype(int), rate) current_totals = np.asarray(e.sum(1)).squeeze() unsampled_orig_totals = original_totals - current_totals unsampled_downsamp_totals = np.random.binomial( diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 9eebeb0246..9d6abf1a38 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -41,7 +41,7 @@ import pandas as pd from numpy.typing import NDArray - from .._utils.random import RNGLike, SeedLike, _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals( @@ -853,7 +853,7 @@ def sample( fraction: float | None = None, *, n: int | None = None, - rng: RNGLike | SeedLike | None = 0, + rng: RNGLike | SeedLike | None = None, copy: Literal[False] = False, replace: bool = False, axis: Literal["obs", 0, "var", 1] = "obs", @@ -996,7 +996,7 @@ def downsample_counts( counts_per_cell: int | Collection[int] | None = None, total_counts: int | None = None, *, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, replace: bool = False, copy: bool = False, ) -> AnnData | None: @@ -1020,7 +1020,7 @@ def downsample_counts( total_counts Target total counts. If the count matrix has more than `total_counts` it will be downsampled to have this number. - random_state + rng Random seed for subsampling. replace Whether to sample the counts with replacement. @@ -1037,6 +1037,7 @@ def downsample_counts( """ raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch + rng = np.random.default_rng(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1046,11 +1047,11 @@ def downsample_counts( adata = adata.copy() if total_counts_call: adata.X = _downsample_total_counts( - adata.X, total_counts, random_state=random_state, replace=replace + adata.X, total_counts, rng=rng, replace=replace ) elif counts_per_cell_call: adata.X = _downsample_per_cell( - adata.X, counts_per_cell, random_state=random_state, replace=replace + adata.X, counts_per_cell, rng=rng, replace=replace ) if copy: return adata @@ -1061,7 +1062,7 @@ def _downsample_per_cell( /, counts_per_cell: int, *, - random_state: _LegacyRandom, + rng: np.random.Generator, replace: bool, ) -> CSBase: n_obs = x.shape[0] @@ -1088,11 +1089,7 @@ def _downsample_per_cell( for rowidx in under_target: row = rows[rowidx] _downsample_array( - row, - counts_per_cell[rowidx], - random_state=random_state, - replace=replace, - inplace=True, + row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True ) x.eliminate_zeros() if not issubclass(original_type, CSRBase): # Put it back @@ -1103,11 +1100,7 @@ def _downsample_per_cell( for rowidx in under_target: row = x[rowidx, :] _downsample_array( - row, - counts_per_cell[rowidx], - random_state=random_state, - replace=replace, - inplace=True, + row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True ) return x @@ -1117,7 +1110,7 @@ def _downsample_total_counts( /, total_counts: int, *, - random_state: _LegacyRandom, + rng: np.random.Generator, replace: bool, ) -> CSBase: total_counts = int(total_counts) @@ -1128,21 +1121,13 @@ def _downsample_total_counts( original_type = type(x) if not isinstance(x, CSRBase): x = x.tocsr() - _downsample_array( - x.data, - total_counts, - random_state=random_state, - replace=replace, - inplace=True, - ) + _downsample_array(x.data, total_counts, rng=rng, replace=replace, inplace=True) x.eliminate_zeros() if not issubclass(original_type, CSRBase): x = original_type(x) else: v = x.reshape(np.multiply(*x.shape)) - _downsample_array( - v, total_counts, random_state=random_state, replace=replace, inplace=True - ) + _downsample_array(v, total_counts, rng=rng, replace=replace, inplace=True) return x @@ -1152,7 +1137,7 @@ def _downsample_array( col: np.ndarray, target: int, *, - random_state: _LegacyRandom = 0, + rng: np.random.Generator, replace: bool = True, inplace: bool = False, ): @@ -1162,14 +1147,13 @@ def _downsample_array( * total counts in cell must be less than target """ - np.random.seed(random_state) cumcounts = col.cumsum() if inplace: col[:] = 0 else: col = np.zeros_like(col) total = np.int_(cumcounts[-1]) - sample = np.random.choice(total, target, replace=replace) + sample = rng.choice(total, target, replace=replace) sample.sort() geneptr = 0 for count in sample: diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 0ba97c7b4e..3da0fed2d5 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -5,24 +5,25 @@ import numpy as np from sklearn.random_projection import sample_without_replacement +from .._utils.random import legacy_random_state + if TYPE_CHECKING: from typing import Literal from numpy.typing import NDArray - from .._utils.random import _LegacyRandom - def sample_comb( dims: tuple[int, ...], nsamp: int, *, - random_state: _LegacyRandom = None, + rng: np.random.Generator, method: Literal[ "auto", "tracking_selection", "reservoir_sampling", "pool" ] = "auto", ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" + random_state = legacy_random_state(rng) idx = sample_without_replacement( np.prod(dims), nsamp, random_state=random_state, method=method ) diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 62fa55b92e..eae7050fda 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -2,13 +2,15 @@ from typing import TYPE_CHECKING +import numpy as np + from .._compat import old_positionals from ._dpt import _diffmap if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals("neighbors_key", "random_state", "copy") @@ -17,7 +19,7 @@ def diffmap( n_comps: int = 15, *, neighbors_key: str | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: """Diffusion Maps :cite:p:`Coifman2005,Haghverdi2015,Wolf2018`. @@ -50,8 +52,8 @@ def diffmap( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. - random_state - A numpy random seed + rng + A numpy random number generator copy Return a copy instead of writing to adata. @@ -75,6 +77,7 @@ def diffmap( e.g. `adata.obsm["X_diffmap"][:,1]` """ + rng = np.random.default_rng(rng) if neighbors_key is None: neighbors_key = "neighbors" @@ -85,7 +88,5 @@ def diffmap( msg = "Provide any value greater than 2 for `n_comps`. " raise ValueError(msg) adata = adata.copy() if copy else adata - _diffmap( - adata, n_comps=n_comps, neighbors_key=neighbors_key, random_state=random_state - ) + _diffmap(adata, n_comps=n_comps, neighbors_key=neighbors_key, rng=rng) return adata if copy else None diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 2a52633026..7d8e9dc661 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -9,6 +9,7 @@ from .. import logging as logg from .._compat import old_positionals +from .._utils.random import legacy_numpy_gen from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -17,11 +18,17 @@ from anndata import AnnData -def _diffmap(adata, n_comps=15, neighbors_key=None, random_state=0): +def _diffmap( + adata: AnnData, + n_comps: int = 15, + *, + neighbors_key: str | None, + rng: np.random.Generator, +) -> None: start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() - dpt.compute_eigen(n_comps=n_comps, random_state=random_state) + dpt.compute_eigen(n_comps=n_comps, rng=rng) adata.obsm["X_diffmap"] = dpt.eigen_basis adata.uns["diffmap_evals"] = dpt.eigen_values logg.info( @@ -144,7 +151,7 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key) + _diffmap(adata, neighbors_key=neighbors_key, rng=legacy_numpy_gen(0)) # start with the actual computation dpt = DPT( adata, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index a76bad1c30..5a6cc6eec3 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -1,6 +1,5 @@ from __future__ import annotations -import random from importlib.util import find_spec from typing import TYPE_CHECKING, Literal @@ -10,6 +9,7 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _choose_graph, get_literal_vals +from .._utils.random import set_igraph_rng from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -18,7 +18,7 @@ from anndata import AnnData from .._compat import SpBase - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] @@ -41,7 +41,7 @@ def draw_graph( # noqa: PLR0913 *, init_pos: str | bool | None = None, root: int | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, n_jobs: int | None = None, adjacency: SpBase | None = None, key_added_ext: str | None = None, @@ -83,7 +83,7 @@ def draw_graph( # noqa: PLR0913 'rt' (Reingold Tilford tree layout). root Root for tree layouts. - random_state + rng For layouts with random initialization like 'fr', change this to use different intial states for the optimization. If `None`, no seed is set. adjacency @@ -123,6 +123,7 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") + rng = np.random.default_rng(rng) if layout not in (layouts := get_literal_vals(_Layout)): msg = f"Provide a valid layout, one of {layouts}." raise ValueError(msg) @@ -136,33 +137,30 @@ def draw_graph( # noqa: PLR0913 init_coords = get_init_pos_from_paga( adata, adjacency, - random_state=random_state, + rng=rng, neighbors_key=neighbors_key, obsp=obsp, ) else: - np.random.seed(random_state) - init_coords = np.random.random((adjacency.shape[0], 2)) + init_coords = rng.random((adjacency.shape[0], 2)) layout = coerce_fa2_layout(layout) # actual drawing if layout == "fa": positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: - # igraph doesn't use numpy seed - random.seed(random_state) - g = _utils.get_igraph_from_adjacency(adjacency) - if layout in {"fr", "drl", "kk", "grid_fr"}: - ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) - elif "rt" in layout: - if root is not None: - root = [root] - ig_layout = g.layout(layout, root=root, **kwds) - else: - ig_layout = g.layout(layout, **kwds) + with set_igraph_rng(rng): + if layout in {"fr", "drl", "kk", "grid_fr"}: + ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) + elif "rt" in layout: + if root is not None: + root = [root] + ig_layout = g.layout(layout, root=root, **kwds) + else: + ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=random_state) + adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=rng) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions logg.info( diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 61967f2ce8..58ae9d04f3 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Hashable from typing import TYPE_CHECKING, cast import numpy as np @@ -9,7 +10,7 @@ from .. import _utils from .. import logging as logg from .._compat import warn -from .._utils.random import set_igraph_random_state +from .._utils.random import set_igraph_rng from ._utils_clustering import rename_groups, restrict_adjacency if TYPE_CHECKING: @@ -19,7 +20,7 @@ from anndata import AnnData from .._compat import CSBase - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike try: # sphinx-autodoc-typehints + optional dependency from leidenalg.VertexPartition import MutableVertexPartition @@ -34,7 +35,7 @@ def leiden( # noqa: PLR0913 resolution: float = 1, *, restrict_to: tuple[str, Sequence[str]] | None = None, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, key_added: str = "leiden", adjacency: CSBase | None = None, directed: bool | None = None, @@ -67,7 +68,7 @@ def leiden( # noqa: PLR0913 Higher values lead to more clusters. Set to `None` if overriding `partition_type` to one that doesn’t accept a `resolution_parameter`. - random_state + rng Change the initialization of the optimization. restrict_to Restrict the clustering to the categories within the key for sample @@ -160,7 +161,8 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - clustering_args["seed"] = random_state + if isinstance(rng, Hashable): + clustering_args["seed"] = rng part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), @@ -172,7 +174,7 @@ def leiden( # noqa: PLR0913 if resolution is not None: clustering_args["resolution"] = resolution clustering_args.setdefault("objective_function", "modularity") - with set_igraph_random_state(random_state): + with set_igraph_rng(rng): part = g.community_leiden(**clustering_args) # store output into adata.obs groups = np.array(part.membership) @@ -195,7 +197,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=random_state, + random_state=rng, n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 426497182d..afef115ce0 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -19,7 +19,8 @@ from anndata import AnnData from numpy.typing import DTypeLike, NDArray - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike + type _StrIdx = pd.Index[str] type _GetSubset = Callable[[_StrIdx], np.ndarray | CSBase] @@ -61,7 +62,7 @@ def score_genes( # noqa: PLR0913 gene_pool: Sequence[str] | pd.Index[str] | None = None, n_bins: int = 25, score_name: str = "score", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, copy: bool = False, use_raw: bool | None = None, layer: str | None = None, @@ -94,8 +95,8 @@ def score_genes( # noqa: PLR0913 Number of expression level bins for sampling. score_name Name of the field to be added in `.obs`. - random_state - The random seed for sampling. + rng + The random number generator for sampling. copy Copy `adata` or modify it inplace. use_raw @@ -119,19 +120,17 @@ def score_genes( # noqa: PLR0913 """ start = logg.info(f"computing score {score_name!r}") + rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata use_raw = check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: msg = f"score_genes is not implemented for matrices of type {type(adata.X)}" raise NotImplementedError(msg) - if random_state is not None: - np.random.seed(random_state) - gene_list, gene_pool, get_subset = _check_score_genes_args( adata, gene_list, gene_pool, use_raw=use_raw, layer=layer ) - del use_raw, layer, random_state + del use_raw, layer # Trying here to match the Seurat approach in scoring cells. # Basically we need to compare genes against random genes in a matched @@ -145,6 +144,7 @@ def score_genes( # noqa: PLR0913 ctrl_size=ctrl_size, n_bins=n_bins, get_subset=get_subset, + rng=rng, ): control_genes = control_genes.union(r_genes) @@ -224,6 +224,7 @@ def _score_genes_bins( ctrl_size: int, n_bins: int, get_subset: _GetSubset, + rng: np.random.Generator, ) -> Generator[pd.Index[str], None, None]: # average expression of genes obs_avg = pd.Series(_nan_means(get_subset(gene_pool), axis=0), index=gene_pool) @@ -244,7 +245,7 @@ def _score_genes_bins( ) logg.warning(msg) if ctrl_size < len(r_genes): - r_genes = r_genes.to_series().sample(ctrl_size).index + r_genes = r_genes.to_series().sample(ctrl_size, random_state=rng).index if ctrl_as_ref: # otherwise `r_genes` is already filtered r_genes = r_genes.difference(gene_list) yield r_genes diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index de3f3a7300..5aca180291 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,13 +6,14 @@ from .._compat import old_positionals, warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type +from .._utils.random import legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation if TYPE_CHECKING: from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike @old_positionals( @@ -36,7 +37,7 @@ def tsne( # noqa: PLR0913 metric: str = "euclidean", early_exaggeration: float = 12, learning_rate: float = 1000, - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, use_fast_tsne: bool = False, n_jobs: int | None = None, key_added: str | None = None, @@ -84,7 +85,7 @@ def tsne( # noqa: PLR0913 optimization, the early exaggeration factor or the learning rate might be too high. If the cost function gets stuck in a bad local minimum increasing the learning rate helps sometimes. - random_state + rng Change this to use different intial states for the optimization. If `None`, the initial state is not reproducible. n_jobs @@ -118,7 +119,7 @@ def tsne( # noqa: PLR0913 n_jobs = settings.n_jobs if n_jobs is None else n_jobs params_sklearn = dict( perplexity=perplexity, - random_state=random_state, + random_state=legacy_random_state(rng), verbose=settings.verbosity > 3, early_exaggeration=early_exaggeration, learning_rate=learning_rate, diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index a1848173fa..ba0b908e8d 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -4,12 +4,13 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.utils import check_array, check_random_state +from sklearn.utils import check_array from .. import logging as logg from .._compat import old_positionals, warn from .._settings import settings from .._utils import NeighborsView +from .._utils.random import legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -17,7 +18,8 @@ from anndata import AnnData - from .._utils.random import _LegacyRandom + from .._utils.random import RNGLike, SeedLike + type _InitPos = Literal["paga", "spectral", "random"] @@ -49,7 +51,7 @@ def umap( # noqa: PLR0913, PLR0915 gamma: float = 1.0, negative_sample_rate: int = 5, init_pos: _InitPos | np.ndarray | None = "spectral", - random_state: _LegacyRandom = 0, + rng: SeedLike | RNGLike | None = None, a: float | None = None, b: float | None = None, method: Literal["umap", "rapids"] = "umap", @@ -111,11 +113,10 @@ def umap( # noqa: PLR0913, PLR0915 * 'spectral': use a spectral embedding of the graph. * 'random': assign initial embedding positions at random. * A numpy array of initial embedding positions. - random_state - If `int`, `random_state` is the seed used by the random number generator; - If `RandomState` or `Generator`, `random_state` is the random number generator; - If `None`, the random number generator is the `RandomState` instance used - by `np.random`. + rng + If `int`, `rng` is the seed used by the random number generator; + If `Generator`, `random_state` is the random number generator; + If `None`, the random number generator is not reproducible. a More specific parameters controlling the embedding. If `None` these values are set automatically as determined by `min_dist` and @@ -158,6 +159,7 @@ def umap( # noqa: PLR0913, PLR0915 UMAP parameters. """ + rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 @@ -191,16 +193,15 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = adata.obsm[init_pos] elif isinstance(init_pos, str) and init_pos == "paga": init_coords = get_init_pos_from_paga( - adata, random_state=random_state, neighbors_key=neighbors_key + adata, rng=rng, neighbors_key=neighbors_key ) else: init_coords = init_pos # Let umap handle it if hasattr(init_coords, "dtype"): init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) - if random_state != 0: - adata.uns[key_uns]["params"]["random_state"] = random_state - random_state = check_random_state(random_state) + if rng is not None: + adata.uns[key_uns]["params"]["random_state"] = rng neigh_params = neighbors["params"] x = _choose_representation( @@ -225,7 +226,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=random_state, + random_state=legacy_random_state(rng), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, @@ -265,7 +266,7 @@ def umap( # noqa: PLR0913, PLR0915 a=a, b=b, verbose=settings.verbosity > 3, - random_state=random_state, + random_state=legacy_random_state(rng), ) x_umap = umap.fit_transform(x_contiguous) adata.obsm[key_obsm] = x_umap # annotate samples with UMAP coordinates diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index fdcff28720..bdb9ae7f90 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from anndata import AnnData + from numpy.typing import NDArray from .._compat import CSRBase, SpBase @@ -77,12 +78,12 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa def get_init_pos_from_paga( adata: AnnData, + *, + rng: np.random.Generator, adjacency: SpBase | None = None, - random_state=0, neighbors_key: str | None = None, obsp: str | None = None, -): - np.random.seed(random_state) +) -> NDArray[np.float64]: if adjacency is None: adjacency = _choose_graph(adata, obsp, neighbors_key) if "pos" not in adata.uns.get("paga", {}): @@ -99,7 +100,7 @@ def get_init_pos_from_paga( if len(neighbors[1]) > 0: connectivities = connectivities_coarse[i][neighbors] nearest_neighbor = neighbors[1][np.argmax(connectivities)] - noise = np.random.random((len(subset[subset]), 2)) + noise = rng.random((len(subset[subset]), 2)) dist = group_pos - pos[nearest_neighbor] noise = noise * dist init_pos[subset] = group_pos - 0.5 * dist + noise From 8ab6661858c7cab8c03f76909a0a1370a9a3a187 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 24 Feb 2026 13:59:37 +0100 Subject: [PATCH 02/20] add decorator --- src/scanpy/_utils/random.py | 70 ++++++++++++++----- src/scanpy/datasets/_datasets.py | 3 +- src/scanpy/experimental/pp/_normalization.py | 2 + src/scanpy/experimental/pp/_recipes.py | 3 + src/scanpy/neighbors/__init__.py | 5 +- .../preprocessing/_deprecated/sampling.py | 4 +- src/scanpy/preprocessing/_pca/__init__.py | 3 +- src/scanpy/preprocessing/_pca/_compat.py | 3 +- src/scanpy/preprocessing/_recipes.py | 2 + .../preprocessing/_scrublet/__init__.py | 3 + src/scanpy/preprocessing/_simple.py | 2 + src/scanpy/tools/_diffmap.py | 2 + src/scanpy/tools/_dpt.py | 4 +- src/scanpy/tools/_draw_graph.py | 3 +- src/scanpy/tools/_leiden.py | 3 +- src/scanpy/tools/_tsne.py | 3 +- tests/test_utils.py | 4 +- 17 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index aa405c0ce8..49074f22f3 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from collections.abc import Generator + from typing import Self from numpy.typing import NDArray @@ -20,8 +21,8 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "accepts_legacy_random_state", "ith_k_tuple", - "legacy_numpy_gen", "legacy_random_state", "random_k_tuples", "random_str", @@ -76,23 +77,23 @@ def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ################################### -def legacy_numpy_gen( - random_state: _LegacyRandom | None = None, -) -> np.random.Generator: - """Return a random generator that behaves like the legacy one.""" - if random_state is not None: - if isinstance(random_state, np.random.RandomState): - np.random.set_state(random_state.get_state(legacy=False)) - return _FakeRandomGen(random_state) - np.random.seed(random_state) - return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) - - class _FakeRandomGen(np.random.Generator): + _arg: _LegacyRandom _state: np.random.RandomState - def __init__(self, random_state: np.random.RandomState) -> None: - self._state = random_state + def __init__(self, seed_or_state: _LegacyRandom) -> None: + self._arg = seed_or_state + self._state = np.random.RandomState(seed_or_state) + + @classmethod + def wrap_global(cls, random_state: _LegacyRandom | None = None) -> Self: + """Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions.""" + if random_state is not None: + if isinstance(random_state, np.random.RandomState): + np.random.set_state(random_state.get_state(legacy=False)) + return _FakeRandomGen(random_state) + np.random.seed(random_state) + return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) @classmethod def _delegate(cls) -> None: @@ -114,13 +115,46 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): _FakeRandomGen._delegate() -def legacy_random_state(rng: SeedLike | RNGLike | None) -> np.random.RandomState: - rng = np.random.default_rng(rng) +def legacy_random_state(rng: SeedLike | RNGLike | None) -> _LegacyRandom: + """Convert a np.random.Generator into a legacy `random_state` argument. + + If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute. + """ if isinstance(rng, _FakeRandomGen): - return rng._state + return rng._arg + rng = np.random.default_rng(rng) return np.random.RandomState(rng.bit_generator.spawn(1)[0]) +def accepts_legacy_random_state[**P, R]( + random_state_default: _LegacyRandom, +) -> callable[[callable[P, R]], callable[P, R]]: + """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. + + If the decorated function is called with a `random_state` argument, + it’ll be wrapped in a :class:`_FakeRandomGen`. + Passing both ``rng`` and ``random_state`` at the same time is an error. + If neither is given, ``random_state_default`` is used. + """ + + def decorator(func: callable[P, R]) -> callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + match "random_state" in kwargs, "rng" in kwargs: + case True, True: + msg = "Specify at most one of `rng` and `random_state`." + raise TypeError(msg) + case True, False: + kwargs["rng"] = _FakeRandomGen(kwargs.pop("random_state")) + case False, False: + kwargs["rng"] = _FakeRandomGen(random_state_default) + return func(*args, **kwargs) + + return wrapper + + return decorator + + ################### # Random k-tuples # ################### diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index 116c8a572d..d4b56f055d 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,7 +12,7 @@ from .._compat import deprecated, old_positionals from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists @@ -58,6 +58,7 @@ @old_positionals( "n_variables", "n_centers", "cluster_std", "n_observations", "random_state" ) +@accepts_legacy_random_state(0) def blobs( *, n_variables: int = 11, diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index eddb899bb5..9cd2d6d092 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -9,6 +9,7 @@ from ... import logging as logg from ..._compat import CSBase, warn from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual +from ..._utils.random import accepts_legacy_random_state from ...experimental._docs import ( doc_adata, doc_check_values, @@ -161,6 +162,7 @@ def normalize_pearson_residuals( check_values=doc_check_values, inplace=doc_inplace, ) +@accepts_legacy_random_state(0) def normalize_pearson_residuals_pca( adata: AnnData, *, diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 5b0357dc4d..b688b8d4d8 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -17,6 +17,8 @@ ) from scanpy.preprocessing import pca +from ..._utils.random import accepts_legacy_random_state + if TYPE_CHECKING: from collections.abc import Mapping from typing import Any @@ -35,6 +37,7 @@ check_values=doc_check_values, inplace=doc_inplace, ) +@accepts_legacy_random_state(0) def recipe_pearson_residuals( # noqa: PLR0913 adata: AnnData, *, diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 4fa2b5d597..705398c053 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -17,7 +17,7 @@ from .._compat import CSBase, CSRBase, SpBase, old_positionals, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -78,6 +78,7 @@ class NeighborsParams(TypedDict): # noqa: D101 @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) +@accepts_legacy_random_state(0) def neighbors( # noqa: PLR0913 adata: AnnData, n_neighbors: int = 15, @@ -572,6 +573,7 @@ def to_igraph(self) -> Graph: return _utils.get_igraph_from_adjacency(self.connectivities) @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) + @accepts_legacy_random_state(0) def compute_neighbors( self, n_neighbors: int = 30, @@ -842,6 +844,7 @@ def compute_transitions(self, *, density_normalize: bool = True): self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) + @accepts_legacy_random_state(0) def compute_eigen( self, *, diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py index 2280f3c9a0..b2dfeb0e92 100644 --- a/src/scanpy/preprocessing/_deprecated/sampling.py +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from ..._compat import old_positionals -from ..._utils.random import legacy_numpy_gen +from ..._utils.random import _FakeRandomGen from .._simple import sample if TYPE_CHECKING: @@ -52,7 +52,7 @@ def subsample( returns a subsampled copy of it (`copy == True`). """ - rng = legacy_numpy_gen(random_state) + rng = _FakeRandomGen.wrap_global(random_state) return sample( data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0 ) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 87ad1752ee..c6d70e7c36 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -10,7 +10,7 @@ from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type -from ..._utils.random import legacy_random_state +from ..._utils.random import accepts_legacy_random_state, legacy_random_state from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -54,6 +54,7 @@ @_doc_params( mask_var_hvg=doc_mask_var_hvg, ) +@accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, n_comps: int | None = None, diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index 133be7b0f0..f1f4e33b0d 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -10,7 +10,7 @@ from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version -from ..._utils.random import legacy_random_state +from ..._utils.random import accepts_legacy_random_state, legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -22,6 +22,7 @@ from ..._utils.random import RNGLike, SeedLike +@accepts_legacy_random_state(None) def _pca_compat_sparse( x: CSBase, n_pcs: int, diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index 34cf986d57..504561dcf7 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -7,6 +7,7 @@ from .. import logging as logg from .. import preprocessing as pp from .._compat import CSBase, old_positionals +from .._utils.random import accepts_legacy_random_state from ._deprecated.highly_variable_genes import ( filter_genes_cv_deprecated, filter_genes_dispersion, @@ -28,6 +29,7 @@ "random_state", "copy", ) +@accepts_legacy_random_state(0) def recipe_weinreb17( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index b5c7d60a55..7b9f86335b 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -11,6 +11,7 @@ from ... import logging as logg from ... import preprocessing as pp from ..._compat import old_positionals +from ..._utils.random import accepts_legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -39,6 +40,7 @@ "copy", "random_state", ) +@accepts_legacy_random_state() def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -304,6 +306,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None +@accepts_legacy_random_state() def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 9d6abf1a38..89816d4152 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -30,6 +30,7 @@ sanitize_anndata, view_to_actual, ) +from .._utils.random import accepts_legacy_random_state from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -991,6 +992,7 @@ def sample( # noqa: PLR0912 @renamed_arg("target_counts", "counts_per_cell") +@accepts_legacy_random_state(0) def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index eae7050fda..bf533938dd 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -5,6 +5,7 @@ import numpy as np from .._compat import old_positionals +from .._utils.random import accepts_legacy_random_state from ._dpt import _diffmap if TYPE_CHECKING: @@ -14,6 +15,7 @@ @old_positionals("neighbors_key", "random_state", "copy") +@accepts_legacy_random_state(0) def diffmap( adata: AnnData, n_comps: int = 15, diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 7d8e9dc661..5225ec2164 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -9,7 +9,7 @@ from .. import logging as logg from .._compat import old_positionals -from .._utils.random import legacy_numpy_gen +from .._utils.random import _FakeRandomGen from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -151,7 +151,7 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key, rng=legacy_numpy_gen(0)) + _diffmap(adata, neighbors_key=neighbors_key, rng=_FakeRandomGen(0)) # start with the actual computation dpt = DPT( adata, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 5a6cc6eec3..b0e9406575 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -9,7 +9,7 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _choose_graph, get_literal_vals -from .._utils.random import set_igraph_rng +from .._utils.random import accepts_legacy_random_state, set_igraph_rng from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -35,6 +35,7 @@ "obsp", "copy", ) +@accepts_legacy_random_state(0) def draw_graph( # noqa: PLR0913 adata: AnnData, layout: _Layout = "fa", diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 58ae9d04f3..80efd4d6b0 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -10,7 +10,7 @@ from .. import _utils from .. import logging as logg from .._compat import warn -from .._utils.random import set_igraph_rng +from .._utils.random import accepts_legacy_random_state, set_igraph_rng from ._utils_clustering import rename_groups, restrict_adjacency if TYPE_CHECKING: @@ -30,6 +30,7 @@ MutableVertexPartition.__module__ = "leidenalg.VertexPartition" +@accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 adata: AnnData, resolution: float = 1, diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 5aca180291..f49456f505 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,7 +6,7 @@ from .._compat import old_positionals, warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation @@ -26,6 +26,7 @@ "n_jobs", "copy", ) +@accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( # noqa: PLR0913 adata: AnnData, diff --git a/tests/test_utils.py b/tests/test_utils.py index b82deea324..6c8709d296 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -18,8 +18,8 @@ descend_classes_and_funcs, ) from scanpy._utils.random import ( + _FakeRandomGen, ith_k_tuple, - legacy_numpy_gen, random_k_tuples, random_str, ) @@ -206,7 +206,7 @@ def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: if direct and seed is not None: np.random.seed(seed) - gen = np.random if direct else legacy_numpy_gen(seed) + gen = np.random if direct else _FakeRandomGen.wrap_global(seed) match func: case "choice": arr = np.arange(1000) From 1ef87803bcbb1d9e01f0ca919bcf917beff4320a Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 24 Feb 2026 14:19:23 +0100 Subject: [PATCH 03/20] scrublet --- src/scanpy/preprocessing/_scrublet/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 7b9f86335b..6967b4ecd1 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -40,7 +40,7 @@ "copy", "random_state", ) -@accepts_legacy_random_state() +@accepts_legacy_random_state(0) def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -306,7 +306,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None -@accepts_legacy_random_state() +@accepts_legacy_random_state(0) def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, From 5308a1acb90d6de3c3bf432b3d24d40c83dc187c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 24 Feb 2026 18:14:04 +0100 Subject: [PATCH 04/20] almost done --- pyproject.toml | 2 + src/scanpy/_utils/random.py | 57 ++++++++++++++----- src/scanpy/neighbors/__init__.py | 19 +++++-- src/scanpy/preprocessing/_pca/__init__.py | 25 +++++--- .../preprocessing/_scrublet/__init__.py | 18 ++++-- src/scanpy/preprocessing/_simple.py | 19 +++++-- src/scanpy/tools/_draw_graph.py | 7 ++- src/scanpy/tools/_leiden.py | 10 ++-- src/scanpy/tools/_score_genes.py | 5 ++ src/scanpy/tools/_umap.py | 7 ++- tests/test_pca.py | 4 +- 11 files changed, 125 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 126098f593..545c6f8a3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -279,6 +279,8 @@ filterwarnings = [ "ignore:The `igraph` implementation of leiden clustering:UserWarning", # everybody uses this zarr 3 feature, including us, XArray, lots of data out there … "ignore:Consolidated metadata is currently not part:UserWarning", + # joblib fallback to serial mode in restricted multiprocessing environments + "ignore:.*joblib will operate in serial mode:UserWarning", ] [tool.coverage] diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 49074f22f3..b2692192b5 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -21,6 +21,7 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "_if_legacy_apply_global", "accepts_legacy_random_state", "ith_k_tuple", "legacy_random_state", @@ -81,22 +82,38 @@ class _FakeRandomGen(np.random.Generator): _arg: _LegacyRandom _state: np.random.RandomState - def __init__(self, seed_or_state: _LegacyRandom) -> None: - self._arg = seed_or_state - self._state = np.random.RandomState(seed_or_state) + def __init__( + self, arg: _LegacyRandom, state: np.random.RandomState | None = None + ) -> None: + self._arg = arg + self._state = np.random.RandomState(arg) if state is None else state + super().__init__(self._state._bit_generator) @classmethod - def wrap_global(cls, random_state: _LegacyRandom | None = None) -> Self: + def wrap_global( + cls, + arg: _LegacyRandom = None, + state: np.random.RandomState | None = None, + ) -> Self: """Create a generator that wraps the global `RandomState` backing the legacy `np.random` functions.""" - if random_state is not None: - if isinstance(random_state, np.random.RandomState): - np.random.set_state(random_state.get_state(legacy=False)) - return _FakeRandomGen(random_state) - np.random.seed(random_state) - return _FakeRandomGen(np.random.RandomState(np.random.get_bit_generator())) + if arg is not None: + if isinstance(arg, np.random.RandomState): + np.random.set_state(arg.get_state(legacy=False)) + return _FakeRandomGen(arg, state) + np.random.seed(arg) + return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _FakeRandomGen): + return False + return self._arg == other._arg + + def __hash__(self) -> int: + return hash((type(self), self._arg)) @classmethod def _delegate(cls) -> None: + names = dict(integers="randint") for name, meth in np.random.Generator.__dict__.items(): if name.startswith("_") or not callable(meth): continue @@ -109,19 +126,33 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): return wrapper - setattr(cls, name, mk_wrapper(name, meth)) + setattr(cls, names.get(name, name), mk_wrapper(name, meth)) _FakeRandomGen._delegate() -def legacy_random_state(rng: SeedLike | RNGLike | None) -> _LegacyRandom: +def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: + """Re-apply legacy `random_state` semantics when `rng` is a `_FakeRandomGen`. + + This resets the global legacy RNG from the original `_arg` and returns a + generator which continues drawing from the same internal state. + """ + if not isinstance(rng, _FakeRandomGen): + return rng + + return _FakeRandomGen.wrap_global(rng._arg, rng._state) + + +def legacy_random_state( + rng: SeedLike | RNGLike | None, *, always_state: bool = False +) -> _LegacyRandom: """Convert a np.random.Generator into a legacy `random_state` argument. If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute. """ if isinstance(rng, _FakeRandomGen): - return rng._arg + return rng._state if always_state else rng._arg rng = np.random.default_rng(rng) return np.random.RandomState(rng.bit_generator.spawn(1)[0]) diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 705398c053..3e03c60da5 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -17,7 +17,11 @@ from .._compat import CSBase, CSRBase, SpBase, old_positionals, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import ( + _FakeRandomGen, + accepts_legacy_random_state, + legacy_random_state, +) from . import _connectivity from ._common import ( _get_indices_distances_from_dense_matrix, @@ -225,17 +229,20 @@ def neighbors( # noqa: PLR0913 ) else: params = locals() - if ignored := { + ignored = { p.name for p in signature(neighbors).parameters.values() - if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds", "rng"} + if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds"} if params[p.name] != p.default - }: + } + if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + ignored.add("rng/random_state") + rng = _FakeRandomGen(0) + if ignored: warn( f"Parameter(s) ignored if `distances` is given: {ignored}", UserWarning, ) - random_state = 0 if callable(metric): msg = "`metric` must be a string if `distances` is given." raise TypeError(msg) @@ -263,7 +270,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=rng, + random_state=legacy_random_state(rng), metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index c6d70e7c36..da1888a1cc 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -10,7 +10,11 @@ from ..._compat import CSBase, DaskArray, pkg_version, warn from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type -from ..._utils.random import accepts_legacy_random_state, legacy_random_state +from ..._utils.random import ( + _FakeRandomGen, + accepts_legacy_random_state, + legacy_random_state, +) from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg from ._compat import _pca_compat_sparse @@ -243,19 +247,22 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 raise NotImplementedError(msg) # dask needs an int for random state - if not isinstance(x, DaskArray): - rng = np.random.default_rng(rng) - elif not isinstance(rng, int): - msg = f"rng needs to be an int, not a {type(rng).__name__} when passing a dask array" + rng = np.random.default_rng(rng) + if not isinstance(rng, _FakeRandomGen) and not isinstance( + rng._arg, int | np.random.RandomState + ): + # TODO: remove this error and if we don’t have a _FakeRandomGen, + # just use rng.integers to make a seed farther down + msg = f"rng needs to be an int or a np.random.RandomState, not a {type(rng).__name__} when passing a dask array" raise TypeError(msg) if chunked: if ( not zero_center - or rng is not None + or (not isinstance(rng, _FakeRandomGen) or rng._arg != 0) or (svd_solver is not None and svd_solver != "arpack") ): - logg.debug("Ignoring zero_center, random_state, svd_solver") + logg.debug("Ignoring zero_center, rng, svd_solver") incremental_pca_kwargs = dict() if isinstance(x, DaskArray): @@ -303,8 +310,8 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask - if rng is not None: - msg = f"Ignoring {rng=} when using a sparse dask array" + if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + msg = f"Ignoring rng={legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 6967b4ecd1..a75f27a33b 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -11,7 +11,11 @@ from ... import logging as logg from ... import preprocessing as pp from ..._compat import old_positionals -from ..._utils.random import accepts_legacy_random_state +from ..._utils.random import ( + _if_legacy_apply_global, + accepts_legacy_random_state, + legacy_random_state, +) from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -181,6 +185,7 @@ def scrublet( # noqa: PLR0913 """ rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -444,10 +449,14 @@ def _scrublet_call_doublets( # noqa: PLR0913 if mean_center: logg.info("Embedding transcriptomes using PCA...") - pipeline.pca(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) + pipeline.pca( + scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=scrub._rng + ) else: logg.info("Embedding transcriptomes using Truncated SVD...") - pipeline.truncated_svd(scrub, n_prin_comps=n_prin_comps, rng=scrub._rng) + pipeline.truncated_svd( + scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=scrub._rng + ) # Score the doublets @@ -479,7 +488,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "rng": rng, + "random_state": legacy_random_state(rng), }, } @@ -557,6 +566,7 @@ def scrublet_simulate_doublets( scores for observed transcriptomes and simulated doublets. """ + rng = _if_legacy_apply_global(rng) x = _get_obs_rep(adata, layer=layer) scrub = Scrublet(x, rng=rng) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 89816d4152..83e5e96c58 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -30,7 +30,7 @@ sanitize_anndata, view_to_actual, ) -from .._utils.random import accepts_legacy_random_state +from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -1040,6 +1040,7 @@ def downsample_counts( raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1134,7 +1135,6 @@ def _downsample_total_counts( # TODO: can/should this be parallelized? -@numba.njit(cache=True) # noqa: TID251 def _downsample_array( col: np.ndarray, target: int, @@ -1142,7 +1142,7 @@ def _downsample_array( rng: np.random.Generator, replace: bool = True, inplace: bool = False, -): +) -> np.ndarray: """Evenly reduce counts in cell to target amount. This is an internal function and has some restrictions: @@ -1150,13 +1150,20 @@ def _downsample_array( * total counts in cell must be less than target """ cumcounts = col.cumsum() + total = np.int_(cumcounts[-1]) + sample = rng.choice(total, target, replace=replace) + sample.sort() + return _downsample_array_inner(col, cumcounts, sample, inplace=inplace) + + +@numba.njit(cache=True) # noqa: TID251 +def _downsample_array_inner( + col: np.ndarray, cumcounts: np.ndarray, sample: np.ndarray, *, inplace: bool +) -> np.ndarray: if inplace: col[:] = 0 else: col = np.zeros_like(col) - total = np.int_(cumcounts[-1]) - sample = rng.choice(total, target, replace=replace) - sample.sort() geneptr = 0 for count in sample: while count >= cumcounts[geneptr]: diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index b0e9406575..ce0d5182b3 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -9,7 +9,11 @@ from .. import logging as logg from .._compat import old_positionals from .._utils import _choose_graph, get_literal_vals -from .._utils.random import accepts_legacy_random_state, set_igraph_rng +from .._utils.random import ( + _if_legacy_apply_global, + accepts_legacy_random_state, + set_igraph_rng, +) from ._utils import get_init_pos_from_paga if TYPE_CHECKING: @@ -125,6 +129,7 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") rng = np.random.default_rng(rng) + rng = _if_legacy_apply_global(rng) if layout not in (layouts := get_literal_vals(_Layout)): msg = f"Provide a valid layout, one of {layouts}." raise ValueError(msg) diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 80efd4d6b0..e4a50ef949 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Hashable from typing import TYPE_CHECKING, cast import numpy as np @@ -10,7 +9,11 @@ from .. import _utils from .. import logging as logg from .._compat import warn -from .._utils.random import accepts_legacy_random_state, set_igraph_rng +from .._utils.random import ( + accepts_legacy_random_state, + legacy_random_state, + set_igraph_rng, +) from ._utils_clustering import rename_groups, restrict_adjacency if TYPE_CHECKING: @@ -162,8 +165,7 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - if isinstance(rng, Hashable): - clustering_args["seed"] = rng + clustering_args["seed"] = legacy_random_state(rng) part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index afef115ce0..bf977a7531 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -10,6 +10,7 @@ from .. import logging as logg from .._compat import CSBase, old_positionals from .._utils import check_use_raw, is_backed_type +from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state from ..get import _get_obs_rep if TYPE_CHECKING: @@ -53,6 +54,7 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: @old_positionals( "ctrl_size", "gene_pool", "n_bins", "score_name", "random_state", "copy", "use_raw" ) +@accepts_legacy_random_state(0) def score_genes( # noqa: PLR0913 adata: AnnData, gene_list: Sequence[str] | pd.Index[str], @@ -120,7 +122,10 @@ def score_genes( # noqa: PLR0913 """ start = logg.info(f"computing score {score_name!r}") + rng_was_passed = rng is not None rng = np.random.default_rng(rng) + if rng_was_passed: # backwards compatibility: call np.random.seed() by default + rng = _if_legacy_apply_global(rng) adata = adata.copy() if copy else adata use_raw = check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index ba0b908e8d..1e0cdc0bcc 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -10,7 +10,7 @@ from .._compat import old_positionals, warn from .._settings import settings from .._utils import NeighborsView -from .._utils.random import legacy_random_state +from .._utils.random import accepts_legacy_random_state, legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -40,6 +40,7 @@ "method", "neighbors_key", ) +@accepts_legacy_random_state(0) def umap( # noqa: PLR0913, PLR0915 adata: AnnData, *, @@ -201,7 +202,7 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) if rng is not None: - adata.uns[key_uns]["params"]["random_state"] = rng + adata.uns[key_uns]["params"]["random_state"] = legacy_random_state(rng) neigh_params = neighbors["params"] x = _choose_representation( @@ -226,7 +227,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=legacy_random_state(rng), + random_state=legacy_random_state(rng, always_state=True), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, diff --git a/tests/test_pca.py b/tests/test_pca.py index ba996d8ad9..d4cd5e58f3 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -244,7 +244,7 @@ def test_pca_transform_randomized(array_type): warnings.filterwarnings("error") if isinstance(adata.X, DaskArray) and isinstance(adata.X._meta, CSBase): patterns = ( - r"Ignoring random_state=14 when using a sparse dask array", + r"Ignoring rng=14 when using a sparse dask array", r"Ignoring svd_solver='randomized' when using a sparse dask array", ) ctx = _helpers.MultiContext( @@ -338,7 +338,7 @@ def test_pca_reproducible(array_type): pbmc.X = array_type(pbmc.X) with ( - pytest.warns(UserWarning, match=r"Ignoring random_state.*sparse dask array") + pytest.warns(UserWarning, match=r"Ignoring rng.*sparse dask array") if isinstance(pbmc.X, DaskArray) and isinstance(pbmc.X._meta, CSBase) else nullcontext() ): From 32b3ddc9617df423930f512a71f2afbd698b5616 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:09:57 +0100 Subject: [PATCH 05/20] fix scrublet_simulate_doublets --- src/scanpy/preprocessing/_scrublet/__init__.py | 9 ++------- src/scanpy/preprocessing/_simple.py | 4 ++-- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 718672df25..b9215ae163 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,11 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp -from ..._utils.random import ( - _if_legacy_apply_global, - accepts_legacy_random_state, - legacy_random_state, -) +from ..._utils.random import accepts_legacy_random_state, legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -165,7 +161,6 @@ def scrublet( # noqa: PLR0913 """ rng = np.random.default_rng(rng) - rng = _if_legacy_apply_global(rng) if threshold is None and not find_spec("skimage"): # pragma: no cover # Scrublet.call_doublets requires `skimage` with `threshold=None` but PCA # is called early, which is wasteful if there is not `skimage` @@ -493,6 +488,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 return adata_obs +@accepts_legacy_random_state(0) def scrublet_simulate_doublets( adata: AnnData, *, @@ -543,7 +539,6 @@ def scrublet_simulate_doublets( scores for observed transcriptomes and simulated doublets. """ - rng = _if_legacy_apply_global(rng) x = _get_obs_rep(adata, layer=layer) scrub = Scrublet(x, rng=rng) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index e0c699d7b6..51f8b4a726 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -1020,7 +1020,6 @@ def downsample_counts( raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch rng = np.random.default_rng(rng) - rng = _if_legacy_apply_global(rng) total_counts_call = total_counts is not None counts_per_cell_call = counts_per_cell is not None if total_counts_call is counts_per_cell_call: @@ -1114,7 +1113,6 @@ def _downsample_total_counts( return x -# TODO: can/should this be parallelized? def _downsample_array( col: np.ndarray, target: int, @@ -1129,6 +1127,7 @@ def _downsample_array( * total counts in cell must be less than target """ + rng = _if_legacy_apply_global(rng) cumcounts = col.cumsum() total = np.int_(cumcounts[-1]) sample = rng.choice(total, target, replace=replace) @@ -1136,6 +1135,7 @@ def _downsample_array( return _downsample_array_inner(col, cumcounts, sample, inplace=inplace) +# TODO: can/should this be parallelized? @numba.njit(cache=True) # noqa: TID251 def _downsample_array_inner( col: np.ndarray, cumcounts: np.ndarray, sample: np.ndarray, *, inplace: bool From c3da2bb235cc1410f5a97b341b70812d0d47243e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:25:37 +0100 Subject: [PATCH 06/20] fix _RNGIgraph compat --- src/scanpy/_utils/random.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index b2692192b5..22c1c113d2 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -40,7 +40,7 @@ class _RNGIgraph: - """Random number generator for igraph so global seed is not changed. + """Random number generator for igraph so global random state is not changed. See :func:`igraph.set_random_number_generator` for the requirements. """ @@ -49,8 +49,11 @@ def __init__(self, rng: SeedLike | RNGLike | None) -> None: self._rng = np.random.default_rng(rng) def getrandbits(self, k: int) -> int: - lims = np.iinfo(np.uint64) - i = int(self._rng.integers(0, lims.max, dtype=np.uint64)) + if isinstance(self._rng, _FakeRandomGen): + i = self._rng._state.tomaxint() + else: + lims = np.iinfo(np.uint64) + i = int(self._rng.integers(0, lims.max, dtype=np.uint64)) return i & ((1 << k) - 1) def randint(self, a: int, b: int) -> np.int64: @@ -128,6 +131,9 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): setattr(cls, names.get(name, name), mk_wrapper(name, meth)) + def __getattribute__(self, name: str) -> object: + return super().__getattribute__(name) + _FakeRandomGen._delegate() From bd85d959e049ae495150e7a500a9aef33493fb61 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:38:31 +0100 Subject: [PATCH 07/20] whoops --- src/scanpy/tools/_louvain.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scanpy/tools/_louvain.py b/src/scanpy/tools/_louvain.py index b10491c28b..4e131affd6 100644 --- a/src/scanpy/tools/_louvain.py +++ b/src/scanpy/tools/_louvain.py @@ -86,7 +86,8 @@ def louvain( # noqa: PLR0912, PLR0913, PLR0915 resolution (higher resolution means finding more and smaller clusters), which defaults to 1.0. See “Time as a resolution parameter” in :cite:t:`Lambiotte2014`. - {random_state} + random_state + Change the initialization of the optimization. {restrict_to} key_added Key under which to add the cluster labels. (default: ``'louvain'``) From 8247cdbf9cff12c76686c412b84a07f5137cf10b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 13:43:31 +0100 Subject: [PATCH 08/20] relnote --- docs/release-notes/3983.feat.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/3983.feat.md diff --git a/docs/release-notes/3983.feat.md b/docs/release-notes/3983.feat.md new file mode 100644 index 0000000000..ca67baba94 --- /dev/null +++ b/docs/release-notes/3983.feat.md @@ -0,0 +1 @@ +Add support for {class}`numpy.random.Generator` to all functions previously accepting a `random_state` parameter {smaller}`P Angerer` From 47f3ceba1823d81a383f78cbbde5cf7e180f616f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 14:37:49 +0100 Subject: [PATCH 09/20] =?UTF-8?q?don=E2=80=99t=20store=20rng=20in=20random?= =?UTF-8?q?=5Fstate=20arg?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/scanpy/_utils/random.py | 9 +-------- src/scanpy/tools/_draw_graph.py | 5 ++++- src/scanpy/tools/_leiden.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index baf0c8fdc3..e80eac1613 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING import numpy as np +from numpy.random._generator import Generator from . import ensure_igraph @@ -107,14 +108,6 @@ def wrap_global( np.random.seed(arg) return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) - def __eq__(self, other: object) -> bool: - if not isinstance(other, _FakeRandomGen): - return False - return self._arg == other._arg - - def __hash__(self) -> int: - return hash((type(self), self._arg)) - @classmethod def _delegate(cls) -> None: names = dict(integers="randint") diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 4047d63a41..01c9b151a5 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -11,6 +11,7 @@ from .._utils.random import ( _if_legacy_apply_global, accepts_legacy_random_state, + legacy_random_state, set_igraph_rng, ) from ._utils import get_init_pos_from_paga @@ -154,7 +155,9 @@ def draw_graph( # noqa: PLR0913 ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, random_state=rng) + adata.uns["draw_graph"]["params"] = dict( + layout=layout, random_state=legacy_random_state(rng) + ) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions logg.info( diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 1f4fd2a7ea..f43723c6e0 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -204,7 +204,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=rng, + random_state=legacy_random_state(rng), n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity From 1e43b2ab4eb4be897fd37d8c07e55ebb5f22d0e3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 26 Feb 2026 14:50:08 +0100 Subject: [PATCH 10/20] make consistent --- src/scanpy/_utils/random.py | 18 +++++++++--------- src/scanpy/datasets/_datasets.py | 6 +++--- src/scanpy/experimental/pp/_normalization.py | 4 ++-- src/scanpy/experimental/pp/_recipes.py | 4 ++-- src/scanpy/neighbors/__init__.py | 14 +++++++------- src/scanpy/preprocessing/_pca/__init__.py | 14 +++++++------- src/scanpy/preprocessing/_pca/_compat.py | 6 +++--- src/scanpy/preprocessing/_recipes.py | 4 ++-- src/scanpy/preprocessing/_scrublet/__init__.py | 10 +++++----- src/scanpy/preprocessing/_scrublet/pipeline.py | 6 +++--- src/scanpy/preprocessing/_simple.py | 4 ++-- src/scanpy/preprocessing/_utils.py | 4 ++-- src/scanpy/tools/_diffmap.py | 4 ++-- src/scanpy/tools/_draw_graph.py | 12 ++++++------ src/scanpy/tools/_leiden.py | 14 +++++++------- src/scanpy/tools/_score_genes.py | 4 ++-- src/scanpy/tools/_tsne.py | 6 +++--- src/scanpy/tools/_umap.py | 10 +++++----- 18 files changed, 72 insertions(+), 72 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index e80eac1613..f5dc3c6531 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,12 +7,11 @@ from typing import TYPE_CHECKING import numpy as np -from numpy.random._generator import Generator from . import ensure_igraph if TYPE_CHECKING: - from collections.abc import Generator + from collections.abc import Callable, Generator from typing import Self from numpy.typing import NDArray @@ -22,10 +21,11 @@ "RNGLike", "SeedLike", "_LegacyRandom", + "_accepts_legacy_random_state", "_if_legacy_apply_global", - "accepts_legacy_random_state", + "_legacy_random_state", + "_set_igraph_rng", "ith_k_tuple", - "legacy_random_state", "random_k_tuples", "random_str", ] @@ -66,7 +66,7 @@ def __getattr__(self, attr: str): @contextmanager -def set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: +def _set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ensure_igraph() import igraph @@ -141,7 +141,7 @@ def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: return _FakeRandomGen.wrap_global(rng._arg, rng._state) -def legacy_random_state( +def _legacy_random_state( rng: SeedLike | RNGLike | None, *, always_state: bool = False ) -> _LegacyRandom: """Convert a np.random.Generator into a legacy `random_state` argument. @@ -154,9 +154,9 @@ def legacy_random_state( return np.random.RandomState(rng.bit_generator.spawn(1)[0]) -def accepts_legacy_random_state[**P, R]( +def _accepts_legacy_random_state[**P, R]( random_state_default: _LegacyRandom, -) -> callable[[callable[P, R]], callable[P, R]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. If the decorated function is called with a `random_state` argument, @@ -165,7 +165,7 @@ def accepts_legacy_random_state[**P, R]( If neither is given, ``random_state_default`` is used. """ - def decorator(func: callable[P, R]) -> callable[P, R]: + def decorator(func: Callable[P, R]) -> Callable[P, R]: @wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: match "random_state" in kwargs, "rng" in kwargs: diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index 5b56be0fad..a3775fe3a4 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -12,7 +12,7 @@ from .._compat import deprecated from .._settings import settings from .._utils._doctests import doctest_internet, doctest_needs -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..readwrite import read, read_h5ad, read_visium from ._utils import check_datasetdir_exists @@ -55,7 +55,7 @@ HERE = Path(__file__).parent -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def blobs( *, n_variables: int = 11, @@ -100,7 +100,7 @@ def blobs( n_features=n_variables, centers=n_centers, cluster_std=cluster_std, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) return AnnData(x, obs=dict(blobs=y.astype(str))) diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index 9cd2d6d092..74c03d518c 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -9,7 +9,7 @@ from ... import logging as logg from ..._compat import CSBase, warn from ..._utils import _doc_params, _empty, check_nonnegative_integers, view_to_actual -from ..._utils.random import accepts_legacy_random_state +from ..._utils.random import _accepts_legacy_random_state from ...experimental._docs import ( doc_adata, doc_check_values, @@ -162,7 +162,7 @@ def normalize_pearson_residuals( check_values=doc_check_values, inplace=doc_inplace, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def normalize_pearson_residuals_pca( adata: AnnData, *, diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index b688b8d4d8..ba944350bf 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -17,7 +17,7 @@ ) from scanpy.preprocessing import pca -from ..._utils.random import accepts_legacy_random_state +from ..._utils.random import _accepts_legacy_random_state if TYPE_CHECKING: from collections.abc import Mapping @@ -37,7 +37,7 @@ check_values=doc_check_values, inplace=doc_inplace, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def recipe_pearson_residuals( # noqa: PLR0913 adata: AnnData, *, diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 853695c30b..332dae931a 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -18,9 +18,9 @@ from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals from .._utils.random import ( + _accepts_legacy_random_state, _FakeRandomGen, - accepts_legacy_random_state, - legacy_random_state, + _legacy_random_state, ) from . import _connectivity from ._common import ( @@ -82,7 +82,7 @@ class NeighborsParams(TypedDict): # noqa: D101 @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def neighbors( # noqa: PLR0913 adata: AnnData, n_neighbors: int = 15, @@ -270,7 +270,7 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), metric=metric, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), @@ -579,7 +579,7 @@ def to_igraph(self) -> Graph: return _utils.get_igraph_from_adjacency(self.connectivities) @_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) - @accepts_legacy_random_state(0) + @_accepts_legacy_random_state(0) def compute_neighbors( self, n_neighbors: int = 30, @@ -627,7 +627,7 @@ def compute_neighbors( n_neighbors=n_neighbors, metric=metric, metric_params=metric_kwds, # most use _params, not _kwds - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) method, transformer, shortcut = self._handle_transformer( method, transformer, knn=knn, kwds=transformer_kwds_default @@ -849,7 +849,7 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None: self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) - @accepts_legacy_random_state(0) + @_accepts_legacy_random_state(0) def compute_eigen( self, *, diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 98f27aa030..1863506447 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -11,9 +11,9 @@ from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type from ..._utils.random import ( + _accepts_legacy_random_state, _FakeRandomGen, - accepts_legacy_random_state, - legacy_random_state, + _legacy_random_state, ) from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg @@ -58,7 +58,7 @@ @_doc_params( mask_var_hvg=doc_mask_var_hvg, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, n_comps: int | None = None, @@ -305,13 +305,13 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: - msg = f"Ignoring rng={legacy_random_state(rng)} when using a sparse dask array" + msg = f"Ignoring rng={_legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" @@ -324,7 +324,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 pca_ = PCA( n_components=n_comps, svd_solver=svd_solver, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) x_pca = pca_.fit_transform(x) else: @@ -351,7 +351,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 ) pca_ = TruncatedSVD( n_components=n_comps, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), algorithm=svd_solver, ) x_pca = pca_.fit_transform(x) diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index f1f4e33b0d..7809fff83a 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -10,7 +10,7 @@ from sklearn.utils.extmath import svd_flip from ..._compat import pkg_version -from ..._utils.random import accepts_legacy_random_state, legacy_random_state +from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -22,7 +22,7 @@ from ..._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(None) +@_accepts_legacy_random_state(None) def _pca_compat_sparse( x: CSBase, n_pcs: int, @@ -72,7 +72,7 @@ def rmat_op(v: NDArray[np.floating]): from sklearn.decomposition import PCA pca = PCA( - n_components=n_pcs, svd_solver=solver, random_state=legacy_random_state(rng) + n_components=n_pcs, svd_solver=solver, random_state=_legacy_random_state(rng) ) pca.explained_variance_ = ev pca.explained_variance_ratio_ = ev_ratio diff --git a/src/scanpy/preprocessing/_recipes.py b/src/scanpy/preprocessing/_recipes.py index 179a9ebdbe..4737cbe016 100644 --- a/src/scanpy/preprocessing/_recipes.py +++ b/src/scanpy/preprocessing/_recipes.py @@ -7,7 +7,7 @@ from .. import logging as logg from .. import preprocessing as pp from .._compat import CSBase -from .._utils.random import accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state from ._deprecated.highly_variable_genes import ( filter_genes_cv_deprecated, filter_genes_dispersion, @@ -20,7 +20,7 @@ from .._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def recipe_weinreb17( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index b9215ae163..502204537b 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,7 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp -from ..._utils.random import accepts_legacy_random_state, legacy_random_state +from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -20,7 +20,7 @@ from ...neighbors import _Metric, _MetricFn -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -286,7 +286,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): return adata if copy else None -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def _scrublet_call_doublets( # noqa: PLR0913 adata_obs: AnnData, adata_sim: AnnData, @@ -463,7 +463,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "random_state": legacy_random_state(rng), + "random_state": _legacy_random_state(rng), }, } @@ -488,7 +488,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 return adata_obs -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def scrublet_simulate_doublets( adata: AnnData, *, diff --git a/src/scanpy/preprocessing/_scrublet/pipeline.py b/src/scanpy/preprocessing/_scrublet/pipeline.py index 53d98ff8ed..78b8d0aa7f 100644 --- a/src/scanpy/preprocessing/_scrublet/pipeline.py +++ b/src/scanpy/preprocessing/_scrublet/pipeline.py @@ -6,7 +6,7 @@ from fast_array_utils.stats import mean_var from scipy import sparse -from ..._utils.random import legacy_random_state +from ..._utils.random import _legacy_random_state from .sparse_utils import sparse_multiply, sparse_zscore if TYPE_CHECKING: @@ -58,7 +58,7 @@ def truncated_svd( svd = TruncatedSVD( n_components=n_prin_comps, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), algorithm=algorithm, ).fit(self._counts_obs_norm) self.set_manifold( @@ -83,7 +83,7 @@ def pca( pca = PCA( n_components=n_prin_comps, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), svd_solver=svd_solver, ).fit(x_obs) self.set_manifold(pca.transform(x_obs), pca.transform(x_sim)) diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 51f8b4a726..42f6a70931 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -29,7 +29,7 @@ sanitize_anndata, view_to_actual, ) -from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _if_legacy_apply_global from ..get import _check_mask, _get_obs_rep, _set_obs_rep from ._distributed import materialize_as_ndarray @@ -972,7 +972,7 @@ def sample( # noqa: PLR0912 return subset, indices -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 3da0fed2d5..87d2402ad1 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -5,7 +5,7 @@ import numpy as np from sklearn.random_projection import sample_without_replacement -from .._utils.random import legacy_random_state +from .._utils.random import _legacy_random_state if TYPE_CHECKING: from typing import Literal @@ -23,7 +23,7 @@ def sample_comb( ] = "auto", ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" - random_state = legacy_random_state(rng) + random_state = _legacy_random_state(rng) idx = sample_without_replacement( np.prod(dims), nsamp, random_state=random_state, method=method ) diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 878f0d1d1d..b07f4e3d20 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -4,7 +4,7 @@ import numpy as np -from .._utils.random import accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state from ._dpt import _diffmap if TYPE_CHECKING: @@ -13,7 +13,7 @@ from .._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def diffmap( adata: AnnData, n_comps: int = 15, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 01c9b151a5..6c019dd6fb 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -9,10 +9,10 @@ from .. import logging as logg from .._utils import _choose_graph, get_literal_vals from .._utils.random import ( + _accepts_legacy_random_state, _if_legacy_apply_global, - accepts_legacy_random_state, - legacy_random_state, - set_igraph_rng, + _legacy_random_state, + _set_igraph_rng, ) from ._utils import get_init_pos_from_paga @@ -28,7 +28,7 @@ type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def draw_graph( # noqa: PLR0913 adata: AnnData, layout: _Layout = "fa", @@ -144,7 +144,7 @@ def draw_graph( # noqa: PLR0913 positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: g = _utils.get_igraph_from_adjacency(adjacency) - with set_igraph_rng(rng): + with _set_igraph_rng(rng): if layout in {"fr", "drl", "kk", "grid_fr"}: ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) elif "rt" in layout: @@ -156,7 +156,7 @@ def draw_graph( # noqa: PLR0913 positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} adata.uns["draw_graph"]["params"] = dict( - layout=layout, random_state=legacy_random_state(rng) + layout=layout, random_state=_legacy_random_state(rng) ) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index f43723c6e0..a5c9d493b1 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -11,9 +11,9 @@ from .._compat import warn from .._utils import _doc_params from .._utils.random import ( - accepts_legacy_random_state, - legacy_random_state, - set_igraph_rng, + _accepts_legacy_random_state, + _legacy_random_state, + _set_igraph_rng, ) from ._docs import ( doc_adata, @@ -48,7 +48,7 @@ neighbors_key=doc_neighbors_key.format(method="leiden"), obsp=doc_obsp, ) -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 adata: AnnData, resolution: float = 1, @@ -169,7 +169,7 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - clustering_args["seed"] = legacy_random_state(rng) + clustering_args["seed"] = _legacy_random_state(rng) part = cast( "MutableVertexPartition", leidenalg.find_partition(g, partition_type, **clustering_args), @@ -181,7 +181,7 @@ def leiden( # noqa: PLR0913 if resolution is not None: clustering_args["resolution"] = resolution clustering_args.setdefault("objective_function", "modularity") - with set_igraph_rng(rng): + with _set_igraph_rng(rng): part = g.community_leiden(**clustering_args) # store output into adata.obs groups = np.array(part.membership) @@ -204,7 +204,7 @@ def leiden( # noqa: PLR0913 adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( resolution=resolution, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), n_iterations=n_iterations, ) adata.uns[key_added]["modularity"] = part.modularity diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 97bd90dc8f..ae27a02904 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -10,7 +10,7 @@ from .. import logging as logg from .._compat import CSBase from .._utils import check_use_raw, is_backed_type -from .._utils.random import _if_legacy_apply_global, accepts_legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _if_legacy_apply_global from ..get import _get_obs_rep if TYPE_CHECKING: @@ -51,7 +51,7 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: return m -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def score_genes( # noqa: PLR0913 adata: AnnData, gene_list: Sequence[str] | pd.Index[str], diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index f18e954dc8..e8d9b4414f 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -6,7 +6,7 @@ from .._compat import warn from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep from ._utils import _choose_representation @@ -16,7 +16,7 @@ from .._utils.random import RNGLike, SeedLike -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) def tsne( # noqa: PLR0913 adata: AnnData, @@ -110,7 +110,7 @@ def tsne( # noqa: PLR0913 n_jobs = settings.n_jobs if n_jobs is None else n_jobs params_sklearn = dict( perplexity=perplexity, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), verbose=settings.verbosity > 3, early_exaggeration=early_exaggeration, learning_rate=learning_rate, diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 8d46298ec5..ba59c2712f 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -10,7 +10,7 @@ from .._compat import warn from .._settings import settings from .._utils import NeighborsView -from .._utils.random import accepts_legacy_random_state, legacy_random_state +from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -24,7 +24,7 @@ type _InitPos = Literal["paga", "spectral", "random"] -@accepts_legacy_random_state(0) +@_accepts_legacy_random_state(0) def umap( # noqa: PLR0913, PLR0915 adata: AnnData, *, @@ -186,7 +186,7 @@ def umap( # noqa: PLR0913, PLR0915 init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) if rng is not None: - adata.uns[key_uns]["params"]["random_state"] = legacy_random_state(rng) + adata.uns[key_uns]["params"]["random_state"] = _legacy_random_state(rng) neigh_params = neighbors["params"] x = _choose_representation( @@ -211,7 +211,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=legacy_random_state(rng, always_state=True), + random_state=_legacy_random_state(rng, always_state=True), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, @@ -251,7 +251,7 @@ def umap( # noqa: PLR0913, PLR0915 a=a, b=b, verbose=settings.verbosity > 3, - random_state=legacy_random_state(rng), + random_state=_legacy_random_state(rng), ) x_umap = umap.fit_transform(x_contiguous) adata.obsm[key_obsm] = x_umap # annotate samples with UMAP coordinates From 64a0f26d732e8b4987d65944e38914d2576eff48 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 27 Feb 2026 11:25:56 +0100 Subject: [PATCH 11/20] use sub-generators --- hatch.toml | 1 + src/scanpy/_utils/random.py | 26 ++-- src/scanpy/neighbors/__init__.py | 33 +++-- src/scanpy/preprocessing/_pca/_compat.py | 25 ++-- .../preprocessing/_scrublet/__init__.py | 34 ++++-- src/scanpy/preprocessing/_simple.py | 114 +++++++++--------- src/scanpy/tools/_draw_graph.py | 17 +-- src/scanpy/tools/_leiden.py | 9 +- src/scanpy/tools/_score_genes.py | 9 +- src/scanpy/tools/_umap.py | 24 ++-- src/scanpy/tools/_utils.py | 6 +- 11 files changed, 176 insertions(+), 122 deletions(-) diff --git a/hatch.toml b/hatch.toml index a5051b65d2..b7e7a5985b 100644 --- a/hatch.toml +++ b/hatch.toml @@ -36,6 +36,7 @@ overrides.matrix.deps.python = [ { if = [ "low-vers" ], value = "3.12" }, ] overrides.matrix.deps.extra-dependencies = [ + { if = [ "stable" ], value = "scipy>=1.17" }, { if = [ "pre" ], value = "anndata @ git+https://github.com/scverse/anndata.git" }, { if = [ "pre" ], value = "pandas>=3rc0" }, ] diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index f5dc3c6531..5ab2f57dc3 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -108,11 +108,19 @@ def wrap_global( np.random.seed(arg) return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) + def spawn(self, n_children: int) -> list[Self]: + """Return `self` `n_children` times. + + In a real generator, the spawned children are independent, + but for backwards compatibility we return the same instance. + """ + return [self] * n_children + @classmethod def _delegate(cls) -> None: names = dict(integers="randint") for name, meth in np.random.Generator.__dict__.items(): - if name.startswith("_") or not callable(meth): + if name.startswith("_") or not callable(meth) or name in cls.__dict__: continue def mk_wrapper(name: str, meth): @@ -129,11 +137,11 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs): _FakeRandomGen._delegate() -def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: - """Re-apply legacy `random_state` semantics when `rng` is a `_FakeRandomGen`. +def _if_legacy_apply_global(rng: np.random.Generator, /) -> np.random.Generator: + """Wrap the global legacy RNG if `rng` is a `_FakeRandomGen`. - This resets the global legacy RNG from the original `_arg` and returns a - generator which continues drawing from the same internal state. + This is used where our code used to call `np.random.seed()`. + It’s a no-op if `rng` is not a `_FakeRandomGen`. """ if not isinstance(rng, _FakeRandomGen): return rng @@ -142,7 +150,7 @@ def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator: def _legacy_random_state( - rng: SeedLike | RNGLike | None, *, always_state: bool = False + rng: SeedLike | RNGLike | None, /, *, always_state: bool = False ) -> _LegacyRandom: """Convert a np.random.Generator into a legacy `random_state` argument. @@ -150,12 +158,12 @@ def _legacy_random_state( """ if isinstance(rng, _FakeRandomGen): return rng._state if always_state else rng._arg - rng = np.random.default_rng(rng) - return np.random.RandomState(rng.bit_generator.spawn(1)[0]) + [bitgen] = np.random.default_rng(rng).bit_generator.spawn(1) + return np.random.RandomState(bitgen) def _accepts_legacy_random_state[**P, R]( - random_state_default: _LegacyRandom, + random_state_default: _LegacyRandom, / ) -> Callable[[Callable[P, R]], Callable[P, R]]: """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 332dae931a..d442967c5e 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -10,11 +10,12 @@ import numpy as np import scipy +from packaging.version import Version from scipy import sparse from .. import _utils from .. import logging as logg -from .._compat import CSBase, CSRBase, SpBase, warn +from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals from .._utils.random import ( @@ -46,9 +47,8 @@ # TODO: make `type` when https://github.com/sphinx-doc/sphinx/pull/13508 is released RPForestDict: TypeAlias = Mapping[str, Mapping[str, np.ndarray]] # noqa: UP040 -N_DCS: int = 15 # default number of diffusion components -# Backwards compat, constants should be defined in only one place. -N_PCS: int = settings.N_PCS + +SCIPY_1_17 = pkg_version("scipy") >= Version("1.17") class KwdsForTransformer(TypedDict): @@ -208,6 +208,10 @@ def neighbors( # noqa: PLR0913 :doc:`/how-to/knn-transformers` """ + meta_random_state = ( + dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + ) + if distances is None: if metric is None: metric = "euclidean" @@ -235,9 +239,8 @@ def neighbors( # noqa: PLR0913 if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds"} if params[p.name] != p.default } - if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + if meta_random_state.get("random_state") != 0: # rng or random_state was passed ignored.add("rng/random_state") - rng = _FakeRandomGen(0) if ignored: warn( f"Parameter(s) ignored if `distances` is given: {ignored}", @@ -270,8 +273,8 @@ def neighbors( # noqa: PLR0913 key_added, n_neighbors=neighbors_.n_neighbors, method=method, - random_state=_legacy_random_state(rng), metric=metric, + **meta_random_state, **({} if not metric_kwds else dict(metric_kwds=metric_kwds)), **({} if use_rep is None else dict(use_rep=use_rep)), **({} if n_pcs is None else dict(n_pcs=n_pcs)), @@ -849,15 +852,13 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None: self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) - @_accepts_legacy_random_state(0) def compute_eigen( self, *, n_comps: int = 15, - sym: bool | None = None, sort: Literal["decrease", "increase"] = "decrease", rng: np.random.Generator, - ): + ) -> None: """Compute eigen decomposition of transition matrix. Parameters @@ -886,6 +887,9 @@ def compute_eigen( plotting. """ + [rng_init, rng_eigsh] = np.random.default_rng(rng).spawn(2) + del rng + np.set_printoptions(precision=10) if self._transitions_sym is None: msg = "Run `.compute_transitions` first." @@ -903,9 +907,14 @@ def compute_eigen( matrix = matrix.astype(np.float64) # Setting the random initial vector - v0 = rng.standard_normal(matrix.shape[0]) + v0 = rng_init.standard_normal(matrix.shape[0]) evals, evecs = sparse.linalg.eigsh( - matrix, k=n_comps, which=which, ncv=ncv, v0=v0 + matrix, + k=n_comps, + which=which, + ncv=ncv, + v0=v0, + **(dict(rng=rng_eigsh) if SCIPY_1_17 else {}), ) evals, evecs = evals.astype(np.float32), evecs.astype(np.float32) if sort == "decrease": diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index 7809fff83a..e144a6776f 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -22,6 +22,9 @@ from ..._utils.random import RNGLike, SeedLike +SCIPY_1_15 = pkg_version("scikit-learn") >= Version("1.5.0rc1") + + @_accepts_legacy_random_state(None) def _pca_compat_sparse( x: CSBase, @@ -33,7 +36,11 @@ def _pca_compat_sparse( ) -> tuple[NDArray[np.floating], PCA]: """Sparse PCA for scikit-learn <1.4.""" rng = np.random.default_rng(rng) - random_init = rng.uniform(size=np.min(x.shape)) + # this exists only to be stored in our PCA container object + random_state_meta = _legacy_random_state(rng) + [rng_init, rng_svds] = rng.spawn(2) + del rng + x = check_array(x, accept_sparse=["csr", "csc"]) if mu is None: @@ -55,11 +62,15 @@ def rmat_op(v: NDArray[np.floating]): rmatmat=rmat_op, ) - u, s, v = svds(linop, solver=solver, k=n_pcs, v0=random_init) - # u_based_decision was changed in https://github.com/scikit-learn/scikit-learn/pull/27491 - u, v = svd_flip( - u, v, u_based_decision=pkg_version("scikit-learn") < Version("1.5.0rc1") + random_init = rng_init.uniform(size=np.min(x.shape)) + kw = ( + dict(rng=rng_svds) + if SCIPY_1_15 + else dict(random_state=_legacy_random_state(rng_svds)) ) + u, s, v = svds(linop, solver=solver, k=n_pcs, v0=random_init, **kw) + # u_based_decision was changed in https://github.com/scikit-learn/scikit-learn/pull/27491 + u, v = svd_flip(u, v, u_based_decision=not SCIPY_1_15) idx = np.argsort(-s) v = v[idx, :] @@ -71,9 +82,7 @@ def rmat_op(v: NDArray[np.floating]): from sklearn.decomposition import PCA - pca = PCA( - n_components=n_pcs, svd_solver=solver, random_state=_legacy_random_state(rng) - ) + pca = PCA(n_components=n_pcs, svd_solver=solver, random_state=random_state_meta) pca.explained_variance_ = ev pca.explained_variance_ratio_ = ev_ratio pca.components_ = v diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 502204537b..c0979d212f 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,7 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp -from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state +from ..._utils.random import _accepts_legacy_random_state, _FakeRandomGen from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -177,10 +177,12 @@ def scrublet( # noqa: PLR0913 adata_obs = adata.copy() - def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): + def _run_scrublet( + ad_obs: AnnData, ad_sim: AnnData | None, *, rng: np.random.Generator + ): + rng_sim, rng_call = rng.spawn(2) # With no adata_sim we assume the regular use case, starting with raw # counts and simulating doublets - if ad_sim is None: pp.filter_genes(ad_obs, min_cells=3) pp.filter_cells(ad_obs, min_genes=3) @@ -207,7 +209,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): layer="raw", sim_doublet_ratio=sim_doublet_ratio, synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling, - rng=rng, + rng=rng_sim, ) del ad_obs.layers["raw"] if log_transform: @@ -232,7 +234,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): knn_dist_metric=knn_dist_metric, get_doublet_neighbor_parents=get_doublet_neighbor_parents, threshold=threshold, - rng=rng, + rng=rng_call, verbose=verbose, ) @@ -249,12 +251,14 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): # Run Scrublet independently on batches and return just the # scrublet-relevant parts of the objects to add to the input object batches = np.unique(adata.obs[batch_key]) + sub_rngs = rng.spawn(len(batches)) scrubbed = [ _run_scrublet( adata_obs[adata_obs.obs[batch_key] == batch].copy(), adata_sim, + rng=sub_rng, ) - for batch in batches + for batch, sub_rng in zip(batches, sub_rngs, strict=True) ] scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]).astype( adata.obs.dtypes @@ -274,7 +278,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None): adata.uns["scrublet"]["batched_by"] = batch_key else: - scrubbed = _run_scrublet(adata_obs, adata_sim) + scrubbed = _run_scrublet(adata_obs, adata_sim, rng=rng) # Copy outcomes to input object from our processed version adata.obs["doublet_score"] = scrubbed["obs"]["doublet_score"] @@ -385,6 +389,12 @@ def _scrublet_call_doublets( # noqa: PLR0913 Dictionary of Scrublet parameters """ + meta_random_state = ( + dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + ) + rng_scrub, rng_pca = rng.spawn(2) + del rng + # Estimate n_neighbors if not provided, and create scrublet object. if n_neighbors is None: @@ -398,7 +408,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 n_neighbors=n_neighbors, expected_doublet_rate=expected_doublet_rate, stdev_doublet_rate=stdev_doublet_rate, - rng=rng, + rng=rng_scrub, ) # Ensure normalised matrix sparseness as Scrublet does @@ -424,13 +434,11 @@ def _scrublet_call_doublets( # noqa: PLR0913 if mean_center: logg.info("Embedding transcriptomes using PCA...") - pipeline.pca( - scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=scrub._rng - ) + pipeline.pca(scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=rng_pca) else: logg.info("Embedding transcriptomes using Truncated SVD...") pipeline.truncated_svd( - scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=scrub._rng + scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=rng_pca ) # Score the doublets @@ -463,7 +471,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 .get("sim_doublet_ratio", None) ), "n_neighbors": n_neighbors, - "random_state": _legacy_random_state(rng), + **meta_random_state, }, } diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 42f6a70931..09a1729e6c 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -6,6 +6,8 @@ from __future__ import annotations import warnings +from contextlib import AbstractContextManager +from dataclasses import dataclass, field from functools import singledispatch from itertools import repeat from typing import TYPE_CHECKING, overload @@ -15,6 +17,7 @@ from anndata import AnnData from fast_array_utils import stats from fast_array_utils.conv import to_dense +from numpy._typing._array_like import NDArray from pandas.api.types import CategoricalDtype from sklearn.utils import check_array, sparsefuncs @@ -36,7 +39,7 @@ if TYPE_CHECKING: from collections.abc import Collection, Iterable, Sequence from numbers import Number - from typing import Literal + from typing import Literal, Self import pandas as pd from numpy.typing import NDArray @@ -1020,97 +1023,78 @@ def downsample_counts( raise_not_implemented_error_if_backed_type(adata.X, "downsample_counts") # This logic is all dispatch rng = np.random.default_rng(rng) - total_counts_call = total_counts is not None - counts_per_cell_call = counts_per_cell is not None - if total_counts_call is counts_per_cell_call: + if (total_counts is not None) is (counts_per_cell is not None): msg = "Must specify exactly one of `total_counts` or `counts_per_cell`." raise ValueError(msg) if copy: adata = adata.copy() - if total_counts_call: + if total_counts is not None: adata.X = _downsample_total_counts( adata.X, total_counts, rng=rng, replace=replace ) - elif counts_per_cell_call: + elif counts_per_cell is not None: adata.X = _downsample_per_cell( adata.X, counts_per_cell, rng=rng, replace=replace ) - if copy: - return adata + return adata if copy else None -def _downsample_per_cell( - x: np.ndarray | CSBase, +def _downsample_per_cell[T: (np.ndarray, CSBase)]( + x: T, /, - counts_per_cell: int, + counts_per_cell: int | Collection[int], *, rng: np.random.Generator, replace: bool, -) -> CSBase: +) -> T: n_obs = x.shape[0] - if isinstance(counts_per_cell, int): - counts_per_cell = np.full(n_obs, counts_per_cell) - else: - counts_per_cell = np.asarray(counts_per_cell) - # np.random.choice needs int arguments in numba code: - counts_per_cell = counts_per_cell.astype(np.int_, copy=False) - if not isinstance(counts_per_cell, np.ndarray) or len(counts_per_cell) != n_obs: + counts_per_cell = ( + np.full(n_obs, counts_per_cell) + if isinstance(counts_per_cell, int) + else np.asarray(counts_per_cell, np.int_) + ) + if counts_per_cell.shape != (n_obs,): msg = ( "If provided, 'counts_per_cell' must be either an integer, or " "coercible to an `np.ndarray` of length as number of observations" " by `np.asarray(counts_per_cell)`." ) raise ValueError(msg) - if isinstance(x, CSBase): - original_type = type(x) - if not isinstance(x, CSRBase): - x = x.tocsr() - totals = stats.sum(x, axis=1) # Faster for csr matrix - under_target = np.nonzero(totals > counts_per_cell)[0] - rows = np.split(x.data, x.indptr[1:-1]) - for rowidx in under_target: - row = rows[rowidx] - _downsample_array( - row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True - ) - x.eliminate_zeros() - if not issubclass(original_type, CSRBase): # Put it back - x = original_type(x) - else: + with sparse_as_csr(x) as spc: + x = spc.x # we only mutate x, so spc.x receives the changes + rows = np.split(x.data, x.indptr[1:-1]) if isinstance(x, CSRBase) else x totals = stats.sum(x, axis=1) - under_target = np.nonzero(totals > counts_per_cell)[0] - for rowidx in under_target: - row = x[rowidx, :] + under_target = np.flatnonzero(totals > counts_per_cell) + for rowidx, sub_rng in zip( + under_target, rng.spawn(len(under_target)), strict=True + ): _downsample_array( - row, counts_per_cell[rowidx], rng=rng, replace=replace, inplace=True + rows[rowidx], + counts_per_cell[rowidx], + rng=sub_rng, + replace=replace, + inplace=True, ) - return x + return spc.x # use x that was converted back -def _downsample_total_counts( - x: np.ndarray | CSBase, +def _downsample_total_counts[T: (np.ndarray, CSBase)]( + x: T, /, total_counts: int, *, rng: np.random.Generator, replace: bool, -) -> CSBase: +) -> T: total_counts = int(total_counts) total = x.sum() if total < total_counts: return x - if isinstance(x, CSBase): - original_type = type(x) - if not isinstance(x, CSRBase): - x = x.tocsr() - _downsample_array(x.data, total_counts, rng=rng, replace=replace, inplace=True) - x.eliminate_zeros() - if not issubclass(original_type, CSRBase): - x = original_type(x) - else: - v = x.reshape(np.multiply(*x.shape)) + with sparse_as_csr(x) as spc: + x = spc.x # we only mutate x, so spc.x receives the changes + v = x.data if isinstance(x, CSBase) else x.reshape(-1) _downsample_array(v, total_counts, rng=rng, replace=replace, inplace=True) - return x + return spc.x # use x that was converted back def _downsample_array( @@ -1150,3 +1134,25 @@ def _downsample_array_inner( geneptr += 1 col[geneptr] += 1 return col + + +@dataclass +class sparse_as_csr[T: (np.ndarray | CSBase)](AbstractContextManager): # noqa: N801 + """Context manager that converts to CSR while active.""" + + x: T + _original_type: type[T] = field(init=False) + + def __post_init__(self) -> None: + self._original_type = type(self.x) + if isinstance(self.x, CSBase) and not isinstance(self.x, CSRBase): + self.x = self.x.tocsr() + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: object) -> bool | None: + if isinstance(self.x, CSBase): + self.x.eliminate_zeros() + if not issubclass(self._original_type, CSRBase): + self.x = self._original_type(self.x) diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 6c019dd6fb..75e46c570a 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -10,8 +10,8 @@ from .._utils import _choose_graph, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, + _FakeRandomGen, _if_legacy_apply_global, - _legacy_random_state, _set_igraph_rng, ) from ._utils import get_init_pos_from_paga @@ -118,7 +118,12 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") rng = np.random.default_rng(rng) + meta_random_state = ( + dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + ) rng = _if_legacy_apply_global(rng) + rng_init, rng_layout = rng.spawn(2) + del rng if layout not in (layouts := get_literal_vals(_Layout)): msg = f"Provide a valid layout, one of {layouts}." raise ValueError(msg) @@ -132,19 +137,19 @@ def draw_graph( # noqa: PLR0913 init_coords = get_init_pos_from_paga( adata, adjacency, - rng=rng, + rng=rng_init, neighbors_key=neighbors_key, obsp=obsp, ) else: - init_coords = rng.random((adjacency.shape[0], 2)) + init_coords = rng_init.random((adjacency.shape[0], 2)) layout = coerce_fa2_layout(layout) # actual drawing if layout == "fa": positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: g = _utils.get_igraph_from_adjacency(adjacency) - with _set_igraph_rng(rng): + with _set_igraph_rng(rng_layout): if layout in {"fr", "drl", "kk", "grid_fr"}: ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) elif "rt" in layout: @@ -155,9 +160,7 @@ def draw_graph( # noqa: PLR0913 ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict( - layout=layout, random_state=_legacy_random_state(rng) - ) + adata.uns["draw_graph"]["params"] = dict(layout=layout, **meta_random_state) key_added = f"X_draw_graph_{key_added_ext or layout}" adata.obsm[key_added] = positions logg.info( diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index a5c9d493b1..6e0f52a98e 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -12,6 +12,7 @@ from .._utils import _doc_params from .._utils.random import ( _accepts_legacy_random_state, + _FakeRandomGen, _legacy_random_state, _set_igraph_rng, ) @@ -140,6 +141,10 @@ def leiden( # noqa: PLR0913 _utils.ensure_igraph() clustering_args = dict(clustering_args) + meta_random_state = ( + dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + ) + start = logg.info("running Leiden clustering") adata = adata.copy() if copy else adata # are we clustering a user-provided graph or the default AnnData one? @@ -203,9 +208,7 @@ def leiden( # noqa: PLR0913 # store information on the clustering parameters adata.uns[key_added] = {} adata.uns[key_added]["params"] = dict( - resolution=resolution, - random_state=_legacy_random_state(rng), - n_iterations=n_iterations, + resolution=resolution, n_iterations=n_iterations, **meta_random_state ) adata.uns[key_added]["modularity"] = part.modularity logg.info( diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index ae27a02904..0ba868a888 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -121,10 +121,8 @@ def score_genes( # noqa: PLR0913 """ start = logg.info(f"computing score {score_name!r}") - rng_was_passed = rng is not None rng = np.random.default_rng(rng) - if rng_was_passed: # backwards compatibility: call np.random.seed() by default - rng = _if_legacy_apply_global(rng) + rng = _if_legacy_apply_global(rng) adata = adata.copy() if copy else adata use_raw = check_use_raw(adata, use_raw, layer=layer) if is_backed_type(adata.X) and not use_raw: @@ -240,7 +238,8 @@ def _score_genes_bins( keep_ctrl_in_obs_cut = np.False_ if ctrl_as_ref else obs_cut.index.isin(gene_list) # now pick `ctrl_size` genes from every cut - for cut in np.unique(obs_cut.loc[gene_list]): + cuts = np.unique(obs_cut.loc[gene_list]) + for cut, sub_rng in zip(cuts, rng.spawn(len(cuts)), strict=True): r_genes: pd.Index[str] = obs_cut[(obs_cut == cut) & ~keep_ctrl_in_obs_cut].index if len(r_genes) == 0: msg = ( @@ -249,7 +248,7 @@ def _score_genes_bins( ) logg.warning(msg) if ctrl_size < len(r_genes): - r_genes = r_genes.to_series().sample(ctrl_size, random_state=rng).index + r_genes = r_genes.to_series().sample(ctrl_size, random_state=sub_rng).index if ctrl_as_ref: # otherwise `r_genes` is already filtered r_genes = r_genes.difference(gene_list) yield r_genes diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index ba59c2712f..6e8982874a 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -10,7 +10,11 @@ from .._compat import warn from .._settings import settings from .._utils import NeighborsView -from .._utils.random import _accepts_legacy_random_state, _legacy_random_state +from .._utils.random import ( + _accepts_legacy_random_state, + _FakeRandomGen, + _legacy_random_state, +) from ._utils import _choose_representation, get_init_pos_from_paga if TYPE_CHECKING: @@ -144,7 +148,12 @@ def umap( # noqa: PLR0913, PLR0915 UMAP parameters. """ - rng = np.random.default_rng(rng) + rng_init, rng_umap = np.random.default_rng(rng).spawn(2) + meta_random_state = ( + dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else None + ) + del rng + adata = adata.copy() if copy else adata key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 @@ -173,21 +182,18 @@ def umap( # noqa: PLR0913, PLR0915 if a is None or b is None: a, b = find_ab_params(spread, min_dist) - adata.uns[key_uns] = dict(params=dict(a=a, b=b)) + adata.uns[key_uns] = dict(params=dict(a=a, b=b, **meta_random_state)) if isinstance(init_pos, str) and init_pos in adata.obsm: init_coords = adata.obsm[init_pos] elif isinstance(init_pos, str) and init_pos == "paga": init_coords = get_init_pos_from_paga( - adata, rng=rng, neighbors_key=neighbors_key + adata, rng=rng_init, neighbors_key=neighbors_key ) else: init_coords = init_pos # Let umap handle it if hasattr(init_coords, "dtype"): init_coords = check_array(init_coords, dtype=np.float32, accept_sparse=False) - if rng is not None: - adata.uns[key_uns]["params"]["random_state"] = _legacy_random_state(rng) - neigh_params = neighbors["params"] x = _choose_representation( adata, @@ -211,7 +217,7 @@ def umap( # noqa: PLR0913, PLR0915 negative_sample_rate=negative_sample_rate, n_epochs=n_epochs, init=init_coords, - random_state=_legacy_random_state(rng, always_state=True), + random_state=_legacy_random_state(rng_umap, always_state=True), metric=neigh_params.get("metric", "euclidean"), metric_kwds=neigh_params.get("metric_kwds", {}), densmap=False, @@ -251,7 +257,7 @@ def umap( # noqa: PLR0913, PLR0915 a=a, b=b, verbose=settings.verbosity > 3, - random_state=_legacy_random_state(rng), + random_state=_legacy_random_state(rng_umap), ) x_umap = umap.fit_transform(x_contiguous) adata.obsm[key_obsm] = x_umap # annotate samples with UMAP coordinates diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index bdb9ae7f90..8655d93ccd 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -94,13 +94,15 @@ def get_init_pos_from_paga( pos = adata.uns["paga"]["pos"] connectivities_coarse = adata.uns["paga"]["connectivities"] init_pos = np.ones((adjacency.shape[0], 2)) - for i, group_pos in enumerate(pos): + for i, group_pos, sub_rng in zip( + range(len(pos)), pos, rng.spawn(len(pos)), strict=True + ): subset = (groups == groups.cat.categories[i]).values neighbors = connectivities_coarse[i].nonzero() if len(neighbors[1]) > 0: connectivities = connectivities_coarse[i][neighbors] nearest_neighbor = neighbors[1][np.argmax(connectivities)] - noise = rng.random((len(subset[subset]), 2)) + noise = sub_rng.random((len(subset[subset]), 2)) dist = group_pos - pos[nearest_neighbor] noise = noise * dist init_pos[subset] = group_pos - 0.5 * dist + noise From baf2c852cff5e1e465cd5442d99097530d10f110 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 27 Feb 2026 12:23:29 +0100 Subject: [PATCH 12/20] docs --- src/scanpy/_utils/random.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 5ab2f57dc3..de42de0785 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -14,6 +14,7 @@ from collections.abc import Callable, Generator from typing import Self + from numpy.random import BitGenerator from numpy.typing import NDArray @@ -84,6 +85,13 @@ def _set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: class _FakeRandomGen(np.random.Generator): + """A `Generator` that wraps a legacy `RandomState` instance. + + To behave like a `RandomState`, it’s not enough to just use a MT19937 `bit_generator` + (as in `Generator(RandomState(seed).bit_generator)`), + so instead this hack uses the exact same random numbers as `RandomState(seed)`. + """ + _arg: _LegacyRandom _state: np.random.RandomState @@ -92,7 +100,11 @@ def __init__( ) -> None: self._arg = arg self._state = np.random.RandomState(arg) if state is None else state - super().__init__(self._state._bit_generator) + + @property + def bit_generator(self) -> BitGenerator: + msg = "A _FakeRandomGen instance has no `bit_generator` attribute." + raise AttributeError(msg) @classmethod def wrap_global( From 7e2fab542f4d34568624f640b1051cc4916b06d2 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 27 Feb 2026 12:23:32 +0100 Subject: [PATCH 13/20] paga --- src/scanpy/plotting/_tools/paga.py | 95 ++++++++++++++++-------------- src/scanpy/preprocessing/_utils.py | 23 +++----- 2 files changed, 58 insertions(+), 60 deletions(-) diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index 2fb9276fe8..a751bd8cd3 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Collection, Mapping, Sequence +from contextlib import nullcontext from pathlib import Path from types import MappingProxyType from typing import TYPE_CHECKING, TypedDict @@ -13,7 +14,6 @@ from matplotlib import pyplot as plt from matplotlib.colors import is_color_like from pandas.api.types import CategoricalDtype -from sklearn.utils import check_random_state from scanpy.tools._draw_graph import coerce_fa2_layout, fa2_positions @@ -21,6 +21,11 @@ from ... import logging as logg from ..._compat import CSBase from ..._settings import settings +from ..._utils.random import ( + _accepts_legacy_random_state, + _FakeRandomGen, + _set_igraph_rng, +) from .. import _utils from .._utils import matrix @@ -30,9 +35,10 @@ from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap + from numpy.typing import NDArray from ..._compat import SpBase - from ..._utils.random import _LegacyRandom + from ..._utils.random import RNGLike, SeedLike from ...tools._draw_graph import _Layout as _LayoutWithoutEqTree from .._utils import _FontSize, _FontWeight, _LegendLoc @@ -187,27 +193,24 @@ def paga_compare( # noqa: PLR0912, PLR0913 def _compute_pos( # noqa: PLR0912 adjacency_solid: SpBase | np.ndarray, *, - layout: _Layout | None = None, - random_state: _LegacyRandom = 0, - init_pos: np.ndarray | None = None, - adj_tree=None, - root: int = 0, + layout: _Layout | None, + rng: np.random.Generator, + init_pos: np.ndarray | None, + adj_tree, + root: int, layout_kwds: Mapping[str, Any] = MappingProxyType({}), -): +) -> NDArray[np.float64]: import random import networkx as nx - random_state = check_random_state(random_state) - nx_g_solid = nx.Graph(adjacency_solid) if layout is None: layout = "fr" layout = coerce_fa2_layout(layout) if layout == "fa": - # np.random.seed(random_state) if init_pos is None: - init_coords = random_state.random_sample((adjacency_solid.shape[0], 2)) + init_coords = rng.random((adjacency_solid.shape[0], 2)) else: init_coords = init_pos.copy() pos_list = fa2_positions(adjacency_solid, init_coords, **layout_kwds) @@ -221,41 +224,41 @@ def _compute_pos( # noqa: PLR0912 "Try another `layout`, e.g., {'fr'}." ) raise ValueError(msg) - else: - # igraph layouts - random.seed(random_state.bytes(8)) - g = _sc_utils.get_igraph_from_adjacency(adjacency_solid) - if "rt" in layout: - g_tree = _sc_utils.get_igraph_from_adjacency(adj_tree) - pos_list = g_tree.layout( - layout, root=root if isinstance(root, list) else [root] - ).coords - elif layout == "circle": - pos_list = g.layout(layout).coords + else: # igraph layouts + if isinstance(rng, _FakeRandomGen): # backwards compat + random.seed(rng.bytes(8)) + ctx = nullcontext() else: - # I don't know why this is necessary - # np.random.seed(random_state) - if init_pos is None: - init_coords = random_state.random_sample(( - adjacency_solid.shape[0], - 2, - )).tolist() - else: - init_pos = init_pos.copy() - # this is a super-weird hack that is necessary as igraph’s - # layout function seems to do some strange stuff here - init_pos[:, 1] *= -1 - init_coords = init_pos.tolist() - try: - pos_list = g.layout( - layout, seed=init_coords, weights="weight", **layout_kwds + ctx = _set_igraph_rng(rng) + g = _sc_utils.get_igraph_from_adjacency(adjacency_solid) + with ctx: + if "rt" in layout: + g_tree = _sc_utils.get_igraph_from_adjacency(adj_tree) + pos_list = g_tree.layout( + layout, root=root if isinstance(root, list) else [root] ).coords - except AttributeError: # hack for empty graphs... - pos_list = g.layout(layout, seed=init_coords, **layout_kwds).coords + elif layout == "circle": + pos_list = g.layout(layout).coords + else: + # I don't know why this is necessary + if init_pos is None: + init_coords = rng.random((adjacency_solid.shape[0], 2)).tolist() + else: + init_pos = init_pos.copy() + # this is a super-weird hack that is necessary as igraph’s + # layout function seems to do some strange stuff here + init_pos[:, 1] *= -1 + init_coords = init_pos.tolist() + try: + pos_list = g.layout( + layout, seed=init_coords, weights="weight", **layout_kwds + ).coords + except AttributeError: # hack for empty graphs... + pos_list = g.layout(layout, seed=init_coords, **layout_kwds).coords pos = {n: (x, -y) for n, (x, y) in enumerate(pos_list)} if len(pos) == 1: pos[0] = (0.5, 0.5) - pos_array = np.array([pos[n] for count, n in enumerate(nx_g_solid)]) + pos_array = np.array([pos[n] for n in nx_g_solid]) return pos_array @@ -331,6 +334,7 @@ def make_pos( return make_pos({}) +@_accepts_legacy_random_state(0) def paga( # noqa: PLR0912, PLR0913, PLR0915 adata: AnnData, *, @@ -357,7 +361,7 @@ def paga( # noqa: PLR0912, PLR0913, PLR0915 arrowsize: int = 30, title: str | None = None, left_margin: float = 0.01, - random_state: int | None = 0, + rng: SeedLike | RNGLike | None = None, pos: np.ndarray | Path | str | None = None, normalize_to_color: bool = False, cmap: str | Colormap | None = None, @@ -422,7 +426,7 @@ def paga( # noqa: PLR0912, PLR0913, PLR0915 init_pos Two-column array storing the x and y coordinates for initializing the layout. - random_state + rng For layouts with random initialization like `'fr'`, change this to use different intial states for the optimization. If `None`, the initial state is not reproducible. @@ -529,6 +533,7 @@ def paga( # noqa: PLR0912, PLR0913, PLR0915 pl.paga_path """ + rng = np.random.default_rng(rng) if groups is not None: # backwards compat labels = groups logg.warning("`groups` is deprecated in `pl.paga`: use `labels` instead") @@ -607,7 +612,7 @@ def is_flat(x): pos = _compute_pos( adjacency_solid, layout=layout, - random_state=random_state, + rng=rng, init_pos=init_pos, layout_kwds=layout_kwds, adj_tree=adj_tree, diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 87d2402ad1..9647739957 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -3,28 +3,21 @@ from typing import TYPE_CHECKING import numpy as np -from sklearn.random_projection import sample_without_replacement -from .._utils.random import _legacy_random_state +from .._utils.random import _FakeRandomGen if TYPE_CHECKING: - from typing import Literal - from numpy.typing import NDArray def sample_comb( - dims: tuple[int, ...], - nsamp: int, - *, - rng: np.random.Generator, - method: Literal[ - "auto", "tracking_selection", "reservoir_sampling", "pool" - ] = "auto", + dims: tuple[int, ...], nsamp: int, *, rng: np.random.Generator ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" - random_state = _legacy_random_state(rng) - idx = sample_without_replacement( - np.prod(dims), nsamp, random_state=random_state, method=method - ) + if isinstance(rng, _FakeRandomGen): + from sklearn.random_projection import sample_without_replacement + + idx = sample_without_replacement(np.prod(dims), nsamp, random_state=rng._arg) + else: + idx = rng.choice(np.prod(dims), size=nsamp, replace=False) return np.vstack(np.unravel_index(idx, dims)).T From 8ad699ae7eab95fa091e93b1552c3fa4908521eb Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 27 Feb 2026 16:19:12 +0100 Subject: [PATCH 14/20] test --- src/scanpy/_utils/random.py | 3 +- src/scanpy/preprocessing/_pca/__init__.py | 2 +- src/scanpy/preprocessing/_pca/_compat.py | 2 +- .../preprocessing/_scrublet/sparse_utils.py | 2 +- src/scanpy/tools/_leiden.py | 17 ++-- src/scanpy/tools/_umap.py | 4 +- tests/test_clustering.py | 65 +++++++--------- tests/test_embedding.py | 27 ++++--- tests/test_highly_variable_genes.py | 56 ++++++++------ tests/test_neighbors_key_added.py | 77 ++++++++++++------- tests/test_paga.py | 22 ++++-- tests/test_pca.py | 18 +++-- tests/test_scrublet.py | 11 +-- 13 files changed, 180 insertions(+), 126 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index de42de0785..15fa95978b 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING import numpy as np +from sklearn.utils.random import check_random_state from . import ensure_igraph @@ -99,7 +100,7 @@ def __init__( self, arg: _LegacyRandom, state: np.random.RandomState | None = None ) -> None: self._arg = arg - self._state = np.random.RandomState(arg) if state is None else state + self._state = check_random_state(arg) if state is None else state @property def bit_generator(self) -> BitGenerator: diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 1863506447..94d69570dd 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -311,7 +311,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 from ._dask import PCAEighDask if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: - msg = f"Ignoring rng={_legacy_random_state(rng)} when using a sparse dask array" + msg = f"Ignoring random_state={_legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: msg = f"Ignoring {svd_solver=} when using a sparse dask array" diff --git a/src/scanpy/preprocessing/_pca/_compat.py b/src/scanpy/preprocessing/_pca/_compat.py index e144a6776f..a89f69fa50 100644 --- a/src/scanpy/preprocessing/_pca/_compat.py +++ b/src/scanpy/preprocessing/_pca/_compat.py @@ -22,7 +22,7 @@ from ..._utils.random import RNGLike, SeedLike -SCIPY_1_15 = pkg_version("scikit-learn") >= Version("1.5.0rc1") +SCIPY_1_15 = pkg_version("scipy") >= Version("1.5.0rc1") @_accepts_legacy_random_state(None) diff --git a/src/scanpy/preprocessing/_scrublet/sparse_utils.py b/src/scanpy/preprocessing/_scrublet/sparse_utils.py index 611754f91b..518f0216e1 100644 --- a/src/scanpy/preprocessing/_scrublet/sparse_utils.py +++ b/src/scanpy/preprocessing/_scrublet/sparse_utils.py @@ -52,7 +52,7 @@ def subsample_counts( e.data = rng.binomial(np.round(e.data).astype(int), rate) current_totals = np.asarray(e.sum(1)).squeeze() unsampled_orig_totals = original_totals - current_totals - unsampled_downsamp_totals = np.random.binomial( + unsampled_downsamp_totals = rng.binomial( np.round(unsampled_orig_totals).astype(int), rate ) final_downsamp_totals = current_totals + unsampled_downsamp_totals diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 6e0f52a98e..b38440de9e 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -13,7 +13,6 @@ from .._utils.random import ( _accepts_legacy_random_state, _FakeRandomGen, - _legacy_random_state, _set_igraph_rng, ) from ._docs import ( @@ -127,8 +126,8 @@ def leiden( # noqa: PLR0913 (``'0'``, ``'1'``, ...) for each cell. `adata.uns['leiden' | key_added]['params']` : :class:`dict` - A dict with the values for the parameters `resolution`, `random_state`, - and `n_iterations`. + A dict with the values for the parameters `resolution`, `n_iterations`, + and `random_state` (if applicable). `adata.uns['leiden' | key_added]['modularity']` : :class:`float` The modularity score of the final clustering, @@ -140,7 +139,7 @@ def leiden( # noqa: PLR0913 flavor = _validate_flavor(flavor, partition_type=partition_type, directed=directed) _utils.ensure_igraph() clustering_args = dict(clustering_args) - + rng = np.random.default_rng(rng) meta_random_state = ( dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} ) @@ -174,10 +173,16 @@ def leiden( # noqa: PLR0913 partition_type = leidenalg.RBConfigurationVertexPartition if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) - clustering_args["seed"] = _legacy_random_state(rng) + seed = ( + rng._arg + if isinstance(rng, _FakeRandomGen) + and isinstance(rng._arg, int | np.integer) + # for some reason leidenalg only accepts int32 (signed) seeds … + else rng.integers((i := np.iinfo(np.int32)).min, i.max, dtype=np.int32) + ) part = cast( "MutableVertexPartition", - leidenalg.find_partition(g, partition_type, **clustering_args), + leidenalg.find_partition(g, partition_type, seed=seed, **clustering_args), ) else: g = _utils.get_igraph_from_adjacency(adjacency, directed=False) diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 6e8982874a..8e3cb0d0ca 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -104,7 +104,7 @@ def umap( # noqa: PLR0913, PLR0915 * A numpy array of initial embedding positions. rng If `int`, `rng` is the seed used by the random number generator; - If `Generator`, `random_state` is the random number generator; + If `np.random.Generator`, `rng` is the random number generator; If `None`, the random number generator is not reproducible. a More specific parameters controlling the embedding. If `None` these @@ -150,7 +150,7 @@ def umap( # noqa: PLR0913, PLR0915 """ rng_init, rng_umap = np.random.default_rng(rng).spawn(2) meta_random_state = ( - dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else None + dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} ) del rng diff --git a/tests/test_clustering.py b/tests/test_clustering.py index fbfe0cb6ff..51fa863d9e 100644 --- a/tests/test_clustering.py +++ b/tests/test_clustering.py @@ -66,47 +66,40 @@ def test_leiden_basic( @needs.leidenalg @needs.igraph +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) def test_leiden_random_state( - adata_neighbors: AnnData, flavor: Literal["igraph", "leidenalg"] + subtests: pytest.Subtests, + adata_neighbors: AnnData, + flavor: Literal["igraph", "leidenalg"], + rng_arg: Literal["rng", "random_state"], ) -> None: is_leiden_alg = flavor == "leidenalg" n_iterations = 2 if is_leiden_alg else -1 - adata_1 = sc.tl.leiden( - adata_neighbors, - flavor=flavor, - random_state=1, - copy=True, - directed=is_leiden_alg, - n_iterations=n_iterations, - ) - adata_1_again = sc.tl.leiden( - adata_neighbors, - flavor=flavor, - random_state=1, - copy=True, - directed=is_leiden_alg, - n_iterations=n_iterations, - ) - adata_2 = sc.tl.leiden( - adata_neighbors, - flavor=flavor, - random_state=3, - copy=True, - directed=is_leiden_alg, - n_iterations=n_iterations, - ) - # reproducible - pd.testing.assert_series_equal(adata_1.obs["leiden"], adata_1_again.obs["leiden"]) - assert ( - pytest.approx(adata_1.uns["leiden"]["modularity"]) - == adata_1_again.uns["leiden"]["modularity"] - ) - # different clustering - assert not adata_2.obs["leiden"].equals(adata_1_again.obs["leiden"]) - assert ( - pytest.approx(adata_2.uns["leiden"]["modularity"]) - != adata_1_again.uns["leiden"]["modularity"] + adata_1, adata_1_again, adata_2 = ( + sc.tl.leiden( + adata_neighbors, + flavor=flavor, + copy=True, + directed=is_leiden_alg, + n_iterations=n_iterations, + **{rng_arg: seed}, + ) + for seed in (1, 1, 42) ) + with subtests.test("reproducible"): + pd.testing.assert_series_equal( + adata_1.obs["leiden"], adata_1_again.obs["leiden"] + ) + assert ( + pytest.approx(adata_1.uns["leiden"]["modularity"]) + == adata_1_again.uns["leiden"]["modularity"] + ) + with subtests.test("different clustering"): + assert not adata_2.obs["leiden"].equals(adata_1_again.obs["leiden"]) + assert ( + pytest.approx(adata_2.uns["leiden"]["modularity"]) + != adata_1_again.uns["leiden"]["modularity"] + ) @needs.igraph diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 692157a084..10c09f932a 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_raises @@ -8,6 +10,9 @@ from testing.scanpy._helpers.data import pbmc68k_reduced from testing.scanpy._pytest.marks import needs +if TYPE_CHECKING: + from typing import Literal + @pytest.mark.parametrize( ("key_added", "key_obsm", "key_uns"), @@ -75,16 +80,18 @@ def test_umap_init_paga(layout): sc.tl.umap(pbmc, init_pos="paga") -def test_diffmap(): +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) +def test_diffmap( + subtests: pytest.Subtests, rng_arg: Literal["rng", "random_state"] +) -> None: pbmc = pbmc68k_reduced() - sc.tl.diffmap(pbmc) - d1 = pbmc.obsm["X_diffmap"].copy() - sc.tl.diffmap(pbmc) - d2 = pbmc.obsm["X_diffmap"].copy() - assert_array_equal(d1, d2) + d1, d2, d3 = ( + sc.tl.diffmap(pbmc, copy=True, **{rng_arg: seed}).obsm["X_diffmap"].copy() + for seed in (0, 0, 1234) + ) - # Checking if specifying random_state works, arrays shouldn't be equal - sc.tl.diffmap(pbmc, random_state=1234) - d3 = pbmc.obsm["X_diffmap"].copy() - assert_raises(AssertionError, assert_array_equal, d1, d3) + with subtests.test("reproducible"): + assert_array_equal(d1, d2) + with subtests.test("different embedding"): + assert_raises(AssertionError, assert_array_equal, d1, d3) diff --git a/tests/test_highly_variable_genes.py b/tests/test_highly_variable_genes.py index 670c647ab9..45e7133050 100644 --- a/tests/test_highly_variable_genes.py +++ b/tests/test_highly_variable_genes.py @@ -667,7 +667,12 @@ def test_seurat_v3_bad_chunking(adata, array_type, flavor): ], ) @pytest.mark.parametrize("batch_key", [None, "batch"]) -def test_subset_inplace_consistency(flavor, array_type, batch_key): +def test_subset_inplace_consistency( + subtests: pytest.Subtests, + flavor: Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"], + array_type, + batch_key: Literal["batch"] | None, +) -> None: """Tests `n_top_genes=n`. - if `inplace` and `subset` interact correctly @@ -675,12 +680,12 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): - for dask arrays and non-dask arrays - for both with and without batch_key """ + rng = np.random.default_rng(0) adata = ( - sc.datasets.blobs(n_observations=20, n_variables=80, random_state=0) + sc.datasets.blobs(n_observations=20, n_variables=80, rng=rng) if "seurat_v3" not in flavor else pbmc3k()[:1500, :1000].copy() ) - rng = np.random.default_rng(0) adata.obs["batch"] = rng.choice(["a", "b"], adata.shape[0]) adata.X = array_type(np.abs(adata.X).astype(int)) @@ -705,32 +710,35 @@ def test_subset_inplace_consistency(flavor, array_type, batch_key): inplace=inplace, ) - assert (output_df is None) == inplace - assert len(adata_copy.var if inplace else output_df) == ( - 15 if subset else n_genes - ) - assert sum((adata_copy.var if inplace else output_df)["highly_variable"]) == 15 + with subtests.test(subset=subset, inplace=inplace): + assert (output_df is None) == inplace + assert len(adata_copy.var if inplace else output_df) == ( + 15 if subset else n_genes + ) + assert ( + sum((adata_copy.var if inplace else output_df)["highly_variable"]) == 15 + ) - if not inplace: - assert isinstance(output_df, pd.DataFrame) + if not inplace: + assert isinstance(output_df, pd.DataFrame) - if inplace: - assert subset not in adatas - adatas[subset] = adata_copy - else: - assert subset not in dfs - dfs[subset] = output_df + if inplace: + assert subset not in adatas + adatas[subset] = adata_copy + else: + assert subset not in dfs + dfs[subset] = output_df - # check that the results are consistent for subset True/False: inplace True - adata_subset = adatas[False][:, adatas[False].var["highly_variable"]] - assert adata_subset.var_names.equals(adatas[True].var_names) + with subtests.test("consistency", inplace=True): + adata_subset = adatas[False][:, adatas[False].var["highly_variable"]] + assert adata_subset.var_names.equals(adatas[True].var_names) - # check that the results are consistent for subset True/False: inplace False - df_subset = dfs[False][dfs[False]["highly_variable"]] - assert df_subset.index.equals(dfs[True].index) + with subtests.test("consistency", inplace=False): + df_subset = dfs[False][dfs[False]["highly_variable"]] + assert df_subset.index.equals(dfs[True].index) - # check that the results are consistent for inplace True/False: subset True - assert adatas[True].var_names.equals(dfs[True].index) + with subtests.test("consistency", subset=True): + assert adatas[True].var_names.equals(dfs[True].index) @pytest.mark.parametrize( diff --git a/tests/test_neighbors_key_added.py b/tests/test_neighbors_key_added.py index 4256410c38..4e7b8ba88d 100644 --- a/tests/test_neighbors_key_added.py +++ b/tests/test_neighbors_key_added.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import numpy as np import pytest @@ -7,6 +9,12 @@ from testing.scanpy._helpers.data import pbmc68k_reduced from testing.scanpy._pytest.marks import needs +if TYPE_CHECKING: + from typing import Literal + + from anndata import AnnData + + n_neighbors = 5 key = "test" @@ -23,9 +31,12 @@ def adata(adata_session: sc.AnnData) -> sc.AnnData: return adata_session.copy() -def test_neighbors_key_added(adata: sc.AnnData) -> None: - sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0) - sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0, key_added=key) +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) +def test_neighbors_key_added( + adata: sc.AnnData, rng_arg: Literal["rng", "random_state"] +) -> None: + sc.pp.neighbors(adata, n_neighbors=n_neighbors, **{rng_arg: 0}) + sc.pp.neighbors(adata, n_neighbors=n_neighbors, **{rng_arg: 0}, key_added=key) conns_key = adata.uns[key]["connectivities_key"] dists_key = adata.uns[key]["distances_key"] @@ -50,48 +61,60 @@ def test_neighbors_pca_keys_added_without_previous_pca_run(adata: sc.AnnData) -> assert "pca" in adata.uns -# test functions with neighbors_key and obsp @needs.igraph @pytest.mark.parametrize("field", ["neighbors_key", "obsp"]) -def test_neighbors_key_obsp(adata, field): +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) +def test_neighbors_key_obsp( + subtests: pytest.Subtests, + adata: AnnData, + field: Literal["neighbors_key", "obsp"], + rng_arg: Literal["rng", "random_state"], +) -> None: + """Test functions with neighbors_key and obsp.""" adata1 = adata.copy() - sc.pp.neighbors(adata, n_neighbors=n_neighbors, random_state=0) - sc.pp.neighbors(adata1, n_neighbors=n_neighbors, random_state=0, key_added=key) + sc.pp.neighbors(adata, n_neighbors=n_neighbors, **{rng_arg: 0}) + sc.pp.neighbors(adata1, n_neighbors=n_neighbors, **{rng_arg: 0}, key_added=key) if field == "neighbors_key": arg = {field: key} else: arg = {field: adata1.uns[key]["connectivities_key"]} - sc.tl.draw_graph(adata, layout="fr", random_state=1) - sc.tl.draw_graph(adata1, layout="fr", random_state=1, **arg) + sc.tl.draw_graph(adata, layout="fr", **{rng_arg: 1}) + sc.tl.draw_graph(adata1, layout="fr", **{rng_arg: 1}, **arg) - assert adata.uns["draw_graph"]["params"] == adata1.uns["draw_graph"]["params"] - assert np.allclose(adata.obsm["X_draw_graph_fr"], adata1.obsm["X_draw_graph_fr"]) + with subtests.test("draw_graph"): + assert adata.uns["draw_graph"]["params"] == adata1.uns["draw_graph"]["params"] + assert np.allclose( + adata.obsm["X_draw_graph_fr"], adata1.obsm["X_draw_graph_fr"] + ) - sc.tl.leiden(adata, flavor="igraph", random_state=0) - sc.tl.leiden(adata1, flavor="igraph", random_state=0, **arg) + sc.tl.leiden(adata, flavor="igraph", **{rng_arg: 0}) + sc.tl.leiden(adata1, flavor="igraph", **{rng_arg: 0}, **arg) - assert adata.uns["leiden"]["params"] == adata1.uns["leiden"]["params"] - assert np.all(adata.obs["leiden"] == adata1.obs["leiden"]) + with subtests.test("leiden"): + assert adata.uns["leiden"]["params"] == adata1.uns["leiden"]["params"] + assert np.all(adata.obs["leiden"] == adata1.obs["leiden"]) # no obsp in umap, paga if field == "neighbors_key": - sc.tl.umap(adata, random_state=0) - sc.tl.umap(adata1, random_state=0, neighbors_key=key) + sc.tl.umap(adata, **{rng_arg: 0}) + sc.tl.umap(adata1, **{rng_arg: 0}, neighbors_key=key) - assert adata.uns["umap"]["params"] == adata1.uns["umap"]["params"] - assert np.allclose(adata.obsm["X_umap"], adata1.obsm["X_umap"]) + with subtests.test("umap"): + assert adata.uns["umap"]["params"] == adata1.uns["umap"]["params"] + assert np.allclose(adata.obsm["X_umap"], adata1.obsm["X_umap"]) sc.tl.paga(adata, groups="leiden") sc.tl.paga(adata1, groups="leiden", neighbors_key=key) - assert np.allclose( - adata.uns["paga"]["connectivities"].toarray(), - adata1.uns["paga"]["connectivities"].toarray(), - ) - assert np.allclose( - adata.uns["paga"]["connectivities_tree"].toarray(), - adata1.uns["paga"]["connectivities_tree"].toarray(), - ) + with subtests.test("paga"): + assert np.allclose( + adata.uns["paga"]["connectivities"].toarray(), + adata1.uns["paga"]["connectivities"].toarray(), + ) + assert np.allclose( + adata.uns["paga"]["connectivities_tree"].toarray(), + adata1.uns["paga"]["connectivities_tree"].toarray(), + ) diff --git a/tests/test_paga.py b/tests/test_paga.py index 5975f61eb4..e5b9a34c55 100644 --- a/tests/test_paga.py +++ b/tests/test_paga.py @@ -3,6 +3,7 @@ from functools import partial from importlib.util import find_spec from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import pytest @@ -14,6 +15,10 @@ from testing.scanpy._helpers.data import pbmc3k_processed, pbmc68k_reduced from testing.scanpy._pytest.marks import needs +if TYPE_CHECKING: + from typing import Literal + + HERE: Path = Path(__file__).parent ROOT = HERE / "_images" @@ -111,7 +116,10 @@ def test_paga_compare(image_comparer): save_and_compare_images("paga_compare_pbmc3k") -def test_paga_positions_reproducible(): +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) +def test_paga_positions_reproducible( + subtests: pytest.Subtests, rng_arg: Literal["rng", "random_state"] +) -> None: """Check exact reproducibility and effect of random_state on paga positions.""" # https://github.com/scverse/scanpy/issues/1859 pbmc = pbmc68k_reduced() @@ -121,9 +129,11 @@ def test_paga_positions_reproducible(): b = pbmc.copy() c = pbmc.copy() - sc.pl.paga(a, show=False, random_state=42) - sc.pl.paga(b, show=False, random_state=42) - sc.pl.paga(c, show=False, random_state=13) + sc.pl.paga(a, show=False, **{rng_arg: 42}) + sc.pl.paga(b, show=False, **{rng_arg: 42}) + sc.pl.paga(c, show=False, **{rng_arg: 13}) - np.testing.assert_array_equal(a.uns["paga"]["pos"], b.uns["paga"]["pos"]) - assert a.uns["paga"]["pos"].tolist() != c.uns["paga"]["pos"].tolist() + with subtests.test("reproducible"): + np.testing.assert_array_equal(a.uns["paga"]["pos"], b.uns["paga"]["pos"]) + with subtests.test("different positions"): + assert a.uns["paga"]["pos"].tolist() != c.uns["paga"]["pos"].tolist() diff --git a/tests/test_pca.py b/tests/test_pca.py index d4cd5e58f3..af618fbdee 100644 --- a/tests/test_pca.py +++ b/tests/test_pca.py @@ -333,7 +333,10 @@ def test_pca_sparse(key_added: str | None, keys_expected: tuple[str, str, str]): np.testing.assert_allclose(implicit.varm["PCs"], explicit.varm[key_varm]) -def test_pca_reproducible(array_type): +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) +def test_pca_reproducible( + subtests: pytest.Subtests, array_type, rng_arg: Literal["rng", "random_state"] +): pbmc = pbmc3k_normalized() pbmc.X = array_type(pbmc.X) @@ -342,18 +345,21 @@ def test_pca_reproducible(array_type): if isinstance(pbmc.X, DaskArray) and isinstance(pbmc.X._meta, CSBase) else nullcontext() ): - a = sc.pp.pca(pbmc, copy=True, dtype=np.float64, random_state=42) - b = sc.pp.pca(pbmc, copy=True, dtype=np.float64, random_state=42) - c = sc.pp.pca(pbmc, copy=True, dtype=np.float64, random_state=0) + a, b, c = ( + sc.pp.pca(pbmc, copy=True, dtype=np.float64, **{rng_arg: seed}) + for seed in (42, 42, 0) + ) - assert_equal(a, b) + with subtests.test("reproducible"): + assert_equal(a, b) # Test that changing random seed changes result # Does not show up reliably with 32 bit computation # sparse-in-dask doesn’t use a random seed, so it also doesn’t work there. if not (isinstance(pbmc.X, DaskArray) and isinstance(pbmc.X._meta, CSBase)): a, c = map(AnnData.to_memory, [a, c]) - assert not np.array_equal(a.obsm["X_pca"], c.obsm["X_pca"]) + with subtests.test("different embedding"): + assert not np.array_equal(a.obsm["X_pca"], c.obsm["X_pca"]) def test_pca_chunked() -> None: diff --git a/tests/test_scrublet.py b/tests/test_scrublet.py index 9844d908c3..1b33a5dbd9 100644 --- a/tests/test_scrublet.py +++ b/tests/test_scrublet.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from collections.abc import Callable - from typing import Any + from typing import Any, Literal pytestmark = [needs.skimage] @@ -119,19 +119,20 @@ def _create_sim_from_parents(adata: AnnData, parents: np.ndarray) -> AnnData: ) -def test_scrublet_data(cache: pytest.Cache): +@pytest.mark.parametrize("rng_arg", ["rng", "random_state"]) +def test_scrublet_data(rng_arg: Literal["rng", "random_state"]) -> None: """Test that Scrublet processing is arranged correctly. Check that simulations run on raw data. """ - random_state = 1234 + seed = 1234 # Run Scrublet and let the main function run simulations adata_scrublet_auto_sim = sc.pp.scrublet( pbmc200(), use_approx_neighbors=False, copy=True, - random_state=random_state, + **{rng_arg: seed}, ) # Now make our own simulated data so we can check the result from function @@ -154,7 +155,7 @@ def test_scrublet_data(cache: pytest.Cache): adata_sim=adata_sim, use_approx_neighbors=False, copy=True, - random_state=random_state, + **{rng_arg: seed}, ) # Require that the doublet scores are the same whether simulation is via From a4b2d121a90aa5e2490ef21167e87c16e1e57e78 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 27 Feb 2026 16:31:20 +0100 Subject: [PATCH 15/20] =?UTF-8?q?Selman=E2=80=99s=20findings?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/scanpy/neighbors/__init__.py | 7 +++---- src/scanpy/preprocessing/_pca/__init__.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index d442967c5e..988afee7e6 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -852,12 +852,15 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None: self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) + @_accepts_legacy_random_state(0) def compute_eigen( self, *, n_comps: int = 15, sort: Literal["decrease", "increase"] = "decrease", rng: np.random.Generator, + # unused + sym: None = None, ) -> None: """Compute eigen decomposition of transition matrix. @@ -866,10 +869,6 @@ def compute_eigen( n_comps Number of eigenvalues/vectors to be computed, set `n_comps = 0` if you need all eigenvectors. - sym - Instead of computing the eigendecomposition of the assymetric - transition matrix, computed the eigendecomposition of the symmetric - Ktilde matrix. rng A numpy random number generator diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 94d69570dd..c96ef7e0d4 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -248,7 +248,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 # dask needs an int for random state rng = np.random.default_rng(rng) - if not isinstance(rng, _FakeRandomGen) and not isinstance( + if not isinstance(rng, _FakeRandomGen) or not isinstance( rng._arg, int | np.random.RandomState ): # TODO: remove this error and if we don’t have a _FakeRandomGen, From 0c30aa4796d5ab478723a3adbc24b47c79f79ff7 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 9 Mar 2026 09:44:09 +0100 Subject: [PATCH 16/20] ingest --- src/scanpy/tools/_ingest.py | 80 +++++++++++++++++++++++++++++++------ 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 7bc0e2eb6d..03ab5af698 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -10,15 +10,19 @@ from .. import logging as logg from .._compat import CSBase from .._settings import settings -from .._utils import NeighborsView, raise_not_implemented_error_if_backed_type +from .._utils import NeighborsView, _empty, raise_not_implemented_error_if_backed_type from .._utils._doctests import doctest_skip +from .._utils.random import _FakeRandomGen, _legacy_random_state from ..neighbors import FlatTree if TYPE_CHECKING: from collections.abc import Generator, Iterable from anndata import AnnData + from pynndescent import NNDescent + from umap import UMAP + from .._utils import Empty from ..neighbors import RPForestDict @@ -201,19 +205,56 @@ class Ingest: Parameters ---------- - adata : :class:`~anndata.AnnData` + adata The annotated data matrix of shape `n_obs` × `n_vars` with embeddings and labels. + rng + Random number generator. + `None` means non-determinism. + By default uses the tools’ seeds if available, else seed `0`. """ - def _init_umap(self, adata): + # rng + _rng: np.random.Generator | None + # neighbors + _rep: np.ndarray + _use_rep: str + _metric: str + _metric_kwds: dict[str, object] + _n_neighbors: int + _n_pcs: int | None + _nnd_idx: NNDescent + # umap + _umap: UMAP + # pca + _pca_centered: bool + _pca_use_hvg: bool + _pca_basis: np.ndarray + # adata + _adata_ref: AnnData + _adata_new: AnnData | None + _obs: pd.DataFrame | None + _obsm: _DimDict | None + _labels: pd.Series | None + _indices: np.ndarray | None + _distances: np.ndarray | None + + def _get_rng(self, params: dict[str, object]) -> np.random.Generator | None: + if self._rng is None: # indicates we want non-determinism + return None + random_state = params.get("random_state", 0) + return _FakeRandomGen(random_state) + + def _init_umap(self, adata: AnnData) -> None: from umap import UMAP + rng = self._get_rng(adata.uns["umap"]["params"]) self._umap = UMAP( metric=self._metric, - random_state=adata.uns["umap"]["params"].get("random_state", 0), - n_jobs=1, # umap can’t be run in parallel with random_state != None + random_state=_legacy_random_state(rng), + # umap can’t be run in parallel with `random_state is not None` + n_jobs=-1 if rng is None else 1, ) self._umap._initial_alpha = self._umap.learning_rate @@ -237,7 +278,9 @@ def _init_umap(self, adata): self._umap._input_hash = None - def _init_pynndescent(self, distances): + def _init_pynndescent( + self, distances: CSBase, rng: np.random.Generator | None + ) -> None: from pynndescent import NNDescent first_col = np.arange(distances.shape[0])[:, None] @@ -249,7 +292,7 @@ def _init_pynndescent(self, distances): metric_kwds=self._metric_kwds, n_neighbors=self._n_neighbors, init_graph=init_indices, - random_state=self._neigh_random_state, + random_state=_legacy_random_state(rng), ) # temporary hack for the broken forest storage @@ -267,7 +310,7 @@ def _init_pynndescent(self, distances): self._nnd_idx._angular_trees, ) - def _init_neighbors(self, adata, neighbors_key): + def _init_neighbors(self, adata: AnnData, neighbors_key: str | None) -> None: neighbors = NeighborsView(adata, neighbors_key) self._n_neighbors = neighbors["params"]["n_neighbors"] @@ -287,10 +330,11 @@ def _init_neighbors(self, adata, neighbors_key): self._metric_kwds = neighbors["params"].get("metric_kwds", {}) self._metric = neighbors["params"]["metric"] - self._neigh_random_state = neighbors["params"].get("random_state", 0) - self._init_pynndescent(neighbors["distances"]) + self._init_pynndescent( + neighbors["distances"], rng=self._get_rng(neighbors["params"]) + ) - def _init_pca(self, adata): + def _init_pca(self, adata: AnnData) -> None: self._pca_centered = adata.uns["pca"]["params"]["zero_center"] self._pca_use_hvg = adata.uns["pca"]["params"]["use_highly_variable"] @@ -304,11 +348,23 @@ def _init_pca(self, adata): else: self._pca_basis = adata.varm["PCs"] - def __init__(self, adata: AnnData, neighbors_key: str | None = None): + def __init__( + self, + adata: AnnData, + neighbors_key: str | None = None, + *, + rng: np.random.Generator | None | Empty = _empty, + ) -> None: # assume rep is X if all initializations fail to identify it self._rep = adata.X self._use_rep = "X" + self._rng = ( + None + if rng is None + else np.random.default_rng(None if rng is _empty else rng) + ) + self._n_pcs = None self._adata_ref = adata From e8a411eb649dc4682b5900d02a6b5a910835c66f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 9 Mar 2026 09:49:10 +0100 Subject: [PATCH 17/20] rename --- src/scanpy/_utils/random.py | 52 +++++++++---------- src/scanpy/neighbors/__init__.py | 4 +- src/scanpy/plotting/_tools/paga.py | 4 +- .../preprocessing/_deprecated/sampling.py | 4 +- src/scanpy/preprocessing/_pca/__init__.py | 12 ++--- .../preprocessing/_scrublet/__init__.py | 4 +- src/scanpy/preprocessing/_utils.py | 6 +-- src/scanpy/tools/_dpt.py | 4 +- src/scanpy/tools/_draw_graph.py | 4 +- src/scanpy/tools/_ingest.py | 4 +- src/scanpy/tools/_leiden.py | 13 ++--- src/scanpy/tools/_umap.py | 4 +- tests/test_utils.py | 9 +--- 13 files changed, 57 insertions(+), 67 deletions(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 15fa95978b..2e5c63007d 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -52,15 +52,15 @@ def __init__(self, rng: SeedLike | RNGLike | None) -> None: self._rng = np.random.default_rng(rng) def getrandbits(self, k: int) -> int: - if isinstance(self._rng, _FakeRandomGen): - i = self._rng._state.tomaxint() + if isinstance(self._rng, _LegacyRng): + i = self._rng.state.tomaxint() else: lims = np.iinfo(np.uint64) i = int(self._rng.integers(0, lims.max, dtype=np.uint64, endpoint=True)) return i & ((1 << k) - 1) def randint(self, a: int, b: int) -> np.int64: - """Can’t use `endpoint` here as _FakeRandomGen doesn’t support it.""" + """Can’t use `endpoint` here as _LegacyRng doesn’t support it.""" return self._rng.integers(a, b + 1) def __getattr__(self, attr: str): @@ -85,7 +85,7 @@ def _set_igraph_rng(rng: SeedLike | RNGLike | None) -> Generator[None]: ################################### -class _FakeRandomGen(np.random.Generator): +class _LegacyRng(np.random.Generator): """A `Generator` that wraps a legacy `RandomState` instance. To behave like a `RandomState`, it’s not enough to just use a MT19937 `bit_generator` @@ -93,18 +93,18 @@ class _FakeRandomGen(np.random.Generator): so instead this hack uses the exact same random numbers as `RandomState(seed)`. """ - _arg: _LegacyRandom - _state: np.random.RandomState + arg: _LegacyRandom + state: np.random.RandomState def __init__( self, arg: _LegacyRandom, state: np.random.RandomState | None = None ) -> None: - self._arg = arg - self._state = check_random_state(arg) if state is None else state + self.arg = arg + self.state = check_random_state(arg) if state is None else state @property def bit_generator(self) -> BitGenerator: - msg = "A _FakeRandomGen instance has no `bit_generator` attribute." + msg = "A _LegacyRng instance has no `bit_generator` attribute." raise AttributeError(msg) @classmethod @@ -117,9 +117,9 @@ def wrap_global( if arg is not None: if isinstance(arg, np.random.RandomState): np.random.set_state(arg.get_state(legacy=False)) - return _FakeRandomGen(arg, state) + return _LegacyRng(arg, state) np.random.seed(arg) - return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator())) + return _LegacyRng(arg, np.random.RandomState(np.random.get_bit_generator())) def spawn(self, n_children: int) -> list[Self]: """Return `self` `n_children` times. @@ -139,27 +139,27 @@ def _delegate(cls) -> None: def mk_wrapper(name: str, meth): # Old pytest versions try to run the doctests @wraps(meth, assigned=set(WRAPPER_ASSIGNMENTS) - {"__doc__"}) - def wrapper(self: _FakeRandomGen, *args, **kwargs): - return getattr(self._state, name)(*args, **kwargs) + def wrapper(self: _LegacyRng, *args, **kwargs): + return getattr(self.state, name)(*args, **kwargs) return wrapper setattr(cls, names.get(name, name), mk_wrapper(name, meth)) -_FakeRandomGen._delegate() +_LegacyRng._delegate() def _if_legacy_apply_global(rng: np.random.Generator, /) -> np.random.Generator: - """Wrap the global legacy RNG if `rng` is a `_FakeRandomGen`. + """Wrap the global legacy RNG if `rng` is a `_LegacyRng`. This is used where our code used to call `np.random.seed()`. - It’s a no-op if `rng` is not a `_FakeRandomGen`. + It’s a no-op if `rng` is not a `_LegacyRng`. """ - if not isinstance(rng, _FakeRandomGen): + if not isinstance(rng, _LegacyRng): return rng - return _FakeRandomGen.wrap_global(rng._arg, rng._state) + return _LegacyRng.wrap_global(rng.arg, rng.state) def _legacy_random_state( @@ -167,10 +167,10 @@ def _legacy_random_state( ) -> _LegacyRandom: """Convert a np.random.Generator into a legacy `random_state` argument. - If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute. + If `rng` is already a `_LegacyRng`, return its original `arg` attribute. """ - if isinstance(rng, _FakeRandomGen): - return rng._state if always_state else rng._arg + if isinstance(rng, _LegacyRng): + return rng.state if always_state else rng.arg [bitgen] = np.random.default_rng(rng).bit_generator.spawn(1) return np.random.RandomState(bitgen) @@ -181,9 +181,9 @@ def _accepts_legacy_random_state[**P, R]( """Make a function accept `random_state: _LegacyRandom` and pass it as `rng`. If the decorated function is called with a `random_state` argument, - it’ll be wrapped in a :class:`_FakeRandomGen`. - Passing both ``rng`` and ``random_state`` at the same time is an error. - If neither is given, ``random_state_default`` is used. + it’ll be wrapped in a `_LegacyRng`. + Passing both `rng` and `random_state` at the same time is an error. + If neither is given, `random_state_default` is used. """ def decorator(func: Callable[P, R]) -> Callable[P, R]: @@ -194,9 +194,9 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: msg = "Specify at most one of `rng` and `random_state`." raise TypeError(msg) case True, False: - kwargs["rng"] = _FakeRandomGen(kwargs.pop("random_state")) + kwargs["rng"] = _LegacyRng(kwargs.pop("random_state")) case False, False: - kwargs["rng"] = _FakeRandomGen(random_state_default) + kwargs["rng"] = _LegacyRng(random_state_default) return func(*args, **kwargs) return wrapper diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 988afee7e6..13ce6c8e2d 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -20,8 +20,8 @@ from .._utils import NeighborsView, _doc_params, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, - _FakeRandomGen, _legacy_random_state, + _LegacyRng, ) from . import _connectivity from ._common import ( @@ -209,7 +209,7 @@ def neighbors( # noqa: PLR0913 """ meta_random_state = ( - dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) if distances is None: diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index a751bd8cd3..fa286ef760 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -23,7 +23,7 @@ from ..._settings import settings from ..._utils.random import ( _accepts_legacy_random_state, - _FakeRandomGen, + _LegacyRng, _set_igraph_rng, ) from .. import _utils @@ -225,7 +225,7 @@ def _compute_pos( # noqa: PLR0912 ) raise ValueError(msg) else: # igraph layouts - if isinstance(rng, _FakeRandomGen): # backwards compat + if isinstance(rng, _LegacyRng): # backwards compat random.seed(rng.bytes(8)) ctx = nullcontext() else: diff --git a/src/scanpy/preprocessing/_deprecated/sampling.py b/src/scanpy/preprocessing/_deprecated/sampling.py index 2663462a68..40913f2c35 100644 --- a/src/scanpy/preprocessing/_deprecated/sampling.py +++ b/src/scanpy/preprocessing/_deprecated/sampling.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from ..._utils.random import _FakeRandomGen +from ..._utils.random import _LegacyRng from .._simple import sample if TYPE_CHECKING: @@ -50,7 +50,7 @@ def subsample( returns a subsampled copy of it (`copy == True`). """ - rng = _FakeRandomGen.wrap_global(random_state) + rng = _LegacyRng.wrap_global(random_state) return sample( data=data, fraction=fraction, n=n_obs, rng=rng, copy=copy, replace=False, axis=0 ) diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index c96ef7e0d4..d9b349a1fb 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -12,8 +12,8 @@ from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type from ..._utils.random import ( _accepts_legacy_random_state, - _FakeRandomGen, _legacy_random_state, + _LegacyRng, ) from ...get import _check_mask, _get_obs_rep from .._docs import doc_mask_var_hvg @@ -248,10 +248,10 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 # dask needs an int for random state rng = np.random.default_rng(rng) - if not isinstance(rng, _FakeRandomGen) or not isinstance( - rng._arg, int | np.random.RandomState + if not isinstance(rng, _LegacyRng) or not isinstance( + rng.arg, int | np.random.RandomState ): - # TODO: remove this error and if we don’t have a _FakeRandomGen, + # TODO: remove this error and if we don’t have a _LegacyRng, # just use rng.integers to make a seed farther down msg = f"rng needs to be an int or a np.random.RandomState, not a {type(rng).__name__} when passing a dask array" raise TypeError(msg) @@ -259,7 +259,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 if chunked: if ( not zero_center - or (not isinstance(rng, _FakeRandomGen) or rng._arg != 0) + or (not isinstance(rng, _LegacyRng) or rng.arg != 0) or (svd_solver is not None and svd_solver != "arpack") ): logg.debug("Ignoring zero_center, rng, svd_solver") @@ -310,7 +310,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh": from ._dask import PCAEighDask - if not isinstance(rng, _FakeRandomGen) or rng._arg != 0: + if not isinstance(rng, _LegacyRng) or rng.arg != 0: msg = f"Ignoring random_state={_legacy_random_state(rng)} when using a sparse dask array" warn(msg, UserWarning) if svd_solver not in {None, "covariance_eigh"}: diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index c0979d212f..68dce62abb 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,7 +10,7 @@ from ... import logging as logg from ... import preprocessing as pp -from ..._utils.random import _accepts_legacy_random_state, _FakeRandomGen +from ..._utils.random import _accepts_legacy_random_state, _LegacyRng from ...get import _get_obs_rep from . import pipeline from .core import Scrublet @@ -390,7 +390,7 @@ def _scrublet_call_doublets( # noqa: PLR0913 """ meta_random_state = ( - dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) rng_scrub, rng_pca = rng.spawn(2) del rng diff --git a/src/scanpy/preprocessing/_utils.py b/src/scanpy/preprocessing/_utils.py index 9647739957..6ff2112c93 100644 --- a/src/scanpy/preprocessing/_utils.py +++ b/src/scanpy/preprocessing/_utils.py @@ -4,7 +4,7 @@ import numpy as np -from .._utils.random import _FakeRandomGen +from .._utils.random import _LegacyRng if TYPE_CHECKING: from numpy.typing import NDArray @@ -14,10 +14,10 @@ def sample_comb( dims: tuple[int, ...], nsamp: int, *, rng: np.random.Generator ) -> NDArray[np.int64]: """Randomly sample indices from a grid, without repeating the same tuple.""" - if isinstance(rng, _FakeRandomGen): + if isinstance(rng, _LegacyRng): from sklearn.random_projection import sample_without_replacement - idx = sample_without_replacement(np.prod(dims), nsamp, random_state=rng._arg) + idx = sample_without_replacement(np.prod(dims), nsamp, random_state=rng.arg) else: idx = rng.choice(np.prod(dims), size=nsamp, replace=False) return np.vstack(np.unravel_index(idx, dims)).T diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 2b80b61bc9..74d7fe66ae 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -8,7 +8,7 @@ from natsort import natsorted from .. import logging as logg -from .._utils.random import _FakeRandomGen +from .._utils.random import _LegacyRng from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -147,7 +147,7 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key, rng=_FakeRandomGen(0)) + _diffmap(adata, neighbors_key=neighbors_key, rng=_LegacyRng(0)) # start with the actual computation dpt = DPT( adata, diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 75e46c570a..2baa7f9e9d 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -10,8 +10,8 @@ from .._utils import _choose_graph, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, - _FakeRandomGen, _if_legacy_apply_global, + _LegacyRng, _set_igraph_rng, ) from ._utils import get_init_pos_from_paga @@ -119,7 +119,7 @@ def draw_graph( # noqa: PLR0913 start = logg.info(f"drawing single-cell graph using layout {layout!r}") rng = np.random.default_rng(rng) meta_random_state = ( - dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) rng = _if_legacy_apply_global(rng) rng_init, rng_layout = rng.spawn(2) diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 03ab5af698..e24eab7873 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -12,7 +12,7 @@ from .._settings import settings from .._utils import NeighborsView, _empty, raise_not_implemented_error_if_backed_type from .._utils._doctests import doctest_skip -from .._utils.random import _FakeRandomGen, _legacy_random_state +from .._utils.random import _legacy_random_state, _LegacyRng from ..neighbors import FlatTree if TYPE_CHECKING: @@ -244,7 +244,7 @@ def _get_rng(self, params: dict[str, object]) -> np.random.Generator | None: if self._rng is None: # indicates we want non-determinism return None random_state = params.get("random_state", 0) - return _FakeRandomGen(random_state) + return _LegacyRng(random_state) def _init_umap(self, adata: AnnData) -> None: from umap import UMAP diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index b38440de9e..49179e48bb 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -10,11 +10,7 @@ from .. import logging as logg from .._compat import warn from .._utils import _doc_params -from .._utils.random import ( - _accepts_legacy_random_state, - _FakeRandomGen, - _set_igraph_rng, -) +from .._utils.random import _accepts_legacy_random_state, _LegacyRng, _set_igraph_rng from ._docs import ( doc_adata, doc_adjacency, @@ -141,7 +137,7 @@ def leiden( # noqa: PLR0913 clustering_args = dict(clustering_args) rng = np.random.default_rng(rng) meta_random_state = ( - dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) start = logg.info("running Leiden clustering") @@ -174,9 +170,8 @@ def leiden( # noqa: PLR0913 if use_weights: clustering_args["weights"] = np.array(g.es["weight"]).astype(np.float64) seed = ( - rng._arg - if isinstance(rng, _FakeRandomGen) - and isinstance(rng._arg, int | np.integer) + rng.arg + if isinstance(rng, _LegacyRng) and isinstance(rng.arg, int | np.integer) # for some reason leidenalg only accepts int32 (signed) seeds … else rng.integers((i := np.iinfo(np.int32)).min, i.max, dtype=np.int32) ) diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 8e3cb0d0ca..7850674400 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -12,8 +12,8 @@ from .._utils import NeighborsView from .._utils.random import ( _accepts_legacy_random_state, - _FakeRandomGen, _legacy_random_state, + _LegacyRng, ) from ._utils import _choose_representation, get_init_pos_from_paga @@ -150,7 +150,7 @@ def umap( # noqa: PLR0913, PLR0915 """ rng_init, rng_umap = np.random.default_rng(rng).spawn(2) meta_random_state = ( - dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {} + dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) del rng diff --git a/tests/test_utils.py b/tests/test_utils.py index 40852a972e..c838ccf00f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,12 +17,7 @@ check_nonnegative_integers, descend_classes_and_funcs, ) -from scanpy._utils.random import ( - _FakeRandomGen, - ith_k_tuple, - random_k_tuples, - random_str, -) +from scanpy._utils.random import _LegacyRng, ith_k_tuple, random_k_tuples, random_str from testing.scanpy._pytest.params import ( ARRAY_TYPES, ARRAY_TYPES_DASK, @@ -206,7 +201,7 @@ def test_legacy_numpy_gen(*, seed: int, pass_seed: bool, func: str): def _mk_random(func: str, *, direct: bool, seed: int | None) -> np.ndarray: if direct and seed is not None: np.random.seed(seed) - gen = np.random if direct else _FakeRandomGen.wrap_global(seed) + gen = np.random if direct else _LegacyRng.wrap_global(seed) match func: case "choice": arr = np.arange(1000) From 40775020456a8293f2dd0e9290cbae09872e9926 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 9 Mar 2026 09:50:15 +0100 Subject: [PATCH 18/20] spawn docs --- src/scanpy/_utils/random.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/_utils/random.py b/src/scanpy/_utils/random.py index 2e5c63007d..591fac5562 100644 --- a/src/scanpy/_utils/random.py +++ b/src/scanpy/_utils/random.py @@ -125,7 +125,7 @@ def spawn(self, n_children: int) -> list[Self]: """Return `self` `n_children` times. In a real generator, the spawned children are independent, - but for backwards compatibility we return the same instance. + but for backwards compatibility we return the same instance so that its internal state is advanced by each child. """ return [self] * n_children From 57971295c3f552a9d6adb52937e116b3962fab1f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 9 Mar 2026 10:06:44 +0100 Subject: [PATCH 19/20] fix paga --- src/scanpy/tools/_draw_graph.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 2baa7f9e9d..b26c671789 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -17,6 +17,7 @@ from ._utils import get_init_pos_from_paga if TYPE_CHECKING: + from contextlib import AbstractContextManager from typing import LiteralString from anndata import AnnData @@ -121,7 +122,6 @@ def draw_graph( # noqa: PLR0913 meta_random_state = ( dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) - rng = _if_legacy_apply_global(rng) rng_init, rng_layout = rng.spawn(2) del rng if layout not in (layouts := get_literal_vals(_Layout)): @@ -133,7 +133,7 @@ def draw_graph( # noqa: PLR0913 # init coordinates if init_pos in adata.obsm: init_coords = adata.obsm[init_pos] - elif init_pos == "paga" or init_pos: + elif init_pos: # "paga" or True init_coords = get_init_pos_from_paga( adata, adjacency, @@ -142,6 +142,7 @@ def draw_graph( # noqa: PLR0913 obsp=obsp, ) else: + _if_legacy_apply_global(rng_init) init_coords = rng_init.random((adjacency.shape[0], 2)) layout = coerce_fa2_layout(layout) # actual drawing @@ -149,7 +150,7 @@ def draw_graph( # noqa: PLR0913 positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) else: g = _utils.get_igraph_from_adjacency(adjacency) - with _set_igraph_rng(rng_layout): + with _igraph_rng_compat(rng_layout): if layout in {"fr", "drl", "kk", "grid_fr"}: ig_layout = g.layout(layout, seed=init_coords.tolist(), **kwds) elif "rt" in layout: @@ -217,3 +218,17 @@ def coerce_fa2_layout[S: LiteralString](layout: S) -> S | Literal["fa", "fr"]: return "fr" return "fa" + + +def _igraph_rng_compat(rng: SeedLike | RNGLike | None) -> AbstractContextManager[None]: + """Context manager that sets the igraph RNG to the given RNG. + + For legacy code, this just calls `random.seed()`. + """ + import random + from contextlib import nullcontext + + if isinstance(rng, _LegacyRng): + random.seed(rng.arg) + return nullcontext() + return _set_igraph_rng(rng) From 9bc61f0d420e651ce053f259c3bbde8f9dcd0504 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 9 Mar 2026 11:22:05 +0100 Subject: [PATCH 20/20] docs reuse --- src/scanpy/_docs.py | 17 +++++++++++++++++ src/scanpy/datasets/_datasets.py | 6 ++++-- src/scanpy/neighbors/__init__.py | 10 +++++----- src/scanpy/plotting/_tools/paga.py | 10 ++++++---- src/scanpy/preprocessing/_pca/__init__.py | 8 +++----- src/scanpy/preprocessing/_scrublet/__init__.py | 6 ++++-- src/scanpy/preprocessing/_scrublet/core.py | 9 ++++++--- src/scanpy/preprocessing/_simple.py | 12 ++++++++---- src/scanpy/tools/_diffmap.py | 6 ++++-- src/scanpy/tools/_draw_graph.py | 10 ++++++---- src/scanpy/tools/_ingest.py | 14 +++++++++----- src/scanpy/tools/_leiden.py | 5 +++-- src/scanpy/tools/_score_genes.py | 7 ++++--- src/scanpy/tools/_tsne.py | 7 +++---- src/scanpy/tools/_umap.py | 9 ++++----- 15 files changed, 86 insertions(+), 50 deletions(-) create mode 100644 src/scanpy/_docs.py diff --git a/src/scanpy/_docs.py b/src/scanpy/_docs.py new file mode 100644 index 0000000000..d4fccb5418 --- /dev/null +++ b/src/scanpy/_docs.py @@ -0,0 +1,17 @@ +"""Shared docstrings for general parameters.""" + +from __future__ import annotations + +__all__ = ["doc_rng"] + +doc_rng = """\ +rng + Random number generation to control stochasticity. + + If a type:`SeedLike` value, it’s used to seed a new random number generator; + If a :class:`numpy.random.Generator`, `rng`’s state will be directly advanced; + If :data:`None`, a non-reproducible random number generator is used. + See :func:`numpy.random.default_rng` for more details. + + The default value matches legacy scanpy behavior and will change to `None` in scanpy 2.0. +""" diff --git a/src/scanpy/datasets/_datasets.py b/src/scanpy/datasets/_datasets.py index a3775fe3a4..0ca4dd7feb 100644 --- a/src/scanpy/datasets/_datasets.py +++ b/src/scanpy/datasets/_datasets.py @@ -10,7 +10,9 @@ from .. import _utils from .._compat import deprecated +from .._docs import doc_rng from .._settings import settings +from .._utils import _doc_params from .._utils._doctests import doctest_internet, doctest_needs from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..readwrite import read, read_h5ad, read_visium @@ -55,6 +57,7 @@ HERE = Path(__file__).parent +@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) def blobs( *, @@ -77,8 +80,7 @@ def blobs( n_observations Number of observations. By default, this is the same observation number as in :func:`scanpy.datasets.krumsiek11`. - rng - Determines random number generation for dataset creation. + {rng} Returns ------- diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 13ce6c8e2d..01c03e4533 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -16,6 +16,7 @@ from .. import _utils from .. import logging as logg from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn +from .._docs import doc_rng from .._settings import settings from .._utils import NeighborsView, _doc_params, get_literal_vals from .._utils.random import ( @@ -81,7 +82,7 @@ class NeighborsParams(TypedDict): # noqa: D101 n_pcs: NotRequired[int] -@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep) +@_doc_params(n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) @_accepts_legacy_random_state(0) def neighbors( # noqa: PLR0913 adata: AnnData, @@ -163,8 +164,7 @@ def neighbors( # noqa: PLR0913 Options for the metric. *ignored if ``transformer`` is an instance.* - rng - A numpy random number generator. + {rng} *ignored if ``transformer`` is an instance.* key_added @@ -852,6 +852,7 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None: self._transitions_sym = self.Z @ conn_norm @ self.Z logg.info(" finished", time=start) + @_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) def compute_eigen( self, @@ -869,8 +870,7 @@ def compute_eigen( n_comps Number of eigenvalues/vectors to be computed, set `n_comps = 0` if you need all eigenvectors. - rng - A numpy random number generator + {rng} Returns ------- diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index fa286ef760..6736bd171d 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -20,7 +20,9 @@ from ... import _utils as _sc_utils from ... import logging as logg from ..._compat import CSBase +from ..._docs import doc_rng from ..._settings import settings +from ..._utils import _doc_params from ..._utils.random import ( _accepts_legacy_random_state, _LegacyRng, @@ -334,6 +336,7 @@ def make_pos( return make_pos({}) +@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) def paga( # noqa: PLR0912, PLR0913, PLR0915 adata: AnnData, @@ -426,10 +429,9 @@ def paga( # noqa: PLR0912, PLR0913, PLR0915 init_pos Two-column array storing the x and y coordinates for initializing the layout. - rng - For layouts with random initialization like `'fr'`, change this to use - different intial states for the optimization. If `None`, the initial - state is not reproducible. + {rng} + + Applies to layouts with random initialization like `'fr'`. root If choosing a tree layout, this is the index of the root node or a list of root node indices. If this is a non-empty vector then the supplied diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index d9b349a1fb..ea5576fa22 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -8,6 +8,7 @@ from ... import logging as logg from ..._compat import CSBase, DaskArray, pkg_version, warn +from ..._docs import doc_rng from ..._settings import settings from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type from ..._utils.random import ( @@ -55,9 +56,7 @@ type SvdSolver = SvdSolvDaskML | SvdSolvSkearn | SvdSolvPCACustom -@_doc_params( - mask_var_hvg=doc_mask_var_hvg, -) +@_doc_params(mask_var_hvg=doc_mask_var_hvg, rng=doc_rng) @_accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 data: AnnData | np.ndarray | CSBase, @@ -162,8 +161,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 chunk_size Number of observations to include in each chunk. Required if `chunked=True` was passed. - rng - Change to use different initial states for the optimization. + {rng} return_info Only relevant when not passing an :class:`~anndata.AnnData`: see “Returns”. diff --git a/src/scanpy/preprocessing/_scrublet/__init__.py b/src/scanpy/preprocessing/_scrublet/__init__.py index 68dce62abb..f69b0c4471 100644 --- a/src/scanpy/preprocessing/_scrublet/__init__.py +++ b/src/scanpy/preprocessing/_scrublet/__init__.py @@ -10,6 +10,8 @@ from ... import logging as logg from ... import preprocessing as pp +from ..._docs import doc_rng +from ..._utils import _doc_params from ..._utils.random import _accepts_legacy_random_state, _LegacyRng from ...get import _get_obs_rep from . import pipeline @@ -21,6 +23,7 @@ @_accepts_legacy_random_state(0) +@_doc_params(rng=doc_rng) def scrublet( # noqa: PLR0913 adata: AnnData, adata_sim: AnnData | None = None, @@ -129,8 +132,7 @@ def scrublet( # noqa: PLR0913 copy If :data:`True`, return a copy of the input ``adata`` with Scrublet results added. Otherwise, Scrublet results are added in place. - rng - Initial state for doublet simulation and nearest neighbors. + {rng} Returns ------- diff --git a/src/scanpy/preprocessing/_scrublet/core.py b/src/scanpy/preprocessing/_scrublet/core.py index a11ebf2f98..601bd574af 100644 --- a/src/scanpy/preprocessing/_scrublet/core.py +++ b/src/scanpy/preprocessing/_scrublet/core.py @@ -9,6 +9,8 @@ from scipy import sparse from ... import logging as logg +from ..._docs import doc_rng +from ..._utils import _doc_params from ...neighbors import ( Neighbors, _get_indices_distances_from_sparse_matrix, @@ -26,6 +28,7 @@ __all__ = ["Scrublet"] +@_doc_params(rng=doc_rng) @dataclass(kw_only=True) class Scrublet: """Initialize Scrublet object with counts matrix and doublet prediction parameters. @@ -56,9 +59,9 @@ class Scrublet: stdev_doublet_rate Uncertainty in the expected doublet rate. - rng - Random number generator for doublet simulation, approximate - nearest neighbor search, and PCA/TruncatedSVD. + {rng} + + Used for doublet simulation, nearest neighbor search, and PCA/TruncatedSVD. """ diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index 09a1729e6c..18eb79c52d 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -23,8 +23,10 @@ from .. import logging as logg from .._compat import CSBase, CSRBase, DaskArray, deprecated, njit +from .._docs import doc_rng from .._settings import settings as sett from .._utils import ( + _doc_params, _resolve_axis, check_array_function_arguments, is_backed_type, @@ -868,6 +870,9 @@ def sample[A: np.ndarray | CSBase | DaskArray]( axis: Literal["obs", 0, "var", 1] = "obs", p: str | NDArray[np.bool] | NDArray[np.floating] | None = None, ) -> tuple[A, NDArray[np.int64]]: ... + + +@_doc_params(rng=doc_rng) def sample( # noqa: PLR0912 data: AnnData | np.ndarray | CSBase | DaskArray, fraction: float | None = None, @@ -895,8 +900,7 @@ def sample( # noqa: PLR0912 See `axis` and `replace`. n Sample to this number of observations or variables. See `axis`. - rng - Random seed to change subsampling. + {rng} copy If an :class:`~anndata.AnnData` is passed, determines whether a copy is returned. @@ -976,6 +980,7 @@ def sample( # noqa: PLR0912 @_accepts_legacy_random_state(0) +@_doc_params(rng=doc_rng) def downsample_counts( adata: AnnData, counts_per_cell: int | Collection[int] | None = None, @@ -1005,8 +1010,7 @@ def downsample_counts( total_counts Target total counts. If the count matrix has more than `total_counts` it will be downsampled to have this number. - rng - Random seed for subsampling. + {rng} replace Whether to sample the counts with replacement. copy diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index b07f4e3d20..90f2976837 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -4,6 +4,8 @@ import numpy as np +from .._docs import doc_rng +from .._utils import _doc_params from .._utils.random import _accepts_legacy_random_state from ._dpt import _diffmap @@ -13,6 +15,7 @@ from .._utils.random import RNGLike, SeedLike +@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) def diffmap( adata: AnnData, @@ -52,8 +55,7 @@ def diffmap( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. - rng - A numpy random number generator + {rng} copy Return a copy instead of writing to adata. diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index b26c671789..1335b329de 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -7,7 +7,8 @@ from .. import _utils from .. import logging as logg -from .._utils import _choose_graph, get_literal_vals +from .._docs import doc_rng +from .._utils import _choose_graph, _doc_params, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, _if_legacy_apply_global, @@ -29,6 +30,7 @@ type _Layout = Literal["fr", "drl", "kk", "grid_fr", "lgl", "rt", "rt_circular", "fa"] +@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) def draw_graph( # noqa: PLR0913 adata: AnnData, @@ -78,9 +80,9 @@ def draw_graph( # noqa: PLR0913 'rt' (Reingold Tilford tree layout). root Root for tree layouts. - rng - For layouts with random initialization like 'fr', change this to use - different intial states for the optimization. If `None`, no seed is set. + {rng} + + Applies to layouts with random initialization like `'fr'`. adjacency Sparse adjacency matrix of the graph, defaults to neighbors connectivities. key_added_ext diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index e24eab7873..3e274208e7 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -9,8 +9,14 @@ from .. import logging as logg from .._compat import CSBase +from .._docs import doc_rng from .._settings import settings -from .._utils import NeighborsView, _empty, raise_not_implemented_error_if_backed_type +from .._utils import ( + NeighborsView, + _doc_params, + _empty, + raise_not_implemented_error_if_backed_type, +) from .._utils._doctests import doctest_skip from .._utils.random import _legacy_random_state, _LegacyRng from ..neighbors import FlatTree @@ -197,6 +203,7 @@ def __repr__(self): return f"{type(self).__name__}({self._data})" +@_doc_params(rng=doc_rng) class Ingest: """Class to map labels and embeddings from existing data to new data. @@ -208,10 +215,7 @@ class Ingest: adata The annotated data matrix of shape `n_obs` × `n_vars` with embeddings and labels. - rng - Random number generator. - `None` means non-determinism. - By default uses the tools’ seeds if available, else seed `0`. + {rng} """ diff --git a/src/scanpy/tools/_leiden.py b/src/scanpy/tools/_leiden.py index 49179e48bb..0e9a59573f 100644 --- a/src/scanpy/tools/_leiden.py +++ b/src/scanpy/tools/_leiden.py @@ -9,6 +9,7 @@ from .. import _utils from .. import logging as logg from .._compat import warn +from .._docs import doc_rng from .._utils import _doc_params from .._utils.random import _accepts_legacy_random_state, _LegacyRng, _set_igraph_rng from ._docs import ( @@ -43,6 +44,7 @@ adjacency=doc_adjacency, neighbors_key=doc_neighbors_key.format(method="leiden"), obsp=doc_obsp, + rng=doc_rng, ) @_accepts_legacy_random_state(0) def leiden( # noqa: PLR0913 @@ -82,8 +84,7 @@ def leiden( # noqa: PLR0913 Higher values lead to more clusters. Set to `None` if overriding `partition_type` to one that doesn’t accept a `resolution_parameter`. - rng - Change the initialization of the optimization. + {rng} {restrict_to} key_added `adata.obs` key under which to add the cluster labels. diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index 0ba868a888..9c0708fe62 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -9,7 +9,8 @@ from .. import logging as logg from .._compat import CSBase -from .._utils import check_use_raw, is_backed_type +from .._docs import doc_rng +from .._utils import _doc_params, check_use_raw, is_backed_type from .._utils.random import _accepts_legacy_random_state, _if_legacy_apply_global from ..get import _get_obs_rep @@ -51,6 +52,7 @@ def _sparse_nanmean(x: CSBase, /, axis: Literal[0, 1]) -> NDArray[np.float64]: return m +@_doc_params(rng=doc_rng) @_accepts_legacy_random_state(0) def score_genes( # noqa: PLR0913 adata: AnnData, @@ -96,8 +98,7 @@ def score_genes( # noqa: PLR0913 Number of expression level bins for sampling. score_name Name of the field to be added in `.obs`. - rng - The random number generator for sampling. + {rng} copy Copy `adata` or modify it inplace. use_raw diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index e8d9b4414f..96ba4e2a05 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -4,6 +4,7 @@ from .. import logging as logg from .._compat import warn +from .._docs import doc_rng from .._settings import settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type from .._utils.random import _accepts_legacy_random_state, _legacy_random_state @@ -17,7 +18,7 @@ @_accepts_legacy_random_state(0) -@_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep) +@_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) def tsne( # noqa: PLR0913 adata: AnnData, n_pcs: int | None = None, @@ -76,9 +77,7 @@ def tsne( # noqa: PLR0913 optimization, the early exaggeration factor or the learning rate might be too high. If the cost function gets stuck in a bad local minimum increasing the learning rate helps sometimes. - rng - Change this to use different intial states for the optimization. - If `None`, the initial state is not reproducible. + {rng} n_jobs Number of jobs for parallel computation. `None` means using :attr:`scanpy.settings.n_jobs`. diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 7850674400..d1cc0dd1e7 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -8,8 +8,9 @@ from .. import logging as logg from .._compat import warn +from .._docs import doc_rng from .._settings import settings -from .._utils import NeighborsView +from .._utils import NeighborsView, _doc_params from .._utils.random import ( _accepts_legacy_random_state, _legacy_random_state, @@ -29,6 +30,7 @@ @_accepts_legacy_random_state(0) +@_doc_params(rng=doc_rng) def umap( # noqa: PLR0913, PLR0915 adata: AnnData, *, @@ -102,10 +104,7 @@ def umap( # noqa: PLR0913, PLR0915 * 'spectral': use a spectral embedding of the graph. * 'random': assign initial embedding positions at random. * A numpy array of initial embedding positions. - rng - If `int`, `rng` is the seed used by the random number generator; - If `np.random.Generator`, `rng` is the random number generator; - If `None`, the random number generator is not reproducible. + {rng} a More specific parameters controlling the embedding. If `None` these values are set automatically as determined by `min_dist` and