Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ py_library(
"//grain/_src/python:checkpointing",
"//grain/_src/python:grain_logging",
"//grain/_src/python:grain_pool",
"//grain/_src/python:multiprocessing_common",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/proto:execution_summary_py_pb2",
Expand Down
185 changes: 101 additions & 84 deletions grain/_src/python/dataset/transformations/process_prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
import copy
import functools
from multiprocessing import queues
from multiprocessing import synchronize
Expand All @@ -29,31 +28,20 @@
from grain._src.core.config import config
import multiprocessing as mp
from grain._src.python import grain_logging
from grain._src.python import multiprocessing_common
from grain._src.python import shared_memory_array
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
from grain._src.python.dataset.transformations import interleave
from grain._src.python.dataset.transformations import prefetch


T = TypeVar("T")

# Type for the iterator state.
StateT = dict[str, Any]

# Minimal interval (in seconds) between consecutive state recordings in worker
# processes of `_ProcessPrefetchDatasetIterator`. We record the state
# periodically to reduce the overhead of sending the state from workers.
# Note that this is also an approximate upper bound on how long it is going to
# take to recover from a checkpointed state. Larger values will decrease the
# overhead of sending the updated state but will also make recovery from a
# checkpoint longer on average.
_RECORD_STATE_INTERVAL_S = 3

# Keys in `_ProcessPrefetchDatasetIterator` checkpoints.
_WORKER_STATE = "worker_state"
_ITERATIONS_TO_SKIP = "iterations_to_skip"

# Timeout for killing worker processes on iterator close.
_PROCESS_KILL_TIMEOUT_S = 10
# Interval to wait in the worker process when the parent iterator is exhausted
Expand All @@ -62,6 +50,8 @@
# Timeout for getting an element from the worker process.
_QUEUE_WAIT_TIMEOUT_S = 1

_is_in_worker_process = False


def _run_all(fns: Sequence[Callable[[], None]]):
for fn in fns:
Expand Down Expand Up @@ -90,27 +80,6 @@ def _get_dataset_options(ds: dataset.IterDataset) -> base.DatasetOptions:
return result


def _validate_no_nested_process_prefetch(
ds: dataset.MapDataset | dataset.IterDataset,
):
"""Checks that there are no nested process prefetch nodes."""
to_check: list[dataset.MapDataset | dataset.IterDataset] = [ds]
while to_check:
d = to_check.pop(0)
if isinstance(
d,
(
ProcessPrefetchIterDataset,
prefetch.MultiprocessPrefetchIterDataset,
),
):
raise ValueError(
"Nesting prefetching with processes is not allowed, but found "
f"{type(d).__name__} under a ProcessPrefetchIterDataset."
)
to_check.extend(d.parents)


def _check_picklable(
ds: dataset.IterDataset | dataset.MapDataset,
):
Expand Down Expand Up @@ -149,23 +118,44 @@ def _serialize_dataset(ds: dataset.IterDataset) -> bytes:
raise e


def _clear_queue_and_maybe_unlink_shm(q: queues.Queue[Any]) -> int:
count = 0
while True:
try:
shared_memory_array.unlink_shm(q.get_nowait())
count += 1
except queue.Empty:
return count


class ProcessPrefetchIterDataset(dataset.IterDataset[T]):
"""Iterable dataset that uses a background process for prefetching."""
"""Iterable dataset that uses a background process for prefetching.

This dataset transformation accepts an IterDataset and prefetches elements
from it in a separate process, buffering up to `buffer_size` elements.
"""

def __init__(
self,
parent: dataset.IterDataset[T],
buffer_size: int,
worker_init_fn: Callable[[], None] | None = None,
):
"""Initializes the ProcessPrefetchIterDataset.

Args:
parent: The dataset to prefetch from.
buffer_size: The size of the buffer used for prefetching.
worker_init_fn: An optional function to run in the worker process at
startup.
"""
if buffer_size <= 0:
raise ValueError(
f"`buffer_size` must be greater than 0, got {buffer_size}."
)
super().__init__(parent)
self._buffer_size = buffer_size
self._worker_init_fn = worker_init_fn
_validate_no_nested_process_prefetch(self._parent)

def __str__(self) -> str:
return f"ProcessPrefetchIterDataset(buffer_size={self._buffer_size})"
Expand Down Expand Up @@ -196,6 +186,8 @@ def _put_dataset_elements_in_buffer(
debug_flags: dict[str, Any],
):
"""Prefetches elements in a separate process."""
global _is_in_worker_process
_is_in_worker_process = True
try:
parse_debug_flags_fn = cloudpickle.loads(pickled_parse_debug_flags_fn)
parse_debug_flags_fn(debug_flags)
Expand All @@ -212,12 +204,13 @@ def _put_dataset_elements_in_buffer(
if set_state_event.is_set():
set_state_event.clear()
parent_exhausted = False
new_state, iterations_to_skip_after_set_state = set_state_queue.get()
new_state = set_state_queue.get()
if new_state is not None:
it.set_state(new_state)
for _ in range(iterations_to_skip_after_set_state):
_ = next(it)
buffer.put((_SetStateIsDone(), None, None))
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(_SetStateIsDone(), None, None), buffer, should_stop.is_set
):
continue
if parent_exhausted:
# Avoid busy-waiting when parent iterator is exhausted due to an
# error. Wait until set_state_event or should_stop is set.
Expand All @@ -226,17 +219,32 @@ def _put_dataset_elements_in_buffer(
try:
element = it.__next__()
except Exception as e: # pylint: disable=broad-except
buffer.put((None, None, e))
multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(None, None, e), buffer, should_stop.is_set
)
parent_exhausted = True
continue
element = shared_memory_array.copy_to_shm(element, min_size=min_shm_size)
# If the node is prefetch, we already record the bytes produced in it's
# __next__ method.
if not it._stats._config.is_prefetch: # pylint: disable=protected-access
it._stats.record_bytes_produced(element) # pylint: disable=protected-access
buffer.put((element, it.get_state(), None))
if not multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(element, it.get_state(), None), buffer, should_stop.is_set
):
# We failed to put the element into the output queue because the
# should_stop event was set. The element may contain a shared memory
# block reference that has to be cleaned up.
shared_memory_array.unlink_shm(element)
except Exception as e: # pylint: disable=broad-except
buffer.put((None, None, e))
_clear_queue_and_maybe_unlink_shm(buffer)
_clear_queue_and_maybe_unlink_shm(set_state_queue)
multiprocessing_common.add_element_to_queue( # pytype: disable=wrong-arg-types
(None, None, e), buffer, should_stop.is_set
)
return
_clear_queue_and_maybe_unlink_shm(buffer)
_clear_queue_and_maybe_unlink_shm(set_state_queue)


