Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
hiding first batch processing behind model checkpoint recovery.
* Introduces `grain.experimental.multithread_prefetch` as an
alternative to multiprocessing prefetch in free-threading Python.
* Switches to multithreading instead of multiprocessing in
`IterDataset.mp_prefetch` when free-threaded Python is detected.

* Breaking changes:

Expand Down
21 changes: 21 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from collections.abc import Awaitable, Callable, Iterable, Iterator, Sequence
import functools
import json
import sys
from typing import Any, Generic, TypeVar, Union, cast, overload
import warnings

Expand Down Expand Up @@ -1290,6 +1291,9 @@ def mp_prefetch(
prefetch workers, consider moving many-to-one and stateful transformations
to after ``mp_prefetch`` or outside of the Grain pipeline.

NOTE: In free-threaded Python builds, this implementation switches to
multithreading, ignoring ``worker_init_fn``.

Args:
options: options for the prefetching processes. ``options.num_workers``
must be greater than or equal to 0. If ``options.num_workers`` is 0,
Expand All @@ -1312,6 +1316,18 @@ def mp_prefetch(
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import prefetch
# pylint: enable=g-import-not-at-top
if is_in_free_threaded_python():
if worker_init_fn is not None:
warnings.warn(
"Free-threaded Python is used: `mp_prefetch` falls back to"
" thread-based implementation and `worker_init_fn` is ignored."
)
return prefetch.multithread_prefetch(
self,
num_threads=options.num_workers,
buffer_size=options.per_worker_buffer_size,
sequential_slice=sequential_slice,
)
return prefetch.MultiprocessPrefetchIterDataset(
self,
multiprocessing_options=options,
Expand Down Expand Up @@ -1684,3 +1700,8 @@ def get_execution_summary(
)
return execution_stats._get_execution_summary()
# pylint: enable=protected-access


def is_in_free_threaded_python() -> bool:
"""Returns whether Python is running in free-threaded mode."""
return hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() # pylint: disable=protected-access
16 changes: 16 additions & 0 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import prefetch
import grain._src.python.testing.experimental as test_util
from grain.proto import execution_summary_pb2
import numpy as np
Expand Down Expand Up @@ -1182,6 +1183,21 @@ def test_apply(self, ds):
],
)

def test_mp_prefetch_switches_to_threads_for_free_threaded_python(self):
ds = dataset.MapDataset.range(15).to_iter_dataset()
prefetched_ds = ds.mp_prefetch()
is_free_threaded = (
hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled()
)
if is_free_threaded:
self.assertNotIsInstance(
prefetched_ds, prefetch.MultiprocessPrefetchIterDataset
)
else:
self.assertIsInstance(
prefetched_ds, prefetch.MultiprocessPrefetchIterDataset
)


class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform):

Expand Down