Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
and advance a `grain.DatasetIterator` to the given produced element index.

* Breaking changes:
* Custom implementations of `RandomAccessDataSource` should accept `int`
index in `__getitem__`. Legacy paths that handle `SupportsIndex` will still
work at runtime, but depending on the type checker in use, if you're
directly inheriting from `grain.RandomAccessDataSource` and call
`super().__getitem__` with `Supportsindex` you may see a type checking
error. Switch to `int` to fix it.

* Deprecations:
* Deprecates `grain.python.experimental.MultiprocessPrefetchIterDataset`,
Expand Down
27 changes: 22 additions & 5 deletions docs/data_sources.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,31 @@ could be in a file/storage system or generated on the fly. Data sources need to
implement the following protocol:

```python
class RandomAccessDataSource(Protocol, Generic[T]):
"""Interface for datasources where storage supports efficient random access."""
@typing.runtime_checkable
class RandomAccessDataSource(Protocol[T]):
"""Interface for datasets where storage supports efficient random access.

If used with `DataLoader`, `__repr__` has to be additionally implemented to
support checkpointing.

If used with multiprocessing, must be picklable.
"""

def __len__(self) -> int:
"""Number of records in the dataset."""
"""Returns the total number of records in the data source."""

def __getitem__(self, index: int) -> T:
"""Returns the value for the given index.

This method must be thread-safe and deterministic.

Arguments:
index: An integer in `[0, len(self)-1]`.

def __getitem__(self, record_key: SupportsIndex) -> T:
"""Retrieves record for the given record_key."""
Returns:
The corresponding record. File data sources often return the raw bytes but
records can be any Python object.
"""
```

## File Format
Expand Down
16 changes: 10 additions & 6 deletions docs/grain.sources.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
List of Members
---------------

.. autosummary::
:toctree: _autosummary
.. autoclass:: RandomAccessDataSource
:special-members: __len__, __getitem__,

ArrayRecordDataSource
SharedMemoryDataSource
RandomAccessDataSource
RangeDataSource
.. autoclass:: ArrayRecordDataSource
:special-members: __init__, __len__, __getitem__

.. autoclass:: SharedMemoryDataSource
:special-members: __init__, __len__, __getitem__

