From d47a7f0000ce0574f8052aabe0c4c4add45bf5d7 Mon Sep 17 00:00:00 2001 From: Nithin Tatikonda Date: Tue, 10 Feb 2026 09:39:39 -0800 Subject: [PATCH] mp_prefetch uses multithreading when the GIL is disabled. PiperOrigin-RevId: 868188468 --- CHANGELOG.md | 2 ++ grain/_src/python/dataset/dataset.py | 25 ++++++++++++++++++++++- grain/_src/python/dataset/dataset_test.py | 18 ++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f43006174..1321a77dd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change errors. * Adds experimental support for `get_next_index` and `set_next_index` to fetch and advance a `grain.DatasetIterator` to the given produced element index. + * Switches to multithreading instead of multiprocessing in + `IterDataset.mp_prefetch` when free-threaded Python is detected. * Breaking changes: * Custom implementations of `RandomAccessDataSource` should accept `int` diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 2e77d65a1..e506508b1 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -50,6 +50,7 @@ from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence import functools import json +import sys from typing import Any, Generic, TypeVar, Union, cast, overload import warnings @@ -1334,6 +1335,9 @@ def mp_prefetch( multiprocessing resources. We will by default run the cleanup on garbage collection, but GC and its sequence is not guaranteed in CPython. + NOTE: In free-threaded Python builds, this implementation switches to + multithreading, ignoring ``worker_init_fn``. + Args: options: options for the prefetching processes. ``options.num_workers`` must be greater than or equal to 0. If ``options.num_workers`` is 0, @@ -1353,10 +1357,24 @@ def mp_prefetch( """ options = options or grain_options.MultiprocessingOptions(num_workers=10) - # Loaded lazily due to a circular dependency (dataset <-> process_prefetch). + # Loaded lazily due to a circular dependency (dataset <-> process_prefetch) + # and (dataset <-> prefetch). # pylint: disable=g-import-not-at-top + from grain._src.python.dataset.transformations import prefetch from grain._src.python.dataset.transformations import process_prefetch # pylint: enable=g-import-not-at-top + if is_in_free_threaded_python(): + if worker_init_fn is not None: + warnings.warn( + "Free-threaded Python is used: `mp_prefetch` falls back to" + " thread-based implementation and `worker_init_fn` is ignored." + ) + return prefetch.multithread_prefetch( + self, + num_threads=options.num_workers, + buffer_size=options.per_worker_buffer_size, + sequential_slice=sequential_slice, + ) return process_prefetch.multiprocess_prefetch( self, num_workers=options.num_workers, @@ -1835,3 +1853,8 @@ def set_next_index(ds_iter: DatasetIterator, index: int) -> None: def get_next_index(ds_iter: DatasetIterator) -> int: """Returns the next index for the dataset iterator.""" return ds_iter._get_next_index() # pylint: disable=protected-access + + +def is_in_free_threaded_python() -> bool: + """Returns whether Python is running in free-threaded mode.""" + return hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() # pylint: disable=protected-access diff --git a/grain/_src/python/dataset/dataset_test.py b/grain/_src/python/dataset/dataset_test.py index 54411432f..fb76f2078 100644 --- a/grain/_src/python/dataset/dataset_test.py +++ b/grain/_src/python/dataset/dataset_test.py @@ -32,6 +32,7 @@ from grain._src.python.dataset import dataset from grain._src.python.dataset import stats as dataset_stats from grain._src.python.dataset.transformations import prefetch +from grain._src.python.dataset.transformations import process_prefetch import grain._src.python.testing.experimental as test_util from grain.proto import execution_summary_pb2 import numpy as np @@ -1183,6 +1184,23 @@ def test_apply(self, ds): ], ) + @mock.patch.object(process_prefetch, "multiprocess_prefetch", autospec=True) + @mock.patch.object(prefetch, "multithread_prefetch", autospec=True) + def test_mp_prefetch_switches_to_threads_for_free_threaded_python( + self, mock_multithread_prefetch, mock_multiprocess_prefetch + ): + ds = dataset.MapDataset.range(15).to_iter_dataset() + ds.mp_prefetch() + is_free_threaded = ( + hasattr(sys, "_is_gil_enabled") and not sys._is_gil_enabled() + ) + if is_free_threaded: + mock_multithread_prefetch.assert_called_once() + mock_multiprocess_prefetch.assert_not_called() + else: + mock_multiprocess_prefetch.assert_called_once() + mock_multithread_prefetch.assert_not_called() + class TfRandomMapAlwaysAddingOne(transforms.TfRandomMap):