diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 33c9f8ed..604b9b4e 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -7,12 +7,15 @@ import logging import os import time -from typing import Any, Callable, List, Literal, Sequence, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Iterator, List, Literal, Sequence import chz import numpy as np import tinker import torch + from tinker_cookbook import checkpoint_utils from tinker_cookbook.completers import TinkerTokenCompleter from tinker_cookbook.display import colorize_example @@ -39,9 +42,7 @@ from tinker_cookbook.tokenizer_utils import Tokenizer from tinker_cookbook.utils import logtree, ml_log from tinker_cookbook.utils.misc_utils import safezip, split_list, timed -from tinker_cookbook.utils.trace import scope, trace_init, get_scope_context -from contextlib import contextmanager - +from tinker_cookbook.utils.trace import get_scope_context, scope, trace_init logger = logging.getLogger(__name__) @@ -354,7 +355,7 @@ async def do_sync_training_with_stream_minibatch( ): # Samplers will produce trajectory groups asynchronously, # and the trainer will consume them as soon as they are ready - trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]() + trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]() env_group_builders_P = dataset.get_batch(i_batch) @scope @@ -393,10 +394,7 @@ async def trajectory_group_worker_task( ) # Run multiple optimizer substeps per training iteration - ( - sampling_client, - full_batch_metrics, - ) = await do_train_step_streaming_and_get_sampling_client( + streaming_result = await do_train_step_streaming_and_get_sampling_client( cfg, i_batch, trajectory_groups_queue, @@ -404,6 +402,10 @@ async def trajectory_group_worker_task( service_client, tokenizer, ) + if streaming_result is None: + logger.info("[do_sync_training_with_stream_minibatch] Received shutdown signal") + return + sampling_client, full_batch_metrics = streaming_result # Log metrics metrics.update(full_batch_metrics) @@ -428,6 +430,22 @@ class WrappedTrajectoryGroup: metrics: dict[str, Any] = chz.field(default_factory=dict) +@dataclass +class Shutdown: + pass + + +class AsyncCounter: + def __init__(self, start: int = 0): + self.value = start + self.lock = asyncio.Lock() + + async def decrement_and_get(self) -> int: + async with self.lock: + self.value -= 1 + return self.value + + @scope async def do_async_training( start_batch: int, @@ -444,13 +462,12 @@ async def do_async_training( """Implements async off-policy training, capped at K steps off policy.""" assert cfg.async_config is not None - shutdown_event = asyncio.Event() # We will have groups_per_batch worker generating rollouts, so cap the # queue size to be groups_per_batch. - env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | None]( + env_group_builders_queue = asyncio.Queue[EnvGroupBuilder | Shutdown]( maxsize=cfg.async_config.groups_per_batch ) - trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | None]() + trajectory_groups_queue = asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None]() # Initial sampling client to use path_dict = await checkpoint_utils.save_checkpoint_async( @@ -461,38 +478,46 @@ async def do_async_training( kind="both", ) + # When the dataloader is out of data, we want to make sure all remaining samples + # are processed before terminating. + evaluation_loop_should_shutdown_event = asyncio.Event() + trajectory_group_worker_alive_counter = AsyncCounter(cfg.async_config.groups_per_batch) + # This will be updated by the training loop sampling_client = training_client.create_sampling_client(path_dict["sampler_path"]) sampling_client_step = start_batch sampling_client_updated_event = asyncio.Event() sampling_client_updated_event.set() - @scope - def shutdown_loops(): - """Trigger all loops to shutdown""" - shutdown_event.set() - assert cfg.async_config is not None - for _ in range(cfg.async_config.groups_per_batch): - env_group_builders_queue.put_nowait(None) - sampling_client_updated_event.set() - @scope async def dataloader_loop(): """Gets the next set of env builders to run""" i_batch = start_batch - while not shutdown_event.is_set() and i_batch < end_batch: + while i_batch < end_batch: env_group_builders_P = dataset.get_batch(i_batch) for env_group_builder in env_group_builders_P: await env_group_builders_queue.put(env_group_builder) i_batch += 1 + # We are done with the data loader loop, enqueue sentinel values + # to allow the trajectory group worker loops to terminate. + logger.info("[dataloader_loop] No more data, shutting down trajectory group worker loops") + assert cfg.async_config is not None + for _ in range(cfg.async_config.groups_per_batch): + await env_group_builders_queue.put(Shutdown()) + logger.info("[dataloader_loop] Data loader loop terminated") + @scope async def trajectory_group_worker_loop(): """Generates trajectories for a single env builder""" - while not shutdown_event.is_set(): + while True: env_group_builder = await env_group_builders_queue.get() - if env_group_builder is None: - break + match env_group_builder: + case EnvGroupBuilder(): + pass + case Shutdown(): + logger.info("[trajectory_group_worker_loop] Received shutdown signal") + break metrics = {} t_start = time.time() @@ -518,6 +543,14 @@ async def trajectory_group_worker_loop(): metrics=metrics, ) ) + num_alive_workers = await trajectory_group_worker_alive_counter.decrement_and_get() + if num_alive_workers == 0: + # All workers are done, enqueue a sentinel to terminate the training loop + logger.info( + "[trajectory_group_worker_loop] Last worker terminated, shutting down training loop" + ) + trajectory_groups_queue.put_nowait(Shutdown()) + logger.info("[trajectory_group_worker_loop] Trajectory group worker loop terminated") @scope async def training_loop(): @@ -530,9 +563,6 @@ async def training_loop(): i_batch = start_batch wrapped_trajectory_groups = [] while i_batch < end_batch: - wrapped_trajectory_group = await trajectory_groups_queue.get() - if wrapped_trajectory_group is None: - continue @scope def filter_stale_trajectory_group( @@ -567,10 +597,7 @@ def filter_stale_trajectory_group( nonlocal sampling_client nonlocal sampling_client_step if cfg.stream_minibatch_config is not None: - ( - sampling_client, - train_step_metrics, - ) = await do_train_step_streaming_and_get_sampling_client( + streaming_result = await do_train_step_streaming_and_get_sampling_client( cfg, i_batch, trajectory_groups_queue, @@ -579,7 +606,21 @@ def filter_stale_trajectory_group( tokenizer, filter_stale_trajectory_group, ) + if streaming_result is None: + logger.info("[training_loop] Received shutdown signal") + break + sampling_client, train_step_metrics = streaming_result else: + wrapped_trajectory_group = await trajectory_groups_queue.get() + match wrapped_trajectory_group: + case WrappedTrajectoryGroup(): + pass + case Shutdown(): + logger.info("[training_loop] Received shutdown signal") + break + case None: + continue + if not filter_stale_trajectory_group(wrapped_trajectory_group): continue @@ -618,7 +659,9 @@ def filter_stale_trajectory_group( i_batch += 1 wrapped_trajectory_groups = [] - shutdown_loops() + evaluation_loop_should_shutdown_event.set() + sampling_client_updated_event.set() + logger.info("[training_loop] Training loop terminated") @scope async def evaluation_loop(): @@ -626,7 +669,7 @@ async def evaluation_loop(): if len(evaluators) == 0 or cfg.eval_every == 0: return - while not shutdown_event.is_set(): + while not evaluation_loop_should_shutdown_event.is_set(): await sampling_client_updated_event.wait() sampling_client_updated_event.clear() @@ -643,6 +686,7 @@ async def evaluation_loop(): metrics.update({f"test/{k}": v for k, v in eval_metrics.items()}) metrics["time/evaluation_loop/total"] = time.time() - t_start ml_logger.log_metrics(metrics, step=sampling_client_eval_step) + logger.info("[evaluation_loop] Evaluation loop terminated") await asyncio.gather( asyncio.create_task(dataloader_loop(), name="dataloader_loop"), @@ -787,12 +831,12 @@ async def compute_full_batch_metrics_and_get_sampling_client( async def do_train_step_streaming_and_get_sampling_client( cfg: Config, i_batch: int, - trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | None], + trajectory_groups_queue: asyncio.Queue[WrappedTrajectoryGroup | Shutdown | None], training_client: tinker.TrainingClient, service_client: tinker.ServiceClient, tokenizer: Tokenizer, trajectory_group_filter: Callable[[WrappedTrajectoryGroup | None], bool] = lambda _: True, -) -> tuple[tinker.SamplingClient, dict[str, Any]]: +) -> tuple[tinker.SamplingClient, dict[str, Any]] | None: """ As soon as we have enough trajectories for a minibatch, we will train on them. This allows us to overlap sampling and training. @@ -825,8 +869,16 @@ async def do_train_step_streaming_and_get_sampling_client( i_minibatch = 0 while i_minibatch < cfg.stream_minibatch_config.num_minibatches: wrapped_trajectory_group = await trajectory_groups_queue.get() - if not trajectory_group_filter(wrapped_trajectory_group): - continue + match wrapped_trajectory_group: + case WrappedTrajectoryGroup(): + pass + case Shutdown(): + logger.info( + "[do_train_step_streaming_and_get_sampling_client] Received shutdown signal" + ) + return None + case None: + continue wrapped_trajectory_groups.append(wrapped_trajectory_group) if len(wrapped_trajectory_groups) < groups_per_minibatch: