diff --git a/CHANGELOG.md b/CHANGELOG.md index 9cfe1bfe..089d0535 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning][]. - {attr}`annbatch.types.LoadRequest.splits` now index in **request order** -- position `j` is the `j`-th observation when the request's `chunks` are concatenated in the order given. Previously, `splits` had to index into the loader's internal dataset-grouped memory layout. The {class}`~annbatch.Loader` now remaps splits to that layout itself, so custom samplers must produce chunk-order splits and stop compensating for the dataset reordering. - Deprecated `annbatch.types.LoadRequest.chunks` in favor of {attr}`annbatch.types.LoadRequest.requests`. +### Fixed +- Handling of different data types i.e., `float32` vs `float64` in the same {class}`~annbatch.Loader` ## [0.1.6] diff --git a/docs/detailed-walkthrough.md b/docs/detailed-walkthrough.md index bb8d34cf..13285169 100644 --- a/docs/detailed-walkthrough.md +++ b/docs/detailed-walkthrough.md @@ -18,6 +18,8 @@ collection = DatasetCollection("path/to/output/store.zarr").add_adatas( First, you convert your existing `.h5ad` files into a zarr-backed anndata format. In the process, the data gets shuffled and is distributed across several anndata files. Shuffling is important to ensure model convergence, especially because of our contiguous data fetching scheme which is not perfectly random. +Shuffling also helps improve performance because it ensures uniform data types across your training dataset i.e., matrices all of `float64` type. +This allows efficient preallocation of pinned memory. The output is a collection of sharded zarr anndata files, meant to reduce the burden on file systems of indexing. See the [zarr docs on sharding][] for more information. For performance considerations, see our dedicated docs page: {doc}`preshuffling`. @@ -55,6 +57,10 @@ Using {mod}`zarr` on its own will not yield high performance for local filesyste We have not tested remote data (i.e., using {func}`zarr.open` with a {class}`zarr.storage.ObjectStore`) but because we use {mod}`zarr`, this data loader should also work over cloud connections via relevant zarr stores. Note that {doc}`zarrs-python ` cannot be used with these sorts of non-local stores. +:::{important} +As mentioned above, ensuring uniform dtypes across your training collection is important for performance. +For in-memory data, this transformation should be trivial. + ### User configurable sampling strategy We support user-configurable sampling strategies like weighting or sampling by implementing the abstract {class}`annbatch.abc.Sampler`. diff --git a/src/annbatch/loader.py b/src/annbatch/loader.py index aa215cc0..2286c761 100644 --- a/src/annbatch/loader.py +++ b/src/annbatch/loader.py @@ -176,9 +176,10 @@ class Loader[ _shapes: list[tuple[int, int]] _preload_to_gpu: bool = True _to_torch: bool = True - _dataset_elem_cache: dict[int, CSRDatasetElems] + _sparse_dataset_elem_cache: dict[int, CSRDatasetElems] _batch_sampler: Sampler _collection_added: bool = False + _dtypes_homogeneous: bool = True def __init__( self, @@ -231,7 +232,7 @@ def __init__( self._to_torch = to_torch self._train_datasets = [] self._shapes = [] - self._dataset_elem_cache = {} + self._sparse_dataset_elem_cache = {} def __len__(self) -> int: return self._batch_sampler.n_batches(self.n_obs) @@ -480,6 +481,14 @@ def _add_dataset_unchecked( raise TypeError("var must be a pandas DataFrame") datasets = self._train_datasets + [dataset] check_var_shapes(datasets) + self._dtypes_homogeneous = self._datasets_share_dtype(datasets) + if self._train_datasets and not self._dtypes_homogeneous: + warn( + f"Adding dataset with dtype {dataset.dtype!r} that differs from the existing dataset dtype(s) " + f"(first dataset: {self._train_datasets[0].dtype!r}). Heterogeneous dtypes incur extra per-batch " + "allocation and dtype promotion in the loader; consider casting all datasets to a common dtype.", + stacklevel=2, + ) self._shapes = self._shapes + [dataset.shape] self._train_datasets = datasets if self._obs is not None: # obs exist @@ -537,6 +546,13 @@ def _requests_to_dataset_rows(self, requests: list[slice] | np.ndarray) -> Order result[ds] = global_index[order[gs:ge]] - starts[ds] return result, order + def _alloc(self, shape: tuple[int, ...], dtype: np.dtype, *, use_pinned: bool) -> np.ndarray: + if use_pinned: + import cupyx as cpx + + return cpx.empty_pinned(shape, dtype) + return np.empty(shape, dtype) + def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> CSRContainer | np.ndarray: """Preallocate a single contiguous output buffer covering all datasets and rows. @@ -550,17 +566,10 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> """ total_rows = sum(len(rows) for rows in dataset_index_to_rows.values()) - def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: - if self._preload_to_gpu: - import cupyx as cpx - - return cpx.empty_pinned(shape, dtype) - return np.empty(shape, dtype) - if (is_backed := issubclass(self.dataset_type, ad.abc.CSRDataset)) or issubclass( self.dataset_type, sp.csr_array | sp.csr_matrix ): - datasets = self._dataset_elem_cache if is_backed else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed else self._train_datasets total_nnz = sum( int((datasets[idx].indptr[rows + 1] - datasets[idx].indptr[rows]).sum()) for idx, rows in dataset_index_to_rows.items() @@ -571,8 +580,8 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: indptr_dtype = datasets[first_idx].indptr.dtype return CSRContainer( elems=( - _alloc((total_nnz,), data_dtype), - _alloc((total_nnz,), indices_dtype), + self._alloc((total_nnz,), data_dtype, use_pinned=self._preload_to_gpu), + self._alloc((total_nnz,), indices_dtype, use_pinned=self._preload_to_gpu), np.empty(total_rows + 1, dtype=indptr_dtype), ), shape=(total_rows, self.n_var), @@ -582,7 +591,91 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray: first_idx = next(iter(dataset_index_to_rows)) dtype = self._train_datasets[first_idx].dtype shape_res = self._train_datasets[first_idx].shape[1:] - return _alloc((total_rows, *shape_res), dtype) + return self._alloc((total_rows, *shape_res), dtype, use_pinned=self._preload_to_gpu) + + @staticmethod + def _datasets_share_dtype(datasets: list[BackingArray]) -> bool: + """Whether all given dataset-like objects share the same dtype(s).""" + if len(datasets) <= 1: + return True + + def dtypes_of(d): + if isinstance(d, ad.abc.CSRDataset): + return (d.group["data"].dtype, d.group["indices"].dtype) + if hasattr(d, "data") and hasattr(d, "indices"): + return (d.data.dtype, d.indices.dtype) + return (d.dtype,) + + first = dtypes_of(datasets[0]) + return all(dtypes_of(d) == first for d in datasets[1:]) + + def _allocate_per_dataset_outs( + self, dataset_index_to_rows: OrderedDict[int, np.ndarray] + ) -> OrderedDict[int, CSRContainer | np.ndarray]: + """Allocate one output buffer per dataset, each using that dataset's native dtype(s). + + Used when datasets have differing dtypes — the per-dataset buffers are concatenated + into a final buffer of the promoted dtype by :meth:`_concatenate_outs`. + Must be called after :meth:`_ensure_sparse_cache` for backed-sparse datasets. + """ + is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset) + is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix) + outs: OrderedDict[int, CSRContainer | np.ndarray] = OrderedDict() + if is_sparse: + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets + for idx, rows in dataset_index_to_rows.items(): + ds = datasets[idx] + nnz = int((ds.indptr[rows + 1] - ds.indptr[rows]).sum()) + outs[idx] = CSRContainer( + elems=( + self._alloc((nnz,), ds.data.dtype, use_pinned=False), + self._alloc((nnz,), ds.indices.dtype, use_pinned=False), + self._alloc((len(rows) + 1,), np.min_scalar_type(nnz), use_pinned=False), + ), + shape=(len(rows), self.n_var), + dtype=ds.data.dtype, + ) + else: + for idx, rows in dataset_index_to_rows.items(): + ds = self._train_datasets[idx] + outs[idx] = self._alloc((len(rows), *ds.shape[1:]), ds.dtype, use_pinned=False) + return outs + + def _concatenate_outs(self, outs: OrderedDict[int, CSRContainer | np.ndarray]) -> CSRContainer | np.ndarray: + """Concatenate per-dataset buffers into a single buffer with promoted dtype(s).""" + values = list(outs.values()) + if isinstance(values[0], CSRContainer): + data_dtype = np.result_type(*[o.elems[0].dtype for o in values]) + indices_dtype = np.result_type(*[o.elems[1].dtype for o in values]) + total_nnz = sum(o.elems[0].size for o in values) + total_rows = sum(o.shape[0] for o in values) + data = self._alloc((total_nnz,), data_dtype, use_pinned=self._preload_to_gpu) + indices = self._alloc((total_nnz,), indices_dtype, use_pinned=self._preload_to_gpu) + indptr = self._alloc((total_rows + 1,), np.min_scalar_type(total_nnz), use_pinned=self._preload_to_gpu) + indptr[0] = 0 + nnz_offset = 0 + row_offset = 0 + for o in values: + n = o.elems[0].size + r = o.shape[0] + data[nnz_offset : nnz_offset + n] = o.elems[0] + indices[nnz_offset : nnz_offset + n] = o.elems[1] + indptr[row_offset + 1 : row_offset + r + 1] = o.elems[2][1:] + nnz_offset + nnz_offset += n + row_offset += r + return CSRContainer( + elems=(data, indices, indptr), + shape=(total_rows, self.n_var), + dtype=data_dtype, + ) + dtype = np.result_type(*[o.dtype for o in values]) + total_rows = sum(o.shape[0] for o in values) + out = self._alloc((total_rows, *values[0].shape[1:]), dtype, use_pinned=self._preload_to_gpu) + offset = 0 + for o in values: + out[offset : offset + o.shape[0]] = o + offset += o.shape[0] + return out @singledispatchmethod async def _fetch_data( @@ -658,16 +751,16 @@ async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems: async def _ensure_sparse_cache(self) -> None: """Build up the cache of datasets i.e., in-memory indptr, and backed indices and data.""" - arr_idxs = [idx for idx in range(len(self._train_datasets)) if idx not in self._dataset_elem_cache] + arr_idxs = [idx for idx in range(len(self._train_datasets)) if idx not in self._sparse_dataset_elem_cache] all_elems: list[CSRDatasetElems] = await asyncio.gather( *( self._create_sparse_elems(idx) for idx in range(len(self._train_datasets)) - if idx not in self._dataset_elem_cache + if idx not in self._sparse_dataset_elem_cache ) ) for idx, elems in zip(arr_idxs, all_elems, strict=True): - self._dataset_elem_cache[idx] = elems + self._sparse_dataset_elem_cache[idx] = elems def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray: """Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index. @@ -681,9 +774,9 @@ def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray: ------- The arrays representing the sparse data. """ - if dataset_idx not in self._dataset_elem_cache: + if dataset_idx not in self._sparse_dataset_elem_cache: raise ValueError("Cache not prepared") - return self._dataset_elem_cache[dataset_idx] + return self._sparse_dataset_elem_cache[dataset_idx] @_fetch_data.register async def _fetch_data_csr_matrix( @@ -768,6 +861,27 @@ async def _index_datasets( if is_backed_sparse: await self._ensure_sparse_cache() + if not self._dtypes_homogeneous: + per_dataset_outs = self._allocate_per_dataset_outs(dataset_index_to_rows) + tasks = [ + self._fetch_data( + self._get_elem_from_cache(dataset_idx) if is_backed_sparse else self._train_datasets[dataset_idx], + rows, + per_dataset_outs[dataset_idx], + ) + for dataset_idx, rows in dataset_index_to_rows.items() + ] + await asyncio.gather(*tasks) + if is_sparse: + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets + for dataset_idx, rows in dataset_index_to_rows.items(): + sub_out = per_dataset_outs[dataset_idx] + cached_indptr = datasets[dataset_idx].indptr + per_row_nnz = cached_indptr[rows + 1] - cached_indptr[rows] + sub_out.elems[2][0] = 0 + np.cumsum(per_row_nnz, out=sub_out.elems[2][1:]) + return self._concatenate_outs(per_dataset_outs) + out = self._allocate_out(dataset_index_to_rows) tasks = [] @@ -777,7 +891,7 @@ async def _index_datasets( for dataset_idx, rows in dataset_index_to_rows.items(): nrows = len(rows) if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets cached_indptr = datasets[dataset_idx].indptr nnz = int((cached_indptr[rows + 1] - cached_indptr[rows]).sum()) out_view: CSRContainer | np.ndarray = CSRContainer( @@ -805,7 +919,7 @@ async def _index_datasets( await asyncio.gather(*tasks) if is_sparse: - datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets + datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets running_nnz = 0 row_pos = 0 out.elems[2][0] = 0 diff --git a/tests/conftest.py b/tests/conftest.py index dadc4e43..b81e2b8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -129,3 +129,42 @@ def simple_collection( shuffle_chunk_size=10, ) return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection + + +@pytest.fixture(scope="session", params=[False, True], ids=["same-dtype", "mixed-dtype"]) +def maybe_mixed_dtype_collection( + request, tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path] +) -> tuple[ad.AnnData, DatasetCollection, bool]: + """Like ``simple_collection``, but optionally rewrites the first dataset's + X (and sparse layer) with a different dtype to exercise the dtype-promotion + code path in ``Loader._concatenate_outs``. Returns ``(adata, collection, is_mixed)``.""" + zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir()) + output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "mixed_dtype_fixture.zarr" + collection = DatasetCollection(output_path).add_adatas( + zarr_stores, + n_obs_per_chunk=10, + shard_size=20, + dataset_size=60, + shuffle_chunk_size=10, + ) + is_mixed = bool(request.param) + if is_mixed: + with ad.settings.override(auto_shard_zarr_v3=True, zarr_write_format=3): + first = next(iter(collection)) + new_X = first["X"][...].astype("f8") + del first["X"] + ad.io.write_elem(first, "X", new_X) + sparse_layer = ad.io.read_elem(first["layers"]["sparse"]).astype("int64") + del first["layers"]["sparse"] + ad.io.write_elem(first["layers"], "sparse", sparse_layer) + + datasets = list(collection) + first_X_dtype = datasets[0]["X"].dtype + first_sparse_dtype = datasets[0]["layers"]["sparse"]["data"].dtype + assert any(ds["X"].dtype != first_X_dtype for ds in datasets[1:]), ( + "mixed-dtype fixture failed to produce differing X dtypes" + ) + assert any(ds["layers"]["sparse"]["data"].dtype != first_sparse_dtype for ds in datasets[1:]), ( + "mixed-dtype fixture failed to produce differing sparse layer dtypes" + ) + return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection, is_mixed diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 370c9210..fe3b8fc2 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -204,7 +204,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]: ], ) def test_store_load_dataset( - simple_collection: tuple[ad.AnnData, DatasetCollection], + maybe_mixed_dtype_collection: tuple[ad.AnnData, DatasetCollection, bool], *, shuffle: bool, gen_loader, @@ -217,10 +217,14 @@ def test_store_load_dataset( 3. All samples from the dataset are processed 4. If the dataset is not shuffled, it returns the correct data """ - loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs) + adata, collection, is_mixed = maybe_mixed_dtype_collection + if is_mixed: + with pytest.warns(UserWarning, match="Adding dataset with dtype"): + loader: Loader = gen_loader(collection, shuffle, use_zarrs) + else: + loader: Loader = gen_loader(collection, shuffle, use_zarrs) if use_zarrs and loader.dataset_type in {np.ndarray, sp.csr_matrix, sp.csr_array}: pytest.skip("No need to run zarrs with in-memory") - adata = simple_collection[0] is_dense = loader.dataset_type in {zarr.Array, np.ndarray} n_elems = 0 batches = []