diff --git a/README.md b/README.md index 58bb8b1e..b89d97b2 100644 --- a/README.md +++ b/README.md @@ -350,6 +350,17 @@ PipelineRL is organized as a modular, Hydra-driven pipeline with 6 core componen - Pull a batch → call `rl_step(...)` (in `pipelinerl/finetune/rl/utils.py`) to compute policy-gradient (+ KL penalty if configured) → `optimizer.step()` → `lr_scheduler.step()`. - On rank 0, use `WeightUpdateManager.send_weight_update(version)` to gather model parameters, send `WeightUpdateRequest` to Actor LLMs (HTTP), broadcast tensors via NCCL, and write a `WeightUpdateSuccess` message to the update stream. +#### Fast-LLM trainer path (preview) + +When `use_fast_llm: true` (default in `conf/math.yaml`), the DeepSpeed ZeRO-3 trainer above is replaced with [Fast-LLM](https://github.com/ServiceNow/Fast-LLM) (FSDP + sequence-data-parallel) and the per-step weight update over HTTP is replaced with a persistent NCCL broadcast group: + +- Trainer: `fast_llm train gpt` launched via torchrun (`pipelinerl/launch.py:run_finetune`); rank 0 also serves the broadcast `TCPStore`. +- Fast-LLM's `StreamingTrainerCallback` gathers full-precision weights after each optimizer step and broadcasts them on a persistent NCCL group whose name is `WEIGHTS_BROADCAST_PG_NAME`. +- vLLM workers join the same group via `vllm1.init_actor_update_group(...)` and copy parameters into the model in place. +- Coordinated NCCL teardown (`pipelinerl/vllm1.py:484-547`) listens to a `training_finished` redis xadd from the trainer and destroys the process group on the vLLM side so `dist.destroy_process_group()` doesn't hang. + +This path is **WIP** — see [`docs/FAST_LLM_INTEGRATION.md`](docs/FAST_LLM_INTEGRATION.md) for known issues, configuration knobs, and example interactive-job scripts. + ### 6. Verifier - Entrypoint: `pipelinerl/entrypoints/verifier.py` - Serves a FastAPI app with: @@ -360,6 +371,7 @@ PipelineRL is organized as a modular, Hydra-driven pipeline with 6 core componen - Defined in `pipelinerl/streams.py`. - Implements `SingleStreamSpec` and `StreamRangeSpec` for file-system or Redis-based queues. - `write_to_streams(...)` and `read_stream(...)` provide a JSON-line protocol for inter-process messaging. +- Pass `shared=True` to these helpers when multiple actors must fan-in to a single Redis stream (e.g., ServiceNow/Fast-LLM trainer). The shared mode encodes payloads via `orjson`, tags them with a global index, and lets the trainer perform downstream sharding safely. - Available backends: - File system: default. - Redis: requires Redis server. @@ -371,3 +383,107 @@ PipelineRL is organized as a modular, Hydra-driven pipeline with 6 core componen - `training_data` stream (StreamRangeSpec(topic="training_data")): File- or Redis-backed stream used to transfer processed training micro-batches from the Preprocessor to the Trainer. Configured via `cfg.preprocess.output` and `cfg.finetune.input` (defaulting to "training_data") in `conf/base.yaml`. Written in `pipelinerl/run_preprocess.py` and consumed in `pipelinerl/run_finetune.py`. - `actor_test` and `stats_test` streams: analogous streams used for evaluation loops (test samples and test metrics). - `stats` stream (SingleStreamSpec(topic="stats")): produced by `ActorLoop.publish_stats` with sliding-window metrics; consumed by external monitoring (e.g. WANDB, logging viewers). + + + + +## Multi-Node Requirements + +PipelineRL can span multiple nodes, with actor (vLLM) and trainer roles on separate machines. Each role opens outbound TCP connections to other roles; every target port must be reachable from the source node. + +### Ports and config params + +| Port (default) | Config param | Direction | Purpose | +|---|---|---|---| +| `streams.port` (11000) | `conf/streams/redis.yaml` | all nodes → rank-0 node | Redis data streams (actor → preprocessor → trainer) | +| `world.actor_group_port` (9000) | `conf/base.yaml` | actor node → trainer node | Weight-broadcast process group (NCCL TCPStore rendezvous) | +| `world.environment_start_port` (7777) | `conf/base.yaml` | actor node → environment node | Remote environment HTTP server | +| `8080 + gpu_local_idx` | derived from GPU placement | trainer node → actor node | vLLM HTTP endpoints for weight updates, one per GPU | +| `MASTER_PORT` env var | set by your cluster launcher | trainer nodes ↔ each other | torchrun / accelerate rendezvous between finetune ranks | + +### What each node connects to + +**Trainer node** opens connections to: +- `{actor_node_ip}:{8080 + i}` for each vLLM GPU `i` — to POST updated weights after each optimizer step. +- `{rank_0_ip}:{streams.port}` — to read training batches from Redis (when `streams=redis`). + +**Actor node** opens connections to: +- `{rank_0_ip}:{streams.port}` — to publish rollout data to Redis. +- `{rank_0_ip}:{world.actor_group_port}` — to join the NCCL weight-broadcast process group (vLLM workers connect as clients; the trainer creates the TCPStore server on this port). +- `{env_node_ip}:{world.environment_start_port + i}` — to call remote environment servers (if `environments[*].mode=remote`). + +**All finetune nodes** connect to each other on `MASTER_PORT` for the distributed training rendezvous (rank-0 finetune node is the server). + +### Topology assumptions + +- With fast-llm (`use_fast_llm=true`), each component must occupy whole nodes — torchrun requires every finetune rank to see a complete, identical GPU set. +- With `world.preprocessor_fraction=0`, every node is either a pure actor node or a pure trainer node (no mixing). +- The DeepSpeed hostfile and `--deepspeed_inclusion_filter` use DNS/hostname names (not IPs), so the cluster rendezvous port (`MASTER_PORT`) must be reachable via those names. All other cross-node connections use IP addresses and are independent of DNS. + +### Running and resuming multi-node jobs + +**`world.run_id` is required for multi-node jobs.** It must be a string that is unique per job run. It namespaces the pod IP exchange directory on the shared NFS mount so that stale files from a previous run are never picked up by a new one. Any value that your cluster scheduler guarantees to be unique per job works — a job UUID, a replica-group ID, or the job's `MASTER_ADDR` (which is unique per torchrun launch): + +```bash +python -m pipelinerl.launch ... 'world.run_id=${MASTER_ADDR}' +``` + +**To resume a preempted run**, reuse the same `output_dir` as the original job. fast-LLM automatically finds the latest checkpoint in `output_dir/finetune/checkpoint/` and resumes from it. WandB also resumes the same run because fast-LLM persists the run ID in `output_dir/finetune/wandb_config.yaml` on the first launch and reloads it on every subsequent launch. + +Each resumed job must still use a fresh `world.run_id` (the new job's ID, not the original one), so the pod IP exchange directory is always clean. + +# Install FastLLM+PipelineRL + +> **Status (2026-05-06):** This integration is WIP — see [`docs/FAST_LLM_INTEGRATION.md`](docs/FAST_LLM_INTEGRATION.md) for the full handover (architecture, known issues, TODO). + +### 1. Container image + +To **use**: reference the prebuilt image +``` +registry.toolkit-sp.yul201.service-now.com/snow.research.afm/interactive-toolkit:25.12-py3-vllm014rc1redis +``` +It bundles the redis server. + +To **build** (from the [`ServiceNow/research-interactive-toolkit`](https://github.com/ServiceNow/research-interactive-toolkit/tree/fml/pytorch_vllm014rc1) repo, branch `fml/pytorch_vllm014rc1` — SN-internal, link is gated): set `~/.research-interactive-env` and run the toolkit's build target. + +```shell +USE_ACCOUNT_REPO := 1 +BASE_IMAGE := nvcr.io/nvidia/pytorch:25.12-py3 +IMAGE_REVISION := 25.12-py3-vllm014rc1redis +EAI_PROFILE := yul201 +``` + +Base layer is `nvcr.io/nvidia/pytorch:25.12-py3`; the toolkit branch layers on vLLM 0.14.0rc1, redis, and the EAI helpers. + +### 2. Clone + venv + editable installs + +Inside a running interactive instance, install both Fast-LLM and PipelineRL into a single venv at `PipelineRL/.venv`: + +```shell +git clone git@github.com:ServiceNow/Fast-LLM.git +git clone git@github.com:ServiceNow/PipelineRL.git + +cd PipelineRL +/usr/bin/python3.12 -m venv --system-site-packages .venv +source .venv/bin/activate +export PIP_CONSTRAINT="" + +# Fast-LLM: GSPO branch is the one paired with the PipelineRL fast-llm branch +cd ../Fast-LLM +git submodule update --init --recursive +git checkout gspo +pip install --no-cache-dir --no-build-isolation -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" triton==3.5.1 + +# PipelineRL: fast-llm branch +cd ../PipelineRL +git checkout fast-llm +pip install --no-cache-dir -e ".[lora]" +``` + +### 3. Known caveats + +- **`pyproject.toml:81-87`** — `[tool.uv]` overrides `transformers>=4.51.0` and `accelerate>=1.7.0` because `tapeagents==0.1.16` pins them lower; the `[tapeagents]` extra is **broken at runtime** until tapeagents bumps support. Track this as a TODO; do not enable `[tapeagents]` on the fast-llm path. +- **`PIP_CONSTRAINT=""`** is required — the toolkit image sets a constraint file that conflicts with our pinned versions. +- **Triton must be `==3.5.1`** — newer triton breaks the fast-llm GSPO kernels. + + diff --git a/conf/base.yaml b/conf/base.yaml index 2faa4b94..47b01a0b 100644 --- a/conf/base.yaml +++ b/conf/base.yaml @@ -74,7 +74,7 @@ vllm_config: world: replicas: 1 - + actor_fraction: 4 preprocessor_fraction: 0 finetune_fraction: 4 @@ -83,6 +83,11 @@ world: actor_group_port: 9000 environment_start_port: 7777 + # Unique identifier for this job run, used to namespace the pod IP exchange + # directory so stale files from previous runs are never seen. + # Defaults to $MASTER_ADDR when null (suitable for EAI and torchrun jobs). + run_id: null + # this will be autocreated based on the config jobs: [] @@ -100,7 +105,7 @@ fsdp: reduce_dtype: fp32 buffer_dtype: fp32 -output_dir: ??? +output_dir: null force_restart: false pop_old_data: true max_lag: null @@ -111,6 +116,93 @@ debug: streams_from: null place_inference_workers: true use_existing_llms: false + log_data_pipeline: false + +# Fast-LLM integration: when true, fast-llm is used as the trainer. +# Data flows actors -> Redis (fast_llm_streaming) -> fast-llm training loop. +# Weight updates are broadcast via NCCL using fast-llm's streaming callback. +use_fast_llm: true +# Whether the trainer broadcasts updated weights to vLLM after each training step. +weight_broadcast: true + +# Pure fast-llm config written as-is to a YAML file at launch time. +# Fields set to null are populated by the launcher at runtime (source noted in the comment) — do not modify them here. +# This section is only used when use_fast_llm: true. +fast_llm: + training: + num_workers: 0 + train_iters: 100000 # Total number of optimizer steps (provided by pipelinerl) + wandb: + entity_name: null # cfg.wandb.wandb_entity_name (null disables wandb) + project_name: null # cfg.wandb.wandb_project_name + group_name: null # cfg.wandb.wandb_group + logs: + interval: 1 # Logging frequency in optimizer steps + checkpoint: + interval: 1000 + export: + interval: 1000 + format: ${fast_llm_finetune.model_format} + + schedule: + depth_first_micro_batches: 16 # Gradient accumulation steps (sequential, one sample at a time) + + data: + micro_batch_size: 18000 # Tokens per sample; also the max rollout length accepted + truncate_documents: false # Do not truncate RL rollouts + shuffle: disabled # Streaming dataset ignores shuffling + datasets: + training: + type: streaming # Redis-backed streaming dataset + host: null # cfg.streams.host + port: null # cfg.streams.port + + pretrained: + format: ${fast_llm_finetune.model_format} + path: null # cfg.model_path + model_weights: true + + model: + base_model: + head: + losses: + grpo: + type: grpo + epsilon_low: 0.2 + epsilon_high: 0.2 + multi_stage: + zero_stage: 2 + distributed: + compute_dtype: bf16 + tensor_parallel: 1 + pipeline_parallel: 1 + sequence_data_parallel: 1 + + run: + experiment_dir: null # exp_dir/finetune + experiment_name: null # derived from exp_dir relative to cfg.wandb.wandb_workspace_root + + # callbacks section is written only when weight_broadcast: true (removed by launcher otherwise) + callbacks: + streaming: + type: streaming + host: null # cfg.streams.host + port: null # cfg.streams.port + broadcast: + backend: nccl + external_world_size: null # world_map.weight_update_group_size - 1 + host: null # world_map.master_addr + port: null # cfg.world.actor_group_port + export: + format: ${fast_llm_finetune.model_format} + model_weights: true + optimizer_state: false + +# Launcher-specific fast-llm settings (not passed to fast-llm itself). +fast_llm_finetune: + model_type: gpt # fast-llm model type argument: fast-llm train + model_format: qwen2 # pretrained/export format; interpolated into fast_llm config + torchrun_port: 29500 # master port for torchrun rendezvous me: # Which job is this one? This will be autopopulated diff --git a/conf/counting.yaml b/conf/counting.yaml index 97f61a88..9eaff581 100644 --- a/conf/counting.yaml +++ b/conf/counting.yaml @@ -3,6 +3,24 @@ defaults: finetune: seq_length: 4000 gradient_accumulation_passes: 1024 +vllm_config: + vllm_kwargs: + max_model_len: 4000 +fast_llm: + training: + num_workers: 1 + schedule: + depth_first_micro_batches: 256 + model: + base_model: + head: + losses: + grpo: + epsilon_low: 0.2 + epsilon_high: 0.2 + optimizer: + learning_rate: + base: 1e-5 llm: parameters: max_tokens: 1000 diff --git a/conf/math.yaml b/conf/math.yaml index 25629454..59ff9218 100644 --- a/conf/math.yaml +++ b/conf/math.yaml @@ -2,6 +2,13 @@ defaults: - base - _self_ +use_fast_llm: true +weight_broadcast: true + +fast_llm: + data: + micro_batch_size: 18000 + actor: rollout_policy: pipelinerl.domains.math.generate_math_rollout system_prompt: Please reason step by step, and put your final answer within \boxed{}. diff --git a/docs/FAST_LLM_INTEGRATION.md b/docs/FAST_LLM_INTEGRATION.md new file mode 100644 index 00000000..a2803eb6 --- /dev/null +++ b/docs/FAST_LLM_INTEGRATION.md @@ -0,0 +1,428 @@ +# Fast-LLM Integration — Handover + +> **Status:** WIP. Last verified end-to-end on 2026-05-06 with a 2-step smoke run on a 4-node EAI job (both DeepSpeed PPO and Fast-LLM GSPO finished cleanly, all metrics in expected ranges). +> +> **Authoring history:** Denis Kocetkov (denis.kocetkov@servicenow.com) — leaving the integration project. This document is the canonical handover; the [PR description](#) on GitHub is the executive summary. + +## Table of contents + +1. [Why fast-llm](#1-why-fast-llm) +2. [Branch state](#2-branch-state) +3. [End-to-end install](#3-end-to-end-install) +4. [Architecture (fast-llm path)](#4-architecture-fast-llm-path) +5. [Per-file changes](#5-per-file-changes) +6. [Configuration knobs](#6-configuration-knobs) +7. [Glossary](#7-glossary) +8. [Known issues & bugs](#8-known-issues--bugs) +9. [Testing](#9-testing) +10. [Operations](#10-operations) +11. [Where data lives](#11-where-data-lives) +12. [Open questions / decisions for the successor](#12-open-questions--decisions-for-the-successor) + +--- + +## 1. Why fast-llm + +DeepSpeed ZeRO-3 is the default trainer in PipelineRL. It works, but: + +- Weight updates to vLLM go over **HTTP**, gathered to rank 0 and POSTed; that's a serialization+network bottleneck on every optimizer step. +- ZeRO-3 partitioning forces a parameter all-gather every forward pass. +- DeepSpeed's loss/gradient pipeline is harder to extend with custom RL loss kernels (GSPO, GRPO with advanced metrics). + +[Fast-LLM](https://github.com/ServiceNow/Fast-LLM) replaces the trainer with FSDP + sequence-data-parallel (SDP) and broadcasts weights to vLLM over a **persistent NCCL group** instead of HTTP. The integration also adds custom GSPO/GRPO loss kernels with full DS parity (see PR #502 in the Fast-LLM repo). + +Goals: + +- **Higher GPU utilization** by avoiding HTTP serialization on every step. +- **More on-policy data** because broadcasts can run concurrently with vLLM generation. +- **Custom RL losses** (GSPO, sequence-level IS-ratio clipping) that are first-class in fast-llm. + +## 2. Branch state + +| Repo | Branch | Status | +|---|---|---| +| `ServiceNow/PipelineRL` | `fast-llm` | WIP, this PR's source branch | +| `ServiceNow/Fast-LLM` | `gspo` | WIP, Fast-LLM PR [#502](https://github.com/ServiceNow/Fast-LLM/pull/502) | + +The two branches must be used together. Fast-LLM's `gspo` branch contains the GSPO loss kernels, the divisor² + SDP loss-math fix, the `metrics: GRPOMetricsLevel` enum (merged from `grpo-metrics`), and `fp32_lm_head` precision matching for vLLM. The PipelineRL `fast-llm` branch contains the launcher integration, weight-broadcast plumbing, multi-node fixes, and the test suite (`tests/test_vllm1_*`, `tests/test_world_multinode.py`, `tests/test_actor_error_handling.py`). + +### Active CI + +There is **no CI specific to the fast-llm path**. Unit tests in `tests/` exercise weight-broadcast and multi-node behavior but do not run a full pipeline. Verifying the path requires a live multi-node smoke (see [§9 Testing](#9-testing)). + +## 3. End-to-end install + +### Image + +**To use**, reference the prebuilt image directly: + +``` +registry.toolkit-sp.yul201.service-now.com/snow.research.afm/interactive-toolkit:25.12-py3-vllm014rc1redis +``` + +It bundles the redis server (used by `streams=redis`). + +**To build it yourself** (e.g. when bumping the PyTorch / vLLM version — see open question 6 below): clone the [`ServiceNow/research-interactive-toolkit`](https://github.com/ServiceNow/research-interactive-toolkit/tree/fml/pytorch_vllm014rc1) repo (SN-internal, link is gated), check out branch `fml/pytorch_vllm014rc1`, then set `~/.research-interactive-env` and run the toolkit's build target: + +```shell +USE_ACCOUNT_REPO := 1 +BASE_IMAGE := nvcr.io/nvidia/pytorch:25.12-py3 +IMAGE_REVISION := 25.12-py3-vllm014rc1redis +EAI_PROFILE := yul201 +``` + +Base layer is `nvcr.io/nvidia/pytorch:25.12-py3`; the branch layers on vLLM 0.14.0rc1, redis, and the EAI helpers. + +### Launching an interactive EAI dev session + +Interactive jobs are single-replica dev environments (typically 1-2 GPUs) — they're for editing code, running tests, and submitting production multi-node training jobs *from inside them*. They are **not** the 4-node training environment themselves. + +To start one: + +1. Clone the toolkit repo (one-time): `git clone git@github.com:ServiceNow/research-interactive-toolkit.git ~/code/research-interactive-toolkit`. For the vLLM 0.14.0rc1 image, check out branch `fml/pytorch_vllm014rc1`. +2. Configure `~/.research-interactive-env` per the block above (selects image revision and EAI profile, plus `CPU`/`GPU`/`MEM` for the dev environment — typically modest, e.g. `GPU := 2`). +3. From the toolkit repo, run `make launch` and attach via VSCode Remote-SSH (full instructions in the toolkit README). +4. Inside the running interactive container, follow [§3 End-to-end install](#3-end-to-end-install) → "Steps" to clone Fast-LLM + PipelineRL into the venv. From there you can submit 4-node training jobs with `bash submit_eai_math_7b_multinode.sh 4` (etc.) — see §"How to launch" below. + +The 4-node training jobs run in their own EAI batch jobs (not in your interactive session). You only need the interactive session as the launch console / dev env. + +### Steps + +```bash +git clone git@github.com:ServiceNow/Fast-LLM.git +git clone git@github.com:ServiceNow/PipelineRL.git + +cd PipelineRL +/usr/bin/python3.12 -m venv --system-site-packages .venv +source .venv/bin/activate +export PIP_CONSTRAINT="" + +cd ../Fast-LLM +git submodule update --init --recursive +git checkout gspo +pip install --no-cache-dir --no-build-isolation \ + -e ".[CORE,OPTIONAL,HUGGINGFACE,SSM,VISION,GENERATION,STREAMING,DEV]" \ + triton==3.5.1 + +cd ../PipelineRL +git checkout fast-llm +pip install --no-cache-dir -e ".[lora]" +``` + +### Troubleshooting + +| Symptom | Cause | Fix | +|---|---|---| +| `pip` resolves wrong transformers / accelerate versions | `[tool.uv]` override in `pyproject.toml:81-87` only applies to uv | Stay on the listed versions; do not enable the `[tapeagents]` extra on this branch | +| Triton kernel compile errors on first GSPO step | Triton version drift | `pip install triton==3.5.1` (newer breaks GSPO kernels) | +| `pip install` killed mid-build | Default `TMPDIR=/tmp` ephemeral quota (16 GiB) on EAI | `export TMPDIR=$HOME/.tmp; mkdir -p $TMPDIR` before installing | +| `_GLIBCXX_USE_CXX11_ABI` mismatch when loading vLLM | PyTorch wheel C++ ABI mismatch | Check `python -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)"` and pick the right vLLM wheel (the toolkit image already matches) | +| `PIP_CONSTRAINT` errors | The toolkit image ships a constraints file conflicting with our pinned versions | `export PIP_CONSTRAINT=""` before any `pip install` | + +## 4. Architecture (fast-llm path) + +``` +┌─ orchestrator (pipelinerl.launch) ──────────────────────────────────┐ +│ │ +│ 1. pre-creates a TCPStore on world.actor_group_port (rank 0 only) │ +│ because torchrun sets TORCHELASTIC_USE_AGENT_STORE=True which │ +│ makes every rank a client by default → no server, no rendezvous│ +│ 2. launches actor (vLLM) and finetune (fast-llm) processes │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ + │ │ + ▼ ▼ +┌─ vLLM (run_vllm1.py) ────────┐ ┌─ fast-llm trainer (torchrun) ──┐ +│ │ │ │ +│ init_actor_update_group( │ │ StreamingTrainerCallback: │ +│ group_name= │◄──►│ - gather weights │ +│ WEIGHTS_BROADCAST_PG_ │ │ - broadcast on NCCL group │ +│ NAME) │ │ - xadd "step_done" event │ +│ │ │ │ +│ on "training_finished": │ │ on final step: │ +│ destroy_actor_update_ │ │ xadd "training_finished" │ +│ group() │ │ │ +│ │ │ │ +└──────────────────────────────┘ └────────────────────────────────┘ + │ │ + └──── redis (streams=redis) ────────────────────┘ + ▲ + │ + ┌──────┴────────────┐ + │ actor processes │ + │ (rollouts in) │ + └───────────────────┘ +``` + +### Weight-broadcast NCCL group + +| Property | Value | Source | +|---|---|---| +| Group name | `WEIGHTS_BROADCAST_PG_NAME` | both sides use the same string → matching store prefixes | +| Init method | `tcp://:9000` | `world.actor_group_port` from `conf/base.yaml` | +| TCPStore server | rank 0 of orchestrator (master node) | `pipelinerl/launch.py:998-1019` (only when `use_fast_llm and weight_broadcast`) | +| Member processes | trainer rank 0 (writer) + every vLLM worker (readers) | trainer joins via `init_extra_process_group(group_name=WEIGHTS_BROADCAST_PG_NAME)`; vLLM joins via `vllm1.init_actor_update_group` (`pipelinerl/vllm1.py:86-145`) | + +### Why pre-create the TCPStore + +When fast-llm runs under torchrun, every fast-llm process inherits `TORCHELASTIC_USE_AGENT_STORE=True` (hardcoded in `StaticTCPRendezvous.use_agent_store` — there is no flag to disable it). PyTorch's `_create_c10d_store` then treats every rank as a client. If nobody pre-creates the server, both fast-llm rank 0 and the vLLM workers spin retrying connection-refused on port 9000. + +Fix in `pipelinerl/launch.py:998-1019`: on `world_map.my_rank == 0`, create a `TCPStore(is_master=True, wait_for_workers=False)` on `world_map.master_addr:actor_group_port` before launching child processes, and keep the `broadcast_store` local alive until `watch_processes_running` returns. fast-llm and vLLM both then connect as clients to this pre-existing server. + +### Coordinated NCCL teardown + +`dist.destroy_process_group()` is a collective; if one side calls it and the other doesn't, the calling side hangs. The trainer xadds `{"type": "training_finished"}` to the `fast_llm_events` redis stream (see `FAST_LLM_EVENTS_STREAM` in `pipelinerl/state.py:20`); vLLM's monitoring thread (`pipelinerl/vllm1.py:484-547`) handles the event by calling `self.destroy_actor_update_group()` and setting `_fast_llm_stop_event`. Both sides then hit the collective barrier simultaneously. + +## 5. Per-file changes + +This is the high-impact subset, not exhaustive. Use `git log origin/main..fast-llm` for the full list. + +### Orchestrator / launcher + +| File | What | Why | +|---|---|---| +| `pipelinerl/launch.py:55-57` | Reject the deprecated top-level `fp32_lm_head` knob | After PR #137, fp32 is always-on; passing the old knob now warns | +| `pipelinerl/launch.py:88, 211, 246, 331, 397, 434-460` | Branch on `cfg.use_fast_llm` for finetune launch, callbacks injection, and weight-broadcast wiring | Replaces DS-only paths | +| `pipelinerl/launch.py:454-460` | Inject `callbacks.streaming.broadcast.{host,port,external_world_size}` from `world_map` | Lets fast-llm find the TCPStore at runtime | +| `pipelinerl/launch.py:998-1019` | Pre-create the broadcast `TCPStore` on rank 0 | Workaround for torchrun client-only rendezvous behavior | + +### State / actor + +| File | What | Why | +|---|---|---| +| `pipelinerl/state.py:24-29` | `samples_processed=0` in fast-llm mode (was `None`) | `wait_for_processed_samples()` would block at startup otherwise | +| `pipelinerl/state.py:64-141` | Background thread reading the `fast_llm_events` redis stream | Polls fast-llm trainer progress (samples processed, training_finished) | +| `pipelinerl/state.py:153-...` | `wait_for_training_done(timeout)` helper | Used by orchestrator shutdown | +| `pipelinerl/actor.py:158, 613-614` | `samples_target = max_train_steps × train_batch_size × gradient_accumulation_passes` | **See [§8 actor overshoot bug](#actor-overshoot)** — this assumption is wrong for fast-llm and stops the actor too early | + +### vLLM v1 worker + +| File | What | Why | +|---|---|---| +| `pipelinerl/vllm1.py:86-145` | `init_actor_update_group(group_name=WEIGHTS_BROADCAST_PG_NAME)` for fast-llm; `group_name="actor"` for HTTP mode | Matching store prefixes for rendezvous | +| `pipelinerl/vllm1.py:147-180` | `destroy_actor_update_group()` callable | NCCL teardown | +| `pipelinerl/vllm1.py:462, 484-547` | Background thread that consumes `fast_llm_events`; on `training_finished` schedules `destroy_actor_update_group` | Coordinated teardown | +| `pipelinerl/vllm1.py:567-571` | Fallback: forces stop if `training_finished` never arrives | Defensive | + +### Async LLM client (rollout retries) + +| File | What | Why | +|---|---|---| +| `pipelinerl/async_llm.py:61, 137-146, 194` | Retryable abort detection + `attempt=1/2` retry | vLLM aborts in-flight completions when weights are updated; we retry once | + +### Configs + +| File | What | Why | +|---|---|---| +| `conf/math.yaml:5-6` | `use_fast_llm: true` and `weight_broadcast: true` defaults | This config is the one verified end-to-end | +| `conf/base.yaml:78-89` | `world.actor_fraction`, `world.finetune_fraction`, `world.run_id` | Multi-node knobs | +| `conf/base.yaml:185-202` | `fast_llm.callbacks.streaming.broadcast.*` block (placeholder values) | Gets filled in at launch time by the launcher (see launch.py:454-460) | + +### Tests + +The fast-llm branch adds `tests/test_vllm1_fast_llm_broadcast.py`, `tests/test_vllm1_integration.py`, `tests/test_world_multinode.py`, `tests/test_actor_error_handling.py`, plus helpers (`tests/{vllm_engine_helper,distributed_trainer_helper,fast_llm_trainer_helper,server_weight_update_utils}.py`). They exercise weight-broadcast on a single host with a fake trainer + 1-3 vLLM workers (TP=1 or TP=2). They do **not** run a full multi-node pipeline. + +## 6. Configuration knobs + +PipelineRL side (Hydra overrides at launch): + +| Knob | Default | Notes | +|---|---|---| +| `use_fast_llm` | `false` (true in `math.yaml`) | Switches finetune path between DS and fast-llm | +| `weight_broadcast` | `true` | Enables NCCL broadcast group; disabling falls back to per-step HTTP weight updates | +| `streams` | `files` | **Must be `redis`** with `use_fast_llm=true` (files-mode dataset isn't implemented for fast-llm — see [§8 streams=files](#streams-files-not-supported)) | +| `world.actor_fraction` | `1` | Number of nodes hosting actor (vLLM) processes | +| `world.finetune_fraction` | `0` | Number of nodes hosting fast-llm trainer | +| `world.run_id` | `null` | **Required for multi-node** — see README §"Running and resuming multi-node jobs" | +| `world.actor_group_port` | `9000` | Broadcast TCPStore port | + +Fast-LLM side (passed as `+fast_llm.=value`): + +| Knob | Default | Notes | +|---|---|---| +| `fast_llm.model.distributed.sequence_data_parallel` | `1` | Set to `2` for 7B-math; loss-math fix divides by `sdp_size` (Fast-LLM `loss/grpo.py`) | +| `fast_llm.schedule.docs_per_step` | (set per run) | Documents per training step (e.g. 1024 for 7B-math) | +| `fast_llm.model.base_model.head.fp32_lm_head` | `false` | **Must be `true`** to match vLLM's `bf16_last_layer_fp32` precision (otherwise IS ratios diverge) | +| `fast_llm.model.base_model.head.losses.grpo.policy_loss` | `grpo` | `gspo` for sequence-level geometric-mean clipping | +| `fast_llm.model.base_model.head.losses.grpo.epsilon_low/_high` | `0.2 / 0.2` | Clipping thresholds | +| `fast_llm.model.base_model.head.losses.grpo.normalize_by_documents` | `false` | **Must be `true`** to match DeepSpeed's `1/batch_size` token weighting | +| `fast_llm.model.base_model.head.losses.grpo.temperature` | `1.0` | Set to actor's sampling temperature (e.g. `0.7`) so IS ratios start near 1 | +| `fast_llm.model.base_model.head.losses.grpo.metrics` | `none` | `none`/`basic`/`with_entropy` (see Fast-LLM PR #494). Replaces the old `compute_extra_metrics`/`compute_entropy_metric` flags | + +## 7. Glossary + +- **GRPO** — Group Relative Policy Optimization. Per-token IS-ratio clipping policy-gradient loss. +- **GSPO** — Group Sequence-level Policy Optimization. Geometric-mean IS-ratio clipping over the whole sequence (all tokens get the same multiplier). +- **DP / FSDP** — Data Parallel / Fully Sharded DP. FSDP shards parameters and gathers them on demand. +- **SDP** — Sequence Data Parallel (Fast-LLM concept). A second axis of parallelism that splits the *sequence* dimension across ranks. Requires extra all-reductions inside the loss. +- **ZeRO Stage 3** — DeepSpeed's parameter sharding. Equivalent to FSDP-1. +- **Microbatch / docs_per_step** — `docs_per_step` is the trainer's logical step size in *documents*. Each step consumes that many rollout documents; gradient accumulation breaks this into microbatches. +- **Broadcast PG** — the NCCL process group used to push weights from trainer rank 0 to vLLM workers. Created once and reused for every weight update. +- **`bf16_last_layer_fp32`** — vLLM's option to keep the LM head in fp32 while the rest of the model runs bf16. The trainer must match this exactly or IS ratios drift. + +## 8. Known issues & bugs + +### Actor `_prefetch_to_doc_target` overshoot — premature run end + +- **Symptom:** Long fast-llm runs (50+ steps) end before the configured `max_train_steps`. Actor signals completion → trainer stalls on the next step → `TimeoutError: No document received after 600 seconds`. Trainer reaches step ~43 of 50, run ends. +- **Root cause:** `pipelinerl/actor.py:158, 613-614` computes `samples_target = max_train_steps × train_batch_size × gradient_accumulation_passes` assuming exactly 1024 docs/step. Fast-LLM's `_prefetch_to_doc_target` (in Fast-LLM `fast_llm/engine/training/trainer.py:160-179`) overshoots `docs_per_step` by ~5–17% because of `while total_docs < target`. At runtime each step actually consumes ~1197 docs vs the 1024 target. The actor sees `samples_processed` cross `samples_target` early, signals completion, stops producing. +- **Workaround:** bump `max_train_steps` by ~20% (e.g. 50 → 60) so the actor has headroom. +- **Real fix:** make `actor.py:613` overshoot-aware (e.g. multiply by `(1 + safety_margin)` derived from `_prefetch_to_doc_target` actual ratio) or have the trainer signal "done" instead of the actor inferring it. +- **Memory file:** `project_actor_samples_target_overshoot_bug.md`. + +### Rollout retry exhaustion — occasional hang on bursts + +- **Symptom:** Actor logs show `Retryable aborted completion ... attempt=2/2 reason=finish_reason=abort`. Sometimes the second retry also aborts (because another weight update fires before the rollout completes), the request is dropped, and the rollout sits in the actor's "in_progress" tracking forever, blocking that slot. +- **Root cause:** vLLM aborts in-flight requests during weight updates. `pipelinerl/async_llm.py:137-146` retries once. Under bursty weight updates a single rollout can hit two consecutive aborts. +- **Workaround:** none currently; happens infrequently. +- **Real fix:** allow more retries (config flag), or make the actor evict rollouts that are stuck without a final response after N seconds. +- **Memory file:** `project_stall_investigation.md` (related, has more context). + +### Reward lag vs DeepSpeed — lower `actor/reward_mean` + +- **Symptom:** Even with exact `grpo_new_logprobs` parity (DS step 50 = -0.105, fast-llm step 50 = -0.103), fast-llm's `actor/reward_mean` lags DS by 2–3 EMA points throughout training. By step 400, fast-llm's `no_answer_mean` is **51× DS** (3.1% vs 0.06%). +- **Root cause:** Unknown. The trained model receives identical gradients (newlp parity verified), so the gap is upstream of the trainer — most likely in the data pipeline or in run-to-run sampling variance. Needs investigation, not a known fix. +- **Memory file:** `project_fastllm_reward_lag_after_gspo_fix.md`. + +### Current limitation: `streams=files` is not implemented for `use_fast_llm=true` + +Not a bug, just a current limitation: Fast-LLM only ships `RedisStreamingDataset`, so this branch requires `streams=redis`. If you launch with `use_fast_llm=true streams=files` you'll get an error from the launcher. **Memory file:** `project_streams_files_not_supported_fast_llm.md`. + +## 9. Testing + +### Unit tests (single host) + +```bash +cd /home/toolkit/code/PipelineRL +source .venv/bin/activate +pytest tests/test_vllm1_fast_llm_broadcast.py # weight broadcast +pytest tests/test_vllm1_integration.py # vLLM v1 path +pytest tests/test_world_multinode.py # topology / port assignment +pytest tests/test_actor_error_handling.py # rollout retry +``` + +These run on 1-3 GPUs (the helpers spawn TP=1 or TP=2 vLLM engines plus a fake trainer). + +### 4-node test results + +#### 2-step smoke (last verified 2026-05-06) + +Quick "everything launches" verification — temporarily set `max_train_steps=2` (and `train_iters=2` for fast-llm) in the submit script, launch, and look for the trainer's "Reached final step 2, stopping" / "Saving checkpoint at iteration 2" log line within ~10 minutes of `RUNNING`. Revert to 400 before committing. + +| Smoke | EAI Job | Step 1 grad_norm | Step 2 grad_norm | Step 1 newlp | Step 2 newlp | NaN | +|---|---|---|---|---|---|---| +| fast-llm GSPO | `59f3b62f` | 0.166 | 0.173 | -0.171 | -0.162 | 0 | +| DeepSpeed PPO | `084ef7d8` | 0.201 | 0.247 | -0.162 | -0.146 | 0 | + +#### 400-step training curves: fast-llm GSPO vs DeepSpeed GSPO + +Comparing fast-llm `math_7b_4node_fastllm_gspo_20260505_122944` (the divisor² + SDP fix run) against DeepSpeed `math_7b_ds_fastllm_4node_20260428_135427` (matching GSPO config: `policy_loss=gspo`, `epsilon_low=3e-3`, 400 steps). + +**`new_logprobs` — fast-llm matches DS step-by-step** (the GSPO loss math fix is correct): + +![new_logprobs fast-llm vs DS](images/new_logprobs.png) + +**`actor/reward_mean` — fast-llm lags DS by ~2 points at step 400** (the open issue, root cause unknown): + +![reward_mean fast-llm vs DS](images/reward_mean.png) + +### How to run 4-node tests + +#### Personalize + +Both submit launchers default to Denis's setup. Before running, override these env vars (or edit the defaults at the top of each script): + +| Env var | Default | What it is | +|---|---|---| +| `RESULTS_DIR` | `/mnt/shared/denis/math_7b_results` | Where outputs / checkpoints / logs land. Must be on a shared NFS readable by every node. | +| `WANDB_ENTITY` | `denisko-se` | Your wandb entity (user or org). | +| `WANDB_PROJECT` | `watermelon` | Your wandb project. | +| `EAI_HOME_DATA` | `snow.home.denis_kocetkov` | Your EAI home data object (mounted at `/home/toolkit` inside the container). | +| `EAI_SHARED_DATA` | `snow.research.afm.shared_fml` | Your shared NFS data object (mounted at `/mnt/shared`). | +| `MODEL_PATH` | `/home/toolkit/Qwen2.5-7B` | Path to the base model checkpoint inside the container. | + +The handover doc and PR description also mention `denisko-se/watermelon` runs and `/mnt/shared/denis/math_7b_results/` paths — those are pointers to Denis's historical runs and stay as-is for traceability; you don't need to edit them, just point your own runs to your own places. + +#### Reproduction scripts + +Two production launchers in the repo root reproduce the chart-baseline runs byte-for-byte. Each submits a 4-replica × 8-GPU EAI batch job. + +| Script | What it reproduces | +|---|---| +| [`submit_eai_math_7b_multinode.sh`](../submit_eai_math_7b_multinode.sh) | Fast-llm GSPO 400-step run — produced `math_7b_4node_fastllm_gspo_20260505_122944` (the chart's fast-llm curve). | +| [`submit_eai_math_7b_multinode_ds_fastllm_branch.sh`](../submit_eai_math_7b_multinode_ds_fastllm_branch.sh) | DS GSPO 400-step run — produced `math_7b_ds_fastllm_4node_20260428_135427` (the chart's DS curve). | + +#### Launch + +You launch these from inside an interactive EAI dev session (see §"Launching an interactive EAI dev session" above) — that's the dev/console environment. Each `bash submit_eai_*.sh 4` call submits a *separate* 4-node EAI batch job that runs the actual training; your interactive session is just the launch console and stays free. + +Prereqs: + +1. Fast-LLM + PipelineRL installed in a shared venv — see [§3 End-to-end install → "Steps"](#3-end-to-end-install) above (clones both repos, checks out `gspo` and `fast-llm` branches, editable-installs). +2. `eai` CLI authenticated. Run `eai login` once if it isn't already. +3. Wandb credentials configured for the entity in `WANDB_ENTITY` (`~/.netrc` or `wandb login`). +4. The personalization env vars above exported (or edit the defaults in the script). +5. A 7B base model checkpoint at the path `MODEL_PATH` points to (default `/home/toolkit/Qwen2.5-7B`). + +```bash +# fast-llm GSPO (4 replicas × 8 GPUs = 32 GPUs total, ~9-14 h wall clock for 400 steps) +bash submit_eai_math_7b_multinode.sh 4 + +# DS GSPO (same compute footprint) +bash submit_eai_math_7b_multinode_ds_fastllm_branch.sh 4 +``` + +Each call returns a job ID and queues a 4-replica × 8-GPU EAI job. The job creates `${RESULTS_DIR}/${EXP_NAME}/` with `launch.log`, `finetune/stdout_node*.log`, `actor/info.log`, `actor_vllm_*/{stdout,stderr}.log`, and a `wandb_config.yaml` with the resumable wandb run id. WandB run name is set via `+wandb.wandb_run_name=...` and includes the timestamp. + +To monitor: `eai job logs ` or tail the log files directly on the shared NFS mount (`/mnt/shared/...`). To stop early: `eai job kill ` (sends SIGINT — orchestrator does the coordinated NCCL teardown). + +## 10. Operations + +### Where logs live + +For an EAI-launched job with `output_dir=/mnt/shared/.../`: + +| Log | Path | +|---|---| +| Orchestrator | `/launch.log` | +| fast-llm trainer | `/finetune/stdout_node{N}.log` (per-rank training metrics on stdout) | +| DeepSpeed trainer | `/finetune/stderr_node{N}.log` (`pipelinerl.finetune_loop - Completed steps N: {...}`) | +| Actor | `/actor/info.log` and `actor/debug.log` | +| vLLM workers | `/actor_vllm_/{stdout,stderr}.log` | +| Redis | `/redis/redis.log` | + +**Common gotcha:** fast-llm prints step metrics to **stdout**; DeepSpeed prints them to **stderr** as `pipelinerl.finetune_loop` log lines. Both are normal; don't grep one and assume the other is broken. + +### How to monitor a running EAI job + +```bash +eai job ls --account snow.research.afm | grep +eai job logs # streamed +eai job kill # graceful shutdown signal +``` + +For shutdown semantics, **always** SIGINT the launch process (don't `kill -9` the children) — the orchestrator's coordinated NCCL teardown depends on a clean signal path. + +### WandB + +- Project: `denisko-se/watermelon` +- Group: `eai_math7b_fastllm_gspo` (fast-llm) / `eai_math7b_ds_fastllm` (DS) +- Run name: set via `+wandb.wandb_run_name=...` + +## 11. Where data lives + +| What | Where | +|---|---| +| Shared NFS results dir | `/mnt/shared/denis/math_7b_results/` | +| Model checkpoints (Qwen2.5-7B) | `/home/toolkit/Qwen2.5-7B/` | +| Code (PipelineRL, Fast-LLM) | `/home/toolkit/code/{PipelineRL,Fast-LLM}/` | +| venv | `/home/toolkit/code/PipelineRL/.venv/` | + +## 12. Open questions / decisions for the successor + +1. **Fix or compensate the actor overshoot?** Cleanest is to make the trainer signal "done" instead of the actor computing a target. Workaround is a constant safety multiplier in `actor.py:613`. +2. **Reward lag root cause.** Need to identify where the gap comes from before deciding whether it's worth fixing on this branch. +3. **Should the GSPO loss math fix (Fast-LLM PR #502) be merged before this PipelineRL PR?** Yes — this PR pins to the `gspo` branch by name; once `gspo` merges to Fast-LLM `main` we should rev this branch's install instructions to use `main`. +4. **Resolve the commented-out `pyproject.toml` overrides** (`pyproject.toml:81-87`). The `[tool.uv]` block force-overrides `transformers>=4.51.0` / `accelerate>=1.7.0` because `tapeagents==0.1.16` pins them lower; the `[tapeagents]` extra is broken at runtime as a result. Either bump tapeagents (when upstream supports newer libs) or drop the extra altogether on this branch. +5. **Close metric gaps on the fast-llm finetune side**, e.g. `rl/ess` (effective sample size — diagnostic for data/policy drift). Diff DS's `rl/*` and `stats/*` against fast-llm's `training.*` and pick what's worth porting. +6. **Move off the interactive-toolkit base image and the vLLM 0.14.0rc1 pin.** Current image is `interactive-toolkit:25.12-py3-vllm014rc1redis` (PyTorch 25.12 + vLLM 0.14.0rc1 + bundled redis). Step up to the latest base PyTorch and vLLM versions that Fast-LLM and PipelineRL both support, then re-verify the smoke runs. diff --git a/docs/images/new_logprobs.png b/docs/images/new_logprobs.png new file mode 100644 index 00000000..a620bd6f Binary files /dev/null and b/docs/images/new_logprobs.png differ diff --git a/docs/images/reward_mean.png b/docs/images/reward_mean.png new file mode 100644 index 00000000..4bef318c Binary files /dev/null and b/docs/images/reward_mean.png differ diff --git a/monitor_jobs.sh b/monitor_jobs.sh new file mode 100755 index 00000000..57a06c84 --- /dev/null +++ b/monitor_jobs.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# Monitor two comparison jobs for failure/cancellation/preemption. +# Usage: bash monitor_jobs.sh + +DS_JOB="${1:-fe9561a0-5c66-4971-88b3-d38bcab0b6e4}" +FL_JOB="${2:-18baa4d1-8f91-4153-9d1c-0affb7d62536}" +DS_DIR="${3:-/mnt/shared/denis/math_7b_results/math_7b_ds_fastllm_4node_20260428_135427}" +FL_DIR="${4:-/mnt/shared/denis/math_7b_results/math_7b_4node_fastllm_gspo_20260428_135448}" + +BAD_STATES="FAILED CANCELLED PREEMPTED INTERRUPTED" +INTERVAL=120 # seconds between polls + +log() { echo "[$(date '+%H:%M:%S')] $*"; } + +check_job() { + local job_id="$1" + local label="$2" + local state + state=$(eai job get "$job_id" 2>/dev/null | awk 'NR==2{print $2}') + if [ -z "$state" ]; then + state="UNKNOWN" + fi + for bad in $BAD_STATES; do + if [ "$state" = "$bad" ]; then + log "ALERT: $label ($job_id) is $state" + return 1 + fi + done + log "$label ($job_id): $state" + return 0 +} + +check_dir() { + local dir="$1" + local label="$2" + local count + count=$(find "$dir" -maxdepth 1 -mindepth 1 2>/dev/null | wc -l) + log "$label dir has $count top-level entries" +} + +log "Monitoring DS job: $DS_JOB" +log "Monitoring FastLLM job: $FL_JOB" +log "DS dir: $DS_DIR" +log "FastLLM dir: $FL_DIR" +log "Polling every ${INTERVAL}s. Ctrl-C to stop." +echo "" + +ds_alive=1 +fl_alive=1 + +while true; do + if [ $ds_alive -eq 1 ]; then + check_job "$DS_JOB" "DS" || ds_alive=0 + fi + if [ $fl_alive -eq 1 ]; then + check_job "$FL_JOB" "FastLLM" || fl_alive=0 + fi + check_dir "$DS_DIR" "DS" + check_dir "$FL_DIR" "FastLLM" + echo "" + + if [ $ds_alive -eq 0 ] && [ $fl_alive -eq 0 ]; then + log "Both jobs ended. Exiting." + break + fi + + sleep "$INTERVAL" +done diff --git a/pipelinerl/actor.py b/pipelinerl/actor.py index d276747d..1a41020a 100644 --- a/pipelinerl/actor.py +++ b/pipelinerl/actor.py @@ -132,6 +132,19 @@ async def schedule_rollouts( """ loop = asyncio.get_running_loop() + # Diagnostic logging (Process B side) – enabled by debug.log_data_pipeline + _pb_log_file = None + if cfg.debug.get("log_data_pipeline", False): + import json as _json_b + import pathlib as _pathlib_b + _log_dir_b = _pathlib_b.Path(cfg.output_dir) / "actor" / "data_pipeline_log" + _log_dir_b.mkdir(parents=True, exist_ok=True) + # Use scheduler_name to distinguish multiple workers + _safe_name = scheduler_name.replace(" ", "_").replace("/", "_").replace(",", "") + _pb_log_file = open(_log_dir_b / f"process_b_{_safe_name}.jsonl", "a") + _pb_problem_queue_empty_count = 0 + _pb_llm_busy_count = 0 + # Track active tasks per LLM active_rollouts = [0] * len(llms) started_rollouts = 0 @@ -145,6 +158,7 @@ async def schedule_rollouts( samples_target = final_steps * cfg.finetune.train_batch_size * cfg.finetune.gradient_accumulation_passes retryable_rollout_exceptions = ( aiohttp.ServerTimeoutError, + aiohttp.ServerDisconnectedError, asyncio.TimeoutError, TimeoutError, RetryableAbortedCompletionError, @@ -180,7 +194,7 @@ async def rollout_and_maybe_produce_result( llm_index: int, session: aiohttp.ClientSession, ): - nonlocal started_rollouts, finished_rollouts + nonlocal started_rollouts, finished_rollouts, _pb_problem_queue_empty_count, _pb_llm_busy_count try: llm = llms[llm_index] model_version = trainer_state.propagated_weight_version @@ -192,21 +206,35 @@ async def rollout_and_maybe_produce_result( break except asyncio.CancelledError: raise - except Exception as exc: - is_retryable = isinstance(exc, retryable_rollout_exceptions) - can_retry = max_rollout_retries < 0 or retry_count < max_rollout_retries - if is_retryable and can_retry and not is_trainer_finished(): - retry_count += 1 - backoff_s = min(retry_max_delay_s, retry_initial_delay_s * (2 ** (retry_count - 1))) - if retry_count == 1 or retry_count % 10 == 0: - logger.warning( - f"{scheduler_name}: rollout {group_id}/{rollout_index} failed with " - f"{exc.__class__.__name__}, retry {retry_count}" - ) - await asyncio.sleep(backoff_s) - continue - handle_rollout_exception(exc) - return + except aiohttp.ClientResponseError as http_exc: + if 400 <= http_exc.status < 500: + logger.warning( + f"Rollout failed with HTTP {http_exc.status} for group {group_id}, " + f"skipping this rollout: {http_exc.message}" + ) + rollout_result = RolloutResult( + training_texts=[], + metrics=BaseMetrics(reward=0.0, success=False, no_error=False, no_answer=True), + latency=0.0, + ) + break + exc = http_exc + except Exception as exc_: + exc = exc_ + is_retryable = isinstance(exc, retryable_rollout_exceptions) + can_retry = max_rollout_retries < 0 or retry_count < max_rollout_retries + if is_retryable and can_retry and not is_trainer_finished(): + retry_count += 1 + backoff_s = min(retry_max_delay_s, retry_initial_delay_s * (2 ** (retry_count - 1))) + if retry_count == 1 or retry_count % 10 == 0: + logger.warning( + f"{scheduler_name}: rollout {group_id}/{rollout_index} failed with " + f"{exc.__class__.__name__}, retry {retry_count}" + ) + await asyncio.sleep(backoff_s) + continue + handle_rollout_exception(exc) + return rollout_result.model_version = model_version # Make a group id that will be different from groups made by another rollout maker full_group_id = f"{scheduler_name}_{group_id}" @@ -219,10 +247,35 @@ async def rollout_and_maybe_produce_result( sample.group_id = full_group_id group_rollouts[group_id].append(rollout_result) if len(group_rollouts[group_id]) == attempts: - # This is blocking call, but there's just one other thread reading from this queue. - random.shuffle(group_rollouts[group_id]) - result_queue.put(group_rollouts[group_id]) + # Filter out empty results (failed rollouts with no training data) + valid_results = [r for r in group_rollouts[group_id] if r.training_texts] + if not valid_results: + logger.warning( + f"Dropping group {group_id}: all {attempts} rollouts failed " + f"(no training samples produced)" + ) + del group_rollouts[group_id] + finished_rollouts += 1 + return + random.shuffle(valid_results) del group_rollouts[group_id] + _t_put_start = time.monotonic() + await asyncio.get_event_loop().run_in_executor(None, result_queue.put, valid_results) + _put_duration = time.monotonic() - _t_put_start + if _pb_log_file is not None: + _pb_log_file.write(_json_b.dumps({ + "wall": time.time(), + "event": "put", + "put_blocked_s": _put_duration, + "result_queue_depth_after": result_queue.qsize(), + "active_rollouts": sum(active_rollouts), + "groups_in_progress": len(group_rollouts), + "problem_queue_empty_since_last": _pb_problem_queue_empty_count, + "llm_busy_since_last": _pb_llm_busy_count, + }) + "\n") + _pb_log_file.flush() + _pb_problem_queue_empty_count = 0 + _pb_llm_busy_count = 0 finished_rollouts += 1 except Exception as e: handle_rollout_exception(e) @@ -259,6 +312,7 @@ async def rollout_and_maybe_produce_result( problem = problem_queue.get(block=False) except Empty: # give some quality time for other couroutines to work + _pb_problem_queue_empty_count += 1 await asyncio.sleep(0.01) continue group_id += 1 @@ -268,6 +322,7 @@ async def rollout_and_maybe_produce_result( next_llm = active_rollouts.index(min(active_rollouts)) if active_rollouts[next_llm] == cfg.actor.llm_max_rollouts: # all llms are busy, wait for one to finish + _pb_llm_busy_count += 1 await asyncio.sleep(0.01) continue active_rollouts[next_llm] += 1 @@ -284,6 +339,8 @@ async def rollout_and_maybe_produce_result( ) group_rollout_index += 1 logger.info("Rollout scheduler finished") + if _pb_log_file is not None: + _pb_log_file.close() def rollout_maker_entrypoint( @@ -294,7 +351,7 @@ def rollout_maker_entrypoint( llms: list[TrainableLLM], scheduler_name: str, ): - trainer_state = TrainerState(Path(cfg.output_dir)) + trainer_state = TrainerState(Path(cfg.output_dir), use_fast_llm=cfg.use_fast_llm, weight_broadcast=cfg.weight_broadcast) if cfg.debug.mode: trainer_state.propagated_weight_version = 0 else: @@ -476,6 +533,16 @@ def _run(self, dataset: list[tuple[str, dict]]): published_samples = 0 submitted_groups = 0 finished_groups = 0 + + # Diagnostic logging setup (enabled by debug.log_data_pipeline) + _pipeline_log_file = None + if self.is_training and self.cfg.debug.get("log_data_pipeline", False): + import json as _json + import pathlib as _pathlib + _log_dir = _pathlib.Path(self.cfg.output_dir) / "actor" / "data_pipeline_log" + _log_dir.mkdir(parents=True, exist_ok=True) + _pipeline_log_file = open(_log_dir / "process_a.jsonl", "a") + _last_publish_wall = None # wall clock of last successful publish expected_rollouts = -1 if self.is_training else len(dataset) if expected_rollouts > 0: logger.info(f"Will stop after {expected_rollouts} rollouts") @@ -583,14 +650,16 @@ def _run(self, dataset: list[tuple[str, dict]]): except queue.Empty: continue + _t_got = time.monotonic() + if isinstance(rollout_results, Exception): logger.error("Stop actor loop due to error") raise rollout_results assert isinstance(rollout_results, list) assert isinstance(rollout_results[0], RolloutResult) - assert len(rollout_results) == attempts, ( - f"Expected {attempts} rollouts, got {len(rollout_results)}" + assert 0 < len(rollout_results) <= attempts, ( + f"Expected 1-{attempts} rollouts, got {len(rollout_results)}" ) group_samples = sum(len(r.training_texts) for r in rollout_results) @@ -649,7 +718,9 @@ def _run(self, dataset: list[tuple[str, dict]]): for r in rollout_results: for text in r.training_texts: all_text_dumps.append(text.model_dump()) + _t_before_redis = time.monotonic() data_stream_writer.write(all_text_dumps) + _t_after_redis = time.monotonic() in_progress = submitted_groups - finished_groups logger.info( f"Published {group_samples} {'train' if self.is_training else 'test'} samples" @@ -663,10 +734,12 @@ def _run(self, dataset: list[tuple[str, dict]]): time_to_publish_train_stats = ( self.is_training and trainer_version_to_publish is not None - ) or self.debug_mode + ) or self.debug_mode time_to_publish_test_stats = finished_groups == expected_rollouts # Publish stats at every new model version or if all tapes are finished + _t_before_stats = None + _t_after_stats = None if time_to_publish_train_stats or time_to_publish_test_stats: if self.is_training: loop_stats = { @@ -674,7 +747,7 @@ def _run(self, dataset: list[tuple[str, dict]]): "problem_queue_size": self.problem_queue.qsize(), "result_queue_size": self.result_queue.qsize(), "finished_groups": finished_groups, - "trainer_model_version": trainer_version_to_publish, + "trainer_model_version": trainer_version_to_publish, "time_since_start": time.time() - loop_start_time, } trainer_version_to_publish = None @@ -683,16 +756,38 @@ def _run(self, dataset: list[tuple[str, dict]]): "trainer_model_version": last_trainer_version } + _t_before_stats = time.monotonic() self.publish_stats( stats_writer=stats_writer, loop_stats=loop_stats, ) + _t_after_stats = time.monotonic() + + if _pipeline_log_file is not None: + _now = time.monotonic() + _entry = { + "wall": time.time(), + "finished_groups": finished_groups, + "result_queue_depth": self.result_queue.qsize(), + "inter_publish_gap_s": _t_got - _last_publish_wall if _last_publish_wall is not None else None, + "process_s": _t_before_redis - _t_got, + "redis_write_s": _t_after_redis - _t_before_redis, + "stats_write_s": (_t_after_stats - _t_before_stats) if _t_before_stats is not None else None, + "total_cycle_s": _now - _t_got, + "group_samples": group_samples, + } + _pipeline_log_file.write(_json.dumps(_entry) + "\n") + _pipeline_log_file.flush() + _last_publish_wall = _t_got if finished_groups == expected_rollouts: logger.info(f"Finished {expected_rollouts} rollouts, stopping actor loop") break + if _pipeline_log_file is not None: + _pipeline_log_file.close() + def publish_stats(self, stats_writer: StreamWriter, loop_stats: Dict): split_name = "test_" if not self.is_training else "" @@ -842,7 +937,7 @@ def run_actor_loop(cfg: DictConfig): wait_for_inference_servers(llm_urls) wait_for_environments(cfg) - trainer_state = TrainerState(exp_path) + trainer_state = TrainerState(exp_path, use_fast_llm=cfg.use_fast_llm, weight_broadcast=cfg.weight_broadcast) if cfg.debug.mode: trainer_state.debug_mode_init() else: diff --git a/pipelinerl/async_llm.py b/pipelinerl/async_llm.py index b305c458..fd9e1137 100644 --- a/pipelinerl/async_llm.py +++ b/pipelinerl/async_llm.py @@ -1,3 +1,4 @@ +import asyncio import base64 import io import logging @@ -188,6 +189,11 @@ async def llm_async_generate( logger.exception(f"Failed to parse llm response: {response_data}") raise + if finish_reason == "abort": + raise asyncio.TimeoutError( + f"vLLM aborted request (weight update in progress); will retry" + ) + output = LLMOutput(content=content) llm_call = llm.log_output(prompt, output, count_tokens=False) llm_call.prompt_length_tokens = response_data["usage"]["prompt_tokens"] diff --git a/pipelinerl/launch.py b/pipelinerl/launch.py index 725f2e02..38ac3404 100644 --- a/pipelinerl/launch.py +++ b/pipelinerl/launch.py @@ -2,6 +2,7 @@ import math import os import shutil +import socket import subprocess import sys import time @@ -84,9 +85,15 @@ def validate_config(cfg: DictConfig): raise ValueError("value_loss_coef must be greater than 0 when using causal-language-modeling-with-value-head") # Check that model being tuned to the max length accepted by inference - if cfg.finetune.seq_length < cfg.vllm_config.vllm_kwargs.max_model_len: + if cfg.use_fast_llm: + max_seq_length = cfg.fast_llm.data.micro_batch_size + seq_length_label = "fast_llm.data.micro_batch_size" + else: + max_seq_length = cfg.finetune.seq_length + seq_length_label = "finetune.seq_length" + if max_seq_length < cfg.vllm_config.vllm_kwargs.max_model_len: raise ValueError( - f"seq_length {cfg.finetune.seq_length} must be greater than or equal to " + f"{seq_length_label} {max_seq_length} must be greater than or equal to " f"vllm_kwargs.max_model_len {cfg.vllm_config.vllm_kwargs.max_model_len}" ) @@ -155,7 +162,7 @@ def run_ref_llm(cfg: DictConfig, preprocessor_llm_idx: int, local_idx: int, gpus os.makedirs(log_dir, exist_ok=True) cmd = [ - "python", + sys.executable, "-m", "vllm.entrypoints.openai.api_server", "--model", @@ -201,8 +208,9 @@ def run_actor_llm( log_dir = exp_dir / f"actor_vllm_{actor_llm_idx}" os.makedirs(log_dir, exist_ok=True) entrypoint = "pipelinerl.entrypoints.run_vllm1" + broadcast_port = cfg.world.actor_group_port cmd = [ - "python", + sys.executable, "-m", entrypoint, "--model", @@ -216,7 +224,7 @@ def run_actor_llm( "--actor-llm-idx", str(actor_llm_idx), "--weight-update-group-init-method", - f"tcp://{world_map.master_addr}:{cfg.world.actor_group_port}", + f"tcp://{world_map.master_addr}:{broadcast_port}", "--weight-update-group-world-size", str(world_map.weight_update_group_size), ] @@ -224,18 +232,34 @@ def run_actor_llm( cmd.extend(_get_quantization_args(cfg)) kwargs = _get_vllm_kwargs(cfg) + # vLLM v1 rejects num-scheduler-steps; defensively drop it. + if "num-scheduler-steps" in kwargs: + kwargs.pop("num-scheduler-steps") if kwargs: _append_vllm_kwargs(cmd, kwargs) - if cfg.debug.mode: + if cfg.debug.mode or not cfg.weight_broadcast: cmd.append("--disable-weight-updates") + # Always tell the vLLM actor server which weight-update protocol to use, + # so its conditional init takes the right branch (HTTP vs fast-llm broadcast). + if cfg.use_fast_llm: + cmd += [ + "--weight-update-mode", "fast-llm", + "--redis-host", cfg.streams.host, + "--redis-port", str(cfg.streams.port), + ] + else: + cmd += ["--weight-update-mode", "http"] + gpu_str = ",".join([str(gpu) for gpu in gpus]) logger.info(f"Running actor_llm with command: {' '.join(cmd)} on gpus: {gpu_str}") save_command(log_dir, cmd) log_file_path = os.path.join(log_dir, "stdout.log") err_file_path = os.path.join(log_dir, "stderr.log") - env = {**os.environ, "CUDA_VISIBLE_DEVICES": gpu_str, **_get_quantization_env(cfg)} + # Give each actor a distinct base port so vLLM's get_open_port() race condition + # (TOCTOU: find-free-port then bind) doesn't cause EADDRINUSE when multiple servers start simultaneously. + env = {**os.environ, "CUDA_VISIBLE_DEVICES": gpu_str, "VLLM_PORT": str(30000 + actor_llm_idx * 20), **_get_quantization_env(cfg)} with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: proc = _popen( cmd, @@ -252,7 +276,7 @@ def run_actor(world_map: WorldMap, actor_idx: int, exp_dir: Path): raise NotImplementedError("Can only do 1 actor yet") llm_urls = "+".join(world_map.get_actor_urls()) cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_actor", "--config-dir", @@ -276,7 +300,7 @@ def run_environment(cfg: DictConfig, job: Job): # run in a subprocess like in the rest of the code run_dir = Path(cfg.output_dir) / f"environment_{job.replica_idx}" cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_environment", "--config-dir", @@ -304,6 +328,13 @@ def run_environment(cfg: DictConfig, job: Job): def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: Path): + if cfg.use_fast_llm: + yield from _run_finetune_fast_llm(cfg, world_map, gpus, exp_dir) + else: + yield from _run_finetune_deepspeed(cfg, world_map, gpus, exp_dir) + + +def _run_finetune_deepspeed(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: Path): if cfg.use_fsdp and cfg.use_deepspeed: raise ValueError("Cannot use both FSDP and DeepSpeed") cmd = [ @@ -312,10 +343,10 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: "accelerate.commands.launch", ] if world_map.world_size > 1: - # DeepSpeed multi-node args assert cfg.use_deepspeed - assert world_map.master_addr.startswith("dns-") and world_map.master_addr.endswith("-0") - hosts = [world_map.master_addr[:-2] + f"-{i}" for i in range(world_map.world_size)] + # Use original DNS names (pod IP exchange may have replaced address_map with IPs). + dns_map = getattr(world_map, "dns_address_map", world_map.address_map) + hosts = [dns_map[i] for i in range(world_map.world_size)] filter_parts = [] for rank, job_list in world_map.job_map.items(): for job in job_list: @@ -323,34 +354,23 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: filter_parts.append(f"{hosts[rank]}:{','.join(map(str, job.gpus))}") deepspeed_include_filter = "@".join(filter_parts) logger.info(f"Deepspeed include filter: {deepspeed_include_filter}") - # Orchestrator rank must have already created hostfile.txt hostfile_path = str(exp_dir / "hostfile.txt") cmd += [ - "--num_machines", - str(len(world_map.nodes_with_finetuning())), - "--machine_rank", - str(world_map.my_finetuning_rank()), - "--main_process_ip", - str(os.environ.get("MASTER_ADDR")), - "--main_process_port", - str(os.environ.get("MASTER_PORT")), - "--deepspeed_hostfile", - hostfile_path, - "--deepspeed_inclusion_filter", - deepspeed_include_filter, - "--deepspeed_multinode_launcher", - "nossh" + "--num_machines", str(len(world_map.nodes_with_finetuning())), + "--machine_rank", str(world_map.my_finetuning_rank()), + "--main_process_ip", str(os.environ.get("MASTER_ADDR")), + "--main_process_port", str(os.environ.get("MASTER_PORT")), + "--deepspeed_hostfile", hostfile_path, + "--deepspeed_inclusion_filter", deepspeed_include_filter, + "--deepspeed_multinode_launcher", "nossh", ] - # get path to this file this_file_path = Path(os.path.dirname(os.path.abspath(__file__))) if cfg.use_deepspeed: - # DeepSpeed single-node args cmd += [ "--use_deepspeed", "--deepspeed_config_file", str(this_file_path / f"../conf/deepspeed/{cfg.deepspeed_config}.json"), ] - # DeepSpeed and non-DeepSpeed args accelerate_config = cfg.accelerate_config if accelerate_config is None: if cfg.use_deepspeed: @@ -362,27 +382,18 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: cmd += [ "--config_file", str(this_file_path / f"../conf/accelerate/{accelerate_config}.yaml"), - "--rdzv_backend", - "c10d", + "--rdzv_backend", "c10d", ] if gpus: gpus_str = str(",".join([str(gpu) for gpu in gpus])) if len(gpus) < world_map.node_size else "all" - cmd += [ - "--gpu-ids", - gpus_str, - ] + cmd += ["--gpu-ids", gpus_str] cmd += [ - "--num_processes", - str(world_map.total_finetune_gpus), - "pipelinerl/entrypoints/run_finetune.py", - "--config-dir", - f"{exp_dir}/conf", - "--config-name", - "exp_config", + "--num_processes", str(world_map.total_finetune_gpus), + str(this_file_path / "entrypoints/run_finetune.py"), + "--config-dir", f"{exp_dir}/conf", + "--config-name", "exp_config", f"output_dir={exp_dir}", f"hydra.run.dir={exp_dir}/finetune", - # TODO: figure out why we can't build WorldMap in run_finetune.py - # Current workaround: pass the essential information as follows: f"+me.weight_update_group_init_method=tcp://{world_map.master_addr}:{cfg.world.actor_group_port}", f"+me.weight_update_group_world_size={world_map.weight_update_group_size}", f"+me.llm_urls={'+'.join(world_map.get_actor_urls())}", @@ -390,21 +401,222 @@ def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: if cfg.debug.mode in ["finetune", "open_loop", "finetune+preprocessor"]: cmd.append("finetune.send_weight_updates=False") - logger.info(f"Running finetune with command: {' '.join(cmd)}") - save_command(exp_dir / "finetune", cmd) + finetune_nodes = world_map.nodes_with_finetuning() + finetune_rank = world_map.my_finetuning_rank() + node_suffix = f"_node{finetune_rank}" if len(finetune_nodes) > 1 else "" + + logger.info(f"Running DeepSpeed finetune with command: {' '.join(cmd)}") + save_command(exp_dir / "finetune", cmd, suffix=node_suffix) env = dict(os.environ) env["DS_ENV_FILE"] = str(exp_dir / ".deepspeed_env") - proc = _popen(cmd, env=env) + save_dir = exp_dir / "finetune" + os.makedirs(save_dir, exist_ok=True) + log_file_path = save_dir / f"stdout{node_suffix}.log" + err_file_path = save_dir / f"stderr{node_suffix}.log" + with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: + proc = _popen(cmd, env=env, stdout=log_file, stderr=err_file) + if proc is not None: + yield LaunchedProcess(kind="finetune", handle=proc) + + +def _run_finetune_fast_llm(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: Path): + save_dir = exp_dir / "finetune" + os.makedirs(save_dir, exist_ok=True) + + if not os.path.isdir(cfg.model_path): + raise ValueError( + f"fast-llm requires a local model path but got: {cfg.model_path!r}. " + "Download the model first and set model_path to its local directory." + ) + + # Build fast-llm config, stripping callbacks when weight broadcast is disabled or in debug mode. + fast_llm_cfg = OmegaConf.to_container(cfg.fast_llm, resolve=True, throw_on_missing=False) + if not cfg.weight_broadcast or bool(cfg.debug.mode): + fast_llm_cfg.pop("callbacks", None) + + # Derive experiment name for wandb from save_dir relative to workspace root. + root = cfg.wandb.wandb_workspace_root + save_dir_str = str(save_dir) + experiment_name = save_dir_str[len(root) + 1:] if root and save_dir_str.startswith(root + "/") else save_dir.name + + # Fill in all dynamic values so the saved config is fully functional. + fast_llm_cfg["pretrained"]["path"] = cfg.model_path + fast_llm_cfg["run"]["experiment_dir"] = str(save_dir) + fast_llm_cfg["run"]["experiment_name"] = experiment_name + fast_llm_cfg["data"]["datasets"]["training"]["host"] = cfg.streams.host + fast_llm_cfg["data"]["datasets"]["training"]["port"] = cfg.streams.port + if cfg.debug.log_data_pipeline: + fast_llm_cfg["data"]["datasets"]["training"]["log_data_pipeline"] = True + fast_llm_cfg.setdefault("schedule", {})["log_data_pipeline"] = True + fast_llm_cfg["training"]["wandb"]["entity_name"] = cfg.wandb.wandb_entity_name + fast_llm_cfg["training"]["wandb"]["project_name"] = cfg.wandb.wandb_project_name + fast_llm_cfg["training"]["wandb"]["group_name"] = cfg.wandb.wandb_group + if cfg.weight_broadcast and not bool(cfg.debug.mode): + fast_llm_cfg["callbacks"]["streaming"]["host"] = cfg.streams.host + fast_llm_cfg["callbacks"]["streaming"]["port"] = cfg.streams.port + # fast-llm runs on node 0 (same node as the TCPStore server); use localhost + # to avoid DNS self-resolution issues. vLLM (on node 1) uses master_addr. + fast_llm_cfg["callbacks"]["streaming"]["broadcast"]["host"] = "localhost" + fast_llm_cfg["callbacks"]["streaming"]["broadcast"]["port"] = cfg.world.actor_group_port + fast_llm_cfg["callbacks"]["streaming"]["broadcast"]["external_world_size"] = world_map.weight_update_group_size - 1 + + # Use per-node suffixes for all output files to avoid NFS write races when multiple + # finetune nodes share the same experiment directory. + model_type = cfg.fast_llm_finetune.model_type + torchrun_port = cfg.fast_llm_finetune.torchrun_port + finetune_nodes = world_map.nodes_with_finetuning() + finetune_rank = world_map.my_finetuning_rank() + node_suffix = f"_node{finetune_rank}" if len(finetune_nodes) > 1 else "" + + config_path = save_dir / f"fast_llm_config{node_suffix}.yaml" + OmegaConf.save(OmegaConf.create(fast_llm_cfg), config_path) + + if len(finetune_nodes) > 1: + finetune_master = world_map.address_map[finetune_nodes[0]] + cmd = [ + "torchrun", + f"--nproc_per_node={len(gpus)}", + f"--nnodes={len(finetune_nodes)}", + f"--node_rank={finetune_rank}", + "--rdzv_backend=static", + "--rdzv_id=0", + f"--rdzv_endpoint={finetune_master}:{torchrun_port}", + "--rdzv_conf=timeout=3600", + "--max_restarts=0", + "--no_python", + str(Path(sys.executable).parent / "fast-llm"), + "train", + model_type, + "--config", + str(config_path), + ] + else: + cmd = [ + "torchrun", + f"--nproc_per_node={len(gpus)}", + f"--master_port={torchrun_port}", + "--no_python", + str(Path(sys.executable).parent / "fast-llm"), + "train", + model_type, + "--config", + str(config_path), + ] + + logger.info(f"Running finetune with command: {' '.join(cmd)}") + save_command(save_dir, cmd, suffix=node_suffix) + env = dict(os.environ) + env["PYTHONHASHSEED"] = "42" + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(gpu) for gpu in gpus) + log_file_path = save_dir / f"stdout{node_suffix}.log" + err_file_path = save_dir / f"stderr{node_suffix}.log" + with open(log_file_path, "a") as log_file, open(err_file_path, "a") as err_file: + proc = _popen(cmd, env=env, stdout=log_file, stderr=err_file) if proc is not None: yield LaunchedProcess(kind="finetune", handle=proc) +# def run_finetune(cfg: DictConfig, world_map: WorldMap, gpus: list[int], exp_dir: Path): +# if cfg.use_fsdp and cfg.use_deepspeed: +# raise ValueError("Cannot use both FSDP and DeepSpeed") +# cmd = [ +# "python", +# "-m", +# "accelerate.commands.launch", +# ] +# if world_map.world_size > 1: +# # DeepSpeed multi-node args +# assert cfg.use_deepspeed +# assert world_map.master_addr.startswith("dns-") and world_map.master_addr.endswith("-0") +# hosts = [world_map.master_addr[:-2] + f"-{i}" for i in range(world_map.world_size)] +# filter_parts = [] +# for rank, job_list in world_map.job_map.items(): +# for job in job_list: +# if job.kind == "finetune": +# filter_parts.append(f"{hosts[rank]}:{','.join(map(str, job.gpus))}") +# deepspeed_include_filter = "@".join(filter_parts) +# logger.info(f"Deepspeed include filter: {deepspeed_include_filter}") +# # Orchestrator rank must have already created hostfile.txt +# hostfile_path = str(exp_dir / "hostfile.txt") +# cmd += [ +# "--num_machines", +# str(len(world_map.nodes_with_finetuning())), +# "--machine_rank", +# str(world_map.my_finetuning_rank()), +# "--main_process_ip", +# str(os.environ.get("MASTER_ADDR")), +# "--main_process_port", +# str(os.environ.get("MASTER_PORT")), +# "--deepspeed_hostfile", +# hostfile_path, +# "--deepspeed_inclusion_filter", +# deepspeed_include_filter, +# "--deepspeed_multinode_launcher", +# "nossh" +# ] +# # get path to this file +# this_file_path = Path(os.path.dirname(os.path.abspath(__file__))) +# if cfg.use_deepspeed: +# # DeepSpeed single-node args +# cmd += [ +# "--use_deepspeed", +# "--deepspeed_config_file", +# str(this_file_path / f"../conf/deepspeed/{cfg.deepspeed_config}.json"), +# ] +# # DeepSpeed and non-DeepSpeed args +# accelerate_config = cfg.accelerate_config +# if accelerate_config is None: +# if cfg.use_deepspeed: +# accelerate_config = "deepspeed" +# elif cfg.use_fsdp: +# accelerate_config = "fsdp_mp" +# else: +# accelerate_config = "base_mp" +# cmd += [ +# "--config_file", +# str(this_file_path / f"../conf/accelerate/{accelerate_config}.yaml"), +# "--rdzv_backend", +# "c10d", +# ] +# if gpus: +# gpus_str = str(",".join([str(gpu) for gpu in gpus])) if len(gpus) < world_map.node_size else "all" +# cmd += [ +# "--gpu-ids", +# gpus_str, +# ] +# cmd += [ +# "--num_processes", +# str(world_map.total_finetune_gpus), +# "pipelinerl/entrypoints/run_finetune.py", +# "--config-dir", +# f"{exp_dir}/conf", +# "--config-name", +# "exp_config", +# f"output_dir={exp_dir}", +# f"hydra.run.dir={exp_dir}/finetune", +# # TODO: figure out why we can't build WorldMap in run_finetune.py +# # Current workaround: pass the essential information as follows: +# f"+me.weight_update_group_init_method=tcp://{world_map.master_addr}:{cfg.world.actor_group_port}", +# f"+me.weight_update_group_world_size={world_map.weight_update_group_size}", +# f"+me.llm_urls={'+'.join(world_map.get_actor_urls())}", +# ] +# if cfg.debug.mode in ["finetune", "open_loop", "finetune+preprocessor"]: +# cmd.append("finetune.send_weight_updates=False") + +# logger.info(f"Running finetune with command: {' '.join(cmd)}") +# save_command(exp_dir / "finetune", cmd) +# env = dict(os.environ) +# env["DS_ENV_FILE"] = str(exp_dir / ".deepspeed_env") +# proc = _popen(cmd, env=env) +# if proc is not None: +# yield LaunchedProcess(kind="finetune", handle=proc) + def run_preprocess(world_map: WorldMap, preprocessor_idx: int, exp_dir: Path): if preprocessor_idx != 0: raise NotImplementedError("Can only do 1 preprocessor yet") llm_urls = "+".join(world_map.get_preprocessor_urls()) cmd = [ - "python", + sys.executable, "-m", "pipelinerl.entrypoints.run_preprocess", "--config-dir", @@ -426,7 +638,11 @@ def run_preprocess(world_map: WorldMap, preprocessor_idx: int, exp_dir: Path): def run_redis(cfg: DictConfig): - # Launch redis-server + # Launch redis-server. Resolve paths to absolutes because redis-server + # chdir's to --dir before opening --logfile, which breaks relative paths. + output_dir = Path(cfg.output_dir).resolve() + redis_dir = output_dir / "redis" + os.makedirs(redis_dir, exist_ok=True) cmd = [ "redis-server", "--bind", @@ -434,11 +650,15 @@ def run_redis(cfg: DictConfig): "--port", str(cfg.streams.port), "--dir", - str(cfg.output_dir), + str(output_dir), "--protected-mode", "no", "--save", cfg.streams.save, + "--logfile", + str(redis_dir / "redis.log"), + "--loglevel", + "verbose", ] logger.info(f"Running redis with command: {' '.join(cmd)}") save_command(Path(cfg.output_dir) / "redis", cmd) @@ -447,9 +667,9 @@ def run_redis(cfg: DictConfig): yield LaunchedProcess(kind="redis", handle=proc) -def save_command(script_dir: Path, cmd): +def save_command(script_dir: Path, cmd, suffix: str = ""): os.makedirs(script_dir, exist_ok=True) - script_path = script_dir / "start.sh" + script_path = script_dir / f"start{suffix}.sh" with open(script_path, "w") as f: f.write("#!/bin/bash\n") # Properly quote arguments for the shell script @@ -468,7 +688,6 @@ def clean_up(exp_dir, force_restart): os.remove(f"{exp_dir}/streams") if os.path.exists(f"{exp_dir}/dump.rdb"): os.remove(f"{exp_dir}/dump.rdb") - if force_restart: if os.path.exists(f"{exp_dir}/finetune"): logger.info("Cleaning up finetune directory") @@ -486,9 +705,9 @@ def is_inference_process(proc: LaunchedProcess) -> bool: return proc.kind in {"actor_llm", "preprocessor_llm"} -def watch_processes_running(exp_path: Path, processes: List[LaunchedProcess], debug_mode: bool = False): +def watch_processes_running(exp_path: Path, processes: List[LaunchedProcess], debug_mode: bool = False, use_fast_llm: bool = False, weight_broadcast: bool = True): if not debug_mode: - trainer_state = TrainerState(exp_path) + trainer_state = TrainerState(exp_path, use_fast_llm=use_fast_llm, weight_broadcast=weight_broadcast) trainer_state.start_listening() else: trainer_state = None @@ -605,6 +824,80 @@ def setup_logging(log_file: Path): logger.info("Logging setup complete") +def _get_pod_ip() -> str: + """Return this pod's primary IP (bypasses Kubernetes Service kube-proxy).""" + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) + return s.getsockname()[0] + finally: + s.close() + + +def _exchange_pod_ips(world_map: "WorldMap", exp_dir: Path, run_id: str) -> None: + """Exchange pod IPs across replicas via the shared NFS mount. + + Kubernetes Services only expose the declared master port; all other ports + (Redis, vLLM HTTP, TCPStore) are silently dropped for Service ClusterIPs. + Using pod IPs bypasses kube-proxy and gives full port access. + + After the exchange, all Job.url and Job.hostname fields are updated to use + pod IPs so every cross-node HTTP/TCP connection bypasses the Service. + """ + # Save DNS names before overwriting so DeepSpeed hostfile can use them. + world_map.dns_address_map = dict(world_map.address_map) + + ip_dir = exp_dir / ".pod_ips" / run_id + my_ip = _get_pod_ip() + + if world_map.my_rank == 0: + if ip_dir.exists(): + raise RuntimeError( + f"Pod IP exchange directory already exists for run_id={run_id!r}. " + "world.run_id must be unique per job run." + ) + ip_dir.mkdir(parents=True) + else: + waited = 0 + while not ip_dir.exists(): + time.sleep(0.5) + waited += 0.5 + if waited % 10 == 0: + logger.info(f"Waiting for rank 0 to create pod IP dir ({waited:.0f}s)...") + + ip_file = ip_dir / f"rank_{world_map.my_rank}.txt" + ip_file.write_text(my_ip) + logger.info(f"Pod IP exchange: rank {world_map.my_rank} pod IP = {my_ip}") + + pod_ips = {} + for rank in range(world_map.world_size): + peer_file = ip_dir / f"rank_{rank}.txt" + waited = 0 + while not peer_file.exists(): + time.sleep(0.5) + waited += 0.5 + if waited % 10 == 0: + logger.info(f"Waiting for pod IP from rank {rank} ({waited:.0f}s)...") + pod_ip = peer_file.read_text().strip() + pod_ips[rank] = pod_ip + world_map.address_map[rank] = pod_ip + logger.info(f"Pod IP exchange: rank {rank} → {pod_ip}") + + world_map.master_addr = pod_ips[0] + logger.info(f"Updated master_addr to pod IP: {world_map.master_addr}") + + # Update all Job URLs and hostnames to pod IPs so cross-node connections + # bypass the Kubernetes Service (which only exposes declared ports). + for node, jobs in world_map.job_map.items(): + pod_ip = pod_ips[node] + dns_name = world_map.dns_address_map[node] + for job in jobs: + job.hostname = pod_ip + if job.url: + job.url = job.url.replace(dns_name, pod_ip) + logger.info("Updated all job URLs to pod IPs for direct pod-to-pod connectivity.") + + @hydra.main( config_path="../conf/", config_name="base", @@ -620,6 +913,18 @@ def main(cfg: DictConfig): log_file = exp_dir / "launcher" / f"launcher_{os.environ.get('RANK', 0)}.log" setup_logging(log_file) world_map = WorldMap(cfg, verbose=True) + + # In multi-node EAI jobs the `dns--` names are Kubernetes Services + # that expose only the declared master port. Connecting to those Service IPs + # on any other port (Redis, vLLM HTTP, TCPStore) gets SYN-dropped by kube-proxy. + # Pod IPs bypass kube-proxy and have all ports open, so we exchange pod IPs via + # a shared NFS file and update address_map before any TCP connections are made. + if world_map.world_size > 1: + run_id = cfg.world.get("run_id") + if not run_id: + raise ValueError("world.run_id must be set for multi-node jobs (use a unique value per job run)") + _exchange_pod_ips(world_map, exp_dir, run_id) + cfg.jobs = [job.model_dump() for job in world_map.get_all_jobs()] group = str(exp_dir) @@ -639,7 +944,15 @@ def main(cfg: DictConfig): ) cfg.finetune.gradient_accumulation_passes = new_accum_passes if cfg.streams.backend == "redis": - cfg.streams.host = world_map.master_addr + if world_map.world_size > 1: + # Multi-node: use the pod IP of rank 0 (world_map.master_addr after pod IP + # exchange). Pod-to-pod connections are unrestricted on all ports, so rank 0 + # can reach its own Redis via its pod IP, and rank 1 via the cross-node pod IP. + # Using the pod IP (not localhost or a DNS name) also ensures the saved + # exp_config.yaml has a reachable address for DeepSpeed workers on node 1. + cfg.streams.host = world_map.master_addr + else: + cfg.streams.host = "localhost" set_streams_backend(**cfg.streams) processes = [] @@ -657,8 +970,9 @@ def main(cfg: DictConfig): redis.flushall() if world_map.world_size > 1: - assert world_map.master_addr.startswith("dns-") and world_map.master_addr.endswith("-0") - hosts = [world_map.master_addr[:-2] + f"-{i}" for i in range(world_map.world_size)] + # Use original DNS names (pod IP exchange may have replaced address_map with IPs). + dns_map = getattr(world_map, "dns_address_map", world_map.address_map) + hosts = [dns_map[i] for i in range(world_map.world_size)] hostfile_lines = [f"{host} slots=8" for host in hosts] deepspeed_hostfile_content = "\n".join(hostfile_lines) hostfile_path = str(exp_dir / "hostfile.txt") @@ -681,6 +995,28 @@ def main(cfg: DictConfig): raise ValueError(f"Expected {init_msg}, got {msg}") logger.info(f"Orchestrator {world_map.my_rank} heard that the exp folder is ready.") + # Pre-create the broadcast rendezvous TCPStore on actor_group_port so that + # fast-llm (launched via torchrun) can connect as a client. Torchrun sets + # TORCHELASTIC_USE_AGENT_STORE=True which makes PyTorch treat ALL ranks as + # clients in _create_c10d_store; without a pre-existing server the port is + # never opened and both fast-llm and vLLM hang forever. Only the master + # node (my_rank == 0) hosts the server; vLLM workers connect via master_addr. + broadcast_store = None + if cfg.use_fast_llm and cfg.weight_broadcast and world_map.my_rank == 0: + from torch.distributed import TCPStore + broadcast_store = TCPStore( + host_name="0.0.0.0", + port=cfg.world.actor_group_port, + world_size=world_map.weight_update_group_size, + is_master=True, + wait_for_workers=False, + ) + logger.info( + f"Broadcast TCPStore server started on " + f"{world_map.master_addr}:{cfg.world.actor_group_port} " + f"(world_size={world_map.weight_update_group_size})" + ) + if cfg.debug.mode == "finetune": processes.extend(launch_jobs(cfg, world_map, ["finetune"])) elif cfg.debug.mode == "actor": @@ -699,7 +1035,7 @@ def main(cfg: DictConfig): if os.environ.get("DRY_RUN", "0") == "1": assert not processes return - watch_processes_running(exp_dir, processes, bool(cfg.debug.mode)) + watch_processes_running(exp_dir, processes, bool(cfg.debug.mode), cfg.use_fast_llm, cfg.weight_broadcast) if __name__ == "__main__": diff --git a/pipelinerl/preprocess.py b/pipelinerl/preprocess.py index 0a6015e4..38d5c7cc 100644 --- a/pipelinerl/preprocess.py +++ b/pipelinerl/preprocess.py @@ -347,6 +347,65 @@ def write_micro_batch_slices( data_writer.write(micro_batch, lead_trainer_id) +def convert_to_fast_llm_format(entry: dict) -> dict: + """Convert a preprocessed sample entry to Fast-LLM streaming format. + + Fast-LLM RedisDocument fields: + - tokens: list of token IDs (full sequence: prompt + completion) + - loss_masking_spans: list of (start, end) spans where loss IS computed (completion only) + - advantage: scalar float (per-rollout GRPO advantage) + - old_log_probabilities: list of floats, full sequence length (zeros for prompt tokens) + """ + input_ids = entry["input_ids"] + tokens = input_ids.tolist() if hasattr(input_ids, "tolist") else list(input_ids) + + result: dict = {"tokens": tokens} + + # loss_masking_spans: contiguous spans where label == -100 (prompt tokens to mask out). + # fast-llm sets labels to -100 at these positions, so only completion tokens contribute to loss. + if "labels" in entry: + labels = entry["labels"] + labels = labels.tolist() if hasattr(labels, "tolist") else list(labels) + + spans = [] + in_span = False + span_start = 0 + for i, label in enumerate(labels): + if label == -100 and not in_span: + in_span = True + span_start = i + elif label != -100 and in_span: + spans.append((span_start, i)) + in_span = False + if in_span: + spans.append((span_start, len(labels))) + + if spans: + result["loss_masking_spans"] = spans + + # advantage: scalar per rollout (populate_rl_data stores a list of per-step scalars; + # for single-step tasks like math there is exactly one element) + if "advantages" in entry: + advantages = entry["advantages"] + if advantages: + result["advantage"] = float(advantages[0]) + + # old_log_probabilities: full sequence length, zeros for prompt tokens + # (prepare_rl_fields pads with zeros on the left to match len(input_ids)) + if "old_logprobs" in entry: + old_logprobs = entry["old_logprobs"] + old_logprobs = old_logprobs.tolist() if hasattr(old_logprobs, "tolist") else list(old_logprobs) + result["old_log_probabilities"] = [float(x) for x in old_logprobs] + + return result + + +def write_sample_for_fast_llm(data_writer: StreamWriter, entry: dict): + """Write a single sample to the stream in Fast-LLM format.""" + fast_llm_sample = convert_to_fast_llm_format(entry) + data_writer.write(fast_llm_sample) + + def run_preprocessing_loop( cfg: DictConfig, @@ -373,13 +432,27 @@ def run_preprocessing_loop( wait_for_inference_servers(llm_urls) input_stream = SingleStreamSpec(exp_path=exp_root_dir, topic=cfg.preprocess.input) - output_stream = StreamRangeSpec( - exp_path=exp_root_dir, - topic=cfg.preprocess.output, - partition_range=(0, max(world_map.total_finetune_gpus, 1)), - ) + # For Fast-LLM: use SingleStreamSpec with shared=True (uses orjson serialization) + # For standard PipelineRL: use StreamRangeSpec with partitions per GPU + if cfg.use_fast_llm: + from fast_llm.data.dataset.config import REDIS_DATA_STREAM as _FAST_LLM_DATA_STREAM + fast_llm_stream_name = _FAST_LLM_DATA_STREAM + output_stream = SingleStreamSpec( + exp_path=exp_root_dir, + topic=cfg.preprocess.output, + partition=0, + ) + use_shared_stream = True + else: + fast_llm_stream_name = None + output_stream = StreamRangeSpec( + exp_path=exp_root_dir, + topic=cfg.preprocess.output, + partition_range=(0, max(world_map.total_finetune_gpus, 1)), + ) + use_shared_stream = False stats_streams = SingleStreamSpec(exp_path=exp_root_dir, topic="preprocessor_stats") - logger.info("Streams initialized") + logger.info(f"Streams initialized (shared={use_shared_stream})") raw_chunk_queue = Queue(cfg.preprocess.raw_queue_size) rl_config = RLConfig(**cfg.finetune.rl) @@ -397,7 +470,7 @@ def run_preprocessing_loop( dataset_loader_thread.start() # Initialize TrainerState - trainer_state = TrainerState(exp_root_dir) + trainer_state = TrainerState(exp_root_dir, use_fast_llm=cfg.use_fast_llm, weight_broadcast=cfg.weight_broadcast) if cfg.debug.mode == "preprocessor": logger.info("Debug mode: preprocessor") trainer_state.debug_mode_init() @@ -462,7 +535,16 @@ def run_preprocessing_loop( # Per-trainer sample tracking (similar to finetune_loop.py) total_filtered_out = 0 # Track total filtered samples across all batches - with write_to_streams(output_stream) as data_writer, write_to_streams(stats_streams) as stats_writer: + pipeline_log_file = None + if cfg.use_fast_llm and cfg.debug.get("log_data_pipeline", False): + import json as _json + import pathlib as _pathlib + # Write alongside fast-llm rank files: {exp_dir}/finetune/data_pipeline_log/ + _log_dir = _pathlib.Path(cfg.output_dir) / "finetune" / "data_pipeline_log" + _log_dir.mkdir(parents=True, exist_ok=True) + pipeline_log_file = open(_log_dir / "preprocessor.jsonl", "a") + + with write_to_streams(output_stream, shared=use_shared_stream, stream_name_override=fast_llm_stream_name, pipelinerl_metadata=not cfg.use_fast_llm) as data_writer, write_to_streams(stats_streams) as stats_writer: with SharedMemoryManager() as smm: # Create shared memory queues without the manager parameter input_queue = SharedMemoryQueue(smm, cfg.preprocess.input_queue_size, cfg.preprocess.shared_memory_entry_size) @@ -495,6 +577,7 @@ def run_preprocessing_loop( fetching_took = 0 writing_took = 0 num_filtered_out = 0 + last_backpressure_log = 0.0 while True: if ( trainer_state.samples_processed is not None @@ -567,13 +650,43 @@ def run_preprocessing_loop( assert isinstance(trainer_state.samples_processed, int) if published_samples - trainer_state.samples_processed > max_unconsumed_samples: # wait for the finetune loop to finish processing data + now = time.time() + if now - last_backpressure_log >= 10.0: + last_backpressure_log = now + logger.info( + f"Back-pressure: published={published_samples} consumed={trainer_state.samples_processed}" + f" unconsumed={published_samples - trainer_state.samples_processed} > max={max_unconsumed_samples}, waiting" + ) continue batch_done = False start_writing = time.time() while (len(processed_entries_queue) > 0 and not batch_done) or (cfg.preprocess.dataset_buffer_size and not batch_done): logger.debug(f"[inner loop] trainer {trainer_id} has {samples_per_trainer[trainer_id]} samples, target is {target_samples_per_lead}") - if cfg.finetune.seq_packing: + + # Fast-LLM path: write individual samples directly (Fast-LLM does its own packing) + if cfg.use_fast_llm: + write_start = time.time() if pipeline_log_file else None + write_samples = 0 + write_tokens = 0 + while len(processed_entries_queue) > 0: + entry = processed_entries_queue.popleft() + if pipeline_log_file is not None: + write_samples += 1 + write_tokens += len(entry.get("input_ids", [])) + write_sample_for_fast_llm(data_writer, entry) + published_samples += 1 + if pipeline_log_file is not None and write_samples > 0: + pipeline_log_file.write(_json.dumps({ + "event": "WRITE", + "t_start": round(write_start, 3), + "t_end": round(time.time(), 3), + "samples": write_samples, + "tokens": write_tokens, + }) + "\n") + pipeline_log_file.flush() + batch_done = True # Always mark done for Fast-LLM (no batching) + elif cfg.finetune.seq_packing: if samples_per_trainer[trainer_id] == target_samples_per_lead: logger.debug(f"[inner loop] trainer {trainer_id} has all {target_samples_per_lead} samples, creating sentinel batch") sentinel_batch = create_sentinel_batch( @@ -615,8 +728,11 @@ def run_preprocessing_loop( current_length = 0 logger.debug(f"[inner loop] Packed microbatch with {len(current_batch)} samples for trainer {trainer_id}") else: + # Unpacked path: need a full micro-batch before collating. + if len(processed_entries_queue) < cfg.finetune.train_batch_size: + break # wait for more data; outer loop will refill the queue batch_entries = [] - for _ in range(cfg.finetune.train_batch_size ): + for _ in range(cfg.finetune.train_batch_size): batch_entries.append(processed_entries_queue.popleft()) batch_encoding = collate(batch_entries, tokenizer=tokenizer) write_micro_batch_slices(trainer_id, data_writer, batch_encoding, cfg.finetune.seq_parallel) @@ -667,7 +783,8 @@ def run_preprocessing_loop( logger.info( f"Processed {processed_samples} samples (filtered out {num_filtered_out}) in {processing_took:.3f}s" f" (fetching took {fetching_took:.3f} and writing took {writing_took:.3f})" - f" and wrote to {output_stream}, total {published_samples} samples so far," + f" and wrote to {output_stream}, total {published_samples} samples so far" + f" (trainer consumed {trainer_state.samples_processed}, unconsumed {published_samples - trainer_state.samples_processed})," f" {samples_in_output_queue} samples in output queue, max output queue entry size {output_queue.max_actual_entry_size()} bytes" ) start_processing = time.time() diff --git a/pipelinerl/state.py b/pipelinerl/state.py index 944b5114..a9539812 100644 --- a/pipelinerl/state.py +++ b/pipelinerl/state.py @@ -16,12 +16,17 @@ logger = logging.getLogger(__name__) +# Fast-LLM event stream name (must match fast-llm config events.redis.stream_key) +FAST_LLM_EVENTS_STREAM = "fast_llm_events" + class TrainerState: - def __init__(self, exp_path: Path): + def __init__(self, exp_path: Path, use_fast_llm: bool = False, weight_broadcast: bool = True): self.exp_path = exp_path - self.propagated_weight_version: int | None = None - self.samples_processed: int | None = None + self.use_fast_llm = use_fast_llm + self.weight_broadcast = weight_broadcast + self.propagated_weight_version: int | None = None if weight_broadcast else 0 + self.samples_processed: int | None = None if weight_broadcast else 0 self.training_done: bool = False self._training_done_event = threading.Event() @@ -32,6 +37,13 @@ def debug_mode_init(self): self._training_done_event.set() def start_listening(self): + if self.use_fast_llm: + self._start_listening_fast_llm() + else: + self._start_listening_legacy() + + def _start_listening_legacy(self): + """Listen to legacy PipelineRL trainer messages.""" stream = SingleStreamSpec(exp_path=self.exp_path, topic=TRAINER_TOPIC) def listen(): @@ -49,6 +61,95 @@ def listen(): self._thread = threading.Thread(target=listen, daemon=True) self._thread.start() + def _start_listening_fast_llm(self): + """Listen to Fast-LLM trainer events directly from Redis.""" + import orjson + from pipelinerl.streams import RedisConfig, _backend, connect_to_redis + + from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME + + # Fast-LLM event stream config (must match fast-llm config) + stream_key = FAST_LLM_EVENTS_STREAM # "fast_llm_events" + payload_key = b"event" # Fast-LLM uses "event" as payload key + + # Initialize to 0 so wait_for_processed_samples() doesn't block at startup. + # The lag thread below will update this once the data stream/consumer group exists. + self.samples_processed = 0 + + def listen_events(): + assert isinstance(_backend, RedisConfig) + r = connect_to_redis(_backend) + last_id = "0-0" + + logger.info(f"Listening for Fast-LLM events on Redis stream '{stream_key}'") + + while True: + result = r.xread({stream_key: last_id}, count=1, block=1000) + + if not result: + continue + + for stream_name, messages in result: + for msg_id, msg_data in messages: + last_id = msg_id + + if payload_key not in msg_data: + logger.warning(f"Fast-LLM event missing '{payload_key.decode()}' field: {msg_data}") + continue + + try: + event = orjson.loads(msg_data[payload_key]) + except Exception as e: + logger.error(f"Failed to parse Fast-LLM event: {e}") + continue + + event_type = event.get("type") + step = event.get("step") + + if event_type == "weights_ready": + logger.info(f"Received weights_ready event: step={step}") + self.propagated_weight_version = step + elif event_type == "training_finished": + logger.info("Received training_finished event") + self.training_done = True + self._training_done_event.set() + else: + logger.warning(f"Unknown Fast-LLM event type: {event_type}") + + def poll_lag(): + assert isinstance(_backend, RedisConfig) + r = connect_to_redis(_backend) + lag_check_interval = 0.5 # seconds + + while True: + try: + stream_info = r.xinfo_stream(REDIS_DATA_STREAM) + total_len = stream_info.get("length", 0) + groups = r.xinfo_groups(REDIS_DATA_STREAM) + for group in groups: + gname = group.get("name", "") + if isinstance(gname, bytes): + gname = gname.decode() + if gname == REDIS_GROUP_NAME: + entries_read = group.get("entries-read") + if entries_read is None: + lag = group.get("lag", 0) or 0 + entries_read = total_len - lag + self.samples_processed = int(entries_read) + logger.info( + f"Fast-LLM lag check: stream_len={total_len} entries_read={entries_read} " + f"samples_processed={self.samples_processed}" + ) + break + except Exception as e: + logger.debug(f"Fast-LLM lag check failed (stream/group not yet created?): {e}") + time.sleep(lag_check_interval) + + self._event_thread = threading.Thread(target=listen_events, daemon=True) + self._lag_thread = threading.Thread(target=poll_lag, daemon=True) + self._event_thread.start() + self._lag_thread.start() + def wait_for_training_done(self, timeout: float | None = None) -> bool: return self._training_done_event.wait(timeout=timeout) diff --git a/pipelinerl/streams.py b/pipelinerl/streams.py index 632b760e..d6f8f79d 100644 --- a/pipelinerl/streams.py +++ b/pipelinerl/streams.py @@ -192,6 +192,130 @@ def read(self): yield pickle.loads(entry[b"data"]) +class RedisSharedStreamWriter(StreamWriter): + """Redis writer that supports multiple producers appending to a single stream.""" + + def __init__( + self, + stream: SingleStreamSpec, + mode: Literal["w", "a"] = "a", + *, + writer_id: str | None = None, + maxlen: int = 1_000_000, + stream_name_override: str | None = None, + pipelinerl_metadata: bool = True, + ): + self.stream = stream + assert isinstance(_backend, RedisConfig) + self._redis = connect_to_redis(_backend) + self._stream_name = stream_name_override if stream_name_override is not None else str(self.stream) + self._counter_key = f"stream:{self._stream_name}:next_index" + self._writer_id = str(writer_id) if writer_id is not None else None + self._maxlen = maxlen + self._pipelinerl_metadata = pipelinerl_metadata + + if mode not in {"w", "a"}: + raise ValueError(f"Invalid mode: {mode}. Only 'w' and 'a' are supported.") + + if mode == "w": + last_entry = self._redis.xrevrange(self._stream_name, count=1) + if last_entry: + raise ValueError(f"Stream {self.stream} already exists. Cannot overwrite it.") + self._redis.delete(self._counter_key) + self._redis.set(self._counter_key, -1) + else: + if not self._redis.exists(self._counter_key): + last_entry = self._redis.xrevrange(self._stream_name, count=1) + if last_entry: + _, entry = last_entry[0] + raw_index = entry.get(b"index") + next_index = int(raw_index.decode("utf-8")) + 1 if raw_index else 0 + else: + next_index = 0 + self._redis.set(self._counter_key, next_index - 1) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._redis.close() + + def write(self, data, partition: int | None = None): + # Note: partition is ignored for shared streams - all data goes to a single stream + # This is intentional for Fast-LLM integration where Fast-LLM handles its own sharding + serialized = _serialize_with_orjson(data) + if self._pipelinerl_metadata: + entry_index = self._redis.incr(self._counter_key) + record: dict[str, Any] = { + "index": str(entry_index), + "data": serialized, + "ts": f"{time.time():.6f}", + } + if self._writer_id is not None: + record["writer"] = self._writer_id + else: + record = {"data": serialized} + self._redis.xadd(self._stream_name, record, maxlen=self._maxlen, approximate=True) + + +class RedisSharedStreamReader(StreamReader): + """Redis reader that validates fan-in ordering for a shared stream.""" + + def __init__(self, stream: SingleStreamSpec, *, fail_on_gap: bool = True): + self.stream = stream + assert isinstance(_backend, RedisConfig) + self._redis = connect_to_redis(_backend) + self._stream_name = str(self.stream) + self._last_id = 0 + self._expected_index: int | None = None + self._fail_on_gap = fail_on_gap + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self._redis.close() + + def _update_expected_index(self, entry: dict[bytes, bytes]): + raw_index = entry.get(b"index") + if raw_index is None: + return + + index_value = int(raw_index.decode("utf-8")) + if self._expected_index is None: + self._expected_index = index_value + elif index_value != self._expected_index: + message = ( + f"Index mismatch for shared stream {self.stream}: expected {self._expected_index}, got {index_value}" + ) + if self._fail_on_gap: + raise ValueError(message) + logger.warning(message) + self._expected_index = index_value + + self._expected_index += 1 + + def read(self): + block = int(_REREAD_DELAY * 1000) + while True: + response = self._redis.xread({self._stream_name: self._last_id}, count=1, block=block) + if not response: + continue + + stream_name, result = response[0] + assert stream_name.decode("utf-8") == self._stream_name + assert isinstance(result, list) and len(result) == 1 + entry_id, entry = result[0] + self._last_id = entry_id + self._update_expected_index(entry) + + payload = entry.get(b"data") + if payload is None: + raise ValueError(f"Shared stream entry missing 'data' field: {entry}") + + yield orjson.loads(payload) + + class RoundRobinRedisStreamWriter(StreamWriter): # TODO: share the connection across writers @@ -246,6 +370,32 @@ def stream_file(stream_dir: Path, shard_id: int) -> Path: StreamSpec = SingleStreamSpec | StreamRangeSpec +def _to_json_ready(value: Any) -> Any: + if isinstance(value, BaseModel): + value = value.model_dump() + + if isinstance(value, torch.Tensor): + return value.detach().cpu().numpy() + + if isinstance(value, numpy.ndarray): + return value + + if isinstance(value, numpy.generic): + return value.item() + + if isinstance(value, dict): + return {key: _to_json_ready(item) for key, item in value.items()} + + if isinstance(value, (list, tuple)): + return [_to_json_ready(item) for item in value] + + return value + + +def _serialize_with_orjson(data: Any) -> bytes: + return orjson.dumps(_to_json_ready(data), option=orjson.OPT_SERIALIZE_NUMPY) + + class FileStreamWriter(StreamWriter): def __init__(self, stream: SingleStreamSpec, mode: Literal["w", "a"] = "a"): self.stream = stream @@ -266,13 +416,8 @@ def write(self, data, partition: int | None = None): if partition is not None: raise ValueError() # Textual streams are so useful, that we try hard to jsonify the given object. - if isinstance(data, BaseModel): - data_dict = data.model_dump() - for key, value in data_dict.items(): - if isinstance(value, torch.Tensor): - data_dict[key] = value.numpy() - data = data_dict - self._file.write(orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY).decode("utf-8")) + payload = _serialize_with_orjson(data) + self._file.write(payload.decode("utf-8")) self._file.write("\n") self._file.flush() @@ -387,32 +532,62 @@ def write(self, data, partition: int | None = None): # Below are the public stream APIs. Easy to replace files with Redis or another pubsub system. -def read_stream(stream: SingleStreamSpec) -> StreamReader: - """Start reading the stream from the beginning""" +def read_stream(stream: SingleStreamSpec, *, shared: bool = False, fail_on_gap: bool = True) -> StreamReader: + """Start reading the stream from the beginning. + + When ``shared`` is True, multiple producers are assumed to append to the same + Redis stream and the reader will validate ordering using the stored index + metadata. + """ raise_if_backend_not_set() if not isinstance(stream, SingleStreamSpec): raise ValueError(f"Invalid stream spec: {stream}") if isinstance(_backend, RedisConfig): + if shared: + return RedisSharedStreamReader(stream, fail_on_gap=fail_on_gap) return RedisStreamReader(stream) elif _backend == "files": + if shared: + raise ValueError("Shared stream mode is only supported with the Redis backend") return FileStreamReader(stream) else: assert False -def write_to_streams(streams: StreamSpec, mode: Literal["w", "a"] = "a") -> StreamWriter: - """Append to the end of the stream.""" +def write_to_streams( + streams: StreamSpec, + mode: Literal["w", "a"] = "a", + *, + shared: bool = False, + writer_id: str | None = None, + stream_name_override: str | None = None, + pipelinerl_metadata: bool = True, +) -> StreamWriter: + """Append to the end of the stream. + + Set ``shared`` to True when multiple producers must append to the same Redis + stream and ServiceNow/Fast-LLM will perform downstream sharding. + + ``stream_name_override`` bypasses the stream spec naming and writes directly + to the given Redis key. Only supported for shared Redis streams. + """ raise_if_backend_not_set() if not isinstance(streams, (SingleStreamSpec, StreamRangeSpec)): raise ValueError(f"Invalid stream spec: {streams}") if isinstance(_backend, RedisConfig): if isinstance(streams, SingleStreamSpec): + if shared: + return RedisSharedStreamWriter(streams, mode, writer_id=writer_id, stream_name_override=stream_name_override, pipelinerl_metadata=pipelinerl_metadata) return RedisStreamWriter(streams, mode) elif isinstance(streams, StreamRangeSpec): + if shared: + raise ValueError("Shared Redis streams only support SingleStreamSpec inputs") return RoundRobinRedisStreamWriter(streams, mode) else: assert False elif _backend == "files": + if shared: + raise ValueError("Shared stream mode is only supported with the Redis backend") if isinstance(streams, SingleStreamSpec): return FileStreamWriter(streams, mode) elif isinstance(streams, StreamRangeSpec): diff --git a/pipelinerl/utils.py b/pipelinerl/utils.py index 1cd2300b..8eb8bea3 100644 --- a/pipelinerl/utils.py +++ b/pipelinerl/utils.py @@ -210,7 +210,14 @@ def init_wandb( python_env = {} for dist in distributions(): - python_env[dist.metadata["Name"]] = dist.version + if dist.metadata is None: + continue + try: + name = dist.metadata["Name"] + if name is not None: + python_env[name] = dist.version + except Exception as e: + logger.warning(f"Accessing {dist} resulted in error {e}") config_for_wandb["python_env"] = python_env if cfg.wandb.wandb_resume == "always": diff --git a/pipelinerl/vllm1.py b/pipelinerl/vllm1.py index 788364b4..855a992d 100644 --- a/pipelinerl/vllm1.py +++ b/pipelinerl/vllm1.py @@ -1,4 +1,5 @@ import asyncio +import inspect import logging import os import signal @@ -26,11 +27,15 @@ from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from pipelinerl.finetune_loop import WeightUpdateRequest +from pipelinerl.finetune_loop import WeightUpdateRequest, ParameterInfo from pipelinerl.vllm_quantization import string_to_dtype # reuse mapping +from typing import Any, Protocol, runtime_checkable, Dict, Optional +from fastapi import BackgroundTasks +import pipelinerl.torch_utils from pipelinerl.torch_utils import stateless_init_process_group -from typing import Any, Protocol, runtime_checkable import pipelinerl.vllm_quantization # Register bf16_last_layer_fp32 quantization config +from vllm.distributed import cleanup_dist_env_and_memory +from contextlib import asynccontextmanager try: from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -38,14 +43,22 @@ from vllm.tool_parsers import ToolParserManager logger = logging.getLogger(__name__) -# configure this logger individually, in order to avoid messign +# configure this logger individually, in order to avoid messing # with the default vllm logger configuration -logger.setLevel(logging.INFO) +# Check environment variable to enable DEBUG logging (for tests) +import os + +log_level = logging.DEBUG if os.getenv("PIPELINERL_DEBUG") else logging.INFO +logger.setLevel(log_level) handler = logging.StreamHandler() -handler.setLevel(logging.INFO) -formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +handler.setLevel(log_level) +formatter = logging.Formatter( + "[%(asctime)s] [VLLM-%(levelname)s] %(message)s", datefmt="%H:%M:%S" +) handler.setFormatter(formatter) logger.addHandler(handler) +# Prevent propagation to vLLM's loggers to avoid double logging +logger.propagate = False @runtime_checkable @@ -60,6 +73,15 @@ class LikeWorker(Protocol): class WorkerExtension: + def is_extension_loaded(self: LikeWorker) -> int: + """Simple method to verify the extension is loaded on workers. + + Returns: + PID of the worker process + """ + import os + + return os.getpid() def init_actor_update_group( self: LikeWorker, @@ -67,6 +89,7 @@ def init_actor_update_group( actor_ngpus: int, weight_update_group_init_method: str, weight_update_group_world_size: int, + weight_update_mode: str = "http", ): self.pg_rank = 1 + actor_idx * actor_ngpus + self.rank # log all you know @@ -77,7 +100,7 @@ def init_actor_update_group( ) logger.info( prefix - + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}" + + f"Weight update group init method: {weight_update_group_init_method}, world size: {weight_update_group_world_size}, mode: {weight_update_mode}" ) batch_invariant_env = os.getenv("VLLM_BATCH_INVARIANT", "0") @@ -98,33 +121,193 @@ def init_actor_update_group( ): os.environ.pop(_k, None) - # Use vLLM's StatelessProcessGroup instead of torch.distributed - self.model_update_group = stateless_init_process_group( - init_method=weight_update_group_init_method, - rank=self.pg_rank, - world_size=weight_update_group_world_size, - device=self.device, - ) + if weight_update_mode == 'http': + # HTTP mode uses vLLM's StatelessProcessGroup to match the trainer, + # which in pipelinerl/finetune_loop.py uses torch_utils.stateless_init_process_group. + self.model_update_group = stateless_init_process_group( + init_method=weight_update_group_init_method, + rank=self.pg_rank, + world_size=weight_update_group_world_size, + device=self.device, + ) + else: + from fast_llm.engine.distributed.config import DistributedBackend + from fast_llm.engine.distributed.distributed import ProcessGroupPool + + self.model_update_group = ProcessGroupPool( + rank=self.pg_rank, + world_size=weight_update_group_world_size, + local_world_size=1, + init_method=weight_update_group_init_method, + backend=DistributedBackend.nccl, + ).get_process_group(range(weight_update_group_world_size), self.pg_rank) + self._process_group_destroyed = False logger.info(prefix + "Actor update process group initialized") + def destroy_actor_update_group(self: LikeWorker): + self._process_group_destroyed = True + if isinstance(self.model_update_group, torch.distributed.ProcessGroup): + torch.distributed.destroy_process_group(self.model_update_group) + elif hasattr(self.model_update_group, "shutdown"): + self.model_update_group.shutdown() + # StatelessProcessGroup has no shutdown method; rely on GC. + + def is_actor_update_group_destroyed(self: LikeWorker) -> bool: + return getattr(self, "_process_group_destroyed", False) + def receive_weight_update(self: LikeWorker, request_json: str): request = WeightUpdateRequest.model_validate_json(request_json) torch.cuda.synchronize(self.device) - logger.info("Start receiving weight update") + logger.info( + f"Start receiving weight update: {len(request.parameters_info)} parameters" + ) expected_dtypes = (torch.bfloat16, torch.float32, torch.float16) - for info in request.parameters_info: + for i, info in enumerate(request.parameters_info): + logger.debug( + f"[{i+1}/{len(request.parameters_info)}] Preparing to receive: {info.name}" + ) + logger.debug(f" - shape: {info.shape}, dtype: {info.dtype}") + target_dtype = string_to_dtype(info.dtype) if target_dtype not in expected_dtypes: logger.warning(f"Unexpected dtype for {info.name}: {info.dtype}") - buffer = torch.empty(tuple(info.shape), dtype=target_dtype, device=self.device) - self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) - loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore - if len(loaded_params) != 1: - raise ValueError(f"model {info.name} not found in model state dict") + + logger.debug(f" - Creating buffer for {info.name}") + buffer = torch.empty( + tuple(info.shape), dtype=target_dtype, device=self.device + ) + logger.debug( + f" - Buffer created: shape={buffer.shape}, dtype={buffer.dtype}, device={buffer.device}" + ) + + logger.debug(f" - Calling broadcast for {info.name}...") + # StatelessProcessGroup exposes .broadcast(); torch.distributed.ProcessGroup + # (fast-llm path) uses the functional torch.distributed.broadcast. + if isinstance(self.model_update_group, torch.distributed.ProcessGroup): + torch.distributed.broadcast(buffer, src=0, group=self.model_update_group) + else: + self.model_update_group.broadcast(buffer, src=0, stream=torch.cuda.current_stream()) + logger.debug(f" - Broadcast received for {info.name}") + + logger.debug(f" - Loading weights for {info.name}...") + try: + loaded_params = self.model_runner.model.load_weights(weights=[(info.name, buffer)]) # type: ignore + if len(loaded_params) == 0: + # Parameter doesn't exist in vLLM model - this is an error + logger.error(f" - ERROR: {info.name} not found in vLLM model") + raise ValueError( + f"Parameter {info.name} not found in vLLM model state dict" + ) + elif len(loaded_params) == 1: + logger.debug(f" - Weights loaded for {info.name}") + else: + logger.error( + f" - ERROR: load_weights returned {len(loaded_params)} params for {info.name}" + ) + raise ValueError( + f"Unexpected number of parameters loaded for {info.name}" + ) + except Exception as e: + logger.error(f" - ERROR loading weights for {info.name}: {e}") + raise + + if (i + 1) % 10 == 0: + logger.info(f"Received {i+1}/{len(request.parameters_info)} parameters") pipelinerl.vllm_quantization.invalidate_fp32_cache() - logger.info("Weight update received") + logger.info("Weight update received - all parameters processed") + + def receive_weight_update_fast_llm(self: LikeWorker): + """Receive weight update via Fast-LLM broadcast protocol. + + Called via collective_rpc_async from the main-process monitoring thread, + so it runs in each worker's main thread — serialized with inference, + identical concurrency model to receive_weight_update (HTTP path). + + Protocol: + 1. Loop: receive metadata via broadcast_object_list + 2. Receive tensor via broadcast + 3. Call model.load_weights() for each parameter + 4. Exit when metadata is [None] (end signal) + """ + torch.cuda.synchronize(self.device) + logger.info(f"[Worker rank={self.rank}] Start receiving Fast-LLM weight update") + + expected_dtypes = (torch.bfloat16, torch.float32, torch.float16) + param_count = 0 + + from fast_llm.core.distributed import broadcast as _broadcast, broadcast_object as _broadcast_object + + while True: + # Receive metadata + logger.debug(f"[Worker rank={self.rank}] Waiting for metadata broadcast...") + meta = _broadcast_object(None, self.model_update_group, src=0) + logger.debug(f"[Worker rank={self.rank}] Received metadata: {meta}") + + # Check for end signal + if meta is None: + logger.info( + f"[Worker rank={self.rank}] Received end signal, finished receiving {param_count} parameters" + ) + break + + # Parse metadata: (shard_name, layer_name, shape, dtype) + # shard_name is a category label ("weights", "grads", etc.), not part of the HF param name + shard_name, layer_name, shape, dtype = meta + param_name = layer_name + + # Convert dtype to torch dtype + target_dtype = string_to_dtype(str(dtype)) + + # Allocate buffer and receive tensor (must happen for every broadcast to stay in sync) + buffer = torch.empty(tuple(shape), dtype=target_dtype, device=self.device) + _broadcast(buffer, 0, self.model_update_group) + + # Only load weight shards (skip grads, optimizer state, etc.) + if shard_name != "weights": + continue + + param_count += 1 + logger.debug( + f"[{param_count}] Receiving: {param_name}, shape={shape}, dtype={dtype}" + ) + + if target_dtype not in expected_dtypes: + logger.warning(f"Unexpected dtype for {param_name}: {dtype}") + + logger.debug(f"[{param_count}] Received tensor for {param_name}") + + # Load weights + try: + loaded_params = self.model_runner.model.load_weights( + weights=[(param_name, buffer)] + ) + if len(loaded_params) == 0: + logger.error(f"ERROR: {param_name} not found in vLLM model") + raise ValueError( + f"Parameter {param_name} not found in vLLM model state dict" + ) + elif len(loaded_params) == 1: + logger.debug(f"[{param_count}] Loaded {param_name}") + else: + logger.error( + f"ERROR: load_weights returned {len(loaded_params)} params for {param_name}" + ) + raise ValueError( + f"Unexpected number of parameters loaded for {param_name}" + ) + except Exception as e: + logger.error(f"ERROR loading {param_name}: {e!r}", exc_info=True) + raise + + if param_count % 10 == 0: + logger.info(f"[Worker rank={self.rank}] Received {param_count} parameters") + + pipelinerl.vllm_quantization.invalidate_fp32_cache() + logger.info( + f"[Worker rank={self.rank}] Fast-LLM weight update complete - {param_count} parameters processed" + ) def close_communicator(self): """Closes the communicator when weight synchronization is no longer needed.""" @@ -134,30 +317,62 @@ def close_communicator(self): logger.info("Weight update communicator closed") -class WeightUpdateManager: - def __init__(self, args, engine: AsyncLLM, engine_client: AsyncMPClient): +async def _pause_generation(engine: AsyncLLM) -> None: + """Pause generation without draining in-flight requests. + + Adapts to the installed vLLM version at runtime: newer builds expose + pause_generation(mode=) while older ones use wait_for_inflight_requests=. + """ + if 'mode' in inspect.signature(engine.pause_generation).parameters: + await engine.pause_generation(mode="keep", clear_cache=False) + else: + await engine.pause_generation(wait_for_inflight_requests=False, clear_cache=False) + + +class EngineManager: + def __init__(self, args, engine: AsyncLLM, engine_config: Any): self.args = args self.engine = engine - self.engine_client = engine_client + self.engine_config = engine_config self.update_lock = asyncio.Lock() - async def input_process_groups(self): - await self.engine_client.collective_rpc_async( + async def is_extension_loaded(self): + return await self.engine.engine_core.collective_rpc_async( + "is_extension_loaded", + args=(), + ) + + async def init_actor_update_group(self): + await self.engine.engine_core.collective_rpc_async( "init_actor_update_group", args=( self.args.actor_llm_idx, torch.cuda.device_count(), self.args.weight_update_group_init_method, self.args.weight_update_group_world_size, + getattr(self.args, "weight_update_mode", "http"), ), ) + async def destroy_actor_update_group(self): + await self.engine.engine_core.collective_rpc_async( + "destroy_actor_update_group", + args=(), + ) + + async def is_actor_update_group_destroyed(self) -> bool: + results = await self.engine.engine_core.collective_rpc_async( + "is_actor_update_group_destroyed", + args=(), + ) + return all(results) + async def receive_weight_update(self, request: WeightUpdateRequest): async with self.update_lock: version = getattr(request, "version", "unknown") pause_started_at = time.perf_counter() logger.info(f"Pausing generation for weight update version={version}") - await self.engine.pause_generation(mode="keep", clear_cache=False) + await _pause_generation(self.engine) logger.info( f"Generation paused for weight update version={version} " f"in {time.perf_counter() - pause_started_at:.3f}s" @@ -165,7 +380,7 @@ async def receive_weight_update(self, request: WeightUpdateRequest): try: update_started_at = time.perf_counter() logger.info(f"Starting weight update version={version}") - await self.engine_client.collective_rpc_async( + await self.engine.engine_core.collective_rpc_async( "receive_weight_update", args=(request.model_dump_json(),) ) logger.info( @@ -183,7 +398,276 @@ async def receive_weight_update(self, request: WeightUpdateRequest): async def close_communicator(self): """Closes the communicator when weight synchronization is no longer needed.""" - await self.engine_client.collective_rpc_async("close_communicator") + await self.engine.engine_core.collective_rpc_async("close_communicator") + + async def init_fast_llm_receiver(self): + """Store Redis connection info for the main-process monitoring thread.""" + self._redis_host = self.args.redis_host + self._redis_port = self.args.redis_port + logger.info( + f"Fast-LLM receiver initialized (Redis {self._redis_host}:{self._redis_port})" + ) + + async def receive_weight_update_fast_llm(self): + """Run a fast-llm broadcast weight update paused-for-the-duration. + + Pause/resume wraps the collective RPC symmetrically with the HTTP path + so that in-flight generation cannot interleave with a mid-broadcast + parameter swap (the source of logprob drift PR #137 closed). + + NOTE: this must NOT be used for the very first weights_ready event + after process startup, because at that point the actor has not yet + begun issuing rollouts (it's blocked in wait_for_model_version) and + pause_generation will deadlock waiting for an in-flight-decode state + that never arrives. The monitor thread gates this accordingly. + """ + async with self.update_lock: + pause_started_at = time.perf_counter() + logger.info("Pausing generation for fast-llm weight update") + await _pause_generation(self.engine) + logger.info( + f"Generation paused for fast-llm weight update " + f"in {time.perf_counter() - pause_started_at:.3f}s" + ) + try: + update_started_at = time.perf_counter() + await self.engine.engine_core.collective_rpc_async( + "receive_weight_update_fast_llm", args=() + ) + logger.info( + f"Fast-llm weight update processed " + f"in {time.perf_counter() - update_started_at:.3f}s" + ) + finally: + resume_started_at = time.perf_counter() + logger.info("Resuming generation after fast-llm weight update") + await self.engine.resume_generation() + logger.info( + f"Generation resumed after fast-llm weight update " + f"in {time.perf_counter() - resume_started_at:.3f}s" + ) + + async def start_fast_llm_monitoring(self): + """Start a single Redis monitoring thread in the main process. + + When weights_ready arrives the thread calls + collective_rpc_async("receive_weight_update_fast_llm") which runs in + each worker's main thread — blocking inference during the update, + identical concurrency to the HTTP path. training_finished is handled + the same way via destroy_actor_update_group(). + """ + import asyncio + import threading + + self._fast_llm_stop_event = threading.Event() + loop = asyncio.get_event_loop() + + def monitor_redis_stream(): + import redis + import orjson + import time + + r = redis.Redis(host=self._redis_host, port=self._redis_port) + stream_key = "fast_llm_events" + payload_key = b"event" + last_id = "0-0" + # First weights_ready event since this vLLM process started is the + # initial broadcast (step can be 0 on fresh start or k>0 on resume). + # Actor is still blocked in wait_for_model_version at this point, so + # vLLM has zero in-flight requests — pause_generation would deadlock. + # Take the raw RPC path for the first event; wrap with pause/resume + # thereafter, matching PR #137's guard against mid-rollout weight swaps. + first_weights_ready_seen = False + + logger.info("[FastLLM] Main-process Redis monitoring started") + + while not self._fast_llm_stop_event.is_set(): + try: + result = r.xread({stream_key: last_id}, count=1, block=1000) + if not result: + continue + + for _stream_name, messages in result: + for msg_id, msg_data in messages: + last_id = msg_id + + if payload_key not in msg_data: + logger.warning( + f"[FastLLM] Event missing 'event' field: {msg_data}" + ) + continue + + try: + event = orjson.loads(msg_data[payload_key]) + except Exception as e: + logger.error(f"[FastLLM] Failed to parse event: {e}") + continue + + event_type = event.get("type") + step = event.get("step") + + if event_type == "weights_ready": + if not first_weights_ready_seen: + logger.info( + f"[FastLLM] weights_ready step={step} (initial broadcast — no pause wrap)" + ) + coro = self.engine.engine_core.collective_rpc_async( + "receive_weight_update_fast_llm", args=() + ) + first_weights_ready_seen = True + else: + logger.info( + f"[FastLLM] weights_ready step={step}, dispatching to workers" + ) + coro = self.receive_weight_update_fast_llm() + try: + future = asyncio.run_coroutine_threadsafe(coro, loop) + future.result() + logger.info( + f"[FastLLM] Weight update complete: step={step}" + ) + except Exception as e: + logger.error( + f"[FastLLM] Error receiving weight update: {e}" + ) + + elif event_type == "training_finished": + logger.info( + "[FastLLM] training_finished received, destroying process group" + ) + try: + future = asyncio.run_coroutine_threadsafe( + self.destroy_actor_update_group(), loop + ) + future.result() + except Exception as e: + logger.error( + f"[FastLLM] Error destroying process group: {e}" + ) + self._fast_llm_stop_event.set() + + except Exception as e: + logger.error(f"[FastLLM] Error in Redis monitor: {e}") + if not self._fast_llm_stop_event.is_set(): + time.sleep(1) + + logger.info("[FastLLM] Main-process Redis monitoring stopped") + r.close() + + self._fast_llm_monitor_thread = threading.Thread( + target=monitor_redis_stream, + daemon=True, + name="FastLLMMonitor", + ) + self._fast_llm_monitor_thread.start() + logger.info("[FastLLM] Main-process monitoring thread started") + + async def stop_fast_llm_monitoring(self): + """Stop the main-process Fast-LLM monitoring thread.""" + if not hasattr(self, "_fast_llm_stop_event"): + return + if not self._fast_llm_stop_event.is_set(): + logger.warning("[FastLLM] training_finished was not received; forcing stop") + self._fast_llm_stop_event.set() + if hasattr(self, "_fast_llm_monitor_thread"): + self._fast_llm_monitor_thread.join(timeout=5) + logger.info("[FastLLM] Main-process monitoring thread stopped") + + @asynccontextmanager + @staticmethod + async def create_engine( + args: Any, + cleanup: bool = True, + ): + """Create vLLM AsyncLLM engine with automatic cleanup. + + This is an async context manager that ensures proper engine lifecycle + management with automatic cleanup on exit. + + Usage: + # Simple usage (tests) + async with create_engine(args) as (engine, engine_config): + # Use engine for generation + async for output in engine.generate(...): + ... + # Automatic cleanup happens here + + # Or unpack only what you need + async with create_engine(args) as (engine, _): + # Use engine, ignore config + ... + + # Server usage (no cleanup) + async with create_engine(args, cleanup=False) as (engine, engine_config): + # Use both engine and config + await init_app_state(engine, engine_config, ...) + ... + + Args: + args: Arguments object with vLLM engine configuration. + Must be compatible with AsyncEngineArgs.from_cli_args(). + Required attributes: model + Optional attributes: tensor_parallel_size, disable_log_stats, + disable_log_requests, etc. + cleanup: Whether to cleanup engine on exit (default: True). + Set to False for server usage where engine runs indefinitely. + + Yields: + Tuple of (engine, engine_config): + - engine: AsyncLLM engine instance + - engine_config: VllmConfig for init_app_state + """ + engine_args = AsyncEngineArgs.from_cli_args(args) + engine_args.worker_extension_cls = "pipelinerl.vllm1.WorkerExtension" + engine_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER) + + logger.info(f"Creating vLLM engine with model={args.model}") + engine = AsyncLLM.from_vllm_config( + vllm_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER, + disable_log_stats=engine_args.disable_log_stats, + enable_log_requests=engine_args.enable_log_requests, + ) + + logger.info("vLLM engine created successfully") + + try: + assert isinstance(engine.engine_core, AsyncMPClient) + manager = EngineManager(args, engine, engine_config) + weight_update_mode = getattr(args, "weight_update_mode", "http") + if not args.disable_weight_updates: + await manager.init_actor_update_group() + + # Initialize Fast-LLM mode if enabled + if weight_update_mode == "fast-llm": + await manager.init_fast_llm_receiver() + await manager.start_fast_llm_monitoring() + logger.info("Fast-LLM weight update mode enabled") + + yield manager + finally: + if not args.disable_weight_updates: + # Stop Fast-LLM monitoring if enabled + if weight_update_mode == "fast-llm": + await manager.stop_fast_llm_monitoring() + + if not await manager.is_actor_update_group_destroyed(): + logger.warning( + "training_finished was not called before shutdown; " + "NCCL process group was not destroyed — potential resource leak" + ) + if cleanup: + logger.info("Cleaning up vLLM engine") + # Clear manager reference to engine first + manager.engine = None + manager.engine_config = None + # Delete engine and force immediate garbage collection + del engine + del manager + import gc + + gc.collect() + cleanup_dist_env_and_memory() async def run_server(args, **uvicorn_kwargs) -> None: @@ -219,65 +703,77 @@ def signal_handler(*_) -> None: signal.signal(signal.SIGTERM, signal_handler) - engine_args = AsyncEngineArgs.from_cli_args(args) - engine_args.worker_extension_cls = "pipelinerl.vllm1.WorkerExtension" - engine_config = engine_args.create_engine_config(UsageContext.OPENAI_API_SERVER) - engine = AsyncLLM.from_vllm_config( - vllm_config=engine_config, - usage_context=UsageContext.OPENAI_API_SERVER, - disable_log_stats=engine_args.disable_log_stats, - enable_log_requests=engine_args.enable_log_requests, - ) - assert isinstance(engine.engine_core, AsyncMPClient) + # Create engine (cleanup=False since server runs indefinitely) + async with EngineManager.create_engine(args, cleanup=False) as manager: + # Run HTTP server + sock_addr = (args.host or "", args.port) + sock = create_server_socket(sock_addr) + # vLLM 0.18.1+ requires supported_tasks to build the app and app state; + # older vllm (e.g. 0.14.x) has 1-arg build_app / 3-arg init_app_state. + import inspect as _inspect + _build_app_params = _inspect.signature(build_app).parameters + if "supported_tasks" in _build_app_params and hasattr(manager.engine, "get_supported_tasks"): + supported_tasks = await manager.engine.get_supported_tasks() + logger.info(f"Supported tasks: {supported_tasks}") + app = build_app(args, supported_tasks) + else: + supported_tasks = None + app = build_app(args) + + # Register HTTP endpoint only if using HTTP mode + if getattr(args, "weight_update_mode", "http") == "http": + @app.post("/receive_weight_update") + async def _receive_weight_update(request: WeightUpdateRequest): + await manager.receive_weight_update(request) + return {"status": "ok"} + + @app.post("/training_finished") + async def _training_finished(background_tasks: BackgroundTasks): + logger.info("Received /training_finished, scheduling NCCL process group teardown") + background_tasks.add_task(manager.destroy_actor_update_group) + return {"status": "ok"} + + logger.info("HTTP weight update endpoint registered") + else: + logger.info("Fast-LLM mode: using Redis stream (no HTTP endpoint registered)") + + if "supported_tasks" in _inspect.signature(init_app_state).parameters: + await init_app_state(manager.engine, app.state, args, supported_tasks) + else: + await init_app_state(manager.engine, app.state, args) + shutdown_task = await serve_http( + app, + sock, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + # increase timeout + timeout_keep_alive=60, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) - weight_update_manager = WeightUpdateManager(args, engine, engine.engine_core) - if not args.disable_weight_updates: - await weight_update_manager.input_process_groups() + # NB: Await server shutdown only after the backend context is exited + await shutdown_task - # Run HTTP server - sock_addr = (args.host or "", args.port) - sock = create_server_socket(sock_addr) - supported_tasks = await engine.get_supported_tasks() - logger.info(f"Supported tasks: {supported_tasks}") - app = build_app(args, supported_tasks) - - @app.post("/receive_weight_update") - async def _receive_weight_update(request: WeightUpdateRequest): - # Blocking: wait for weight update to complete before returning - logger.info("Received weight update request") - await weight_update_manager.receive_weight_update(request) - return {"status": "ok"} - - await init_app_state(engine, app.state, args, supported_tasks) - shutdown_task = await serve_http( - app, - sock, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - # increase timeout - timeout_keep_alive=60, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - # NB: Await server shutdown only after the backend context is exited - await shutdown_task + sock.close() - # Cleanup - if not args.disable_weight_updates: - await weight_update_manager.close_communicator() - sock.close() + # NOTE: weight-broadcast process group teardown must be coordinated with the trainer — + # the trainer sends training_finished, then the engine manager destroys its side here. def run_llm(): - parser = FlexibleArgumentParser(description="vLLM OpenAI-Compatible RESTful API server.") + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server." + ) parser = make_arg_parser(parser) parser.add_argument( - "--disable-weight-updates", action="store_true", help="Whether to receive weight updates from the trainer" + "--disable-weight-updates", + action="store_true", + help="Whether to receive weight updates from the trainer", ) parser.add_argument( "--actor-llm-idx", @@ -291,6 +787,25 @@ def run_llm(): "--weight-update-group-world-size", type=int, ) + parser.add_argument( + "--weight-update-mode", + type=str, + choices=["http", "fast-llm"], + default="http", + help="Weight update protocol: 'http' (HTTP POST) or 'fast-llm' (Redis+broadcast)", + ) + parser.add_argument( + "--redis-host", + type=str, + default="localhost", + help="Redis host for Fast-LLM mode", + ) + parser.add_argument( + "--redis-port", + type=int, + default=6379, + help="Redis port for Fast-LLM mode", + ) args = parser.parse_args() validate_parsed_serve_args(args) diff --git a/pipelinerl/world.py b/pipelinerl/world.py index 517634c7..8173e987 100644 --- a/pipelinerl/world.py +++ b/pipelinerl/world.py @@ -57,7 +57,7 @@ def __init__(self, cfg: DictConfig, verbose: bool = False): tp = llm_kwargs.get("tensor-parallel-size", 1) pp = llm_kwargs.get("pipeline-parallel-size", 1) self.gpus_per_llm = tp * pp - self.node_size = 8 if self.world_size > 1 else torch.cuda.device_count() + self.node_size = int(os.environ.get("GPUS_PER_NODE", torch.cuda.device_count())) place_inference_jobs = not cfg.debug.mode or cfg.debug.place_inference_workers if place_inference_jobs: @@ -152,6 +152,26 @@ def _split_gpus_by_purpose(self, cfg): max(int(total_gpus * preprocessor_fraction), self.gpus_per_llm) if cfg.world.preprocessor_fraction else 0 ) desired_finetune_gpu_share = total_gpus - desired_actor_gpu_share - desired_preprocessor_gpu_share + + # For multi-node fast-llm spanning more than one node, every component + # must occupy whole nodes so torchrun's rdzv gets a clean full-node GPU + # set. Snap all three components; actor takes whatever remains. + # When fast-llm lands on a single node proportional allocation is fine. + if self.world_size > 1 and cfg.get("use_fast_llm", False): + finetune_frac = cfg.world.finetune_fraction / fraction_sum + finetune_nodes = max(1, round(self.world_size * finetune_frac)) + preprocessor_nodes = ( + max(1, round(self.world_size * preprocessor_fraction)) + if cfg.world.preprocessor_fraction else 0 + ) + actor_nodes = self.world_size - finetune_nodes - preprocessor_nodes + if cfg.world.actor_fraction > 0 and actor_nodes < 1: + finetune_nodes -= 1 + actor_nodes += 1 + if finetune_nodes > 1: + desired_finetune_gpu_share = finetune_nodes * self.node_size + desired_preprocessor_gpu_share = preprocessor_nodes * self.node_size + desired_actor_gpu_share = actor_nodes * self.node_size self._log_info( f"Desired GPU share: {desired_actor_gpu_share} for actors," f"{desired_preprocessor_gpu_share} for preprocessors, {desired_finetune_gpu_share} for finetune" diff --git a/submit_eai_math_7b_multinode.sh b/submit_eai_math_7b_multinode.sh new file mode 100755 index 00000000..7639f2ee --- /dev/null +++ b/submit_eai_math_7b_multinode.sh @@ -0,0 +1,128 @@ +#!/bin/bash +# Multi-node fast-llm finetuner math run with DS-matched params (GSPO, docs_per_step). +# Topology: actor_fraction=4 (16 GPUs / 2 nodes) + finetune_fraction=4 (16 GPUs / 2 nodes). +# Usage: bash submit_eai_math_7b_multinode.sh [NODES] [TIMESTAMP] +# Example (fresh): bash submit_eai_math_7b_multinode.sh 4 +# Example (resume): bash submit_eai_math_7b_multinode.sh 4 20260428_132330 +# Run `eai login` before executing this script. + +IMAGE="registry.toolkit-sp.yul201.service-now.com/snow.research.afm/interactive-toolkit:25.12-py3-vllm014rc1redis" + +# === PERSONALIZE THESE BEFORE RUNNING (or override via env vars) === +RESULTS_DIR="${RESULTS_DIR:-/mnt/shared/denis/math_7b_results}" # your shared NFS results dir +WANDB_ENTITY="${WANDB_ENTITY:-denisko-se}" # your wandb entity +WANDB_PROJECT="${WANDB_PROJECT:-watermelon}" # your wandb project +EAI_HOME_DATA="${EAI_HOME_DATA:-snow.home.denis_kocetkov}" # your EAI home data object +EAI_SHARED_DATA="${EAI_SHARED_DATA:-snow.research.afm.shared_fml}" # your shared NFS data object +# =================================================================== + +MODEL_PATH="${MODEL_PATH:-/home/toolkit/Qwen2.5-7B}" +NODES="${1:-4}" +TIMESTAMP="${2:-$(date +%Y%m%d_%H%M%S)}" + +EXP_NAME="math_7b_${NODES}node_fastllm_gspo_${TIMESTAMP}" +EXP_DIR="${RESULTS_DIR}/${EXP_NAME}" + +if [ -n "${2:-}" ]; then + RESUME_TS=$(date +%Y%m%d_%H%M%S) + JOB_NAME="${EXP_NAME}_resume_${RESUME_TS}" + echo "RESUMING: ${EXP_DIR} (job: ${JOB_NAME})" +else + JOB_NAME="${EXP_NAME}" +fi + +echo "Config: ${NODES} nodes, actor_fraction=4, finetune_fraction=4, docs_per_step=1024, max_train_steps=400" + +CMD=" +set -e +mkdir -p ${EXP_DIR} +cd /home/toolkit/code/PipelineRL +source /home/toolkit/code/PipelineRL/.venv/bin/activate +PYTHONHASHSEED=42 python -m pipelinerl.launch \ + --config-path /home/toolkit/code/PipelineRL/conf \ + --config-name math \ + streams=redis \ + world.actor_fraction=4 \ + world.preprocessor_fraction=0 \ + world.finetune_fraction=4 \ + world.run_id=\${MASTER_ADDR} \ + model_path=${MODEL_PATH} \ + output_dir=${EXP_DIR} \ + force_restart=true \ + actor.llm_max_rollouts=128 \ + finetune.attempts=8 \ + finetune.max_train_steps=400 \ + '+finetune.rl.filter_zero_advantage_groups=true' \ + eval_every_n_versions=0 \ + wandb.wandb_workspace_root=${RESULTS_DIR} \ + "wandb.wandb_entity_name=${WANDB_ENTITY}" \ + "wandb.wandb_project_name=${WANDB_PROJECT}" \ + wandb.wandb_group=eai_math7b_fastllm_gspo \ + '+wandb.wandb_run_name=math7b_fastllm_gspo_${NODES}node_${TIMESTAMP}' \ + 'vllm_config.vllm_kwargs.gpu-memory-utilization=0.85' \ + 'vllm_config.vllm_kwargs.max-num-batched-tokens=8192' \ + 'vllm_config.vllm_kwargs.max_model_len=20000' \ + 'llm.parameters.max_tokens=16000' \ + 'llm.parameters.temperature=0.7' \ + 'test_llm.parameters.max_tokens=16000' \ + 'test_llm.parameters.temperature=0.7' \ + 'fast_llm.data.micro_batch_size=20000' \ + '+fast_llm.schedule.docs_per_step=1024' \ + 'fast_llm.training.train_iters=400' \ + 'fast_llm.training.num_workers=1' \ + 'fast_llm.training.checkpoint.interval=20' \ + 'fast_llm.model.distributed.sequence_data_parallel=2' \ + '+fast_llm.model.distributed.timeout=3600' \ + '+fast_llm.model.base_model.decoder.block.mlp.recompute_level=full' \ + '+fast_llm.model.base_model.head.fp32_lm_head=true' \ + '+fast_llm.model.base_model.head.losses.grpo.policy_loss=gspo' \ + 'fast_llm.model.base_model.head.losses.grpo.epsilon_low=3e-3' \ + 'fast_llm.model.base_model.head.losses.grpo.epsilon_high=4e-3' \ + '+fast_llm.model.base_model.head.losses.grpo.normalize_by_documents=true' \ + '+fast_llm.model.base_model.head.losses.grpo.temperature=0.7' \ + '+fast_llm.model.base_model.head.losses.grpo.metrics=with_entropy' \ + '+fast_llm.optimizer.learning_rate.base=1e-6' \ + '+fast_llm.optimizer.learning_rate.warmup_iterations=50' \ + '+fast_llm.optimizer.learning_rate.decay_style=cosine' \ + '+fast_llm.optimizer.learning_rate.decay_iterations=400' \ + '+fast_llm.optimizer.gradient_norm_clipping=0.3' +" + +SPEC_YAML=$(mktemp /tmp/eai_job_spec_XXXXXX.yaml) +cat > "$SPEC_YAML" << 'YAML_EOF' +options: + internal-dns: + name: "" + ports: + - port: 29501 + - port: 11000 + - port: 9000 + - port: 7777 + - port: 8080 + - port: 8081 + - port: 8082 + - port: 8083 + - port: 8084 + - port: 8085 + - port: 8086 + - port: 8087 +YAML_EOF + +eai job new \ + --file "$SPEC_YAML" \ + --non-preemptable \ + --replicas "$NODES" \ + --gpu 8 \ + --cpu 128 \ + --mem 800 \ + --name "$JOB_NAME" \ + -i "$IMAGE" \ + --data "${EAI_HOME_DATA}:/home/toolkit:rw" \ + --data "${EAI_SHARED_DATA}:/mnt/shared:rw" \ + --env "HOME=/home/toolkit" \ + --env "GPUS_PER_NODE=8" \ + --env "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" \ + --env "TRITON_CACHE_DIR=/tmp/triton_cache" \ + -- /bin/bash -c "$CMD" + +rm -f "$SPEC_YAML" diff --git a/submit_eai_math_7b_multinode_ds_fastllm_branch.sh b/submit_eai_math_7b_multinode_ds_fastllm_branch.sh new file mode 100644 index 00000000..d330b351 --- /dev/null +++ b/submit_eai_math_7b_multinode_ds_fastllm_branch.sh @@ -0,0 +1,104 @@ +#!/bin/bash +# Multi-node EAI DeepSpeed GSPO math run on the fast-llm branch +# (use_fast_llm=false; DS trainer + vLLM v1 + GSPO loss; eps_low=3e-3, +# eps_high=4e-3). Reproduces the DS curve in the fast-llm vs DS comparison +# charts (docs/FAST_LLM_INTEGRATION.md). +# Topology: 1 actor node (vLLM) + (NODES-1) DeepSpeed trainer nodes. +# Usage: bash submit_eai_math_7b_multinode_ds_fastllm_branch.sh [NODES] +# Run `eai login` before executing this script. + +IMAGE="registry.toolkit-sp.yul201.service-now.com/snow.research.afm/interactive-toolkit:25.12-py3-vllm014rc1redis" + +# === PERSONALIZE THESE BEFORE RUNNING (or override via env vars) === +RESULTS_DIR="${RESULTS_DIR:-/mnt/shared/denis/math_7b_results}" # your shared NFS results dir +WANDB_ENTITY="${WANDB_ENTITY:-denisko-se}" # your wandb entity +WANDB_PROJECT="${WANDB_PROJECT:-watermelon}" # your wandb project +EAI_HOME_DATA="${EAI_HOME_DATA:-snow.home.denis_kocetkov}" # your EAI home data object +EAI_SHARED_DATA="${EAI_SHARED_DATA:-snow.research.afm.shared_fml}" # your shared NFS data object +# =================================================================== + +MODEL_PATH="${MODEL_PATH:-/home/toolkit/Qwen2.5-7B}" +NODES="${1:-4}" + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +EXP_NAME="math_7b_ds_fastllmbranch_${NODES}node_${TIMESTAMP}" +EXP_DIR="${RESULTS_DIR}/${EXP_NAME}" +JOB_NAME="${EXP_NAME}" + +echo "Config: ${NODES} nodes, DS on fast-llm branch, actor_fraction=4, finetune_fraction=4, max_train_steps=400" + +CMD=" +set -e +mkdir -p ${EXP_DIR} +cd /home/toolkit/code/PipelineRL +source /home/toolkit/code/PipelineRL/.venv/bin/activate +PYTHONHASHSEED=42 python -m pipelinerl.launch \ + --config-path /home/toolkit/code/PipelineRL/conf \ + --config-name math \ + output_dir=${EXP_DIR} \ + wandb.wandb_workspace_root=${RESULTS_DIR} \ + "wandb.wandb_entity_name=${WANDB_ENTITY}" \ + "wandb.wandb_project_name=${WANDB_PROJECT}" \ + wandb.wandb_group=eai_math7b_ds_fastllmbranch \ + '+wandb.wandb_run_name=math7b_ds_fastllmbranch_${NODES}node_${TIMESTAMP}' \ + use_fast_llm=false \ + actor.llm_max_rollouts=128 \ + force_restart=true \ + finetune.learning_rate=1e-6 \ + finetune.attempts=8 \ + finetune.rl.policy_loss=gspo \ + finetune.rl.epsilon_low=3e-3 \ + finetune.rl.epsilon_high=4e-3 \ + '+finetune.rl.filter_zero_advantage_groups=true' \ + finetune.max_train_steps=400 \ + finetune.seq_length=20000 \ + finetune.gradient_accumulation_passes=1024 \ + 'vllm_config.vllm_kwargs.max_model_len=20000' \ + 'llm.parameters.max_tokens=16000' \ + 'llm.parameters.temperature=0.7' \ + 'test_llm.parameters.max_tokens=16000' \ + 'test_llm.parameters.temperature=0.7' \ + world.actor_fraction=4 \ + world.preprocessor_fraction=0 \ + world.finetune_fraction=4 \ + world.run_id=\${MASTER_ADDR} \ + streams=files \ + eval_every_n_versions=0 \ + model_path=${MODEL_PATH} +" + +SPEC_YAML=$(mktemp /tmp/eai_job_spec_XXXXXX.yaml) +cat > "$SPEC_YAML" << 'YAML_EOF' +options: + internal-dns: + name: "" + ports: + - port: 29501 + - port: 9000 + - port: 7777 + - port: 8080 + - port: 8081 + - port: 8082 + - port: 8083 + - port: 8084 + - port: 8085 + - port: 8086 + - port: 8087 +YAML_EOF + +eai job new \ + --file "$SPEC_YAML" \ + --non-preemptable \ + --replicas "$NODES" \ + --gpu 8 \ + --cpu 128 \ + --mem 800 \ + --name "$JOB_NAME" \ + -i "$IMAGE" \ + --data "${EAI_HOME_DATA}:/home/toolkit:rw" \ + --data "${EAI_SHARED_DATA}:/mnt/shared:rw" \ + --env "HOME=/home/toolkit" \ + --env "PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True" \ + -- /bin/bash -c "$CMD" + +rm -f "$SPEC_YAML" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..9b491dd0 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for PipelineRL vLLM integration.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..e33b8261 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,196 @@ +"""Pytest configuration and fixtures for vllm1 tests.""" + +import os +import pytest +import torch +import tempfile +from pathlib import Path +import subprocess +import sys + +from pipelinerl.vllm1 import EngineManager + + +@pytest.fixture(scope="session") +def model_name(): + """Model to use for testing.""" + return "Qwen/Qwen2.5-0.5B-Instruct" + + +@pytest.fixture(scope="session") +def sample_prompts(): + """Sample prompts for generation testing.""" + return [ + "Write a haiku about coding:", + "The capital of France is", + "In a galaxy far away,", + ] + + +@pytest.fixture(scope="session") +def simple_prompt(): + """Single simple prompt for deterministic testing.""" + return "The capital of France is" + + +@pytest.fixture(scope="session") +def num_gpus(): + """Number of GPUs available.""" + return torch.cuda.device_count() + + +@pytest.fixture(scope="session") +def require_2_gpus(num_gpus): + """Skip test if less than 2 GPUs available.""" + if num_gpus < 2: + pytest.skip("Test requires at least 2 GPUs") + + +@pytest.fixture(scope="session") +def require_gpu(): + """Skip test if no GPU available.""" + if not torch.cuda.is_available(): + pytest.skip("Test requires GPU") + + +@pytest.fixture +def temp_dir(): + """Temporary directory for test files.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture(scope="session") +def shared_test_dir(): + """Session-scoped shared directory for test data that persists across tests. + + Use this for data that needs to be shared between tests (like perturbed weights). + """ + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def distributed_init_method(temp_dir): + """File-based init method for distributed testing.""" + return f"file://{temp_dir}/dist_init" + + +@pytest.fixture(scope="session") +def shared_distributed_init_method(shared_test_dir): + """Session-scoped file-based init method for tests that share data.""" + return f"file://{shared_test_dir}/dist_init" + + +@pytest.fixture(scope="session") +def cache_dir(): + """Directory for caching downloaded models.""" + cache_path = Path(os.environ.get("HF_HOME", Path.home() / ".cache" / "huggingface")) + cache_path.mkdir(parents=True, exist_ok=True) + return cache_path + + +@pytest.fixture +def vllm_server_port(): + """Port for vLLM server in tests.""" + # Use a high port to avoid conflicts + return 8765 + + +@pytest.fixture +def generation_config(): + """Configuration for deterministic generation.""" + return { + "temperature": 0.0, + "top_p": 1.0, + "max_tokens": 50, + "seed": 42, + } + + +@pytest.fixture +def vllm_engine_factory_2gpu(model_name): + """Factory fixture that defaults to 2 GPUs. + + Usage: + async with vllm_engine_factory_2gpu() as manager: + # Uses 2 GPUs by default + # Access engine via manager.engine + ... + """ + def _factory(tensor_parallel_size: int = 2, **kwargs): + """Create engine with 2 GPUs by default.""" + import argparse + + args = argparse.Namespace( + model=model_name, + tensor_parallel_size=tensor_parallel_size, + disable_log_stats=True, + enable_log_requests=False, + **kwargs + ) + + return EngineManager.create_engine(args) + + return _factory + + +@pytest.fixture +def vllm_engine_factory(model_name): + """Factory fixture for creating vLLM engines. + + Usage in tests: + async with vllm_engine_factory() as manager: + # use manager.engine for generation + ... + # automatic cleanup + + Or with custom config: + async with vllm_engine_factory(tensor_parallel_size=2) as manager: + # use manager.engine with 2 GPUs + ... + + Or if you need engine_config: + async with vllm_engine_factory() as manager: + # access manager.engine, manager.engine_config, manager.args + ... + """ + def _factory(tensor_parallel_size: int = 1, **kwargs): + """Create engine context manager with test defaults. + + Args: + tensor_parallel_size: Number of GPUs + **kwargs: Additional attributes for args object + + Returns: + Async context manager for EngineManager + """ + import argparse + + # Create minimal args object with required attributes for AsyncEngineArgs.from_cli_args() + args = argparse.Namespace( + model=model_name, + tensor_parallel_size=tensor_parallel_size, + disable_log_stats=True, + enable_log_requests=False, + # Apply any additional kwargs + **kwargs + ) + + print("args: ", args) + + return EngineManager.create_engine(args) + + return _factory + + +@pytest.fixture +def distributed_trainer_helper(): + """Path to the distributed trainer helper script.""" + return Path(__file__).parent / "distributed_trainer_helper.py" + + +@pytest.fixture +def vllm_engine_helper(): + """Path to the vLLM engine helper script.""" + return Path(__file__).parent / "vllm_engine_helper.py" diff --git a/tests/distributed_trainer_helper.py b/tests/distributed_trainer_helper.py new file mode 100755 index 00000000..7e50decb --- /dev/null +++ b/tests/distributed_trainer_helper.py @@ -0,0 +1,650 @@ +#!/usr/bin/env python3 +"""Helper script for distributed trainer process. + +This script is run as a separate process with CUDA_VISIBLE_DEVICES set, +allowing proper GPU isolation for distributed tests. +""" + +import sys +import argparse +import logging +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) +from trainer_test_utils import ( + _resolve_model_path, + _load_state_dict, + _create_perturbed_state_dict, + _init_actor_process_group, + _broadcast_tensors, + _wait_for_servers_ready, +) + +# Setup debug logging +logging.basicConfig( + level=logging.DEBUG, + format="[%(asctime)s] [TRAINER-%(levelname)s] %(message)s", + datefmt="%H:%M:%S", +) +logger = logging.getLogger(__name__) + + +def _wait_all_actors(sync_path, name: str, num_actors: int, timeout: float = 120): + """Wait for all actors to signal a named sync point. + + Each actor signals ``{name}_actor_{i}`` for i in range(num_actors). + """ + from pathlib import Path + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint + + for i in range(num_actors): + SyncPoint(sync_path, f"{name}_actor_{i}").wait(timeout=timeout) + + +def _broadcast_via_server( + state_dict: dict, + server_urls: list, + version: int, + process_group, + label: str = "", +): + """Broadcast weights to one or more running vLLM servers via HTTP POST + NCCL. + + One POST thread is started per server URL (all in parallel) before the + NCCL broadcast so that all servers are ready to receive simultaneously. + """ + import threading + import time + import requests + from weight_update_utils import create_weight_update_request_from_state_dict + + label_str = f" {label}" if label else "" + print(f"[Trainer] Broadcasting {len(state_dict)}{label_str} parameters to {len(server_urls)} server(s)") + + request = create_weight_update_request_from_state_dict(state_dict, version=version) + + errors = [] + threads = [] + + for url in server_urls: + err = {"error": None} + errors.append(err) + + def _post(server_url=url, post_result=err): + try: + print(f"[Trainer] POSTing weight update request to {server_url}...") + resp = requests.post( + f"{server_url}/receive_weight_update", + json=request.model_dump(), + timeout=600, + ) + if resp.status_code != 200: + post_result["error"] = ( + f"POST to {server_url} failed with status {resp.status_code}: {resp.text}" + ) + else: + print(f"[Trainer] Server {server_url} acknowledged weight update") + except Exception as e: + post_result["error"] = f"POST to {server_url} failed: {e}" + + t = threading.Thread(target=_post, daemon=False) + threads.append(t) + t.start() + + time.sleep(0.5) # Give all servers a moment to start receiving + + _broadcast_tensors(state_dict, process_group) + + for t in threads: + t.join(timeout=60) + + failed = [e["error"] for e in errors if e["error"]] + if failed: + raise RuntimeError(f"Weight update POST(s) failed: {failed}") + + print(f"[Trainer] Broadcast{label_str} complete") + + +# --------------------------------------------------------------------------- +# Public command functions +# --------------------------------------------------------------------------- + +def init_process_group(init_method: str, rank: int, world_size: int): + """Initialize a distributed process group and wait.""" + import torch.distributed as dist + import time + + process_group = _init_actor_process_group(init_method, rank, world_size) + print(f"[Trainer rank={rank}] Process group initialized successfully") + + # Wait for coordination + time.sleep(3) + + print(f"[Trainer rank={rank}] Destroying process group") + dist.destroy_process_group(process_group) + print(f"[Trainer rank={rank}] Process group destroyed") + + +def save_model_to_dir(state_dict: dict, output_dir: str, model_name: str): + """Save state_dict to a directory as safetensors with config. + + Args: + state_dict: Model state dict to save + output_dir: Directory to save model + model_name: Original model name to copy config from + """ + from pathlib import Path + from safetensors.torch import save_file + import shutil + + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Save weights as safetensors + safetensors_path = output_path / "model.safetensors" + save_file(state_dict, str(safetensors_path)) + print(f"[Trainer] Saved model weights to {safetensors_path}") + + # Copy config.json from original model + original_path = _resolve_model_path(model_name) + + config_src = original_path / "config.json" + config_dst = output_path / "config.json" + shutil.copy(config_src, config_dst) + print(f"[Trainer] Copied config.json to {config_dst}") + + # Copy tokenizer files + for filename in [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + "tokenizer.model", + ]: + src = original_path / filename + if src.exists(): + dst = output_path / filename + shutil.copy(src, dst) + print(f"[Trainer] Copied {filename}") + + return str(output_path) + + +def broadcast_weights( + init_method: str, model_name: str, perturb: bool = False, sync_dir: str = None +): + """Load model and broadcast weights to vLLM worker.""" + import torch + import torch.distributed as dist + from pathlib import Path + + # Setup sync points if provided + if sync_dir: + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint, write_weight_update_request + + sync_path = Path(sync_dir) + baseline_done = SyncPoint(sync_path, "baseline_done") + ready_to_receive = SyncPoint(sync_path, "ready_to_receive") + request_ready = SyncPoint(sync_path, "request_ready") + receiving_started = SyncPoint(sync_path, "receiving_started") + broadcast_done = SyncPoint(sync_path, "broadcast_done") + + # IMPORTANT: Initialize process group FIRST (before any waiting) + process_group = _init_actor_process_group(init_method, rank=0, world_size=2) + + # Now wait for vLLM to finish baseline and be ready to receive + if sync_dir: + print("[Trainer] Waiting for vLLM to finish baseline generation...") + baseline_done.wait(timeout=60) + print("[Trainer] Baseline done") + + print("[Trainer] Waiting for vLLM to be ready to receive weights...") + ready_to_receive.wait(timeout=60) + print("[Trainer] vLLM ready, starting weight broadcast") + + print(f"[Trainer] Loading tensors from safetensors for {model_name}") + state_dict, _ = _load_state_dict(model_name) + + params_to_broadcast = state_dict + print(f"[Trainer] Will broadcast {len(params_to_broadcast)} parameters") + + # Create and send WeightUpdateRequest to vLLM + if sync_dir: + from weight_update_utils import create_weight_update_request_from_state_dict + + print("[Trainer] Creating WeightUpdateRequest...") + request = create_weight_update_request_from_state_dict( + params_to_broadcast, version=1 + ) + write_weight_update_request(sync_path, request) + request_ready.signal() + print( + f"[Trainer] Sent WeightUpdateRequest with {len(request.parameters_info)} parameters" + ) + + # Wait for vLLM to start receiving before we broadcast + print("[Trainer] Waiting for vLLM to start receiving...") + receiving_started.wait(timeout=60) + print("[Trainer] vLLM is receiving, starting broadcast") + + print(f"[Trainer] Broadcasting {len(params_to_broadcast)} parameters") + + # Optionally perturb weights - add noise to ALL tensors + if perturb: + params_to_broadcast = _create_perturbed_state_dict(params_to_broadcast) + + # Broadcast each weight with detailed logging + logger.info(f"Starting broadcast of {len(params_to_broadcast)} parameters") + for i, (name, tensor) in enumerate(params_to_broadcast.items()): + logger.debug(f"[{i+1}/{len(state_dict)}] Preparing to broadcast: {name}") + logger.debug( + f" - shape: {tensor.shape}, dtype: {tensor.dtype}, device: {tensor.device}" + ) + if tensor.device.type != "cuda": + logger.debug(f" - Moving {name} to CUDA") + tensor = tensor.cuda(0) + logger.debug(f" - {name} now on device: {tensor.device}") + logger.debug(f" - Calling dist.broadcast for {name}...") + dist.broadcast(tensor, src=0, group=process_group) + logger.debug(f" - Broadcast complete for {name}") + if (i + 1) % 10 == 0: + logger.info(f"Broadcasted {i+1}/{len(params_to_broadcast)} parameters") + + print(f"[Trainer] All {len(params_to_broadcast)} parameters broadcasted") + + # Signal broadcast complete BEFORE destroying process group + if sync_dir: + broadcast_done.signal() + print("[Trainer] Signaled broadcast complete") + + dist.destroy_process_group(process_group) + print("[Trainer] Process group destroyed") + + +def broadcast_cross_validation( + init_method: str, model_name: str, sync_dir: str, temp_dir: str +): + """Cross-validation test: broadcast perturbed, then original weights. + + Also saves perturbed model to disk for vLLM to load. + """ + import torch.distributed as dist + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint, write_weight_update_request + from weight_update_utils import create_weight_update_request_from_state_dict + + sync_path = Path(sync_dir) + baseline_done = SyncPoint(sync_path, "baseline_done") + perturbed_model_saved = SyncPoint(sync_path, "perturbed_model_saved") + ready_to_receive_perturbed = SyncPoint(sync_path, "ready_to_receive_perturbed") + perturbed_broadcast_done = SyncPoint(sync_path, "perturbed_broadcast_done") + mod1_done = SyncPoint(sync_path, "mod1_done") + first_engine_destroyed = SyncPoint(sync_path, "first_engine_destroyed") + engine_recreated = SyncPoint(sync_path, "engine_recreated") + ready_to_receive_original = SyncPoint(sync_path, "ready_to_receive_original") + original_broadcast_done = SyncPoint(sync_path, "original_broadcast_done") + + process_group = _init_actor_process_group(init_method, rank=0, world_size=2) + + print("[Trainer] Waiting for vLLM baseline generation...") + baseline_done.wait(timeout=120) + + print(f"[Trainer] Loading original model {model_name}") + original_state_dict, model_path = _load_state_dict(model_name) + + perturbed_state_dict = _create_perturbed_state_dict(original_state_dict) + + # Save perturbed model to disk + perturbed_model_dir = Path(temp_dir) / "perturbed_model" + print(f"[Trainer] Saving perturbed model to {perturbed_model_dir}") + saved_path = save_model_to_dir( + perturbed_state_dict, str(perturbed_model_dir), str(model_path) + ) + + path_file = sync_path / "perturbed_model_path.txt" + path_file.write_text(saved_path) + perturbed_model_saved.signal() + print(f"[Trainer] Signaled perturbed model saved at: {saved_path}") + + # Broadcast perturbed weights + print("[Trainer] Waiting for vLLM to be ready for perturbed broadcast...") + ready_to_receive_perturbed.wait(timeout=120) + + print(f"[Trainer] Broadcasting {len(perturbed_state_dict)} perturbed parameters") + request = create_weight_update_request_from_state_dict(perturbed_state_dict, version=1) + write_weight_update_request(sync_path, request) + _broadcast_tensors(perturbed_state_dict, process_group) + + perturbed_broadcast_done.signal() + print("[Trainer] Perturbed weights broadcast complete") + + print("[Trainer] Waiting for vLLM to finish res_mod_1...") + mod1_done.wait(timeout=120) + + print("[Trainer] Destroying process group for first broadcast") + dist.destroy_process_group(process_group) + + print("[Trainer] Waiting for vLLM to destroy first engine...") + first_engine_destroyed.wait(timeout=120) + + print("[Trainer] Recreating process group for second broadcast") + process_group = _init_actor_process_group(init_method, rank=0, world_size=2) + print("[Trainer] Process group recreated, waiting at rendezvous...") + + print("[Trainer] Waiting for vLLM to recreate engine...") + engine_recreated.wait(timeout=300) # 5 minutes - engine creation can be slow + print("[Trainer] vLLM engine recreated, both in new process group") + + # Broadcast original weights + print("[Trainer] Waiting for vLLM to be ready for original broadcast...") + ready_to_receive_original.wait(timeout=120) + + print(f"[Trainer] Broadcasting {len(original_state_dict)} original parameters") + request = create_weight_update_request_from_state_dict(original_state_dict, version=2) + write_weight_update_request(sync_path, request) + _broadcast_tensors(original_state_dict, process_group) + + original_broadcast_done.signal() + print("[Trainer] Original weights broadcast complete") + + dist.destroy_process_group(process_group) + print("[Trainer] Process group destroyed") + + +def broadcast_back_and_forth( + init_method: str, + model_name: str, + sync_dir: str, + num_actors: int = 1, + world_size: int = 2, +): + """Back-and-forth test: broadcast perturbed → original → perturbed again. + + Tests that we can switch between weight sets multiple times. + Supports multiple actors: waits for all actors to signal readiness before + each broadcast, then sends a single shared completion signal. + """ + import torch.distributed as dist + from pathlib import Path + + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint, write_weight_update_request + from weight_update_utils import create_weight_update_request_from_state_dict + + sync_path = Path(sync_dir) + perturbed1_done = SyncPoint(sync_path, "perturbed1_done") + original_done = SyncPoint(sync_path, "original_done") + perturbed2_done = SyncPoint(sync_path, "perturbed2_done") + + process_group = _init_actor_process_group(init_method, rank=0, world_size=world_size) + + print(f"[Trainer] Waiting for {num_actors} actor(s) to finish baseline generation...") + _wait_all_actors(sync_path, "baseline_done", num_actors, timeout=120) + + print(f"[Trainer] Loading model {model_name}") + original_state_dict, model_path = _load_state_dict(model_name) + + perturbed_state_dict = _create_perturbed_state_dict(original_state_dict) + + # Save perturbed weights for reuse in server tests + perturbed_weights_dir = Path(sync_dir) / "perturbed_weights" + print(f"[Trainer] Saving perturbed weights to {perturbed_weights_dir}") + saved_path = save_model_to_dir( + perturbed_state_dict, str(perturbed_weights_dir), str(model_path) + ) + print(f"[Trainer] Perturbed weights saved to {saved_path}") + + # Broadcast 1: Perturbed weights + print(f"[Trainer] Waiting for {num_actors} actor(s) to be ready for first perturbed broadcast...") + _wait_all_actors(sync_path, "ready_for_perturbed1", num_actors, timeout=120) + + print(f"[Trainer] Broadcasting perturbed weights (1st time) to {num_actors} actor(s)") + request = create_weight_update_request_from_state_dict(perturbed_state_dict, version=1) + write_weight_update_request(sync_path, request) + _broadcast_tensors(perturbed_state_dict, process_group) + + perturbed1_done.signal() + print("[Trainer] First perturbed broadcast complete") + + # Broadcast 2: Original weights + print(f"[Trainer] Waiting for {num_actors} actor(s) to be ready for original broadcast...") + _wait_all_actors(sync_path, "ready_for_original", num_actors, timeout=120) + + print(f"[Trainer] Broadcasting original weights to {num_actors} actor(s)") + request = create_weight_update_request_from_state_dict(original_state_dict, version=2) + write_weight_update_request(sync_path, request) + _broadcast_tensors(original_state_dict, process_group) + + original_done.signal() + print("[Trainer] Original broadcast complete") + + # Broadcast 3: Perturbed weights again (same as first) + print(f"[Trainer] Waiting for {num_actors} actor(s) to be ready for second perturbed broadcast...") + _wait_all_actors(sync_path, "ready_for_perturbed2", num_actors, timeout=120) + + print(f"[Trainer] Broadcasting perturbed weights (2nd time) to {num_actors} actor(s)") + request = create_weight_update_request_from_state_dict(perturbed_state_dict, version=3) + write_weight_update_request(sync_path, request) + _broadcast_tensors(perturbed_state_dict, process_group) + + perturbed2_done.signal() + print("[Trainer] Second perturbed broadcast complete") + + dist.destroy_process_group(process_group) + print("[Trainer] Process group destroyed") + + +def timed_broadcast_server_test( + init_method: str, + model_name: str, + server_urls: list, + world_size: int = 2, +): + """Timed broadcast for server tests: perturbed → original → perturbed with delays. + + This simulates a real-world scenario where weight updates happen while + the server is running and serving requests. + + Pattern: original (server default) → perturbed → original → perturbed + + Args: + init_method: Distributed init method + model_name: Model name to load + server_urls: List of base URLs of vLLM servers (e.g., ["http://127.0.0.1:8000"]) + world_size: Total world size (trainer rank 0 + all vLLM workers) + """ + import torch.distributed as dist + import time + import requests + + process_group = _init_actor_process_group(init_method, rank=0, world_size=world_size) + + _wait_for_servers_ready(server_urls, extra_wait_secs=10) + + print(f"[Trainer] Loading original weights from {model_name}") + original_state_dict, _ = _load_state_dict(model_name) + + perturbed_state_dict = _create_perturbed_state_dict(original_state_dict) + + # Broadcast 1: Perturbed weights + _broadcast_via_server(perturbed_state_dict, server_urls, version=1, process_group=process_group, label="perturbed") + + print("[Trainer] Waiting 5 seconds before broadcasting original weights...") + time.sleep(5) + + # Broadcast 2: Original weights + _broadcast_via_server(original_state_dict, server_urls, version=2, process_group=process_group, label="original") + + print("[Trainer] Waiting 5 seconds before broadcasting perturbed weights again...") + time.sleep(5) + + # Broadcast 3: Perturbed weights again (same as first) + _broadcast_via_server(perturbed_state_dict, server_urls, version=3, process_group=process_group, label="perturbed (2nd time)") + + # Wait to allow generation with the last broadcast before tearing down + print("[Trainer] Waiting 5 seconds for generation with final weights...") + time.sleep(5) + + # Signal training is finished so vLLM servers destroy their side of the process group + for url in server_urls: + print(f"[Trainer] Sending training_finished signal to {url}...") + requests.post(f"{url}/training_finished", timeout=10) + + # Cleanup — destroy_process_group now resolves because vLLM servers respond to /training_finished + dist.destroy_process_group(process_group) + print("[Trainer] Process group destroyed, exiting") + + +def rapid_broadcast_cycles( + init_method: str, + model_name: str, + server_urls: list, + world_size: int = 2, + n_cycles: int = 6, +): + """Hybrid broadcast designed to catch transition/garbage generations. + + Structure: + 1. Slow broadcast: perturbed (5 s wait after) — establishes text_B + 2. Slow broadcast: original (5 s wait after) — re-establishes text_A + 3. n_cycles rapid pairs: perturbed → original (1 s between each) + 4. Slow broadcast: perturbed (5 s wait after) — end on text_B so the + overall A→B→A→B pattern remains detectable + + The slow initial cycles give the generation loop enough stable time to + identify text_A and text_B by frequency. The rapid cycles create many + short broadcast windows where mid-broadcast (garbage) generations are + likely to be caught by a zero-interval generation loop. + """ + import torch.distributed as dist + import time + import requests + + process_group = _init_actor_process_group(init_method, rank=0, world_size=world_size) + + _wait_for_servers_ready(server_urls, extra_wait_secs=10) + + print(f"[Trainer] Loading weights from {model_name}") + original_state_dict, _ = _load_state_dict(model_name) + perturbed_state_dict = _create_perturbed_state_dict(original_state_dict) + + version = 1 + + # --- Slow cycle: establish text_B and text_A clearly --- + print("[Trainer] Slow broadcast 1: perturbed (establishing text_B)...") + _broadcast_via_server(perturbed_state_dict, server_urls, version=version, process_group=process_group, label="perturbed (slow)") + version += 1 + time.sleep(5) + + print("[Trainer] Slow broadcast 2: original (re-establishing text_A)...") + _broadcast_via_server(original_state_dict, server_urls, version=version, process_group=process_group, label="original (slow)") + version += 1 + time.sleep(5) + + # --- Rapid cycles: 1 s between broadcasts --- + for i in range(n_cycles): + print(f"[Trainer] Rapid cycle {i + 1}/{n_cycles}: perturbed...") + _broadcast_via_server(perturbed_state_dict, server_urls, version=version, process_group=process_group, label=f"perturbed (rapid {i + 1})") + version += 1 + time.sleep(1) + + print(f"[Trainer] Rapid cycle {i + 1}/{n_cycles}: original...") + _broadcast_via_server(original_state_dict, server_urls, version=version, process_group=process_group, label=f"original (rapid {i + 1})") + version += 1 + time.sleep(1) + + # --- Final slow broadcast: end on perturbed so ABAB pattern holds --- + print("[Trainer] Final slow broadcast: perturbed (ending on text_B)...") + _broadcast_via_server(perturbed_state_dict, server_urls, version=version, process_group=process_group, label="perturbed (final)") + time.sleep(5) + + for url in server_urls: + print(f"[Trainer] Sending training_finished signal to {url}...") + requests.post(f"{url}/training_finished", timeout=10) + + dist.destroy_process_group(process_group) + print("[Trainer] Process group destroyed, exiting") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Distributed trainer helper") + parser.add_argument("command", choices=["init", "broadcast", "cross_validation", "back_and_forth", "timed_broadcast_server_test", "rapid_broadcast_cycles"]) + parser.add_argument("--init-method", required=True) + parser.add_argument("--rank", type=int, default=0) + parser.add_argument("--world-size", type=int, default=2) + parser.add_argument("--model-name", type=str) + parser.add_argument("--perturb", action="store_true") + parser.add_argument("--sync-dir", type=str, help="Directory for sync files") + parser.add_argument( + "--temp-dir", type=str, help="Temporary directory for saving models" + ) + parser.add_argument( + "--server-urls", nargs="+", help="Base URL(s) of vLLM server(s) (e.g., http://127.0.0.1:8000)" + ) + parser.add_argument("--num-actors", type=int, default=1, help="Number of vLLM actor processes") + parser.add_argument("--n-cycles", type=int, default=6, help="Number of rapid broadcast cycles (rapid_broadcast_cycles command)") + + args = parser.parse_args() + + try: + if args.command == "init": + init_process_group(args.init_method, args.rank, args.world_size) + elif args.command == "broadcast": + if not args.model_name: + print("Error: --model-name required for broadcast command") + sys.exit(1) + broadcast_weights( + args.init_method, args.model_name, args.perturb, args.sync_dir + ) + elif args.command == "cross_validation": + if not args.model_name or not args.sync_dir or not args.temp_dir: + print( + "Error: --model-name, --sync-dir, and --temp-dir required for cross_validation" + ) + sys.exit(1) + broadcast_cross_validation( + args.init_method, args.model_name, args.sync_dir, args.temp_dir + ) + elif args.command == "back_and_forth": + if not args.model_name or not args.sync_dir: + print("Error: --model-name and --sync-dir required for back_and_forth") + sys.exit(1) + broadcast_back_and_forth( + args.init_method, + args.model_name, + args.sync_dir, + num_actors=args.num_actors, + world_size=args.world_size, + ) + elif args.command == "timed_broadcast_server_test": + if not args.model_name or not args.server_urls: + print("Error: --model-name and --server-urls required for timed_broadcast_server_test") + sys.exit(1) + timed_broadcast_server_test( + args.init_method, + args.model_name, + args.server_urls, + world_size=args.world_size, + ) + elif args.command == "rapid_broadcast_cycles": + if not args.model_name or not args.server_urls: + print("Error: --model-name and --server-urls required for rapid_broadcast_cycles") + sys.exit(1) + rapid_broadcast_cycles( + args.init_method, + args.model_name, + args.server_urls, + world_size=args.world_size, + n_cycles=args.n_cycles, + ) + except Exception as e: + print(f"[Trainer] Error: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/tests/fast_llm_trainer_helper.py b/tests/fast_llm_trainer_helper.py new file mode 100644 index 00000000..6ff173e2 --- /dev/null +++ b/tests/fast_llm_trainer_helper.py @@ -0,0 +1,271 @@ +"""Helper functions for Fast-LLM weight broadcast testing.""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) +from trainer_test_utils import ( + _load_state_dict, + _create_perturbed_state_dict, + _wait_for_servers_ready, + _init_actor_process_group, +) + + +def timed_broadcast_fast_llm( + init_method: str, + model_name: str, + server_urls: list, + redis_host: str = "localhost", + redis_port: int = 6379, + world_size: int = 2, +): + """Timed broadcast using Fast-LLM protocol: perturbed → original → perturbed with delays. + + This simulates Fast-LLM's weight broadcast protocol where weight updates are signaled + via Redis stream and broadcast using broadcast_object_list + broadcast. + + Pattern: original (server default) → perturbed → original → perturbed + + Args: + init_method: Distributed init method + model_name: Model name to load + server_urls: Base URLs of vLLM server(s) (for health check only) + redis_host: Redis host address + redis_port: Redis port number + world_size: Total NCCL world size (trainer rank 0 + all vLLM workers) + """ + import torch + import torch.distributed as dist + import time + import redis + import orjson + + from fast_llm.engine.distributed.config import DistributedBackend + from fast_llm.engine.distributed.distributed import ProcessGroupPool + + print(f"[Trainer] Initializing process group as rank 0 (world_size={world_size})") + process_group = ProcessGroupPool( + rank=0, + world_size=world_size, + local_world_size=1, + init_method=init_method, + backend=DistributedBackend.nccl, + ).get_process_group(range(world_size), 0) + print("[Trainer] Process group initialized") + + # Connect to Redis + print(f"[Trainer] Connecting to Redis at {redis_host}:{redis_port}") + r = redis.Redis(host=redis_host, port=redis_port) + stream_key = "fast_llm_events" + payload_key = "event" + print(f"[Trainer] Connected to Redis, will write to stream '{stream_key}'") + + _wait_for_servers_ready(server_urls, extra_wait_secs=15) + + # Load weights + print(f"[Trainer] Loading original weights from {model_name}") + original_state_dict, _ = _load_state_dict(model_name) + perturbed_state_dict = _create_perturbed_state_dict(original_state_dict) + + from fast_llm.core.distributed import broadcast as _broadcast, broadcast_object as _broadcast_object + + # Helper function to broadcast weights using Fast-LLM protocol + def broadcast_weights_fast_llm(state_dict, step): + """Broadcast weights using Fast-LLM protocol. + + Protocol: + 1. Send Redis event: {type: "weights_ready", step: N} + 2. For each parameter: + - broadcast_object((shard_name, layer_name, shape, dtype)) + - broadcast(tensor) + 3. Send end signal: broadcast_object(None) + """ + # Send Redis stream event + event = {"type": "weights_ready", "step": step} + r.xadd(stream_key, {payload_key: orjson.dumps(event)}) + print(f"[Trainer] Sent Redis event to '{stream_key}': {event}") + + # Broadcast each parameter + for i, (name, tensor) in enumerate(state_dict.items()): + if tensor.device.type != "cuda": + tensor = tensor.cuda(0) + + _broadcast_object(("weights", name, list(tensor.shape), str(tensor.dtype)), process_group, src=0) + _broadcast(tensor, 0, process_group) + + if (i + 1) % 50 == 0: + print(f"[Trainer] Broadcasted {i+1}/{len(state_dict)} parameters") + + # Send end signal + _broadcast_object(None, process_group, src=0) + print(f"[Trainer] Sent end signal, broadcast complete") + + # Broadcast 1: Perturbed weights + print(f"[Trainer] Broadcasting {len(perturbed_state_dict)} perturbed parameters") + broadcast_weights_fast_llm(perturbed_state_dict, step=1) + print("[Trainer] Perturbed weights broadcast complete") + + print("[Trainer] Waiting 5 seconds before broadcasting original weights...") + time.sleep(5) + + # Broadcast 2: Original weights + print(f"[Trainer] Broadcasting {len(original_state_dict)} original parameters") + broadcast_weights_fast_llm(original_state_dict, step=2) + print("[Trainer] Original weights broadcast complete") + + print("[Trainer] Waiting 5 seconds before broadcasting perturbed weights again...") + time.sleep(5) + + # Broadcast 3: Perturbed weights again (same as first) + print(f"[Trainer] Broadcasting {len(perturbed_state_dict)} perturbed parameters (2nd time)") + broadcast_weights_fast_llm(perturbed_state_dict, step=3) + print("[Trainer] Perturbed weights broadcast complete (2nd time)") + + # Wait to allow generation with the last broadcast before tearing down + print("[Trainer] Waiting 5 seconds for generation with final weights...") + time.sleep(5) + + # Signal training is finished so vLLM workers destroy their side of the process group + print("[Trainer] Sending training_finished signal...") + r.xadd(stream_key, {payload_key: orjson.dumps({"type": "training_finished"})}) + + # Cleanup — destroy_process_group now resolves because vLLM workers respond to training_finished + r.close() + process_group.shutdown() + print("[Trainer] Redis connection closed, process group destroyed, exiting") + + +def rapid_broadcast_cycles_fast_llm( + init_method: str, + model_name: str, + server_urls: list, + redis_host: str = "localhost", + redis_port: int = 6379, + world_size: int = 2, + n_cycles: int = 6, +): + """Hybrid Fast-LLM broadcast designed to catch transition/garbage generations. + + Structure: + 1. Slow broadcast: perturbed (5 s wait after) — establishes text_B + 2. Slow broadcast: original (5 s wait after) — re-establishes text_A + 3. n_cycles rapid pairs: perturbed → original (1 s between each) + 4. Slow broadcast: perturbed (5 s wait after) — end on text_B so the + overall A→B→A→B pattern remains detectable + """ + import torch.distributed as dist + import time + import redis as redis_lib + import orjson + + from fast_llm.engine.distributed.config import DistributedBackend + from fast_llm.engine.distributed.distributed import ProcessGroupPool + + print(f"[Trainer] Initializing process group as rank 0 (world_size={world_size})") + process_group = ProcessGroupPool( + rank=0, + world_size=world_size, + local_world_size=1, + init_method=init_method, + backend=DistributedBackend.nccl, + ).get_process_group(range(world_size), 0) + print("[Trainer] Process group initialized") + + r = redis_lib.Redis(host=redis_host, port=redis_port) + stream_key = "fast_llm_events" + payload_key = "event" + + _wait_for_servers_ready(server_urls, extra_wait_secs=15) + + print(f"[Trainer] Loading weights from {model_name}") + original_state_dict, _ = _load_state_dict(model_name) + perturbed_state_dict = _create_perturbed_state_dict(original_state_dict) + + step = 1 + + def broadcast_weights(state_dict, label): + nonlocal step + import torch + event = {"type": "weights_ready", "step": step} + r.xadd(stream_key, {payload_key: orjson.dumps(event)}) + print(f"[Trainer] Sent weights_ready step={step} ({label})") + step += 1 + + from fast_llm.core.distributed import broadcast as _broadcast, broadcast_object as _broadcast_object + + for name, tensor in state_dict.items(): + if tensor.device.type != "cuda": + tensor = tensor.cuda(0) + _broadcast_object(("weights", name, list(tensor.shape), str(tensor.dtype)), process_group, src=0) + _broadcast(tensor, 0, process_group) + + _broadcast_object(None, process_group, src=0) + print(f"[Trainer] Broadcast complete ({label})") + + # --- Slow cycle: establish text_B and text_A clearly --- + print("[Trainer] Slow broadcast 1: perturbed (establishing text_B)...") + broadcast_weights(perturbed_state_dict, "perturbed slow") + time.sleep(5) + + print("[Trainer] Slow broadcast 2: original (re-establishing text_A)...") + broadcast_weights(original_state_dict, "original slow") + time.sleep(5) + + # --- Rapid cycles: 1 s between broadcasts --- + for i in range(n_cycles): + print(f"[Trainer] Rapid cycle {i + 1}/{n_cycles}: perturbed...") + broadcast_weights(perturbed_state_dict, f"perturbed rapid {i + 1}") + time.sleep(1) + + print(f"[Trainer] Rapid cycle {i + 1}/{n_cycles}: original...") + broadcast_weights(original_state_dict, f"original rapid {i + 1}") + time.sleep(1) + + # --- Final slow broadcast: end on perturbed so ABAB pattern holds --- + print("[Trainer] Final slow broadcast: perturbed (ending on text_B)...") + broadcast_weights(perturbed_state_dict, "perturbed final") + time.sleep(5) + + print("[Trainer] Sending training_finished signal...") + r.xadd(stream_key, {payload_key: orjson.dumps({"type": "training_finished"})}) + + r.close() + process_group.shutdown() + print("[Trainer] Redis connection closed, process group destroyed, exiting") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Fast-LLM trainer helper") + parser.add_argument("--init-method", required=True, help="Distributed init method") + parser.add_argument("--model", required=True, help="Model name") + parser.add_argument("--server-urls", nargs="+", required=True, help="Server URL(s)") + parser.add_argument("--world-size", type=int, default=2, help="Total distributed world size") + parser.add_argument("--redis-host", default="localhost", help="Redis host") + parser.add_argument("--redis-port", type=int, default=6379, help="Redis port") + parser.add_argument("--n-cycles", type=int, default=0, + help="If > 0, run rapid_broadcast_cycles with this many rapid pairs") + + args = parser.parse_args() + + if args.n_cycles > 0: + rapid_broadcast_cycles_fast_llm( + init_method=args.init_method, + model_name=args.model, + server_urls=args.server_urls, + redis_host=args.redis_host, + redis_port=args.redis_port, + world_size=args.world_size, + n_cycles=args.n_cycles, + ) + else: + timed_broadcast_fast_llm( + init_method=args.init_method, + model_name=args.model, + server_urls=args.server_urls, + redis_host=args.redis_host, + redis_port=args.redis_port, + world_size=args.world_size, + ) diff --git a/tests/server_weight_update_utils.py b/tests/server_weight_update_utils.py new file mode 100644 index 00000000..a4ade92a --- /dev/null +++ b/tests/server_weight_update_utils.py @@ -0,0 +1,653 @@ +"""Shared utilities for server weight update integration tests.""" + +import asyncio +import requests +import time +from pathlib import Path +import subprocess +import sys +import os + + +async def wait_for_server_ready(server_url: str, server_proc, trainer_proc, timeout_seconds: int = 300): + """Wait for server to be ready by polling health endpoint. + + Args: + server_url: Base URL of server (e.g., "http://127.0.0.1:8000") + server_proc: Server subprocess + trainer_proc: Trainer subprocess + timeout_seconds: Maximum time to wait + + Returns: + True if server is ready + + Raises: + RuntimeError: If server or trainer process terminates + TimeoutError: If server doesn't become ready within timeout + """ + print("[Main] Waiting for server to be ready...") + for i in range(timeout_seconds): + # Check if server process crashed + if server_proc.poll() is not None: + print(f"[Main] Server process terminated with code {server_proc.returncode}") + raise RuntimeError(f"Server process terminated with code {server_proc.returncode}") + + # Check if trainer process crashed + if trainer_proc.poll() is not None: + print(f"[Main] Trainer process terminated with code {trainer_proc.returncode}") + raise RuntimeError(f"Trainer process terminated with code {trainer_proc.returncode}") + + try: + resp = requests.get(f"{server_url}/health", timeout=1) + if resp.status_code == 200: + print("[Main] Server is ready!") + return True + except requests.exceptions.RequestException: + pass + + if i % 10 == 0: + print(f"[Main] Still waiting for server... ({i} seconds)") + await asyncio.sleep(1) + + raise TimeoutError(f"Server did not become ready within {timeout_seconds} seconds") + + +def _build_phases(generations): + """Collapse a generation list into (text, items) phase tuples.""" + phases = [] + current_text = None + current_phase = [] + for ts, text in generations: + if text != current_text: + if current_phase: + phases.append((current_text, current_phase)) + current_text = text + current_phase = [(ts, text)] + else: + current_phase.append((ts, text)) + if current_phase: + phases.append((current_text, current_phase)) + return phases + + +def _identify_stable_texts(phases, min_stable_gens=5): + """Return (text_a, text_b) identified from the first two stable phases. + + Iterates phases in order, skipping any with fewer than ``min_stable_gens`` + generations (transition artifacts). The first stable phase gives text_A; + the first stable phase with a different text gives text_B. + + Returns (text_a, text_b) or (None, None) if two distinct stable texts + cannot be found. + """ + text_a = None + text_b = None + for text, items in phases: + if len(items) < min_stable_gens: + continue + if text_a is None: + text_a = text + elif text != text_a: + text_b = text + break + if text_a is None or text_b is None: + return None, None + return text_a, text_b + + +def _find_abab_pattern(phases, min_stable_gens=5): + """Search for the A→B→A→B pattern. + + text_A and text_B are identified from the first two *stable* phases — + phases with at least ``min_stable_gens`` generations. Short transition + phases (1–few gens) produced while an NCCL broadcast is in-flight are + automatically skipped during identification. + + The test is designed so that the server always starts with a long run of + text_A (original weights, typically hundreds of gens) followed by a long + run of text_B (first perturbed broadcast, tens of gens), making them + unambiguous even with transition artifacts in between. + + After identifying text_A and text_B the full A→B→A→B subsequence is + located in the phase list (transition phases between the four anchors are + silently skipped). + + Returns (phase_a, phase_b, phase_a2, phase_b2) or None. + """ + if len(phases) < 4: + return None + + text_a, text_b = _identify_stable_texts(phases, min_stable_gens) + + if text_a is None or text_b is None: + return None + + texts = [t for t, _ in phases] + + # Find ABAB as a subsequence in the phase list + first_a = next((i for i, t in enumerate(texts) if t == text_a), None) + if first_a is None: + return None + + first_b = next((i for i in range(first_a + 1, len(phases)) if texts[i] == text_b), None) + if first_b is None: + return None + + second_a = next((i for i in range(first_b + 1, len(phases)) if texts[i] == text_a), None) + if second_a is None: + return None + + second_b = next((i for i in range(second_a + 1, len(phases)) if texts[i] == text_b), None) + if second_b is None: + return None + + return phases[first_a], phases[first_b], phases[second_a], phases[second_b] + + +def check_pattern_detected(generations): + """Check whether the full A→B→A→B pattern is present in the generation history. + + This is a **post-hoc analysis helper** (e.g. for assertions after the + generation loop ends). It is intentionally *not* used as an early-stop + signal inside the generation loops. + + Why not early-stop? Any transition artifact text T that happens to appear + with several consecutive identical generations (possible when NCCL broadcasts + are slow) is indistinguishable from the real perturbed text B at generation + time. False positives would cut the loop short before the final stable B + phase accumulates. The generation loops instead rely on the trainer process + exiting (``trainer_proc.poll() is not None``) as their sole reliable + termination signal — the trainer exits within milliseconds of completing its + last broadcast, so no significant extra generation happens. + + Args: + generations: List of (timestamp, text) tuples + + Returns: + True if the A→B→A→B pattern is present + """ + if len(generations) < 4: + return False + phases = _build_phases(generations) + return _find_abab_pattern(phases) is not None + + +async def run_generation_loop( + server_url: str, + model_name: str, + simple_prompt: str, + generation_config: dict, + trainer_proc, + max_duration: int = 120, + generation_interval: float = 0.5, +): + """Run continuous generation loop until pattern is detected or timeout. + + Args: + server_url: Base URL of server + model_name: Model name for API request + simple_prompt: Prompt to generate from + generation_config: Config dict with max_tokens, etc. + trainer_proc: Trainer subprocess to monitor + max_duration: Maximum duration in seconds + generation_interval: Time between generations + + Returns: + List of (timestamp, generated_text) tuples + """ + print("[Main] Starting continuous generation loop...") + generations = [] + start_time = time.time() + + while time.time() - start_time < max_duration: + # Check if trainer is still running + trainer_poll = trainer_proc.poll() + if trainer_poll is not None: + print(f"[Main] Trainer exited with code {trainer_poll}") + break + + try: + # Generate via HTTP API + payload = { + "model": model_name, + "prompt": simple_prompt, + "max_tokens": generation_config["max_tokens"], + "temperature": 0.0, # Deterministic + "top_p": 1.0, + "seed": 42, + } + + resp = requests.post( + f"{server_url}/v1/completions", + json=payload, + timeout=30, + ) + + if resp.status_code == 200: + result = resp.json() + generated_text = result["choices"][0]["text"] + timestamp = time.time() - start_time + generations.append((timestamp, generated_text)) + print(f"[Main] [{timestamp:.1f}s] Generated: '{generated_text}'") + else: + print(f"[Main] Generation failed with status {resp.status_code}") + + except requests.exceptions.RequestException as e: + print(f"[Main] Request failed: {e}") + + await asyncio.sleep(generation_interval) + + return generations + + +def analyze_and_verify_pattern(generations): + """Analyze generation sequence and verify the expected A→B→A→B pattern. + + Tolerates transition-artifact phases (e.g. a single generation produced + while an NCCL broadcast was in-flight) by searching for the pattern as + a subsequence rather than requiring it at exactly positions [0,1,2,3]. + + Args: + generations: List of (timestamp, text) tuples + + Returns: + Tuple of (text_a, text_b) — the original and perturbed texts. + + Raises: + AssertionError: If pattern is not as expected + """ + print("\n" + "=" * 60) + print("GENERATION SEQUENCE ANALYSIS") + print("=" * 60) + print(f"Total generations: {len(generations)}") + + for i, (ts, text) in enumerate(generations): + print(f"[{ts:5.1f}s] Gen {i+1}: '{text[:80]}...'") + + assert len(generations) >= 4, ( + f"Not enough generations to verify pattern (need at least 4, got {len(generations)})" + ) + + phases = _build_phases(generations) + + _GRAY = "\033[90m" + _RESET = "\033[0m" + stable_a, stable_b = _identify_stable_texts(phases) + stable_texts = {t for t in (stable_a, stable_b) if t is not None} + print("\n" + "=" * 60) + print(f"Detected {len(phases)} phase(s):") + for i, (text, items) in enumerate(phases): + line = f"Phase {i+1}: {len(items)} generation(s) - '{text[:60]}...'" + if text not in stable_texts: + print(f"{_GRAY}{line} ← transition{_RESET}") + else: + print(line) + print("=" * 60) + + result = _find_abab_pattern(phases) + assert result is not None, ( + f"Could not find A→B→A→B pattern in {len(phases)} phase(s). " + f"Phases: {[(text[:40], len(items)) for text, items in phases]}" + ) + + (phase_a_text, phase_a_items), (phase_b_text, phase_b_items), \ + (phase_a2_text, phase_a2_items), (phase_b2_text, phase_b2_items) = result + + # These hold by construction from _find_abab_pattern, but assert for clarity + assert phase_a_text != phase_b_text, "Phase A and Phase B should be different" + assert phase_a2_text == phase_a_text, "Second A should match first A (original weights restored)" + assert phase_b2_text == phase_b_text, "Second B should match first B (perturbed weights reapplied)" + + skipped = len(phases) - 4 + skip_note = f" ({skipped} transition phase(s) skipped)" if skipped else "" + print(f"\n✓ Pattern verified{skip_note}:") + print(f" Phase A (original): {len(phase_a_items)} generation(s)") + print(f" Phase B (perturbed): {len(phase_b_items)} generation(s)") + print(f" Phase A2 (original): {len(phase_a2_items)} generation(s) ← matches A ✓") + print(f" Phase B2 (perturbed): {len(phase_b2_items)} generation(s) ← matches B ✓") + + return phase_a_text, phase_b_text + + +def analyze_and_verify_pattern_multi(per_server_generations): + """Verify A→B→A→B pattern independently per server, then check consistency. + + Each server's generation history is checked independently (since weight + updates are not coordinated with requests, servers can transiently disagree). + After all pass, we assert that every server converged on the same text A + and text B. + + Args: + per_server_generations: List of per-server generation lists, each a + list of (timestamp, text) tuples (as returned by + run_generation_loop_multi). + + Raises: + AssertionError: If any server fails its pattern check or servers + disagree on text A / text B. + """ + patterns = [] + for i, generations in enumerate(per_server_generations): + print(f"\n{'=' * 60}") + print(f"Actor {i} pattern analysis") + text_a, text_b = analyze_and_verify_pattern(generations) + patterns.append((text_a, text_b)) + + unique_a = set(t_a for t_a, _ in patterns) + unique_b = set(t_b for _, t_b in patterns) + assert len(unique_a) == 1, ( + f"Servers disagree on text A (original weights): " + f"{[t_a[:40] for t_a, _ in patterns]}" + ) + assert len(unique_b) == 1, ( + f"Servers disagree on text B (perturbed weights): " + f"{[t_b[:40] for _, t_b in patterns]}" + ) + print(f"\n✓ All {len(patterns)} actor(s) agree on text A and text B") + + +def extract_transition_phases(generations, text_a, text_b): + """Return phases that are neither text_a nor text_b. + + These are mid-broadcast 'garbage' generations produced while an NCCL + weight update was in flight and the model had partially updated weights. + + Args: + generations: List of (timestamp, text) tuples + text_a: The original-weights text (established first) + text_b: The perturbed-weights text + + Returns: + List of (text, items) phase tuples where text is neither text_a nor text_b + """ + phases = _build_phases(generations) + return [(text, items) for text, items in phases if text != text_a and text != text_b] + + +def analyze_and_verify_transitions(generations, n_cycles): + """Verify A→B→A→B pattern and assert that transition generations were caught. + + The ``rapid_broadcast_cycles`` trainer command performs: + - 1 startup A phase (server starts on original weights) + - 1 slow perturbed broadcast → text_B + - 1 slow original broadcast → text_A + - n_cycles rapid pairs → text_B, text_A each cycle + - 1 final slow perturbed → text_B + + This gives exactly ``4 + 2 * n_cycles`` stable phases. Seeing fewer + means a broadcast was missed entirely (timing/sync bug). + + Args: + generations: List of (timestamp, text) tuples + n_cycles: Number of rapid broadcast pairs (passed as ``--n-cycles`` + to the trainer helper). + + Raises: + AssertionError: If ABAB pattern is not found, stable phase count is + wrong, or no transition generations were caught. + """ + text_a, text_b = analyze_and_verify_pattern(generations) + + phases = _build_phases(generations) + stable_phases = [(text, items) for text, items in phases if text == text_a or text == text_b] + expected_stable = 4 + 2 * n_cycles + assert len(stable_phases) == expected_stable, ( + f"Expected {expected_stable} stable phases (4 + 2×{n_cycles} cycles) " + f"but found {len(stable_phases)}. " + f"A broadcast may have been missed or merged. " + f"Stable phase counts: {[len(items) for _, items in stable_phases]}" + ) + print(f"\n✓ Stable phase count correct: {len(stable_phases)} (expected {expected_stable})") + + transition_phases = extract_transition_phases(generations, text_a, text_b) + + print("\n" + "=" * 60) + print(f"TRANSITION / GARBAGE GENERATIONS: {len(transition_phases)} phase(s)") + print("=" * 60) + if transition_phases: + for i, (text, items) in enumerate(transition_phases): + ts_start = items[0][0] + ts_end = items[-1][0] + print(f" [{i + 1}] {len(items)} gen(s) @ {ts_start:.2f}s–{ts_end:.2f}s: '{text[:120]}'") + else: + print(" (none)") + + assert len(transition_phases) > 0, ( + "No transition generations were caught. " + "Try increasing --n-cycles or verify generation_interval=0.0 is set." + ) + print(f"\n✓ Caught {len(transition_phases)} transition phase(s)") + + +def start_vllm_server( + model_name: str, + server_port: int, + distributed_init_method: str, + stream_process_output_fn, + extra_args: list = None, + gpu_ids: str = "0", + actor_llm_idx: int = 0, + world_size: int = 2, + tensor_parallel_size: int = 1, +): + """Start vLLM HTTP server subprocess. + + Args: + model_name: Model to load + server_port: Port to bind to + distributed_init_method: Distributed initialization method + stream_process_output_fn: Function to stream process output + extra_args: Additional CLI arguments (e.g., ["--weight-update-mode", "fast-llm"]) + gpu_ids: CUDA_VISIBLE_DEVICES value (e.g., "0" or "0,1") + actor_llm_idx: Actor index for this vLLM instance + world_size: Total distributed world size + tensor_parallel_size: Number of GPUs for tensor parallelism + + Returns: + Tuple of (server_proc, stdout_thread, stderr_thread) + """ + vllm_env = os.environ.copy() + vllm_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + vllm_env["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + + print(f"[Main] Starting vLLM HTTP server on port {server_port} (GPU(s) {gpu_ids}, actor_idx={actor_llm_idx}, TP={tensor_parallel_size})") + vllm_entry_point = Path(__file__).parent.parent / "pipelinerl" / "entrypoints" / "run_vllm1.py" + + cmd = [ + sys.executable, + str(vllm_entry_point), + "--model", model_name, + "--port", str(server_port), + "--host", "127.0.0.1", + "--actor-llm-idx", str(actor_llm_idx), + "--weight-update-group-init-method", distributed_init_method, + "--weight-update-group-world-size", str(world_size), + "--tensor-parallel-size", str(tensor_parallel_size), + ] + + if extra_args: + cmd.extend(extra_args) + + server_proc = subprocess.Popen( + cmd, + env=vllm_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + print("[Main] Starting server output streaming...") + stdout_thread, stderr_thread = stream_process_output_fn(server_proc, f"vLLM Server (actor {actor_llm_idx})") + + return server_proc, stdout_thread, stderr_thread + + +async def wait_for_all_servers_ready( + server_urls: list, + server_procs: list, + trainer_proc, + timeout_seconds: int = 300, +): + """Wait for all servers to be ready by polling their health endpoints. + + Args: + server_urls: List of server base URLs + server_procs: List of server subprocesses (same order as server_urls) + trainer_proc: Trainer subprocess + timeout_seconds: Maximum time to wait per server + + Returns: + True if all servers are ready + + Raises: + RuntimeError: If any process terminates unexpectedly + TimeoutError: If any server doesn't become ready within timeout + """ + for url, proc in zip(server_urls, server_procs): + await wait_for_server_ready(url, proc, trainer_proc, timeout_seconds) + return True + + +async def run_generation_loop_multi( + server_urls: list, + model_name: str, + simple_prompt: str, + generation_config: dict, + trainer_proc, + max_duration: int = 120, + generation_interval: float = 0.5, +): + """Run continuous generation loop querying all servers each round. + + Each server is tracked independently because weight updates and requests + are not coordinated — different actors can temporarily return different + results while a broadcast is in flight. Pattern checking is therefore + done per-server after the loop (see analyze_and_verify_pattern_multi). + + Args: + server_urls: List of server base URLs + model_name: Model name for API request + simple_prompt: Prompt to generate from + generation_config: Config dict with max_tokens, etc. + trainer_proc: Trainer subprocess to monitor + max_duration: Maximum duration in seconds + generation_interval: Time between generation rounds + + Returns: + List of per-server generation lists, each a list of + (timestamp, generated_text) tuples (same order as server_urls). + """ + print(f"[Main] Starting continuous generation loop across {len(server_urls)} server(s)...") + per_server = [[] for _ in server_urls] + start_time = time.time() + + payload = { + "model": model_name, + "prompt": simple_prompt, + "max_tokens": generation_config["max_tokens"], + "temperature": 0.0, + "top_p": 1.0, + "seed": 42, + } + + while time.time() - start_time < max_duration: + # Check if trainer is still running + trainer_poll = trainer_proc.poll() + if trainer_poll is not None: + print(f"[Main] Trainer exited with code {trainer_poll}") + break + + for i, url in enumerate(server_urls): + try: + resp = requests.post( + f"{url}/v1/completions", + json=payload, + timeout=30, + ) + if resp.status_code == 200: + text = resp.json()["choices"][0]["text"] + timestamp = time.time() - start_time + per_server[i].append((timestamp, text)) + print(f"[Main] [{timestamp:.1f}s] Actor {i}: '{text}'") + else: + print(f"[Main] Generation from actor {i} ({url}) failed with status {resp.status_code}") + except requests.exceptions.RequestException as e: + print(f"[Main] Request to actor {i} ({url}) failed: {e}") + + await asyncio.sleep(generation_interval) + + return per_server + + +def start_trainer_process( + trainer_helper_path: Path, + distributed_init_method: str, + model_name: str, + server_urls: list, + stream_process_output_fn, + extra_args: list = None, + gpu_id: str = "1", + world_size: int = 2, + command: str = "timed_broadcast_server_test", +): + """Start trainer subprocess. + + Args: + trainer_helper_path: Path to trainer helper script + distributed_init_method: Distributed initialization method + model_name: Model name + server_urls: List of server URLs (one per actor) + stream_process_output_fn: Function to stream process output + extra_args: Additional CLI arguments (e.g., ["--n-cycles", "6"]) + gpu_id: CUDA_VISIBLE_DEVICES value for the trainer GPU + world_size: Total distributed world size + command: Positional command for distributed_trainer_helper.py + (ignored for fast_llm_trainer_helper.py which uses --init-method style) + + Returns: + Tuple of (trainer_proc, stdout_thread, stderr_thread) + """ + trainer_env = os.environ.copy() + trainer_env["CUDA_VISIBLE_DEVICES"] = gpu_id + + print(f"[Main] Starting trainer process (GPU {gpu_id}) for process group rendezvous") + + cmd = [ + sys.executable, + str(trainer_helper_path), + ] + + # Check which trainer helper is being used by the script name + if "fast_llm" in str(trainer_helper_path): + # fast_llm_trainer_helper.py uses argparse with --init-method, --model, etc. + cmd.extend([ + "--init-method", distributed_init_method, + "--model", model_name, + "--world-size", str(world_size), + "--server-urls", + ] + list(server_urls)) + else: + # distributed_trainer_helper.py uses positional command + flags + cmd.extend([ + command, + "--init-method", distributed_init_method, + "--model-name", model_name, + "--world-size", str(world_size), + "--server-urls", + ] + list(server_urls)) + + if extra_args: + cmd.extend(extra_args) + + trainer_proc = subprocess.Popen( + cmd, + env=trainer_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + print("[Main] Starting trainer output streaming...") + stdout_thread, stderr_thread = stream_process_output_fn(trainer_proc, "Trainer") + + return trainer_proc, stdout_thread, stderr_thread diff --git a/tests/sync_helper.py b/tests/sync_helper.py new file mode 100644 index 00000000..73ebed82 --- /dev/null +++ b/tests/sync_helper.py @@ -0,0 +1,110 @@ +"""Simple file-based synchronization for distributed test processes.""" + +import time +from pathlib import Path + + +class SyncPoint: + """File-based synchronization point for coordinating subprocesses.""" + + def __init__(self, sync_dir: Path, name: str): + """Create a sync point. + + Args: + sync_dir: Directory for sync files + name: Name of this sync point (e.g., "baseline_done") + """ + self.sync_file = sync_dir / f"{name}.sync" + self.sync_dir = sync_dir + + def signal(self): + """Signal that this point is reached.""" + self.sync_file.touch() + # Force filesystem sync to ensure file is visible immediately + import os + fd = os.open(str(self.sync_file.parent), os.O_RDONLY) + os.fsync(fd) + os.close(fd) + print(f"[Sync] Signaled: {self.sync_file.name}") + + def wait(self, timeout: float = 60): + """Wait for this point to be signaled. + + Args: + timeout: Maximum time to wait in seconds + + Raises: + TimeoutError: If sync point not reached within timeout + """ + start = time.time() + while not self.sync_file.exists(): + if time.time() - start > timeout: + raise TimeoutError( + f"Timeout waiting for sync point: {self.sync_file.name}" + ) + time.sleep(0.1) + print(f"[Sync] Reached: {self.sync_file.name}") + + def clear(self): + """Clear this sync point.""" + if self.sync_file.exists(): + self.sync_file.unlink() + + +def create_sync_dir(base_dir: Path) -> Path: + """Create a directory for sync files. + + Args: + base_dir: Base temporary directory + + Returns: + Path to sync directory + """ + sync_dir = base_dir / "sync" + sync_dir.mkdir(exist_ok=True) + return sync_dir + + +def write_weight_update_request(sync_dir: Path, request): + """Write WeightUpdateRequest to JSON file. + + Args: + sync_dir: Sync directory + request: WeightUpdateRequest object + """ + import json + + request_file = sync_dir / "weight_update_request.json" + with open(request_file, "w") as f: + json.dump(request.model_dump(), f) + print(f"[Sync] Wrote weight update request to {request_file.name}") + + +def read_weight_update_request(sync_dir: Path): + """Read WeightUpdateRequest from JSON file. + + Args: + sync_dir: Sync directory + + Returns: + WeightUpdateRequest object + """ + import json + from pipelinerl.finetune_loop import WeightUpdateRequest + + request_file = sync_dir / "weight_update_request.json" + + # Wait for file to exist + import time + timeout = 60 + start = time.time() + while not request_file.exists(): + if time.time() - start > timeout: + raise TimeoutError(f"Timeout waiting for {request_file.name}") + time.sleep(0.1) + + with open(request_file, "r") as f: + data = json.load(f) + + print(f"[Sync] Read weight update request from {request_file.name}") + return WeightUpdateRequest(**data) diff --git a/tests/test_actor_error_handling.py b/tests/test_actor_error_handling.py new file mode 100644 index 00000000..61adc5eb --- /dev/null +++ b/tests/test_actor_error_handling.py @@ -0,0 +1,290 @@ +"""Test that actor rollout error handling doesn't crash the entire actor. + +Specifically tests that: +1. HTTP 4xx errors from vLLM (e.g., max_tokens too large) are handled gracefully +2. Groups where ALL rollouts fail are dropped (not submitted) +3. Groups where SOME rollouts fail submit only valid results +4. HTTP 5xx errors still propagate as fatal +""" + +import asyncio +import queue +from unittest.mock import MagicMock, AsyncMock, patch + +import aiohttp +import pytest +from omegaconf import OmegaConf + +from pipelinerl.rollouts import BaseMetrics, RolloutResult, TrainingText + + +# --------------------------------------------------------------------------- +# Helpers – lightweight stand-ins for heavy classes used by schedule_rollouts +# --------------------------------------------------------------------------- + +class FakeQueue: + """Minimal stand-in for SharedMemoryQueue (no shared memory needed).""" + + def __init__(self): + self._q = queue.Queue() + + def put(self, item, block=True, timeout=None): + self._q.put(item) + + def get(self, block=True, timeout=None): + return self._q.get(block=block, timeout=timeout) + + def qsize(self): + return self._q.qsize() + + def max_actual_entry_size(self): + return 0 + + def get_memory_size(self): + return 0 + + +class FakeTrainerState: + def __init__(self): + self.propagated_weight_version = 1 + self.samples_processed = 0 + + +def make_good_result() -> RolloutResult: + """A valid rollout result with one training sample.""" + return RolloutResult( + training_texts=[ + TrainingText( + text="prompt output", + n_predicted=6, + reward=1.0, + input_ids=[1, 2, 3], + labels=[-100, 2, 3], + finished=True, + prompt_tokens=5, + output_tokens=6, + ) + ], + metrics=BaseMetrics(reward=1.0, success=True, no_error=True, no_answer=False), + latency=0.5, + ) + + +def make_client_response_error(status: int, message: str = "Bad Request"): + """Create an aiohttp.ClientResponseError.""" + mock_req = MagicMock() + mock_req.url = "http://localhost:8080/v1/chat/completions" + return aiohttp.ClientResponseError( + request_info=mock_req, + history=(), + status=status, + message=message, + ) + + +# --------------------------------------------------------------------------- +# Core test: exercise rollout_and_maybe_produce_result + group completion +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_all_rollouts_fail_group_dropped(): + """When all rollouts in a group fail with 4xx, the group should be dropped.""" + attempts = 4 + problem_q = FakeQueue() + result_q = FakeQueue() + trainer_state = FakeTrainerState() + + # Put one problem in the queue + problem_q.put({"task": "What is 2+2?", "answer": "4"}) + + call_count = 0 + + async def failing_rollout_policy(cfg, llm, problem, session): + nonlocal call_count + call_count += 1 + raise make_client_response_error(400, "max_tokens too large") + + cfg = OmegaConf.create({ + "actor": { + "rollout_policy": "not_used", # we patch it + "llm_max_rollouts": 64, + }, + "finetune": { + "train_batch_size": 1000, + "gradient_accumulation_passes": 1, + "train_iters": 100, + "interrupt_train_steps": None, + }, + "debug": {}, + }) + + llms = [MagicMock()] # 1 LLM + + # We can't easily run schedule_rollouts (too many dependencies), + # so we directly test the inner logic by reimplementing the key parts. + # This mirrors rollout_and_maybe_produce_result + group completion. + + group_rollouts = {} + group_id = 0 + group_rollouts[group_id] = [] + finished_rollouts = 0 + warnings_logged = [] + + for rollout_index in range(attempts): + try: + rollout_result = await failing_rollout_policy(cfg, llms[0], {"task": "x"}, None) + except aiohttp.ClientResponseError as e: + if 400 <= e.status < 500: + warnings_logged.append(str(e.status)) + rollout_result = RolloutResult( + training_texts=[], + metrics=BaseMetrics(reward=0.0, success=False, no_error=False, no_answer=True), + latency=0.0, + ) + else: + raise + + rollout_result.model_version = 1 + rollout_result.group_id = f"test_{group_id}" + group_rollouts[group_id].append(rollout_result) + + # Now check group completion logic + assert len(group_rollouts[group_id]) == attempts + valid_results = [r for r in group_rollouts[group_id] if r.training_texts] + + # All failed → group should be dropped + assert len(valid_results) == 0, "Expected all results to be empty" + assert call_count == attempts + assert len(warnings_logged) == attempts + + # In real code: del group_rollouts[group_id], don't put in result_q + del group_rollouts[group_id] + assert result_q.qsize() == 0, "No group should be in the result queue" + + +@pytest.mark.asyncio +async def test_partial_failure_submits_valid_only(): + """When some rollouts fail but others succeed, submit only valid ones.""" + attempts = 4 + result_q = FakeQueue() + + call_count = 0 + + async def mixed_rollout_policy(cfg, llm, problem, session): + nonlocal call_count + call_count += 1 + # First 2 calls fail, last 2 succeed + if call_count <= 2: + raise make_client_response_error(400, "max_tokens too large") + return make_good_result() + + group_rollouts = {} + group_id = 0 + group_rollouts[group_id] = [] + + for rollout_index in range(attempts): + try: + rollout_result = await mixed_rollout_policy(None, None, {"task": "x"}, None) + except aiohttp.ClientResponseError as e: + if 400 <= e.status < 500: + rollout_result = RolloutResult( + training_texts=[], + metrics=BaseMetrics(reward=0.0, success=False, no_error=False, no_answer=True), + latency=0.0, + ) + else: + raise + + rollout_result.model_version = 1 + rollout_result.group_id = f"test_{group_id}" + group_rollouts[group_id].append(rollout_result) + + assert len(group_rollouts[group_id]) == attempts + + valid_results = [r for r in group_rollouts[group_id] if r.training_texts] + + # 2 failed, 2 succeeded + assert len(valid_results) == 2, f"Expected 2 valid results, got {len(valid_results)}" + + # In real code: result_queue.put(valid_results) + result_q.put(valid_results) + got = result_q.get(block=False) + assert len(got) == 2 + assert all(len(r.training_texts) > 0 for r in got) + + +@pytest.mark.asyncio +async def test_5xx_errors_still_propagate(): + """HTTP 5xx errors should NOT be caught — they indicate server failure.""" + + async def server_error_policy(cfg, llm, problem, session): + raise make_client_response_error(500, "Internal Server Error") + + with pytest.raises(aiohttp.ClientResponseError) as exc_info: + try: + await server_error_policy(None, None, {"task": "x"}, None) + except aiohttp.ClientResponseError as e: + if 400 <= e.status < 500: + pass # Would be caught in real code + else: + raise # 5xx re-raised + + assert exc_info.value.status == 500 + + +@pytest.mark.asyncio +async def test_all_succeed_normal_path(): + """When all rollouts succeed, the full group is submitted.""" + attempts = 4 + result_q = FakeQueue() + + async def good_policy(cfg, llm, problem, session): + return make_good_result() + + group_rollouts = {} + group_id = 0 + group_rollouts[group_id] = [] + + for rollout_index in range(attempts): + try: + rollout_result = await good_policy(None, None, {"task": "x"}, None) + except aiohttp.ClientResponseError as e: + if 400 <= e.status < 500: + rollout_result = RolloutResult( + training_texts=[], + metrics=BaseMetrics(reward=0.0, success=False, no_error=False, no_answer=True), + latency=0.0, + ) + else: + raise + + rollout_result.model_version = 1 + rollout_result.group_id = f"test_{group_id}" + group_rollouts[group_id].append(rollout_result) + + valid_results = [r for r in group_rollouts[group_id] if r.training_texts] + assert len(valid_results) == attempts, "All rollouts should be valid" + + result_q.put(valid_results) + got = result_q.get(block=False) + assert len(got) == attempts + + +@pytest.mark.asyncio +async def test_consumer_assertion_accepts_partial_group(): + """The consumer-side assertion should accept groups with fewer than `attempts` results.""" + attempts = 8 + # Simulate a partial group with 5 valid results + partial_count = 5 + + results = [make_good_result() for _ in range(partial_count)] + + # This mirrors the relaxed assertion in actor.py + assert isinstance(results, list) + assert isinstance(results[0], RolloutResult) + assert 0 < len(results) <= attempts, ( + f"Expected 1-{attempts} rollouts, got {len(results)}" + ) + + group_samples = sum(len(r.training_texts) for r in results) + assert group_samples == partial_count diff --git a/tests/test_vllm1_fast_llm_broadcast.py b/tests/test_vllm1_fast_llm_broadcast.py new file mode 100644 index 00000000..f7cc410b --- /dev/null +++ b/tests/test_vllm1_fast_llm_broadcast.py @@ -0,0 +1,590 @@ +"""Integration tests for vllm1 with Fast-LLM weight broadcast protocol.""" + +import asyncio +import pytest +import tempfile +from pathlib import Path +from typing import Dict, List +import time +import os +import subprocess +import sys +import signal + +# torch is needed at top level for pytest.mark.skipif decorators +import torch + +# Import shared utilities +from .server_weight_update_utils import ( + wait_for_server_ready, + wait_for_all_servers_ready, + run_generation_loop, + run_generation_loop_multi, + analyze_and_verify_pattern, + analyze_and_verify_pattern_multi, + analyze_and_verify_transitions, + start_vllm_server, + start_trainer_process, +) + +try: + import psutil + + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + print("WARNING: psutil not available, process tree cleanup will be limited") + + +def stream_process_output(proc, name): + """Start background threads to continuously stream process stdout/stderr. + + Args: + proc: subprocess.Popen object + name: Name for logging prefix (e.g., "vLLM Server", "Trainer") + + Returns: + Tuple of (stdout_thread, stderr_thread) + """ + import threading + + def read_stream(stream, prefix): + """Read from stream and print with prefix.""" + try: + for line in iter(stream.readline, ""): + if line: + print(f"{prefix} {line.rstrip()}", flush=True) + except Exception as e: + print(f"{prefix} [Stream read error: {e}]", flush=True) + + stdout_thread = threading.Thread( + target=read_stream, + args=(proc.stdout, f"[{name} OUT]"), + daemon=True, + ) + stderr_thread = threading.Thread( + target=read_stream, + args=(proc.stderr, f"[{name} ERR]"), + daemon=True, + ) + + stdout_thread.start() + stderr_thread.start() + + return stdout_thread, stderr_thread + + +def kill_process_tree(pid, sig=signal.SIGKILL): + """Kill a process and all its children/grandchildren. + + Args: + pid: Process ID to kill + sig: Signal to send (default SIGKILL) + """ + if not HAS_PSUTIL: + # Fallback: just kill the main process + try: + os.kill(pid, sig) + except ProcessLookupError: + pass + return + + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Kill children first + for child in children: + try: + print(f"[Kill] Killing child process {child.pid}") + child.send_signal(sig) + except psutil.NoSuchProcess: + pass + + # Kill parent + try: + parent.send_signal(sig) + except psutil.NoSuchProcess: + pass + + +@pytest.fixture +def fast_llm_trainer_helper(): + """Path to Fast-LLM trainer helper script.""" + return Path(__file__).parent / "fast_llm_trainer_helper.py" + + +@pytest.fixture +def redis_server(): + """Start a Redis server for testing and stop it after the test. + + Returns: + Tuple of (host, port) for the Redis server + """ + import shutil + import socket + + # Check if redis-server is available + redis_server_bin = shutil.which("redis-server") + if not redis_server_bin: + pytest.skip("redis-server not found in PATH") + + # Find an available port + def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + s.listen(1) + port = s.getsockname()[1] + return port + + redis_port = find_free_port() + redis_host = "localhost" + + print(f"[Redis] Starting Redis server on {redis_host}:{redis_port}") + + # Start Redis server with minimal config + redis_proc = subprocess.Popen( + [ + redis_server_bin, + "--port", str(redis_port), + "--bind", redis_host, + "--save", "", # Disable persistence + "--appendonly", "no", # Disable AOF + "--protected-mode", "no", # Allow connections without password + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Start streaming Redis output + redis_stdout_thread, redis_stderr_thread = stream_process_output(redis_proc, "Redis") + + # Wait for Redis to be ready + import redis + r = redis.Redis(host=redis_host, port=redis_port) + for i in range(30): + try: + r.ping() + print(f"[Redis] Server ready on {redis_host}:{redis_port}") + break + except redis.ConnectionError: + if redis_proc.poll() is not None: + raise RuntimeError(f"Redis server failed to start (exit code {redis_proc.returncode})") + time.sleep(0.1) + else: + redis_proc.kill() + raise TimeoutError("Redis server did not start within 3 seconds") + + try: + yield (redis_host, redis_port) + finally: + # Cleanup + print(f"[Redis] Stopping Redis server (PID {redis_proc.pid})") + redis_proc.terminate() + try: + redis_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + print("[Redis] Redis did not stop gracefully, killing...") + redis_proc.kill() + redis_proc.wait() + print("[Redis] Redis server stopped") + + +# --------------------------------------------------------------------------- +# Module-level helper shared by all Fast-LLM test variants +# --------------------------------------------------------------------------- + +async def _run_fast_llm_server_test( + model_name, + simple_prompt, + generation_config, + init_method, + fast_llm_trainer_helper, + redis_host, + redis_port, + vllm_server_configs, + trainer_gpu, + world_size, + timeout=2400, +): + """Run Fast-LLM server weight-update pattern test with one or more vLLM servers. + + Args: + vllm_server_configs: List of dicts, each with keys: + - port: int + - gpu_ids: str + - actor_llm_idx: int + - tensor_parallel_size: int + trainer_gpu: str, e.g. "1" or "2" + world_size: total NCCL world size (trainer + all vLLM workers) + redis_host: Redis host address + redis_port: Redis port number + """ + server_procs = [] + server_urls = [] + + fast_llm_server_args = [ + "--weight-update-mode", "fast-llm", + "--redis-host", redis_host, + "--redis-port", str(redis_port), + ] + + for cfg in vllm_server_configs: + port = cfg["port"] + url = f"http://127.0.0.1:{port}" + server_urls.append(url) + + server_proc, _, _ = start_vllm_server( + model_name=model_name, + server_port=port, + distributed_init_method=init_method, + stream_process_output_fn=stream_process_output, + extra_args=fast_llm_server_args, + gpu_ids=cfg.get("gpu_ids", "0"), + actor_llm_idx=cfg.get("actor_llm_idx", 0), + world_size=world_size, + tensor_parallel_size=cfg.get("tensor_parallel_size", 1), + ) + server_procs.append(server_proc) + + await asyncio.sleep(1) + + trainer_proc, _, _ = start_trainer_process( + trainer_helper_path=fast_llm_trainer_helper, + distributed_init_method=init_method, + model_name=model_name, + server_urls=server_urls, + stream_process_output_fn=stream_process_output, + extra_args=[ + "--redis-host", redis_host, + "--redis-port", str(redis_port), + ], + gpu_id=trainer_gpu, + world_size=world_size, + ) + + try: + await wait_for_all_servers_ready(server_urls, server_procs, trainer_proc) + + if len(server_urls) == 1: + generations = await run_generation_loop( + server_url=server_urls[0], + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + trainer_proc=trainer_proc, + ) + else: + per_server_generations = await run_generation_loop_multi( + server_urls=server_urls, + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + trainer_proc=trainer_proc, + ) + + # Wait for trainer to finish + print("[Main] Waiting for trainer to finish...") + for _ in range(30): + if trainer_proc.poll() is not None: + break + await asyncio.sleep(1) + + if len(server_urls) == 1: + analyze_and_verify_pattern(generations) + else: + analyze_and_verify_pattern_multi(per_server_generations) + print(f"\n✓ Fast-LLM server weight update pattern test PASSED ({len(server_urls)} server(s))") + + finally: + print("[Main] Cleaning up processes...") + for proc in server_procs: + if proc: + kill_process_tree(proc.pid) + if trainer_proc: + kill_process_tree(trainer_proc.pid) + + +class TestFastLLMServerIntegration: + """Test Fast-LLM weight broadcast with vLLM HTTP server — 2 GPUs (baseline).""" + + @pytest.mark.timeout(2400) # 40 minutes for server test + @pytest.mark.asyncio + @pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs" + ) + async def test_server_fast_llm_broadcast_pattern( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + fast_llm_trainer_helper, + redis_server, + temp_dir, + ): + """Server integration test: verify Fast-LLM weight broadcast pattern with HTTP API. + + Validates the Fast-LLM protocol where: + - Redis server signals weight updates + - vLLM server receives weights via broadcast_object_list + broadcast + - Server responses change as expected (original → perturbed → original → perturbed) + + Topology: 1 vLLM server on GPU 0, trainer on GPU 1 (world_size=2). + """ + print("\n" + "=" * 60) + print("Starting Fast-LLM server weight update pattern test (TP=1, 1 actor, 2 GPUs)") + print("=" * 60) + + redis_host, redis_port = redis_server + + await _run_fast_llm_server_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + fast_llm_trainer_helper=fast_llm_trainer_helper, + redis_host=redis_host, + redis_port=redis_port, + vllm_server_configs=[{"port": 8000, "gpu_ids": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}], + trainer_gpu="1", + world_size=2, + timeout=2400, + ) + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs" + ) + async def test_fast_llm_server_catch_transitions( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + fast_llm_trainer_helper, + redis_server, + temp_dir, + ): + """Diagnostic test: catch garbage generations during Fast-LLM weight broadcasts. + + The trainer runs a slow initial cycle (perturbed → original, 5 s each) + to firmly establish text_A and text_B, then fires N rapid back-to-back + broadcast cycles (perturbed → original) with no inter-broadcast delay. + The generation loop runs with generation_interval=0.0 to maximise the + chance of hitting a mid-broadcast state. + + Assertions: + 1. The A→B→A→B pattern is still detected (broadcasts actually worked). + 2. At least one transition/garbage phase was captured. + + Topology: 1 vLLM server on GPU 0, trainer on GPU 1 (world_size=2). + """ + print("\n" + "=" * 60) + print("Starting Fast-LLM transition-capture test (TP=1, 1 actor, 2 GPUs)") + print("=" * 60) + + redis_host, redis_port = redis_server + server_url = "http://127.0.0.1:8000" + + server_proc, _, _ = start_vllm_server( + model_name=model_name, + server_port=8000, + distributed_init_method=distributed_init_method, + stream_process_output_fn=stream_process_output, + extra_args=[ + "--weight-update-mode", "fast-llm", + "--redis-host", redis_host, + "--redis-port", str(redis_port), + ], + gpu_ids="0", + actor_llm_idx=0, + world_size=2, + tensor_parallel_size=1, + ) + + await asyncio.sleep(1) + + trainer_proc, _, _ = start_trainer_process( + trainer_helper_path=fast_llm_trainer_helper, + distributed_init_method=distributed_init_method, + model_name=model_name, + server_urls=[server_url], + stream_process_output_fn=stream_process_output, + extra_args=[ + "--redis-host", redis_host, + "--redis-port", str(redis_port), + "--n-cycles", "6", + ], + gpu_id="1", + world_size=2, + ) + + try: + await wait_for_server_ready(server_url, server_proc, trainer_proc) + + generations = await run_generation_loop( + server_url=server_url, + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + trainer_proc=trainer_proc, + generation_interval=0.0, + ) + + print("[Main] Waiting for trainer to finish...") + for _ in range(30): + if trainer_proc.poll() is not None: + break + await asyncio.sleep(1) + + analyze_and_verify_transitions(generations, n_cycles=6) + print("\n✓ Fast-LLM transition-capture test PASSED") + + finally: + print("[Main] Cleaning up processes...") + if server_proc: + kill_process_tree(server_proc.pid) + if trainer_proc: + kill_process_tree(trainer_proc.pid) + + +class TestFastLLMServerTP2: + """Test Fast-LLM weight broadcast with tensor-parallel (TP=2) — needs 3 GPUs.""" + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif( + torch.cuda.device_count() < 3, reason="Requires at least 3 GPUs" + ) + async def test_server_fast_llm_broadcast_pattern_tp2( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + fast_llm_trainer_helper, + redis_server, + temp_dir, + ): + """Fast-LLM server test with TP=2: one server on GPUs 0+1, trainer on GPU 2. + + Verifies that tensor-parallel vLLM correctly receives Fast-LLM weight + updates when multiple GPU workers share the same NCCL process group. + """ + print("\n" + "=" * 60) + print("Starting Fast-LLM server weight update pattern test (TP=2, 1 actor, 3 GPUs)") + print("=" * 60) + + redis_host, redis_port = redis_server + + await _run_fast_llm_server_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + fast_llm_trainer_helper=fast_llm_trainer_helper, + redis_host=redis_host, + redis_port=redis_port, + vllm_server_configs=[{"port": 8001, "gpu_ids": "0,1", "actor_llm_idx": 0, "tensor_parallel_size": 2}], + trainer_gpu="2", + world_size=3, + timeout=2400, + ) + + +class TestFastLLMServerMultiActor: + """Test Fast-LLM weight broadcast with multiple independent vLLM actors.""" + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif( + torch.cuda.device_count() < 3, reason="Requires at least 3 GPUs" + ) + async def test_server_fast_llm_broadcast_pattern_2actors( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + fast_llm_trainer_helper, + redis_server, + temp_dir, + ): + """Fast-LLM server test with 2 actors: servers on GPUs 0 and 1, trainer on GPU 2. + + Verifies that two separate vLLM servers simultaneously receive the same + Fast-LLM weight broadcast and produce identical generation results. + """ + print("\n" + "=" * 60) + print("Starting Fast-LLM server weight update pattern test (TP=1, 2 actors, 3 GPUs)") + print("=" * 60) + + redis_host, redis_port = redis_server + + await _run_fast_llm_server_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + fast_llm_trainer_helper=fast_llm_trainer_helper, + redis_host=redis_host, + redis_port=redis_port, + vllm_server_configs=[ + {"port": 8000, "gpu_ids": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}, + {"port": 8001, "gpu_ids": "1", "actor_llm_idx": 1, "tensor_parallel_size": 1}, + ], + trainer_gpu="2", + world_size=3, + timeout=2400, + ) + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif( + torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs" + ) + async def test_server_fast_llm_broadcast_pattern_3actors( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + fast_llm_trainer_helper, + redis_server, + temp_dir, + ): + """Fast-LLM server test with 3 actors: servers on GPUs 0/1/2, trainer on GPU 3. + + Verifies that three separate vLLM servers simultaneously receive the same + Fast-LLM weight broadcast and produce identical generation results. + """ + print("\n" + "=" * 60) + print("Starting Fast-LLM server weight update pattern test (TP=1, 3 actors, 4 GPUs)") + print("=" * 60) + + redis_host, redis_port = redis_server + + await _run_fast_llm_server_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + fast_llm_trainer_helper=fast_llm_trainer_helper, + redis_host=redis_host, + redis_port=redis_port, + vllm_server_configs=[ + {"port": 8000, "gpu_ids": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}, + {"port": 8001, "gpu_ids": "1", "actor_llm_idx": 1, "tensor_parallel_size": 1}, + {"port": 8002, "gpu_ids": "2", "actor_llm_idx": 2, "tensor_parallel_size": 1}, + ], + trainer_gpu="3", + world_size=4, + timeout=2400, + ) diff --git a/tests/test_vllm1_integration.py b/tests/test_vllm1_integration.py new file mode 100644 index 00000000..6c2d68f0 --- /dev/null +++ b/tests/test_vllm1_integration.py @@ -0,0 +1,1313 @@ +"""Integration tests for vllm1 with actual distributed setup.""" + +import asyncio +import pytest +import tempfile +from pathlib import Path +from typing import Dict, List +import time +import os +import subprocess +import sys +import signal + +# torch is needed at top level for pytest.mark.skipif decorators +import torch + +# Import shared utilities +from .server_weight_update_utils import ( + wait_for_server_ready, + wait_for_all_servers_ready, + run_generation_loop, + run_generation_loop_multi, + analyze_and_verify_pattern, + analyze_and_verify_pattern_multi, + analyze_and_verify_transitions, + start_vllm_server, + start_trainer_process, +) + +try: + import psutil + HAS_PSUTIL = True +except ImportError: + HAS_PSUTIL = False + print("WARNING: psutil not available, process tree cleanup will be limited") + + +def stream_process_output(proc, name): + """Start background threads to continuously stream process stdout/stderr. + + Args: + proc: subprocess.Popen object + name: Name for logging prefix (e.g., "vLLM Server", "Trainer") + + Returns: + Tuple of (stdout_thread, stderr_thread) + """ + import threading + + def read_stream(stream, prefix): + """Read from stream and print with prefix.""" + try: + for line in iter(stream.readline, ''): + if line: + print(f"{prefix} {line.rstrip()}", flush=True) + except Exception as e: + print(f"{prefix} [Stream read error: {e}]", flush=True) + + stdout_thread = threading.Thread( + target=read_stream, + args=(proc.stdout, f"[{name} OUT]"), + daemon=True, + ) + stderr_thread = threading.Thread( + target=read_stream, + args=(proc.stderr, f"[{name} ERR]"), + daemon=True, + ) + + stdout_thread.start() + stderr_thread.start() + + return stdout_thread, stderr_thread + + +def kill_process_tree(pid, sig=signal.SIGKILL): + """Kill a process and all its children/grandchildren. + + Args: + pid: Process ID to kill + sig: Signal to send (default SIGKILL) + """ + if not HAS_PSUTIL: + # Fallback: just kill the main process + try: + os.kill(pid, sig) + except ProcessLookupError: + pass + return + + try: + parent = psutil.Process(pid) + except psutil.NoSuchProcess: + return + + # Get all children recursively + children = parent.children(recursive=True) + + # Kill children first + for child in children: + try: + print(f"[Kill] Killing child process {child.pid}") + child.send_signal(sig) + except psutil.NoSuchProcess: + pass + + # Kill parent + try: + parent.send_signal(sig) + except psutil.NoSuchProcess: + pass + + +def force_kill_process(proc, name): + """Forcefully kill a process tree and collect output. + + SIGKILL always kills the process. If communicate() hangs, it's the PIPES + that are stuck, not the process. We handle this with retries and timeouts. + + Returns: + Tuple of (stdout, stderr, returncode) + """ + # If already dead, try to get output + if proc.poll() is not None: + try: + stdout, stderr = proc.communicate(timeout=2) + return stdout, stderr, proc.returncode + except subprocess.TimeoutExpired: + print(f"[Kill] {name} already dead but pipes hung, closing...") + proc.stdout.close() if proc.stdout else None + proc.stderr.close() if proc.stderr else None + return "", "", proc.returncode + + # Kill entire process tree (including vLLM workers, trainer subprocesses, etc) + print(f"[Kill] Killing {name} process tree (PID {proc.pid})...") + kill_process_tree(proc.pid, signal.SIGKILL) + + # Wait for main process to actually die + try: + proc.wait(timeout=2) + print(f"[Kill] {name} process tree killed") + except subprocess.TimeoutExpired: + print(f"[Kill] WARNING: {name} didn't die after SIGKILL") + + # Try to read output from pipes (this is what usually hangs) + for attempt, timeout_val in enumerate([1, 2, 3], start=1): + try: + stdout, stderr = proc.communicate(timeout=timeout_val) + print(f"[Kill] {name} output collected (attempt {attempt})") + return stdout, stderr, proc.returncode + except subprocess.TimeoutExpired: + print(f"[Kill] {name} communicate() timed out (attempt {attempt})") + continue + + # Pipes are stuck - force close them + print(f"[Kill] {name} pipes stuck, force closing...") + try: + proc.stdout.close() if proc.stdout else None + proc.stderr.close() if proc.stderr else None + proc.stdin.close() if proc.stdin else None + except Exception as e: + print(f"[Kill] Error closing pipes: {e}") + + return "", "", proc.returncode if proc.returncode else -999 + + +async def wait_for_processes(processes_with_names, check_interval=0.5, timeout=60): + """Wait for multiple subprocesses to complete, printing output in real-time. + + Args: + processes_with_names: List of (subprocess.Popen, name) tuples + check_interval: How often to check process status (seconds) + timeout: Maximum time to wait for all processes (seconds) + + Raises: + RuntimeError: If any process fails or timeout is reached + """ + start_time = time.time() + + # Create async readers for each process's stdout and stderr + async def read_stream(stream, prefix): + """Read from a stream line-by-line and print with prefix.""" + loop = asyncio.get_event_loop() + try: + while True: + line = await loop.run_in_executor(None, stream.readline) + if not line: + break + print(f"{prefix} {line.rstrip()}", flush=True) + except Exception as e: + print(f"{prefix} [Read error: {e}]", flush=True) + + # Start readers for all processes + reader_tasks = [] + for proc, name in processes_with_names: + reader_tasks.append(asyncio.create_task(read_stream(proc.stdout, f"[{name} OUT]"))) + reader_tasks.append(asyncio.create_task(read_stream(proc.stderr, f"[{name} ERR]"))) + + try: + while True: + # Check if timeout exceeded + if time.time() - start_time > timeout: + print(f"\n{'='*60}", flush=True) + print("TIMEOUT: Killing all processes", flush=True) + print(f"{'='*60}\n", flush=True) + + # Kill all processes forcefully + for proc, name in processes_with_names: + if proc.poll() is None: + print(f"[Main] Killing {name}...", flush=True) + kill_process_tree(proc.pid, signal.SIGKILL) + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + + raise RuntimeError(f"Timeout after {timeout} seconds waiting for processes") + + # Check each process + crashed_proc = None + crashed_name = None + + for proc, name in processes_with_names: + returncode = proc.poll() + if returncode is not None and returncode != 0: + crashed_proc = proc + crashed_name = name + print(f"\n{'='*60}", flush=True) + print(f"{name} process CRASHED with exit code {returncode}", flush=True) + print(f"{'='*60}\n", flush=True) + break + + # If a process crashed, kill the others + if crashed_proc is not None: + # Kill all other processes + for proc, name in processes_with_names: + if proc != crashed_proc and proc.poll() is None: + print(f"[Main] Killing {name}...", flush=True) + kill_process_tree(proc.pid, signal.SIGKILL) + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + pass + + raise RuntimeError( + f"{crashed_name} process failed with exit code {crashed_proc.returncode}" + ) + + # Check if all processes completed successfully + all_done = all(proc.poll() is not None for proc, _ in processes_with_names) + if all_done: + # Wait for readers to finish draining pipes + print("[Main] All processes completed, waiting for output to finish...", flush=True) + await asyncio.sleep(1) # Give readers time to finish + + print(f"\n{'='*60}", flush=True) + print("✓ All processes completed successfully", flush=True) + print(f"{'='*60}\n", flush=True) + return + + # Sleep before next check + await asyncio.sleep(check_interval) + finally: + # Cancel reader tasks + for task in reader_tasks: + if not task.done(): + task.cancel() + # Wait for cancellation + await asyncio.gather(*reader_tasks, return_exceptions=True) + + +# --------------------------------------------------------------------------- +# Module-level helpers shared by all test variants +# --------------------------------------------------------------------------- + +def _compare_actor_results(sync_dir: Path, num_actors: int): + """Assert that all actors produced identical generation results. + + Each actor writes ``sync_dir/results_actor_{i}.json`` with keys + res_or_1, res_mod_1, res_or_2, res_mod_2. + """ + import json + + results = [ + json.loads((sync_dir / f"results_actor_{i}.json").read_text()) + for i in range(num_actors) + ] + for key in results[0]: + texts = [r[key] for r in results] + assert len(set(texts)) == 1, ( + f"Actors disagree on '{key}': {texts}" + ) + + +async def _run_back_and_forth_engine_test( + model_name, + simple_prompt, + generation_config, + init_method, + distributed_trainer_helper, + vllm_engine_helper, + sync_dir, + vllm_configs, + trainer_gpu, + world_size, + timeout=1800, +): + """Run back-and-forth engine test with one or more vLLM actor processes. + + Args: + vllm_configs: List of dicts, each with keys: + - cuda_devices: str, e.g. "0" or "0,1" + - actor_llm_idx: int + - tensor_parallel_size: int + trainer_gpu: str, e.g. "1" or "2" + world_size: total NCCL world size (all vLLM workers + trainer) + """ + from .sync_helper import create_sync_dir + + num_actors = len(vllm_configs) + all_procs = [] + + # Start all vLLM actor subprocesses + for cfg in vllm_configs: + vllm_env = os.environ.copy() + vllm_env["CUDA_VISIBLE_DEVICES"] = cfg["cuda_devices"] + vllm_env["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + vllm_env["PIPELINERL_DEBUG"] = "1" + + actor_idx = cfg["actor_llm_idx"] + tp = cfg.get("tensor_parallel_size", 1) + print(f"[Main] Starting vLLM actor {actor_idx} (GPU(s) {cfg['cuda_devices']}, TP={tp})") + + vllm_proc = subprocess.Popen( + [ + sys.executable, + str(vllm_engine_helper), + "back_and_forth", + "--model-name", model_name, + "--init-method", init_method, + "--actor-llm-idx", str(actor_idx), + "--world-size", str(world_size), + "--tensor-parallel-size", str(tp), + "--prompt", simple_prompt, + "--max-tokens", str(generation_config["max_tokens"]), + "--sync-dir", str(sync_dir), + ], + env=vllm_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + all_procs.append((vllm_proc, f"vLLM Actor {actor_idx}")) + + await asyncio.sleep(1) + + # Start trainer subprocess + trainer_env = os.environ.copy() + trainer_env["CUDA_VISIBLE_DEVICES"] = trainer_gpu + trainer_env["PIPELINERL_DEBUG"] = "1" + + print(f"[Main] Starting trainer (GPU {trainer_gpu}, {num_actors} actor(s), world_size={world_size})") + trainer_proc = subprocess.Popen( + [ + sys.executable, + str(distributed_trainer_helper), + "back_and_forth", + "--init-method", init_method, + "--model-name", model_name, + "--sync-dir", str(sync_dir), + "--num-actors", str(num_actors), + "--world-size", str(world_size), + ], + env=trainer_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + all_procs.append((trainer_proc, "Trainer")) + + await wait_for_processes(all_procs, timeout=timeout) + + # Verify all actors produced the same results + _compare_actor_results(sync_dir, num_actors) + print(f"\n✓ Back-and-forth test PASSED ({num_actors} actor(s), world_size={world_size})") + + +async def _run_server_weight_update_test( + model_name, + simple_prompt, + generation_config, + init_method, + distributed_trainer_helper, + vllm_server_configs, + trainer_gpu, + world_size, + timeout=2400, +): + """Run server weight-update pattern test with one or more vLLM servers. + + Args: + vllm_server_configs: List of dicts, each with keys: + - port: int + - gpu_ids: str + - actor_llm_idx: int + - tensor_parallel_size: int + trainer_gpu: str, e.g. "1" or "2" + world_size: total NCCL world size + """ + server_procs = [] + server_urls = [] + + for cfg in vllm_server_configs: + port = cfg["port"] + url = f"http://127.0.0.1:{port}" + server_urls.append(url) + + server_proc, _, _ = start_vllm_server( + model_name=model_name, + server_port=port, + distributed_init_method=init_method, + stream_process_output_fn=stream_process_output, + extra_args=None, + gpu_ids=cfg.get("gpu_ids", "0"), + actor_llm_idx=cfg.get("actor_llm_idx", 0), + world_size=world_size, + tensor_parallel_size=cfg.get("tensor_parallel_size", 1), + ) + server_procs.append(server_proc) + + await asyncio.sleep(1) + + trainer_proc, _, _ = start_trainer_process( + trainer_helper_path=distributed_trainer_helper, + distributed_init_method=init_method, + model_name=model_name, + server_urls=server_urls, + stream_process_output_fn=stream_process_output, + extra_args=None, + gpu_id=trainer_gpu, + world_size=world_size, + ) + + try: + await wait_for_all_servers_ready(server_urls, server_procs, trainer_proc) + + if len(server_urls) == 1: + generations = await run_generation_loop( + server_url=server_urls[0], + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + trainer_proc=trainer_proc, + ) + else: + per_server_generations = await run_generation_loop_multi( + server_urls=server_urls, + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + trainer_proc=trainer_proc, + ) + + # Wait for trainer to finish + print("[Main] Waiting for trainer to finish...") + for _ in range(30): + if trainer_proc.poll() is not None: + break + await asyncio.sleep(1) + + if len(server_urls) == 1: + analyze_and_verify_pattern(generations) + else: + analyze_and_verify_pattern_multi(per_server_generations) + print(f"\n✓ Server weight update pattern test PASSED ({len(server_urls)} server(s))") + + finally: + print("[Main] Cleaning up processes...") + for proc in server_procs: + if proc: + kill_process_tree(proc.pid) + if trainer_proc: + kill_process_tree(trainer_proc.pid) + + +class TestBasicGeneration: + """Test basic vLLM generation with worker extension.""" + + @pytest.mark.asyncio + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU") + async def test_load_model_and_generate(self, vllm_engine_factory, simple_prompt, generation_config): + """Test loading model and generating text.""" + from vllm import SamplingParams + + async with vllm_engine_factory(disable_weight_updates=True) as manager: + # Generate text + sampling_params = SamplingParams( + temperature=generation_config["temperature"], + top_p=generation_config["top_p"], + max_tokens=generation_config["max_tokens"], + seed=generation_config["seed"], + ) + + request_id = "test_request_1" + async for output in manager.engine.generate( + simple_prompt, + sampling_params=sampling_params, + request_id=request_id, + ): + final_output = output + + assert final_output is not None + assert len(final_output.outputs) > 0 + assert len(final_output.outputs[0].text) > 0 + + print(f"Generated text: {final_output.outputs[0].text}") + + @pytest.mark.asyncio + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU") + async def test_deterministic_generation(self, vllm_engine_factory, simple_prompt, generation_config): + """Test that generation is deterministic with same seed and temperature=0.""" + from vllm import SamplingParams + + async with vllm_engine_factory(disable_weight_updates=True) as manager: + sampling_params = SamplingParams( + temperature=generation_config["temperature"], + top_p=generation_config["top_p"], + max_tokens=generation_config["max_tokens"], + seed=generation_config["seed"], + ) + + # Generate twice with same parameters + outputs = [] + for i in range(2): + request_id = f"test_request_{i}" + async for output in manager.engine.generate( + simple_prompt, + sampling_params=sampling_params, + request_id=request_id, + ): + final_output = output + outputs.append(final_output.outputs[0].text) + + # Outputs should be identical + assert outputs[0] == outputs[1], f"Outputs differ: '{outputs[0]}' vs '{outputs[1]}'" + + +class TestWorkerExtension: + """Test WorkerExtension loading and methods.""" + + @pytest.mark.asyncio + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires GPU") + async def test_extension_loaded(self, vllm_engine_factory): + """Test that WorkerExtension is properly loaded.""" + from vllm.v1.engine.core_client import AsyncMPClient + + async with vllm_engine_factory(disable_weight_updates=True) as manager: + # Check that engine has the extension methods + assert isinstance(manager.engine.engine_core, AsyncMPClient) + + # Test that we can call the extension method + # This verifies the extension is loaded on workers + # collective_rpc_async returns a list of results (one per worker) + results = await manager.is_extension_loaded() + # Extension should be loaded on all workers + assert isinstance(results, list) + assert len(results) > 0 # At least one worker + # Results are PIDs (integers > 0) + assert all(isinstance(r, int) and r > 0 for r in results), f"Expected PIDs, got: {results}" + print(f"WorkerExtension successfully loaded on {len(results)} worker(s)") + print(f"Worker PIDs: {results}") + print(f"Unique PIDs: {len(set(results))} (indicates {len(set(results))} separate processes)") + + +class TestWeightUpdateDistributed: + """Test weight updates with 2-GPU distributed setup.""" + + @pytest.mark.timeout(300) # 5 minutes for init test + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_init_actor_update_group( + self, + model_name, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + ): + """Test initializing actor update group with 2 GPUs. + + This test verifies that the process group can be initialized correctly: + - vLLM engine runs on GPU 0 as rank 1 (in subprocess) + - Dummy trainer process runs on GPU 1 as rank 0 (in subprocess) + + Both run in subprocesses to ensure proper CUDA_VISIBLE_DEVICES isolation. + """ + print("\n" + "="*60) + print("Starting distributed process group initialization test") + print("="*60) + + # Step 1: Start trainer subprocess FIRST with CUDA_VISIBLE_DEVICES=1 + trainer_env = os.environ.copy() + trainer_env["CUDA_VISIBLE_DEVICES"] = "1" + trainer_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting trainer process (rank 0, GPU 1)") + trainer_proc = subprocess.Popen( + [ + sys.executable, + str(distributed_trainer_helper), + "init", + "--init-method", distributed_init_method, + "--rank", "0", + "--world-size", "2", + ], + env=trainer_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Give trainer a moment to start and begin initializing + await asyncio.sleep(1) + + # Step 2: Start vLLM engine subprocess with CUDA_VISIBLE_DEVICES=0 + vllm_env = os.environ.copy() + vllm_env["CUDA_VISIBLE_DEVICES"] = "0" + vllm_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting vLLM engine process (rank 1, GPU 0)") + vllm_proc = subprocess.Popen( + [ + sys.executable, + str(vllm_engine_helper), + "init", # Command argument + "--model-name", model_name, + "--init-method", distributed_init_method, + "--actor-llm-idx", "0", + "--world-size", "2", + ], + env=vllm_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Step 3: Wait for both processes, killing all if one crashes + await wait_for_processes([ + (trainer_proc, "Trainer"), + (vllm_proc, "vLLM Engine"), + ], timeout=180) # Init test is faster, but give it 3 minutes to be safe + + @pytest.mark.timeout(1000) # 1000 seconds for broadcasting 291 parameters + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_weight_update_same_weights( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + temp_dir, + ): + """Test that updating with same weights produces same output. + + This test: + 1. vLLM engine generates baseline output (in subprocess on GPU 0) + 2. Trainer waits for baseline, then broadcasts weights (in subprocess on GPU 1) + 3. vLLM engine receives update and generates again + 4. vLLM engine verifies outputs are identical + + Both run in subprocesses for proper CUDA_VISIBLE_DEVICES isolation. + Uses file-based sync points for coordination. + """ + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting weight update test (same weights)") + print("="*60) + + # Create sync directory for coordination + sync_dir = create_sync_dir(temp_dir) + print(f"[Main] Sync directory: {sync_dir}") + + # Step 1: Start vLLM engine subprocess with weight_update command + vllm_env = os.environ.copy() + vllm_env["CUDA_VISIBLE_DEVICES"] = "0" + # NOTE: needed to pass WeightUpdateRequest to collective + vllm_env["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + # Enable DEBUG logging in vllm1.py + vllm_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting vLLM engine process (GPU 0)") + vllm_proc = subprocess.Popen( + [ + sys.executable, + str(vllm_engine_helper), + "weight_update", + "--model-name", model_name, + "--init-method", distributed_init_method, + "--actor-llm-idx", "0", + "--world-size", "2", + "--prompt", simple_prompt, + "--max-tokens", str(generation_config["max_tokens"]), + "--sync-dir", str(sync_dir), + ], + env=vllm_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Give vLLM engine a moment to start + await asyncio.sleep(1) + + # Step 2: Start trainer subprocess (will wait for baseline_done sync point) + trainer_env = os.environ.copy() + trainer_env["CUDA_VISIBLE_DEVICES"] = "1" + # Enable DEBUG logging in vllm1.py (for consistency) + trainer_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting trainer process (GPU 1)") + trainer_proc = subprocess.Popen( + [ + sys.executable, + str(distributed_trainer_helper), + "broadcast", + "--init-method", distributed_init_method, + "--model-name", model_name, + "--sync-dir", str(sync_dir), + ], + env=trainer_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Step 3: Wait for both processes, killing all if one crashes + # 291 parameters takes ~600 seconds, so use 900s (15 min) to be safe + await wait_for_processes([ + (vllm_proc, "vLLM Engine"), + (trainer_proc, "Trainer"), + ], timeout=900) + + @pytest.mark.timeout(1000) # 1000 seconds for broadcasting 290 parameters + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_weight_update_different_weights( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + temp_dir, + ): + """Test that updating with perturbed weights produces different output. + + This test: + 1. vLLM engine generates baseline output (in subprocess on GPU 0) + 2. Trainer broadcasts PERTURBED weights (in subprocess on GPU 1) + 3. vLLM engine receives update and generates again + 4. vLLM engine verifies outputs are DIFFERENT (perturbed weights changed output) + + Both run in subprocesses for proper CUDA_VISIBLE_DEVICES isolation. + Uses file-based sync points for coordination. + """ + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting weight update test (perturbed weights)") + print("="*60) + + # Create sync directory for coordination + sync_dir = create_sync_dir(temp_dir) + print(f"[Main] Sync directory: {sync_dir}") + + # Step 1: Start vLLM engine subprocess with weight_update command + vllm_env = os.environ.copy() + vllm_env["CUDA_VISIBLE_DEVICES"] = "0" + # NOTE: needed to pass WeightUpdateRequest to collective + vllm_env["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + # Enable DEBUG logging in vllm1.py + vllm_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting vLLM engine process (GPU 0)") + vllm_proc = subprocess.Popen( + [ + sys.executable, + str(vllm_engine_helper), + "weight_update", + "--model-name", model_name, + "--init-method", distributed_init_method, + "--actor-llm-idx", "0", + "--world-size", "2", + "--prompt", simple_prompt, + "--max-tokens", str(generation_config["max_tokens"]), + "--sync-dir", str(sync_dir), + "--expect-different", # Flag to expect different outputs + ], + env=vllm_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Give vLLM engine a moment to start + await asyncio.sleep(1) + + # Step 2: Start trainer subprocess with --perturb flag + trainer_env = os.environ.copy() + trainer_env["CUDA_VISIBLE_DEVICES"] = "1" + # Enable DEBUG logging in vllm1.py (for consistency) + trainer_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting trainer process (GPU 1) with --perturb") + trainer_proc = subprocess.Popen( + [ + sys.executable, + str(distributed_trainer_helper), + "broadcast", + "--init-method", distributed_init_method, + "--model-name", model_name, + "--sync-dir", str(sync_dir), + "--perturb", # Perturb weights to test different outputs + ], + env=trainer_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Step 3: Wait for both processes, killing all if one crashes + # 290 parameters takes ~600 seconds, so use 900s (15 min) to be safe + await wait_for_processes([ + (vllm_proc, "vLLM Engine"), + (trainer_proc, "Trainer"), + ], timeout=900) + + + @pytest.mark.timeout(2000) # 2000 seconds - this test does 2 full broadcasts + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_weight_update_cross_validation( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + temp_dir, + ): + """Cross-validation test: verify broadcast = load from disk. + + This test validates that: + 1. Broadcasting weights produces same results as loading from disk + 2. Round-trip works: original → modified → original + + Flow: + - vLLM: Load original, generate res_un_1 + - Trainer: Save perturbed model to disk, broadcast perturbed weights + - vLLM: Receive perturbed, generate res_mod_1 + - vLLM: Recreate engine with perturbed model from disk, generate res_mod_2 + - Trainer: Broadcast original weights + - vLLM: Receive original, generate res_un_2 + + Assertions: + - res_un_1 == res_un_2 (original weights produce same output) + - res_mod_1 == res_mod_2 (broadcast = load from disk) + """ + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting cross-validation test") + print("="*60) + + # Create sync directory for coordination + sync_dir = create_sync_dir(temp_dir) + print(f"[Main] Sync directory: {sync_dir}") + print(f"[Main] Temp directory: {temp_dir}") + + # Step 1: Start vLLM engine subprocess + vllm_env = os.environ.copy() + vllm_env["CUDA_VISIBLE_DEVICES"] = "0" + vllm_env["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" + vllm_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting vLLM engine process (GPU 0)") + vllm_proc = subprocess.Popen( + [ + sys.executable, + str(vllm_engine_helper), + "cross_validation", + "--model-name", model_name, + "--init-method", distributed_init_method, + "--actor-llm-idx", "0", + "--world-size", "2", + "--prompt", simple_prompt, + "--max-tokens", str(generation_config["max_tokens"]), + "--sync-dir", str(sync_dir), + ], + env=vllm_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Give vLLM engine a moment to start + await asyncio.sleep(1) + + # Step 2: Start trainer subprocess + trainer_env = os.environ.copy() + trainer_env["CUDA_VISIBLE_DEVICES"] = "1" + trainer_env["PIPELINERL_DEBUG"] = "1" + + print("[Main] Starting trainer process (GPU 1)") + trainer_proc = subprocess.Popen( + [ + sys.executable, + str(distributed_trainer_helper), + "cross_validation", + "--init-method", distributed_init_method, + "--model-name", model_name, + "--sync-dir", str(sync_dir), + "--temp-dir", str(temp_dir), + ], + env=trainer_env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + # Step 3: Wait for both processes + # This test does 2 broadcasts, so double the timeout + await wait_for_processes([ + (vllm_proc, "vLLM Engine"), + (trainer_proc, "Trainer"), + ], timeout=1800) # 30 minutes + + + @pytest.mark.timeout(2000) # 2000 seconds - this test does 3 broadcasts + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_weight_update_back_and_forth( + self, + model_name, + simple_prompt, + generation_config, + shared_distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + shared_test_dir, + ): + """Back-and-forth test: switch between original and perturbed weights. + + Validates that we can update weights multiple times and the results + are deterministic and reproducible. + """ + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting back-and-forth test (TP=1, 1 actor, 2 GPUs)") + print("="*60) + + sync_dir = create_sync_dir(shared_test_dir) + await _run_back_and_forth_engine_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=shared_distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_engine_helper=vllm_engine_helper, + sync_dir=sync_dir, + vllm_configs=[{"cuda_devices": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}], + trainer_gpu="1", + world_size=2, + timeout=1800, + ) + + @pytest.mark.timeout(2400) # 40 minutes for server test + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_server_weight_update_pattern( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + temp_dir, + ): + """Server integration test: verify weight update pattern with HTTP API. + + Validates the real-world scenario where a vLLM HTTP server receives + weight updates from a trainer while serving requests. + """ + print("\n" + "="*60) + print("Starting server weight update pattern test (TP=1, 1 actor, 2 GPUs)") + print("="*60) + + await _run_server_weight_update_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_server_configs=[{"port": 8000, "gpu_ids": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}], + trainer_gpu="1", + world_size=2, + timeout=2400, + ) + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") + async def test_server_weight_update_catch_transitions( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + temp_dir, + ): + """Diagnostic test: catch garbage generations produced during NCCL weight broadcasts. + + The trainer runs a slow initial cycle (perturbed → original, 5 s each) + to firmly establish text_A and text_B, then fires N rapid back-to-back + broadcast cycles (perturbed → original) with no inter-broadcast delay. + The generation loop runs with generation_interval=0.0 (back-to-back + requests) to maximise the chance of hitting a mid-broadcast state. + + Assertions: + 1. The A→B→A→B pattern is still detected (broadcasts actually worked). + 2. At least one transition/garbage phase was captured. + + Topology: 1 vLLM server on GPU 0, trainer on GPU 1 (world_size=2). + """ + print("\n" + "=" * 60) + print("Starting transition-capture test (TP=1, 1 actor, 2 GPUs)") + print("=" * 60) + + server_url = "http://127.0.0.1:8000" + + server_proc, _, _ = start_vllm_server( + model_name=model_name, + server_port=8000, + distributed_init_method=distributed_init_method, + stream_process_output_fn=stream_process_output, + gpu_ids="0", + actor_llm_idx=0, + world_size=2, + tensor_parallel_size=1, + ) + + await asyncio.sleep(1) + + trainer_proc, _, _ = start_trainer_process( + trainer_helper_path=distributed_trainer_helper, + distributed_init_method=distributed_init_method, + model_name=model_name, + server_urls=[server_url], + stream_process_output_fn=stream_process_output, + extra_args=["--n-cycles", "6"], + gpu_id="1", + world_size=2, + command="rapid_broadcast_cycles", + ) + + try: + await wait_for_server_ready(server_url, server_proc, trainer_proc) + + generations = await run_generation_loop( + server_url=server_url, + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + trainer_proc=trainer_proc, + generation_interval=0.0, + ) + + print("[Main] Waiting for trainer to finish...") + for _ in range(30): + if trainer_proc.poll() is not None: + break + await asyncio.sleep(1) + + analyze_and_verify_transitions(generations, n_cycles=6) + print("\n✓ Transition-capture test PASSED") + + finally: + print("[Main] Cleaning up processes...") + if server_proc: + kill_process_tree(server_proc.pid) + if trainer_proc: + kill_process_tree(trainer_proc.pid) + + +class TestWeightUpdateTP2: + """Test weight updates with tensor-parallel (TP=2) vLLM — needs 3 GPUs.""" + + @pytest.mark.timeout(2000) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Requires at least 3 GPUs") + async def test_weight_update_back_and_forth_tp2( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + temp_dir, + ): + """Back-and-forth test with TP=2: one vLLM instance on GPUs 0+1, trainer on GPU 2.""" + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting back-and-forth test (TP=2, 1 actor, 3 GPUs)") + print("="*60) + + sync_dir = create_sync_dir(temp_dir) + await _run_back_and_forth_engine_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_engine_helper=vllm_engine_helper, + sync_dir=sync_dir, + vllm_configs=[{"cuda_devices": "0,1", "actor_llm_idx": 0, "tensor_parallel_size": 2}], + trainer_gpu="2", + world_size=3, + timeout=1800, + ) + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Requires at least 3 GPUs") + async def test_server_weight_update_pattern_tp2( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + temp_dir, + ): + """Server weight update test with TP=2: one server on GPUs 0+1, trainer on GPU 2.""" + print("\n" + "="*60) + print("Starting server weight update pattern test (TP=2, 1 actor, 3 GPUs)") + print("="*60) + + await _run_server_weight_update_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_server_configs=[{"port": 8001, "gpu_ids": "0,1", "actor_llm_idx": 0, "tensor_parallel_size": 2}], + trainer_gpu="2", + world_size=3, + timeout=2400, + ) + + +class TestWeightUpdateMultiActor: + """Test weight updates with multiple independent vLLM actors.""" + + @pytest.mark.timeout(2000) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Requires at least 3 GPUs") + async def test_weight_update_back_and_forth_2actors( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + temp_dir, + ): + """Back-and-forth test with 2 actors: vLLM on GPU 0 and GPU 1, trainer on GPU 2.""" + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting back-and-forth test (TP=1, 2 actors, 3 GPUs)") + print("="*60) + + sync_dir = create_sync_dir(temp_dir) + await _run_back_and_forth_engine_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_engine_helper=vllm_engine_helper, + sync_dir=sync_dir, + vllm_configs=[ + {"cuda_devices": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}, + {"cuda_devices": "1", "actor_llm_idx": 1, "tensor_parallel_size": 1}, + ], + trainer_gpu="2", + world_size=3, + timeout=1800, + ) + + @pytest.mark.timeout(2000) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs") + async def test_weight_update_back_and_forth_3actors( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + vllm_engine_helper, + temp_dir, + ): + """Back-and-forth test with 3 actors: vLLM on GPUs 0/1/2, trainer on GPU 3.""" + from .sync_helper import create_sync_dir + + print("\n" + "="*60) + print("Starting back-and-forth test (TP=1, 3 actors, 4 GPUs)") + print("="*60) + + sync_dir = create_sync_dir(temp_dir) + await _run_back_and_forth_engine_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_engine_helper=vllm_engine_helper, + sync_dir=sync_dir, + vllm_configs=[ + {"cuda_devices": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}, + {"cuda_devices": "1", "actor_llm_idx": 1, "tensor_parallel_size": 1}, + {"cuda_devices": "2", "actor_llm_idx": 2, "tensor_parallel_size": 1}, + ], + trainer_gpu="3", + world_size=4, + timeout=1800, + ) + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 3, reason="Requires at least 3 GPUs") + async def test_server_weight_update_pattern_2actors( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + temp_dir, + ): + """Server weight update test with 2 actors: servers on GPUs 0 and 1, trainer on GPU 2.""" + print("\n" + "="*60) + print("Starting server weight update pattern test (TP=1, 2 actors, 3 GPUs)") + print("="*60) + + await _run_server_weight_update_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_server_configs=[ + {"port": 8000, "gpu_ids": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}, + {"port": 8001, "gpu_ids": "1", "actor_llm_idx": 1, "tensor_parallel_size": 1}, + ], + trainer_gpu="2", + world_size=3, + timeout=2400, + ) + + @pytest.mark.timeout(2400) + @pytest.mark.asyncio + @pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 GPUs") + async def test_server_weight_update_pattern_3actors( + self, + model_name, + simple_prompt, + generation_config, + distributed_init_method, + distributed_trainer_helper, + temp_dir, + ): + """Server weight update test with 3 actors: servers on GPUs 0/1/2, trainer on GPU 3.""" + print("\n" + "="*60) + print("Starting server weight update pattern test (TP=1, 3 actors, 4 GPUs)") + print("="*60) + + await _run_server_weight_update_test( + model_name=model_name, + simple_prompt=simple_prompt, + generation_config=generation_config, + init_method=distributed_init_method, + distributed_trainer_helper=distributed_trainer_helper, + vllm_server_configs=[ + {"port": 8000, "gpu_ids": "0", "actor_llm_idx": 0, "tensor_parallel_size": 1}, + {"port": 8001, "gpu_ids": "1", "actor_llm_idx": 1, "tensor_parallel_size": 1}, + {"port": 8002, "gpu_ids": "2", "actor_llm_idx": 2, "tensor_parallel_size": 1}, + ], + trainer_gpu="3", + world_size=4, + timeout=2400, + ) diff --git a/tests/test_world_multinode.py b/tests/test_world_multinode.py new file mode 100644 index 00000000..14d1474a --- /dev/null +++ b/tests/test_world_multinode.py @@ -0,0 +1,842 @@ +"""Tests for multi-node WorldMap topology and fast-llm torchrun command assembly.""" + +import os +import sys +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import pytest +from omegaconf import OmegaConf + + +def _make_cfg( + actor_fraction=1, + finetune_fraction=1, + preprocessor_fraction=0, + replicas=1, + use_fast_llm=True, + tp=1, + pp=1, + seq_parallel=1, +): + """Minimal config for WorldMap construction.""" + return OmegaConf.create({ + "world": { + "actor_fraction": actor_fraction, + "finetune_fraction": finetune_fraction, + "preprocessor_fraction": preprocessor_fraction, + "replicas": replicas, + "actor_group_port": 9000, + "environment_start_port": 7777, + }, + "vllm_config": { + "vllm_kwargs": { + "tensor-parallel-size": tp, + "pipeline-parallel-size": pp, + } + }, + "finetune": {"seq_parallel": seq_parallel}, + "use_fast_llm": use_fast_llm, + "debug": {"mode": "", "place_inference_workers": True}, + }) + + +def _make_world_map(cfg, world_size, rank=0, master_addr="dns-test-0"): + from pipelinerl.world import WorldMap + env = { + "WORLD_SIZE": str(world_size), + "RANK": str(rank), + "MASTER_ADDR": master_addr, + } + with patch.dict(os.environ, env, clear=False): + # collect_environment_specs needs cfg fields that don't exist in minimal cfg; + # patch it out to avoid AttributeError. + with patch("pipelinerl.world.WorldMap._place_environments"): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + return WorldMap(cfg, verbose=False) + + +# --------------------------------------------------------------------------- +# WorldMap topology tests +# --------------------------------------------------------------------------- + +class TestWorldMapMultiNode: + + def test_2node_1actor_1finetune_whole_nodes(self): + """2 nodes: 1 actor node + 1 finetune node — each gets all 8 GPUs.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2) + + assert wm.total_finetune_gpus == 8, "finetune should get exactly 1 full node" + assert wm.total_finetune_gpus % wm.node_size == 0 + assert len(wm.nodes_with_finetuning()) == 1 + + def test_4node_1actor_3finetune_whole_nodes(self): + """4 nodes: 1 actor node + 3 finetune nodes.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4) + + assert wm.total_finetune_gpus == 24, "finetune should get exactly 3 full nodes" + assert wm.total_finetune_gpus % wm.node_size == 0 + assert len(wm.nodes_with_finetuning()) == 3 + + def test_4node_2actor_2finetune_whole_nodes(self): + """4 nodes: 2 actor nodes + 2 finetune nodes.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=2) + wm = _make_world_map(cfg, world_size=4) + + assert wm.total_finetune_gpus == 16 + assert wm.total_finetune_gpus % wm.node_size == 0 + assert len(wm.nodes_with_finetuning()) == 2 + + def test_finetune_always_at_least_one_node(self): + """Even with a large actor fraction, finetune gets at least 1 full node.""" + cfg = _make_cfg(actor_fraction=3, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=4) + + assert len(wm.nodes_with_finetuning()) >= 1 + assert wm.total_finetune_gpus >= wm.node_size + assert wm.total_finetune_gpus % wm.node_size == 0 + + def test_actors_never_exceed_world_size_minus_one(self): + """Actor nodes never consume all nodes — at least 1 reserved for finetune.""" + cfg = _make_cfg(actor_fraction=10, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=4) + + finetune_nodes = len(wm.nodes_with_finetuning()) + assert finetune_nodes >= 1 + assert finetune_nodes < 4 + + def test_single_node_unchanged(self): + """Single-node path is not affected by the multi-node rounding.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6) + # Single-node: world_size=1, node_size = actual device count (mocked) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + assert wm.total_finetune_gpus == 6 + assert wm.world_size == 1 + + def test_nodes_with_finetuning_returns_sorted_ranks(self): + """nodes_with_finetuning() returns a sorted list of node ranks.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4) + + fn = wm.nodes_with_finetuning() + assert fn == sorted(fn) + + def test_my_finetuning_rank_on_finetune_node(self): + """my_finetuning_rank() returns 0 for the first finetune node.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + # With 2 nodes, finetune is on node 0 (actor on node 1 due to reversed placement) + wm = _make_world_map(cfg, world_size=2, rank=0) + + finetune_nodes = wm.nodes_with_finetuning() + # my_rank=0 should be a finetune node + assert 0 in finetune_nodes + assert wm.my_finetuning_rank() == finetune_nodes.index(0) + + def test_4node_with_preprocessor_all_whole_nodes(self): + """4 nodes, actor=1, preprocessor=1, finetune=6: all three get whole nodes.""" + cfg = _make_cfg(actor_fraction=1, preprocessor_fraction=1, finetune_fraction=6) + wm = _make_world_map(cfg, world_size=4) + + assert wm.total_finetune_gpus % wm.node_size == 0, "finetune must be whole nodes" + # preprocessor and actor GPU shares should also be multiples of node_size + total = wm.world_size * wm.node_size + actor_gpus = total - wm.total_finetune_gpus - wm.gpus_per_preprocessor * cfg.world.replicas + assert actor_gpus % wm.node_size == 0, "actor must be whole nodes" + assert (wm.gpus_per_preprocessor * cfg.world.replicas) % wm.node_size == 0, "preprocessor must be whole nodes" + assert wm.total_finetune_gpus + actor_gpus + wm.gpus_per_preprocessor * cfg.world.replicas == total + + def test_3node_with_preprocessor_all_whole_nodes(self): + """3 nodes, actor=1, preprocessor=1, finetune=1: each component gets 1 node.""" + cfg = _make_cfg(actor_fraction=1, preprocessor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=3) + + assert wm.total_finetune_gpus % wm.node_size == 0 + total = wm.world_size * wm.node_size + actor_gpus = total - wm.total_finetune_gpus - wm.gpus_per_preprocessor * cfg.world.replicas + assert actor_gpus % wm.node_size == 0 + assert (wm.gpus_per_preprocessor * cfg.world.replicas) % wm.node_size == 0 + + def test_address_map_derived_from_master_addr(self): + """address_map entries follow the dns-- pattern.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + assert wm.address_map[0] == "dns-abc123-0" + assert wm.address_map[1] == "dns-abc123-1" + + +# --------------------------------------------------------------------------- +# torchrun command assembly test +# --------------------------------------------------------------------------- + +class TestTorchrunCommand: + + def _capture_cmd(self, world_map, cfg_extra=None): + """Run _run_finetune_fast_llm with mocked I/O and capture the torchrun command.""" + from pipelinerl.launch import _run_finetune_fast_llm + + cfg = OmegaConf.create({ + "model_path": "/tmp/fake_model", + "weight_broadcast": False, + "debug": {"mode": "", "log_data_pipeline": False}, + "streams": {"host": "localhost", "port": 11000}, + "wandb": { + "wandb_workspace_root": "/tmp", + "wandb_entity_name": "test", + "wandb_project_name": "test", + "wandb_group": "test", + }, + "fast_llm": { + "training": { + "train_iters": 10, + "wandb": {"entity_name": None, "project_name": None, "group_name": None}, + }, + "data": {"datasets": {"training": {"type": "streaming", "host": None, "port": None}}}, + "pretrained": {"format": "llama", "path": None, "model_weights": True}, + "run": {"experiment_dir": None, "experiment_name": None}, + "callbacks": {}, + }, + "fast_llm_finetune": { + "model_type": "llama", + "torchrun_port": 29500, + "model_format": "llama", + }, + }) + if cfg_extra: + cfg = OmegaConf.merge(cfg, OmegaConf.create(cfg_extra)) + + captured_cmd = [] + + def mock_popen(cmd, **kwargs): + captured_cmd.extend(cmd) + return None # no process spawned + + with tempfile.TemporaryDirectory() as tmp: + exp_dir = Path(tmp) + # Patch os.path.isdir to pass the model_path check + with patch("pipelinerl.launch._popen", side_effect=mock_popen): + with patch("pipelinerl.launch.save_command"): + with patch("os.path.isdir", return_value=True): + list(_run_finetune_fast_llm(cfg, world_map, gpus=[0, 1, 2, 3], exp_dir=exp_dir)) + + return captured_cmd + + def test_single_node_uses_master_port(self): + """Single-node torchrun uses --master_port, no rdzv flags.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + cmd = self._capture_cmd(wm) + assert "--master_port=29500" in cmd + assert "--rdzv_backend=static" not in cmd + assert "--nnodes=6" not in cmd + + def test_2node_1finetune_uses_single_node_torchrun(self): + """2-node job with 1 actor + 1 finetune node: fast-llm spans 1 node → single-node torchrun.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2, rank=0, master_addr="dns-abc-0") + + assert len(wm.nodes_with_finetuning()) == 1, "only 1 finetune node in 2-node job" + cmd = self._capture_cmd(wm) + # Should use simple --master_port, not rdzv + assert "--master_port=29500" in cmd + assert "--rdzv_backend=static" not in cmd + + def test_multi_node_uses_static_rdzv(self): + """Fast-llm spanning multiple nodes uses static rdzv with correct nnodes and node_rank.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4, rank=0, master_addr="dns-abc-0") + + assert len(wm.nodes_with_finetuning()) == 3 + cmd = self._capture_cmd(wm) + assert "--rdzv_backend=static" in cmd + assert "--rdzv_id=0" in cmd + assert "--max_restarts=0" in cmd + finetune_count = len(wm.nodes_with_finetuning()) + assert f"--nnodes={finetune_count}" in cmd + assert f"--node_rank={wm.my_finetuning_rank()}" in cmd + finetune_master = wm.address_map[wm.nodes_with_finetuning()[0]] + assert any(f"--rdzv_endpoint={finetune_master}:29500" in arg for arg in cmd) + assert not any("--master_port" in arg for arg in cmd) + + def test_multi_node_4nodes_correct_nnodes(self): + """4-node job: torchrun nnodes = 3 (finetune nodes only).""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4, rank=0) + + cmd = self._capture_cmd(wm) + finetune_count = len(wm.nodes_with_finetuning()) + assert finetune_count == 3 + assert f"--nnodes={finetune_count}" in cmd + + +# --------------------------------------------------------------------------- +# DeepSpeed regression: snapping must NOT apply when use_fast_llm=False +# --------------------------------------------------------------------------- + +class TestWorldMapDeepSpeed: + + def test_deepspeed_single_node_fractional_split(self): + """Single-node DeepSpeed split is unchanged — 2 actor GPUs + 6 finetune GPUs.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6, use_fast_llm=False) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + assert wm.total_finetune_gpus == 6 + assert wm.world_size == 1 + + def test_deepspeed_multinode_no_rounding(self): + """Multi-node DeepSpeed: no whole-node snapping (handled by DeepSpeed itself).""" + # 2 nodes, actor_fraction=1, finetune_fraction=1 → 8 finetune GPUs (happens to be whole node) + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2) + # Should still compute correctly without triggering fast-llm rounding path + assert wm.total_finetune_gpus > 0 + assert wm.world_size == 2 + + def test_fast_llm_single_node_unchanged(self): + """Single-node fast-llm: fractional split within one node is preserved.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6, use_fast_llm=True) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + assert wm.total_finetune_gpus == 6 + assert wm.world_size == 1 + + +# --------------------------------------------------------------------------- +# Pod IP exchange: dns_address_map, job URL rewriting, DeepSpeed/fast-llm compat +# --------------------------------------------------------------------------- + +def _simulate_pod_ip_exchange(wm, pod_ips: dict): + """Simulate _exchange_pod_ips without NFS I/O. + + Sets dns_address_map to original DNS names, updates address_map and job + URLs/hostnames to pod IPs — mirrors the real function's side-effects. + """ + from pipelinerl.launch import _exchange_pod_ips as real_fn # noqa: F401 (not called) + # Save DNS names first (matches the real implementation order) + wm.dns_address_map = dict(wm.address_map) + # Overwrite address_map with pod IPs + for rank, ip in pod_ips.items(): + wm.address_map[rank] = ip + wm.master_addr = pod_ips[0] + # Rewrite job URLs/hostnames + for node, jobs in wm.job_map.items(): + dns_name = wm.dns_address_map[node] + pod_ip = pod_ips[node] + for job in jobs: + job.hostname = pod_ip + if job.url: + job.url = job.url.replace(dns_name, pod_ip) + + +class TestPodIPExchange: + + def test_dns_address_map_holds_original_dns_names(self): + """After pod IP exchange, dns_address_map contains original DNS names, not pod IPs.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + pod_ips = {0: "10.0.0.1", 1: "10.0.0.2"} + _simulate_pod_ip_exchange(wm, pod_ips) + + assert wm.dns_address_map[0] == "dns-abc123-0" + assert wm.dns_address_map[1] == "dns-abc123-1" + + def test_address_map_updated_to_pod_ips(self): + """After pod IP exchange, address_map and master_addr hold pod IPs.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + pod_ips = {0: "10.0.0.1", 1: "10.0.0.2"} + _simulate_pod_ip_exchange(wm, pod_ips) + + assert wm.address_map[0] == "10.0.0.1" + assert wm.address_map[1] == "10.0.0.2" + assert wm.master_addr == "10.0.0.1" + + def test_job_urls_rewritten_to_pod_ips(self): + """After pod IP exchange, actor_llm job URLs use pod IPs, not DNS names.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + # Verify that actor_llm jobs have DNS-based URLs before exchange + actor_urls_before = [job.url for job in wm.get_all_jobs() if job.kind == "actor_llm"] + assert all("dns-abc123-1" in u for u in actor_urls_before) + + pod_ips = {0: "10.0.0.1", 1: "10.0.0.2"} + _simulate_pod_ip_exchange(wm, pod_ips) + + actor_urls_after = [job.url for job in wm.get_all_jobs() if job.kind == "actor_llm"] + assert all("10.0.0.2" in u for u in actor_urls_after), f"Expected pod IP in URLs: {actor_urls_after}" + assert all("dns-abc123" not in u for u in actor_urls_after) + + def test_no_dns_address_map_without_exchange(self): + """Without pod IP exchange, dns_address_map is not set (no AttributeError).""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + assert not hasattr(wm, "dns_address_map") + + +# --------------------------------------------------------------------------- +# DeepSpeed command assembly: hostfile and inclusion filter use DNS names +# --------------------------------------------------------------------------- + +class TestDeepSpeedCommand: + + def _make_ds_cfg(self): + return OmegaConf.create({ + "use_deepspeed": True, + "use_fsdp": False, + "deepspeed_config": "zero2", + "accelerate_config": None, + "world": {"actor_group_port": 9000}, + "debug": {"mode": ""}, + }) + + def _capture_ds_cmd(self, world_map, cfg_extra=None): + """Run _run_finetune_deepspeed with mocked I/O and capture the command.""" + from pipelinerl.launch import _run_finetune_deepspeed + + cfg = self._make_ds_cfg() + if cfg_extra: + cfg = OmegaConf.merge(cfg, OmegaConf.create(cfg_extra)) + + captured_cmd = [] + + def mock_popen(cmd, **kwargs): + captured_cmd.extend(cmd) + return None + + with tempfile.TemporaryDirectory() as tmp: + exp_dir = Path(tmp) + (exp_dir / "hostfile.txt").write_text("") # pre-create + with patch("pipelinerl.launch._popen", side_effect=mock_popen): + with patch("pipelinerl.launch.save_command"): + with patch.dict(os.environ, {"MASTER_ADDR": "dns-test-0", "MASTER_PORT": "29501"}): + list(_run_finetune_deepspeed(cfg, world_map, gpus=[0, 1, 2, 3], exp_dir=exp_dir)) + + return captured_cmd + + def test_deepspeed_multinode_uses_dns_names_without_exchange(self): + """DeepSpeed 2-node without pod IP exchange: inclusion filter uses DNS names.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + cmd = self._capture_ds_cmd(wm) + # The deepspeed_inclusion_filter should contain the DNS hostname for the finetune node + filter_arg = next((c for c in cmd if "dns-abc123" in c), None) + assert filter_arg is not None, f"Expected DNS name in cmd, got: {cmd}" + + def test_deepspeed_multinode_after_pod_ip_exchange_uses_dns_names(self): + """After pod IP exchange, DeepSpeed inclusion filter still uses DNS names (not pod IPs).""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + # Simulate pod IP exchange + _simulate_pod_ip_exchange(wm, {0: "10.0.0.1", 1: "10.0.0.2"}) + + cmd = self._capture_ds_cmd(wm) + # Inclusion filter must still use DNS names, not pod IPs + filter_arg = next((c for c in cmd if "dns-abc123" in c), None) + assert filter_arg is not None, f"Expected DNS name in DS filter after pod IP exchange, got: {cmd}" + # Pod IPs must NOT appear in the inclusion filter + assert not any("10.0.0" in c for c in cmd if "--deepspeed_inclusion_filter" not in c and "@" in c), \ + f"Pod IP leaked into DS filter: {cmd}" + + def test_deepspeed_single_node_no_pod_ip_exchange(self): + """Single-node DeepSpeed: no world_size>1 branch, pod IP exchange never runs.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6, use_fast_llm=False) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + assert wm.world_size == 1 + assert not hasattr(wm, "dns_address_map") + # Should not crash even without dns_address_map + cmd = self._capture_ds_cmd(wm) + assert "--num_machines" not in cmd # single-node, no multi-machine flags + + +# --------------------------------------------------------------------------- +# Hostfile creation in main(): uses dns_address_map after pod IP exchange +# --------------------------------------------------------------------------- + +class TestHostfileCreation: + + def test_hostfile_uses_dns_names_after_pod_ip_exchange(self): + """The DeepSpeed hostfile written by main() uses DNS names even after pod IP exchange.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + # Simulate pod IP exchange + _simulate_pod_ip_exchange(wm, {0: "10.0.0.1", 1: "10.0.0.2"}) + + dns_map = getattr(wm, "dns_address_map", wm.address_map) + hosts = [dns_map[i] for i in range(wm.world_size)] + + assert hosts[0] == "dns-abc123-0" + assert hosts[1] == "dns-abc123-1" + assert "10.0.0" not in hosts[0] + assert "10.0.0" not in hosts[1] + + def test_hostfile_uses_address_map_without_exchange(self): + """Without pod IP exchange, dns_address_map is absent — falls back to address_map.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + dns_map = getattr(wm, "dns_address_map", wm.address_map) + hosts = [dns_map[i] for i in range(wm.world_size)] + + assert hosts[0] == "dns-abc123-0" + assert hosts[1] == "dns-abc123-1" + + +# --------------------------------------------------------------------------- +# Redis host in saved exp_config.yaml for multi-node (DeepSpeed + Redis) +# --------------------------------------------------------------------------- + +class TestRedisHostMultiNode: + + def _compute_streams_host(self, world_map, my_rank: int) -> str: + """Mirror the launch.py logic for cfg.streams.host selection.""" + if world_map.world_size > 1: + return world_map.master_addr + return "localhost" + + def test_single_node_redis_host_is_localhost(self): + """Single-node: Redis host is localhost regardless of pod IP exchange.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6, use_fast_llm=False) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + host = self._compute_streams_host(wm, my_rank=0) + assert host == "localhost" + + def test_multinode_rank0_redis_host_is_pod_ip(self): + """Multi-node rank 0: Redis host is pod IP (not localhost) after exchange. + + This ensures the saved exp_config.yaml has a reachable address for + DeepSpeed workers on other nodes. + """ + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + _simulate_pod_ip_exchange(wm, {0: "10.0.0.1", 1: "10.0.0.2"}) + + host = self._compute_streams_host(wm, my_rank=0) + assert host == "10.0.0.1", "rank 0 should use pod IP so saved config is reachable cross-node" + assert host != "localhost" + + def test_multinode_rank1_redis_host_is_pod_ip(self): + """Multi-node rank 1: Redis host is pod IP of rank 0.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0", rank=1) + _simulate_pod_ip_exchange(wm, {0: "10.0.0.1", 1: "10.0.0.2"}) + + host = self._compute_streams_host(wm, my_rank=1) + assert host == "10.0.0.1", "rank 1 should use rank 0's pod IP to reach Redis" + + def test_multinode_both_ranks_same_redis_host(self): + """Both ranks in a 2-node job resolve to the same Redis host (pod IP of rank 0).""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm0 = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0", rank=0) + wm1 = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0", rank=1) + + _simulate_pod_ip_exchange(wm0, {0: "10.0.0.1", 1: "10.0.0.2"}) + _simulate_pod_ip_exchange(wm1, {0: "10.0.0.1", 1: "10.0.0.2"}) + + host0 = self._compute_streams_host(wm0, my_rank=0) + host1 = self._compute_streams_host(wm1, my_rank=1) + + assert host0 == host1 == "10.0.0.1" + + def test_multinode_without_pod_ip_exchange_uses_master_addr(self): + """Without pod IP exchange, multi-node uses master_addr (DNS name) for Redis. + + This is a fallback; the pod IP exchange should always run in practice + but the code must not crash without it. + """ + cfg = _make_cfg(actor_fraction=1, finetune_fraction=1, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=2, master_addr="dns-abc123-0") + + # No pod IP exchange — master_addr is still a DNS name + assert wm.master_addr == "dns-abc123-0" + host = self._compute_streams_host(wm, my_rank=0) + assert host == "dns-abc123-0" # DNS name (port filtering may apply, but code doesn't crash) + + +# --------------------------------------------------------------------------- +# DeepSpeed run_finetune.py path: must be absolute (not relative to CWD) +# --------------------------------------------------------------------------- + +class TestDeepSpeedEntrypointPath: + + def _capture_ds_cmd(self, world_map): + from pipelinerl.launch import _run_finetune_deepspeed + from omegaconf import OmegaConf + + cfg = OmegaConf.create({ + "use_deepspeed": True, + "use_fsdp": False, + "deepspeed_config": "zero2", + "accelerate_config": None, + "world": {"actor_group_port": 9000}, + "debug": {"mode": ""}, + }) + captured_cmd = [] + + def mock_popen(cmd, **kwargs): + captured_cmd.extend(cmd) + return None + + with tempfile.TemporaryDirectory() as tmp: + exp_dir = Path(tmp) + with patch("pipelinerl.launch._popen", side_effect=mock_popen): + with patch("pipelinerl.launch.save_command"): + with patch.dict(os.environ, {"MASTER_ADDR": "dns-test-0", "MASTER_PORT": "29501"}): + list(_run_finetune_deepspeed(cfg, world_map, gpus=[0, 1, 2, 3], exp_dir=exp_dir)) + + return captured_cmd + + def test_run_finetune_path_is_absolute(self): + """run_finetune.py must be an absolute path so it works regardless of CWD. + + When EAI starts the pod, CWD is /home/toolkit (not the repo root). A relative + path like 'pipelinerl/entrypoints/run_finetune.py' resolves to + '/home/toolkit/pipelinerl/...' which doesn't exist. + """ + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6, use_fast_llm=False) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + cmd = self._capture_ds_cmd(wm) + + # Find the run_finetune.py argument + finetune_script = next((c for c in cmd if "run_finetune.py" in c), None) + assert finetune_script is not None, f"run_finetune.py not found in cmd: {cmd}" + assert Path(finetune_script).is_absolute(), ( + f"run_finetune.py path must be absolute but got: {finetune_script!r}. " + "A relative path resolves against CWD which is /home/toolkit in EAI pods." + ) + assert Path(finetune_script).exists(), ( + f"run_finetune.py absolute path must exist: {finetune_script!r}" + ) + + +# --------------------------------------------------------------------------- +# Per-node file naming: fast-llm and DeepSpeed avoid NFS write races +# --------------------------------------------------------------------------- + +class TestPerNodeFileNaming: + """Verify that multinode fast-llm and DeepSpeed finetune runs write separate + output files per node (config, start.sh, stdout, stderr) to avoid NFS races.""" + + def _capture_fast_llm_files(self, world_map, gpus=None): + """Run _run_finetune_fast_llm and return captured file suffix info.""" + from pipelinerl.launch import _run_finetune_fast_llm + + cfg = OmegaConf.create({ + "model_path": "/tmp/fake_model", + "weight_broadcast": False, + "debug": {"mode": "", "log_data_pipeline": False}, + "streams": {"host": "localhost", "port": 11000}, + "wandb": { + "wandb_workspace_root": "/tmp", + "wandb_entity_name": "test", + "wandb_project_name": "test", + "wandb_group": "test", + }, + "fast_llm": { + "training": { + "train_iters": 10, + "wandb": {"entity_name": None, "project_name": None, "group_name": None}, + }, + "data": {"datasets": {"training": {"type": "streaming", "host": None, "port": None}}}, + "pretrained": {"format": "llama", "path": None, "model_weights": True}, + "run": {"experiment_dir": None, "experiment_name": None}, + "callbacks": {}, + }, + "fast_llm_finetune": { + "model_type": "llama", + "torchrun_port": 29500, + "model_format": "llama", + }, + }) + + written_files = {} + + real_open = open + + def mock_popen(cmd, **kwargs): + written_files["stdout"] = str(kwargs.get("stdout", {}).name if hasattr(kwargs.get("stdout"), "name") else "") + written_files["stderr"] = str(kwargs.get("stderr", {}).name if hasattr(kwargs.get("stderr"), "name") else "") + return None + + captured_save = {} + + def mock_save_command(script_dir, cmd, suffix=""): + captured_save["suffix"] = suffix + captured_save["dir"] = str(script_dir) + + captured_config = {} + + real_omegaconf_save = None + + with tempfile.TemporaryDirectory() as tmp: + exp_dir = Path(tmp) + with patch("pipelinerl.launch._popen", side_effect=mock_popen): + with patch("pipelinerl.launch.save_command", side_effect=mock_save_command): + with patch("os.path.isdir", return_value=True): + with patch("omegaconf.OmegaConf.save") as mock_cfg_save: + list(_run_finetune_fast_llm(cfg, world_map, gpus=gpus or [0, 1, 2, 3], exp_dir=exp_dir)) + if mock_cfg_save.call_args: + # OmegaConf.save(cfg, path) — second positional arg is path + args = mock_cfg_save.call_args[0] + captured_config["path"] = str(args[1]) if len(args) > 1 else "" + + return { + "config_path": captured_config.get("path", ""), + "save_suffix": captured_save.get("suffix", ""), + "stdout": written_files.get("stdout", ""), + "stderr": written_files.get("stderr", ""), + } + + def _capture_deepspeed_files(self, world_map, gpus=None): + """Run _run_finetune_deepspeed and return captured file suffix.""" + from pipelinerl.launch import _run_finetune_deepspeed + + cfg = OmegaConf.create({ + "use_deepspeed": True, + "use_fsdp": False, + "deepspeed_config": "zero2", + "accelerate_config": None, + "world": {"actor_group_port": 9000}, + "debug": {"mode": ""}, + }) + + captured_save = {} + written_files = {} + + def mock_popen(cmd, **kwargs): + written_files["stdout"] = str(kwargs.get("stdout", {}).name if hasattr(kwargs.get("stdout"), "name") else "") + written_files["stderr"] = str(kwargs.get("stderr", {}).name if hasattr(kwargs.get("stderr"), "name") else "") + return None + + def mock_save_command(script_dir, cmd, suffix=""): + captured_save["suffix"] = suffix + + with tempfile.TemporaryDirectory() as tmp: + exp_dir = Path(tmp) + with patch("pipelinerl.launch._popen", side_effect=mock_popen): + with patch("pipelinerl.launch.save_command", side_effect=mock_save_command): + with patch.dict(os.environ, {"MASTER_ADDR": "dns-test-0", "MASTER_PORT": "29501"}): + list(_run_finetune_deepspeed(cfg, world_map, gpus=gpus or [0, 1, 2, 3], exp_dir=exp_dir)) + + return { + "save_suffix": captured_save.get("suffix", ""), + "stdout": written_files.get("stdout", ""), + "stderr": written_files.get("stderr", ""), + } + + # --- fast-llm single-node: no suffix --- + + def test_fast_llm_single_node_no_suffix(self): + """Single-node fast-llm: no _node0 suffix — backward compat.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + result = self._capture_fast_llm_files(wm) + assert result["save_suffix"] == "", f"Single-node must have no suffix, got: {result['save_suffix']!r}" + assert "_node" not in result["config_path"], f"Single-node config must have no _node suffix: {result['config_path']}" + + # --- fast-llm multinode: each node gets its own suffix --- + + def test_fast_llm_multinode_node0_suffix(self): + """4-node fast-llm, finetune node 0: files get _node0 suffix. + Actor takes the last node (rank 3), so ranks 0/1/2 are finetune.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4, rank=0) # rank 0 = first finetune node + + result = self._capture_fast_llm_files(wm) + assert result["save_suffix"] == "_node0", f"Expected _node0, got: {result['save_suffix']!r}" + assert "_node0" in result["config_path"], f"Config path must contain _node0: {result['config_path']}" + + def test_fast_llm_multinode_node1_suffix(self): + """4-node fast-llm, finetune node 1: files get _node1 suffix.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4, rank=1) # rank 1 = second finetune node + + result = self._capture_fast_llm_files(wm) + assert result["save_suffix"] == "_node1", f"Expected _node1, got: {result['save_suffix']!r}" + assert "_node1" in result["config_path"] + + def test_fast_llm_multinode_node2_suffix(self): + """4-node fast-llm, finetune node 2: files get _node2 suffix.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3) + wm = _make_world_map(cfg, world_size=4, rank=2) # rank 2 = third finetune node + + result = self._capture_fast_llm_files(wm) + assert result["save_suffix"] == "_node2", f"Expected _node2, got: {result['save_suffix']!r}" + + # --- DeepSpeed single-node: no suffix --- + + def test_deepspeed_single_node_no_suffix(self): + """Single-node DeepSpeed: no _node suffix.""" + cfg = _make_cfg(actor_fraction=2, finetune_fraction=6, use_fast_llm=False) + with patch("torch.cuda.device_count", return_value=8): + with patch("pipelinerl.utils.collect_environment_specs", return_value=[]): + with patch("pipelinerl.world.WorldMap._place_environments"): + from pipelinerl.world import WorldMap + wm = WorldMap(cfg, verbose=False) + + result = self._capture_deepspeed_files(wm) + assert result["save_suffix"] == "", f"Single-node must have no suffix, got: {result['save_suffix']!r}" + + # --- DeepSpeed multinode: each node gets its own suffix --- + + def test_deepspeed_multinode_node0_suffix(self): + """4-node DeepSpeed, finetune node 0: save_command gets _node0 suffix. + Actor takes the last node (rank 3), so ranks 0/1/2 are finetune.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=4, rank=0) # rank 0 = first finetune node + + result = self._capture_deepspeed_files(wm) + assert result["save_suffix"] == "_node0", f"Expected _node0, got: {result['save_suffix']!r}" + + def test_deepspeed_multinode_node2_suffix(self): + """4-node DeepSpeed, finetune node 2: save_command gets _node2 suffix.""" + cfg = _make_cfg(actor_fraction=1, finetune_fraction=3, use_fast_llm=False) + wm = _make_world_map(cfg, world_size=4, rank=2) # rank 2 = third finetune node + + result = self._capture_deepspeed_files(wm) + assert result["save_suffix"] == "_node2", f"Expected _node2, got: {result['save_suffix']!r}" diff --git a/tests/trainer_test_utils.py b/tests/trainer_test_utils.py new file mode 100644 index 00000000..d6de57e5 --- /dev/null +++ b/tests/trainer_test_utils.py @@ -0,0 +1,128 @@ +"""Shared utilities for trainer helper scripts (both HTTP and fast-llm variants).""" + + +def _resolve_model_path(model_name: str): + """Resolve model name to a local Path, downloading from HuggingFace if needed.""" + from pathlib import Path + from huggingface_hub import snapshot_download + + model_path = Path(model_name) + if not model_path.exists(): + print(f"[Trainer] Downloading model from HuggingFace Hub: {model_name}") + model_path = Path(snapshot_download(model_name)) + return model_path + + +def _load_state_dict(model_name: str, device: str = "cuda:0") -> tuple: + """Load model state dict from safetensors files. + + Returns: + (state_dict, model_path) + """ + import json + from safetensors.torch import load_file + + model_path = _resolve_model_path(model_name) + index_file = model_path / "model.safetensors.index.json" + + if index_file.exists(): + print(f"[Trainer] Found index file, loading sharded model") + with open(index_file) as f: + index = json.load(f) + weight_map = index["weight_map"] + + file_to_params = {} + for param_name, filename in weight_map.items(): + file_to_params.setdefault(filename, []).append(param_name) + + state_dict = {} + for filename, param_names in file_to_params.items(): + file_path = model_path / filename + print(f"[Trainer] Loading {len(param_names)} parameters from {filename}") + tensors = load_file(str(file_path), device=device) + for param_name in param_names: + state_dict[param_name] = tensors[param_name] + else: + safetensors_file = model_path / "model.safetensors" + print(f"[Trainer] Loading from single file: {safetensors_file}") + state_dict = load_file(str(safetensors_file), device=device) + + print(f"[Trainer] Loaded {len(state_dict)} parameters from safetensors") + return state_dict, model_path + + +def _create_perturbed_state_dict( + state_dict: dict, seed: int = 42, noise_scale: float = 0.001 +) -> dict: + """Return a new state dict with Gaussian noise added to all tensors.""" + import torch + + print(f"[Trainer] Creating perturbed weights (all tensors) with seed={seed}...") + torch.manual_seed(seed) + perturbed = {} + for name, tensor in state_dict.items(): + perturbed_tensor = tensor.clone() + perturbed_tensor.add_(torch.randn_like(perturbed_tensor) * noise_scale) + perturbed[name] = perturbed_tensor + print( + f"[Trainer] Perturbed all {len(perturbed)} tensors with noise={noise_scale}, seed={seed}" + ) + return perturbed + + +def _init_actor_process_group(init_method: str, rank: int = 0, world_size: int = 2, group_name: str = "actor"): + """Initialize the actor NCCL process group and return it.""" + import pipelinerl.torch_utils + + print(f"[Trainer] Initializing process group as rank {rank} (group_name={group_name!r})") + process_group = pipelinerl.torch_utils.init_extra_process_group( + group_name=group_name, + backend="nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + ) + print("[Trainer] Process group initialized") + return process_group + + +def _broadcast_tensors(state_dict: dict, process_group, log_interval: int = 50): + """Broadcast every tensor in state_dict via NCCL (src=0).""" + import torch.distributed as dist + + total = len(state_dict) + for i, (name, tensor) in enumerate(state_dict.items()): + if tensor.device.type != "cuda": + tensor = tensor.cuda(0) + dist.broadcast(tensor, src=0, group=process_group) + if (i + 1) % log_interval == 0: + print(f"[Trainer] Broadcasted {i+1}/{total} parameters") + print(f"[Trainer] All {total} parameters broadcasted") + + +def _wait_for_servers_ready(server_urls: list, extra_wait_secs: int = 10): + """Poll /health on each server until all respond 200, then sleep extra_wait_secs.""" + import time + import requests + + for server_url in server_urls: + print(f"[Trainer] Waiting for server {server_url} to be ready...") + server_ready = False + for i in range(120): # up to 2 minutes + try: + resp = requests.get(f"{server_url}/health", timeout=1) + if resp.status_code == 200: + server_ready = True + print(f"[Trainer] Server {server_url} is ready (took {i} seconds)") + break + except requests.exceptions.RequestException: + pass + time.sleep(1) + if not server_ready: + raise TimeoutError(f"Server {server_url} did not become ready within 2 minutes") + + if extra_wait_secs > 0: + print( + f"[Trainer] Waiting additional {extra_wait_secs} seconds for server(s) to fully initialize..." + ) + time.sleep(extra_wait_secs) diff --git a/tests/vllm_engine_helper.py b/tests/vllm_engine_helper.py new file mode 100755 index 00000000..798743bc --- /dev/null +++ b/tests/vllm_engine_helper.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +"""Helper script for running vLLM engine in a subprocess with proper CUDA isolation. + +This script is run as a separate process with CUDA_VISIBLE_DEVICES set, +ensuring the engine only sees the intended GPU. +""" + +import sys +import argparse +import asyncio + + +async def init_engine_and_process_group( + model_name: str, + init_method: str, + actor_llm_idx: int, + world_size: int, +): + """Initialize vLLM engine and process group. + + create_engine() automatically calls init_actor_update_group() when + disable_weight_updates=False, and calls destroy_actor_update_group() + on context manager exit. + """ + from pipelinerl.vllm1 import EngineManager + import argparse as ap + + print("[vLLM Engine] Starting engine initialization") + + # Create args for engine with process group params + args = ap.Namespace( + model=model_name, + tensor_parallel_size=1, + disable_log_stats=True, + enable_log_requests=False, + disable_weight_updates=False, + # Process group params - needed for automatic init_actor_update_group() + actor_llm_idx=actor_llm_idx, + weight_update_group_init_method=init_method, + weight_update_group_world_size=world_size, + ) + + print(f"[vLLM Engine] Creating engine with model={model_name}") + + # create_engine automatically: + # 1. Creates engine and manager + # 2. Calls manager.init_actor_update_group() (rank 1) + # 3. On exit, calls manager.destroy_actor_update_group() + async with EngineManager.create_engine(args) as manager: + print("[vLLM Engine] Engine and process group created successfully") + + # Keep engine alive until trainer completes its work + print("[vLLM Engine] Process group active, waiting for trainer...") + await asyncio.sleep(5) + + # Context manager exit automatically cleans up process group + print("[vLLM Engine] Engine and process group cleaned up") + + +async def test_weight_update( + model_name: str, + init_method: str, + actor_llm_idx: int, + world_size: int, + prompt: str, + max_tokens: int, + sync_dir: str, + expect_different: bool = False, +): + """Test weight update with generation before and after. + + This mode: + 1. Creates engine and initializes process group + 2. Generates baseline output + 3. Signals baseline_done, waits for broadcast_done + 4. Receives weight update + 5. Generates again with same prompt + 6. Prints both outputs for comparison + """ + from pipelinerl.vllm1 import EngineManager + from vllm import SamplingParams + from pathlib import Path + import argparse as ap + # Import sync helper from same directory + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint + + print("[vLLM Engine] Starting weight update test") + + # Create sync points + sync_path = Path(sync_dir) + baseline_done = SyncPoint(sync_path, "baseline_done") + ready_to_receive = SyncPoint(sync_path, "ready_to_receive") + request_ready = SyncPoint(sync_path, "request_ready") + receiving_started = SyncPoint(sync_path, "receiving_started") + broadcast_done = SyncPoint(sync_path, "broadcast_done") + + # Create args for engine with process group params + args = ap.Namespace( + model=model_name, + tensor_parallel_size=1, + disable_log_stats=True, + enable_log_requests=False, + disable_weight_updates=False, + actor_llm_idx=actor_llm_idx, + weight_update_group_init_method=init_method, + weight_update_group_world_size=world_size, + ) + + print(f"[vLLM Engine] Creating engine with model={model_name}") + + async with EngineManager.create_engine(args) as manager: + print("[vLLM Engine] Engine and process group created successfully") + + # Step 1: Generate baseline + sampling_params = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=max_tokens, + seed=42, + ) + + print(f"[vLLM Engine] Generating baseline with prompt: '{prompt}'") + async for output in manager.engine.generate( + prompt, + sampling_params=sampling_params, + request_id="baseline", + ): + baseline_output = output + + baseline_text = baseline_output.outputs[0].text + print(f"[vLLM Engine] Baseline output: '{baseline_text}'") + + # Step 2: Signal baseline done and ready to receive + baseline_done.signal() + ready_to_receive.signal() + + # Step 3: Wait for trainer to send WeightUpdateRequest + print("[vLLM Engine] Waiting for trainer to send weight update request...") + request_ready.wait(timeout=60) + + # Step 4: Read WeightUpdateRequest from trainer + from sync_helper import read_weight_update_request + request = read_weight_update_request(sync_path) + print(f"[vLLM Engine] Received request with {len(request.parameters_info)} parameters") + + # Step 5: Signal we're about to start receiving, then call receive_weight_update + receiving_started.signal() + print("[vLLM Engine] Signaled receiving_started, calling receive_weight_update...") + print("[vLLM Engine] (This will block until trainer broadcasts all weights)") + await manager.receive_weight_update(request) + print("[vLLM Engine] Weight update received!") + + # Step 6: Wait for trainer to signal broadcast complete + broadcast_done.wait(timeout=60) + print("[vLLM Engine] Trainer confirmed broadcast complete") + + # Step 7: Generate again with same prompt + print(f"[vLLM Engine] Generating after update with prompt: '{prompt}'") + async for output in manager.engine.generate( + prompt, + sampling_params=sampling_params, + request_id="after_update", + ): + updated_output = output + + updated_text = updated_output.outputs[0].text + print(f"[vLLM Engine] Updated output: '{updated_text}'") + + # Step 8: Compare outputs + if expect_different: + # Perturbed weights - expect different outputs + if baseline_text != updated_text: + print("[vLLM Engine] ✓ Outputs differ (as expected for perturbed weights)") + print(f"[vLLM Engine] Baseline: '{baseline_text}'") + print(f"[vLLM Engine] Updated: '{updated_text}'") + else: + print("[vLLM Engine] ✗ Outputs are the same!") + print(f"[vLLM Engine] Both: '{baseline_text}'") + print("[vLLM Engine] ERROR: Perturbed weights should have changed the output") + sys.exit(1) + else: + # Same weights - expect same outputs + if baseline_text == updated_text: + print("[vLLM Engine] ✓ Outputs match (as expected for same weights)") + else: + print("[vLLM Engine] ✗ Outputs differ!") + print(f"[vLLM Engine] Baseline: '{baseline_text}'") + print(f"[vLLM Engine] Updated: '{updated_text}'") + sys.exit(1) + + print("[vLLM Engine] Engine and process group cleaned up") + + +async def test_cross_validation( + model_name: str, + init_method: str, + actor_llm_idx: int, + world_size: int, + prompt: str, + max_tokens: int, + sync_dir: str, +): + """Cross-validation test for weight updates. + + Tests that broadcasting weights produces same results as loading from disk. + Flow: + 1. Generate with original model → res_un_1 + 2. Receive perturbed weights, generate → res_mod_1 + 3. Recreate engine with perturbed model from disk, generate → res_mod_2 + 4. Receive original weights, generate → res_un_2 + 5. Verify: res_un_1 == res_un_2 and res_mod_1 == res_mod_2 + """ + from pipelinerl.vllm1 import EngineManager + from vllm import SamplingParams + from pathlib import Path + import argparse as ap + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint, read_weight_update_request + + print("[vLLM Engine] Starting cross-validation test") + + # Create sync points + sync_path = Path(sync_dir) + baseline_done = SyncPoint(sync_path, "baseline_done") + perturbed_model_saved = SyncPoint(sync_path, "perturbed_model_saved") + ready_to_receive_perturbed = SyncPoint(sync_path, "ready_to_receive_perturbed") + perturbed_broadcast_done = SyncPoint(sync_path, "perturbed_broadcast_done") + mod1_done = SyncPoint(sync_path, "mod1_done") + first_engine_destroyed = SyncPoint(sync_path, "first_engine_destroyed") + engine_recreated = SyncPoint(sync_path, "engine_recreated") + ready_to_receive_original = SyncPoint(sync_path, "ready_to_receive_original") + original_broadcast_done = SyncPoint(sync_path, "original_broadcast_done") + + sampling_params = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=max_tokens, + seed=42, + ) + + # Step 1: Generate with original model + args = ap.Namespace( + model=model_name, + tensor_parallel_size=1, + disable_log_stats=True, + enable_log_requests=False, + disable_weight_updates=False, + actor_llm_idx=actor_llm_idx, + weight_update_group_init_method=init_method, + weight_update_group_world_size=world_size, + ) + + print(f"[vLLM Engine] Step 1: Creating engine with original model: {model_name}") + async with EngineManager.create_engine(args) as manager: + print(f"[vLLM Engine] Generating res_un_1 with prompt: '{prompt}'") + async for output in manager.engine.generate( + prompt, + sampling_params=sampling_params, + request_id="res_un_1", + ): + res_un_1_output = output + res_un_1 = res_un_1_output.outputs[0].text + print(f"[vLLM Engine] res_un_1: '{res_un_1}'") + + baseline_done.signal() + + # Wait for perturbed model to be saved + print("[vLLM Engine] Waiting for trainer to save perturbed model...") + perturbed_model_saved.wait(timeout=180) + + # Step 2: Receive perturbed weights and generate + ready_to_receive_perturbed.signal() + print("[vLLM Engine] Waiting for perturbed weight update request...") + + # Wait a moment for request file to be written + import time + time.sleep(0.5) + + request = read_weight_update_request(sync_path) + print(f"[vLLM Engine] Received perturbed request with {len(request.parameters_info)} parameters") + + print("[vLLM Engine] Receiving perturbed weights...") + await manager.receive_weight_update(request) + + perturbed_broadcast_done.wait(timeout=900) + print("[vLLM Engine] Perturbed weights received") + + print(f"[vLLM Engine] Generating res_mod_1 with prompt: '{prompt}'") + async for output in manager.engine.generate( + prompt, + sampling_params=sampling_params, + request_id="res_mod_1", + ): + res_mod_1_output = output + res_mod_1 = res_mod_1_output.outputs[0].text + print(f"[vLLM Engine] res_mod_1: '{res_mod_1}'") + + mod1_done.signal() + + # Engine destroyed here (context manager exit) + print("[vLLM Engine] First engine destroyed") + first_engine_destroyed.signal() + + # Step 3: Recreate engine with perturbed model from disk + perturbed_model_path = (sync_path / "perturbed_model_path.txt").read_text().strip() + print(f"[vLLM Engine] Step 3: Recreating engine with perturbed model from: {perturbed_model_path}") + + args_perturbed = ap.Namespace( + model=perturbed_model_path, + tensor_parallel_size=1, + disable_log_stats=True, + enable_log_requests=False, + disable_weight_updates=False, + actor_llm_idx=actor_llm_idx, + weight_update_group_init_method=init_method, + weight_update_group_world_size=world_size, + ) + + async with EngineManager.create_engine(args_perturbed) as manager: + # Signal immediately after engine is created + engine_recreated.signal() + print("[vLLM Engine] Engine recreated, signaled to trainer") + + print(f"[vLLM Engine] Generating res_mod_2 with prompt: '{prompt}'") + async for output in manager.engine.generate( + prompt, + sampling_params=sampling_params, + request_id="res_mod_2", + ): + res_mod_2_output = output + res_mod_2 = res_mod_2_output.outputs[0].text + print(f"[vLLM Engine] res_mod_2: '{res_mod_2}'") + + # Step 4: Receive original weights and generate + ready_to_receive_original.signal() + print("[vLLM Engine] Waiting for original weight update request...") + + time.sleep(0.5) + request = read_weight_update_request(sync_path) + print(f"[vLLM Engine] Received original request with {len(request.parameters_info)} parameters") + + print("[vLLM Engine] Receiving original weights...") + await manager.receive_weight_update(request) + + original_broadcast_done.wait(timeout=900) + print("[vLLM Engine] Original weights received") + + print(f"[vLLM Engine] Generating res_un_2 with prompt: '{prompt}'") + async for output in manager.engine.generate( + prompt, + sampling_params=sampling_params, + request_id="res_un_2", + ): + res_un_2_output = output + res_un_2 = res_un_2_output.outputs[0].text + print(f"[vLLM Engine] res_un_2: '{res_un_2}'") + + # Step 5: Verify + print("\n" + "="*60) + print("CROSS-VALIDATION RESULTS") + print("="*60) + print(f"res_un_1: '{res_un_1}'") + print(f"res_un_2: '{res_un_2}'") + print(f"res_mod_1: '{res_mod_1}'") + print(f"res_mod_2: '{res_mod_2}'") + print("="*60) + + # Check assertions + success = True + if res_un_1 == res_un_2: + print("✓ res_un_1 == res_un_2 (original weights produce same output)") + else: + print("✗ res_un_1 != res_un_2 (FAILED)") + success = False + + if res_mod_1 == res_mod_2: + print("✓ res_mod_1 == res_mod_2 (broadcast = load from disk)") + else: + print("✗ res_mod_1 != res_mod_2 (FAILED)") + success = False + + if not success: + sys.exit(1) + + print("\n✓ Cross-validation test PASSED") + + +async def test_back_and_forth( + model_name: str, + init_method: str, + actor_llm_idx: int, + world_size: int, + prompt: str, + max_tokens: int, + sync_dir: str, + tensor_parallel_size: int = 1, +): + """Back-and-forth test: switch between original and perturbed weights. + + Flow: + 1. Generate with original → res_or_1 + 2. Receive perturbed, generate → res_mod_1 + 3. Receive original, generate → res_or_2 + 4. Receive perturbed again, generate → res_mod_2 + 5. Verify: res_or_1 == res_or_2 and res_mod_1 == res_mod_2 + """ + from pipelinerl.vllm1 import EngineManager + from vllm import SamplingParams + from pathlib import Path + import argparse as ap + sys.path.insert(0, str(Path(__file__).parent)) + from sync_helper import SyncPoint, read_weight_update_request + + print("[vLLM Engine] Starting back-and-forth test") + + # Create sync points — actor-signaled names use per-actor suffix; + # completion signals (trainer→actors) stay unadorned and are shared. + sync_path = Path(sync_dir) + suffix = f"_actor_{actor_llm_idx}" + baseline_done = SyncPoint(sync_path, f"baseline_done{suffix}") + ready_for_perturbed1 = SyncPoint(sync_path, f"ready_for_perturbed1{suffix}") + perturbed1_done = SyncPoint(sync_path, "perturbed1_done") + ready_for_original = SyncPoint(sync_path, f"ready_for_original{suffix}") + original_done = SyncPoint(sync_path, "original_done") + ready_for_perturbed2 = SyncPoint(sync_path, f"ready_for_perturbed2{suffix}") + perturbed2_done = SyncPoint(sync_path, "perturbed2_done") + + sampling_params = SamplingParams( + temperature=0.0, + top_p=1.0, + max_tokens=max_tokens, + seed=42, + ) + + # Create engine args + args = ap.Namespace( + model=model_name, + tensor_parallel_size=tensor_parallel_size, + disable_log_stats=True, + enable_log_requests=False, + disable_weight_updates=False, + actor_llm_idx=actor_llm_idx, + weight_update_group_init_method=init_method, + weight_update_group_world_size=world_size, + ) + + print(f"[vLLM Engine] Creating engine with model: {model_name}") + async with EngineManager.create_engine(args) as manager: + # Step 1: Generate with original weights + print(f"[vLLM Engine] Step 1: Generating res_or_1") + async for output in manager.engine.generate( + prompt, sampling_params=sampling_params, request_id="res_or_1" + ): + res_or_1 = output.outputs[0].text + print(f"[vLLM Engine] res_or_1: '{res_or_1}'") + baseline_done.signal() + + # Step 2: Receive perturbed weights, generate + ready_for_perturbed1.signal() + import time + time.sleep(0.5) + request = read_weight_update_request(sync_path) + print(f"[vLLM Engine] Step 2: Receiving perturbed weights (1st time)") + await manager.receive_weight_update(request) + perturbed1_done.wait(timeout=900) + + print(f"[vLLM Engine] Generating res_mod_1") + async for output in manager.engine.generate( + prompt, sampling_params=sampling_params, request_id="res_mod_1" + ): + res_mod_1 = output.outputs[0].text + print(f"[vLLM Engine] res_mod_1: '{res_mod_1}'") + + # Step 3: Receive original weights, generate + ready_for_original.signal() + time.sleep(0.5) + request = read_weight_update_request(sync_path) + print(f"[vLLM Engine] Step 3: Receiving original weights") + await manager.receive_weight_update(request) + original_done.wait(timeout=900) + + print(f"[vLLM Engine] Generating res_or_2") + async for output in manager.engine.generate( + prompt, sampling_params=sampling_params, request_id="res_or_2" + ): + res_or_2 = output.outputs[0].text + print(f"[vLLM Engine] res_or_2: '{res_or_2}'") + + # Step 4: Receive perturbed weights again, generate + ready_for_perturbed2.signal() + time.sleep(0.5) + request = read_weight_update_request(sync_path) + print(f"[vLLM Engine] Step 4: Receiving perturbed weights (2nd time)") + await manager.receive_weight_update(request) + perturbed2_done.wait(timeout=900) + + print(f"[vLLM Engine] Generating res_mod_2") + async for output in manager.engine.generate( + prompt, sampling_params=sampling_params, request_id="res_mod_2" + ): + res_mod_2 = output.outputs[0].text + print(f"[vLLM Engine] res_mod_2: '{res_mod_2}'") + + # Step 5: Save results to per-actor file for multi-actor comparison + import json + results_file = sync_path / f"results_actor_{actor_llm_idx}.json" + actor_results = { + "res_or_1": res_or_1, + "res_mod_1": res_mod_1, + "res_or_2": res_or_2, + "res_mod_2": res_mod_2, + } + with open(results_file, "w") as f: + json.dump(actor_results, f, indent=2) + print(f"[vLLM Engine] Saved results for actor {actor_llm_idx} to {results_file}") + + # Step 6: Verify + print("\n" + "="*60) + print("BACK-AND-FORTH TEST RESULTS") + print("="*60) + print(f"res_or_1: '{res_or_1}'") + print(f"res_or_2: '{res_or_2}'") + print(f"res_mod_1: '{res_mod_1}'") + print(f"res_mod_2: '{res_mod_2}'") + print("="*60) + + # Check assertions + success = True + if res_or_1 == res_or_2: + print("✓ res_or_1 == res_or_2 (can switch back to original)") + else: + print("✗ res_or_1 != res_or_2 (FAILED)") + success = False + + if res_mod_1 == res_mod_2: + print("✓ res_mod_1 == res_mod_2 (perturbed weights consistent)") + else: + print("✗ res_mod_1 != res_mod_2 (FAILED)") + success = False + + if not success: + sys.exit(1) + + print("\n✓ Back-and-forth test PASSED") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="vLLM engine helper") + parser.add_argument("command", choices=["init", "weight_update", "cross_validation", "back_and_forth"]) + parser.add_argument("--model-name", required=True) + parser.add_argument("--init-method", required=True) + parser.add_argument("--actor-llm-idx", type=int, default=0) + parser.add_argument("--world-size", type=int, default=2) + # For weight_update command + parser.add_argument("--prompt", type=str, default="The capital of France is") + parser.add_argument("--max-tokens", type=int, default=50) + parser.add_argument("--sync-dir", type=str, help="Directory for sync files") + parser.add_argument("--expect-different", action="store_true", help="Expect outputs to be different (for perturbed weights)") + parser.add_argument("--tensor-parallel-size", type=int, default=1, help="Tensor parallel size for engine") + + args = parser.parse_args() + + try: + if args.command == "init": + asyncio.run(init_engine_and_process_group( + args.model_name, + args.init_method, + args.actor_llm_idx, + args.world_size, + )) + elif args.command == "weight_update": + if not args.sync_dir: + print("Error: --sync-dir required for weight_update command") + sys.exit(1) + asyncio.run(test_weight_update( + args.model_name, + args.init_method, + args.actor_llm_idx, + args.world_size, + args.prompt, + args.max_tokens, + args.sync_dir, + args.expect_different, + )) + elif args.command == "cross_validation": + if not args.sync_dir: + print("Error: --sync-dir required for cross_validation command") + sys.exit(1) + asyncio.run(test_cross_validation( + args.model_name, + args.init_method, + args.actor_llm_idx, + args.world_size, + args.prompt, + args.max_tokens, + args.sync_dir, + )) + elif args.command == "back_and_forth": + if not args.sync_dir: + print("Error: --sync-dir required for back_and_forth command") + sys.exit(1) + asyncio.run(test_back_and_forth( + args.model_name, + args.init_method, + args.actor_llm_idx, + args.world_size, + args.prompt, + args.max_tokens, + args.sync_dir, + tensor_parallel_size=args.tensor_parallel_size, + )) + except Exception as e: + print(f"[vLLM Engine] Error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) diff --git a/tests/weight_update_utils.py b/tests/weight_update_utils.py new file mode 100644 index 00000000..3e36a1e6 --- /dev/null +++ b/tests/weight_update_utils.py @@ -0,0 +1,53 @@ +"""Utility functions for weight update testing.""" + +from typing import Dict +import torch +from pipelinerl.finetune_loop import WeightUpdateRequest, ParameterInfo + + +def dtype_to_string(dtype: torch.dtype) -> str: + """Convert torch dtype to string format expected by vLLM. + + Args: + dtype: PyTorch dtype + + Returns: + String representation (e.g., 'bfloat16', 'float32') + """ + dtype_str = str(dtype).replace("torch.", "") + return dtype_str + + +def create_weight_update_request_from_state_dict( + state_dict: Dict[str, torch.Tensor], + version: int = 0, +) -> WeightUpdateRequest: + """Create a WeightUpdateRequest from a model state dict. + + This helper function is useful for testing and for creating weight + update requests from saved model checkpoints. + + Args: + state_dict: Dictionary mapping parameter names to tensors + version: Version number for this weight update + + Returns: + WeightUpdateRequest object ready to be sent to workers + + Example: + >>> state_dict = torch.load('model.pt') + >>> request = create_weight_update_request_from_state_dict(state_dict, version=1) + >>> # Send request to vLLM server via HTTP endpoint + """ + parameters_info = [] + for name, tensor in state_dict.items(): + if isinstance(tensor, torch.Tensor): + parameters_info.append( + ParameterInfo( + name=name, + shape=list(tensor.shape), + dtype=dtype_to_string(tensor.dtype), + ) + ) + + return WeightUpdateRequest(version=version, parameters_info=parameters_info)