From 0ddab38d6e1e5d1c033661cf165ae5ff17acde8a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:06:08 +0200 Subject: [PATCH 01/24] feat: in-memory backend --- docs/installation.md | 1 + pyproject.toml | 6 +- src/annbatch/loader.py | 97 +++++++++++++++++---- src/annbatch/types.py | 2 +- tests/test_dataset.py | 185 ++++++++++++++++++++++++++++++++++------- 5 files changed, 245 insertions(+), 46 deletions(-) diff --git a/docs/installation.md b/docs/installation.md index bf205f3f..1c7eec92 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -28,6 +28,7 @@ Otherwise, be sure to install the `[remote]` extra for `zarr-python` to be able | `torch` | Yields batches as 0-copy {class}`torch.Tensor`s. | | `cupy-cuda12` | GPU acceleration via `cupy` for CUDA 12, highly recommended for CUDA systems. | | `cupy-cuda13` | GPU acceleration via `cupy` for CUDA 13, highly recommended for CUDA systems. | +| `numba` | CPU acceleration for indexing of in-memory sparse matrices i.e., when {meth}`Loader.add_adatas` is called on an {class}`~anndata.AnnData` object with a {class}`~scipy.sparse.csr_matrix` in-memory. | `cupy` provides accelerated handling of the data via `preload_to_gpu` once it has been read off disk, and does not need to be used in conjunction with `torch`. `cupy` is also compatible with `rocm` (AMD) devices, although we do not provide an extra for installing. diff --git a/pyproject.toml b/pyproject.toml index 32157269..893d83e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,9 @@ optional-dependencies.doc = [ "sphinxcontrib-bibtex>=1", "sphinxext-opengraph", ] +optional-dependencies.numba = [ + "numba", +] optional-dependencies.test = [ "annbatch[zarrs]", "coverage", @@ -89,7 +92,7 @@ envs.docs.scripts.open = "python -m webbrowser -t docs/_build/html/index.html" envs.hatch-test.python = "3.14" envs.hatch-test.features = [ "test" ] envs.hatch-test.matrix = [ - { deps = [ "min-low", "pre", "torch", "min-high" ] }, + { deps = [ "min-low", "pre", "torch", "min-high", "numba" ] }, ] # If the matrix variable `deps` is set to "pre", # set the environment variable `UV_PRERELEASE` to "allow". @@ -99,6 +102,7 @@ envs.hatch-test.overrides.matrix.deps.env-vars = [ ] envs.hatch-test.overrides.matrix.deps.features = [ { value = "torch", if = [ "torch" ] }, + { value = "numba", if = [ "numba" ] }, ] envs.hatch-test.overrides.matrix.deps.pre-install-commands = [ { value = "echo 'zarr @ git+https://github.com/zarr-developers/zarr-python.git' > pre-deps.txt", if = [ diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index ec08099d..b0de30a0 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -53,6 +53,34 @@ class CSRDatasetElems(NamedTuple): data: zarr.AsyncArray +if find_spec("numba"): + import numba + + @numba.njit(parallel=True, cache=True, nogil=True) + def _csr_subset_rows(src_data, src_indices, src_indptr, rows, out_data, out_indices): # type: ignore + n_rows = rows.shape[0] + row_nnz = np.empty(n_rows, dtype=np.int64) + for i in range(n_rows): + r = rows[i] + row_nnz[i] = src_indptr[r + 1] - src_indptr[r] + out_offsets = np.empty(n_rows + 1, dtype=np.int64) + out_offsets[0] = 0 + for i in range(n_rows): + out_offsets[i + 1] = out_offsets[i] + row_nnz[i] + for i in numba.prange(n_rows): + r = rows[i] + src_start = src_indptr[r] + dst_start = out_offsets[i] + n = row_nnz[i] + for j in range(n): + out_data[dst_start + j] = src_data[src_start + j] + out_indices[dst_start + j] = src_indices[src_start + j] +else: + + def _csr_subset_rows(src_data, src_indices, src_indptr, rows, out_data, out_indices): + raise ImportError("numba must be installed for in-memory data: `pip install annbatch[numba]`") + + def _cupy_dtype(dtype: np.dtype) -> np.dtype: if dtype in {np.dtype("float32"), np.dtype("float64"), np.dtype("bool")}: return dtype @@ -391,7 +419,10 @@ def add_datasets( @validate_sampler def add_dataset( - self, dataset: BackingArray, obs: pd.DataFrame | None = None, var: pd.DataFrame | None = None + self, + dataset: BackingArray, + obs: pd.DataFrame | None = None, + var: pd.DataFrame | None = None, ) -> Self: """Append a dataset to this dataset. @@ -409,7 +440,10 @@ def add_dataset( return self def _add_dataset_unchecked( - self, dataset: BackingArray, obs: pd.DataFrame | None = None, var: pd.DataFrame | None = None + self, + dataset: BackingArray, + obs: pd.DataFrame | None = None, + var: pd.DataFrame | None = None, ) -> Self: if len(self._train_datasets) > 0: if self._obs is None and obs is not None: @@ -480,7 +514,11 @@ def _requests_to_dataset_rows(self, requests: list[slice] | np.ndarray) -> Order global_index = np.concatenate([np.arange(s.start, s.stop) for s in requests]) # Locate each requested row in its dataset by binary-searching the dataset boundaries, - sizes = np.fromiter((shape[0] for shape in self._shapes), dtype=np.int64, count=len(self._shapes)) + sizes = np.fromiter( + (shape[0] for shape in self._shapes), + dtype=np.int64, + count=len(self._shapes), + ) ends = np.cumsum(sizes) starts = ends - sizes dataset_of_row = np.searchsorted(ends, global_index, side="right") @@ -517,15 +555,18 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: return cpx.empty_pinned(shape, dtype) return np.empty(shape, dtype) - if issubclass(self.dataset_type, ad.abc.CSRDataset): + if (is_backed := issubclass(self.dataset_type, ad.abc.CSRDataset)) or issubclass( + self.dataset_type, sp.csr_array | sp.csr_matrix + ): + datasets = self._dataset_elem_cache if is_backed else self._train_datasets total_nnz = sum( - int((self._dataset_elem_cache[idx].indptr[rows + 1] - self._dataset_elem_cache[idx].indptr[rows]).sum()) + int((datasets[idx].indptr[rows + 1] - datasets[idx].indptr[rows]).sum()) for idx, rows in dataset_index_to_rows.items() ) first_idx = next(iter(dataset_index_to_rows)) - data_dtype = self._dataset_elem_cache[first_idx].data.dtype - indices_dtype = self._dataset_elem_cache[first_idx].indices.dtype - indptr_dtype = self._dataset_elem_cache[first_idx].indptr.dtype + data_dtype = datasets[first_idx].data.dtype + indices_dtype = datasets[first_idx].indices.dtype + indptr_dtype = datasets[first_idx].indptr.dtype return CSRContainer( elems=( _alloc((total_nnz,), data_dtype), @@ -642,6 +683,31 @@ def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray: raise ValueError("Cache not prepared") return self._dataset_elem_cache[dataset_idx] + @_fetch_data.register + async def _fetch_data_csr_matrix( + self, + dataset: np.ndarray, + rows: np.ndarray, + out: np.ndarray, + ) -> None: + out[:] = dataset[rows] + + @_fetch_data.register + async def _fetch_data_csr_matrix( + self, + dataset: sp.csr_matrix | sp.csr_array, + rows: np.ndarray, + out: CSRContainer, + ) -> None: + _csr_subset_rows( + dataset.data, + dataset.indices, + dataset.indptr, + np.ascontiguousarray(rows), + out.elems[0], + out.elems[1], + ) + @_fetch_data.register async def _fetch_data_sparse( self, @@ -695,8 +761,9 @@ async def _index_datasets( dataset_index_to_rows A lookup of the list-placement index of a dataset to the sorted row indices to fetch. """ - is_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) - if is_sparse: + is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) + is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix) + if is_backed_sparse: await self._ensure_sparse_cache() out = self._allocate_out(dataset_index_to_rows) @@ -708,7 +775,8 @@ async def _index_datasets( for dataset_idx, rows in dataset_index_to_rows.items(): nrows = len(rows) if is_sparse: - cached_indptr = self._dataset_elem_cache[dataset_idx].indptr + datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + cached_indptr = datasets[dataset_idx].indptr nnz = int((cached_indptr[rows + 1] - cached_indptr[rows]).sum()) out_view: CSRContainer | np.ndarray = CSRContainer( elems=( @@ -725,7 +793,7 @@ async def _index_datasets( tasks.append( self._fetch_data( - self._get_elem_from_cache(dataset_idx) if is_sparse else self._train_datasets[dataset_idx], + self._get_elem_from_cache(dataset_idx) if is_backed_sparse else self._train_datasets[dataset_idx], rows, out_view, ) @@ -735,11 +803,12 @@ async def _index_datasets( await asyncio.gather(*tasks) if is_sparse: + datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets running_nnz = 0 row_pos = 0 out.elems[2][0] = 0 for dataset_idx, rows in dataset_index_to_rows.items(): - cached_indptr = self._dataset_elem_cache[dataset_idx].indptr + cached_indptr = datasets[dataset_idx].indptr per_row_nnz = cached_indptr[rows + 1] - cached_indptr[rows] dest = out.elems[2][row_pos + 1 : row_pos + len(rows) + 1] np.cumsum(per_row_nnz, out=dest) @@ -766,7 +835,7 @@ def __iter__( [len(self._train_datasets), self.n_obs], ["Number of datasets", "Number of observations"], ) - is_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) + is_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset | sp.csr_matrix | sp.csr_array) # Create `positions` variable so we don't need to run `np.arange` (O(n)) every time positions = np.empty(0, dtype=np.intp) for load_request in self._batch_sampler.sample(self.n_obs): diff --git a/src/annbatch/types.py b/src/annbatch/types.py index 340689ff..3c11d2ca 100644 --- a/src/annbatch/types.py +++ b/src/annbatch/types.py @@ -11,7 +11,7 @@ from .compat import CupyArray, CupyCSRMatrix, Tensor from .utils import CSRContainer -type BackingArray_T = ad.abc.CSRDataset | ZarrArray +type BackingArray_T = ad.abc.CSRDataset | ZarrArray | sp.csr_array | sp.csr_matrix | np.ndarray type InputInMemoryArray_T = CSRContainer | np.ndarray type OutputInMemoryArray_T = sp.csr_matrix | np.ndarray | CupyCSRMatrix | CupyArray | Tensor diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c901576f..aef1000b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -32,6 +32,7 @@ skip_if_no_cupy = pytest.mark.skipif(find_spec("cupy") is None, reason="Can't test for preload_to_gpu without cupy") skip_if_no_torch = pytest.mark.skipif(find_spec("torch") is None, reason="Need torch installed.") +skip_if_no_numba = pytest.mark.skipif(find_spec("numba") is None, reason="Can't test for in-memory without numba") class Data(TypedDict): @@ -60,6 +61,36 @@ def open_sparse(path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata return data +def open_in_memory_sparse( + path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False +) -> Data | ad.AnnData: + if not isinstance(path, zarr.Group): + path = zarr.open(path) + data = { + "dataset": ad.io.read_elem(path["layers"]["sparse"]), + "obs": ad.io.read_elem(path["obs"]), + "var": ad.io.read_elem(path["var"]), + } + if use_anndata: + return ad.AnnData(X=data["dataset"], obs=data["obs"], var=data["var"]) + return data + + +def open_in_memory_dense( + path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False +) -> Data | ad.AnnData: + if not isinstance(path, zarr.Group): + path = zarr.open(path) + data = { + "dataset": ad.io.read_elem(path["X"]), + "obs": ad.io.read_elem(path["obs"]), + "var": ad.io.read_elem(path["var"]), + } + if use_anndata: + return ad.AnnData(X=data["dataset"], obs=data["obs"], var=data["var"]) + return data + + def open_dense(path: Path | zarr.Group, *, use_zarrs: bool = False, use_anndata: bool = False) -> Data | ad.AnnData: old_pipeline = zarr.config.get("codec_pipeline.path") @@ -92,7 +123,11 @@ def open_3d(path: Path | zarr.Group, *, use_zarrs: bool = False) -> Data: def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: return ( - {"datasets": [d["dataset"] for d in datas], "obs": [d["obs"] for d in datas], "var": [d["var"] for d in datas]} + { + "datasets": [d["dataset"] for d in datas], + "obs": [d["obs"] for d in datas], + "var": [d["var"] for d in datas], + } if all(isinstance(d, dict) for d in datas) else datas ) @@ -122,12 +157,20 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ) ), id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] - marks=[skip_if_no_cupy, pytest.mark.gpu] if preload_to_gpu else [], + marks=[skip_if_no_cupy, pytest.mark.gpu] + if preload_to_gpu + else [] + ([skip_if_no_numba] if open_func is open_sparse else []), ) for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem for preload_to_gpu in [True, False] - for open_func in [open_sparse, open_dense, None] + for open_func in [ + open_sparse, + open_dense, + open_in_memory_dense, + open_in_memory_sparse, + None, + ] for elem in [ [ 1, @@ -162,7 +205,11 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ], ) def test_store_load_dataset( - simple_collection: tuple[ad.AnnData, DatasetCollection], *, shuffle: bool, gen_loader, use_zarrs + simple_collection: tuple[ad.AnnData, DatasetCollection], + *, + shuffle: bool, + gen_loader, + use_zarrs, ): """ This test verifies that the DaskDataset works correctly: @@ -172,8 +219,10 @@ def test_store_load_dataset( 4. If the dataset is not shuffled, it returns the correct data """ loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) + if use_zarrs and loader.dataset_type in {np.ndarray, sp.csr_matrix, sp.csr_array}: + pytest.skip("No need to run zarrs with in-memory") adata = simple_collection[0] - is_dense = loader.dataset_type is zarr.Array + is_dense = loader.dataset_type in {zarr.Array, np.ndarray} n_elems = 0 batches = [] obs = [] @@ -239,7 +288,13 @@ def test_zarr_store_errors_lt_1(gen_loader, adata_with_zarr_path_same_var_space: def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): data = open_dense(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr"))) data["dataset"] = data["dataset"][...] - ds = Loader(shuffle=True, chunk_size=10, preload_nchunks=10, preload_to_gpu=False, to_torch=False) + ds = Loader( + shuffle=True, + chunk_size=10, + preload_nchunks=10, + preload_to_gpu=False, + to_torch=False, + ) with pytest.raises(TypeError, match="Cannot add"): ds.add_dataset(**data) @@ -348,10 +403,18 @@ def test_len( assert len(loader) == actual_batches -def test_bad_adata_X_hdf5(adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path]): +def test_bad_adata_X_hdf5( + adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path], +): with h5py.File(next(adata_with_h5_path_different_var_space[1].glob("*.h5ad"))) as f: data = ad.io.sparse_dataset(f["X"]) - ds = Loader(shuffle=True, chunk_size=10, preload_nchunks=10, preload_to_gpu=False, to_torch=False) + ds = Loader( + shuffle=True, + chunk_size=10, + preload_nchunks=10, + preload_to_gpu=False, + to_torch=False, + ) with pytest.raises(TypeError, match="Cannot add"): ds.add_dataset(data) @@ -375,7 +438,9 @@ def _custom_collate_fn(elems): @skip_if_no_torch @pytest.mark.parametrize("open_func", [open_sparse, open_dense]) def test_torch_multiprocess_dataloading_zarr( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], open_func, use_zarrs: bool + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], + open_func, + use_zarrs: bool, ): """ Test that Loader can be used with PyTorch's DataLoader in a multiprocess context and that each element of @@ -383,7 +448,13 @@ def test_torch_multiprocess_dataloading_zarr( """ from torch.utils.data import DataLoader - ds = Loader(chunk_size=10, preload_nchunks=4, shuffle=True, return_index=True, preload_to_gpu=False) + ds = Loader( + chunk_size=10, + preload_nchunks=4, + shuffle=True, + return_index=True, + preload_to_gpu=False, + ) ds.add_datasets( **concat([open_func(p, use_zarrs=use_zarrs) for p in adata_with_zarr_path_same_var_space[1].glob("*.zarr")]) ) @@ -393,7 +464,11 @@ def test_torch_multiprocess_dataloading_zarr( x_ref = adata_with_zarr_path_same_var_space[0].X dataloader = DataLoader( - ds, batch_size=32, num_workers=4, collate_fn=_custom_collate_fn, multiprocessing_context="spawn" + ds, + batch_size=32, + num_workers=4, + collate_fn=_custom_collate_fn, + multiprocessing_context="spawn", ) x_list, idx_list = [], [] for batch in dataloader: @@ -408,11 +483,20 @@ def test_torch_multiprocess_dataloading_zarr( @pytest.mark.parametrize( - "preload_to_gpu", [False, pytest.param(True, marks=[pytest.mark.gpu, skip_if_no_cupy])], ids=["no_cupy", "cupy"] + "preload_to_gpu", + [False, pytest.param(True, marks=[pytest.mark.gpu, skip_if_no_cupy])], + ids=["no_cupy", "cupy"], +) +@pytest.mark.parametrize( + "to_torch", + [False, pytest.param(True, marks=[skip_if_no_torch])], + ids=["no_torch", "torch"], ) -@pytest.mark.parametrize("to_torch", [False, pytest.param(True, marks=[skip_if_no_torch])], ids=["no_torch", "torch"]) def test_3d( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], use_zarrs: bool, preload_to_gpu: bool, to_torch: bool + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], + use_zarrs: bool, + preload_to_gpu: bool, + to_torch: bool, ): ds = Loader( chunk_size=10, @@ -449,17 +533,20 @@ def test_3d( @pytest.mark.skipif( - find_spec("cupy") is not None, reason="Can't test for preload_to_gpu True ImportError with cupy installed" + find_spec("cupy") is not None, + reason="Can't test for preload_to_gpu True ImportError with cupy installed", ) def test_no_cupy(): with pytest.raises( - ImportError, match=r"Follow the directions at https://docs.cupy.dev/en/stable/install.html to install cupy." + ImportError, + match=r"Follow the directions at https://docs.cupy.dev/en/stable/install.html to install cupy.", ): Loader(chunk_size=10, preload_nchunks=4, preload_to_gpu=True, to_torch=False) @pytest.mark.skipif( - find_spec("torch") is not None, reason="Can't test for to_torch True ImportError with torch installed" + find_spec("torch") is not None, + reason="Can't test for to_torch True ImportError with torch installed", ) def test_no_torch(): with pytest.raises(ImportError, match=r"Try `pip install torch`."): @@ -492,11 +579,18 @@ def get_default_sparse() -> type: ), ) def test_default_data_structures( - adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], expected_cls: type, kwargs: dict + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], + expected_cls: type, + kwargs: dict, ): # format is a smoke test for sparse ds = Loader( - chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, return_index=False, **kwargs + chunk_size=10, + preload_nchunks=4, + batch_size=20, + shuffle=True, + return_index=False, + **kwargs, ).add_dataset( **(open_sparse if issubclass(expected_cls, get_default_sparse()) else open_dense)( list(adata_with_zarr_path_same_var_space[1].iterdir())[0] @@ -566,26 +660,43 @@ def test_mismatched_var_raises_error(tmp_path: Path, subtests): with subtests.test(msg="add_datasets"): loader = Loader(chunk_size=10, preload_nchunks=4, batch_size=20) with pytest.raises(ValueError, match="All datasets must have identical var DataFrames"): - loader.add_datasets([adata1_on_disk.X, adata2_on_disk.X], var=[adata1_on_disk.var, adata2_on_disk.var]) + loader.add_datasets( + [adata1_on_disk.X, adata2_on_disk.X], + var=[adata1_on_disk.var, adata2_on_disk.var], + ) @pytest.mark.gpu @skip_if_no_cupy @pytest.mark.parametrize( ("dtype_in", "expected"), - [(np.int16, np.float32), (np.int32, np.float64), (np.float32, np.float32), (np.float64, np.float64)], + [ + (np.int16, np.float32), + (np.int32, np.float64), + (np.float32, np.float32), + (np.float64, np.float64), + ], ) def test_preload_dtype(tmp_path: Path, dtype_in: np.dtype, expected: np.dtype): z = zarr.open(tmp_path / "foo.zarr") - write_sharded(z, ad.AnnData(X=sp.random(100, 10, dtype=dtype_in, format="csr", rng=np.random.default_rng()))) - adata = ad.AnnData(X=ad.io.sparse_dataset(z["X"])) - loader = Loader(preload_to_gpu=True, batch_size=10, chunk_size=10, preload_nchunks=2, to_torch=False).add_adata( - adata + write_sharded( + z, + ad.AnnData(X=sp.random(100, 10, dtype=dtype_in, format="csr", rng=np.random.default_rng())), ) + adata = ad.AnnData(X=ad.io.sparse_dataset(z["X"])) + loader = Loader( + preload_to_gpu=True, + batch_size=10, + chunk_size=10, + preload_nchunks=2, + to_torch=False, + ).add_adata(adata) assert next(iter(loader))["X"].dtype == expected -def test_add_dataset_validation_failure_preserves_state(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): +def test_add_dataset_validation_failure_preserves_state( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], +): """Test that failed validation in add_dataset doesn't modify internal state.""" class FailOnSecondValidateSampler(Sampler): @@ -686,10 +797,20 @@ def test_cannot_provide_batch_sampler_with_sampler_args(kwarg): def test_rng(simple_collection: tuple[ad.AnnData, DatasetCollection]): ds1 = Loader( - chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, rng=np.random.default_rng(0), to_torch=False + chunk_size=10, + preload_nchunks=4, + batch_size=20, + shuffle=True, + rng=np.random.default_rng(0), + to_torch=False, ) ds2 = Loader( - chunk_size=10, preload_nchunks=4, batch_size=20, shuffle=True, rng=np.random.default_rng(0), to_torch=False + chunk_size=10, + preload_nchunks=4, + batch_size=20, + shuffle=True, + rng=np.random.default_rng(0), + to_torch=False, ) ds1.use_collection( simple_collection[1], @@ -712,7 +833,9 @@ def _sample(self, n_obs: int): yield {"requests": self._requests, "splits": self._splits} -def test_splits_are_chunk_order_across_datasets(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): +def test_splits_are_chunk_order_across_datasets( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], +): """Splits index in chunk order; the loader undoes its dataset-grouped buffer layout. The request lists a chunk from dataset 1 *before* a chunk from dataset 0, so the loader's @@ -750,7 +873,9 @@ def test_splits_are_chunk_order_across_datasets(adata_with_zarr_path_same_var_sp np.testing.assert_array_equal(np.asarray(batches[1]["X"]), np.asarray(data0["dataset"][0:10])) -def test_chunks_deprecation_warning(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): +def test_chunks_deprecation_warning( + adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path], +): paths = sorted(adata_with_zarr_path_same_var_space[1].glob("*.zarr")) data0 = open_dense(paths[0]) From 2e6cc15b2b29ff3cc8a9ce17621e29d6424e9874 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:10:30 +0200 Subject: [PATCH 02/24] chore: gpu test --- .github/workflows/test-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-gpu.yaml b/.github/workflows/test-gpu.yaml index b39dc871..f48ba906 100644 --- a/.github/workflows/test-gpu.yaml +++ b/.github/workflows/test-gpu.yaml @@ -41,7 +41,7 @@ jobs: runs-on: "cirun-aws-gpu--${{ github.run_id }}" strategy: matrix: - extras: ["torch,cupy-cuda12", "torch", "cupy-cuda12"] + extras: ["torch,cupy-cuda12", "torch", "cupy-cuda12", "cupy-cuda12,numba"] # Setting a timeout of 30 minutes, as the AWS costs money # At time of writing, a typical run takes about 5 minutes timeout-minutes: 30 From df88dac8211670a45a852967b48487afd67efca7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:13:27 +0200 Subject: [PATCH 03/24] import error --- src/annbatch/loader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b0de30a0..24b8fa5c 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -75,10 +75,10 @@ def _csr_subset_rows(src_data, src_indices, src_indptr, rows, out_data, out_indi for j in range(n): out_data[dst_start + j] = src_data[src_start + j] out_indices[dst_start + j] = src_indices[src_start + j] -else: +else: # pragma: no cover def _csr_subset_rows(src_data, src_indices, src_indptr, rows, out_data, out_indices): - raise ImportError("numba must be installed for in-memory data: `pip install annbatch[numba]`") + raise ImportError("numba must be installed for in-memory sparse data: `pip install annbatch[numba]`") def _cupy_dtype(dtype: np.dtype) -> np.dtype: @@ -472,6 +472,8 @@ def _add_dataset_unchecked( raise TypeError( "Cannot add CSRDataset backed by h5ad at the moment: see https://github.com/zarr-developers/VirtualiZarr/pull/790" ) + if isinstance(dataset, sp.csr_matrix | sp.csr_array) and not find_spec("numba"): + raise ImportError("numba must be installed for in-memory sparse data: `pip install annbatch[numba]`") if not isinstance(obs, pd.DataFrame) and obs is not None: raise TypeError("obs must be a pandas DataFrame") if not isinstance(var, pd.DataFrame) and var is not None: From f7da338444f763e69b75d6173cfceb05d4b51b06 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:15:41 +0200 Subject: [PATCH 04/24] chore: test --- tests/test_dataset.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index aef1000b..c9ffb4cd 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -553,6 +553,20 @@ def test_no_torch(): Loader(chunk_size=10, preload_nchunks=4, to_torch=True, preload_to_gpu=False) +@pytest.mark.skipif( + find_spec("numba") is not None, + reason="Can't test for sparse in-memory ImportError with numba installed", +) +def test_no_numba_in_memory_sparse(monkeypatch: pytest.MonkeyPatch): + loader = Loader(chunk_size=10, preload_nchunks=4, to_torch=False, preload_to_gpu=False) + sparse_data = sp.csr_matrix(np.eye(10, dtype=np.float32)) + with pytest.raises( + ImportError, + match=r"numba must be installed for in-memory sparse data", + ): + loader.add_dataset(sparse_data) + + def get_default_dense() -> type: if find_spec("torch"): from torch import Tensor as expected_dense From 076b14c6ffbbb1ba1c1a96f2afff292a18d69da1 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:21:23 +0200 Subject: [PATCH 05/24] fix: installation docs --- docs/installation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation.md b/docs/installation.md index 1c7eec92..1d8fa4e7 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -28,7 +28,7 @@ Otherwise, be sure to install the `[remote]` extra for `zarr-python` to be able | `torch` | Yields batches as 0-copy {class}`torch.Tensor`s. | | `cupy-cuda12` | GPU acceleration via `cupy` for CUDA 12, highly recommended for CUDA systems. | | `cupy-cuda13` | GPU acceleration via `cupy` for CUDA 13, highly recommended for CUDA systems. | -| `numba` | CPU acceleration for indexing of in-memory sparse matrices i.e., when {meth}`Loader.add_adatas` is called on an {class}`~anndata.AnnData` object with a {class}`~scipy.sparse.csr_matrix` in-memory. | +| `numba` | CPU acceleration for indexing of in-memory sparse matrices i.e., when {meth}`annbatch.Loader.add_adatas` is called on an {class}`~anndata.AnnData` object with a {class}`~scipy.sparse.csr_matrix` in-memory. | `cupy` provides accelerated handling of the data via `preload_to_gpu` once it has been read off disk, and does not need to be used in conjunction with `torch`. `cupy` is also compatible with `rocm` (AMD) devices, although we do not provide an extra for installing. From 378cf85708c29e0e4913016b2ceb75e16d5fa510 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:22:38 +0200 Subject: [PATCH 06/24] fix: marks --- tests/test_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c9ffb4cd..a9a3bcb6 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -157,9 +157,8 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ) ), id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] - marks=[skip_if_no_cupy, pytest.mark.gpu] - if preload_to_gpu - else [] + ([skip_if_no_numba] if open_func is open_sparse else []), + marks=([skip_if_no_cupy, pytest.mark.gpu] if preload_to_gpu else []) + + ([skip_if_no_numba] if open_func is open_sparse else []), ) for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem From 7c794e01101cada9d8d5416f68bf2deba5edbc1e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:26:35 +0200 Subject: [PATCH 07/24] fix: check --- tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a9a3bcb6..13e15722 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -158,7 +158,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ), id=f"chunk_size={chunk_size}-preload_nchunks={preload_nchunks}-open_func={open_func.__name__[5:] if open_func is not None else 'None'}-batch_size={batch_size}{'-cupy' if preload_to_gpu else ''}", # type: ignore[attr-defined] marks=([skip_if_no_cupy, pytest.mark.gpu] if preload_to_gpu else []) - + ([skip_if_no_numba] if open_func is open_sparse else []), + + ([skip_if_no_numba] if open_func is open_in_memory_sparse else []), ) for chunk_size, preload_nchunks, open_func, batch_size, preload_to_gpu in [ elem From 69628e40945d87fa607b5b5d661d5453a01fb44f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:28:37 +0200 Subject: [PATCH 08/24] chore: changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f9c269f..85282c74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning][]. ### Feature - Add a `merge` argument to {meth}`annbatch.DatasetCollection.add_adatas` to handle how columns in {attr}`~anndata.AnnData.var` are handled when creating the on-disk dataset. - Now {attr}`annbatch.types.LoadRequest.requests` (formerly `annbatch.types.LoadRequest.chunks`) can also be a numpy array of integers. +- Support in memory matrices ({class}`scipy.sparse.csr_matrix`, {class}`scipy.sparse.csr_array`, {class}`numpy.ndarray`) requiring `numba` for the sparse cases (new additional dependency group for `numba` included). ### Breaking - Removal of deprecated `annbatch.ChunkSampler` From bc5ec060d06a451955127c5878987cac8a5b8358 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:32:03 +0200 Subject: [PATCH 09/24] chore: remove test --- tests/test_dataset.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 13e15722..370c9210 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -284,20 +284,6 @@ def test_zarr_store_errors_lt_1(gen_loader, adata_with_zarr_path_same_var_space: gen_loader(adata_with_zarr_path_same_var_space[1]) -def test_bad_adata_X_type(adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]): - data = open_dense(next(adata_with_zarr_path_same_var_space[1].glob("*.zarr"))) - data["dataset"] = data["dataset"][...] - ds = Loader( - shuffle=True, - chunk_size=10, - preload_nchunks=10, - preload_to_gpu=False, - to_torch=False, - ) - with pytest.raises(TypeError, match="Cannot add"): - ds.add_dataset(**data) - - def test_use_collection_twice(simple_collection: tuple[ad.AnnData, DatasetCollection]): ds = Loader() ds = ds.use_collection(simple_collection[1]) From a7f81b13a6d15e971b250f7428ff413e85135592 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:36:20 +0200 Subject: [PATCH 10/24] chore: readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 235db444..c4361439 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@ [badge-docs]: https://img.shields.io/readthedocs/annbatch -A data loader and io utilities for mini-batched data loading of on-disk AnnData files, co-developed by [Lamin Labs][] and [scverse][] +A data loader and io utilities for mini-batched data loading of on-disk AnnData files as well as in-memory data, co-developed by [Lamin Labs][] and [scverse][] ## Getting started @@ -29,7 +29,7 @@ Please refer to the [documentation][], in particular, the [API documentation][]. pip install annbatch ``` -Please see our [installation][] page for full documentation about extras, especially [`zarrs-python`][] which is essential for local filesystems but not for remote ones. +Please see our [installation][] page for full documentation about extras, especially [`zarrs-python`][] which is essential for local filesystems but not for remote ones. `numba` is needed for in-memory sparse data. ## Performance From 2ff8c24ef8311db718e10f730b95db307e027e81 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:39:53 +0200 Subject: [PATCH 11/24] chore: add link --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c4361439..a489c462 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ Please refer to the [documentation][], in particular, the [API documentation][]. pip install annbatch ``` -Please see our [installation][] page for full documentation about extras, especially [`zarrs-python`][] which is essential for local filesystems but not for remote ones. `numba` is needed for in-memory sparse data. +Please see our [installation][] page for full documentation about extras, especially [`zarrs-python`][] which is essential for local filesystems but not for remote ones. [`numba`][] is needed for in-memory sparse data. ## Performance @@ -168,3 +168,5 @@ If you use `annbatch` in your work, please cite the `annbatch` publication as fo [in-depth section of our docs]: https://annbatch.readthedocs.io/en/stable/notebooks/example.html [installation]: https://annbatch.readthedocs.io/en/stable/installation.html + +[`numba`]: https://numba.readthedocs.io/en/stable/ From 0a8b2be319e33843502021af29c41a526eb26e48 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:42:48 +0200 Subject: [PATCH 12/24] chore: add note in notebook --- docs/notebooks/example.ipynb | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 6be4ce18..60435d2e 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -122,7 +122,9 @@ " * We recommend to choose a dataset size that comfortably fits into system memory.\n", "\n", "\n", - "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`" + "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`\n", + "\n", + "If your data fits in-memory, `annbatch` supports in-memory objects (requiring `numba` installed for sparse matrices)." ] }, { From c70b4bed6af9591bdab036175f352b2f7a6fa363 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:44:02 +0200 Subject: [PATCH 13/24] chore: important --- docs/notebooks/example.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index 60435d2e..c71c9781 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -124,7 +124,7 @@ "\n", "You can apply custom data transformations to each input h5ad file by supplying a `load_adata` function to `DatasetCollection.add`\n", "\n", - "If your data fits in-memory, `annbatch` supports in-memory objects (requiring `numba` installed for sparse matrices)." + "**IMPORTANT**: If your data fits in-memory, `annbatch` supports in-memory objects (requiring `numba` installed for sparse matrices)." ] }, { From 2b26dd5969256f0b2001a1659ee61d22643fd2d9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:45:18 +0200 Subject: [PATCH 14/24] add note to walkthrough --- docs/detailed-walkthrough.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/detailed-walkthrough.md b/docs/detailed-walkthrough.md index 207721a3..239c92d3 100644 --- a/docs/detailed-walkthrough.md +++ b/docs/detailed-walkthrough.md @@ -108,4 +108,5 @@ Thus, as a first step to assessing your needs, if your data fits in memory, load To accelerate reading the data into memory, you may still find {doc}`zarrs-python ` in conjunction with sharding still helpful in the same way it accelerates io here. To this end, please have a look at [this gist](https://gist.github.com/ilan-gold/c73383def3798df2724405aa64e40c3d) comparing file loading speeds between {func}`anndata.io.read_zarr` and {func}`anndata.io.read_h5ad`. It highlights how {doc}`zarrs-python ` and sharding can help there as well. -However, once you have too much data to fit into memory, for whatever reason, the data loading functionality offered here can provide significant speedups over state of the art out-of-core dataloaders. +`annbatch` natively supports in-memory data with unified `var` spaces (sparse and dense). +Ince you have too much data to fit into memory, for whatever reason, the on-disk data loading functionality offered here can provide significant speedups over state of the art out-of-core dataloaders. From 33f6466fead4844e159c603218c2be7427a05476 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:46:12 +0200 Subject: [PATCH 15/24] chore: readme --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index 130d576d..62569e7a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ A data loader and io utilities for mini-batched data loading of on-disk {mod}`an co-developed by [Lamin Labs](https://lamin.ai/) and [scverse](https://scverse.org/). `annbatch` lets you train models on terabyte-scale collections of `AnnData` files that do not fit -into memory, while keeping your GPU fed with high-throughput, shuffled mini-batches. +into memory, while keeping your GPU fed with high-throughput, shuffled mini-batches. It also supports in-memory data. ```{image} _static/speed_comparision.png :alt: annbatch data-loading speed compared to other dataloaders From 191ab2e9fa996f15764ffd7d069126771a458d6e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:47:33 +0200 Subject: [PATCH 16/24] chore: more notes --- docs/detailed-walkthrough.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/detailed-walkthrough.md b/docs/detailed-walkthrough.md index 239c92d3..33ff1c05 100644 --- a/docs/detailed-walkthrough.md +++ b/docs/detailed-walkthrough.md @@ -44,8 +44,9 @@ for batch in ds: The data loader implements a chunked fetching strategy where `preload_nchunks` number of contiguous-chunks of size `chunk_size` are loaded. `chunk_size` corresponds the number of rows of `anndata` store to load sequentially. +This number can be quite large for pre-shuffled data but not for un-shuffled data. -For performance reasons, you should use our dataloader directly without wrapping it into a {class}`torch.utils.data.DataLoader`. +For performance reasons, you should use our dataloader directly without wrapping it into a {class}`torch.utils.data.DataLoader` regardless of matrix type. Your code will work the same way as with a {class}`torch.utils.data.DataLoader`, but you will get better performance. In order to take advantage of the sharded zarr files performance, though, locally, you *must* set the codec pipeline to use {doc}`zarrs-python ` when reading. From 52d7b125baba823a3cb4c0c7797d15f32bf978d7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 14:48:28 +0200 Subject: [PATCH 17/24] chore: cleaner title --- docs/detailed-walkthrough.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/detailed-walkthrough.md b/docs/detailed-walkthrough.md index 33ff1c05..d1d31b49 100644 --- a/docs/detailed-walkthrough.md +++ b/docs/detailed-walkthrough.md @@ -3,7 +3,7 @@ This page walks through how `annbatch` works in depth. For a hands-on, runnable version, see the {doc}`quickstart notebook `. -## Preprocessing +## Preprocessing of On-Disk Data ```python collection = DatasetCollection("path/to/output/store.zarr").add_adatas( @@ -21,6 +21,7 @@ Shuffling is important to ensure model convergence, especially because of our co The output is a collection of sharded zarr anndata files, meant to reduce the burden on file systems of indexing. See the [zarr docs on sharding][] for more information. For performance considerations, see our dedicated docs page: {doc}`preshuffling`. +If your data fits in-memory, consider simply shuffling in-memory and passing it to the loader. [zarr docs on sharding]: https://zarr.readthedocs.io/en/stable/user-guide/arrays/#sharding From 81d6fb6dedb3c476e75f00574be32fa1ef5795a6 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 15:16:51 +0200 Subject: [PATCH 18/24] fix:spelling --- .pre-commit-config.yaml | 6 ++++++ docs/detailed-walkthrough.md | 2 +- src/annbatch/loader.py | 4 ++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fd76265b..adae5a44 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,3 +36,9 @@ repos: # Check that there are no merge conflicts (could be generated by template sync) - id: check-merge-conflict args: [--assume-in-merge] + - repo: https://github.com/codespell-project/codespell + rev: v2.4.2 + hooks: + - id: codespell + additional_dependencies: + - tomli diff --git a/docs/detailed-walkthrough.md b/docs/detailed-walkthrough.md index d1d31b49..bb8d34cf 100644 --- a/docs/detailed-walkthrough.md +++ b/docs/detailed-walkthrough.md @@ -111,4 +111,4 @@ To accelerate reading the data into memory, you may still find {doc}`zarrs-pytho To this end, please have a look at [this gist](https://gist.github.com/ilan-gold/c73383def3798df2724405aa64e40c3d) comparing file loading speeds between {func}`anndata.io.read_zarr` and {func}`anndata.io.read_h5ad`. It highlights how {doc}`zarrs-python ` and sharding can help there as well. `annbatch` natively supports in-memory data with unified `var` spaces (sparse and dense). -Ince you have too much data to fit into memory, for whatever reason, the on-disk data loading functionality offered here can provide significant speedups over state of the art out-of-core dataloaders. +Once you have too much data to fit into memory, for whatever reason, the on-disk data loading functionality offered here can provide significant speedups over state of the art out-of-core dataloaders. diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 24b8fa5c..24a87b7b 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -136,8 +136,8 @@ class Loader[ This option entails greater GPU memory usage, but is faster at least for sparse operations. :func:`torch.vstack` does not support CSR sparse matrices, hence the current use of `cupy` internally (which also means `torch` is an optional dep). Setting this to `False` is advisable when using the :class:`torch.utils.data.DataLoader` wrapper or potentially with dense data due to memory pressure. - For top performance, this should be used in conjuction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to densify. - :meth:`cupy.cuda.MemoryPool.free_all_blocks` (i.e., the method of the pool of :func:`cupy.get_default_memory_pool()`) is called aggresively to keep memory usage low. + For top performance, this should be used in conjunction with `to_torch` and then :meth:`torch.Tensor.to_dense` if you wish to densify. + :meth:`cupy.cuda.MemoryPool.free_all_blocks` (i.e., the method of the pool of :func:`cupy.get_default_memory_pool()`) is called aggressively to keep memory usage low. If you are using your own memory pool or allocator, you may have to free blocks on your own. to_torch Whether to return `torch.Tensor` as the output. From 334548c5ac372e820570c55053f5c15178f2a81c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 15:32:29 +0200 Subject: [PATCH 19/24] fix:spelling + docs --- CHANGELOG.md | 2 +- docs/contributing.md | 2 +- docs/custom-sampler.md | 2 +- docs/notebooks/example.ipynb | 2 +- docs/zarr-configuration.md | 2 +- pyproject.toml | 4 ++++ src/annbatch/abc/sampler.py | 2 +- src/annbatch/io.py | 4 ++-- src/annbatch/loader.py | 4 ++-- 9 files changed, 14 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c99ae29..af5adf82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -86,7 +86,7 @@ and this project adheres to [Semantic Versioning][]. ## [0.0.8] -- {class}`~annbatch.Loader` acccepts an `rng` argument now +- {class}`~annbatch.Loader` accepts an `rng` argument now ## [0.0.7] diff --git a/docs/contributing.md b/docs/contributing.md index 699d9429..9e6310cd 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -103,7 +103,7 @@ This can have undesired side-effects, such as requiring to install a lower version of a library your project depends on, only because an outdated sphinx plugin pins an older version. -To initalize a virtual environment in the `.venv` directory of your project, simply run +To initialize a virtual environment in the `.venv` directory of your project, simply run ```bash uv sync --all-extras diff --git a/docs/custom-sampler.md b/docs/custom-sampler.md index efce67f7..7d8ce922 100644 --- a/docs/custom-sampler.md +++ b/docs/custom-sampler.md @@ -41,7 +41,7 @@ This `TypedDict` is what {meth}`annbatch.abc.Sampler._sample` yields and specifi Note: The slices are purely virtual and are defined by the user through the `requests` argument. They don't necessarily need to with the underlying zarr chunks. - **Important:** The number of samples that get loaded into memory at once, must be devisible by the batch size. + **Important:** The number of samples that get loaded into memory at once, must be divisible by the batch size. Otherwise, the remainder will yield to a smaller batch size or will be dropped if `drop_last=True`. - **{attr}`~annbatch.types.LoadRequest`** (optional): A list of numpy arrays that define how the loaded data should be split into batches after being read from disk and concatenated in memory. diff --git a/docs/notebooks/example.ipynb b/docs/notebooks/example.ipynb index c71c9781..5321a846 100644 --- a/docs/notebooks/example.ipynb +++ b/docs/notebooks/example.ipynb @@ -167,7 +167,7 @@ "import anndata as ad\n", "from annbatch import DatasetCollection\n", "\n", - "# let's write out only shared colunms - otherwise DatasetCollection will warn about all the columns we are missing for good reason - mismatched columns can lead to unexpected data and missing values.\n", + "# let's write out only shared columns - otherwise DatasetCollection will warn about all the columns we are missing for good reason - mismatched columns can lead to unexpected data and missing values.\n", "shared_columns = ad.experimental.read_lazy(\"866d7d5e-436b-4dbd-b7c1-7696487d452e.h5ad\").obs.columns.intersection(\n", " ad.experimental.read_lazy(\"f81463b8-4986-4904-a0ea-20ff02cbb317.h5ad\").obs.columns\n", ")\n", diff --git a/docs/zarr-configuration.md b/docs/zarr-configuration.md index 7974deb9..8cc5e85e 100644 --- a/docs/zarr-configuration.md +++ b/docs/zarr-configuration.md @@ -7,7 +7,7 @@ import zarr zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) ``` -Otherwise normal use {mod}`zarr` without {doc}`zarrs-python ` (wich does not support, for example, remote stores). +Otherwise normal use {mod}`zarr` without {doc}`zarrs-python ` (which does not support, for example, remote stores). ## `zarrs` Performance diff --git a/pyproject.toml b/pyproject.toml index 893d83e2..813b9961 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -152,6 +152,10 @@ lint.per-file-ignores."docs/*" = [ "I" ] lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" +[tool.codespell] +ignore-words-list = "theis,coo,homogenous,GroupT,datas" +skip = ".git,*.pdf,*.svg" + [tool.mypy] [[tool.mypy.overrides]] overrides = [ { module = [ "anndata.*", "cupy.*", "cupyx.*", "h5py.*", "torch.*" ], ignore_missing_imports = true } ] diff --git a/src/annbatch/abc/sampler.py b/src/annbatch/abc/sampler.py index 8cbc58f9..028f70bf 100644 --- a/src/annbatch/abc/sampler.py +++ b/src/annbatch/abc/sampler.py @@ -106,7 +106,7 @@ def n_iters(self, n_obs: int) -> int: def sample(self, n_obs: int) -> Iterator[LoadRequest]: """Sample load requests given the total number of observations. - Base implemention simply calls :meth:`~annbatch.abc.Sampler.validate` and then yields via :meth:`~annbatch.abc.Sampler._sample`. + Base implementation simply calls :meth:`~annbatch.abc.Sampler.validate` and then yields via :meth:`~annbatch.abc.Sampler._sample`. Parameters ---------- diff --git a/src/annbatch/io.py b/src/annbatch/io.py index 5c68c554..5a3d28b9 100644 --- a/src/annbatch/io.py +++ b/src/annbatch/io.py @@ -63,7 +63,7 @@ def _default_load_adata[T: zarr.Group | h5py.Group | PathLike[str] | str](x: T) if len(adata.obs.columns) > 0: adata.obs = ad.experimental.read_elem_lazy(group["obs"], chunks=(-1,), use_range_index=True) for col in adata.obs.columns: - # Nullables / categoricals have bad perforamnce characteristics when concatenating using dask + # Nullables / categoricals have bad performance characteristics when concatenating using dask if pd.api.types.is_extension_array_dtype(adata.obs[col].dtype): adata.obs[col] = adata.obs[col].data return adata @@ -562,7 +562,7 @@ def __init__( self._group = zarr.open_group(group, mode=mode) else: warnings.warn( - "Loading h5ad is currently not supported and thus we cannot guarantee the funcionality of the ecosystem with h5ad files." + "Loading h5ad is currently not supported and thus we cannot guarantee the functionality of the ecosystem with h5ad files." "DatasetCollection should be able to handle shuffling but we guarantee little else." "Proceed with caution.", stacklevel=2, diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 24a87b7b..aa215cc0 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -355,7 +355,7 @@ def add_adatas( Parameters ---------- adatas - List of :class:`anndata.AnnData` objects, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. + List of :class:`anndata.AnnData` objects, with :class:`zarr.Array`, :class:`scipy.sparse.csr_matrix`, :class:`scipy.sparse.csr_array`, :class:`numpy.ndarray`, or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. """ check_lt_1([len(adatas)], ["Number of adatas"]) for adata in adatas: @@ -369,7 +369,7 @@ def add_adata(self, adata: ad.AnnData) -> Self: Parameters ---------- adata - A :class:`anndata.AnnData` object, with :class:`zarr.Array` or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. + A :class:`anndata.AnnData` object, with :class:`zarr.Array`, :class:`scipy.sparse.csr_matrix`, :class:`scipy.sparse.csr_array`, :class:`numpy.ndarray`, or :class:`anndata.abc.CSRDataset` as the data matrix in :attr:`~anndata.AnnData.X`, and :attr:`~anndata.AnnData.obs` containing annotations to yield in a :class:`pandas.DataFrame`. :attr:`~anndata.AnnData.var` must match the ``var`` of any previously added datasets. """ dataset, obs, var = self._prepare_dataset_obs_and_var(adata) From bcbd18d9a1a617023b543b92f497727b8cb3c194 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 16:51:40 +0200 Subject: [PATCH 20/24] fix: handle differing dtypes --- CHANGELOG.md | 2 + src/annbatch/loader.py | 131 +++++++++++++++++++++++++++++++++++++---- tests/conftest.py | 25 +++++++- 3 files changed, 146 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af5adf82..7a2c1e6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning][]. - {attr}`annbatch.types.LoadRequest.splits` now index in **request order** -- position `j` is the `j`-th observation when the request's `chunks` are concatenated in the order given. Previously, `splits` had to index into the loader's internal dataset-grouped memory layout. The {class}`~annbatch.Loader` now remaps splits to that layout itself, so custom samplers must produce chunk-order splits and stop compensating for the dataset reordering. - Deprecated `annbatch.types.LoadRequest.chunks` in favor of {attr}`annbatch.types.LoadRequest.requests`. +### Fixed +- Handling of different data types i.e., `float32` vs `float64` in the same {class}`~annbatch.Loader` ## [0.1.6] diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index aa215cc0..b4268036 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -537,6 +537,13 @@ def _requests_to_dataset_rows(self, requests: list[slice] | np.ndarray) -> Order result[ds] = global_index[order[gs:ge]] - starts[ds] return result, order + def _alloc(self, shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: + if self._preload_to_gpu: + import cupyx as cpx + + return cpx.empty_pinned(shape, dtype) + return np.empty(shape, dtype) + def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> CSRContainer | np.ndarray: """Preallocate a single contiguous output buffer covering all datasets and rows. @@ -550,13 +557,6 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> """ total_rows = sum(len(rows) for rows in dataset_index_to_rows.values()) - def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: - if self._preload_to_gpu: - import cupyx as cpx - - return cpx.empty_pinned(shape, dtype) - return np.empty(shape, dtype) - if (is_backed := issubclass(self.dataset_type, ad.abc.CSRDataset)) or issubclass( self.dataset_type, sp.csr_array | sp.csr_matrix ): @@ -571,8 +571,8 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: indptr_dtype = datasets[first_idx].indptr.dtype return CSRContainer( elems=( - _alloc((total_nnz,), data_dtype), - _alloc((total_nnz,), indices_dtype), + self._alloc((total_nnz,), data_dtype), + self._alloc((total_nnz,), indices_dtype), np.empty(total_rows + 1, dtype=indptr_dtype), ), shape=(total_rows, self.n_var), @@ -582,7 +582,97 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: first_idx = next(iter(dataset_index_to_rows)) dtype = self._train_datasets[first_idx].dtype shape_res = self._train_datasets[first_idx].shape[1:] - return _alloc((total_rows, *shape_res), dtype) + return self._alloc((total_rows, *shape_res), dtype) + + def _dtypes_homogeneous(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> bool: + """Whether all requested datasets share the same dtype(s). + + For sparse datasets the comparison covers ``data``, ``indices`` and ``indptr`` dtypes. + Must be called after :meth:`_ensure_sparse_cache` for backed-sparse datasets. + """ + idxs = list(dataset_index_to_rows.keys()) + if len(idxs) <= 1: + return True + is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) + is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix) + if is_sparse: + datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + first = datasets[idxs[0]] + return all( + datasets[i].data.dtype == first.data.dtype and datasets[i].indices.dtype == first.indices.dtype + for i in idxs[1:] + ) + first_dtype = self._train_datasets[idxs[0]].dtype + return all(self._train_datasets[i].dtype == first_dtype for i in idxs[1:]) + + def _allocate_per_dataset_outs( + self, dataset_index_to_rows: OrderedDict[int, np.ndarray] + ) -> OrderedDict[int, CSRContainer | np.ndarray]: + """Allocate one output buffer per dataset, each using that dataset's native dtype(s). + + Used when datasets have differing dtypes — the per-dataset buffers are concatenated + into a final buffer of the promoted dtype by :meth:`_concatenate_outs`. + Must be called after :meth:`_ensure_sparse_cache` for backed-sparse datasets. + """ + is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) + is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix) + outs: OrderedDict[int, CSRContainer | np.ndarray] = OrderedDict() + if is_sparse: + datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + for idx, rows in dataset_index_to_rows.items(): + ds = datasets[idx] + nnz = int((ds.indptr[rows + 1] - ds.indptr[rows]).sum()) + outs[idx] = CSRContainer( + elems=( + self._alloc((nnz,), ds.data.dtype), + self._alloc((nnz,), ds.indices.dtype), + np.empty(len(rows) + 1, dtype=ds.indptr.dtype), + ), + shape=(len(rows), self.n_var), + dtype=ds.data.dtype, + ) + else: + for idx, rows in dataset_index_to_rows.items(): + ds = self._train_datasets[idx] + outs[idx] = self._alloc((len(rows), *ds.shape[1:]), ds.dtype) + return outs + + def _concatenate_outs(self, outs: OrderedDict[int, CSRContainer | np.ndarray]) -> CSRContainer | np.ndarray: + """Concatenate per-dataset buffers into a single buffer with promoted dtype(s).""" + values = list(outs.values()) + if isinstance(values[0], CSRContainer): + data_dtype = np.result_type(*[o.elems[0].dtype for o in values]) + indices_dtype = np.result_type(*[o.elems[1].dtype for o in values]) + indptr_dtype = np.result_type(*[o.elems[2].dtype for o in values]) + total_nnz = sum(o.elems[0].size for o in values) + total_rows = sum(o.shape[0] for o in values) + data = self._alloc((total_nnz,), data_dtype) + indices = self._alloc((total_nnz,), indices_dtype) + indptr = np.empty(total_rows + 1, dtype=indptr_dtype) + indptr[0] = 0 + nnz_offset = 0 + row_offset = 0 + for o in values: + n = o.elems[0].size + r = o.shape[0] + data[nnz_offset : nnz_offset + n] = o.elems[0] + indices[nnz_offset : nnz_offset + n] = o.elems[1] + indptr[row_offset + 1 : row_offset + r + 1] = o.elems[2][1:] + nnz_offset + nnz_offset += n + row_offset += r + return CSRContainer( + elems=(data, indices, indptr), + shape=(total_rows, self.n_var), + dtype=data_dtype, + ) + dtype = np.result_type(*[o.dtype for o in values]) + total_rows = sum(o.shape[0] for o in values) + out = self._alloc((total_rows, *values[0].shape[1:]), dtype) + offset = 0 + for o in values: + out[offset : offset + o.shape[0]] = o + offset += o.shape[0] + return out @singledispatchmethod async def _fetch_data( @@ -768,6 +858,27 @@ async def _index_datasets( if is_backed_sparse: await self._ensure_sparse_cache() + if not self._dtypes_homogeneous(dataset_index_to_rows): + per_dataset_outs = self._allocate_per_dataset_outs(dataset_index_to_rows) + tasks = [ + self._fetch_data( + self._get_elem_from_cache(dataset_idx) if is_backed_sparse else self._train_datasets[dataset_idx], + rows, + per_dataset_outs[dataset_idx], + ) + for dataset_idx, rows in dataset_index_to_rows.items() + ] + await asyncio.gather(*tasks) + if is_sparse: + datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + for dataset_idx, rows in dataset_index_to_rows.items(): + sub_out = per_dataset_outs[dataset_idx] + cached_indptr = datasets[dataset_idx].indptr + per_row_nnz = cached_indptr[rows + 1] - cached_indptr[rows] + sub_out.elems[2][0] = 0 + np.cumsum(per_row_nnz, out=sub_out.elems[2][1:]) + return self._concatenate_outs(per_dataset_outs) + out = self._allocate_out(dataset_index_to_rows) tasks = [] diff --git a/tests/conftest.py b/tests/conftest.py index dadc4e43..1a0cd7b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,9 +115,9 @@ def adata_with_h5_path_different_var_space( ), tmp_path -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", params=[False, True], ids=["same-dtype", "mixed-dtype"]) def simple_collection( - tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] + request, tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] ) -> tuple[DatasetCollection, ad.AnnData]: zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "simple_fixture.zarr" @@ -128,4 +128,25 @@ def simple_collection( dataset_size=60, shuffle_chunk_size=10, ) + if request.param: + with ad.settings.override(auto_shard_zarr_v3=True, zarr_write_format=3): + # Rewrite the first dataset's X (and sparse layer) with a different dtype + # to exercise the dtype-promotion code path in Loader._concatenate_outs. + first = next(iter(collection)) + new_X = first["X"][...].astype("f8") + del first["X"] + ad.io.write_elem(first, "X", new_X) + sparse_layer = ad.io.read_elem(first["layers"]["sparse"]).astype("int64") + del first["layers"]["sparse"] + ad.io.write_elem(first["layers"], "sparse", sparse_layer) + + datasets = list(collection) + first_X_dtype = datasets[0]["X"].dtype + first_sparse_dtype = datasets[0]["layers"]["sparse"]["data"].dtype + assert any(ds["X"].dtype != first_X_dtype for ds in datasets[1:]), ( + "mixed-dtype fixture failed to produce differing X dtypes" + ) + assert any(ds["layers"]["sparse"]["data"].dtype != first_sparse_dtype for ds in datasets[1:]), ( + "mixed-dtype fixture failed to produce differing sparse layer dtypes" + ) return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection From b4f1ef23b8c2fb40b1087ad0d41dd0859fd4c459 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 16:55:49 +0200 Subject: [PATCH 21/24] chore: docs --- docs/detailed-walkthrough.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/detailed-walkthrough.md b/docs/detailed-walkthrough.md index bb8d34cf..13285169 100644 --- a/docs/detailed-walkthrough.md +++ b/docs/detailed-walkthrough.md @@ -18,6 +18,8 @@ collection = DatasetCollection("path/to/output/store.zarr").add_adatas( First, you convert your existing `.h5ad` files into a zarr-backed anndata format. In the process, the data gets shuffled and is distributed across several anndata files. Shuffling is important to ensure model convergence, especially because of our contiguous data fetching scheme which is not perfectly random. +Shuffling also helps improve performance because it ensures uniform data types across your training dataset i.e., matrices all of `float64` type. +This allows efficient preallocation of pinned memory. The output is a collection of sharded zarr anndata files, meant to reduce the burden on file systems of indexing. See the [zarr docs on sharding][] for more information. For performance considerations, see our dedicated docs page: {doc}`preshuffling`. @@ -55,6 +57,10 @@ Using {mod}`zarr` on its own will not yield high performance for local filesyste We have not tested remote data (i.e., using {func}`zarr.open` with a {class}`zarr.storage.ObjectStore`) but because we use {mod}`zarr`, this data loader should also work over cloud connections via relevant zarr stores. Note that {doc}`zarrs-python ` cannot be used with these sorts of non-local stores. +:::{important} +As mentioned above, ensuring uniform dtypes across your training collection is important for performance. +For in-memory data, this transformation should be trivial. + ### User configurable sampling strategy We support user-configurable sampling strategies like weighting or sampling by implementing the abstract {class}`annbatch.abc.Sampler`. From 19a5964e793e573ff17807ff9f6f7bb36abff57a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 17:02:56 +0200 Subject: [PATCH 22/24] fix: indptr allocation --- src/annbatch/loader.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index b4268036..3f803a70 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -537,8 +537,8 @@ def _requests_to_dataset_rows(self, requests: list[slice] | np.ndarray) -> Order result[ds] = global_index[order[gs:ge]] - starts[ds] return result, order - def _alloc(self, shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: - if self._preload_to_gpu: + def _alloc(self, shape: tuple[int, ...], dtype: np.dtype, *, use_pinned: bool) -> np.ndarray: + if use_pinned: import cupyx as cpx return cpx.empty_pinned(shape, dtype) @@ -571,8 +571,8 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> indptr_dtype = datasets[first_idx].indptr.dtype return CSRContainer( elems=( - self._alloc((total_nnz,), data_dtype), - self._alloc((total_nnz,), indices_dtype), + self._alloc((total_nnz,), data_dtype, use_pinned=self._preload_to_gpu), + self._alloc((total_nnz,), indices_dtype, use_pinned=self._preload_to_gpu), np.empty(total_rows + 1, dtype=indptr_dtype), ), shape=(total_rows, self.n_var), @@ -582,7 +582,7 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> first_idx = next(iter(dataset_index_to_rows)) dtype = self._train_datasets[first_idx].dtype shape_res = self._train_datasets[first_idx].shape[1:] - return self._alloc((total_rows, *shape_res), dtype) + return self._alloc((total_rows, *shape_res), dtype, use_pinned=self._preload_to_gpu) def _dtypes_homogeneous(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> bool: """Whether all requested datasets share the same dtype(s). @@ -624,9 +624,9 @@ def _allocate_per_dataset_outs( nnz = int((ds.indptr[rows + 1] - ds.indptr[rows]).sum()) outs[idx] = CSRContainer( elems=( - self._alloc((nnz,), ds.data.dtype), - self._alloc((nnz,), ds.indices.dtype), - np.empty(len(rows) + 1, dtype=ds.indptr.dtype), + self._alloc((nnz,), ds.data.dtype, use_pinned=False), + self._alloc((nnz,), ds.indices.dtype, use_pinned=False), + self._alloc((len(rows) + 1,), np.min_scalar_type(nnz), use_pinned=False), ), shape=(len(rows), self.n_var), dtype=ds.data.dtype, @@ -634,7 +634,7 @@ def _allocate_per_dataset_outs( else: for idx, rows in dataset_index_to_rows.items(): ds = self._train_datasets[idx] - outs[idx] = self._alloc((len(rows), *ds.shape[1:]), ds.dtype) + outs[idx] = self._alloc((len(rows), *ds.shape[1:]), ds.dtype, use_pinned=False) return outs def _concatenate_outs(self, outs: OrderedDict[int, CSRContainer | np.ndarray]) -> CSRContainer | np.ndarray: @@ -643,12 +643,11 @@ def _concatenate_outs(self, outs: OrderedDict[int, CSRContainer | np.ndarray]) - if isinstance(values[0], CSRContainer): data_dtype = np.result_type(*[o.elems[0].dtype for o in values]) indices_dtype = np.result_type(*[o.elems[1].dtype for o in values]) - indptr_dtype = np.result_type(*[o.elems[2].dtype for o in values]) total_nnz = sum(o.elems[0].size for o in values) total_rows = sum(o.shape[0] for o in values) - data = self._alloc((total_nnz,), data_dtype) - indices = self._alloc((total_nnz,), indices_dtype) - indptr = np.empty(total_rows + 1, dtype=indptr_dtype) + data = self._alloc((total_nnz,), data_dtype, use_pinned=self._preload_to_gpu) + indices = self._alloc((total_nnz,), indices_dtype, use_pinned=self._preload_to_gpu) + indptr = self._alloc((total_rows + 1,), np.min_scalar_type(total_nnz), use_pinned=self._preload_to_gpu) indptr[0] = 0 nnz_offset = 0 row_offset = 0 @@ -667,7 +666,7 @@ def _concatenate_outs(self, outs: OrderedDict[int, CSRContainer | np.ndarray]) - ) dtype = np.result_type(*[o.dtype for o in values]) total_rows = sum(o.shape[0] for o in values) - out = self._alloc((total_rows, *values[0].shape[1:]), dtype) + out = self._alloc((total_rows, *values[0].shape[1:]), dtype, use_pinned=self._preload_to_gpu) offset = 0 for o in values: out[offset : offset + o.shape[0]] = o From 0a4997e72119d800feb4ac8274f9d39360413fde Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 17:35:20 +0200 Subject: [PATCH 23/24] chore: warning --- src/annbatch/loader.py | 59 ++++++++++++++++++++++++++---------------- tests/conftest.py | 48 +++++++++++++++++++++++----------- tests/test_dataset.py | 10 ++++--- 3 files changed, 77 insertions(+), 40 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 3f803a70..54c51576 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -176,7 +176,7 @@ class Loader[ _shapes: list[tuple[int, int]] _preload_to_gpu: bool = True _to_torch: bool = True - _dataset_elem_cache: dict[int, CSRDatasetElems] + _sparse_dataset_elem_cache: dict[int, CSRDatasetElems] _batch_sampler: Sampler _collection_added: bool = False @@ -231,7 +231,7 @@ def __init__( self._to_torch = to_torch self._train_datasets = [] self._shapes = [] - self._dataset_elem_cache = {} + self._sparse_dataset_elem_cache = {} def __len__(self) -> int: return self._batch_sampler.n_batches(self.n_obs) @@ -480,6 +480,13 @@ def _add_dataset_unchecked( raise TypeError("var must be a pandas DataFrame") datasets = self._train_datasets + [dataset] check_var_shapes(datasets) + if not self._datasets_share_dtype(datasets): + warn( + f"Adding dataset with dtype {dataset.dtype!r} that differs from the existing dataset dtype(s) " + f"(first dataset: {self._train_datasets[0].dtype!r}). Heterogeneous dtypes incur extra per-batch " + "allocation and dtype promotion in the loader; consider casting all datasets to a common dtype.", + stacklevel=2, + ) self._shapes = self._shapes + [dataset.shape] self._train_datasets = datasets if self._obs is not None: # obs exist @@ -560,7 +567,7 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> if (is_backed := issubclass(self.dataset_type, ad.abc.CSRDataset)) or issubclass( self.dataset_type, sp.csr_array | sp.csr_matrix ): - datasets = self._dataset_elem_cache if is_backed else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed else self._train_datasets total_nnz = sum( int((datasets[idx].indptr[rows + 1] - datasets[idx].indptr[rows]).sum()) for idx, rows in dataset_index_to_rows.items() @@ -584,6 +591,22 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> shape_res = self._train_datasets[first_idx].shape[1:] return self._alloc((total_rows, *shape_res), dtype, use_pinned=self._preload_to_gpu) + @staticmethod + def _datasets_share_dtype(datasets: list[CSRDatasetElems | BackingArray]) -> bool: + """Whether all given dataset-like objects share the same dtype(s).""" + if len(datasets) <= 1: + return True + + def dtypes_of(d): + if isinstance(d, ad.abc.CSRDataset): + return (d.group["data"].dtype, d.group["indices"].dtype) + if hasattr(d, "data") and hasattr(d, "indices"): + return (d.data.dtype, d.indices.dtype) + return (d.dtype,) + + first = dtypes_of(datasets[0]) + return all(dtypes_of(d) == first for d in datasets[1:]) + def _dtypes_homogeneous(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> bool: """Whether all requested datasets share the same dtype(s). @@ -594,16 +617,8 @@ def _dtypes_homogeneous(self, dataset_index_to_rows: OrderedDict[int, np.ndarray if len(idxs) <= 1: return True is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) - is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix) - if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets - first = datasets[idxs[0]] - return all( - datasets[i].data.dtype == first.data.dtype and datasets[i].indices.dtype == first.indices.dtype - for i in idxs[1:] - ) - first_dtype = self._train_datasets[idxs[0]].dtype - return all(self._train_datasets[i].dtype == first_dtype for i in idxs[1:]) + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets + return self._datasets_share_dtype([datasets[i] for i in idxs]) def _allocate_per_dataset_outs( self, dataset_index_to_rows: OrderedDict[int, np.ndarray] @@ -618,7 +633,7 @@ def _allocate_per_dataset_outs( is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix) outs: OrderedDict[int, CSRContainer | np.ndarray] = OrderedDict() if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets for idx, rows in dataset_index_to_rows.items(): ds = datasets[idx] nnz = int((ds.indptr[rows + 1] - ds.indptr[rows]).sum()) @@ -747,16 +762,16 @@ async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems: async def _ensure_sparse_cache(self) -> None: """Build up the cache of datasets i.e., in-memory indptr, and backed indices and data.""" - arr_idxs = [idx for idx in range(len(self._train_datasets)) if idx not in self._dataset_elem_cache] + arr_idxs = [idx for idx in range(len(self._train_datasets)) if idx not in self._sparse_dataset_elem_cache] all_elems: list[CSRDatasetElems] = await asyncio.gather( *( self._create_sparse_elems(idx) for idx in range(len(self._train_datasets)) - if idx not in self._dataset_elem_cache + if idx not in self._sparse_dataset_elem_cache ) ) for idx, elems in zip(arr_idxs, all_elems, strict=True): - self._dataset_elem_cache[idx] = elems + self._sparse_dataset_elem_cache[idx] = elems def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray: """Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index. @@ -770,9 +785,9 @@ def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray: ------- The arrays representing the sparse data. """ - if dataset_idx not in self._dataset_elem_cache: + if dataset_idx not in self._sparse_dataset_elem_cache: raise ValueError("Cache not prepared") - return self._dataset_elem_cache[dataset_idx] + return self._sparse_dataset_elem_cache[dataset_idx] @_fetch_data.register async def _fetch_data_csr_matrix( @@ -869,7 +884,7 @@ async def _index_datasets( ] await asyncio.gather(*tasks) if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets for dataset_idx, rows in dataset_index_to_rows.items(): sub_out = per_dataset_outs[dataset_idx] cached_indptr = datasets[dataset_idx].indptr @@ -887,7 +902,7 @@ async def _index_datasets( for dataset_idx, rows in dataset_index_to_rows.items(): nrows = len(rows) if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets cached_indptr = datasets[dataset_idx].indptr nnz = int((cached_indptr[rows + 1] - cached_indptr[rows]).sum()) out_view: CSRContainer | np.ndarray = CSRContainer( @@ -915,7 +930,7 @@ async def _index_datasets( await asyncio.gather(*tasks) if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets running_nnz = 0 row_pos = 0 out.elems[2][0] = 0 diff --git a/tests/conftest.py b/tests/conftest.py index 1a0cd7b7..b81e2b8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,9 +115,9 @@ def adata_with_h5_path_different_var_space( ), tmp_path -@pytest.fixture(scope="session", params=[False, True], ids=["same-dtype", "mixed-dtype"]) +@pytest.fixture(scope="session") def simple_collection( - request, tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] + tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] ) -> tuple[DatasetCollection, ad.AnnData]: zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "simple_fixture.zarr" @@ -128,10 +128,28 @@ def simple_collection( dataset_size=60, shuffle_chunk_size=10, ) - if request.param: + return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection + + +@pytest.fixture(scope="session", params=[False, True], ids=["same-dtype", "mixed-dtype"]) +def maybe_mixed_dtype_collection( + request, tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] +) -> tuple[ad.AnnData, DatasetCollection, bool]: + """Like ``simple_collection``, but optionally rewrites the first dataset's + X (and sparse layer) with a different dtype to exercise the dtype-promotion + code path in ``Loader._concatenate_outs``. Returns ``(adata, collection, is_mixed)``.""" + zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) + output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "mixed_dtype_fixture.zarr" + collection = DatasetCollection(output_path).add_adatas( + zarr_stores, + n_obs_per_chunk=10, + shard_size=20, + dataset_size=60, + shuffle_chunk_size=10, + ) + is_mixed = bool(request.param) + if is_mixed: with ad.settings.override(auto_shard_zarr_v3=True, zarr_write_format=3): - # Rewrite the first dataset's X (and sparse layer) with a different dtype - # to exercise the dtype-promotion code path in Loader._concatenate_outs. first = next(iter(collection)) new_X = first["X"][...].astype("f8") del first["X"] @@ -140,13 +158,13 @@ def simple_collection( del first["layers"]["sparse"] ad.io.write_elem(first["layers"], "sparse", sparse_layer) - datasets = list(collection) - first_X_dtype = datasets[0]["X"].dtype - first_sparse_dtype = datasets[0]["layers"]["sparse"]["data"].dtype - assert any(ds["X"].dtype != first_X_dtype for ds in datasets[1:]), ( - "mixed-dtype fixture failed to produce differing X dtypes" - ) - assert any(ds["layers"]["sparse"]["data"].dtype != first_sparse_dtype for ds in datasets[1:]), ( - "mixed-dtype fixture failed to produce differing sparse layer dtypes" - ) - return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection + datasets = list(collection) + first_X_dtype = datasets[0]["X"].dtype + first_sparse_dtype = datasets[0]["layers"]["sparse"]["data"].dtype + assert any(ds["X"].dtype != first_X_dtype for ds in datasets[1:]), ( + "mixed-dtype fixture failed to produce differing X dtypes" + ) + assert any(ds["layers"]["sparse"]["data"].dtype != first_sparse_dtype for ds in datasets[1:]), ( + "mixed-dtype fixture failed to produce differing sparse layer dtypes" + ) + return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection, is_mixed diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 370c9210..fe3b8fc2 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -204,7 +204,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ], ) def test_store_load_dataset( - simple_collection: tuple[ad.AnnData, DatasetCollection], + maybe_mixed_dtype_collection: tuple[ad.AnnData, DatasetCollection, bool], *, shuffle: bool, gen_loader, @@ -217,10 +217,14 @@ def test_store_load_dataset( 3. All samples from the dataset are processed 4. If the dataset is not shuffled, it returns the correct data """ - loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) + adata, collection, is_mixed = maybe_mixed_dtype_collection + if is_mixed: + with pytest.warns(UserWarning, match="Adding dataset with dtype"): + loader: Loader = gen_loader(collection, shuffle, use_zarrs) + else: + loader: Loader = gen_loader(collection, shuffle, use_zarrs) if use_zarrs and loader.dataset_type in {np.ndarray, sp.csr_matrix, sp.csr_array}: pytest.skip("No need to run zarrs with in-memory") - adata = simple_collection[0] is_dense = loader.dataset_type in {zarr.Array, np.ndarray} n_elems = 0 batches = [] From 1a12389276e2dfc329d61c53a3ca9188208cbee3 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 11 Jun 2026 17:39:28 +0200 Subject: [PATCH 24/24] chore: cache --- src/annbatch/loader.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index 54c51576..2286c761 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -179,6 +179,7 @@ class Loader[ _sparse_dataset_elem_cache: dict[int, CSRDatasetElems] _batch_sampler: Sampler _collection_added: bool = False + _dtypes_homogeneous: bool = True def __init__( self, @@ -480,7 +481,8 @@ def _add_dataset_unchecked( raise TypeError("var must be a pandas DataFrame") datasets = self._train_datasets + [dataset] check_var_shapes(datasets) - if not self._datasets_share_dtype(datasets): + self._dtypes_homogeneous = self._datasets_share_dtype(datasets) + if self._train_datasets and not self._dtypes_homogeneous: warn( f"Adding dataset with dtype {dataset.dtype!r} that differs from the existing dataset dtype(s) " f"(first dataset: {self._train_datasets[0].dtype!r}). Heterogeneous dtypes incur extra per-batch " @@ -592,7 +594,7 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> return self._alloc((total_rows, *shape_res), dtype, use_pinned=self._preload_to_gpu) @staticmethod - def _datasets_share_dtype(datasets: list[CSRDatasetElems | BackingArray]) -> bool: + def _datasets_share_dtype(datasets: list[BackingArray]) -> bool: """Whether all given dataset-like objects share the same dtype(s).""" if len(datasets) <= 1: return True @@ -607,19 +609,6 @@ def dtypes_of(d): first = dtypes_of(datasets[0]) return all(dtypes_of(d) == first for d in datasets[1:]) - def _dtypes_homogeneous(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> bool: - """Whether all requested datasets share the same dtype(s). - - For sparse datasets the comparison covers ``data``, ``indices`` and ``indptr`` dtypes. - Must be called after :meth:`_ensure_sparse_cache` for backed-sparse datasets. - """ - idxs = list(dataset_index_to_rows.keys()) - if len(idxs) <= 1: - return True - is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) - datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets - return self._datasets_share_dtype([datasets[i] for i in idxs]) - def _allocate_per_dataset_outs( self, dataset_index_to_rows: OrderedDict[int, np.ndarray] ) -> OrderedDict[int, CSRContainer | np.ndarray]: @@ -872,7 +861,7 @@ async def _index_datasets( if is_backed_sparse: await self._ensure_sparse_cache() - if not self._dtypes_homogeneous(dataset_index_to_rows): + if not self._dtypes_homogeneous: per_dataset_outs = self._allocate_per_dataset_outs(dataset_index_to_rows) tasks = [ self._fetch_data(