Skip to content

feat: add FrozenLake multi-turn tool-call GRPO training example#168

Merged
benjibc merged 6 commits intomainfrom
bchen/frozen-lake-grpo-example
Mar 6, 2026
Merged

feat: add FrozenLake multi-turn tool-call GRPO training example#168
benjibc merged 6 commits intomainfrom
bchen/frozen-lake-grpo-example

Conversation

@benjibc
Copy link
Copy Markdown
Contributor

@benjibc benjibc commented Mar 5, 2026

image

Summary

  • Add a complete GRPO training example for FrozenLake with multi-turn tool calls, using eval-protocol SDK for rollouts and Fireworks hosted trainer/deployment infrastructure
  • Implement per-position loss mask in GRPO loss so only model-generated completion tokens receive gradients (environment/tool-response tokens are masked out)
  • FrozenLake domain modules (env, schema, rollout processor) are self-contained in training/examples/frozen_lake/ — eval-protocol stays generic
  • Rollout processor uses the generic FireworksV1CompletionsClient with a pluggable tool_call_parser callback
  • GRPO loss reports granular metrics: active_tokens, mask_ratio, mean_adv_loss, mean_kl_penalty, inf_kld
  • Monotonic WandB step counter avoids step conflicts on filtered/skipped steps
  • Signal handling (SIGTERM/SIGINT) and robust cleanup: always delete auto-created deployments and trainer jobs on exit

Files

New: Frozen Lake example

  • training/examples/frozen_lake/train_frozen_lake.py — main GRPO training script
  • training/examples/frozen_lake/verify_rollout.py — single-rollout verification with eval-protocol UI
  • training/examples/frozen_lake/frozen_lake_env.py — deterministic FrozenLake environment
  • training/examples/frozen_lake/frozen_lake_schema.py — tool schema, action defs, XML parsing
  • training/examples/frozen_lake/frozen_lake_rollout.py — rollout processor wiring generic client with FrozenLake env
  • training/examples/frozen_lake/seeds.jsonl — reproducible seed contexts

Modified: training utilities

  • training/utils/rl/common.py — _get_loss_mask helper for per-position masking
  • training/utils/rl/grpo.py — apply loss_mask, emit granular metrics
  • training/utils/rl/train.py — emit rollout metrics on filtered/skipped steps

Dependencies

Test plan

  • Run train_frozen_lake.py end-to-end with qwen3-4b and verify WandB logging
  • Run verify_rollout.py and confirm token debug view shows correct masking in eval-protocol UI
  • Verify cleanup: kill training process with SIGTERM and confirm deployment + trainer jobs are deleted

benjibc added 6 commits March 5, 2026 14:29
Add a complete GRPO training example for FrozenLake with multi-turn
tool calls using the eval-protocol SDK for rollouts and Fireworks
hosted trainer/deployment infrastructure.

Key changes:
- New frozen_lake example: train_frozen_lake.py, verify_rollout.py, seeds.jsonl
- Per-position loss_mask in GRPO loss for multi-turn episodes (only
  model-generated completion tokens receive gradients, environment/tool
  tokens are masked)
- Training shape support in infra.py (pass training_shape to server,
  clear manual accelerator settings to let server auto-configure)
- Signal handling and robust cleanup (always delete deployment + trainer
  jobs on exit, capture job IDs even on partial failure)
- Log rollout metrics to WandB even when all prompt groups are filtered
- Accept UPDATING deployment state in setup_deployment
- Compatibility with SDK versions that lack disable_speculative_decoding

Made-with: Cursor
…raining

Move domain-specific Frozen Lake modules (env, schema, rollout processor)
from eval-protocol into the cookbook, so eval-protocol stays generic and
the example is self-contained. Key improvements:

- FrozenLake rollout processor now uses the generic
  FireworksV1CompletionsClient with a pluggable tool_call_parser callback
- GRPO loss applies per-position loss_mask and reports granular metrics
  (active_tokens, mask_ratio, mean_adv_loss, mean_kl_penalty, inf_kld)
- Training script logs detailed step summaries and uses monotonic WandB
  step counter to avoid step conflicts on filtered/skipped steps
- Filtered steps still push rollout metrics to WandB for visibility

Made-with: Cursor
…ompletion

With enable_thinking unset, the Qwen3 template doesn't include
<think>\n\n</think>\n\n in the generation prompt, so the model
generates those tokens as part of its completion. This caused them
to receive loss_mask=1.0 and gradients during training.

Setting enable_thinking=False makes the template include the empty
thinking block in the prompt. The model's completion_ids then start
after </think>, correctly excluding template tokens from the loss.

Made-with: Cursor
The frozen lake GRPO example imports from eval_protocol for the
generic /v1/completions client and rollout processor types.

Made-with: Cursor
Both training (loss_mask) and visualization (UI mask) now derive from
the same compute_model_output_spans() function, eliminating duplicated
turn-boundary logic. Tests verify the two masks agree on model-generated
positions after accounting for the logprob coordinate shift.

Made-with: Cursor
@benjibc
Copy link
Copy Markdown
Contributor Author

benjibc commented Mar 6, 2026

Pushed follow-up commit \ with the rebase fixes and validated a full run on \ ().

@benjibc benjibc force-pushed the bchen/frozen-lake-grpo-example branch from febbaa3 to 3f7adf6 Compare March 6, 2026 01:53
@benjibc benjibc merged commit ca23aaa into main Mar 6, 2026
@benjibc benjibc deleted the bchen/frozen-lake-grpo-example branch March 6, 2026 02:06
Hecate0821 pushed a commit that referenced this pull request Mar 6, 2026
PR #168 (FrozenLake example) incorrectly overwrote several shared
utilities with older/incompatible versions during merge:

- Restore TrainStepFns interface in train.py (reverts MinibatchTrainFns
  rewrite that broke rl_loop.py recipe)
- Remove _install_tinker_future_retrieve_compat() from client.py
  (workaround no longer needed, fixed server-side)
- Restore direct disable_speculative_decoding= in config.py
  (removes unnecessary inspect guard)
- Remove grad_accum param and restore apply_shape() in infra.py
- Update frozen_lake example to use TrainStepFns (1:1 loop)

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants