Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/litdata/streaming/item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def setup(
self._df: dict[int, Any] = {}
self._chunk_row_groups: dict[int, Any] = {}
self._chunk_row_group_item_read_count: dict[int, Any] = {}
self._chunk_row_group_offsets: dict[int, list[int]] = {}

def generate_intervals(self) -> list[Interval]:
intervals = []
Expand Down Expand Up @@ -712,18 +713,28 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i
Returns:
Any: The dataframe row corresponding to the specified index.
"""
import bisect

import polars as pl
import pyarrow.parquet as pq

# Load the Parquet file metadata if not already loaded
if chunk_index not in self._df:
self._df[chunk_index] = pq.ParquetFile(chunk_filepath)

# Determine the row group and the row index within the row group
parquet_file = self._df[chunk_index]
num_rows_per_row_group = parquet_file.metadata.row_group(0).num_rows
row_group_index = row_index // num_rows_per_row_group
row_index_within_group = row_index % num_rows_per_row_group
parquet_file = pq.ParquetFile(chunk_filepath)
self._df[chunk_index] = parquet_file
# Precompute cumulative row offsets as a prefix-sum so lookup works for row groups of any size.
offsets = [0]
num_row_groups = parquet_file.metadata.num_row_groups
for i in range(num_row_groups):
num_rows = parquet_file.metadata.row_group(i).num_rows
offsets.append(offsets[-1] + num_rows)
self._chunk_row_group_offsets[chunk_index] = offsets

# Locate the row group containing row_index and the offset inside it.
offsets = self._chunk_row_group_offsets[chunk_index]
row_group_index = bisect.bisect_right(offsets, row_index) - 1
row_index_within_group = row_index - offsets[row_group_index]
row_group_size = offsets[row_group_index + 1] - offsets[row_group_index]

# Check if the row group is already loaded
if chunk_index in self._chunk_row_groups and row_group_index in self._chunk_row_groups[chunk_index]:
Expand All @@ -746,7 +757,7 @@ def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_i

# Check if the row group has been fully read and release memory if necessary
read_count = self._chunk_row_group_item_read_count[chunk_index][row_group_index]
if read_count >= num_rows_per_row_group:
if read_count >= row_group_size:
# Release memory for the fully read row group
del self._chunk_row_groups[chunk_index][row_group_index]
del self._chunk_row_group_item_read_count[chunk_index][row_group_index]
Expand Down Expand Up @@ -797,6 +808,8 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None:

if chunk_index in self._chunk_row_group_item_read_count:
del self._chunk_row_group_item_read_count[chunk_index]
if chunk_index in self._chunk_row_group_offsets:
del self._chunk_row_group_offsets[chunk_index]
if os.path.exists(chunk_filepath):
os.remove(chunk_filepath)
logger.debug(
Expand All @@ -820,5 +833,8 @@ def close(self, chunk_index: int) -> None:
if chunk_index in self._chunk_row_group_item_read_count:
del self._chunk_row_group_item_read_count[chunk_index]

if chunk_index in self._chunk_row_group_offsets:
del self._chunk_row_group_offsets[chunk_index]

def encode_data(self, data: list[bytes], sizes: list[int], flattened: list[Any]) -> Any:
pass
91 changes: 90 additions & 1 deletion tests/streaming/test_item_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from litdata.constants import _NUMPY_DTYPES_MAPPING, _TORCH_DTYPES_MAPPING
from litdata.streaming import Cache, item_loader
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import PyTreeLoader, TokensLoader
from litdata.streaming.item_loader import ParquetLoader, PyTreeLoader, TokensLoader
from litdata.streaming.writer import index_parquet_dataset


def test_serializer_setup():
Expand Down Expand Up @@ -89,3 +90,91 @@ def test_force_download(monkeypatch, tmpdir):

with pytest.raises(Exception, match="worked"):
loader.load_item_from_chunk(0, 0, "chunk_filepath", 0, 1)


def _write_parquet_with_row_groups(path, row_group_values):
"""Write a parquet file where each element of row_group_values becomes its own row group."""
import pyarrow as pa
import pyarrow.parquet as pq

schema = pa.schema([("col", pa.int64())])
with pq.ParquetWriter(path, schema) as writer:
for values in row_group_values:
writer.write_table(pa.table({"col": list(values)}, schema=schema))


@pytest.mark.parametrize(
"row_group_sizes",
[
[10, 5, 5], # regression: uneven groups, shrinking
[3, 7, 2, 8], # uneven groups, varying
[20], # single group
[1, 1, 1, 1, 1], # many size-1 groups
[5, 5, 5], # uniform control case
],
)
@pytest.mark.parametrize("low_memory", [True, False])
def test_parquet_loader_row_group_sizes(tmp_path, row_group_sizes, low_memory):
"""ParquetLoader must correctly read every row regardless of row-group layout."""
parquet_dir = tmp_path / "pq"
parquet_dir.mkdir()

row_group_values = []
expected = []

for value, size in enumerate(row_group_sizes):
row_group_values.append([value] * size)
expected.extend([value] * size)
value += 1
_write_parquet_with_row_groups(parquet_dir / "data.parquet", row_group_values)

index_parquet_dataset(str(parquet_dir))
dataset = StreamingDataset(str(parquet_dir), item_loader=ParquetLoader(low_memory=low_memory))

assert len(dataset) == sum(row_group_sizes)
actual = [dataset[i]["col"] for i in range(len(dataset))]
assert actual == expected


def test_parquet_loader_row_group_boundaries(tmp_path):
"""First and last row of each group (the modulo edges in the old implementation)."""
parquet_dir = tmp_path / "pq"
parquet_dir.mkdir()

row_group_sizes = [10, 5, 5]
_write_parquet_with_row_groups(
parquet_dir / "data.parquet",
[[v] * s for v, s in enumerate(row_group_sizes)],
)

index_parquet_dataset(str(parquet_dir))
dataset = StreamingDataset(str(parquet_dir), item_loader=ParquetLoader(low_memory=True))

boundaries = [0, 9, 10, 14, 15, 19]
expected = [0, 0, 1, 1, 2, 2]
for idx, exp in zip(boundaries, expected):
assert dataset[idx]["col"] == exp


def test_parquet_loader_cache_eviction_with_uneven_groups(tmp_path):
"""After fully reading a row group, it must be evicted from the in-memory cache."""
parquet_dir = tmp_path / "pq"
parquet_dir.mkdir()

row_group_sizes = [10, 5, 5]
_write_parquet_with_row_groups(
parquet_dir / "data.parquet",
[[v] * s for v, s in enumerate(row_group_sizes)],
)

index_parquet_dataset(str(parquet_dir))
loader = ParquetLoader(low_memory=True)
dataset = StreamingDataset(str(parquet_dir), item_loader=loader)

# Iterate through the whole dataset sequentially.
for i in range(len(dataset)):
dataset[i]

# After a sequential pass every row group in the chunk should have been evicted.
for chunk_index, groups in loader._chunk_row_groups.items():
assert groups == {}, f"chunk {chunk_index} still holds row groups: {list(groups)}"
Loading