Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ and this project adheres to [Semantic Versioning][].
- {attr}`annbatch.types.LoadRequest.splits` now index in **request order** -- position `j` is the `j`-th observation when the request's `chunks` are concatenated in the order given. Previously, `splits` had to index into the loader's internal dataset-grouped memory layout. The {class}`~annbatch.Loader` now remaps splits to that layout itself, so custom samplers must produce chunk-order splits and stop compensating for the dataset reordering.
- Deprecated `annbatch.types.LoadRequest.chunks` in favor of {attr}`annbatch.types.LoadRequest.requests`.

### Fixed
- Handling of different data types i.e., `float32` vs `float64` in the same {class}`~annbatch.Loader`

## [0.1.6]

Expand Down
6 changes: 6 additions & 0 deletions docs/detailed-walkthrough.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ collection = DatasetCollection("path/to/output/store.zarr").add_adatas(
First, you convert your existing `.h5ad` files into a zarr-backed anndata format.
In the process, the data gets shuffled and is distributed across several anndata files.
Shuffling is important to ensure model convergence, especially because of our contiguous data fetching scheme which is not perfectly random.
Shuffling also helps improve performance because it ensures uniform data types across your training dataset i.e., matrices all of `float64` type.
This allows efficient preallocation of pinned memory.
The output is a collection of sharded zarr anndata files, meant to reduce the burden on file systems of indexing.
See the [zarr docs on sharding][] for more information.
For performance considerations, see our dedicated docs page: {doc}`preshuffling`.
Expand Down Expand Up @@ -55,6 +57,10 @@ Using {mod}`zarr` on its own will not yield high performance for local filesyste
We have not tested remote data (i.e., using {func}`zarr.open` with a {class}`zarr.storage.ObjectStore`) but because we use {mod}`zarr`, this data loader should also work over cloud connections via relevant zarr stores.
Note that {doc}`zarrs-python <zarrs:index>` cannot be used with these sorts of non-local stores.

:::{important}
As mentioned above, ensuring uniform dtypes across your training collection is important for performance.
For in-memory data, this transformation should be trivial.

### User configurable sampling strategy

We support user-configurable sampling strategies like weighting or sampling by implementing the abstract {class}`annbatch.abc.Sampler`.
Expand Down
154 changes: 134 additions & 20 deletions src/annbatch/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,10 @@ class Loader[
_shapes: list[tuple[int, int]]
_preload_to_gpu: bool = True
_to_torch: bool = True
_dataset_elem_cache: dict[int, CSRDatasetElems]
_sparse_dataset_elem_cache: dict[int, CSRDatasetElems]
_batch_sampler: Sampler
_collection_added: bool = False
_dtypes_homogeneous: bool = True

def __init__(
self,
Expand Down Expand Up @@ -231,7 +232,7 @@ def __init__(
self._to_torch = to_torch
self._train_datasets = []
self._shapes = []
self._dataset_elem_cache = {}
self._sparse_dataset_elem_cache = {}

def __len__(self) -> int:
return self._batch_sampler.n_batches(self.n_obs)
Expand Down Expand Up @@ -480,6 +481,14 @@ def _add_dataset_unchecked(
raise TypeError("var must be a pandas DataFrame")
datasets = self._train_datasets + [dataset]
check_var_shapes(datasets)
self._dtypes_homogeneous = self._datasets_share_dtype(datasets)
if self._train_datasets and not self._dtypes_homogeneous:
warn(
f"Adding dataset with dtype {dataset.dtype!r} that differs from the existing dataset dtype(s) "
f"(first dataset: {self._train_datasets[0].dtype!r}). Heterogeneous dtypes incur extra per-batch "
"allocation and dtype promotion in the loader; consider casting all datasets to a common dtype.",
stacklevel=2,
)
self._shapes = self._shapes + [dataset.shape]
self._train_datasets = datasets
if self._obs is not None: # obs exist
Expand Down Expand Up @@ -537,6 +546,13 @@ def _requests_to_dataset_rows(self, requests: list[slice] | np.ndarray) -> Order
result[ds] = global_index[order[gs:ge]] - starts[ds]
return result, order

def _alloc(self, shape: tuple[int, ...], dtype: np.dtype, *, use_pinned: bool) -> np.ndarray:
if use_pinned:
import cupyx as cpx

return cpx.empty_pinned(shape, dtype)
return np.empty(shape, dtype)

def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) -> CSRContainer | np.ndarray:
"""Preallocate a single contiguous output buffer covering all datasets and rows.

Expand All @@ -550,17 +566,10 @@ def _allocate_out(self, dataset_index_to_rows: OrderedDict[int, np.ndarray]) ->
"""
total_rows = sum(len(rows) for rows in dataset_index_to_rows.values())

def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray:
if self._preload_to_gpu:
import cupyx as cpx

return cpx.empty_pinned(shape, dtype)
return np.empty(shape, dtype)

if (is_backed := issubclass(self.dataset_type, ad.abc.CSRDataset)) or issubclass(
self.dataset_type, sp.csr_array | sp.csr_matrix
):
datasets = self._dataset_elem_cache if is_backed else self._train_datasets
datasets = self._sparse_dataset_elem_cache if is_backed else self._train_datasets
total_nnz = sum(
int((datasets[idx].indptr[rows + 1] - datasets[idx].indptr[rows]).sum())
for idx, rows in dataset_index_to_rows.items()
Expand All @@ -571,8 +580,8 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray:
indptr_dtype = datasets[first_idx].indptr.dtype
return CSRContainer(
elems=(
_alloc((total_nnz,), data_dtype),
_alloc((total_nnz,), indices_dtype),
self._alloc((total_nnz,), data_dtype, use_pinned=self._preload_to_gpu),
self._alloc((total_nnz,), indices_dtype, use_pinned=self._preload_to_gpu),
np.empty(total_rows + 1, dtype=indptr_dtype),
),
shape=(total_rows, self.n_var),
Expand All @@ -582,7 +591,91 @@ def _alloc(shape: tuple[int, ...], dtype: np.dtype) -> np.ndarray:
first_idx = next(iter(dataset_index_to_rows))
dtype = self._train_datasets[first_idx].dtype
shape_res = self._train_datasets[first_idx].shape[1:]
return _alloc((total_rows, *shape_res), dtype)
return self._alloc((total_rows, *shape_res), dtype, use_pinned=self._preload_to_gpu)

@staticmethod
def _datasets_share_dtype(datasets: list[BackingArray]) -> bool:
"""Whether all given dataset-like objects share the same dtype(s)."""
if len(datasets) <= 1:
return True

def dtypes_of(d):
if isinstance(d, ad.abc.CSRDataset):
return (d.group["data"].dtype, d.group["indices"].dtype)
if hasattr(d, "data") and hasattr(d, "indices"):
return (d.data.dtype, d.indices.dtype)
return (d.dtype,)

first = dtypes_of(datasets[0])
return all(dtypes_of(d) == first for d in datasets[1:])

def _allocate_per_dataset_outs(
self, dataset_index_to_rows: OrderedDict[int, np.ndarray]
) -> OrderedDict[int, CSRContainer | np.ndarray]:
"""Allocate one output buffer per dataset, each using that dataset's native dtype(s).

Used when datasets have differing dtypes — the per-dataset buffers are concatenated
into a final buffer of the promoted dtype by :meth:`_concatenate_outs`.
Must be called after :meth:`_ensure_sparse_cache` for backed-sparse datasets.
"""
is_backed_sparse = issubclass(self.dataset_type, ad.abc.CSRDataset)
is_sparse = is_backed_sparse or issubclass(self.dataset_type, sp.csr_array | sp.csr_matrix)
outs: OrderedDict[int, CSRContainer | np.ndarray] = OrderedDict()
if is_sparse:
datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets
for idx, rows in dataset_index_to_rows.items():
ds = datasets[idx]
nnz = int((ds.indptr[rows + 1] - ds.indptr[rows]).sum())
outs[idx] = CSRContainer(
elems=(
self._alloc((nnz,), ds.data.dtype, use_pinned=False),
self._alloc((nnz,), ds.indices.dtype, use_pinned=False),
self._alloc((len(rows) + 1,), np.min_scalar_type(nnz), use_pinned=False),
),
shape=(len(rows), self.n_var),
dtype=ds.data.dtype,
)
else:
for idx, rows in dataset_index_to_rows.items():
ds = self._train_datasets[idx]
outs[idx] = self._alloc((len(rows), *ds.shape[1:]), ds.dtype, use_pinned=False)
return outs

def _concatenate_outs(self, outs: OrderedDict[int, CSRContainer | np.ndarray]) -> CSRContainer | np.ndarray:
"""Concatenate per-dataset buffers into a single buffer with promoted dtype(s)."""
values = list(outs.values())
if isinstance(values[0], CSRContainer):
data_dtype = np.result_type(*[o.elems[0].dtype for o in values])
indices_dtype = np.result_type(*[o.elems[1].dtype for o in values])
total_nnz = sum(o.elems[0].size for o in values)
total_rows = sum(o.shape[0] for o in values)
data = self._alloc((total_nnz,), data_dtype, use_pinned=self._preload_to_gpu)
indices = self._alloc((total_nnz,), indices_dtype, use_pinned=self._preload_to_gpu)
indptr = self._alloc((total_rows + 1,), np.min_scalar_type(total_nnz), use_pinned=self._preload_to_gpu)
indptr[0] = 0
nnz_offset = 0
row_offset = 0
for o in values:
n = o.elems[0].size
r = o.shape[0]
data[nnz_offset : nnz_offset + n] = o.elems[0]
indices[nnz_offset : nnz_offset + n] = o.elems[1]
indptr[row_offset + 1 : row_offset + r + 1] = o.elems[2][1:] + nnz_offset
nnz_offset += n
row_offset += r
return CSRContainer(
elems=(data, indices, indptr),
shape=(total_rows, self.n_var),
dtype=data_dtype,
)
dtype = np.result_type(*[o.dtype for o in values])
total_rows = sum(o.shape[0] for o in values)
out = self._alloc((total_rows, *values[0].shape[1:]), dtype, use_pinned=self._preload_to_gpu)
offset = 0
for o in values:
out[offset : offset + o.shape[0]] = o
offset += o.shape[0]
return out

@singledispatchmethod
async def _fetch_data(
Expand Down Expand Up @@ -658,16 +751,16 @@ async def _create_sparse_elems(self, idx: int) -> CSRDatasetElems:

async def _ensure_sparse_cache(self) -> None:
"""Build up the cache of datasets i.e., in-memory indptr, and backed indices and data."""
arr_idxs = [idx for idx in range(len(self._train_datasets)) if idx not in self._dataset_elem_cache]
arr_idxs = [idx for idx in range(len(self._train_datasets)) if idx not in self._sparse_dataset_elem_cache]
all_elems: list[CSRDatasetElems] = await asyncio.gather(
*(
self._create_sparse_elems(idx)
for idx in range(len(self._train_datasets))
if idx not in self._dataset_elem_cache
if idx not in self._sparse_dataset_elem_cache
)
)
for idx, elems in zip(arr_idxs, all_elems, strict=True):
self._dataset_elem_cache[idx] = elems
self._sparse_dataset_elem_cache[idx] = elems

def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray:
"""Return the arrays (zarr or otherwise) needed to represent on-disk data at a given index.
Expand All @@ -681,9 +774,9 @@ def _get_elem_from_cache(self, dataset_idx: int) -> CSRDatasetElems | ZarrArray:
-------
The arrays representing the sparse data.
"""
if dataset_idx not in self._dataset_elem_cache:
if dataset_idx not in self._sparse_dataset_elem_cache:
raise ValueError("Cache not prepared")
return self._dataset_elem_cache[dataset_idx]
return self._sparse_dataset_elem_cache[dataset_idx]

@_fetch_data.register
async def _fetch_data_csr_matrix(
Expand Down Expand Up @@ -768,6 +861,27 @@ async def _index_datasets(
if is_backed_sparse:
await self._ensure_sparse_cache()

if not self._dtypes_homogeneous:
per_dataset_outs = self._allocate_per_dataset_outs(dataset_index_to_rows)
tasks = [
self._fetch_data(
self._get_elem_from_cache(dataset_idx) if is_backed_sparse else self._train_datasets[dataset_idx],
rows,
per_dataset_outs[dataset_idx],
)
for dataset_idx, rows in dataset_index_to_rows.items()
]
await asyncio.gather(*tasks)
if is_sparse:
datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets
for dataset_idx, rows in dataset_index_to_rows.items():
sub_out = per_dataset_outs[dataset_idx]
cached_indptr = datasets[dataset_idx].indptr
per_row_nnz = cached_indptr[rows + 1] - cached_indptr[rows]
sub_out.elems[2][0] = 0
np.cumsum(per_row_nnz, out=sub_out.elems[2][1:])
return self._concatenate_outs(per_dataset_outs)

out = self._allocate_out(dataset_index_to_rows)

tasks = []
Expand All @@ -777,7 +891,7 @@ async def _index_datasets(
for dataset_idx, rows in dataset_index_to_rows.items():
nrows = len(rows)
if is_sparse:
datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets
datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets
cached_indptr = datasets[dataset_idx].indptr
nnz = int((cached_indptr[rows + 1] - cached_indptr[rows]).sum())
out_view: CSRContainer | np.ndarray = CSRContainer(
Expand Down Expand Up @@ -805,7 +919,7 @@ async def _index_datasets(
await asyncio.gather(*tasks)

if is_sparse:
datasets = self._dataset_elem_cache if is_backed_sparse else self._train_datasets
datasets = self._sparse_dataset_elem_cache if is_backed_sparse else self._train_datasets
running_nnz = 0
row_pos = 0
out.elems[2][0] = 0
Expand Down
39 changes: 39 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,42 @@ def simple_collection(
shuffle_chunk_size=10,
)
return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection


@pytest.fixture(scope="session", params=[False, True], ids=["same-dtype", "mixed-dtype"])
def maybe_mixed_dtype_collection(
request, tmpdir_factory, adata_with_zarr_path_same_var_space: tuple[ad.AnnData, Path]
) -> tuple[ad.AnnData, DatasetCollection, bool]:
"""Like ``simple_collection``, but optionally rewrites the first dataset's
X (and sparse layer) with a different dtype to exercise the dtype-promotion
code path in ``Loader._concatenate_outs``. Returns ``(adata, collection, is_mixed)``."""
zarr_stores = sorted(f for f in adata_with_zarr_path_same_var_space[1].iterdir() if f.is_dir())
output_path = Path(tmpdir_factory.mktemp("zarr_folder")) / "mixed_dtype_fixture.zarr"
collection = DatasetCollection(output_path).add_adatas(
zarr_stores,
n_obs_per_chunk=10,
shard_size=20,
dataset_size=60,
shuffle_chunk_size=10,
)
is_mixed = bool(request.param)
if is_mixed:
with ad.settings.override(auto_shard_zarr_v3=True, zarr_write_format=3):
first = next(iter(collection))
new_X = first["X"][...].astype("f8")
del first["X"]
ad.io.write_elem(first, "X", new_X)
sparse_layer = ad.io.read_elem(first["layers"]["sparse"]).astype("int64")
del first["layers"]["sparse"]
ad.io.write_elem(first["layers"], "sparse", sparse_layer)

datasets = list(collection)
first_X_dtype = datasets[0]["X"].dtype
first_sparse_dtype = datasets[0]["layers"]["sparse"]["data"].dtype
assert any(ds["X"].dtype != first_X_dtype for ds in datasets[1:]), (
"mixed-dtype fixture failed to produce differing X dtypes"
)
assert any(ds["layers"]["sparse"]["data"].dtype != first_sparse_dtype for ds in datasets[1:]), (
"mixed-dtype fixture failed to produce differing sparse layer dtypes"
)
return ad.concat([ad.io.read_elem(ds) for ds in collection], join="outer"), collection, is_mixed
10 changes: 7 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def concat(datas: list[Data | ad.AnnData]) -> ListData | list[ad.AnnData]:
],
)
def test_store_load_dataset(
simple_collection: tuple[ad.AnnData, DatasetCollection],
maybe_mixed_dtype_collection: tuple[ad.AnnData, DatasetCollection, bool],
*,
shuffle: bool,
gen_loader,
Expand All @@ -217,10 +217,14 @@ def test_store_load_dataset(
3. All samples from the dataset are processed
4. If the dataset is not shuffled, it returns the correct data
"""
loader: Loader = gen_loader(simple_collection[1], shuffle, use_zarrs)
adata, collection, is_mixed = maybe_mixed_dtype_collection
if is_mixed:
with pytest.warns(UserWarning, match="Adding dataset with dtype"):
loader: Loader = gen_loader(collection, shuffle, use_zarrs)
else:
loader: Loader = gen_loader(collection, shuffle, use_zarrs)
if use_zarrs and loader.dataset_type in {np.ndarray, sp.csr_matrix, sp.csr_array}:
pytest.skip("No need to run zarrs with in-memory")
adata = simple_collection[0]
is_dense = loader.dataset_type in {zarr.Array, np.ndarray}
n_elems = 0
batches = []
Expand Down
Loading