diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 30d0bc71..41abd187 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -642,6 +642,7 @@ def setup( self._df: dict[int, Any] = {} self._chunk_row_groups: dict[int, Any] = {} self._chunk_row_group_item_read_count: dict[int, Any] = {} + self._chunk_row_group_offsets: dict[int, list[int]] = {} def generate_intervals(self) -> list[Interval]: intervals = [] @@ -712,18 +713,28 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i Returns: Any: The dataframe row corresponding to the specified index. """ + import bisect + import polars as pl import pyarrow.parquet as pq # Load the Parquet file metadata if not already loaded if chunk_index not in self._df: - self._df[chunk_index] = pq.ParquetFile(chunk_filepath) - - # Determine the row group and the row index within the row group - parquet_file = self._df[chunk_index] - num_rows_per_row_group = parquet_file.metadata.row_group(0).num_rows - row_group_index = row_index // num_rows_per_row_group - row_index_within_group = row_index % num_rows_per_row_group + parquet_file = pq.ParquetFile(chunk_filepath) + self._df[chunk_index] = parquet_file + # Precompute cumulative row offsets as a prefix-sum so lookup works for row groups of any size. + offsets = [0] + num_row_groups = parquet_file.metadata.num_row_groups + for i in range(num_row_groups): + num_rows = parquet_file.metadata.row_group(i).num_rows + offsets.append(offsets[-1] + num_rows) + self._chunk_row_group_offsets[chunk_index] = offsets + + # Locate the row group containing row_index and the offset inside it. + offsets = self._chunk_row_group_offsets[chunk_index] + row_group_index = bisect.bisect_right(offsets, row_index) - 1 + row_index_within_group = row_index - offsets[row_group_index] + row_group_size = offsets[row_group_index + 1] - offsets[row_group_index] # Check if the row group is already loaded if chunk_index in self._chunk_row_groups and row_group_index in self._chunk_row_groups[chunk_index]: @@ -746,7 +757,7 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i # Check if the row group has been fully read and release memory if necessary read_count = self._chunk_row_group_item_read_count[chunk_index][row_group_index] - if read_count >= num_rows_per_row_group: + if read_count >= row_group_size: # Release memory for the fully read row group del self._chunk_row_groups[chunk_index][row_group_index] del self._chunk_row_group_item_read_count[chunk_index][row_group_index] @@ -797,6 +808,8 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: if chunk_index in self._chunk_row_group_item_read_count: del self._chunk_row_group_item_read_count[chunk_index] + if chunk_index in self._chunk_row_group_offsets: + del self._chunk_row_group_offsets[chunk_index] if os.path.exists(chunk_filepath): os.remove(chunk_filepath) logger.debug( @@ -820,5 +833,8 @@ def close(self, chunk_index: int) -> None: if chunk_index in self._chunk_row_group_item_read_count: del self._chunk_row_group_item_read_count[chunk_index] + if chunk_index in self._chunk_row_group_offsets: + del self._chunk_row_group_offsets[chunk_index] + def encode_data(self, data: list[bytes], sizes: list[int], flattened: list[Any]) -> Any: pass diff --git a/tests/streaming/test_item_loader.py b/tests/streaming/test_item_loader.py index 68555bc1..fb8b58a9 100644 --- a/tests/streaming/test_item_loader.py +++ b/tests/streaming/test_item_loader.py @@ -7,7 +7,8 @@ from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING from litdata.streaming import Cache, item_loader from litdata.streaming.dataset import StreamingDataset -from litdata.streaming.item_loader import PyTreeLoader, TokensLoader +from litdata.streaming.item_loader import ParquetLoader, PyTreeLoader, TokensLoader +from litdata.streaming.writer import index_parquet_dataset def test_serializer_setup(): @@ -89,3 +90,91 @@ def test_force_download(monkeypatch, tmpdir): with pytest.raises(Exception, match="worked"): loader.load_item_from_chunk(0, 0, "chunk_filepath", 0, 1) + + +def _write_parquet_with_row_groups(path, row_group_values): + """Write a parquet file where each element of row_group_values becomes its own row group.""" + import pyarrow as pa + import pyarrow.parquet as pq + + schema = pa.schema([("col", pa.int64())]) + with pq.ParquetWriter(path, schema) as writer: + for values in row_group_values: + writer.write_table(pa.table({"col": list(values)}, schema=schema)) + + +@pytest.mark.parametrize( + "row_group_sizes", + [ + [10, 5, 5], # regression: uneven groups, shrinking + [3, 7, 2, 8], # uneven groups, varying + [20], # single group + [1, 1, 1, 1, 1], # many size-1 groups + [5, 5, 5], # uniform control case + ], +) +@pytest.mark.parametrize("low_memory", [True, False]) +def test_parquet_loader_row_group_sizes(tmp_path, row_group_sizes, low_memory): + """ParquetLoader must correctly read every row regardless of row-group layout.""" + parquet_dir = tmp_path / "pq" + parquet_dir.mkdir() + + row_group_values = [] + expected = [] + + for value, size in enumerate(row_group_sizes): + row_group_values.append([value] * size) + expected.extend([value] * size) + value += 1 + _write_parquet_with_row_groups(parquet_dir / "data.parquet", row_group_values) + + index_parquet_dataset(str(parquet_dir)) + dataset = StreamingDataset(str(parquet_dir), item_loader=ParquetLoader(low_memory=low_memory)) + + assert len(dataset) == sum(row_group_sizes) + actual = [dataset[i]["col"] for i in range(len(dataset))] + assert actual == expected + + +def test_parquet_loader_row_group_boundaries(tmp_path): + """First and last row of each group (the modulo edges in the old implementation).""" + parquet_dir = tmp_path / "pq" + parquet_dir.mkdir() + + row_group_sizes = [10, 5, 5] + _write_parquet_with_row_groups( + parquet_dir / "data.parquet", + [[v] * s for v, s in enumerate(row_group_sizes)], + ) + + index_parquet_dataset(str(parquet_dir)) + dataset = StreamingDataset(str(parquet_dir), item_loader=ParquetLoader(low_memory=True)) + + boundaries = [0, 9, 10, 14, 15, 19] + expected = [0, 0, 1, 1, 2, 2] + for idx, exp in zip(boundaries, expected): + assert dataset[idx]["col"] == exp + + +def test_parquet_loader_cache_eviction_with_uneven_groups(tmp_path): + """After fully reading a row group, it must be evicted from the in-memory cache.""" + parquet_dir = tmp_path / "pq" + parquet_dir.mkdir() + + row_group_sizes = [10, 5, 5] + _write_parquet_with_row_groups( + parquet_dir / "data.parquet", + [[v] * s for v, s in enumerate(row_group_sizes)], + ) + + index_parquet_dataset(str(parquet_dir)) + loader = ParquetLoader(low_memory=True) + dataset = StreamingDataset(str(parquet_dir), item_loader=loader) + + # Iterate through the whole dataset sequentially. + for i in range(len(dataset)): + dataset[i] + + # After a sequential pass every row group in the chunk should have been evicted. + for chunk_index, groups in loader._chunk_row_groups.items(): + assert groups == {}, f"chunk {chunk_index} still holds row groups: {list(groups)}"