-
Notifications
You must be signed in to change notification settings - Fork 219
feat: add speculative decoding during post-training #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: hiso <hiso@nvidia.com>
terrykong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📝 WalkthroughWalkthroughThis change introduces speculative decoding metrics instrumentation throughout the GRPO training pipeline and vLLM generation infrastructure. New utilities aggregate and compute speculative decoding metrics from worker groups, vLLM generation classes expose metric collection methods, and GRPO training captures counter snapshots before and after generation to track speculative decoding performance. Changes
Sequence DiagramsequenceDiagram
participant GRPO as GRPO Training Loop
participant PolicyGen as PolicyGeneration
participant Worker as vLLM Worker
participant Metrics as Metric Aggregation
GRPO->>PolicyGen: policy_generation.get_metrics()
PolicyGen->>Worker: RPC get_metrics() to rank 0 workers
Worker-->>PolicyGen: return spec_counters dict
PolicyGen-->>GRPO: aggregated worker_metrics list
GRPO->>Metrics: spec_counters_start = aggregate_spec_decode_counters()
Note over GRPO,Worker: Generation Phase
GRPO->>PolicyGen: run generation step
PolicyGen->>Worker: generate tokens with spec decode
GRPO->>PolicyGen: policy_generation.get_metrics()
PolicyGen->>Worker: RPC get_metrics() to rank 0 workers
Worker-->>PolicyGen: return updated spec_counters dict
PolicyGen-->>GRPO: aggregated worker_metrics list
GRPO->>Metrics: spec_counters_end = aggregate_spec_decode_counters()
GRPO->>Metrics: compute_spec_decode_metrics(start, end)
Metrics-->>GRPO: spec_metrics (deltas, derived metrics)
GRPO->>GRPO: merge spec_metrics into training metrics
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@nemo_rl/algorithms/grpo.py`:
- Around line 1152-1154: Calls to policy_generation.get_metrics() can raise for
backends (e.g., megatron) that don’t implement get_metrics; guard those calls by
checking capability and defaulting spec_metrics to {}. Update the places using
aggregate_spec_decode_counters(policy_generation.get_metrics()) (e.g., where
spec_counters_start is assigned and the other occurrences around lines
referenced) to first check if hasattr(policy_generation, "get_metrics") or
callable(getattr(policy_generation, "get_metrics", None)); if present call it
and pass the result to aggregate_spec_decode_counters, otherwise pass an empty
dict so spec_metrics/spec_counters_start defaults to {} and training won’t
break.
🧹 Nitpick comments (2)
nemo_rl/models/generation/__init__.py (1)
44-45: Guard against missingvllm_kwargsfor custom configs.If a user config omits
vllm_kwargs(or sets it toNone), this line will raise. A small defensive guard keeps behavior identical while avoiding a hard failure in edge configs.♻️ Suggested guard
- is_spec = "speculative_config" in config["vllm_kwargs"] + vllm_kwargs = config.get("vllm_kwargs") or {} + is_spec = "speculative_config" in vllm_kwargs config["vllm_cfg"]["load_format"] = "auto" if is_eval or is_spec else "dummy"nemo_rl/models/generation/vllm/vllm_worker.py (1)
284-318: Make patch logging reflect whether it actually applied.Right now the log says “Successfully patched…” even if the snippet wasn’t found (newer vLLM or already patched). Returning a boolean and logging accordingly avoids confusion during upgrades.
♻️ Suggested change
- def _patch_vllm_speculative_decoding_post_step(): + def _patch_vllm_speculative_decoding_post_step() -> bool: @@ - if new_snippet in content or old_snippet not in content: - return + if new_snippet in content or old_snippet not in content: + return False @@ - with open(file_to_patch, "w") as f: - f.write(content) + with open(file_to_patch, "w") as f: + f.write(content) + return True @@ - _patch_vllm_speculative_decoding_post_step() - logger.info("Successfully patched vllm speculative decoding post_step.") + if _patch_vllm_speculative_decoding_post_step(): + logger.info("Successfully patched vllm speculative decoding post_step.") + else: + logger.info( + "Skipped vllm speculative decoding post_step patch (already patched or incompatible version)." + )Also applies to: 325-326
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
nemo_rl/algorithms/grpo.pynemo_rl/algorithms/utils.pynemo_rl/models/generation/__init__.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/generation/vllm/vllm_worker.py
🧰 Additional context used
📓 Path-based instructions (4)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/generation/__init__.pynemo_rl/algorithms/utils.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with
@ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/generation/__init__.pynemo_rl/algorithms/utils.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.py
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
nemo_rl/models/generation/__init__.pynemo_rl/algorithms/utils.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.py
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
nemo_rl/models/generation/__init__.pynemo_rl/algorithms/utils.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.py
🧬 Code graph analysis (2)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
get_metrics(384-397)
nemo_rl/algorithms/grpo.py (4)
nemo_rl/algorithms/utils.py (3)
aggregate_spec_decode_counters(775-810)calculate_baseline_and_std_per_prompt(80-157)compute_spec_decode_metrics(813-879)nemo_rl/models/generation/vllm/vllm_generation.py (1)
get_metrics(384-397)nemo_rl/models/generation/vllm/vllm_worker.py (1)
get_metrics(526-546)nemo_rl/data/packing/metrics.py (1)
update(52-91)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Lint check
- GitHub Check: Post submodule check comment / Comment on PR
- GitHub Check: Post automodel integration comment / Comment on PR
🔇 Additional comments (6)
nemo_rl/models/generation/vllm/vllm_generation.py (1)
384-398: Looks good.nemo_rl/algorithms/grpo.py (1)
41-44: Imports look fine.nemo_rl/models/generation/vllm/vllm_worker.py (2)
456-456: LGTM.
526-546: Nice addition for metrics visibility.nemo_rl/algorithms/utils.py (2)
18-18: No issues here.
775-879: Spec‑decode aggregation utilities look solid.
✏️ Tip: You can disable this entire section by setting review_details to false in your review settings.
What does this PR do ?
Enable speculative decoding support in NeMo-RL using the vLLM backend during post-training (GRPO).
This PR integrates vLLM's speculative decoding capabilities into NeMo-RL, allowing for faster generation during the post-training phase. It includes necessary patches for vLLM to ensure correct metric collection and provides utility functions to track and report speculative decoding performance (e.g., acceptance rates) during training.
Key changes:
vllm.v1.engine.core_clientto properly callpost_step, which is essential for speculative decoding to function correctly in the v1 engine whenVLLM_ENABLE_V1_MULTIPROCESSING=0. This is fixed upstream in vllm-project/vllm#30319 but not yet in a released version.VllmGenerationWorkerandVllmGenerationto collect speculative decoding counters (draft tokens, accepted tokens, etc.) from the underlying vLLM engine.nemo_rl/algorithms/utils.pyto aggregate these metrics across multiple workers and compute derived metrics like "acceptance rate" and "draft efficiency".load_format="auto"inVllmConfigwhenspeculative_configis detected, ensuring the model weights are loaded correctly for speculative execution.Issues
List issues that this PR closes:
N/A
Usage
To enable speculative decoding, include the
speculative_modeland related parameters in yourvllm_kwargsconfiguration:Warning
Limitation: When using speculative decoding with vLLM < 0.12.0, generation log probabilities will be returned as 0. This means
use_importance_samplingcannot be used. This is fixed in vllm-project/vllm#29223 and will be available in vLLM v0.12.0+.Before your PR is "Ready for review"
Pre checks:
Additional Information
spec_decode/prefix if enabled.Summary by CodeRabbit
New Features
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.