From 2fee9df0dcd520b7eebfb970b0ed052dacf24aee Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Sun, 18 Jan 2026 18:17:15 -0800 Subject: [PATCH 1/5] running scdl on remote --- .../src/bionemo/scdl/io/chunk_sampler.py | 79 +++++++++++ .../bionemo/scdl/io/remote_chunk_loader.py | 128 ++++++++++++++++++ .../scdl/io/single_cell_memmap_dataset.py | 127 +++++++++++++---- .../src/bionemo/scdl/util/partition_scdl.py | 14 +- .../tests/bionemo/scdl/conftest.py | 19 ++- 5 files changed, 333 insertions(+), 34 deletions(-) create mode 100644 sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py create mode 100644 sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py new file mode 100644 index 0000000000..05976efe9b --- /dev/null +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk-aware sampler for efficient iteration over chunked SCDL datasets.""" + +import random +from typing import Iterator, Optional + +from torch.utils.data import Sampler + +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + + +class ChunkAwareSampler(Sampler[int]): + """Sampler that iterates by chunks for efficient access patterns. + + This sampler ensures all rows from a chunk are accessed together before + moving to the next chunk. This is optimal for: + - Local: memory locality (chunk data stays in cache) + - Remote: prefetching (download chunk once, use all rows) + + Args: + dataset: A chunked SingleCellMemMapDataset. + shuffle_chunks: Whether to shuffle chunk order each epoch. + shuffle_within_chunk: Whether to shuffle rows within each chunk. + seed: Random seed for reproducibility. + """ + + def __init__( + self, + dataset: SingleCellMemMapDataset, + shuffle_chunks: bool = True, + shuffle_within_chunk: bool = True, + seed: Optional[int] = None, + ): + """Initialize the chunk aware sampler.""" + if not dataset._is_chunked: + raise ValueError("ChunkAwareSampler requires a chunked dataset") + + self.dataset = dataset + self.shuffle_chunks = shuffle_chunks + self.shuffle_within_chunk = shuffle_within_chunk + self.rng = random.Random(seed) + + self.chunked_info = dataset.header.chunked_info + + def __iter__(self) -> Iterator[int]: + """Yield row indices, grouped by chunk.""" + chunk_ids = list(range(self.chunked_info.num_chunks)) + + if self.shuffle_chunks: + self.rng.shuffle(chunk_ids) + + for chunk_id in chunk_ids: + start = chunk_id * self.chunked_info.chunk_size + end = min(start + self.chunked_info.chunk_size, self.chunked_info.total_rows) + + if self.shuffle_within_chunk: + row_indices = list(range(start, end)) + self.rng.shuffle(row_indices) + yield from row_indices + else: + yield from range(start, end) # Lazy, no list + + def __len__(self) -> int: + """Return total number of samples.""" + return self.chunked_info.total_rows diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py new file mode 100644 index 0000000000..92ffe2c1d1 --- /dev/null +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Remote chunk loader with LRU caching for chunked SCDL datasets. + +NOTE: This is a simple POC implementation. For production multi-worker/multi-node use: +- Add file locking for shared cache (filelock) +- Add reference counting to prevent evicting in-use chunks +- Use DistributedChunkSampler to shard chunks across nodes +""" + +import shutil +import tempfile +from collections import OrderedDict +from pathlib import Path +from typing import Optional + +import fsspec + + +class RemoteChunkLoader: + """Downloads and caches chunks from remote storage with LRU eviction. + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + """ + + def __init__( + self, + remote_path: str, + cache_dir: Optional[Path] = None, + max_cached_chunks: int = 2, + ): + """Initialize the remote chunk loader.""" + self.remote_path = remote_path.rstrip("/") + self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.mkdtemp(prefix="scdl_cache_")) + self.max_cached_chunks = max_cached_chunks + self._cache: OrderedDict[int, Path] = OrderedDict() + + # Initialize filesystem + protocol = remote_path.split("://")[0] if "://" in remote_path else "file" + self._fs = fsspec.filesystem(protocol) + + # Ensure cache directory exists + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def get_chunk(self, chunk_id: int) -> Path: + """Get local path to chunk, downloading if needed. + + Args: + chunk_id: The chunk index to retrieve. + + Returns: + Local path to the chunk directory. + """ + if chunk_id in self._cache: + self._cache.move_to_end(chunk_id) + return self._cache[chunk_id] + + # Evict oldest chunks if at capacity + while len(self._cache) >= self.max_cached_chunks: + old_id, old_path = self._cache.popitem(last=False) + shutil.rmtree(old_path, ignore_errors=True) + + # Download chunk + local_path = self._download_chunk(chunk_id) + self._cache[chunk_id] = local_path + return local_path + + def _download_chunk(self, chunk_id: int) -> Path: + """Download a chunk from remote storage.""" + chunk_name = f"chunk_{chunk_id:05d}" + remote_chunk = f"{self.remote_path}/{chunk_name}" + local_chunk = self.cache_dir / chunk_name + + local_chunk.mkdir(parents=True, exist_ok=True) + + # Download all files in chunk directory + for remote_file in self._fs.ls(remote_chunk): + fname = Path(remote_file).name + self._fs.get(remote_file, str(local_chunk / fname)) + + return local_chunk + + def get_metadata(self) -> Path: + """Download and return path to metadata files (header, features, etc.).""" + metadata_dir = self.cache_dir / "_metadata" + if metadata_dir.exists(): + return metadata_dir + + metadata_dir.mkdir(parents=True, exist_ok=True) + + # Download header and feature indices + for name in ["header.json", "version.json", "metadata.json"]: + remote_file = f"{self.remote_path}/{name}" + if self._fs.exists(remote_file): + self._fs.get(remote_file, str(metadata_dir / name)) + + # Download feature directories + for name in ["var_features", "obs_features"]: + remote_dir = f"{self.remote_path}/{name}" + if self._fs.exists(remote_dir): + local_dir = metadata_dir / name + self._fs.get(remote_dir, str(local_dir), recursive=True) + + return metadata_dir + + def cleanup(self): + """Delete all cached data.""" + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def __del__(self): + """Cleanup on garbage collection.""" + self.cleanup() diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py index 89fa3f9d61..1e8ddf4458 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py @@ -18,6 +18,7 @@ import logging import os import shutil +import tempfile import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -41,6 +42,7 @@ determine_dtype, smallest_uint_dtype, ) +from bionemo.scdl.util.partition_scdl import partition_scdl from bionemo.scdl.util.scdl_constants import FLOAT_ORDER, INT_ORDER, FileNames, Mode, NeighborSamplingStrategy @@ -128,6 +130,9 @@ def __init__( self.data_path: str = data_path self.header: SCDLHeader = None self.mode: Mode = mode + self._is_chunked: bool = False + self._chunks: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = [] + self._chunk_loader = None # For remote chunked datasets self.paginated_load_cutoff = paginated_load_cutoff self.load_block_row_size = load_block_row_size self.var_feature_index_name = var_feature_index_name @@ -260,6 +265,19 @@ def version(self) -> str: """ return self._version + @classmethod + def from_remote(cls, remote_path: str, cache_dir: Optional[str] = None, max_cached_chunks: int = 2): + """Load a chunked dataset from remote storage (S3, GCS, HTTP).""" + from bionemo.scdl.io.remote_chunk_loader import RemoteChunkLoader + + loader = RemoteChunkLoader(remote_path, Path(cache_dir) if cache_dir else None, max_cached_chunks) + metadata_path = loader.get_metadata() + ds = cls.__new__(cls) + ds._chunk_loader = loader + ds.data_path = remote_path + ds.load(str(metadata_path)) + return ds + def _extract_neighbor_data(self, adata) -> bool: """Extracts neighbor data from AnnData.obsp object and saves to memmap files. @@ -436,10 +454,19 @@ def get_row( List[np.ndarray]: optional, corresponding variable (column) features. List[np.ndarray]: optional, corresponding observed (row) features. """ - start = self.row_index[index] - end = self.row_index[index + 1] - values = self.data[start:end] - columns = self.col_index[start:end] + if self._is_chunked: + chunk_id, local_idx = self.header.chunked_info.get_chunk_for_row(index) + if self._chunk_loader: + data, rowptr, colptr = self._load_chunk_from_path(self._chunk_loader.get_chunk(chunk_id)) + else: + data, rowptr, colptr = self._chunks[chunk_id] + start, end = rowptr[local_idx], rowptr[local_idx + 1] + values, columns = data[start:end], colptr[start:end] + else: + start = self.row_index[index] + end = self.row_index[index + 1] + values = self.data[start:end] + columns = self.col_index[start:end] ret = (values, columns) var_features = ( self._var_feature_index.lookup(index, select_features=var_feature_names)[0] @@ -685,37 +712,50 @@ def load(self, stored_path: str) -> None: raise ValueError(f"Array name {FileNames[array_info.name].value} not found in dtypes") self.dtypes[FileNames[array_info.name].value] = array_info.dtype.numpy_dtype_string - # Metadata is required, so we must check if it exists and fail if not. - if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"): - raise FileNotFoundError( - f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist." - ) - - with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi: - self.metadata = json.load(mfi) + # Load metadata if exists + metadata_path = f"{self.data_path}/{FileNames.METADATA.value}" + if os.path.exists(metadata_path): + with open(metadata_path, Mode.READ_APPEND.value) as mfi: + self.metadata = json.load(mfi) + # Load feature indices if os.path.exists(f"{self.data_path}/{FileNames.VAR_FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.VAR_FEATURES.value}") - elif os.path.exists( - f"{self.data_path}/{FileNames.FEATURES.value}" - ): # Backward compatibility with old features file + elif os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}") if os.path.exists(f"{self.data_path}/{FileNames.OBS_FEATURES.value}"): self._obs_feature_index = ObservedFeatureIndex.load(f"{self.data_path}/{FileNames.OBS_FEATURES.value}") - # mmap the existing arrays - self.data = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] - ) - self.row_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] - ) - self.col_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] - ) - # Load neighbor data - if self.load_neighbors: - self._load_neighbor_memmaps() + # Load data arrays - chunked vs monolithic + if self.header is not None and self.header.backend == Backend.CHUNKED_MEMMAP_V0: + self._is_chunked = True + self._load_chunk_memmaps() + else: + self.data = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] + ) + self.row_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] + ) + self.col_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] + ) + if self.load_neighbors: + self._load_neighbor_memmaps() + + def _load_chunk_memmaps(self) -> None: + """Preload all chunk memmaps (lazy - just file handles, no RAM).""" + for chunk_id in range(self.header.chunked_info.num_chunks): + chunk_path = Path(self.data_path) / f"chunk_{chunk_id:05d}" + self._chunks.append(self._load_chunk_from_path(chunk_path)) + + def _load_chunk_from_path(self, chunk_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Load memmaps for a single chunk directory.""" + return ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"), + np.memmap(chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r"), + np.memmap(chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r"), + ) def _write_metadata(self) -> None: with open(f"{self.data_path}/{FileNames.METADATA.value}", f"{Mode.CREATE.value}") as mfi: @@ -1218,6 +1258,8 @@ def number_of_rows(self) -> int: ValueError if the length of the number of rows in the feature index does not correspond to the number of stored rows. """ + if self._is_chunked: + return self.header.chunked_info.total_rows if len(self._var_feature_index) > 0 and self._var_feature_index.number_of_rows() != self.row_index.size - 1: raise ValueError( f"""The number of rows in the feature index {self._var_feature_index.number_of_rows()} @@ -1445,3 +1487,32 @@ def concat( mode=Mode.READ_APPEND.value, ) self.save() + + def to_chunked( + self, output_path: Optional[str] = None, chunk_size: int = 100_000, delete_original: bool = False + ) -> "SingleCellMemMapDataset": + """Convert this dataset to a chunked format for efficient remote access. + + Args: + output_path: Path where the chunked dataset will be created. If None, replaces in-place. + chunk_size: Number of rows per chunk (default: 100,000). + delete_original: If True and output_path is set, delete the original after conversion. + + Returns: + A new SingleCellMemMapDataset instance pointing to the chunked data. + """ + if self._is_chunked: + raise ValueError("Dataset is already chunked") + + src = Path(self.data_path) + if output_path is None: + # In-place: partition to temp, then swap + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / "chunked" + partition_scdl(src, tmp_path, chunk_size=chunk_size) + shutil.rmtree(src) + shutil.move(str(tmp_path), str(src)) + return SingleCellMemMapDataset(str(src)) + + partition_scdl(src, Path(output_path), chunk_size=chunk_size, delete_original=delete_original) + return SingleCellMemMapDataset(output_path) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py index 1affa2a596..e1f1578858 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py @@ -20,7 +20,6 @@ import numpy as np -from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset from bionemo.scdl.schema.header import ChunkedInfo, SCDLHeader from bionemo.scdl.util.scdl_constants import Backend, FileNames @@ -29,8 +28,11 @@ def partition_scdl( input_path: Path, output_path: Path, chunk_size: int = 100_000, + delete_original: bool = False, ) -> SCDLHeader: """Partition an SCDL dataset into chunks.""" + from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + input_path, output_path = Path(input_path), Path(output_path) if not input_path.exists(): @@ -44,7 +46,11 @@ def partition_scdl( source_ds = SingleCellMemMapDataset(str(input_path)) total_rows = len(source_ds) rowptr = source_ds.row_index - num_chunks = (total_rows + chunk_size - 1) // chunk_size + if chunk_size <= 0: + raise ValueError(f"Chunk size must be greater than 0, got {chunk_size}") + if total_rows <= 0: + raise ValueError(f"Total rows must be greater than 0, got {total_rows}") + num_chunks = max(1, (total_rows + chunk_size - 1) // chunk_size) # Create chunks for chunk_id in range(num_chunks): @@ -78,4 +84,8 @@ def partition_scdl( header.chunked_info = ChunkedInfo(chunk_size=chunk_size, num_chunks=num_chunks, total_rows=total_rows) header.save(str(output_path / FileNames.HEADER.value)) + if delete_original: + del source_ds # Release memmap handles + shutil.rmtree(input_path) + return header diff --git a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py index 3b8e934471..7152d7e53f 100644 --- a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py +++ b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py @@ -199,13 +199,24 @@ def _make(tmp_path): @pytest.fixture def make_h5ad_with_raw(make_random_csr): - """Factory to create an h5ad with uniquely randomized data for the fields .raw.X and .X""" + """Factory to create an h5ad with uniquely randomized data for .raw.X, .X, obs, and var.""" def _make(tmp_path): - X = make_random_csr(total_nnz=100, n_cols=50, seed=42) - X_raw = make_random_csr(total_nnz=100, n_cols=50, seed=43) + n_rows, n_cols = 100, 50 + X = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=42) + X_raw = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=43) + + obs = pd.DataFrame( + {"cell_type": [f"type_{i % 3}" for i in range(n_rows)]}, + index=[f"cell_{i}" for i in range(n_rows)], + ) + var = pd.DataFrame( + {"gene_name": [f"gene_{i}" for i in range(n_cols)]}, + index=[f"ENSG{i:08d}" for i in range(n_cols)], + ) + h = tmp_path / "var.h5ad" - ad.AnnData(X=X, var=pd.DataFrame(index=np.arange(X.shape[1])), raw={"X": X_raw}).write_h5ad(h) + ad.AnnData(X=X, obs=obs, var=var, raw={"X": X_raw}).write_h5ad(h) return h return _make From 69196ef82dd93a48801c48da0ba8d2dbe64e7b33 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Sun, 18 Jan 2026 19:09:25 -0800 Subject: [PATCH 2/5] remote dataloader works --- .../src/bionemo/scdl/io/chunk_sampler.py | 50 ++++++--- .../bionemo/scdl/io/remote_chunk_loader.py | 26 +++-- .../scdl/io/single_cell_memmap_dataset.py | 43 ++++++-- .../bionemo-scdl/test_remote_loading.py | 102 ++++++++++++++++++ 4 files changed, 191 insertions(+), 30 deletions(-) create mode 100644 sub-packages/bionemo-scdl/test_remote_loading.py diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py index 05976efe9b..db4746137c 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py @@ -16,6 +16,7 @@ """Chunk-aware sampler for efficient iteration over chunked SCDL datasets.""" import random +import warnings from typing import Iterator, Optional from torch.utils.data import Sampler @@ -26,15 +27,16 @@ class ChunkAwareSampler(Sampler[int]): """Sampler that iterates by chunks for efficient access patterns. - This sampler ensures all rows from a chunk are accessed together before - moving to the next chunk. This is optimal for: + This sampler ensures all rows from a chunk window are accessed together + before moving to the next window. This is optimal for: - Local: memory locality (chunk data stays in cache) - - Remote: prefetching (download chunk once, use all rows) + - Remote: prefetching (download chunks once, use all rows) Args: dataset: A chunked SingleCellMemMapDataset. shuffle_chunks: Whether to shuffle chunk order each epoch. - shuffle_within_chunk: Whether to shuffle rows within each chunk. + shuffle_within_window: Whether to shuffle rows within each chunk window. + chunks_per_window: Number of chunks to load together (more = better randomness). seed: Random seed for reproducibility. """ @@ -42,7 +44,8 @@ def __init__( self, dataset: SingleCellMemMapDataset, shuffle_chunks: bool = True, - shuffle_within_chunk: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 1, seed: Optional[int] = None, ): """Initialize the chunk aware sampler.""" @@ -51,28 +54,41 @@ def __init__( self.dataset = dataset self.shuffle_chunks = shuffle_chunks - self.shuffle_within_chunk = shuffle_within_chunk + self.shuffle_within_window = shuffle_within_window + self.chunks_per_window = max(1, chunks_per_window) self.rng = random.Random(seed) - self.chunked_info = dataset.header.chunked_info + # Warn if chunks_per_window exceeds cache size (causes thrashing) + if dataset._chunk_loader and chunks_per_window > dataset._chunk_loader.max_cached_chunks: + warnings.warn( + f"chunks_per_window ({chunks_per_window}) > max_cached_chunks " + f"({dataset._chunk_loader.max_cached_chunks}). This causes cache thrashing. " + f"Increase max_cached_chunks or decrease chunks_per_window." + ) + def __iter__(self) -> Iterator[int]: - """Yield row indices, grouped by chunk.""" + """Yield row indices, grouped by chunk window.""" chunk_ids = list(range(self.chunked_info.num_chunks)) if self.shuffle_chunks: self.rng.shuffle(chunk_ids) - for chunk_id in chunk_ids: - start = chunk_id * self.chunked_info.chunk_size - end = min(start + self.chunked_info.chunk_size, self.chunked_info.total_rows) + # Process in windows of N chunks + for i in range(0, len(chunk_ids), self.chunks_per_window): + window_chunks = chunk_ids[i : i + self.chunks_per_window] + + # Gather all indices from this window + all_indices = [] + for chunk_id in window_chunks: + start = chunk_id * self.chunked_info.chunk_size + end = min(start + self.chunked_info.chunk_size, self.chunked_info.total_rows) + all_indices.extend(range(start, end)) + + if self.shuffle_within_window: + self.rng.shuffle(all_indices) - if self.shuffle_within_chunk: - row_indices = list(range(start, end)) - self.rng.shuffle(row_indices) - yield from row_indices - else: - yield from range(start, end) # Lazy, no list + yield from all_indices def __len__(self) -> int: """Return total number of samples.""" diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py index 92ffe2c1d1..b98b7ce856 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py @@ -37,6 +37,7 @@ class RemoteChunkLoader: remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) cache_dir: Local directory for caching chunks. If None, uses temp directory. max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Optional dict of options passed to fsspec (e.g., endpoint_url for S3). """ def __init__( @@ -44,6 +45,7 @@ def __init__( remote_path: str, cache_dir: Optional[Path] = None, max_cached_chunks: int = 2, + storage_options: Optional[dict] = None, ): """Initialize the remote chunk loader.""" self.remote_path = remote_path.rstrip("/") @@ -51,9 +53,9 @@ def __init__( self.max_cached_chunks = max_cached_chunks self._cache: OrderedDict[int, Path] = OrderedDict() - # Initialize filesystem + # Initialize filesystem with optional storage options protocol = remote_path.split("://")[0] if "://" in remote_path else "file" - self._fs = fsspec.filesystem(protocol) + self._fs = fsspec.filesystem(protocol, **(storage_options or {})) # Ensure cache directory exists self.cache_dir.mkdir(parents=True, exist_ok=True) @@ -96,6 +98,18 @@ def _download_chunk(self, chunk_id: int) -> Path: return local_chunk + def _remote_exists(self, remote_path: str) -> bool: + """Check if a remote path exists (uses ls instead of exists for compatibility).""" + try: + # Use ls instead of exists() because some S3-compatible storage + # doesn't support HeadObject which exists() relies on + parent = "/".join(remote_path.rsplit("/", 1)[:-1]) + name = remote_path.rsplit("/", 1)[-1] + files = self._fs.ls(parent, detail=False) + return any(f.endswith(name) for f in files) + except Exception: + return False + def get_metadata(self) -> Path: """Download and return path to metadata files (header, features, etc.).""" metadata_dir = self.cache_dir / "_metadata" @@ -104,16 +118,16 @@ def get_metadata(self) -> Path: metadata_dir.mkdir(parents=True, exist_ok=True) - # Download header and feature indices - for name in ["header.json", "version.json", "metadata.json"]: + # Download header and feature indices (header.sch is the SCDL header format) + for name in ["header.sch", "version.json", "metadata.json"]: remote_file = f"{self.remote_path}/{name}" - if self._fs.exists(remote_file): + if self._remote_exists(remote_file): self._fs.get(remote_file, str(metadata_dir / name)) # Download feature directories for name in ["var_features", "obs_features"]: remote_dir = f"{self.remote_path}/{name}" - if self._fs.exists(remote_dir): + if self._remote_exists(remote_dir): local_dir = metadata_dir / name self._fs.get(remote_dir, str(local_dir), recursive=True) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py index 1e8ddf4458..cf887bbcd6 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py @@ -31,6 +31,7 @@ from bionemo.scdl.api.single_cell_row_dataset import SingleCellRowDataset from bionemo.scdl.index.row_feature_index import ObservedFeatureIndex, VariableFeatureIndex +from bionemo.scdl.io.remote_chunk_loader import RemoteChunkLoader from bionemo.scdl.schema.header import ArrayDType, ArrayInfo, Backend, FeatureIndexInfo, SCDLHeader from bionemo.scdl.schema.version import CurrentSCDLVersion from bionemo.scdl.util.filecopyutil import extend_files @@ -266,15 +267,37 @@ def version(self) -> str: return self._version @classmethod - def from_remote(cls, remote_path: str, cache_dir: Optional[str] = None, max_cached_chunks: int = 2): - """Load a chunked dataset from remote storage (S3, GCS, HTTP).""" - from bionemo.scdl.io.remote_chunk_loader import RemoteChunkLoader + def from_remote( + cls, + remote_path: str, + cache_dir: Optional[str] = None, + max_cached_chunks: int = 2, + storage_options: Optional[Dict] = None, + ): + """Load a chunked dataset from remote storage (S3, GCS, HTTP). - loader = RemoteChunkLoader(remote_path, Path(cache_dir) if cache_dir else None, max_cached_chunks) + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Options passed to fsspec (e.g., {"endpoint_url": "https://..."} for S3). + """ + loader = RemoteChunkLoader( + remote_path, Path(cache_dir) if cache_dir else None, max_cached_chunks, storage_options + ) metadata_path = loader.get_metadata() ds = cls.__new__(cls) + # Initialize essential attributes that __init__ would set + ds._version = importlib.metadata.version("bionemo.scdl") ds._chunk_loader = loader ds.data_path = remote_path + ds.header = None + ds.mode = Mode.READ_APPEND + ds._is_chunked = False + ds._chunks = [] + ds.dtypes = {} + ds._var_feature_index = None + ds._obs_feature_index = None ds.load(str(metadata_path)) return ds @@ -708,8 +731,6 @@ def load(self, stored_path: str) -> None: if self.header is not None and hasattr(self.header, "arrays"): # Map from FileNames.value to dtype string for array_info in self.header.arrays: - if FileNames[array_info.name].value not in self.dtypes: - raise ValueError(f"Array name {FileNames[array_info.name].value} not found in dtypes") self.dtypes[FileNames[array_info.name].value] = array_info.dtype.numpy_dtype_string # Load metadata if exists @@ -744,7 +765,15 @@ def load(self, stored_path: str) -> None: self._load_neighbor_memmaps() def _load_chunk_memmaps(self) -> None: - """Preload all chunk memmaps (lazy - just file handles, no RAM).""" + """Preload all chunk memmaps (lazy - just file handles, no RAM). + + For local datasets, loads from data_path directly. + For remote datasets, this is skipped - chunks are loaded on demand. + """ + # For remote datasets, don't preload - chunks are fetched on demand via get_row() + if self._chunk_loader is not None: + return + # Local: preload all chunk paths for chunk_id in range(self.header.chunked_info.num_chunks): chunk_path = Path(self.data_path) / f"chunk_{chunk_id:05d}" self._chunks.append(self._load_chunk_from_path(chunk_path)) diff --git a/sub-packages/bionemo-scdl/test_remote_loading.py b/sub-packages/bionemo-scdl/test_remote_loading.py new file mode 100644 index 0000000000..e1cf367a12 --- /dev/null +++ b/sub-packages/bionemo-scdl/test_remote_loading.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for remote chunked SCDL loading with ChunkAwareSampler. + +Usage: + python test_remote_loading.py s3://my-bucket/chunked_scdl + python test_remote_loading.py gs://my-bucket/chunked_scdl + python test_remote_loading.py --cache-dir /tmp/cache --max-chunks 3 s3://bucket/path +""" + +import argparse + +from torch.utils.data import DataLoader + +from bionemo.scdl.io.chunk_sampler import ChunkAwareSampler +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch + + +def main(): + parser = argparse.ArgumentParser(description="Test remote chunked SCDL loading") + parser.add_argument( + "--remote_path", + default="s3://general-purpose/polina/chunked", + help="Remote path (s3://..., gs://..., https://...)", + ) + parser.add_argument("--endpoint-url", default="https://pbss.s8k.io", help="S3 endpoint URL (for non-AWS S3)") + parser.add_argument("--cache-dir", default="/tmp/scdl_cache", help="Local cache directory") + parser.add_argument("--max-chunks", type=int, default=3, help="Max chunks to cache") + parser.add_argument("--chunks-per-window", type=int, default=2, help="Chunks per sampling window") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + parser.add_argument("--num-batches", type=int, default=10, help="Number of batches to iterate") + args = parser.parse_args() + + print(f"Loading remote dataset: {args.remote_path}") + print(f" Endpoint: {args.endpoint_url}") + print(f" Cache dir: {args.cache_dir}") + print(f" Max cached chunks: {args.max_chunks}") + + # Build storage_options for S3-compatible storage + # For s3fs, endpoint_url must be in client_kwargs + storage_options = {} + if args.endpoint_url: + storage_options["client_kwargs"] = {"endpoint_url": args.endpoint_url} + + # 1. Load from remote + ds = SingleCellMemMapDataset.from_remote( + args.remote_path, + cache_dir=args.cache_dir, + max_cached_chunks=args.max_chunks, + storage_options=storage_options if storage_options else None, + ) + print(f" Rows: {len(ds)}") + print(f" Chunks: {ds.header.chunked_info.num_chunks}") + print(f" Chunk size: {ds.header.chunked_info.chunk_size}") + + # 2. Create sampler + print(f"\nCreating ChunkAwareSampler (chunks_per_window={args.chunks_per_window})...") + sampler = ChunkAwareSampler( + ds, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=args.chunks_per_window, + seed=42, + ) + + # 3. Create DataLoader + print(f"Creating DataLoader (batch_size={args.batch_size})...") + loader = DataLoader(ds, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_sparse_matrix_batch) + + # 4. Iterate batches + print(f"\nIterating {args.num_batches} batches...") + for i, batch in enumerate(loader): + if i >= args.num_batches: + break + print(f" Batch {i}: shape={batch.shape}") + + print("\nSuccess! Remote chunked loading works.") + + # 5. Cleanup (optional) + if ds._chunk_loader: + print(f"\nCleaning up cache at {args.cache_dir}...") + ds._chunk_loader.cleanup() + + +if __name__ == "__main__": + main() From 0ee85a150ea1c771fffabdcf479dec6f02ab498d Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 22 Jan 2026 18:16:39 -0800 Subject: [PATCH 3/5] scdl code --- .../src/bionemo/scdl/io/chunk_sampler.py | 119 +++- .../bionemo/scdl/io/remote_chunk_loader.py | 393 ++++++++++++- .../scdl/io/single_cell_memmap_dataset.py | 42 +- .../src/bionemo/scdl/util/partition_scdl.py | 39 +- .../examples/chunked_scdl_benchmark.py | 526 ++++++++++++++++++ .../src/bionemo/scspeedtest/benchmark.py | 16 +- .../src/bionemo/scspeedtest/common.py | 2 + 7 files changed, 1088 insertions(+), 49 deletions(-) create mode 100644 sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py index db4746137c..7d8bd066b2 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py @@ -17,13 +17,45 @@ import random import warnings +from pathlib import Path from typing import Iterator, Optional +import torch.utils.data from torch.utils.data import Sampler from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +def remote_worker_init_fn(worker_id: int) -> None: + """Initialize per-worker cache directories for remote datasets. + + Use this with DataLoader's worker_init_fn to prevent cache conflicts + when using multiple workers with remote chunked SCDL: + + DataLoader( + dataset, + num_workers=4, + worker_init_fn=remote_worker_init_fn, + ... + ) + + Each worker gets its own cache directory: {base_cache}_worker_{id} + """ + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + return + + dataset = worker_info.dataset + if hasattr(dataset, "_chunk_loader") and dataset._chunk_loader is not None: + # Create per-worker cache directory + base_cache = dataset._chunk_loader.cache_dir + worker_cache = Path(str(base_cache) + f"_worker_{worker_id}") + worker_cache.mkdir(parents=True, exist_ok=True) + dataset._chunk_loader.cache_dir = worker_cache + # Clear cached chunks dict for this worker + dataset._chunk_loader._cached_chunks.clear() + + class ChunkAwareSampler(Sampler[int]): """Sampler that iterates by chunks for efficient access patterns. @@ -66,17 +98,83 @@ def __init__( f"({dataset._chunk_loader.max_cached_chunks}). This causes cache thrashing. " f"Increase max_cached_chunks or decrease chunks_per_window." ) + # Warn if cache too small for effective prefetching + if dataset._chunk_loader and 2 * chunks_per_window > dataset._chunk_loader.max_cached_chunks: + warnings.warn( + f"max_cached_chunks ({dataset._chunk_loader.max_cached_chunks}) < 2 * chunks_per_window " + f"({2 * chunks_per_window}). Prefetching disabled - no room for next window. " + f"Set max_cached_chunks >= {2 * chunks_per_window} for prefetching." + ) + + def _preload_memmaps_for_chunks(self, chunk_ids: list) -> None: + """Preload memmaps for chunks into _remote_chunk_cache. + + This should be called after prefetch completes to avoid creating + memmaps during iteration. Files are already on disk, we just need + to create the memmap objects. + """ + if not hasattr(self.dataset, "_remote_chunk_cache"): + return + + for chunk_id in chunk_ids: + if chunk_id not in self.dataset._remote_chunk_cache: + # get_chunk returns path (chunk already downloaded by prefetch) + chunk_path = self.dataset._chunk_loader.get_chunk(chunk_id) + # Load memmaps and cache them + memmaps = self.dataset._load_chunk_from_path(chunk_path) + self.dataset._remote_chunk_cache[chunk_id] = memmaps def __iter__(self) -> Iterator[int]: - """Yield row indices, grouped by chunk window.""" + """Yield row indices, grouped by chunk window. + + In multi-worker mode, each worker handles a disjoint subset of chunks. + This avoids cache conflicts and ensures efficient parallel loading. + """ chunk_ids = list(range(self.chunked_info.num_chunks)) if self.shuffle_chunks: self.rng.shuffle(chunk_ids) - # Process in windows of N chunks + # Multi-worker support: each worker handles a subset of chunks + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + # Split chunks among workers: worker i gets chunks[i::num_workers] + chunk_ids = chunk_ids[worker_info.id :: worker_info.num_workers] + + # Build list of windows from this worker's chunks + windows = [] for i in range(0, len(chunk_ids), self.chunks_per_window): - window_chunks = chunk_ids[i : i + self.chunks_per_window] + windows.append(chunk_ids[i : i + self.chunks_per_window]) + + # Check if we have room for prefetching (need 2x window size in cache) + can_prefetch = ( + self.dataset._chunk_loader and 2 * self.chunks_per_window <= self.dataset._chunk_loader.max_cached_chunks + ) + + import time as time_module # Import once outside loop + + iteration_start = time_module.perf_counter() + + # Process windows with pipelined prefetching (1 window ahead) + for window_idx, window_chunks in enumerate(windows): + if self.dataset._chunk_loader: + # For first window, prefetch synchronously since we need it immediately + if window_idx == 0: + cold_start = time_module.perf_counter() + self.dataset._chunk_loader.prefetch_chunks(window_chunks) + elapsed = time_module.perf_counter() - cold_start + self.dataset._chunk_loader.stats.record_cold_start(elapsed) + else: + # Wait for async prefetch started in previous iteration + self.dataset._chunk_loader.wait_for_prefetch() + + # Preload memmaps for current window (now that files are downloaded) + # This avoids creating memmaps during iteration + self._preload_memmaps_for_chunks(window_chunks) + + # Start async prefetch of NEXT window while we process this one + if can_prefetch and window_idx + 1 < len(windows): + self.dataset._chunk_loader.prefetch_chunks_async(windows[window_idx + 1]) # Gather all indices from this window all_indices = [] @@ -90,6 +188,19 @@ def __iter__(self) -> Iterator[int]: yield from all_indices + # Mark this window's chunks as safe to evict (prefer these over current/next window) + if self.dataset._chunk_loader: + self.dataset._chunk_loader.mark_chunks_done(window_chunks) + + # Record total wall-clock time + if self.dataset._chunk_loader: + total_time = time_module.perf_counter() - iteration_start + self.dataset._chunk_loader.stats.record_wall_clock(total_time) + def __len__(self) -> int: - """Return total number of samples.""" + """Return total number of samples. + + Note: In multi-worker mode, each worker sees only its subset, + but PyTorch DataLoader handles combining results correctly. + """ return self.chunked_info.total_rows diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py index b98b7ce856..3ca1222f96 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py @@ -21,13 +21,77 @@ - Use DistributedChunkSampler to shard chunks across nodes """ +import io import shutil import tempfile +import threading +import time from collections import OrderedDict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import ClassVar, Dict, List, Optional, Tuple import fsspec +import numpy as np + +from bionemo.scdl.util.scdl_constants import FileNames + + +@dataclass +class DownloadStats: + """Statistics for download timing.""" + + total_download_time: float = 0.0 # Cumulative time across all download threads + total_wait_time: float = 0.0 # Time spent blocked waiting for downloads + cold_start_time: float = 0.0 # Time to download first window + wall_clock_time: float = 0.0 # Total wall-clock iteration time + total_bytes_downloaded: int = 0 + download_count: int = 0 + cache_hits: int = 0 + _lock: threading.Lock = field(default_factory=threading.Lock) + + def record_download(self, duration: float, bytes_downloaded: int = 0) -> None: + """Record a download completion.""" + with self._lock: + self.total_download_time += duration + self.total_bytes_downloaded += bytes_downloaded + self.download_count += 1 + + def record_cold_start(self, duration: float) -> None: + """Record cold start time (first window download).""" + self.cold_start_time = duration + + def record_wall_clock(self, duration: float) -> None: + """Record total wall-clock time.""" + self.wall_clock_time = duration + + def record_wait(self, duration: float) -> None: + """Record time spent waiting for a chunk.""" + with self._lock: + self.total_wait_time += duration + + def record_cache_hit(self) -> None: + """Record a cache hit.""" + with self._lock: + self.cache_hits += 1 + + def summary(self) -> dict: + """Return a summary of download statistics.""" + total_mb = self.total_bytes_downloaded / 1e6 + effective_throughput = total_mb / self.wall_clock_time if self.wall_clock_time > 0 else 0 + per_thread_throughput = total_mb / self.total_download_time if self.total_download_time > 0 else 0 + + return { + "cold_start_time_sec": round(self.cold_start_time, 2), + "wall_clock_time_sec": round(self.wall_clock_time, 2), + "total_wait_time_sec": round(self.total_wait_time, 2), + "total_bytes_downloaded_mb": round(total_mb, 2), + "download_count": self.download_count, + "cache_hits": self.cache_hits, + "throughput_mbps": round(effective_throughput, 2), + "per_thread_throughput_mbps": round(per_thread_throughput, 2), + } class RemoteChunkLoader: @@ -40,18 +104,47 @@ class RemoteChunkLoader: storage_options: Optional dict of options passed to fsspec (e.g., endpoint_url for S3). """ + # Files in each chunk directory (uncompressed format) + CHUNK_FILES: ClassVar[List[str]] = [FileNames.DATA.value, FileNames.ROWPTR.value, FileNames.COLPTR.value] + # Compressed format (single file) + COMPRESSED_FILE: ClassVar[str] = "chunk.npz" + + # Type alias for cached chunk data: (file_path, (data, rowptr, colptr) memmaps) + CacheEntry = Tuple[Path, Tuple[np.ndarray, np.ndarray, np.ndarray]] + def __init__( self, remote_path: str, cache_dir: Optional[Path] = None, max_cached_chunks: int = 2, storage_options: Optional[dict] = None, + batch_download_size: int = 10, + use_async_downloads: bool = True, + dtypes: Optional[Dict[str, str]] = None, ): - """Initialize the remote chunk loader.""" + """Initialize the remote chunk loader. + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Optional dict of options passed to fsspec. + batch_download_size: Number of chunks to download at once (unused, for API compat). + use_async_downloads: Whether to use async downloads (unused, for API compat). + dtypes: Optional dict mapping file names to dtypes for memmap loading. + """ self.remote_path = remote_path.rstrip("/") self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.mkdtemp(prefix="scdl_cache_")) self.max_cached_chunks = max_cached_chunks - self._cache: OrderedDict[int, Path] = OrderedDict() + self.batch_download_size = batch_download_size + self._use_async = use_async_downloads + self.dtypes = dtypes or {} + # Cache stores both path (for cleanup) and memmaps (for access) + self._cache: OrderedDict[int, "RemoteChunkLoader.CacheEntry"] = OrderedDict() + self._cache_lock = threading.Lock() # Protect cache access + self._downloading: set = set() # Chunks currently being downloaded + self._download_complete = threading.Condition(self._cache_lock) + self._evictable_chunks: set = set() # Chunks marked as "done" - prefer evicting these # Initialize filesystem with optional storage options protocol = remote_path.split("://")[0] if "://" in remote_path else "file" @@ -60,44 +153,296 @@ def __init__( # Ensure cache directory exists self.cache_dir.mkdir(parents=True, exist_ok=True) - def get_chunk(self, chunk_id: int) -> Path: - """Get local path to chunk, downloading if needed. + # Async prefetching + self._prefetch_executor = ThreadPoolExecutor(max_workers=1) + self._prefetch_future: Optional[Future] = None + self._prefetch_lock = threading.Lock() + + # Download statistics + self.stats = DownloadStats() + + def set_dtypes(self, dtypes: Dict[str, str]) -> None: + """Set dtypes for memmap loading (can be called after header is loaded). + + Args: + dtypes: Dict mapping file names to dtypes. + """ + self.dtypes = dtypes + + def mark_chunks_done(self, chunk_ids: List[int]) -> None: + """Mark chunks as done/evictable. + + Call this when you've finished processing a window of chunks and won't + need them again until the next epoch. Eviction will prefer these chunks. + + Args: + chunk_ids: List of chunk IDs that can be safely evicted. + """ + with self._cache_lock: + self._evictable_chunks.update(chunk_ids) + + def _find_evictable_chunk(self) -> Optional[Tuple[int, "RemoteChunkLoader.CacheEntry"]]: + """Find a chunk to evict, preferring chunks marked as done. + + Returns: + Tuple of (chunk_id, (path, memmaps)) or None if no chunks available. + """ + # First, try to evict chunks marked as "done" + for chunk_id in list(self._cache.keys()): + if chunk_id in self._evictable_chunks: + entry = self._cache.pop(chunk_id) + self._evictable_chunks.discard(chunk_id) + return chunk_id, entry + + # Fallback: evict oldest chunk (LRU) - shouldn't happen if sized correctly + if self._cache: + return self._cache.popitem(last=False) + return None + + def _evict_chunk(self, chunk_id: int, cache_entry: "RemoteChunkLoader.CacheEntry") -> None: + """Evict a chunk from cache, releasing memmaps before deleting files.""" + self._evictable_chunks.discard(chunk_id) + chunk_path, memmaps = cache_entry + # Release memmaps first (closes file handles) + del memmaps + # Now safe to delete files + shutil.rmtree(chunk_path, ignore_errors=True) + + def get_chunk(self, chunk_id: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get chunk as memory-mapped arrays, downloading if needed. Args: chunk_id: The chunk index to retrieve. Returns: - Local path to the chunk directory. + Tuple of (data, rowptr, colptr) as memory-mapped numpy arrays. """ - if chunk_id in self._cache: - self._cache.move_to_end(chunk_id) - return self._cache[chunk_id] + start_time = time.perf_counter() + + # Collect chunks to evict inside lock, delete outside to avoid blocking + chunks_to_evict = [] + + with self._download_complete: + # Wait if prefetch is downloading this chunk + while chunk_id in self._downloading and chunk_id not in self._cache: + self._download_complete.wait(timeout=0.1) + + # Check cache - return memmaps directly + if chunk_id in self._cache: + self._cache.move_to_end(chunk_id) + self.stats.record_cache_hit() + _path, memmaps = self._cache[chunk_id] + return memmaps + + # Evict chunks if at capacity (prefer evicting "done" chunks) + while len(self._cache) >= self.max_cached_chunks and self._cache: + evict_result = self._find_evictable_chunk() + if evict_result is None: + break + old_id, old_entry = evict_result + chunks_to_evict.append((old_id, old_entry)) + + # Mark as downloading + self._downloading.add(chunk_id) - # Evict oldest chunks if at capacity - while len(self._cache) >= self.max_cached_chunks: - old_id, old_path = self._cache.popitem(last=False) - shutil.rmtree(old_path, ignore_errors=True) + # Delete evicted chunks in background thread while we download + eviction_future = None + if chunks_to_evict: + eviction_future = self._prefetch_executor.submit( + lambda: [self._evict_chunk(old_id, old_entry) for old_id, old_entry in chunks_to_evict] + ) - # Download chunk - local_path = self._download_chunk(chunk_id) - self._cache[chunk_id] = local_path - return local_path + try: + # Download chunk (outside lock) + local_path = self._download_chunk(chunk_id) + + # Load memmaps + memmaps = self._load_chunk_memmaps(local_path) + + # Wait for eviction to complete (if any) + if eviction_future: + eviction_future.result() + + with self._download_complete: + self._cache[chunk_id] = (local_path, memmaps) + self._downloading.discard(chunk_id) + self._download_complete.notify_all() + + # Record wait time (includes download time for non-prefetched chunks) + wait_time = time.perf_counter() - start_time + self.stats.record_wait(wait_time) + + return memmaps + except Exception: + with self._download_complete: + self._downloading.discard(chunk_id) + self._download_complete.notify_all() + raise + + def _load_chunk_memmaps(self, chunk_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Load memmaps for a single chunk directory. + + Uses np.memmap when dtypes are known (faster), falls back to np.load otherwise. + """ + if self.dtypes: + return ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes.get(FileNames.DATA.value), mode="r"), + np.memmap( + chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes.get(FileNames.ROWPTR.value), mode="r" + ), + np.memmap( + chunk_path / FileNames.COLPTR.value, dtype=self.dtypes.get(FileNames.COLPTR.value), mode="r" + ), + ) + + else: + raise ValueError("Dtypes are not set") def _download_chunk(self, chunk_id: int) -> Path: - """Download a chunk from remote storage.""" + """Download a chunk from remote storage. + + Supports both compressed (.npz) and uncompressed formats. + Compressed files are downloaded and extracted to uncompressed format for fast memmap access. + """ + start_time = time.perf_counter() + bytes_downloaded = 0 + chunk_name = f"chunk_{chunk_id:05d}" + print(f"[DOWNLOAD] Starting chunk {chunk_id} ({chunk_name})") remote_chunk = f"{self.remote_path}/{chunk_name}" local_chunk = self.cache_dir / chunk_name + # Ensure directory exists local_chunk.mkdir(parents=True, exist_ok=True) - # Download all files in chunk directory - for remote_file in self._fs.ls(remote_chunk): - fname = Path(remote_file).name - self._fs.get(remote_file, str(local_chunk / fname)) + # Check if compressed format exists (single .npz file) + remote_compressed = f"{remote_chunk}/{self.COMPRESSED_FILE}" + try: + # Try compressed format first (1 HTTP request instead of 3) + with self._fs.open(remote_compressed, "rb") as src: + compressed_data = src.read() + bytes_downloaded = len(compressed_data) + + # Load compressed data from memory + npz = np.load(io.BytesIO(compressed_data)) + + # Extract and save as uncompressed files (for fast memmap access) + np.save(local_chunk / FileNames.DATA.value, npz["data"]) + np.save(local_chunk / FileNames.ROWPTR.value, npz["row_ptr"]) + np.save(local_chunk / FileNames.COLPTR.value, npz["col_ptr"]) + + # Record download stats + duration = time.perf_counter() - start_time + self.stats.record_download(duration, bytes_downloaded) + + return local_chunk + + except FileNotFoundError: + # Fall back to uncompressed format (3 files) + pass + except Exception: + # Fall back to uncompressed format (3 files) + pass + + # Download uncompressed files (original format) + def download_file(fname: str) -> int: + remote_file = f"{remote_chunk}/{fname}" + local_file = local_chunk / fname + try: + local_file.parent.mkdir(parents=True, exist_ok=True) + except FileExistsError: + pass + try: + with self._fs.open(remote_file, "rb") as src: + content = src.read() + except FileNotFoundError: + raise FileNotFoundError( + f"Chunk {chunk_id} not found at {remote_file} (compressed also not found at {remote_compressed})" + ) + with open(local_file, "wb") as dst: + dst.write(content) + return len(content) + + with ThreadPoolExecutor(max_workers=3) as executor: + file_sizes = list(executor.map(download_file, self.CHUNK_FILES)) + + # Record download stats + bytes_downloaded = sum(file_sizes) + duration = time.perf_counter() - start_time + self.stats.record_download(duration, bytes_downloaded) return local_chunk + def prefetch_chunks(self, chunk_ids: List[int], max_parallel: int = 16) -> None: + """Prefetch multiple chunks in parallel (synchronous).""" + # Filter out already cached chunks + with self._cache_lock: + to_download = [cid for cid in chunk_ids if cid not in self._cache and cid not in self._downloading] + + if not to_download: + return + + # Collect chunks to evict inside lock, delete outside to avoid blocking + chunks_to_evict = [] + with self._download_complete: + needed = len(to_download) + while len(self._cache) + needed > self.max_cached_chunks and self._cache: + evict_result = self._find_evictable_chunk() + if evict_result is None: + break + old_id, old_entry = evict_result + chunks_to_evict.append((old_id, old_entry)) + + # Mark all as downloading + for cid in to_download: + self._downloading.add(cid) + + try: + + def download_and_load(chunk_id: int) -> Tuple[int, Path, Tuple[np.ndarray, np.ndarray, np.ndarray]]: + local_path = self._download_chunk(chunk_id) + memmaps = self._load_chunk_memmaps(local_path) + return chunk_id, local_path, memmaps + + # Run eviction and downloads in parallel + with ThreadPoolExecutor( + max_workers=min(max_parallel, len(to_download) + len(chunks_to_evict)) + ) as executor: + # Submit eviction tasks (fire and forget) + for old_id, old_entry in chunks_to_evict: + executor.submit(self._evict_chunk, old_id, old_entry) + # Download all chunks + results = list(executor.map(download_and_load, to_download)) + + # Add to cache + with self._download_complete: + for chunk_id, local_path, memmaps in results: + self._cache[chunk_id] = (local_path, memmaps) + self._downloading.discard(chunk_id) + self._download_complete.notify_all() + + except Exception: + with self._download_complete: + for cid in to_download: + self._downloading.discard(cid) + self._download_complete.notify_all() + raise + + def prefetch_chunks_async(self, chunk_ids: List[int]) -> None: + """Start prefetching chunks in background thread.""" + with self._prefetch_lock: + self._prefetch_future = self._prefetch_executor.submit(self.prefetch_chunks, chunk_ids) + + def wait_for_prefetch(self) -> None: + """Wait for any ongoing prefetch to complete.""" + with self._prefetch_lock: + if self._prefetch_future is not None: + try: + self._prefetch_future.result(timeout=300) # 5 min timeout + except Exception: + pass # Ignore prefetch errors + self._prefetch_future = None + def _remote_exists(self, remote_path: str) -> bool: """Check if a remote path exists (uses ls instead of exists for compatibility).""" try: @@ -136,7 +481,3 @@ def get_metadata(self) -> Path: def cleanup(self): """Delete all cached data.""" shutil.rmtree(self.cache_dir, ignore_errors=True) - - def __del__(self): - """Cleanup on garbage collection.""" - self.cleanup() diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py index cf887bbcd6..14cd8ab522 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py @@ -273,6 +273,8 @@ def from_remote( cache_dir: Optional[str] = None, max_cached_chunks: int = 2, storage_options: Optional[Dict] = None, + batch_download_size: int = 10, + use_async_downloads: bool = True, ): """Load a chunked dataset from remote storage (S3, GCS, HTTP). @@ -281,9 +283,16 @@ def from_remote( cache_dir: Local directory for caching chunks. If None, uses temp directory. max_cached_chunks: Maximum number of chunks to keep in cache. storage_options: Options passed to fsspec (e.g., {"endpoint_url": "https://..."} for S3). + batch_download_size: Number of chunks to download in parallel. + use_async_downloads: Whether to use async downloads (currently unused). """ loader = RemoteChunkLoader( - remote_path, Path(cache_dir) if cache_dir else None, max_cached_chunks, storage_options + remote_path, + Path(cache_dir) if cache_dir else None, + max_cached_chunks, + storage_options, + batch_download_size=batch_download_size, + use_async_downloads=use_async_downloads, ) metadata_path = loader.get_metadata() ds = cls.__new__(cls) @@ -295,10 +304,16 @@ def from_remote( ds.mode = Mode.READ_APPEND ds._is_chunked = False ds._chunks = [] - ds.dtypes = {} + ds.dtypes = {} # Will be populated from header in load() ds._var_feature_index = None ds._obs_feature_index = None + ds.load(str(metadata_path)) + + # Pass dtypes to loader after they're loaded from header + if ds.dtypes: + loader.set_dtypes(ds.dtypes) + return ds def _extract_neighbor_data(self, adata) -> bool: @@ -480,8 +495,10 @@ def get_row( if self._is_chunked: chunk_id, local_idx = self.header.chunked_info.get_chunk_for_row(index) if self._chunk_loader: - data, rowptr, colptr = self._load_chunk_from_path(self._chunk_loader.get_chunk(chunk_id)) + # Remote: loader returns memmaps directly (handles caching internally) + data, rowptr, colptr = self._chunk_loader.get_chunk(chunk_id) else: + # Local: use pre-loaded chunk memmaps data, rowptr, colptr = self._chunks[chunk_id] start, end = rowptr[local_idx], rowptr[local_idx + 1] values, columns = data[start:end], colptr[start:end] @@ -779,11 +796,22 @@ def _load_chunk_memmaps(self) -> None: self._chunks.append(self._load_chunk_from_path(chunk_path)) def _load_chunk_from_path(self, chunk_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """Load memmaps for a single chunk directory.""" + """Load memmaps for a single chunk directory. + + Uses np.memmap when dtypes are known (faster), falls back to np.load otherwise. + """ + # Use np.memmap when dtypes are available (faster than np.load) + if self.dtypes: + return ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"), + np.memmap(chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r"), + np.memmap(chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r"), + ) + # Fallback to np.load which auto-detects dtype from .npy header return ( - np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"), - np.memmap(chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r"), - np.memmap(chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r"), + np.load(chunk_path / FileNames.DATA.value, mmap_mode="r"), + np.load(chunk_path / FileNames.ROWPTR.value, mmap_mode="r"), + np.load(chunk_path / FileNames.COLPTR.value, mmap_mode="r"), ) def _write_metadata(self) -> None: diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py index e1f1578858..946d48b599 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py @@ -29,8 +29,18 @@ def partition_scdl( output_path: Path, chunk_size: int = 100_000, delete_original: bool = False, + compressed: bool = False, ) -> SCDLHeader: - """Partition an SCDL dataset into chunks.""" + """Partition an SCDL dataset into chunks. + + Args: + input_path: Path to source SCDL dataset. + output_path: Path for output chunked dataset. + chunk_size: Number of rows per chunk. + delete_original: Whether to delete the source after partitioning. + compressed: If True, save each chunk as a single compressed .npz file + (faster for remote access - 3x fewer HTTP requests). + """ from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset input_path, output_path = Path(input_path), Path(output_path) @@ -61,14 +71,27 @@ def partition_scdl( data_start, data_end = int(rowptr[row_start]), int(rowptr[row_end]) - # Write chunk files using memmap slicing + # Extract chunk data chunk_rowptr = rowptr[row_start : row_end + 1] - data_start - with open(chunk_dir / FileNames.ROWPTR.value, "wb") as f: - f.write(chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]).tobytes()) - with open(chunk_dir / FileNames.DATA.value, "wb") as f: - f.write(np.array(source_ds.data[data_start:data_end]).tobytes()) - with open(chunk_dir / FileNames.COLPTR.value, "wb") as f: - f.write(np.array(source_ds.col_index[data_start:data_end]).tobytes()) + chunk_data = np.array(source_ds.data[data_start:data_end]) + chunk_colptr = np.array(source_ds.col_index[data_start:data_end]) + + if compressed: + # Single compressed file (faster for remote access) + np.savez_compressed( + chunk_dir / "chunk.npz", + data=chunk_data, + row_ptr=chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]), + col_ptr=chunk_colptr, + ) + else: + # Separate files (original format) + with open(chunk_dir / FileNames.ROWPTR.value, "wb") as f: + f.write(chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]).tobytes()) + with open(chunk_dir / FileNames.DATA.value, "wb") as f: + f.write(chunk_data.tobytes()) + with open(chunk_dir / FileNames.COLPTR.value, "wb") as f: + f.write(chunk_colptr.tobytes()) # Copy features and metadata for name in [FileNames.VAR_FEATURES.value, FileNames.OBS_FEATURES.value]: diff --git a/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py b/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py new file mode 100644 index 0000000000..3dacf83dd8 --- /dev/null +++ b/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py @@ -0,0 +1,526 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Benchmark comparing Regular SCDL vs Chunked SCDL vs Remote Chunked SCDL. + +Usage (with defaults): + python chunked_scdl_benchmark.py + +Custom paths: + python chunked_scdl_benchmark.py \ + --scdl-path /path/to/scdl/ \ + --chunked-path /path/to/chunked/ \ + --remote-path s3://bucket/chunked \ + --endpoint-url https://your-s3-endpoint.com + +This script benchmarks: +1. Regular SCDL - Standard DataLoader with shuffle (baseline) +2. Chunked SCDL (local) - Pre-converted chunked dataset with ChunkAwareSampler +3. Remote Chunked SCDL - S3/GCS with LRU caching and ChunkAwareSampler +""" + +import argparse +import os +import time +from datetime import datetime + +from torch.utils.data import DataLoader + +from bionemo.scdl.io.chunk_sampler import ChunkAwareSampler +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch +from bionemo.scspeedtest import benchmark_dataloaders_with_configs, print_comparison + + +def create_regular_scdl_factory( + batch_size: int = 64, shuffle: bool = True, data_path: str | None = None, num_workers: int = 0 +): + """Create a regular SCDL dataloader factory (baseline).""" + + def factory(): + dataset = SingleCellMemMapDataset(data_path) + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=False, + collate_fn=collate_sparse_matrix_batch, + num_workers=num_workers, + ) + + return factory + + +def create_chunked_scdl_preconverted_factory( + batch_size: int = 64, + chunked_path: str | None = None, + num_workers: int = 0, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 2, +): + """Create a chunked SCDL dataloader factory from pre-converted chunked dataset.""" + + def factory(): + dataset = SingleCellMemMapDataset(chunked_path) + + sampler = ChunkAwareSampler( + dataset, + shuffle_chunks=shuffle_chunks, + shuffle_within_window=shuffle_within_window, + chunks_per_window=chunks_per_window, + ) + + return DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + collate_fn=collate_sparse_matrix_batch, + num_workers=num_workers, + ) + + return factory + + +def create_chunked_scdl_random_factory( + batch_size: int = 64, + chunked_path: str | None = None, + num_workers: int = 0, +): + """Create a chunked SCDL dataloader with random shuffle (no ChunkAwareSampler).""" + start_time = time.perf_counter() + + def factory(): + dataset = SingleCellMemMapDataset(chunked_path) + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, # Standard random shuffle - worst case for I/O locality + collate_fn=collate_sparse_matrix_batch, + num_workers=num_workers, + ) + + end_time = time.perf_counter() + print(f"Time taken to instantiate chunked SCDL dataset: {end_time - start_time:.2f} seconds") + return factory + + +class RemoteDataloaderFactory: + """Factory that tracks the last created dataset for stats access.""" + + def __init__( + self, + batch_size: int = 64, + remote_path: str | None = None, + cache_dir: str | None = None, + max_cached_chunks: int = 3, + storage_options: dict | None = None, + num_workers: int = 0, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 2, + batch_download_size: int = 30, + ): + """Initialize the remote dataloader factory with configuration.""" + self.batch_size = batch_size + self.remote_path = remote_path + self.cache_dir = cache_dir + self.max_cached_chunks = max_cached_chunks + self.storage_options = storage_options + self.num_workers = num_workers + self.shuffle_chunks = shuffle_chunks + self.shuffle_within_window = shuffle_within_window + self.chunks_per_window = chunks_per_window + self.batch_download_size = batch_download_size + self.last_dataset = None # Store reference for stats access + # Timing breakdown + self.cache_clear_time = 0.0 + self.dataset_init_time = 0.0 + self.sampler_init_time = 0.0 + self.dataloader_init_time = 0.0 + self.total_init_time = 0.0 + + def __call__(self): + """Create a new dataloader.""" + import shutil + + # Clear cache before each run to measure streaming performance + import time + + total_start = time.perf_counter() + + # 1. Clear cache + if self.cache_dir: + print(f"Clearing cache directory: {self.cache_dir}") + t0 = time.perf_counter() + shutil.rmtree(self.cache_dir, ignore_errors=True) + t1 = time.perf_counter() + self.cache_clear_time = t1 - t0 + print(f"Cache clear time: {self.cache_clear_time:.3f} sec") + + # 2. Create dataset + print("Instantiating SingleCellMemMapDataset.from_remote ...") + t2 = time.perf_counter() + self.last_dataset = SingleCellMemMapDataset.from_remote( + self.remote_path, + cache_dir=self.cache_dir, + max_cached_chunks=self.max_cached_chunks, + storage_options=self.storage_options, + batch_download_size=self.batch_download_size, + use_async_downloads=True, + ) + t3 = time.perf_counter() + self.dataset_init_time = t3 - t2 + print(f"SingleCellMemMapDataset.from_remote time: {self.dataset_init_time:.3f} sec") + + # 3. Create sampler + print("Instantiating ChunkAwareSampler ...") + t4 = time.perf_counter() + sampler = ChunkAwareSampler( + self.last_dataset, + shuffle_chunks=self.shuffle_chunks, + shuffle_within_window=self.shuffle_within_window, + chunks_per_window=self.chunks_per_window, + ) + t5 = time.perf_counter() + self.sampler_init_time = t5 - t4 + print(f"ChunkAwareSampler instantiation time: {self.sampler_init_time:.3f} sec") + + # 4. Create DataLoader + print("Instantiating DataLoader ...") + t6 = time.perf_counter() + dataloader = DataLoader( + self.last_dataset, + sampler=sampler, + batch_size=self.batch_size, + collate_fn=collate_sparse_matrix_batch, + num_workers=self.num_workers, + ) + t7 = time.perf_counter() + self.dataloader_init_time = t7 - t6 + self.total_init_time = t7 - total_start + print(f"DataLoader instantiation time: {self.dataloader_init_time:.3f} sec") + print(f"Total init time: {self.total_init_time:.3f} sec") + + return dataloader + + def get_download_stats(self) -> dict | None: + """Get download statistics from the last created dataset.""" + if self.last_dataset and hasattr(self.last_dataset, "_chunk_loader"): + return self.last_dataset._chunk_loader.stats.summary() + return None + + def print_download_stats(self): + """Print download statistics.""" + stats = self.get_download_stats() + if stats: + iteration_time = stats["wall_clock_time_sec"] + total_time = self.total_init_time + iteration_time + + print("\n" + "=" * 50) + print("REMOTE DOWNLOAD STATS") + print("=" * 50) + print(f"Init: {self.total_init_time:.1f}s | Iteration: {iteration_time:.1f}s | Total: {total_time:.1f}s") + print(f"Cold start: {stats['cold_start_time_sec']:.1f}s | Wait time: {stats['total_wait_time_sec']:.1f}s") + print( + f"Downloaded: {stats['total_bytes_downloaded_mb']:.0f} MB ({stats['download_count']} chunks, {stats['cache_hits']} cache hits)" + ) + print( + f"Throughput: {stats['throughput_mbps']:.1f} MB/s effective, {stats['per_thread_throughput_mbps']:.1f} MB/s per-thread" + ) + + +def create_remote_chunked_scdl_factory( + batch_size: int = 64, + remote_path: str | None = None, + cache_dir: str | None = None, + max_cached_chunks: int = 3, + storage_options: dict | None = None, + num_workers: int = 0, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 2, + batch_download_size: int = 30, +) -> RemoteDataloaderFactory: + """Create a remote chunked SCDL dataloader factory with ChunkAwareSampler.""" + return RemoteDataloaderFactory( + batch_size=batch_size, + remote_path=remote_path, + cache_dir=cache_dir, + max_cached_chunks=max_cached_chunks, + storage_options=storage_options, + num_workers=num_workers, + shuffle_chunks=shuffle_chunks, + shuffle_within_window=shuffle_within_window, + chunks_per_window=chunks_per_window, + batch_download_size=batch_download_size, + ) + + +def chunked_scdl_benchmarking( + num_epochs: int = 1, + num_runs: int = 1, + scdl_path: str | None = None, + chunked_path: str | None = None, + remote_path: str | None = None, + endpoint_url: str | None = None, + cache_dir: str = "/tmp/scdl_cache", + max_cached_chunks: int = 3, + chunks_per_window: int = 2, + max_time_seconds: float = 120.0, + warmup_time_seconds: float = 30.0, + batch_size: int = 64, +): + """Run benchmarks comparing regular SCDL vs chunked SCDL. + + Args: + num_epochs: Number of epochs per configuration + num_runs: Number of runs per configuration + scdl_path: Path to regular (non-chunked) SCDL dataset + chunked_path: Path to pre-converted chunked SCDL dataset (optional) + remote_path: Remote path to chunked dataset (s3://, gs://, etc.) + endpoint_url: Custom S3 endpoint URL (for non-AWS S3) + cache_dir: Local cache directory for remote chunks + max_cached_chunks: Max chunks to keep in LRU cache + chunks_per_window: Chunks per sampling window + max_time_seconds: Max time per configuration + warmup_time_seconds: Warmup time per configuration + batch_size: Batch size for dataloaders + """ + print("=" * 80) + print("CHUNKED SCDL BENCHMARKING") + print("=" * 80) + print() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + configurations = [] + + # Configuration 1: Regular SCDL baseline + if scdl_path: + print(f"Adding Regular SCDL baseline: {scdl_path}") + configurations.append( + { + "name": "Regular_SCDL_Baseline", + "dataloader_factory": create_regular_scdl_factory( + batch_size=batch_size, shuffle=True, data_path=scdl_path, num_workers=0 + ), + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": scdl_path, + "num_runs": num_runs, + } + ) + + # Configuration 2: Pre-converted chunked SCDL with ChunkAwareSampler + if chunked_path: + print(f"Adding Chunked SCDL + ChunkAwareSampler: {chunked_path}") + configurations.append( + { + "name": f"Chunked_SCDL_ChunkAware_window{chunks_per_window}", + "dataloader_factory": create_chunked_scdl_preconverted_factory( + batch_size=batch_size, + chunked_path=chunked_path, + num_workers=0, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=chunks_per_window, + ), + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": chunked_path, + "num_runs": num_runs, + } + ) + + # Configuration 3: Chunked SCDL with random shuffle (no ChunkAwareSampler) + print(f"Adding Chunked SCDL + Random Shuffle: {chunked_path}") + configurations.append( + { + "name": "Chunked_SCDL_RandomShuffle", + "dataloader_factory": create_chunked_scdl_random_factory( + batch_size=batch_size, + chunked_path=chunked_path, + num_workers=0, + ), + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": chunked_path, + "num_runs": num_runs, + } + ) + + # Configuration 4: Remote chunked SCDL + remote_factory = None + if remote_path: + storage_options = { + "default_fill_cache": False, # Don't cache file contents in memory + "default_cache_type": "none", # No block caching + "config_kwargs": {"max_pool_connections": 100}, # More parallel connections + } + if endpoint_url: + storage_options["client_kwargs"] = {"endpoint_url": endpoint_url} + + print(f"Adding Remote Chunked SCDL: {remote_path}") + if endpoint_url: + print(f" Endpoint: {endpoint_url}") + print(f" Cache dir: {cache_dir}") + print(f" Max cached chunks: {max_cached_chunks}") + remote_factory = create_remote_chunked_scdl_factory( + batch_size=batch_size, + remote_path=remote_path, + cache_dir=cache_dir, + max_cached_chunks=max_cached_chunks, + storage_options=storage_options, + num_workers=0, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=chunks_per_window, + batch_download_size=max_cached_chunks, # Download batch = cache size + ) + + configurations.append( + { + "name": f"Chunked_SCDL_Remote_cache{max_cached_chunks}_window{chunks_per_window}", + "dataloader_factory": remote_factory, + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": remote_path, + "num_runs": num_runs, + } + ) + + if not configurations: + print("ERROR: No configurations to benchmark. Provide --scdl-path, --chunked-path, or --remote-path") + return + + print() + print(f"Running {len(configurations)} configuration(s)...") + print() + + results = benchmark_dataloaders_with_configs( + dataloader_configs=configurations, + shared_dataset_factory=None, # Each config loads its own dataset + output_prefix=f"chunked_scdl_benchmark_{timestamp}", + ) + + print() + print("=" * 80) + print("RESULTS SUMMARY") + print("=" * 80) + print_comparison(results) + + # Print remote download statistics if available + if remote_factory: + # Record wall_clock_time from benchmark results (since sampler may not complete naturally) + for result in results: + if "Remote" in result.name and remote_factory.last_dataset and remote_factory.last_dataset._chunk_loader: + remote_factory.last_dataset._chunk_loader.stats.record_wall_clock(result.total_time_seconds) + break + remote_factory.print_download_stats() + + print() + print(f"Results saved to: chunked_scdl_benchmark_{timestamp}_detailed_breakdown.csv") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark Regular SCDL vs Chunked SCDL with ChunkAwareSampler", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Benchmark local regular vs local chunked + %(prog)s --scdl-path /data/scdl/ --chunked-path /data/chunked_scdl/ + + # Benchmark remote chunked dataset + %(prog)s --remote-path s3://bucket/chunked --endpoint-url https://s3.example.com + + # Full comparison + %(prog)s --scdl-path /data/scdl/ --chunked-path /data/chunked/ --remote-path s3://bucket/chunked + """, + ) + + # Data paths - with sensible defaults + parser.add_argument( + "--scdl-path", + type=str, + default="", # "/home/pbinder/bionemo-framework/small_tahoe_format", + help="Path to regular SCDL dataset (baseline)", + ) + parser.add_argument( + "--chunked-path", + type=str, + default="", # "/home/pbinder/bionemo-framework/sub-packages/bionemo-scspeedtest/example_data/tahoe_chunked", + help="Path to pre-converted chunked SCDL dataset", + ) + parser.add_argument( + "--remote-path", + type=str, + default="s3://general-purpose/polina/tahoe_chunked", + help="Remote path to chunked dataset (s3://, gs://)", + ) + parser.add_argument("--endpoint-url", type=str, default="https://pbss.s8k.io", help="Custom S3 endpoint URL") + + # Cache settings + parser.add_argument("--cache-dir", type=str, default="/tmp/scdl_cache", help="Local cache for remote chunks") + parser.add_argument( + "--max-cached-chunks", + type=int, + default=2, + help="Max chunks in LRU cache (>= 2x chunks-per-window for prefetching)", + ) + + # Chunking settings + parser.add_argument("--chunks-per-window", type=int, default=1, help="Chunks per sampling window") + + # Benchmark settings + parser.add_argument("--num-epochs", type=int, default=1, help="Epochs per configuration") + parser.add_argument("--num-runs", type=int, default=1, help="Runs per configuration") + parser.add_argument("--max-time", type=float, default=120, help="Max time per config (seconds)") + parser.add_argument("--warmup-time", type=float, default=0, help="Warmup time per config (seconds)") + parser.add_argument("--batch-size", type=int, default=64, help="Batch size") + + args = parser.parse_args() + + # Validate paths exist (for local paths) + if args.scdl_path and not os.path.exists(args.scdl_path): + print(f"Warning: SCDL path not found: {args.scdl_path}, skipping...") + args.scdl_path = None + if args.chunked_path and not os.path.exists(args.chunked_path): + print(f"Warning: Chunked path not found: {args.chunked_path}, skipping...") + args.chunked_path = None + + # Check at least one valid config + if not any([args.scdl_path, args.chunked_path, args.remote_path]): + parser.error("No valid data paths found. Provide --scdl-path, --chunked-path, or --remote-path") + + chunked_scdl_benchmarking( + num_epochs=args.num_epochs, + num_runs=args.num_runs, + scdl_path=args.scdl_path, + chunked_path=args.chunked_path, + remote_path=args.remote_path, + endpoint_url=args.endpoint_url, + cache_dir=args.cache_dir, + max_cached_chunks=args.max_cached_chunks, + chunks_per_window=args.chunks_per_window, + max_time_seconds=args.max_time, + warmup_time_seconds=args.warmup_time, + batch_size=args.batch_size, + ) diff --git a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py index 6bf4e8acb2..707f05266f 100644 --- a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py +++ b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py @@ -301,7 +301,7 @@ def config_dataloader_from_dataset(): dataloader_factory = config_dataloader_from_dataset else: dataloader_factory = config_dataloader_factory - + start_time = time.perf_counter() result = benchmark_single_dataloader( dataloader_factory=dataloader_factory, data_path=dl_config.get("data_path", None), @@ -318,10 +318,11 @@ def config_dataloader_from_dataset(): output_prefix=output_prefix, dataset_instantiation_time=shared_dataset_time, ) + end_time = time.perf_counter() + result.overall_time_seconds = end_time - start_time # If this hasn't been set, set it to the minimum in the first dataloader if not shared_dataset_baseline: shared_dataset_baseline = result.memory_before_instantiation_mb - print_results(result) if isinstance(result, list): for r in result: @@ -407,7 +408,10 @@ def dataloader_from_dataset(): else 0, # Combined time when no separate dataset factory "dataloader_instantiation_time_seconds": setup_time, } - disk_size_mb = get_disk_size(data_path) + if isinstance(data_path, str) and data_path.startswith("s3"): + disk_size_mb = None + else: + disk_size_mb = get_disk_size(data_path) results = [] for run_idx in range(num_runs): @@ -460,11 +464,15 @@ def print_results(result_or_results: Union[BenchmarkResult, List[BenchmarkResult print(f"Samples/sec: {result.samples_per_second:.2f}") print(f"Total samples: {result.total_samples}") print(f"Total time: {result.total_time_seconds:.3f}s") + print(f"Overall time: {result.overall_time_seconds:.3f}s") print(f"Dataset instantiation: {result.dataset_instantiation_time_seconds:.3f}s") print(f"Dataloader instantiation: {result.dataloader_instantiation_time_seconds:.3f}s") print(f"Peak memory durint iteration: {result.peak_memory_mb:.1f} MB") print(f"Peak memory during instantiation: {result.peak_memory_during_instantiation_mb:.1f} MB") - print(f"Disk size: {result.disk_size_mb:.1f} MB") + if result.disk_size_mb is not None: + print(f"Disk size: {result.disk_size_mb:.1f} MB") + else: + print("Disk size: N/A") print("=" * 60 + "\n") diff --git a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py index b43798900b..ff279726d2 100644 --- a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py +++ b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py @@ -87,10 +87,12 @@ class BenchmarkResult: peak_memory_mb: float = 0.0 avg_memory_mb: float = 0.0 disk_size_mb: float = 0.0 + overall_time_seconds: float = 0.0 def __post_init__(self): """Calculate derived metrics from epoch results.""" self.total_samples = sum(r["samples"] for r in self.epoch_results) + print("Elapsed times: ", [r["elapsed"] for r in self.epoch_results]) self.total_time_seconds = sum(r["elapsed"] for r in self.epoch_results) self.samples_per_second = self.total_samples / self.total_time_seconds if self.total_time_seconds > 0 else 0.0 self.peak_memory_mb = max(r["peak_memory"] for r in self.epoch_results) - self.memory_before_instantiation_mb From 0c8e4239dda15bd5a93aafed654e6e54cd15b90c Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 22 Jan 2026 18:27:36 -0800 Subject: [PATCH 4/5] Moving tracking to a seperate branch --- .../examples/chunked_scdl_benchmark.py | 526 ------------------ .../src/bionemo/scspeedtest/benchmark.py | 16 +- .../src/bionemo/scspeedtest/common.py | 2 - 3 files changed, 4 insertions(+), 540 deletions(-) delete mode 100644 sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py diff --git a/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py b/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py deleted file mode 100644 index 3dacf83dd8..0000000000 --- a/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py +++ /dev/null @@ -1,526 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -r"""Benchmark comparing Regular SCDL vs Chunked SCDL vs Remote Chunked SCDL. - -Usage (with defaults): - python chunked_scdl_benchmark.py - -Custom paths: - python chunked_scdl_benchmark.py \ - --scdl-path /path/to/scdl/ \ - --chunked-path /path/to/chunked/ \ - --remote-path s3://bucket/chunked \ - --endpoint-url https://your-s3-endpoint.com - -This script benchmarks: -1. Regular SCDL - Standard DataLoader with shuffle (baseline) -2. Chunked SCDL (local) - Pre-converted chunked dataset with ChunkAwareSampler -3. Remote Chunked SCDL - S3/GCS with LRU caching and ChunkAwareSampler -""" - -import argparse -import os -import time -from datetime import datetime - -from torch.utils.data import DataLoader - -from bionemo.scdl.io.chunk_sampler import ChunkAwareSampler -from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset -from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch -from bionemo.scspeedtest import benchmark_dataloaders_with_configs, print_comparison - - -def create_regular_scdl_factory( - batch_size: int = 64, shuffle: bool = True, data_path: str | None = None, num_workers: int = 0 -): - """Create a regular SCDL dataloader factory (baseline).""" - - def factory(): - dataset = SingleCellMemMapDataset(data_path) - return DataLoader( - dataset, - batch_size=batch_size, - shuffle=shuffle, - drop_last=False, - collate_fn=collate_sparse_matrix_batch, - num_workers=num_workers, - ) - - return factory - - -def create_chunked_scdl_preconverted_factory( - batch_size: int = 64, - chunked_path: str | None = None, - num_workers: int = 0, - shuffle_chunks: bool = True, - shuffle_within_window: bool = True, - chunks_per_window: int = 2, -): - """Create a chunked SCDL dataloader factory from pre-converted chunked dataset.""" - - def factory(): - dataset = SingleCellMemMapDataset(chunked_path) - - sampler = ChunkAwareSampler( - dataset, - shuffle_chunks=shuffle_chunks, - shuffle_within_window=shuffle_within_window, - chunks_per_window=chunks_per_window, - ) - - return DataLoader( - dataset, - sampler=sampler, - batch_size=batch_size, - collate_fn=collate_sparse_matrix_batch, - num_workers=num_workers, - ) - - return factory - - -def create_chunked_scdl_random_factory( - batch_size: int = 64, - chunked_path: str | None = None, - num_workers: int = 0, -): - """Create a chunked SCDL dataloader with random shuffle (no ChunkAwareSampler).""" - start_time = time.perf_counter() - - def factory(): - dataset = SingleCellMemMapDataset(chunked_path) - - return DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, # Standard random shuffle - worst case for I/O locality - collate_fn=collate_sparse_matrix_batch, - num_workers=num_workers, - ) - - end_time = time.perf_counter() - print(f"Time taken to instantiate chunked SCDL dataset: {end_time - start_time:.2f} seconds") - return factory - - -class RemoteDataloaderFactory: - """Factory that tracks the last created dataset for stats access.""" - - def __init__( - self, - batch_size: int = 64, - remote_path: str | None = None, - cache_dir: str | None = None, - max_cached_chunks: int = 3, - storage_options: dict | None = None, - num_workers: int = 0, - shuffle_chunks: bool = True, - shuffle_within_window: bool = True, - chunks_per_window: int = 2, - batch_download_size: int = 30, - ): - """Initialize the remote dataloader factory with configuration.""" - self.batch_size = batch_size - self.remote_path = remote_path - self.cache_dir = cache_dir - self.max_cached_chunks = max_cached_chunks - self.storage_options = storage_options - self.num_workers = num_workers - self.shuffle_chunks = shuffle_chunks - self.shuffle_within_window = shuffle_within_window - self.chunks_per_window = chunks_per_window - self.batch_download_size = batch_download_size - self.last_dataset = None # Store reference for stats access - # Timing breakdown - self.cache_clear_time = 0.0 - self.dataset_init_time = 0.0 - self.sampler_init_time = 0.0 - self.dataloader_init_time = 0.0 - self.total_init_time = 0.0 - - def __call__(self): - """Create a new dataloader.""" - import shutil - - # Clear cache before each run to measure streaming performance - import time - - total_start = time.perf_counter() - - # 1. Clear cache - if self.cache_dir: - print(f"Clearing cache directory: {self.cache_dir}") - t0 = time.perf_counter() - shutil.rmtree(self.cache_dir, ignore_errors=True) - t1 = time.perf_counter() - self.cache_clear_time = t1 - t0 - print(f"Cache clear time: {self.cache_clear_time:.3f} sec") - - # 2. Create dataset - print("Instantiating SingleCellMemMapDataset.from_remote ...") - t2 = time.perf_counter() - self.last_dataset = SingleCellMemMapDataset.from_remote( - self.remote_path, - cache_dir=self.cache_dir, - max_cached_chunks=self.max_cached_chunks, - storage_options=self.storage_options, - batch_download_size=self.batch_download_size, - use_async_downloads=True, - ) - t3 = time.perf_counter() - self.dataset_init_time = t3 - t2 - print(f"SingleCellMemMapDataset.from_remote time: {self.dataset_init_time:.3f} sec") - - # 3. Create sampler - print("Instantiating ChunkAwareSampler ...") - t4 = time.perf_counter() - sampler = ChunkAwareSampler( - self.last_dataset, - shuffle_chunks=self.shuffle_chunks, - shuffle_within_window=self.shuffle_within_window, - chunks_per_window=self.chunks_per_window, - ) - t5 = time.perf_counter() - self.sampler_init_time = t5 - t4 - print(f"ChunkAwareSampler instantiation time: {self.sampler_init_time:.3f} sec") - - # 4. Create DataLoader - print("Instantiating DataLoader ...") - t6 = time.perf_counter() - dataloader = DataLoader( - self.last_dataset, - sampler=sampler, - batch_size=self.batch_size, - collate_fn=collate_sparse_matrix_batch, - num_workers=self.num_workers, - ) - t7 = time.perf_counter() - self.dataloader_init_time = t7 - t6 - self.total_init_time = t7 - total_start - print(f"DataLoader instantiation time: {self.dataloader_init_time:.3f} sec") - print(f"Total init time: {self.total_init_time:.3f} sec") - - return dataloader - - def get_download_stats(self) -> dict | None: - """Get download statistics from the last created dataset.""" - if self.last_dataset and hasattr(self.last_dataset, "_chunk_loader"): - return self.last_dataset._chunk_loader.stats.summary() - return None - - def print_download_stats(self): - """Print download statistics.""" - stats = self.get_download_stats() - if stats: - iteration_time = stats["wall_clock_time_sec"] - total_time = self.total_init_time + iteration_time - - print("\n" + "=" * 50) - print("REMOTE DOWNLOAD STATS") - print("=" * 50) - print(f"Init: {self.total_init_time:.1f}s | Iteration: {iteration_time:.1f}s | Total: {total_time:.1f}s") - print(f"Cold start: {stats['cold_start_time_sec']:.1f}s | Wait time: {stats['total_wait_time_sec']:.1f}s") - print( - f"Downloaded: {stats['total_bytes_downloaded_mb']:.0f} MB ({stats['download_count']} chunks, {stats['cache_hits']} cache hits)" - ) - print( - f"Throughput: {stats['throughput_mbps']:.1f} MB/s effective, {stats['per_thread_throughput_mbps']:.1f} MB/s per-thread" - ) - - -def create_remote_chunked_scdl_factory( - batch_size: int = 64, - remote_path: str | None = None, - cache_dir: str | None = None, - max_cached_chunks: int = 3, - storage_options: dict | None = None, - num_workers: int = 0, - shuffle_chunks: bool = True, - shuffle_within_window: bool = True, - chunks_per_window: int = 2, - batch_download_size: int = 30, -) -> RemoteDataloaderFactory: - """Create a remote chunked SCDL dataloader factory with ChunkAwareSampler.""" - return RemoteDataloaderFactory( - batch_size=batch_size, - remote_path=remote_path, - cache_dir=cache_dir, - max_cached_chunks=max_cached_chunks, - storage_options=storage_options, - num_workers=num_workers, - shuffle_chunks=shuffle_chunks, - shuffle_within_window=shuffle_within_window, - chunks_per_window=chunks_per_window, - batch_download_size=batch_download_size, - ) - - -def chunked_scdl_benchmarking( - num_epochs: int = 1, - num_runs: int = 1, - scdl_path: str | None = None, - chunked_path: str | None = None, - remote_path: str | None = None, - endpoint_url: str | None = None, - cache_dir: str = "/tmp/scdl_cache", - max_cached_chunks: int = 3, - chunks_per_window: int = 2, - max_time_seconds: float = 120.0, - warmup_time_seconds: float = 30.0, - batch_size: int = 64, -): - """Run benchmarks comparing regular SCDL vs chunked SCDL. - - Args: - num_epochs: Number of epochs per configuration - num_runs: Number of runs per configuration - scdl_path: Path to regular (non-chunked) SCDL dataset - chunked_path: Path to pre-converted chunked SCDL dataset (optional) - remote_path: Remote path to chunked dataset (s3://, gs://, etc.) - endpoint_url: Custom S3 endpoint URL (for non-AWS S3) - cache_dir: Local cache directory for remote chunks - max_cached_chunks: Max chunks to keep in LRU cache - chunks_per_window: Chunks per sampling window - max_time_seconds: Max time per configuration - warmup_time_seconds: Warmup time per configuration - batch_size: Batch size for dataloaders - """ - print("=" * 80) - print("CHUNKED SCDL BENCHMARKING") - print("=" * 80) - print() - - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - configurations = [] - - # Configuration 1: Regular SCDL baseline - if scdl_path: - print(f"Adding Regular SCDL baseline: {scdl_path}") - configurations.append( - { - "name": "Regular_SCDL_Baseline", - "dataloader_factory": create_regular_scdl_factory( - batch_size=batch_size, shuffle=True, data_path=scdl_path, num_workers=0 - ), - "num_epochs": num_epochs, - "max_time_seconds": max_time_seconds, - "warmup_time_seconds": warmup_time_seconds, - "data_path": scdl_path, - "num_runs": num_runs, - } - ) - - # Configuration 2: Pre-converted chunked SCDL with ChunkAwareSampler - if chunked_path: - print(f"Adding Chunked SCDL + ChunkAwareSampler: {chunked_path}") - configurations.append( - { - "name": f"Chunked_SCDL_ChunkAware_window{chunks_per_window}", - "dataloader_factory": create_chunked_scdl_preconverted_factory( - batch_size=batch_size, - chunked_path=chunked_path, - num_workers=0, - shuffle_chunks=True, - shuffle_within_window=True, - chunks_per_window=chunks_per_window, - ), - "num_epochs": num_epochs, - "max_time_seconds": max_time_seconds, - "warmup_time_seconds": warmup_time_seconds, - "data_path": chunked_path, - "num_runs": num_runs, - } - ) - - # Configuration 3: Chunked SCDL with random shuffle (no ChunkAwareSampler) - print(f"Adding Chunked SCDL + Random Shuffle: {chunked_path}") - configurations.append( - { - "name": "Chunked_SCDL_RandomShuffle", - "dataloader_factory": create_chunked_scdl_random_factory( - batch_size=batch_size, - chunked_path=chunked_path, - num_workers=0, - ), - "num_epochs": num_epochs, - "max_time_seconds": max_time_seconds, - "warmup_time_seconds": warmup_time_seconds, - "data_path": chunked_path, - "num_runs": num_runs, - } - ) - - # Configuration 4: Remote chunked SCDL - remote_factory = None - if remote_path: - storage_options = { - "default_fill_cache": False, # Don't cache file contents in memory - "default_cache_type": "none", # No block caching - "config_kwargs": {"max_pool_connections": 100}, # More parallel connections - } - if endpoint_url: - storage_options["client_kwargs"] = {"endpoint_url": endpoint_url} - - print(f"Adding Remote Chunked SCDL: {remote_path}") - if endpoint_url: - print(f" Endpoint: {endpoint_url}") - print(f" Cache dir: {cache_dir}") - print(f" Max cached chunks: {max_cached_chunks}") - remote_factory = create_remote_chunked_scdl_factory( - batch_size=batch_size, - remote_path=remote_path, - cache_dir=cache_dir, - max_cached_chunks=max_cached_chunks, - storage_options=storage_options, - num_workers=0, - shuffle_chunks=True, - shuffle_within_window=True, - chunks_per_window=chunks_per_window, - batch_download_size=max_cached_chunks, # Download batch = cache size - ) - - configurations.append( - { - "name": f"Chunked_SCDL_Remote_cache{max_cached_chunks}_window{chunks_per_window}", - "dataloader_factory": remote_factory, - "num_epochs": num_epochs, - "max_time_seconds": max_time_seconds, - "warmup_time_seconds": warmup_time_seconds, - "data_path": remote_path, - "num_runs": num_runs, - } - ) - - if not configurations: - print("ERROR: No configurations to benchmark. Provide --scdl-path, --chunked-path, or --remote-path") - return - - print() - print(f"Running {len(configurations)} configuration(s)...") - print() - - results = benchmark_dataloaders_with_configs( - dataloader_configs=configurations, - shared_dataset_factory=None, # Each config loads its own dataset - output_prefix=f"chunked_scdl_benchmark_{timestamp}", - ) - - print() - print("=" * 80) - print("RESULTS SUMMARY") - print("=" * 80) - print_comparison(results) - - # Print remote download statistics if available - if remote_factory: - # Record wall_clock_time from benchmark results (since sampler may not complete naturally) - for result in results: - if "Remote" in result.name and remote_factory.last_dataset and remote_factory.last_dataset._chunk_loader: - remote_factory.last_dataset._chunk_loader.stats.record_wall_clock(result.total_time_seconds) - break - remote_factory.print_download_stats() - - print() - print(f"Results saved to: chunked_scdl_benchmark_{timestamp}_detailed_breakdown.csv") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark Regular SCDL vs Chunked SCDL with ChunkAwareSampler", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Benchmark local regular vs local chunked - %(prog)s --scdl-path /data/scdl/ --chunked-path /data/chunked_scdl/ - - # Benchmark remote chunked dataset - %(prog)s --remote-path s3://bucket/chunked --endpoint-url https://s3.example.com - - # Full comparison - %(prog)s --scdl-path /data/scdl/ --chunked-path /data/chunked/ --remote-path s3://bucket/chunked - """, - ) - - # Data paths - with sensible defaults - parser.add_argument( - "--scdl-path", - type=str, - default="", # "/home/pbinder/bionemo-framework/small_tahoe_format", - help="Path to regular SCDL dataset (baseline)", - ) - parser.add_argument( - "--chunked-path", - type=str, - default="", # "/home/pbinder/bionemo-framework/sub-packages/bionemo-scspeedtest/example_data/tahoe_chunked", - help="Path to pre-converted chunked SCDL dataset", - ) - parser.add_argument( - "--remote-path", - type=str, - default="s3://general-purpose/polina/tahoe_chunked", - help="Remote path to chunked dataset (s3://, gs://)", - ) - parser.add_argument("--endpoint-url", type=str, default="https://pbss.s8k.io", help="Custom S3 endpoint URL") - - # Cache settings - parser.add_argument("--cache-dir", type=str, default="/tmp/scdl_cache", help="Local cache for remote chunks") - parser.add_argument( - "--max-cached-chunks", - type=int, - default=2, - help="Max chunks in LRU cache (>= 2x chunks-per-window for prefetching)", - ) - - # Chunking settings - parser.add_argument("--chunks-per-window", type=int, default=1, help="Chunks per sampling window") - - # Benchmark settings - parser.add_argument("--num-epochs", type=int, default=1, help="Epochs per configuration") - parser.add_argument("--num-runs", type=int, default=1, help="Runs per configuration") - parser.add_argument("--max-time", type=float, default=120, help="Max time per config (seconds)") - parser.add_argument("--warmup-time", type=float, default=0, help="Warmup time per config (seconds)") - parser.add_argument("--batch-size", type=int, default=64, help="Batch size") - - args = parser.parse_args() - - # Validate paths exist (for local paths) - if args.scdl_path and not os.path.exists(args.scdl_path): - print(f"Warning: SCDL path not found: {args.scdl_path}, skipping...") - args.scdl_path = None - if args.chunked_path and not os.path.exists(args.chunked_path): - print(f"Warning: Chunked path not found: {args.chunked_path}, skipping...") - args.chunked_path = None - - # Check at least one valid config - if not any([args.scdl_path, args.chunked_path, args.remote_path]): - parser.error("No valid data paths found. Provide --scdl-path, --chunked-path, or --remote-path") - - chunked_scdl_benchmarking( - num_epochs=args.num_epochs, - num_runs=args.num_runs, - scdl_path=args.scdl_path, - chunked_path=args.chunked_path, - remote_path=args.remote_path, - endpoint_url=args.endpoint_url, - cache_dir=args.cache_dir, - max_cached_chunks=args.max_cached_chunks, - chunks_per_window=args.chunks_per_window, - max_time_seconds=args.max_time, - warmup_time_seconds=args.warmup_time, - batch_size=args.batch_size, - ) diff --git a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py index 707f05266f..6bf4e8acb2 100644 --- a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py +++ b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/benchmark.py @@ -301,7 +301,7 @@ def config_dataloader_from_dataset(): dataloader_factory = config_dataloader_from_dataset else: dataloader_factory = config_dataloader_factory - start_time = time.perf_counter() + result = benchmark_single_dataloader( dataloader_factory=dataloader_factory, data_path=dl_config.get("data_path", None), @@ -318,11 +318,10 @@ def config_dataloader_from_dataset(): output_prefix=output_prefix, dataset_instantiation_time=shared_dataset_time, ) - end_time = time.perf_counter() - result.overall_time_seconds = end_time - start_time # If this hasn't been set, set it to the minimum in the first dataloader if not shared_dataset_baseline: shared_dataset_baseline = result.memory_before_instantiation_mb + print_results(result) if isinstance(result, list): for r in result: @@ -408,10 +407,7 @@ def dataloader_from_dataset(): else 0, # Combined time when no separate dataset factory "dataloader_instantiation_time_seconds": setup_time, } - if isinstance(data_path, str) and data_path.startswith("s3"): - disk_size_mb = None - else: - disk_size_mb = get_disk_size(data_path) + disk_size_mb = get_disk_size(data_path) results = [] for run_idx in range(num_runs): @@ -464,15 +460,11 @@ def print_results(result_or_results: Union[BenchmarkResult, List[BenchmarkResult print(f"Samples/sec: {result.samples_per_second:.2f}") print(f"Total samples: {result.total_samples}") print(f"Total time: {result.total_time_seconds:.3f}s") - print(f"Overall time: {result.overall_time_seconds:.3f}s") print(f"Dataset instantiation: {result.dataset_instantiation_time_seconds:.3f}s") print(f"Dataloader instantiation: {result.dataloader_instantiation_time_seconds:.3f}s") print(f"Peak memory durint iteration: {result.peak_memory_mb:.1f} MB") print(f"Peak memory during instantiation: {result.peak_memory_during_instantiation_mb:.1f} MB") - if result.disk_size_mb is not None: - print(f"Disk size: {result.disk_size_mb:.1f} MB") - else: - print("Disk size: N/A") + print(f"Disk size: {result.disk_size_mb:.1f} MB") print("=" * 60 + "\n") diff --git a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py index ff279726d2..b43798900b 100644 --- a/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py +++ b/sub-packages/bionemo-scspeedtest/src/bionemo/scspeedtest/common.py @@ -87,12 +87,10 @@ class BenchmarkResult: peak_memory_mb: float = 0.0 avg_memory_mb: float = 0.0 disk_size_mb: float = 0.0 - overall_time_seconds: float = 0.0 def __post_init__(self): """Calculate derived metrics from epoch results.""" self.total_samples = sum(r["samples"] for r in self.epoch_results) - print("Elapsed times: ", [r["elapsed"] for r in self.epoch_results]) self.total_time_seconds = sum(r["elapsed"] for r in self.epoch_results) self.samples_per_second = self.total_samples / self.total_time_seconds if self.total_time_seconds > 0 else 0.0 self.peak_memory_mb = max(r["peak_memory"] for r in self.epoch_results) - self.memory_before_instantiation_mb From 42d429bfc991a39fd8c6679a560d0e976c33f20a Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 22 Jan 2026 18:32:20 -0800 Subject: [PATCH 5/5] Bring partition_scdl changes from scdl_remote_profile --- .../src/bionemo/scdl/util/partition_scdl.py | 39 ++++--------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py index 946d48b599..e1f1578858 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py @@ -29,18 +29,8 @@ def partition_scdl( output_path: Path, chunk_size: int = 100_000, delete_original: bool = False, - compressed: bool = False, ) -> SCDLHeader: - """Partition an SCDL dataset into chunks. - - Args: - input_path: Path to source SCDL dataset. - output_path: Path for output chunked dataset. - chunk_size: Number of rows per chunk. - delete_original: Whether to delete the source after partitioning. - compressed: If True, save each chunk as a single compressed .npz file - (faster for remote access - 3x fewer HTTP requests). - """ + """Partition an SCDL dataset into chunks.""" from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset input_path, output_path = Path(input_path), Path(output_path) @@ -71,27 +61,14 @@ def partition_scdl( data_start, data_end = int(rowptr[row_start]), int(rowptr[row_end]) - # Extract chunk data + # Write chunk files using memmap slicing chunk_rowptr = rowptr[row_start : row_end + 1] - data_start - chunk_data = np.array(source_ds.data[data_start:data_end]) - chunk_colptr = np.array(source_ds.col_index[data_start:data_end]) - - if compressed: - # Single compressed file (faster for remote access) - np.savez_compressed( - chunk_dir / "chunk.npz", - data=chunk_data, - row_ptr=chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]), - col_ptr=chunk_colptr, - ) - else: - # Separate files (original format) - with open(chunk_dir / FileNames.ROWPTR.value, "wb") as f: - f.write(chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]).tobytes()) - with open(chunk_dir / FileNames.DATA.value, "wb") as f: - f.write(chunk_data.tobytes()) - with open(chunk_dir / FileNames.COLPTR.value, "wb") as f: - f.write(chunk_colptr.tobytes()) + with open(chunk_dir / FileNames.ROWPTR.value, "wb") as f: + f.write(chunk_rowptr.astype(source_ds.dtypes[FileNames.ROWPTR.value]).tobytes()) + with open(chunk_dir / FileNames.DATA.value, "wb") as f: + f.write(np.array(source_ds.data[data_start:data_end]).tobytes()) + with open(chunk_dir / FileNames.COLPTR.value, "wb") as f: + f.write(np.array(source_ds.col_index[data_start:data_end]).tobytes()) # Copy features and metadata for name in [FileNames.VAR_FEATURES.value, FileNames.OBS_FEATURES.value]: