From 6172d6191d8c15ab5830e57a92e4fec74f8ad043 Mon Sep 17 00:00:00 2001 From: Kenny Yu Date: Fri, 21 Nov 2025 22:34:56 +0000 Subject: [PATCH 1/2] [tinker-cookbook] rl: avoid hanging in async runs when we run out of data Previously, on async RL runs, we can hang in shutdown if we run out of data. Thi fixes it to ensure proper shutdown and that all data in queues are drained with the dataloader loop terminates first. --- tinker_cookbook/rl/train.py | 128 +++++++++++++++++++++++++----------- 1 file changed, 90 insertions(+), 38 deletions(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 33c9f8ed..71fcae4e 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -7,12 +7,15 @@ import logging import os import time -from typing import Any, Callable, List, Literal, Sequence, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Iterator, List, Literal, Sequence import chz import numpy as np import tinker import torch + from tinker_cookbook import checkpoint_utils from tinker_cookbook.completers import TinkerTokenCompleter from tinker_cookbook.display import colorize_example @@ -39,9 +42,7 @@ from tinker_cookbook.tokenizer_utils import Tokenizer from tinker_cookbook.utils import logtree, ml_log from tinker_cookbook.utils.misc_utils import safezip, split_list, timed -from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context -from contextlib import contextmanager - +from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init logger = logging.getLogger(__name__) @@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch( ): # Samplers will produce trajectory groups asynchronously, # and the trainer will consume them as soon as they are ready - trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]() + trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]() env_group_builders_P = dataset.get_batch(i_batch) @scope @@ -393,10 +394,7 @@ async def trajectory_group_worker_task( ) # Run multiple optimizer substeps per training iteration - ( - sampling_client, - full_batch_metrics, - ) = await do_train_step_streaming_and_get_sampling_client( + streaming_result = await do_train_step_streaming_and_get_sampling_client( cfg, i_batch, trajectory_groups_queue, @@ -404,6 +402,10 @@ async def trajectory_group_worker_task( service_client, tokenizer, ) + if streaming_result is None: + logger.info("[do_sync_training_with_stream_minibatch] Received shutdown signal") + return + sampling_client, full_batch_metrics = streaming_result # Log metrics metrics.update(full_batch_metrics) @@ -428,6 +430,22 @@ class WrappedTrajectoryGroup: metrics: dict[str, Any] = chz.field(default_factory=dict) +@dataclass +class Shutdown: + pass + + +class AsyncCounter: + def __init__(self, start: int = 0): + self.value = start + self.lock = asyncio.Lock() + + async def decrement_and_get(self) -> int: + async with self.lock: + self.value -= 1 + return self.value + + @scope async def do_async_training( start_batch: int, @@ -444,13 +462,12 @@ async def do_async_training( """Implements async off-policy training, capped at K steps off policy.""" assert cfg.async_config is not None - shutdown_event = asyncio.Event() # We will have groups_per_batch worker generating rollouts, so cap the # queue size to be groups_per_batch. - env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | None]( + env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | Shutdown]( maxsize=cfg.async_config.groups_per_batch ) - trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]() + trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]() # Initial sampling client to use path_dict = await checkpoint_utils.save_checkpoint_async( @@ -461,38 +478,46 @@ async def do_async_training( kind="both", ) + # When the dataloader is out of data, we want to make sure all remaining samples + # are processed before terminating. + evaluation_loop_should_shutdown_event = asyncio.Event() + trajectory_group_worker_alive_counter = AsyncCounter(cfg.async_config.groups_per_batch) + # This will be updated by the training loop sampling_client = training_client.create_sampling_client(path_dict["sampler_path"]) sampling_client_step = start_batch sampling_client_updated_event = asyncio.Event() sampling_client_updated_event.set() - @scope - def shutdown_loops(): - """Trigger all loops to shutdown""" - shutdown_event.set() - assert cfg.async_config is not None - for _ in range(cfg.async_config.groups_per_batch): - env_group_builders_queue.put_nowait(None) - sampling_client_updated_event.set() - @scope async def dataloader_loop(): """Gets the next set of env builders to run""" i_batch = start_batch - while not shutdown_event.is_set() and i_batch < end_batch: + while not i_batch < end_batch: env_group_builders_P = dataset.get_batch(i_batch) for env_group_builder in env_group_builders_P: await env_group_builders_queue.put(env_group_builder) i_batch += 1 + # We are done with the data loader loop, enqueue sentinel values + # to allow the trajectory group worker loops to terminate. + logger.info("[dataloader_loop] No more data, shutting down trajectory group worker loops") + assert cfg.async_config is not None + for _ in range(cfg.async_config.groups_per_batch): + await env_group_builders_queue.put(Shutdown()) + logger.info("[dataloader_loop] Data loader loop terminated") + @scope async def trajectory_group_worker_loop(): """Generates trajectories for a single env builder""" - while not shutdown_event.is_set(): + while True: env_group_builder = await env_group_builders_queue.get() - if env_group_builder is None: - break + match env_group_builder: + case EnvGroupBuilder(): + pass + case Shutdown(): + logger.info("[trajectory_group_worker_loop] Received shutdown signal") + break metrics = {} t_start = time.time() @@ -518,6 +543,14 @@ async def trajectory_group_worker_loop(): metrics=metrics, ) ) + num_alive_workers = await trajectory_group_worker_alive_counter.decrement_and_get() + if num_alive_workers == 0: + # All workers are done, enqueue a sentinel to terminate the training loop + logger.info( + "[trajectory_group_worker_loop] Last worker terminated, shutting down training loop" + ) + trajectory_groups_queue.put_nowait(Shutdown()) + logger.info("[trajectory_group_worker_loop] Trajectory group worker loop terminated") @scope async def training_loop(): @@ -530,9 +563,6 @@ async def training_loop(): i_batch = start_batch wrapped_trajectory_groups = [] while i_batch < end_batch: - wrapped_trajectory_group = await trajectory_groups_queue.get() - if wrapped_trajectory_group is None: - continue @scope def filter_stale_trajectory_group( @@ -567,10 +597,7 @@ def filter_stale_trajectory_group( nonlocal sampling_client nonlocal sampling_client_step if cfg.stream_minibatch_config is not None: - ( - sampling_client, - train_step_metrics, - ) = await do_train_step_streaming_and_get_sampling_client( + streaming_result = await do_train_step_streaming_and_get_sampling_client( cfg, i_batch, trajectory_groups_queue, @@ -579,7 +606,21 @@ def filter_stale_trajectory_group( tokenizer, filter_stale_trajectory_group, ) + if streaming_result is None: + logger.info("[training_loop] Received shutdown signal") + break + sampling_client, train_step_metrics = streaming_result else: + wrapped_trajectory_group = await trajectory_groups_queue.get() + match wrapped_trajectory_group: + case WrappedTrajectoryGroup(): + pass + case Shutdown(): + logger.info("[training_loop] Received shutdown signal") + break + case None: + continue + if not filter_stale_trajectory_group(wrapped_trajectory_group): continue @@ -618,7 +659,9 @@ def filter_stale_trajectory_group( i_batch += 1 wrapped_trajectory_groups = [] - shutdown_loops() + evaluation_loop_should_shutdown_event.set() + sampling_client_updated_event.set() + logger.info("[training_loop] Training loop terminated") @scope async def evaluation_loop(): @@ -626,7 +669,7 @@ async def evaluation_loop(): if len(evaluators) == 0 or cfg.eval_every == 0: return - while not shutdown_event.is_set(): + while not evaluation_loop_should_shutdown_event.is_set(): await sampling_client_updated_event.wait() sampling_client_updated_event.clear() @@ -643,6 +686,7 @@ async def evaluation_loop(): metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) metrics["time/evaluation_loop/total"] = time.time() - t_start ml_logger.log_metrics(metrics, step=sampling_client_eval_step) + logger.info("[evaluation_loop] Evaluation loop terminated") await asyncio.gather( asyncio.create_task(dataloader_loop(), name="dataloader_loop"), @@ -787,12 +831,12 @@ async def compute_full_batch_metrics_and_get_sampling_client( async def do_train_step_streaming_and_get_sampling_client( cfg: Config, i_batch: int, - trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | None], + trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None], training_client: tinker.TrainingClient, service_client: tinker.ServiceClient, tokenizer: Tokenizer, trajectory_group_filter: Callable[[WrappedTrajectoryGroup | None], bool] = lambda _: True, -) -> tuple[tinker.SamplingClient, dict[str, Any]]: +) -> tuple[tinker.SamplingClient, dict[str, Any]] | None: """ As soon as we have enough trajectories for a minibatch, we will train on them. This allows us to overlap sampling and training. @@ -825,8 +869,16 @@ async def do_train_step_streaming_and_get_sampling_client( i_minibatch = 0 while i_minibatch < cfg.stream_minibatch_config.num_minibatches: wrapped_trajectory_group = await trajectory_groups_queue.get() - if not trajectory_group_filter(wrapped_trajectory_group): - continue + match wrapped_trajectory_group: + case WrappedTrajectoryGroup(): + pass + case Shutdown(): + logger.info( + "[do_train_step_streaming_and_get_sampling_client] Received shutdown signal" + ) + return None + case None: + continue wrapped_trajectory_groups.append(wrapped_trajectory_group) if len(wrapped_trajectory_groups) < groups_per_minibatch: From 77d0c582e680860774ea1f51ba4aa5119391c863 Mon Sep 17 00:00:00 2001 From: John Schulman Date: Sun, 23 Nov 2025 03:13:32 +0000 Subject: [PATCH 2/2] . --- .github/workflows/claude-review.yml | 1 + tinker_cookbook/rl/train.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/claude-review.yml b/.github/workflows/claude-review.yml index 4bb2f56a..4ee11dd9 100644 --- a/.github/workflows/claude-review.yml +++ b/.github/workflows/claude-review.yml @@ -14,6 +14,7 @@ on: jobs: claude_review: runs-on: ubuntu-latest + environment: claude-review steps: - name: Run Claude Code review uses: anthropics/claude-code-action@v1 diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 71fcae4e..604b9b4e 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -493,7 +493,7 @@ async def do_async_training( async def dataloader_loop(): """Gets the next set of env builders to run""" i_batch = start_batch - while not i_batch < end_batch: + while i_batch < end_batch: env_group_builders_P = dataset.get_batch(i_batch) for env_group_builder in env_group_builders_P: await env_group_builders_queue.put(env_group_builder)