class _SetStateIsDone:
Expand Down Expand Up @@ -275,7 +283,6 @@ def __init__(
self._stats_in_queue = self._process_ctx.Queue(maxsize=5)
self._start_profiling_event = self._process_ctx.Event()
self._stop_profiling_event = self._process_ctx.Event()
self._iterations_to_skip = 0
self._set_state_count = 0
self._exhausted = False
self._prefetch_ds_iter = None
Expand Down Expand Up @@ -392,16 +399,12 @@ def __next__(self):
break
elif element is not None:
# Unlink shared memory for the discarded element.
shared_memory_array.open_from_shm(element)
shared_memory_array.unlink_shm(element)
if err is not None:
self._stop_prefetch()
self._exhausted = True
raise err
if state is None:
self._iterations_to_skip += 1
else:
self._iterations_to_skip = 0
self._state = state
self._state = state
with self._stats.record_self_time(offset_ns=timer.value()):
element = self._stats.record_bytes_produced(element)
return shared_memory_array.open_from_shm(element)
Expand All @@ -411,57 +414,40 @@ def close(self):
self._closed = True
self._stop_prefetch()

def _clear_buffer(self):
while True:
try:
element, _, _ = self._buffer.get_nowait()
if element is not None and not isinstance(element, _SetStateIsDone):
shared_memory_array.open_from_shm(element)
except queue.Empty:
return

def _clear_set_state_queue(self):
try:
self._set_state_queue.get_nowait()
if _clear_queue_and_maybe_unlink_shm(self._set_state_queue):
self._set_state_count -= 1
except queue.Empty:
return

def _stop_prefetch(self):
"""Stops the prefetching process if it's currently running."""
if self._prefetch_process is None:
return

self._prefetch_should_stop.set()
# Remove entries from the buffer to unblock the producer, so that it checks
# producer_running.is_set() and exits.
self._clear_buffer()
self._prefetch_process.join(_PROCESS_KILL_TIMEOUT_S)

# Not joining here will cause the children to be zombie after they finish.
# Need to join or call active_children.
self._prefetch_process.join(timeout=_PROCESS_KILL_TIMEOUT_S)

# In case all our attempts to terminate the system fails, we forcefully
# kill the child processes.
if self._prefetch_process.is_alive():
self._prefetch_process.kill()
self._prefetch_process = None
# Clear the buffer again in case the prefetch loop added more elements on
# exit.
self._clear_buffer()
_clear_queue_and_maybe_unlink_shm(self._buffer)
self._clear_set_state_queue()
self._set_state_count = 0

def get_state(self) -> StateT:
if self._state is None:
worker_state = self._iter_parent.__iter__().get_state()
else:
worker_state = self._state
return {
_WORKER_STATE: worker_state,
_ITERATIONS_TO_SKIP: self._iterations_to_skip,
}
return self._iter_parent.__iter__().get_state()
return self._state

def set_state(self, state: StateT):
self._state = state[_WORKER_STATE]
self._iterations_to_skip = state[_ITERATIONS_TO_SKIP]
self._state = state
# Remove any pending set_state calls.
self._clear_set_state_queue()
self._set_state_queue.put((self._state, self._iterations_to_skip))
self._set_state_queue.put(self._state)
# Signal the prefetch process to start processing set_state calls.
self._set_state_event.set()
# Increment the number of _SetStateIsDone that need to be skipped to
Expand All @@ -473,6 +459,35 @@ def __str__(self) -> str:
return f"ProcessPrefetchDatasetIterator(buffer_size={self._buffer_size})"


class _LazyWorkerSliceIterDataset(dataset.IterDataset[T]):
"""Applies slice to the parent dataset in the worker process."""

def __init__(
self,
parent: dataset.IterDataset[T],
sl: slice,
sequential_slice: bool,
):
super().__init__(parent)
self._slice = sl
self._sequential_slice = sequential_slice

def __iter__(self) -> dataset.DatasetIterator[T]:
if not _is_in_worker_process:
return self._parent.__iter__()
prefetch._set_slice_iter_dataset(
self._parent, self._slice, self._sequential_slice
)
return self._parent.__iter__()

@property
def _element_spec(self) -> Any:
return dataset.get_element_spec(self._parent)

def __str__(self) -> str:
return f"_LazyWorkerSliceIterDataset(slice={self._slice})"


def multiprocess_prefetch(
ds: dataset.IterDataset[T],
num_workers: int = 0,
Expand Down Expand Up @@ -507,10 +522,12 @@ def multiprocess_prefetch(
if num_workers == 1:
worker_ds = ds
else:
worker_ds = copy.deepcopy(ds)
prefetch._set_slice_iter_dataset( # pylint: disable=protected-access
worker_ds, slice(i, None, num_workers), sequential_slice
worker_ds = _LazyWorkerSliceIterDataset(
ds,
slice(i, None, num_workers),
sequential_slice,
)

worker_ds = prefetch._MpContextIterDataset( # pylint: disable=protected-access
worker_ds,
base.MultiprocessingContext(
Expand Down
Loading