.. autoclass:: RangeDataSource
:special-members: __init__, __len__, __getitem__
5 changes: 3 additions & 2 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ py_test(
}),
deps = [
":data_sources",
"//grain/_src/python/dataset:base",
"@abseil-py//absl/flags",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
Expand Down Expand Up @@ -154,7 +155,6 @@ py_library(
],
srcs_version = "PY3",
deps = [
":data_sources",
":operations",
":options",
":record",
Expand All @@ -166,6 +166,7 @@ py_library(
"//grain/_src/core:tree_lib",
"//grain/_src/python/checkpoint:base",
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:base",
"@abseil-py//absl/logging",
"@pypi//etils:pkg",
"@pypi//numpy:pkg",
Expand Down Expand Up @@ -257,12 +258,12 @@ py_library(
srcs_version = "PY3",
deps = [
":data_loader",
":data_sources",
":options",
":samplers",
"//grain/_src/core:monitoring",
"//grain/_src/core:sharding",
"//grain/_src/core:transforms",
"//grain/_src/python/dataset:base",
],
)

Expand Down
12 changes: 6 additions & 6 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from grain._src.python import options
from grain._src.python import record
from grain._src.python.checkpoint import base as checkpoint_base
from grain._src.python.data_sources import RandomAccessDataSource
from grain._src.python.dataset import base as dataset_base
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import batch as batch_ds
from grain._src.python.dataset.transformations import flatmap
Expand Down Expand Up @@ -124,7 +124,7 @@ class _SamplerMapDataset(dataset.MapDataset[record.Record]):

def __init__(
self,
data_source: RandomAccessDataSource,
data_source: dataset_base.RandomAccessDataSource,
sampler: Sampler,
shard_options: sharding.ShardOptions,
):
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(
shard_options: sharding.ShardOptions,
worker_count: int,
sampler: Sampler,
data_source: RandomAccessDataSource,
data_source: dataset_base.RandomAccessDataSource,
):
super().__init__(parent)
self._shard_options = shard_options
Expand All @@ -244,7 +244,7 @@ def __init__(
shard_options: sharding.ShardOptions | None,
worker_count: int,
sampler: Sampler,
data_source: RandomAccessDataSource,
data_source: dataset_base.RandomAccessDataSource,
):
super().__init__(parent)
self._shard_options = shard_options
Expand Down Expand Up @@ -349,7 +349,7 @@ class DataLoader:
def __init__(
self,
*,
data_source: RandomAccessDataSource,
data_source: dataset_base.RandomAccessDataSource,
sampler: Sampler,
operations: Sequence[transforms.Transformation | Operation] = (),
worker_count: Optional[int] = 0,
Expand Down Expand Up @@ -626,7 +626,7 @@ def __str__(self):
return f"PyGrainDatasetIterator(state={self.get_state().decode()})"


def _source_repr(source: RandomAccessDataSource) -> str:
def _source_repr(source: dataset_base.RandomAccessDataSource) -> str:
"""Returns a string representation of the source."""
# If the source has data in memory avoid printing the data itself.
if isinstance(source, (list, tuple, np.ndarray)):
Expand Down
32 changes: 1 addition & 31 deletions grain/_src/python/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
import os
import threading
import time
import typing
from typing import Any, Generic, Optional, Protocol, SupportsIndex, TypeVar, Union
from typing import Any, Generic, Optional, SupportsIndex, TypeVar, Union

from absl import logging
from etils import epath
Expand Down Expand Up @@ -109,35 +108,6 @@ def paths(self) -> ArrayRecordDataSourcePaths:
return self._paths


@typing.runtime_checkable
class RandomAccessDataSource(Protocol, Generic[T]):
"""Interface for datasources where storage supports efficient random access.

Note that `__repr__` has to be additionally implemented to make checkpointing
work with this source.
"""

def __len__(self) -> int:
"""Returns the total number of records in the data source."""

def __getitem__(self, record_key: SupportsIndex) -> T:
"""Returns the value for the given record_key.

This method must be threadsafe. It's also expected to be deterministic.
When using multiprocessing (worker_count>0) PyGrain will pickle the data
source, which invokes __getstate__(), and send a copy to each worker
process, where __setstate__() is called. After that each worker process
has its own independent data source object.

Arguments:
record_key: This will be an integer in [0, len(self)-1].

Returns:
The corresponding record. File data sources often return the raw bytes but
records can be any Python object.
"""


class RangeDataSource:
"""Range data source, similar to python range() function."""

Expand Down
3 changes: 2 additions & 1 deletion grain/_src/python/data_sources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from etils import epath
import multiprocessing as grain_multiprocessing
from grain._src.python import data_sources
from grain._src.python.dataset import base as dataset_base

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -143,7 +144,7 @@ class ArrayRecordDataSourceTest(DataSourceTest):

def test_array_record_data_implements_random_access(self):
assert issubclass(
data_sources.ArrayRecordDataSource, data_sources.RandomAccessDataSource
data_sources.ArrayRecordDataSource, dataset_base.RandomAccessDataSource
)

def test_array_record_source_empty_sequence(self):
Expand Down
27 changes: 24 additions & 3 deletions grain/_src/python/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,34 @@ class ShapeDtypeStruct(ShapeDtypeStructProtocol):

@typing.runtime_checkable
class RandomAccessDataSource(Protocol[T]):
"""Interface for datasets where storage supports efficient random access."""
"""Interface for datasets where storage supports efficient random access.

If used with `DataLoader`, `__repr__` has to be additionally implemented to
support checkpointing.

If used with multiprocessing, must be picklable.
"""

def __len__(self) -> int:
...
"""Returns the total number of records in the data source."""

def __getitem__(self, index: int) -> T:
...
"""Returns the value for the given index.

This method must be thread-safe and deterministic.

Note that a number of sources take `SupportsIndex` instead of `int` for
`index`. Such sources will still support `int` index and pass the
`isinstance` check with this protocol, but all new source implementations
should use `int` directly.

Arguments:
index: An integer in `[0, len(self)-1]`.

Returns:
The corresponding record. File data sources often return the raw bytes but
records can be any Python object.
"""


class SupportsBatchedReadRandomAccessDataSource(
Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from grain._src.core import sharding
from grain._src.core import transforms
from grain._src.python import data_loader
from grain._src.python import data_sources
from grain._src.python import options
from grain._src.python import samplers
from grain._src.python.dataset import base as dataset_base

from grain._src.core import monitoring

Expand All @@ -22,7 +22,7 @@


def load(
source: data_sources.RandomAccessDataSource,
source: dataset_base.RandomAccessDataSource,
*,
num_epochs: Optional[int] = None,
shuffle: bool = False,
Expand Down
6 changes: 4 additions & 2 deletions grain/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,12 @@
from grain._src.python.data_sources import (
ArrayRecordDataSource,
SharedMemoryDataSource as InMemoryDataSource,
RandomAccessDataSource,
RangeDataSource,
)
from grain._src.python.dataset.base import DatasetSelectionMap
from grain._src.python.dataset.base import (
DatasetSelectionMap,
RandomAccessDataSource,
)
from grain._src.python.dataset.dataset import (
MapDataset,
IterDataset,
Expand Down
2 changes: 1 addition & 1 deletion grain/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@
from grain._src.python.data_sources import (
ArrayRecordDataSource,
SharedMemoryDataSource,
RandomAccessDataSource,
RangeDataSource,
)
from grain._src.python.dataset.base import RandomAccessDataSource