diff --git a/CHANGELOG.md b/CHANGELOG.md index 08de8b00..96d44a06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change hiding first batch processing behind model checkpoint recovery. * Introduces `grain.experimental.multithread_prefetch` as an alternative to multiprocessing prefetch in free-threading Python. + * Switches to multithreading instead of multiprocessing in + `IterDataset.mp_prefetch` when free-threaded Python is detected. * Breaking changes: diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 02e76211..c1f6e52d 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -50,6 +50,7 @@ from collections.abc import Awaitable, Callable, Iterable, Iterator, Sequence import functools import json +import sys from typing import Any, Generic, TypeVar, Union, cast, overload import warnings @@ -1290,6 +1291,9 @@ def mp_prefetch( prefetch workers, consider moving many-to-one and stateful transformations to after ``mp_prefetch`` or outside of the Grain pipeline. + NOTE: In free-threaded Python builds, this implementation switches to + multithreading, ignoring ``worker_init_fn``. + Args: options: options for the prefetching processes. ``options.num_workers`` must be greater than or equal to 0. If ``options.num_workers`` is 0, @@ -1312,6 +1316,18 @@ def mp_prefetch( # pylint: disable=g-import-not-at-top from grain._src.python.dataset.transformations import prefetch # pylint: enable=g-import-not-at-top + if is_in_free_threaded_python(): + if worker_init_fn is not None: + warnings.warn( + "Free-threaded Python is used: `mp_prefetch` falls back to" + " thread-based implementation and `worker_init_fn` is ignored." + ) + return prefetch.multithread_prefetch( + self, + num_threads=options.num_workers, + buffer_size=options.per_worker_buffer_size, + sequential_slice=sequential_slice, + ) return prefetch.MultiprocessPrefetchIterDataset( self, multiprocessing_options=options, @@ -1684,3 +1700,8 @@ def get_execution_summary( ) return execution_stats._get_execution_summary() # pylint: enable=protected-access + + +def is_in_free_threaded_python() -> bool: + """Returns whether Python is running in free-threaded mode.""" + return hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() # pylint: disable=protected-access diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 42a208cb..ce171393 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -31,6 +31,7 @@ from grain._src.python.dataset import base from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats +from grain._src.python.dataset.transformations import prefetch import grain._src.python.testing.experimental as test_util from grain.proto import execution_summary_pb2 import numpy as np @@ -1182,6 +1183,21 @@ def test_apply(self, ds): ], ) + def test_mp_prefetch_switches_to_threads_for_free_threaded_python(self): + ds = dataset.MapDataset.range(15).to_iter_dataset() + prefetched_ds = ds.mp_prefetch() + is_free_threaded = ( + hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() + ) + if is_free_threaded: + self.assertNotIsInstance( + prefetched_ds, prefetch.MultiprocessPrefetchIterDataset + ) + else: + self.assertIsInstance( + prefetched_ds, prefetch.MultiprocessPrefetchIterDataset + ) + class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform):