From 51fd0e3fdcaa3614d6b9c7092e3661df0e2fa8ea Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Sun, 18 Jan 2026 17:35:04 -0800 Subject: [PATCH 1/4] more chunked implementation --- .../scdl/io/single_cell_memmap_dataset.py | 112 +++++++++++++----- .../src/bionemo/scdl/util/partition_scdl.py | 14 ++- .../tests/bionemo/scdl/conftest.py | 19 ++- 3 files changed, 111 insertions(+), 34 deletions(-) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py index 89fa3f9d61..8f9a181f5f 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py @@ -18,6 +18,7 @@ import logging import os import shutil +import tempfile import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -41,6 +42,7 @@ determine_dtype, smallest_uint_dtype, ) +from bionemo.scdl.util.partition_scdl import partition_scdl from bionemo.scdl.util.scdl_constants import FLOAT_ORDER, INT_ORDER, FileNames, Mode, NeighborSamplingStrategy @@ -128,6 +130,8 @@ def __init__( self.data_path: str = data_path self.header: SCDLHeader = None self.mode: Mode = mode + self._is_chunked: bool = False + self._chunks: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = [] self.paginated_load_cutoff = paginated_load_cutoff self.load_block_row_size = load_block_row_size self.var_feature_index_name = var_feature_index_name @@ -436,10 +440,16 @@ def get_row( List[np.ndarray]: optional, corresponding variable (column) features. List[np.ndarray]: optional, corresponding observed (row) features. """ - start = self.row_index[index] - end = self.row_index[index + 1] - values = self.data[start:end] - columns = self.col_index[start:end] + if self._is_chunked: + chunk_id, local_idx = self.header.chunked_info.get_chunk_for_row(index) + data, rowptr, colptr = self._chunks[chunk_id] + start, end = rowptr[local_idx], rowptr[local_idx + 1] + values, columns = data[start:end], colptr[start:end] + else: + start = self.row_index[index] + end = self.row_index[index + 1] + values = self.data[start:end] + columns = self.col_index[start:end] ret = (values, columns) var_features = ( self._var_feature_index.lookup(index, select_features=var_feature_names)[0] @@ -685,37 +695,52 @@ def load(self, stored_path: str) -> None: raise ValueError(f"Array name {FileNames[array_info.name].value} not found in dtypes") self.dtypes[FileNames[array_info.name].value] = array_info.dtype.numpy_dtype_string - # Metadata is required, so we must check if it exists and fail if not. - if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"): - raise FileNotFoundError( - f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist." - ) - - with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi: - self.metadata = json.load(mfi) + # Load metadata if exists + metadata_path = f"{self.data_path}/{FileNames.METADATA.value}" + if os.path.exists(metadata_path): + with open(metadata_path, Mode.READ_APPEND.value) as mfi: + self.metadata = json.load(mfi) + # Load feature indices if os.path.exists(f"{self.data_path}/{FileNames.VAR_FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.VAR_FEATURES.value}") - elif os.path.exists( - f"{self.data_path}/{FileNames.FEATURES.value}" - ): # Backward compatibility with old features file + elif os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}") if os.path.exists(f"{self.data_path}/{FileNames.OBS_FEATURES.value}"): self._obs_feature_index = ObservedFeatureIndex.load(f"{self.data_path}/{FileNames.OBS_FEATURES.value}") - # mmap the existing arrays - self.data = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] - ) - self.row_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] - ) - self.col_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] - ) - # Load neighbor data - if self.load_neighbors: - self._load_neighbor_memmaps() + # Load data arrays - chunked vs monolithic + if self.header is not None and self.header.backend == Backend.CHUNKED_MEMMAP_V0: + self._is_chunked = True + self._load_chunk_memmaps() + else: + self.data = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] + ) + self.row_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] + ) + self.col_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] + ) + if self.load_neighbors: + self._load_neighbor_memmaps() + + def _load_chunk_memmaps(self) -> None: + """Preload all chunk memmaps (lazy - just file handles, no RAM).""" + for chunk_id in range(self.header.chunked_info.num_chunks): + chunk_path = Path(self.data_path) / f"chunk_{chunk_id:05d}" + self._chunks.append( + ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"), + np.memmap( + chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r" + ), + np.memmap( + chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r" + ), + ) + ) def _write_metadata(self) -> None: with open(f"{self.data_path}/{FileNames.METADATA.value}", f"{Mode.CREATE.value}") as mfi: @@ -1218,6 +1243,8 @@ def number_of_rows(self) -> int: ValueError if the length of the number of rows in the feature index does not correspond to the number of stored rows. """ + if self._is_chunked: + return self.header.chunked_info.total_rows if len(self._var_feature_index) > 0 and self._var_feature_index.number_of_rows() != self.row_index.size - 1: raise ValueError( f"""The number of rows in the feature index {self._var_feature_index.number_of_rows()} @@ -1445,3 +1472,32 @@ def concat( mode=Mode.READ_APPEND.value, ) self.save() + + def to_chunked( + self, output_path: Optional[str] = None, chunk_size: int = 100_000, delete_original: bool = False + ) -> "SingleCellMemMapDataset": + """Convert this dataset to a chunked format for efficient remote access. + + Args: + output_path: Path where the chunked dataset will be created. If None, replaces in-place. + chunk_size: Number of rows per chunk (default: 100,000). + delete_original: If True and output_path is set, delete the original after conversion. + + Returns: + A new SingleCellMemMapDataset instance pointing to the chunked data. + """ + if self._is_chunked: + raise ValueError("Dataset is already chunked") + + src = Path(self.data_path) + if output_path is None: + # In-place: partition to temp, then swap + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / "chunked" + partition_scdl(src, tmp_path, chunk_size=chunk_size) + shutil.rmtree(src) + shutil.move(str(tmp_path), str(src)) + return SingleCellMemMapDataset(str(src)) + + partition_scdl(src, Path(output_path), chunk_size=chunk_size, delete_original=delete_original) + return SingleCellMemMapDataset(output_path) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py index 1affa2a596..e1f1578858 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py @@ -20,7 +20,6 @@ import numpy as np -from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset from bionemo.scdl.schema.header import ChunkedInfo, SCDLHeader from bionemo.scdl.util.scdl_constants import Backend, FileNames @@ -29,8 +28,11 @@ def partition_scdl( input_path: Path, output_path: Path, chunk_size: int = 100_000, + delete_original: bool = False, ) -> SCDLHeader: """Partition an SCDL dataset into chunks.""" + from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + input_path, output_path = Path(input_path), Path(output_path) if not input_path.exists(): @@ -44,7 +46,11 @@ def partition_scdl( source_ds = SingleCellMemMapDataset(str(input_path)) total_rows = len(source_ds) rowptr = source_ds.row_index - num_chunks = (total_rows + chunk_size - 1) // chunk_size + if chunk_size <= 0: + raise ValueError(f"Chunk size must be greater than 0, got {chunk_size}") + if total_rows <= 0: + raise ValueError(f"Total rows must be greater than 0, got {total_rows}") + num_chunks = max(1, (total_rows + chunk_size - 1) // chunk_size) # Create chunks for chunk_id in range(num_chunks): @@ -78,4 +84,8 @@ def partition_scdl( header.chunked_info = ChunkedInfo(chunk_size=chunk_size, num_chunks=num_chunks, total_rows=total_rows) header.save(str(output_path / FileNames.HEADER.value)) + if delete_original: + del source_ds # Release memmap handles + shutil.rmtree(input_path) + return header diff --git a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py index 3b8e934471..7152d7e53f 100644 --- a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py +++ b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py @@ -199,13 +199,24 @@ def _make(tmp_path): @pytest.fixture def make_h5ad_with_raw(make_random_csr): - """Factory to create an h5ad with uniquely randomized data for the fields .raw.X and .X""" + """Factory to create an h5ad with uniquely randomized data for .raw.X, .X, obs, and var.""" def _make(tmp_path): - X = make_random_csr(total_nnz=100, n_cols=50, seed=42) - X_raw = make_random_csr(total_nnz=100, n_cols=50, seed=43) + n_rows, n_cols = 100, 50 + X = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=42) + X_raw = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=43) + + obs = pd.DataFrame( + {"cell_type": [f"type_{i % 3}" for i in range(n_rows)]}, + index=[f"cell_{i}" for i in range(n_rows)], + ) + var = pd.DataFrame( + {"gene_name": [f"gene_{i}" for i in range(n_cols)]}, + index=[f"ENSG{i:08d}" for i in range(n_cols)], + ) + h = tmp_path / "var.h5ad" - ad.AnnData(X=X, var=pd.DataFrame(index=np.arange(X.shape[1])), raw={"X": X_raw}).write_h5ad(h) + ad.AnnData(X=X, obs=obs, var=var, raw={"X": X_raw}).write_h5ad(h) return h return _make From 2919a918e1ff0b1a224c90a455fe14c86544a6df Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Sun, 18 Jan 2026 18:12:41 -0800 Subject: [PATCH 2/4] adding test file --- .../bionemo/scdl/io/test_chunked_dataset.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py diff --git a/sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py b/sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py new file mode 100644 index 0000000000..c4594e1544 --- /dev/null +++ b/sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for chunked SingleCellMemMapDataset functionality.""" + +import numpy as np +import pytest + +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +from bionemo.scdl.util.scdl_constants import Backend + + +def test_to_chunked(tmp_path, make_h5ad_with_raw): + """Convert to chunked, verify data and features match.""" + h5ad_path = make_h5ad_with_raw(tmp_path) + original = SingleCellMemMapDataset(tmp_path / "orig", h5ad_path=h5ad_path) + chunked = original.to_chunked(str(tmp_path / "chunked"), chunk_size=30) + + # Basic properties + assert chunked._is_chunked + assert chunked.header.backend == Backend.CHUNKED_MEMMAP_V0 + assert len(chunked) == len(original) + + # Data matches + for idx in range(len(original)): + (orig_vals, orig_cols), _, _ = original.get_row(idx) + (chunk_vals, chunk_cols), _, _ = chunked.get_row(idx) + np.testing.assert_array_equal(orig_vals, chunk_vals) + np.testing.assert_array_equal(orig_cols, chunk_cols) + + # Features preserved + assert len(chunked._var_feature_index) == len(original._var_feature_index) + assert chunked._obs_feature_index.number_of_rows() == original._obs_feature_index.number_of_rows() + + +def test_to_chunked_inplace(tmp_path, make_h5ad_with_raw): + """In-place conversion replaces original with chunked.""" + h5ad_path = make_h5ad_with_raw(tmp_path) + scdl_path = tmp_path / "scdl" + SingleCellMemMapDataset(scdl_path, h5ad_path=h5ad_path) + + chunked = SingleCellMemMapDataset(scdl_path).to_chunked(chunk_size=30) + + assert chunked._is_chunked + assert chunked.data_path == str(scdl_path) + + +def test_to_chunked_already_chunked_raises(tmp_path, make_h5ad_with_raw): + """Cannot chunk an already chunked dataset.""" + h5ad_path = make_h5ad_with_raw(tmp_path) + original = SingleCellMemMapDataset(tmp_path / "orig", h5ad_path=h5ad_path) + chunked = original.to_chunked(str(tmp_path / "chunked"), chunk_size=30) + + with pytest.raises(ValueError, match="already chunked"): + chunked.to_chunked() From ab32f82363fbc7d90622cf839b7221738b304870 Mon Sep 17 00:00:00 2001 From: Polina Binder Date: Thu, 22 Jan 2026 18:40:30 -0800 Subject: [PATCH 3/4] Restore chunked_scdl_benchmark.py --- .../examples/chunked_scdl_benchmark.py | 526 ++++++++++++++++++ 1 file changed, 526 insertions(+) create mode 100644 sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py diff --git a/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py b/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py new file mode 100644 index 0000000000..3dacf83dd8 --- /dev/null +++ b/sub-packages/bionemo-scspeedtest/examples/chunked_scdl_benchmark.py @@ -0,0 +1,526 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Benchmark comparing Regular SCDL vs Chunked SCDL vs Remote Chunked SCDL. + +Usage (with defaults): + python chunked_scdl_benchmark.py + +Custom paths: + python chunked_scdl_benchmark.py \ + --scdl-path /path/to/scdl/ \ + --chunked-path /path/to/chunked/ \ + --remote-path s3://bucket/chunked \ + --endpoint-url https://your-s3-endpoint.com + +This script benchmarks: +1. Regular SCDL - Standard DataLoader with shuffle (baseline) +2. Chunked SCDL (local) - Pre-converted chunked dataset with ChunkAwareSampler +3. Remote Chunked SCDL - S3/GCS with LRU caching and ChunkAwareSampler +""" + +import argparse +import os +import time +from datetime import datetime + +from torch.utils.data import DataLoader + +from bionemo.scdl.io.chunk_sampler import ChunkAwareSampler +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch +from bionemo.scspeedtest import benchmark_dataloaders_with_configs, print_comparison + + +def create_regular_scdl_factory( + batch_size: int = 64, shuffle: bool = True, data_path: str | None = None, num_workers: int = 0 +): + """Create a regular SCDL dataloader factory (baseline).""" + + def factory(): + dataset = SingleCellMemMapDataset(data_path) + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + drop_last=False, + collate_fn=collate_sparse_matrix_batch, + num_workers=num_workers, + ) + + return factory + + +def create_chunked_scdl_preconverted_factory( + batch_size: int = 64, + chunked_path: str | None = None, + num_workers: int = 0, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 2, +): + """Create a chunked SCDL dataloader factory from pre-converted chunked dataset.""" + + def factory(): + dataset = SingleCellMemMapDataset(chunked_path) + + sampler = ChunkAwareSampler( + dataset, + shuffle_chunks=shuffle_chunks, + shuffle_within_window=shuffle_within_window, + chunks_per_window=chunks_per_window, + ) + + return DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + collate_fn=collate_sparse_matrix_batch, + num_workers=num_workers, + ) + + return factory + + +def create_chunked_scdl_random_factory( + batch_size: int = 64, + chunked_path: str | None = None, + num_workers: int = 0, +): + """Create a chunked SCDL dataloader with random shuffle (no ChunkAwareSampler).""" + start_time = time.perf_counter() + + def factory(): + dataset = SingleCellMemMapDataset(chunked_path) + + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, # Standard random shuffle - worst case for I/O locality + collate_fn=collate_sparse_matrix_batch, + num_workers=num_workers, + ) + + end_time = time.perf_counter() + print(f"Time taken to instantiate chunked SCDL dataset: {end_time - start_time:.2f} seconds") + return factory + + +class RemoteDataloaderFactory: + """Factory that tracks the last created dataset for stats access.""" + + def __init__( + self, + batch_size: int = 64, + remote_path: str | None = None, + cache_dir: str | None = None, + max_cached_chunks: int = 3, + storage_options: dict | None = None, + num_workers: int = 0, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 2, + batch_download_size: int = 30, + ): + """Initialize the remote dataloader factory with configuration.""" + self.batch_size = batch_size + self.remote_path = remote_path + self.cache_dir = cache_dir + self.max_cached_chunks = max_cached_chunks + self.storage_options = storage_options + self.num_workers = num_workers + self.shuffle_chunks = shuffle_chunks + self.shuffle_within_window = shuffle_within_window + self.chunks_per_window = chunks_per_window + self.batch_download_size = batch_download_size + self.last_dataset = None # Store reference for stats access + # Timing breakdown + self.cache_clear_time = 0.0 + self.dataset_init_time = 0.0 + self.sampler_init_time = 0.0 + self.dataloader_init_time = 0.0 + self.total_init_time = 0.0 + + def __call__(self): + """Create a new dataloader.""" + import shutil + + # Clear cache before each run to measure streaming performance + import time + + total_start = time.perf_counter() + + # 1. Clear cache + if self.cache_dir: + print(f"Clearing cache directory: {self.cache_dir}") + t0 = time.perf_counter() + shutil.rmtree(self.cache_dir, ignore_errors=True) + t1 = time.perf_counter() + self.cache_clear_time = t1 - t0 + print(f"Cache clear time: {self.cache_clear_time:.3f} sec") + + # 2. Create dataset + print("Instantiating SingleCellMemMapDataset.from_remote ...") + t2 = time.perf_counter() + self.last_dataset = SingleCellMemMapDataset.from_remote( + self.remote_path, + cache_dir=self.cache_dir, + max_cached_chunks=self.max_cached_chunks, + storage_options=self.storage_options, + batch_download_size=self.batch_download_size, + use_async_downloads=True, + ) + t3 = time.perf_counter() + self.dataset_init_time = t3 - t2 + print(f"SingleCellMemMapDataset.from_remote time: {self.dataset_init_time:.3f} sec") + + # 3. Create sampler + print("Instantiating ChunkAwareSampler ...") + t4 = time.perf_counter() + sampler = ChunkAwareSampler( + self.last_dataset, + shuffle_chunks=self.shuffle_chunks, + shuffle_within_window=self.shuffle_within_window, + chunks_per_window=self.chunks_per_window, + ) + t5 = time.perf_counter() + self.sampler_init_time = t5 - t4 + print(f"ChunkAwareSampler instantiation time: {self.sampler_init_time:.3f} sec") + + # 4. Create DataLoader + print("Instantiating DataLoader ...") + t6 = time.perf_counter() + dataloader = DataLoader( + self.last_dataset, + sampler=sampler, + batch_size=self.batch_size, + collate_fn=collate_sparse_matrix_batch, + num_workers=self.num_workers, + ) + t7 = time.perf_counter() + self.dataloader_init_time = t7 - t6 + self.total_init_time = t7 - total_start + print(f"DataLoader instantiation time: {self.dataloader_init_time:.3f} sec") + print(f"Total init time: {self.total_init_time:.3f} sec") + + return dataloader + + def get_download_stats(self) -> dict | None: + """Get download statistics from the last created dataset.""" + if self.last_dataset and hasattr(self.last_dataset, "_chunk_loader"): + return self.last_dataset._chunk_loader.stats.summary() + return None + + def print_download_stats(self): + """Print download statistics.""" + stats = self.get_download_stats() + if stats: + iteration_time = stats["wall_clock_time_sec"] + total_time = self.total_init_time + iteration_time + + print("\n" + "=" * 50) + print("REMOTE DOWNLOAD STATS") + print("=" * 50) + print(f"Init: {self.total_init_time:.1f}s | Iteration: {iteration_time:.1f}s | Total: {total_time:.1f}s") + print(f"Cold start: {stats['cold_start_time_sec']:.1f}s | Wait time: {stats['total_wait_time_sec']:.1f}s") + print( + f"Downloaded: {stats['total_bytes_downloaded_mb']:.0f} MB ({stats['download_count']} chunks, {stats['cache_hits']} cache hits)" + ) + print( + f"Throughput: {stats['throughput_mbps']:.1f} MB/s effective, {stats['per_thread_throughput_mbps']:.1f} MB/s per-thread" + ) + + +def create_remote_chunked_scdl_factory( + batch_size: int = 64, + remote_path: str | None = None, + cache_dir: str | None = None, + max_cached_chunks: int = 3, + storage_options: dict | None = None, + num_workers: int = 0, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 2, + batch_download_size: int = 30, +) -> RemoteDataloaderFactory: + """Create a remote chunked SCDL dataloader factory with ChunkAwareSampler.""" + return RemoteDataloaderFactory( + batch_size=batch_size, + remote_path=remote_path, + cache_dir=cache_dir, + max_cached_chunks=max_cached_chunks, + storage_options=storage_options, + num_workers=num_workers, + shuffle_chunks=shuffle_chunks, + shuffle_within_window=shuffle_within_window, + chunks_per_window=chunks_per_window, + batch_download_size=batch_download_size, + ) + + +def chunked_scdl_benchmarking( + num_epochs: int = 1, + num_runs: int = 1, + scdl_path: str | None = None, + chunked_path: str | None = None, + remote_path: str | None = None, + endpoint_url: str | None = None, + cache_dir: str = "/tmp/scdl_cache", + max_cached_chunks: int = 3, + chunks_per_window: int = 2, + max_time_seconds: float = 120.0, + warmup_time_seconds: float = 30.0, + batch_size: int = 64, +): + """Run benchmarks comparing regular SCDL vs chunked SCDL. + + Args: + num_epochs: Number of epochs per configuration + num_runs: Number of runs per configuration + scdl_path: Path to regular (non-chunked) SCDL dataset + chunked_path: Path to pre-converted chunked SCDL dataset (optional) + remote_path: Remote path to chunked dataset (s3://, gs://, etc.) + endpoint_url: Custom S3 endpoint URL (for non-AWS S3) + cache_dir: Local cache directory for remote chunks + max_cached_chunks: Max chunks to keep in LRU cache + chunks_per_window: Chunks per sampling window + max_time_seconds: Max time per configuration + warmup_time_seconds: Warmup time per configuration + batch_size: Batch size for dataloaders + """ + print("=" * 80) + print("CHUNKED SCDL BENCHMARKING") + print("=" * 80) + print() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + configurations = [] + + # Configuration 1: Regular SCDL baseline + if scdl_path: + print(f"Adding Regular SCDL baseline: {scdl_path}") + configurations.append( + { + "name": "Regular_SCDL_Baseline", + "dataloader_factory": create_regular_scdl_factory( + batch_size=batch_size, shuffle=True, data_path=scdl_path, num_workers=0 + ), + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": scdl_path, + "num_runs": num_runs, + } + ) + + # Configuration 2: Pre-converted chunked SCDL with ChunkAwareSampler + if chunked_path: + print(f"Adding Chunked SCDL + ChunkAwareSampler: {chunked_path}") + configurations.append( + { + "name": f"Chunked_SCDL_ChunkAware_window{chunks_per_window}", + "dataloader_factory": create_chunked_scdl_preconverted_factory( + batch_size=batch_size, + chunked_path=chunked_path, + num_workers=0, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=chunks_per_window, + ), + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": chunked_path, + "num_runs": num_runs, + } + ) + + # Configuration 3: Chunked SCDL with random shuffle (no ChunkAwareSampler) + print(f"Adding Chunked SCDL + Random Shuffle: {chunked_path}") + configurations.append( + { + "name": "Chunked_SCDL_RandomShuffle", + "dataloader_factory": create_chunked_scdl_random_factory( + batch_size=batch_size, + chunked_path=chunked_path, + num_workers=0, + ), + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": chunked_path, + "num_runs": num_runs, + } + ) + + # Configuration 4: Remote chunked SCDL + remote_factory = None + if remote_path: + storage_options = { + "default_fill_cache": False, # Don't cache file contents in memory + "default_cache_type": "none", # No block caching + "config_kwargs": {"max_pool_connections": 100}, # More parallel connections + } + if endpoint_url: + storage_options["client_kwargs"] = {"endpoint_url": endpoint_url} + + print(f"Adding Remote Chunked SCDL: {remote_path}") + if endpoint_url: + print(f" Endpoint: {endpoint_url}") + print(f" Cache dir: {cache_dir}") + print(f" Max cached chunks: {max_cached_chunks}") + remote_factory = create_remote_chunked_scdl_factory( + batch_size=batch_size, + remote_path=remote_path, + cache_dir=cache_dir, + max_cached_chunks=max_cached_chunks, + storage_options=storage_options, + num_workers=0, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=chunks_per_window, + batch_download_size=max_cached_chunks, # Download batch = cache size + ) + + configurations.append( + { + "name": f"Chunked_SCDL_Remote_cache{max_cached_chunks}_window{chunks_per_window}", + "dataloader_factory": remote_factory, + "num_epochs": num_epochs, + "max_time_seconds": max_time_seconds, + "warmup_time_seconds": warmup_time_seconds, + "data_path": remote_path, + "num_runs": num_runs, + } + ) + + if not configurations: + print("ERROR: No configurations to benchmark. Provide --scdl-path, --chunked-path, or --remote-path") + return + + print() + print(f"Running {len(configurations)} configuration(s)...") + print() + + results = benchmark_dataloaders_with_configs( + dataloader_configs=configurations, + shared_dataset_factory=None, # Each config loads its own dataset + output_prefix=f"chunked_scdl_benchmark_{timestamp}", + ) + + print() + print("=" * 80) + print("RESULTS SUMMARY") + print("=" * 80) + print_comparison(results) + + # Print remote download statistics if available + if remote_factory: + # Record wall_clock_time from benchmark results (since sampler may not complete naturally) + for result in results: + if "Remote" in result.name and remote_factory.last_dataset and remote_factory.last_dataset._chunk_loader: + remote_factory.last_dataset._chunk_loader.stats.record_wall_clock(result.total_time_seconds) + break + remote_factory.print_download_stats() + + print() + print(f"Results saved to: chunked_scdl_benchmark_{timestamp}_detailed_breakdown.csv") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark Regular SCDL vs Chunked SCDL with ChunkAwareSampler", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Benchmark local regular vs local chunked + %(prog)s --scdl-path /data/scdl/ --chunked-path /data/chunked_scdl/ + + # Benchmark remote chunked dataset + %(prog)s --remote-path s3://bucket/chunked --endpoint-url https://s3.example.com + + # Full comparison + %(prog)s --scdl-path /data/scdl/ --chunked-path /data/chunked/ --remote-path s3://bucket/chunked + """, + ) + + # Data paths - with sensible defaults + parser.add_argument( + "--scdl-path", + type=str, + default="", # "/home/pbinder/bionemo-framework/small_tahoe_format", + help="Path to regular SCDL dataset (baseline)", + ) + parser.add_argument( + "--chunked-path", + type=str, + default="", # "/home/pbinder/bionemo-framework/sub-packages/bionemo-scspeedtest/example_data/tahoe_chunked", + help="Path to pre-converted chunked SCDL dataset", + ) + parser.add_argument( + "--remote-path", + type=str, + default="s3://general-purpose/polina/tahoe_chunked", + help="Remote path to chunked dataset (s3://, gs://)", + ) + parser.add_argument("--endpoint-url", type=str, default="https://pbss.s8k.io", help="Custom S3 endpoint URL") + + # Cache settings + parser.add_argument("--cache-dir", type=str, default="/tmp/scdl_cache", help="Local cache for remote chunks") + parser.add_argument( + "--max-cached-chunks", + type=int, + default=2, + help="Max chunks in LRU cache (>= 2x chunks-per-window for prefetching)", + ) + + # Chunking settings + parser.add_argument("--chunks-per-window", type=int, default=1, help="Chunks per sampling window") + + # Benchmark settings + parser.add_argument("--num-epochs", type=int, default=1, help="Epochs per configuration") + parser.add_argument("--num-runs", type=int, default=1, help="Runs per configuration") + parser.add_argument("--max-time", type=float, default=120, help="Max time per config (seconds)") + parser.add_argument("--warmup-time", type=float, default=0, help="Warmup time per config (seconds)") + parser.add_argument("--batch-size", type=int, default=64, help="Batch size") + + args = parser.parse_args() + + # Validate paths exist (for local paths) + if args.scdl_path and not os.path.exists(args.scdl_path): + print(f"Warning: SCDL path not found: {args.scdl_path}, skipping...") + args.scdl_path = None + if args.chunked_path and not os.path.exists(args.chunked_path): + print(f"Warning: Chunked path not found: {args.chunked_path}, skipping...") + args.chunked_path = None + + # Check at least one valid config + if not any([args.scdl_path, args.chunked_path, args.remote_path]): + parser.error("No valid data paths found. Provide --scdl-path, --chunked-path, or --remote-path") + + chunked_scdl_benchmarking( + num_epochs=args.num_epochs, + num_runs=args.num_runs, + scdl_path=args.scdl_path, + chunked_path=args.chunked_path, + remote_path=args.remote_path, + endpoint_url=args.endpoint_url, + cache_dir=args.cache_dir, + max_cached_chunks=args.max_cached_chunks, + chunks_per_window=args.chunks_per_window, + max_time_seconds=args.max_time, + warmup_time_seconds=args.warmup_time, + batch_size=args.batch_size, + ) From 401b84a018d153846a6e622f090787b23dd36352 Mon Sep 17 00:00:00 2001 From: polinabinder1 Date: Thu, 22 Jan 2026 18:42:09 -0800 Subject: [PATCH 4/4] Delete sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py Signed-off-by: polinabinder1 --- .../bionemo/scdl/io/test_chunked_dataset.py | 67 ------------------- 1 file changed, 67 deletions(-) delete mode 100644 sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py diff --git a/sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py b/sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py deleted file mode 100644 index c4594e1544..0000000000 --- a/sub-packages/bionemo-scdl/tests/bionemo/scdl/io/test_chunked_dataset.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-Apache2 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for chunked SingleCellMemMapDataset functionality.""" - -import numpy as np -import pytest - -from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset -from bionemo.scdl.util.scdl_constants import Backend - - -def test_to_chunked(tmp_path, make_h5ad_with_raw): - """Convert to chunked, verify data and features match.""" - h5ad_path = make_h5ad_with_raw(tmp_path) - original = SingleCellMemMapDataset(tmp_path / "orig", h5ad_path=h5ad_path) - chunked = original.to_chunked(str(tmp_path / "chunked"), chunk_size=30) - - # Basic properties - assert chunked._is_chunked - assert chunked.header.backend == Backend.CHUNKED_MEMMAP_V0 - assert len(chunked) == len(original) - - # Data matches - for idx in range(len(original)): - (orig_vals, orig_cols), _, _ = original.get_row(idx) - (chunk_vals, chunk_cols), _, _ = chunked.get_row(idx) - np.testing.assert_array_equal(orig_vals, chunk_vals) - np.testing.assert_array_equal(orig_cols, chunk_cols) - - # Features preserved - assert len(chunked._var_feature_index) == len(original._var_feature_index) - assert chunked._obs_feature_index.number_of_rows() == original._obs_feature_index.number_of_rows() - - -def test_to_chunked_inplace(tmp_path, make_h5ad_with_raw): - """In-place conversion replaces original with chunked.""" - h5ad_path = make_h5ad_with_raw(tmp_path) - scdl_path = tmp_path / "scdl" - SingleCellMemMapDataset(scdl_path, h5ad_path=h5ad_path) - - chunked = SingleCellMemMapDataset(scdl_path).to_chunked(chunk_size=30) - - assert chunked._is_chunked - assert chunked.data_path == str(scdl_path) - - -def test_to_chunked_already_chunked_raises(tmp_path, make_h5ad_with_raw): - """Cannot chunk an already chunked dataset.""" - h5ad_path = make_h5ad_with_raw(tmp_path) - original = SingleCellMemMapDataset(tmp_path / "orig", h5ad_path=h5ad_path) - chunked = original.to_chunked(str(tmp_path / "chunked"), chunk_size=30) - - with pytest.raises(ValueError, match="already chunked"): - chunked.to_chunked()