diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py new file mode 100644 index 0000000000..7d8bd066b2 --- /dev/null +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk-aware sampler for efficient iteration over chunked SCDL datasets.""" + +import random +import warnings +from pathlib import Path +from typing import Iterator, Optional + +import torch.utils.data +from torch.utils.data import Sampler + +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + + +def remote_worker_init_fn(worker_id: int) -> None: + """Initialize per-worker cache directories for remote datasets. + + Use this with DataLoader's worker_init_fn to prevent cache conflicts + when using multiple workers with remote chunked SCDL: + + DataLoader( + dataset, + num_workers=4, + worker_init_fn=remote_worker_init_fn, + ... + ) + + Each worker gets its own cache directory: {base_cache}_worker_{id} + """ + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + return + + dataset = worker_info.dataset + if hasattr(dataset, "_chunk_loader") and dataset._chunk_loader is not None: + # Create per-worker cache directory + base_cache = dataset._chunk_loader.cache_dir + worker_cache = Path(str(base_cache) + f"_worker_{worker_id}") + worker_cache.mkdir(parents=True, exist_ok=True) + dataset._chunk_loader.cache_dir = worker_cache + # Clear cached chunks dict for this worker + dataset._chunk_loader._cached_chunks.clear() + + +class ChunkAwareSampler(Sampler[int]): + """Sampler that iterates by chunks for efficient access patterns. + + This sampler ensures all rows from a chunk window are accessed together + before moving to the next window. This is optimal for: + - Local: memory locality (chunk data stays in cache) + - Remote: prefetching (download chunks once, use all rows) + + Args: + dataset: A chunked SingleCellMemMapDataset. + shuffle_chunks: Whether to shuffle chunk order each epoch. + shuffle_within_window: Whether to shuffle rows within each chunk window. + chunks_per_window: Number of chunks to load together (more = better randomness). + seed: Random seed for reproducibility. + """ + + def __init__( + self, + dataset: SingleCellMemMapDataset, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 1, + seed: Optional[int] = None, + ): + """Initialize the chunk aware sampler.""" + if not dataset._is_chunked: + raise ValueError("ChunkAwareSampler requires a chunked dataset") + + self.dataset = dataset + self.shuffle_chunks = shuffle_chunks + self.shuffle_within_window = shuffle_within_window + self.chunks_per_window = max(1, chunks_per_window) + self.rng = random.Random(seed) + self.chunked_info = dataset.header.chunked_info + + # Warn if chunks_per_window exceeds cache size (causes thrashing) + if dataset._chunk_loader and chunks_per_window > dataset._chunk_loader.max_cached_chunks: + warnings.warn( + f"chunks_per_window ({chunks_per_window}) > max_cached_chunks " + f"({dataset._chunk_loader.max_cached_chunks}). This causes cache thrashing. " + f"Increase max_cached_chunks or decrease chunks_per_window." + ) + # Warn if cache too small for effective prefetching + if dataset._chunk_loader and 2 * chunks_per_window > dataset._chunk_loader.max_cached_chunks: + warnings.warn( + f"max_cached_chunks ({dataset._chunk_loader.max_cached_chunks}) < 2 * chunks_per_window " + f"({2 * chunks_per_window}). Prefetching disabled - no room for next window. " + f"Set max_cached_chunks >= {2 * chunks_per_window} for prefetching." + ) + + def _preload_memmaps_for_chunks(self, chunk_ids: list) -> None: + """Preload memmaps for chunks into _remote_chunk_cache. + + This should be called after prefetch completes to avoid creating + memmaps during iteration. Files are already on disk, we just need + to create the memmap objects. + """ + if not hasattr(self.dataset, "_remote_chunk_cache"): + return + + for chunk_id in chunk_ids: + if chunk_id not in self.dataset._remote_chunk_cache: + # get_chunk returns path (chunk already downloaded by prefetch) + chunk_path = self.dataset._chunk_loader.get_chunk(chunk_id) + # Load memmaps and cache them + memmaps = self.dataset._load_chunk_from_path(chunk_path) + self.dataset._remote_chunk_cache[chunk_id] = memmaps + + def __iter__(self) -> Iterator[int]: + """Yield row indices, grouped by chunk window. + + In multi-worker mode, each worker handles a disjoint subset of chunks. + This avoids cache conflicts and ensures efficient parallel loading. + """ + chunk_ids = list(range(self.chunked_info.num_chunks)) + + if self.shuffle_chunks: + self.rng.shuffle(chunk_ids) + + # Multi-worker support: each worker handles a subset of chunks + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + # Split chunks among workers: worker i gets chunks[i::num_workers] + chunk_ids = chunk_ids[worker_info.id :: worker_info.num_workers] + + # Build list of windows from this worker's chunks + windows = [] + for i in range(0, len(chunk_ids), self.chunks_per_window): + windows.append(chunk_ids[i : i + self.chunks_per_window]) + + # Check if we have room for prefetching (need 2x window size in cache) + can_prefetch = ( + self.dataset._chunk_loader and 2 * self.chunks_per_window <= self.dataset._chunk_loader.max_cached_chunks + ) + + import time as time_module # Import once outside loop + + iteration_start = time_module.perf_counter() + + # Process windows with pipelined prefetching (1 window ahead) + for window_idx, window_chunks in enumerate(windows): + if self.dataset._chunk_loader: + # For first window, prefetch synchronously since we need it immediately + if window_idx == 0: + cold_start = time_module.perf_counter() + self.dataset._chunk_loader.prefetch_chunks(window_chunks) + elapsed = time_module.perf_counter() - cold_start + self.dataset._chunk_loader.stats.record_cold_start(elapsed) + else: + # Wait for async prefetch started in previous iteration + self.dataset._chunk_loader.wait_for_prefetch() + + # Preload memmaps for current window (now that files are downloaded) + # This avoids creating memmaps during iteration + self._preload_memmaps_for_chunks(window_chunks) + + # Start async prefetch of NEXT window while we process this one + if can_prefetch and window_idx + 1 < len(windows): + self.dataset._chunk_loader.prefetch_chunks_async(windows[window_idx + 1]) + + # Gather all indices from this window + all_indices = [] + for chunk_id in window_chunks: + start = chunk_id * self.chunked_info.chunk_size + end = min(start + self.chunked_info.chunk_size, self.chunked_info.total_rows) + all_indices.extend(range(start, end)) + + if self.shuffle_within_window: + self.rng.shuffle(all_indices) + + yield from all_indices + + # Mark this window's chunks as safe to evict (prefer these over current/next window) + if self.dataset._chunk_loader: + self.dataset._chunk_loader.mark_chunks_done(window_chunks) + + # Record total wall-clock time + if self.dataset._chunk_loader: + total_time = time_module.perf_counter() - iteration_start + self.dataset._chunk_loader.stats.record_wall_clock(total_time) + + def __len__(self) -> int: + """Return total number of samples. + + Note: In multi-worker mode, each worker sees only its subset, + but PyTorch DataLoader handles combining results correctly. + """ + return self.chunked_info.total_rows diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py new file mode 100644 index 0000000000..3ca1222f96 --- /dev/null +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py @@ -0,0 +1,483 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Remote chunk loader with LRU caching for chunked SCDL datasets. + +NOTE: This is a simple POC implementation. For production multi-worker/multi-node use: +- Add file locking for shared cache (filelock) +- Add reference counting to prevent evicting in-use chunks +- Use DistributedChunkSampler to shard chunks across nodes +""" + +import io +import shutil +import tempfile +import threading +import time +from collections import OrderedDict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from pathlib import Path +from typing import ClassVar, Dict, List, Optional, Tuple + +import fsspec +import numpy as np + +from bionemo.scdl.util.scdl_constants import FileNames + + +@dataclass +class DownloadStats: + """Statistics for download timing.""" + + total_download_time: float = 0.0 # Cumulative time across all download threads + total_wait_time: float = 0.0 # Time spent blocked waiting for downloads + cold_start_time: float = 0.0 # Time to download first window + wall_clock_time: float = 0.0 # Total wall-clock iteration time + total_bytes_downloaded: int = 0 + download_count: int = 0 + cache_hits: int = 0 + _lock: threading.Lock = field(default_factory=threading.Lock) + + def record_download(self, duration: float, bytes_downloaded: int = 0) -> None: + """Record a download completion.""" + with self._lock: + self.total_download_time += duration + self.total_bytes_downloaded += bytes_downloaded + self.download_count += 1 + + def record_cold_start(self, duration: float) -> None: + """Record cold start time (first window download).""" + self.cold_start_time = duration + + def record_wall_clock(self, duration: float) -> None: + """Record total wall-clock time.""" + self.wall_clock_time = duration + + def record_wait(self, duration: float) -> None: + """Record time spent waiting for a chunk.""" + with self._lock: + self.total_wait_time += duration + + def record_cache_hit(self) -> None: + """Record a cache hit.""" + with self._lock: + self.cache_hits += 1 + + def summary(self) -> dict: + """Return a summary of download statistics.""" + total_mb = self.total_bytes_downloaded / 1e6 + effective_throughput = total_mb / self.wall_clock_time if self.wall_clock_time > 0 else 0 + per_thread_throughput = total_mb / self.total_download_time if self.total_download_time > 0 else 0 + + return { + "cold_start_time_sec": round(self.cold_start_time, 2), + "wall_clock_time_sec": round(self.wall_clock_time, 2), + "total_wait_time_sec": round(self.total_wait_time, 2), + "total_bytes_downloaded_mb": round(total_mb, 2), + "download_count": self.download_count, + "cache_hits": self.cache_hits, + "throughput_mbps": round(effective_throughput, 2), + "per_thread_throughput_mbps": round(per_thread_throughput, 2), + } + + +class RemoteChunkLoader: + """Downloads and caches chunks from remote storage with LRU eviction. + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Optional dict of options passed to fsspec (e.g., endpoint_url for S3). + """ + + # Files in each chunk directory (uncompressed format) + CHUNK_FILES: ClassVar[List[str]] = [FileNames.DATA.value, FileNames.ROWPTR.value, FileNames.COLPTR.value] + # Compressed format (single file) + COMPRESSED_FILE: ClassVar[str] = "chunk.npz" + + # Type alias for cached chunk data: (file_path, (data, rowptr, colptr) memmaps) + CacheEntry = Tuple[Path, Tuple[np.ndarray, np.ndarray, np.ndarray]] + + def __init__( + self, + remote_path: str, + cache_dir: Optional[Path] = None, + max_cached_chunks: int = 2, + storage_options: Optional[dict] = None, + batch_download_size: int = 10, + use_async_downloads: bool = True, + dtypes: Optional[Dict[str, str]] = None, + ): + """Initialize the remote chunk loader. + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Optional dict of options passed to fsspec. + batch_download_size: Number of chunks to download at once (unused, for API compat). + use_async_downloads: Whether to use async downloads (unused, for API compat). + dtypes: Optional dict mapping file names to dtypes for memmap loading. + """ + self.remote_path = remote_path.rstrip("/") + self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.mkdtemp(prefix="scdl_cache_")) + self.max_cached_chunks = max_cached_chunks + self.batch_download_size = batch_download_size + self._use_async = use_async_downloads + self.dtypes = dtypes or {} + # Cache stores both path (for cleanup) and memmaps (for access) + self._cache: OrderedDict[int, "RemoteChunkLoader.CacheEntry"] = OrderedDict() + self._cache_lock = threading.Lock() # Protect cache access + self._downloading: set = set() # Chunks currently being downloaded + self._download_complete = threading.Condition(self._cache_lock) + self._evictable_chunks: set = set() # Chunks marked as "done" - prefer evicting these + + # Initialize filesystem with optional storage options + protocol = remote_path.split("://")[0] if "://" in remote_path else "file" + self._fs = fsspec.filesystem(protocol, **(storage_options or {})) + + # Ensure cache directory exists + self.cache_dir.mkdir(parents=True, exist_ok=True) + + # Async prefetching + self._prefetch_executor = ThreadPoolExecutor(max_workers=1) + self._prefetch_future: Optional[Future] = None + self._prefetch_lock = threading.Lock() + + # Download statistics + self.stats = DownloadStats() + + def set_dtypes(self, dtypes: Dict[str, str]) -> None: + """Set dtypes for memmap loading (can be called after header is loaded). + + Args: + dtypes: Dict mapping file names to dtypes. + """ + self.dtypes = dtypes + + def mark_chunks_done(self, chunk_ids: List[int]) -> None: + """Mark chunks as done/evictable. + + Call this when you've finished processing a window of chunks and won't + need them again until the next epoch. Eviction will prefer these chunks. + + Args: + chunk_ids: List of chunk IDs that can be safely evicted. + """ + with self._cache_lock: + self._evictable_chunks.update(chunk_ids) + + def _find_evictable_chunk(self) -> Optional[Tuple[int, "RemoteChunkLoader.CacheEntry"]]: + """Find a chunk to evict, preferring chunks marked as done. + + Returns: + Tuple of (chunk_id, (path, memmaps)) or None if no chunks available. + """ + # First, try to evict chunks marked as "done" + for chunk_id in list(self._cache.keys()): + if chunk_id in self._evictable_chunks: + entry = self._cache.pop(chunk_id) + self._evictable_chunks.discard(chunk_id) + return chunk_id, entry + + # Fallback: evict oldest chunk (LRU) - shouldn't happen if sized correctly + if self._cache: + return self._cache.popitem(last=False) + return None + + def _evict_chunk(self, chunk_id: int, cache_entry: "RemoteChunkLoader.CacheEntry") -> None: + """Evict a chunk from cache, releasing memmaps before deleting files.""" + self._evictable_chunks.discard(chunk_id) + chunk_path, memmaps = cache_entry + # Release memmaps first (closes file handles) + del memmaps + # Now safe to delete files + shutil.rmtree(chunk_path, ignore_errors=True) + + def get_chunk(self, chunk_id: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get chunk as memory-mapped arrays, downloading if needed. + + Args: + chunk_id: The chunk index to retrieve. + + Returns: + Tuple of (data, rowptr, colptr) as memory-mapped numpy arrays. + """ + start_time = time.perf_counter() + + # Collect chunks to evict inside lock, delete outside to avoid blocking + chunks_to_evict = [] + + with self._download_complete: + # Wait if prefetch is downloading this chunk + while chunk_id in self._downloading and chunk_id not in self._cache: + self._download_complete.wait(timeout=0.1) + + # Check cache - return memmaps directly + if chunk_id in self._cache: + self._cache.move_to_end(chunk_id) + self.stats.record_cache_hit() + _path, memmaps = self._cache[chunk_id] + return memmaps + + # Evict chunks if at capacity (prefer evicting "done" chunks) + while len(self._cache) >= self.max_cached_chunks and self._cache: + evict_result = self._find_evictable_chunk() + if evict_result is None: + break + old_id, old_entry = evict_result + chunks_to_evict.append((old_id, old_entry)) + + # Mark as downloading + self._downloading.add(chunk_id) + + # Delete evicted chunks in background thread while we download + eviction_future = None + if chunks_to_evict: + eviction_future = self._prefetch_executor.submit( + lambda: [self._evict_chunk(old_id, old_entry) for old_id, old_entry in chunks_to_evict] + ) + + try: + # Download chunk (outside lock) + local_path = self._download_chunk(chunk_id) + + # Load memmaps + memmaps = self._load_chunk_memmaps(local_path) + + # Wait for eviction to complete (if any) + if eviction_future: + eviction_future.result() + + with self._download_complete: + self._cache[chunk_id] = (local_path, memmaps) + self._downloading.discard(chunk_id) + self._download_complete.notify_all() + + # Record wait time (includes download time for non-prefetched chunks) + wait_time = time.perf_counter() - start_time + self.stats.record_wait(wait_time) + + return memmaps + except Exception: + with self._download_complete: + self._downloading.discard(chunk_id) + self._download_complete.notify_all() + raise + + def _load_chunk_memmaps(self, chunk_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Load memmaps for a single chunk directory. + + Uses np.memmap when dtypes are known (faster), falls back to np.load otherwise. + """ + if self.dtypes: + return ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes.get(FileNames.DATA.value), mode="r"), + np.memmap( + chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes.get(FileNames.ROWPTR.value), mode="r" + ), + np.memmap( + chunk_path / FileNames.COLPTR.value, dtype=self.dtypes.get(FileNames.COLPTR.value), mode="r" + ), + ) + + else: + raise ValueError("Dtypes are not set") + + def _download_chunk(self, chunk_id: int) -> Path: + """Download a chunk from remote storage. + + Supports both compressed (.npz) and uncompressed formats. + Compressed files are downloaded and extracted to uncompressed format for fast memmap access. + """ + start_time = time.perf_counter() + bytes_downloaded = 0 + + chunk_name = f"chunk_{chunk_id:05d}" + print(f"[DOWNLOAD] Starting chunk {chunk_id} ({chunk_name})") + remote_chunk = f"{self.remote_path}/{chunk_name}" + local_chunk = self.cache_dir / chunk_name + + # Ensure directory exists + local_chunk.mkdir(parents=True, exist_ok=True) + + # Check if compressed format exists (single .npz file) + remote_compressed = f"{remote_chunk}/{self.COMPRESSED_FILE}" + try: + # Try compressed format first (1 HTTP request instead of 3) + with self._fs.open(remote_compressed, "rb") as src: + compressed_data = src.read() + bytes_downloaded = len(compressed_data) + + # Load compressed data from memory + npz = np.load(io.BytesIO(compressed_data)) + + # Extract and save as uncompressed files (for fast memmap access) + np.save(local_chunk / FileNames.DATA.value, npz["data"]) + np.save(local_chunk / FileNames.ROWPTR.value, npz["row_ptr"]) + np.save(local_chunk / FileNames.COLPTR.value, npz["col_ptr"]) + + # Record download stats + duration = time.perf_counter() - start_time + self.stats.record_download(duration, bytes_downloaded) + + return local_chunk + + except FileNotFoundError: + # Fall back to uncompressed format (3 files) + pass + except Exception: + # Fall back to uncompressed format (3 files) + pass + + # Download uncompressed files (original format) + def download_file(fname: str) -> int: + remote_file = f"{remote_chunk}/{fname}" + local_file = local_chunk / fname + try: + local_file.parent.mkdir(parents=True, exist_ok=True) + except FileExistsError: + pass + try: + with self._fs.open(remote_file, "rb") as src: + content = src.read() + except FileNotFoundError: + raise FileNotFoundError( + f"Chunk {chunk_id} not found at {remote_file} (compressed also not found at {remote_compressed})" + ) + with open(local_file, "wb") as dst: + dst.write(content) + return len(content) + + with ThreadPoolExecutor(max_workers=3) as executor: + file_sizes = list(executor.map(download_file, self.CHUNK_FILES)) + + # Record download stats + bytes_downloaded = sum(file_sizes) + duration = time.perf_counter() - start_time + self.stats.record_download(duration, bytes_downloaded) + + return local_chunk + + def prefetch_chunks(self, chunk_ids: List[int], max_parallel: int = 16) -> None: + """Prefetch multiple chunks in parallel (synchronous).""" + # Filter out already cached chunks + with self._cache_lock: + to_download = [cid for cid in chunk_ids if cid not in self._cache and cid not in self._downloading] + + if not to_download: + return + + # Collect chunks to evict inside lock, delete outside to avoid blocking + chunks_to_evict = [] + with self._download_complete: + needed = len(to_download) + while len(self._cache) + needed > self.max_cached_chunks and self._cache: + evict_result = self._find_evictable_chunk() + if evict_result is None: + break + old_id, old_entry = evict_result + chunks_to_evict.append((old_id, old_entry)) + + # Mark all as downloading + for cid in to_download: + self._downloading.add(cid) + + try: + + def download_and_load(chunk_id: int) -> Tuple[int, Path, Tuple[np.ndarray, np.ndarray, np.ndarray]]: + local_path = self._download_chunk(chunk_id) + memmaps = self._load_chunk_memmaps(local_path) + return chunk_id, local_path, memmaps + + # Run eviction and downloads in parallel + with ThreadPoolExecutor( + max_workers=min(max_parallel, len(to_download) + len(chunks_to_evict)) + ) as executor: + # Submit eviction tasks (fire and forget) + for old_id, old_entry in chunks_to_evict: + executor.submit(self._evict_chunk, old_id, old_entry) + # Download all chunks + results = list(executor.map(download_and_load, to_download)) + + # Add to cache + with self._download_complete: + for chunk_id, local_path, memmaps in results: + self._cache[chunk_id] = (local_path, memmaps) + self._downloading.discard(chunk_id) + self._download_complete.notify_all() + + except Exception: + with self._download_complete: + for cid in to_download: + self._downloading.discard(cid) + self._download_complete.notify_all() + raise + + def prefetch_chunks_async(self, chunk_ids: List[int]) -> None: + """Start prefetching chunks in background thread.""" + with self._prefetch_lock: + self._prefetch_future = self._prefetch_executor.submit(self.prefetch_chunks, chunk_ids) + + def wait_for_prefetch(self) -> None: + """Wait for any ongoing prefetch to complete.""" + with self._prefetch_lock: + if self._prefetch_future is not None: + try: + self._prefetch_future.result(timeout=300) # 5 min timeout + except Exception: + pass # Ignore prefetch errors + self._prefetch_future = None + + def _remote_exists(self, remote_path: str) -> bool: + """Check if a remote path exists (uses ls instead of exists for compatibility).""" + try: + # Use ls instead of exists() because some S3-compatible storage + # doesn't support HeadObject which exists() relies on + parent = "/".join(remote_path.rsplit("/", 1)[:-1]) + name = remote_path.rsplit("/", 1)[-1] + files = self._fs.ls(parent, detail=False) + return any(f.endswith(name) for f in files) + except Exception: + return False + + def get_metadata(self) -> Path: + """Download and return path to metadata files (header, features, etc.).""" + metadata_dir = self.cache_dir / "_metadata" + if metadata_dir.exists(): + return metadata_dir + + metadata_dir.mkdir(parents=True, exist_ok=True) + + # Download header and feature indices (header.sch is the SCDL header format) + for name in ["header.sch", "version.json", "metadata.json"]: + remote_file = f"{self.remote_path}/{name}" + if self._remote_exists(remote_file): + self._fs.get(remote_file, str(metadata_dir / name)) + + # Download feature directories + for name in ["var_features", "obs_features"]: + remote_dir = f"{self.remote_path}/{name}" + if self._remote_exists(remote_dir): + local_dir = metadata_dir / name + self._fs.get(remote_dir, str(local_dir), recursive=True) + + return metadata_dir + + def cleanup(self): + """Delete all cached data.""" + shutil.rmtree(self.cache_dir, ignore_errors=True) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py index 89fa3f9d61..14cd8ab522 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py @@ -18,6 +18,7 @@ import logging import os import shutil +import tempfile import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -30,6 +31,7 @@ from bionemo.scdl.api.single_cell_row_dataset import SingleCellRowDataset from bionemo.scdl.index.row_feature_index import ObservedFeatureIndex, VariableFeatureIndex +from bionemo.scdl.io.remote_chunk_loader import RemoteChunkLoader from bionemo.scdl.schema.header import ArrayDType, ArrayInfo, Backend, FeatureIndexInfo, SCDLHeader from bionemo.scdl.schema.version import CurrentSCDLVersion from bionemo.scdl.util.filecopyutil import extend_files @@ -41,6 +43,7 @@ determine_dtype, smallest_uint_dtype, ) +from bionemo.scdl.util.partition_scdl import partition_scdl from bionemo.scdl.util.scdl_constants import FLOAT_ORDER, INT_ORDER, FileNames, Mode, NeighborSamplingStrategy @@ -128,6 +131,9 @@ def __init__( self.data_path: str = data_path self.header: SCDLHeader = None self.mode: Mode = mode + self._is_chunked: bool = False + self._chunks: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = [] + self._chunk_loader = None # For remote chunked datasets self.paginated_load_cutoff = paginated_load_cutoff self.load_block_row_size = load_block_row_size self.var_feature_index_name = var_feature_index_name @@ -260,6 +266,56 @@ def version(self) -> str: """ return self._version + @classmethod + def from_remote( + cls, + remote_path: str, + cache_dir: Optional[str] = None, + max_cached_chunks: int = 2, + storage_options: Optional[Dict] = None, + batch_download_size: int = 10, + use_async_downloads: bool = True, + ): + """Load a chunked dataset from remote storage (S3, GCS, HTTP). + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Options passed to fsspec (e.g., {"endpoint_url": "https://..."} for S3). + batch_download_size: Number of chunks to download in parallel. + use_async_downloads: Whether to use async downloads (currently unused). + """ + loader = RemoteChunkLoader( + remote_path, + Path(cache_dir) if cache_dir else None, + max_cached_chunks, + storage_options, + batch_download_size=batch_download_size, + use_async_downloads=use_async_downloads, + ) + metadata_path = loader.get_metadata() + ds = cls.__new__(cls) + # Initialize essential attributes that __init__ would set + ds._version = importlib.metadata.version("bionemo.scdl") + ds._chunk_loader = loader + ds.data_path = remote_path + ds.header = None + ds.mode = Mode.READ_APPEND + ds._is_chunked = False + ds._chunks = [] + ds.dtypes = {} # Will be populated from header in load() + ds._var_feature_index = None + ds._obs_feature_index = None + + ds.load(str(metadata_path)) + + # Pass dtypes to loader after they're loaded from header + if ds.dtypes: + loader.set_dtypes(ds.dtypes) + + return ds + def _extract_neighbor_data(self, adata) -> bool: """Extracts neighbor data from AnnData.obsp object and saves to memmap files. @@ -436,10 +492,21 @@ def get_row( List[np.ndarray]: optional, corresponding variable (column) features. List[np.ndarray]: optional, corresponding observed (row) features. """ - start = self.row_index[index] - end = self.row_index[index + 1] - values = self.data[start:end] - columns = self.col_index[start:end] + if self._is_chunked: + chunk_id, local_idx = self.header.chunked_info.get_chunk_for_row(index) + if self._chunk_loader: + # Remote: loader returns memmaps directly (handles caching internally) + data, rowptr, colptr = self._chunk_loader.get_chunk(chunk_id) + else: + # Local: use pre-loaded chunk memmaps + data, rowptr, colptr = self._chunks[chunk_id] + start, end = rowptr[local_idx], rowptr[local_idx + 1] + values, columns = data[start:end], colptr[start:end] + else: + start = self.row_index[index] + end = self.row_index[index + 1] + values = self.data[start:end] + columns = self.col_index[start:end] ret = (values, columns) var_features = ( self._var_feature_index.lookup(index, select_features=var_feature_names)[0] @@ -681,41 +748,71 @@ def load(self, stored_path: str) -> None: if self.header is not None and hasattr(self.header, "arrays"): # Map from FileNames.value to dtype string for array_info in self.header.arrays: - if FileNames[array_info.name].value not in self.dtypes: - raise ValueError(f"Array name {FileNames[array_info.name].value} not found in dtypes") self.dtypes[FileNames[array_info.name].value] = array_info.dtype.numpy_dtype_string - # Metadata is required, so we must check if it exists and fail if not. - if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"): - raise FileNotFoundError( - f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist." - ) - - with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi: - self.metadata = json.load(mfi) + # Load metadata if exists + metadata_path = f"{self.data_path}/{FileNames.METADATA.value}" + if os.path.exists(metadata_path): + with open(metadata_path, Mode.READ_APPEND.value) as mfi: + self.metadata = json.load(mfi) + # Load feature indices if os.path.exists(f"{self.data_path}/{FileNames.VAR_FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.VAR_FEATURES.value}") - elif os.path.exists( - f"{self.data_path}/{FileNames.FEATURES.value}" - ): # Backward compatibility with old features file + elif os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}") if os.path.exists(f"{self.data_path}/{FileNames.OBS_FEATURES.value}"): self._obs_feature_index = ObservedFeatureIndex.load(f"{self.data_path}/{FileNames.OBS_FEATURES.value}") - # mmap the existing arrays - self.data = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] - ) - self.row_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] - ) - self.col_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] - ) - # Load neighbor data - if self.load_neighbors: - self._load_neighbor_memmaps() + # Load data arrays - chunked vs monolithic + if self.header is not None and self.header.backend == Backend.CHUNKED_MEMMAP_V0: + self._is_chunked = True + self._load_chunk_memmaps() + else: + self.data = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] + ) + self.row_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] + ) + self.col_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] + ) + if self.load_neighbors: + self._load_neighbor_memmaps() + + def _load_chunk_memmaps(self) -> None: + """Preload all chunk memmaps (lazy - just file handles, no RAM). + + For local datasets, loads from data_path directly. + For remote datasets, this is skipped - chunks are loaded on demand. + """ + # For remote datasets, don't preload - chunks are fetched on demand via get_row() + if self._chunk_loader is not None: + return + # Local: preload all chunk paths + for chunk_id in range(self.header.chunked_info.num_chunks): + chunk_path = Path(self.data_path) / f"chunk_{chunk_id:05d}" + self._chunks.append(self._load_chunk_from_path(chunk_path)) + + def _load_chunk_from_path(self, chunk_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Load memmaps for a single chunk directory. + + Uses np.memmap when dtypes are known (faster), falls back to np.load otherwise. + """ + # Use np.memmap when dtypes are available (faster than np.load) + if self.dtypes: + return ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"), + np.memmap(chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r"), + np.memmap(chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r"), + ) + # Fallback to np.load which auto-detects dtype from .npy header + return ( + np.load(chunk_path / FileNames.DATA.value, mmap_mode="r"), + np.load(chunk_path / FileNames.ROWPTR.value, mmap_mode="r"), + np.load(chunk_path / FileNames.COLPTR.value, mmap_mode="r"), + ) def _write_metadata(self) -> None: with open(f"{self.data_path}/{FileNames.METADATA.value}", f"{Mode.CREATE.value}") as mfi: @@ -1218,6 +1315,8 @@ def number_of_rows(self) -> int: ValueError if the length of the number of rows in the feature index does not correspond to the number of stored rows. """ + if self._is_chunked: + return self.header.chunked_info.total_rows if len(self._var_feature_index) > 0 and self._var_feature_index.number_of_rows() != self.row_index.size - 1: raise ValueError( f"""The number of rows in the feature index {self._var_feature_index.number_of_rows()} @@ -1445,3 +1544,32 @@ def concat( mode=Mode.READ_APPEND.value, ) self.save() + + def to_chunked( + self, output_path: Optional[str] = None, chunk_size: int = 100_000, delete_original: bool = False + ) -> "SingleCellMemMapDataset": + """Convert this dataset to a chunked format for efficient remote access. + + Args: + output_path: Path where the chunked dataset will be created. If None, replaces in-place. + chunk_size: Number of rows per chunk (default: 100,000). + delete_original: If True and output_path is set, delete the original after conversion. + + Returns: + A new SingleCellMemMapDataset instance pointing to the chunked data. + """ + if self._is_chunked: + raise ValueError("Dataset is already chunked") + + src = Path(self.data_path) + if output_path is None: + # In-place: partition to temp, then swap + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / "chunked" + partition_scdl(src, tmp_path, chunk_size=chunk_size) + shutil.rmtree(src) + shutil.move(str(tmp_path), str(src)) + return SingleCellMemMapDataset(str(src)) + + partition_scdl(src, Path(output_path), chunk_size=chunk_size, delete_original=delete_original) + return SingleCellMemMapDataset(output_path) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py index 1affa2a596..e1f1578858 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py @@ -20,7 +20,6 @@ import numpy as np -from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset from bionemo.scdl.schema.header import ChunkedInfo, SCDLHeader from bionemo.scdl.util.scdl_constants import Backend, FileNames @@ -29,8 +28,11 @@ def partition_scdl( input_path: Path, output_path: Path, chunk_size: int = 100_000, + delete_original: bool = False, ) -> SCDLHeader: """Partition an SCDL dataset into chunks.""" + from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + input_path, output_path = Path(input_path), Path(output_path) if not input_path.exists(): @@ -44,7 +46,11 @@ def partition_scdl( source_ds = SingleCellMemMapDataset(str(input_path)) total_rows = len(source_ds) rowptr = source_ds.row_index - num_chunks = (total_rows + chunk_size - 1) // chunk_size + if chunk_size <= 0: + raise ValueError(f"Chunk size must be greater than 0, got {chunk_size}") + if total_rows <= 0: + raise ValueError(f"Total rows must be greater than 0, got {total_rows}") + num_chunks = max(1, (total_rows + chunk_size - 1) // chunk_size) # Create chunks for chunk_id in range(num_chunks): @@ -78,4 +84,8 @@ def partition_scdl( header.chunked_info = ChunkedInfo(chunk_size=chunk_size, num_chunks=num_chunks, total_rows=total_rows) header.save(str(output_path / FileNames.HEADER.value)) + if delete_original: + del source_ds # Release memmap handles + shutil.rmtree(input_path) + return header diff --git a/sub-packages/bionemo-scdl/test_remote_loading.py b/sub-packages/bionemo-scdl/test_remote_loading.py new file mode 100644 index 0000000000..e1cf367a12 --- /dev/null +++ b/sub-packages/bionemo-scdl/test_remote_loading.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for remote chunked SCDL loading with ChunkAwareSampler. + +Usage: + python test_remote_loading.py s3://my-bucket/chunked_scdl + python test_remote_loading.py gs://my-bucket/chunked_scdl + python test_remote_loading.py --cache-dir /tmp/cache --max-chunks 3 s3://bucket/path +""" + +import argparse + +from torch.utils.data import DataLoader + +from bionemo.scdl.io.chunk_sampler import ChunkAwareSampler +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch + + +def main(): + parser = argparse.ArgumentParser(description="Test remote chunked SCDL loading") + parser.add_argument( + "--remote_path", + default="s3://general-purpose/polina/chunked", + help="Remote path (s3://..., gs://..., https://...)", + ) + parser.add_argument("--endpoint-url", default="https://pbss.s8k.io", help="S3 endpoint URL (for non-AWS S3)") + parser.add_argument("--cache-dir", default="/tmp/scdl_cache", help="Local cache directory") + parser.add_argument("--max-chunks", type=int, default=3, help="Max chunks to cache") + parser.add_argument("--chunks-per-window", type=int, default=2, help="Chunks per sampling window") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + parser.add_argument("--num-batches", type=int, default=10, help="Number of batches to iterate") + args = parser.parse_args() + + print(f"Loading remote dataset: {args.remote_path}") + print(f" Endpoint: {args.endpoint_url}") + print(f" Cache dir: {args.cache_dir}") + print(f" Max cached chunks: {args.max_chunks}") + + # Build storage_options for S3-compatible storage + # For s3fs, endpoint_url must be in client_kwargs + storage_options = {} + if args.endpoint_url: + storage_options["client_kwargs"] = {"endpoint_url": args.endpoint_url} + + # 1. Load from remote + ds = SingleCellMemMapDataset.from_remote( + args.remote_path, + cache_dir=args.cache_dir, + max_cached_chunks=args.max_chunks, + storage_options=storage_options if storage_options else None, + ) + print(f" Rows: {len(ds)}") + print(f" Chunks: {ds.header.chunked_info.num_chunks}") + print(f" Chunk size: {ds.header.chunked_info.chunk_size}") + + # 2. Create sampler + print(f"\nCreating ChunkAwareSampler (chunks_per_window={args.chunks_per_window})...") + sampler = ChunkAwareSampler( + ds, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=args.chunks_per_window, + seed=42, + ) + + # 3. Create DataLoader + print(f"Creating DataLoader (batch_size={args.batch_size})...") + loader = DataLoader(ds, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_sparse_matrix_batch) + + # 4. Iterate batches + print(f"\nIterating {args.num_batches} batches...") + for i, batch in enumerate(loader): + if i >= args.num_batches: + break + print(f" Batch {i}: shape={batch.shape}") + + print("\nSuccess! Remote chunked loading works.") + + # 5. Cleanup (optional) + if ds._chunk_loader: + print(f"\nCleaning up cache at {args.cache_dir}...") + ds._chunk_loader.cleanup() + + +if __name__ == "__main__": + main() diff --git a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py index 3b8e934471..7152d7e53f 100644 --- a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py +++ b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py @@ -199,13 +199,24 @@ def _make(tmp_path): @pytest.fixture def make_h5ad_with_raw(make_random_csr): - """Factory to create an h5ad with uniquely randomized data for the fields .raw.X and .X""" + """Factory to create an h5ad with uniquely randomized data for .raw.X, .X, obs, and var.""" def _make(tmp_path): - X = make_random_csr(total_nnz=100, n_cols=50, seed=42) - X_raw = make_random_csr(total_nnz=100, n_cols=50, seed=43) + n_rows, n_cols = 100, 50 + X = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=42) + X_raw = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=43) + + obs = pd.DataFrame( + {"cell_type": [f"type_{i % 3}" for i in range(n_rows)]}, + index=[f"cell_{i}" for i in range(n_rows)], + ) + var = pd.DataFrame( + {"gene_name": [f"gene_{i}" for i in range(n_cols)]}, + index=[f"ENSG{i:08d}" for i in range(n_cols)], + ) + h = tmp_path / "var.h5ad" - ad.AnnData(X=X, var=pd.DataFrame(index=np.arange(X.shape[1])), raw={"X": X_raw}).write_h5ad(h) + ad.AnnData(X=X, obs=obs, var=var, raw={"X": X_raw}).write_h5ad(h) return h return _make