-
Notifications
You must be signed in to change notification settings - Fork 219
feat: Support lora in dtensor grpo workflow[1/3]: sync and colocated setup #1748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ℹ️ File Consistency CheckCheck based on commit: 02e55d3 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 4e7aa40 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 2a7975d (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 6d7e746 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 6e306c3 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: 41d735d (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
📝 WalkthroughWalkthroughThis PR introduces LoRA (Low-Rank Adaptation) support across the GRPO training framework, including configuration blocks, selective weight refitting logic with flags to control base-model versus LoRA weight updates, vLLM and DTensor integration with LoRA adapters, and functional/unit tests for validation. Changes
Sequence Diagram(s)sequenceDiagram
participant GRPO as GRPO Train
participant Policy as Policy (DTensor)
participant Worker as DTensorWorker
participant vLLM as vLLM Engine
participant IPC as IPC/ZMQ
GRPO->>Policy: refit_policy_generation(refit_base_model_weights=T, refit_lora_weights=F)
alt refit_base_model_weights=True
Policy->>Worker: stream_weights_via_ipc_zmq(refit_base_model_weights=T, refit_lora_weights=F)
Worker->>Worker: Filter weights (base_model only, skip LoRA)
Worker->>IPC: Send base_model weights
IPC->>vLLM: update_weights_via_ipc_zmq(refit_base_model_weights=T, refit_lora_weights=F)
vLLM->>vLLM: Load base_model weights
end
alt refit_lora_weights=True
Policy->>Worker: stream_weights_via_ipc_zmq(refit_base_model_weights=F, refit_lora_weights=T)
Worker->>Worker: Filter weights (LoRA only, skip base_model)
Worker->>IPC: Send LoRA weights + LoRA config
IPC->>vLLM: update_weights_via_ipc_zmq(lora_config, refit_lora_weights=T)
vLLM->>vLLM: Apply LoRA patches & load LoRA weights
end
GRPO->>GRPO: Post-refit: toggle flags (disable base, enable LoRA)
Note over GRPO,vLLM: Next refit cycle uses updated flags
sequenceDiagram
participant Config as Config (lora_cfg)
participant VLLMWorker as vLLM Worker Init
participant vLLM as vLLM Engine
participant LoRA as LoRA Module
Config->>VLLMWorker: Extract lora_cfg (enabled, dim, alpha, ...)
VLLMWorker->>VLLMWorker: Check if lora_cfg.enabled
alt LoRA Enabled
VLLMWorker->>LoRA: apply_lora_patches()
LoRA->>vLLM: Patch _load_adapter with patched_load_adapter
VLLMWorker->>vLLM: Inject enable_lora, max_loras, max_lora_rank
VLLMWorker->>VLLMWorker: Set self.lora_enabled=True
end
Note over VLLMWorker,vLLM: During generation, if lora_enabled:<br/>Construct LoRARequest and pass to generate()
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
nemo_rl/models/generation/vllm/vllm_worker_async.py (1)
989-1021: Check all worker results from collective_rpc, not just the first one.The code calls
collective_rpcto update weights across multiple internal vLLM workers (relevant whentensor_parallel_size > 1), but only checksworker_results[0]. Thereport_device_idmethod provescollective_rpcreturns a list with results from all workers—not just the first. If any worker at index 1+ fails to update weights, the failure goes undetected.Proposed fix
- worker_result = worker_results[0] - - if not worker_result: - print( - f"Error: Worker failed to update weights. Result: {worker_result}" - ) - return False - return True + # Be robust to multi-worker returns from collective_rpc. + if isinstance(worker_results, list): + ok = all(bool(r) for r in worker_results) + else: + ok = bool(worker_results) + + if not ok: + print(f"Error: One or more workers failed to update weights: {worker_results}") + return False + return Truenemo_rl/algorithms/grpo.py (1)
945-1018: Review nested refit logic and timer context.The refactored refit flow introduces several layers of nesting and conditional logic:
_perform_refit_weightsis defined as a nested function but only called within the same function- Lines 995-1018 have complex conditional logic with timer contexts
- The refit sequence (base first at 1007-1010, then LoRA at 1011-1017) is implicit
Concerns:
- If
refit_base_model_weights=Truebut the refit fails,update_successis False, butrefit_lora_weightswould still be attempted- The
update_successfrom line 1006 is overwritten by the LoRA refit at line 1012-1017- The timer context wraps both base and LoRA refits, making it hard to distinguish their individual timing
🔧 Suggested improvements for clarity and correctness
with timer_context: - update_success = False + base_update_success = True # Default to True if not refitting base if refit_base_model_weights: - update_success = _perform_refit_weights( + base_update_success = _perform_refit_weights( refit_base_model_weights=True, refit_lora_weights=False ) + + lora_update_success = True # Default to True if not refitting LoRA if refit_lora_weights: - update_success = ( + lora_update_success = ( _perform_refit_weights( refit_base_model_weights=False, refit_lora_weights=True ) - and update_success ) + + # Both refits must succeed + update_success = base_update_success and lora_update_success + if not update_success: + raise RuntimeError( + "Weight refit failed. " + f"Base model: {'✓' if base_update_success else '✗'}, " + f"LoRA: {'✓' if lora_update_success else '✗'}" + )This makes the success tracking clearer and provides better error messages.
🤖 Fix all issues with AI agents
In @nemo_rl/algorithms/grpo.py:
- Around line 1093-1102: The refit logic after refit_policy_generation currently
flips REFIT_BASE_MODEL_WEIGHTS to False whenever REFIT_LORA_WEIGHTS is True
without explanation or safeguards; add a concise clarifying comment explaining
that this is an optimization to skip base-model refits once LoRA-only updates
are enabled, and modify the update logic to respect edge cases by only disabling
base refit when REFIT_LORA_WEIGHTS is True AND there is no pending base-model
structural change or checkpoint load (introduce/check a guard flag like
BASE_MODEL_CHANGED or FORCE_REFIT_BASE and the existing POLICY_GENERATION_STALE
if relevant); ensure the change is applied where REFIT_BASE_MODEL_WEIGHTS is set
and referenced alongside refit_policy_generation, policy_generation, and
colocated_inference so future readers/maintainers understand and can override
when simultaneous refits are required.
- Around line 1060-1061: REFIT_LORA_WEIGHTS is set by directly accessing
policy.lora_enabled which can raise AttributeError for policies that don't
define that attribute (e.g., DTensorPolicyWorker, MegatronPolicyWorker); update
the assignment of REFIT_LORA_WEIGHTS to read lora_enabled from policy using
getattr with a default of False so missing attributes are handled defensively.
In @nemo_rl/models/generation/lora.py:
- Around line 1-13: The new module nemo_rl/models/generation/lora.py is missing
from the pyrefly.toml whitelist and must be added to the project-includes so CI
stops failing; open pyrefly.toml and add the relative path
"nemo_rl/models/generation/lora.py" (or the appropriate glob such as
"nemo_rl/models/generation/*.py") to the project-includes array, save the file,
and re-run the pipeline to verify the whitelist change fixes the failure.
In @nemo_rl/models/generation/vllm/vllm_backend.py:
- Around line 167-172: The function update_weights_via_ipc_zmq uses a mutable
default for lora_config ({}), which can lead to shared-state bugs; change the
signature to use lora_config: Optional[dict[str, Any]] = None and inside the
function set lora_config = {} if lora_config is None, and update any subsequent
checks that assume a dict (e.g., the conditional that currently checks
truthiness of lora_config) to correctly handle None/empty dict cases so behavior
is unchanged.
In @nemo_rl/models/generation/vllm/vllm_generation.py:
- Around line 924-934: The return type annotation of get_model_state_dict() is
incorrect: update the signature and docstring to return list[dict[str, Any]]
(matching ray.get(futures) which returns results from all workers) in
nemo_rl/models/generation/vllm/vllm_generation.py, keeping the implementation
(worker_group.run_all_workers_single_data and ray.get) unchanged; optionally, if
you only need rank-0 results, call run_all_workers_single_data with
run_rank_0_only_axes=["tensor_parallel", "pipeline_parallel"] or use the
appropriate worker_group method to request rank-0 only to reduce object-store
overhead.
In @tests/functional/grpo_automodel_lora.sh:
- Line 4: The trap cleanup uses rm -rf /tmp/lora_sft_checkpoints but the script
actually configures the checkpoint directory as /tmp/lora_grpo_checkpoints (the
configured checkpoint variable on line ~38); fix by making the trap target match
the configured checkpoint directory—i.e., update the trap command string "rm -rf
/tmp/lora_sft_checkpoints" to "rm -rf /tmp/lora_grpo_checkpoints" (or
alternatively change the configured checkpoint path to
/tmp/lora_sft_checkpoints) so the exit cleanup removes the correct directory.
🧹 Nitpick comments (11)
nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
1683-1689: Fix comparison style for consistency with Python conventions.The assertion uses
refit_lora_weights == Falsewhich should be replaced withnot refit_lora_weightsfor more Pythonic code.♻️ Proposed fix
- assert refit_base_model_weights and refit_lora_weights == False, ( + assert refit_base_model_weights and not refit_lora_weights, ( f"dtensor v1 not support lora. refit_lora_weights must be False, but got refit_lora_weights={refit_lora_weights} and refit_base_model_weights={refit_base_model_weights}" )nemo_rl/models/generation/interfaces.py (1)
248-252: Document new parameters in the docstring.The signature now includes
refit_base_model_weightsandrefit_lora_weightsparameters to support selective weight refitting, but the docstring doesn't describe their purpose or usage.📝 Suggested docstring enhancement
def update_weights_via_ipc_zmq( self, refit_base_model_weights: bool = True, refit_lora_weights: bool = False ) -> list[ray.ObjectRef]: - """Update the model weights from the given IPC handles.""" + """Update the model weights from the given IPC handles. + + Args: + refit_base_model_weights: Whether to update base model weights (default: True) + refit_lora_weights: Whether to update LoRA adapter weights (default: False) + + Returns: + List of Ray object references for async completion tracking + """ raise NotImplementedErrorexamples/configs/grpo_math_1B.yaml (1)
84-105: LoRA config block looks reasonable; please ensure schema/TypedDict coverage exists for these keys.Since this introduces new public config keys under
policy.dtensor_cfg.lora_cfg, make sure the corresponding config schema (TypedDict / validation) documents each key and treats optional keys as optional (and defaults live in YAML), per repo guidelines.examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml (1)
1-27: Verifydefaultssyntax matches the repo’s config loader expectations.Many Hydra-style configs require
defaultsto be a list (often with_self_). If this repo expects list-form, this scalar may silently fail or be treated as a plain key.tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh (1)
31-41: Harden the “target step reached” jq gate against missing/empty metrics.If
train/lossis absent (or file incomplete),maxcan error. Consider making the jq query return-1on empty/missing and only then run checks.Proposed fix
-if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then +if [[ $(jq -r ' + (to_entries + | map(select(.key == "train/loss")) + | .[0].value? // {} + | keys + | map(tonumber) + | (if length > 0 then max else -1 end) + )' "$JSON_METRICS") -ge $MAX_STEPS ]]; then uv run tests/check_metrics.py $JSON_METRICS \ 'mean(data["train/gen_kl_error"]) < 0.001' \ 'data["train/gen_kl_error"]["20"] < 0.001' \ 'mean(data["train/reward"]) > 0.56' \ 'mean(data["timing/train/total_step_time"], 2) < 50' finemo_rl/models/generation/lora.py (2)
103-105: Use bareraiseinstead ofraise e.Per Python best practices, re-raising the same exception should use bare
raiseto preserve the original traceback.♻️ Proposed fix
except Exception as e: # For BadRequestError - raise e + raise
123-131: Resolve the TODO about file location.The comment indicates uncertainty about whether this helper should live here or in
nemo_rl/models/generation/vllm/utils.py. Consider moving it to the utils module since it's specifically for vLLM metadata, or remove the comment if the current location is intentional.nemo_rl/models/generation/vllm/vllm_worker.py (1)
816-843: Remove unused variablesa_shapesandb_shapes.These variables are computed but never used. If they're intended for debugging or future use, consider removing them to clean up the code.
♻️ Proposed fix
if isinstance(module, BaseLinearLayerWithLoRA): - a_shapes = [tuple(t.shape) for t in module.lora_a_stacked] - b_shapes = [tuple(t.shape) for t in module.lora_b_stacked] a_weights = [t for t in module.lora_a_stacked] b_weights = [t for t in module.lora_b_stacked] details.append(nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
1764-1776: Clarify docstring to match actual behavior.The docstring says "Only yields LoRA weights when LoRA is enabled" but the actual behavior is controlled by the
refit_base_model_weightsandrefit_lora_weightsflags, not LoRA enablement status. Consider updating the docstring to accurately describe the filtering behavior.♻️ Suggested docstring update
def dtensor_params_generator(): - """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. - - Only yields LoRA weights when LoRA is enabled, otherwise yields all weights. - """ + """Generator that yields (name, tensor) pairs, converting DTensors to local tensors. + + Filters weights based on refit_base_model_weights and refit_lora_weights flags. + """tests/unit/models/generation/test_vllm_generation.py (1)
964-970: Recommend adding LoRA test cases for non-colocated mode.The non-colocated test parametrization doesn't include any
enable_lora=Truecases, while the colocated test does. According to the PR objectives, LoRA support is specifically for "synchronous colocated inference on the DTensor backend."However, it would be valuable to verify that non-colocated mode properly rejects or handles LoRA configuration.
Consider adding a test case to verify LoRA restriction in non-colocated mode
@pytest.mark.parametrize( - ("async_engine", "cpu_offload", "vllm_precision", "enable_lora"), + ("async_engine", "cpu_offload", "vllm_precision", "enable_lora", "expect_error"), [ - (True, False, "bfloat16", False), - (False, True, "bfloat16", False), - (True, False, "fp8", False), - (False, True, "fp8", False), + (True, False, "bfloat16", False, False), + (False, True, "bfloat16", False, False), + (True, False, "fp8", False, False), + (False, True, "fp8", False, False), + # Test that LoRA is properly rejected in non-colocated mode + (False, False, "bfloat16", True, True), ], )Then update the test body to check for the expected assertion based on line 469 in grpo.py.
nemo_rl/algorithms/grpo.py (1)
920-943: Review refit_policy_generation parameter logic.The assertion at line 941-943 ensures at least one of
refit_base_model_weightsorrefit_lora_weightsis True. However, the parameter defaults are:
refit_base_model_weights: Optional[bool] = Truerefit_lora_weights: Optional[bool] = FalseSince both have defaults and the assertion checks they can't both be False, the
Optionaltype hint is misleading—these are effectively required booleans with default values.Consider simplifying type hints for clarity
def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, colocated_inference: bool, _refit_buffer_size_gb: Optional[int] = None, timer: Optional[Timer] = None, kv_scales: Optional[dict[str, float]] = None, - refit_base_model_weights: Optional[bool] = True, - refit_lora_weights: Optional[bool] = False, + refit_base_model_weights: bool = True, + refit_lora_weights: bool = False, ) -> None:This better reflects that these are boolean flags with defaults, not optional parameters that can be None.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (19)
examples/configs/grpo_math_1B.yamlexamples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yamlnemo_rl/algorithms/grpo.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/lora.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/policy/workers/megatron_policy_worker.pytests/functional/L1_Functional_Tests_GPU.shtests/functional/grpo_automodel_lora.shtests/functional/sft_automodel_lora.shtests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.shtests/test_suites/nightly.txttests/unit/models/generation/test_vllm_generation.py
🧰 Additional context used
📓 Path-based instructions (9)
!(**/tests/**|**/test_*.py|**/test_*.sh)
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Add the NVIDIA copyright header to all Python files and shell scripts (excluding tests). The header should include the current year
Files:
examples/configs/grpo_math_1B.yamltests/functional/L1_Functional_Tests_GPU.shtests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.shtests/functional/grpo_automodel_lora.shtests/test_suites/nightly.txtnemo_rl/models/generation/vllm/vllm_worker_async.pyexamples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yamlnemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/lora.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_worker.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/models/policy/workers/megatron_policy_worker.py
**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.sh: Use uv run instead of python to execute scripts
Follow the Google Shell Style Guide for shell scripts
Files:
tests/functional/L1_Functional_Tests_GPU.shtests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.shtests/functional/grpo_automodel_lora.sh
**/*.{py,sh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
The NVIDIA copyright header should appear at the top of all Python files and shell scripts (excluding tests)
Files:
tests/functional/L1_Functional_Tests_GPU.shtests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.shtests/functional/grpo_automodel_lora.shnemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/lora.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_worker.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/models/policy/workers/megatron_policy_worker.py
tests/test_suites/**/*.sh
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
tests/test_suites/**/*.sh: When adding support for a new model, create a corresponding driver shell script under tests/test_suites/ in the matching domain
Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run
Files:
tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh
tests/test_suites/nightly.txt
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
When adding a nightly test for a new model, append the driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt
Files:
tests/test_suites/nightly.txt
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Conform code to Python 3.12+
Indent code with 4 spaces. Do not use tabs
Use snake_case for file names
Use PascalCase for class names
Use snake_case for function and method names
Use snake_case for local variables
Prefix variable names that start with a number with 'k' (e.g., k_99th_percentile)
Use upper snake_case with 'G' prefix for global variables (e.g., G_MY_GLOBAL)
Use upper snake_case for constants
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
Prefer docstrings over comments for interfaces that may be used outside a file
Reserve comments for code within a function or interfaces that are local to a file
If a piece of code is commented out, include a comment describing its usage and why it's commented out. Remove debug comments before merging
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx
Avoid using reflection when functionality can be easily achieved without reflection
When using try-except blocks, limit the except clause to the smallest set of specific errors possible
When using try-except blocks for duck-typing, keep the body of the try as small as possible and use the else block for logic
YAML is the single source of truth for configuration defaults. Do not set non-None defaults in code for configuration values
For required configuration attributes, access config directly and expect presence (e.g., policy_cfg['precision']) without hidden defaults
Use typing.NotRequired to mark optional attributes in TypedDict for configuration
When adding a new config key to a TypedDict subclass, document the key's purpose, valid values/types, and recommended default, and reflect the default in exemplar YAMLs under examples/configs/*.yaml
Follow the Google Python Style Guide for Python code
Files:
nemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/lora.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_worker.pytests/unit/models/generation/test_vllm_generation.pynemo_rl/models/policy/workers/megatron_policy_worker.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
For any source file under nemo_rl/*.py that defines a class or function decorated with @ray.remote, add a coverage pragma (# pragma: no cover) because these run in separate Ray processes
Files:
nemo_rl/models/generation/vllm/vllm_worker_async.pynemo_rl/models/policy/workers/dtensor_policy_worker_v2.pynemo_rl/models/generation/interfaces.pynemo_rl/models/generation/lora.pynemo_rl/models/policy/workers/dtensor_policy_worker.pynemo_rl/models/generation/vllm/vllm_backend.pynemo_rl/models/generation/vllm/vllm_generation.pynemo_rl/algorithms/grpo.pynemo_rl/models/policy/lm_policy.pynemo_rl/models/generation/vllm/vllm_worker.pynemo_rl/models/policy/workers/megatron_policy_worker.py
examples/configs/recipes/**/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
When adding support for a new model, create a recipe YAML under examples/configs/recipes/ in the appropriate domain subdirectory (llm, vlm, etc.)
Files:
examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
examples/configs/recipes/llm/*.yaml
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Recipe YAML files should follow the naming pattern: --ng-[-modifiers][-long][.vN].yaml for LLM recipes
Files:
examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
🧠 Learnings (11)
📚 Learning: 2025-10-30T20:50:44.126Z
Learnt from: adil-a
Repo: NVIDIA-NeMo/RL PR: 1440
File: examples/configs/sft_automodel.yaml:48-58
Timestamp: 2025-10-30T20:50:44.126Z
Learning: In DTensor configurations for MoE (Mixture of Experts) models, expert_parallel_size and data_parallel_size can be applied together without multiplying the GPU requirements. Expert Parallelism (EP) only applies to MoE layers, while Data Parallelism/FSDP applies to non-MoE layers. Therefore, configurations like expert_parallel_size: 8 and data_parallel_size: 8 are valid on an 8-GPU cluster for MoE models.
Applied to files:
examples/configs/grpo_math_1B.yaml
📚 Learning: 2025-10-12T14:46:57.171Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:6-11
Timestamp: 2025-10-12T14:46:57.171Z
Learning: Test scripts in tests/test_suites/llm/ follow a standard configuration pattern that includes NUM_NODES, STEPS_PER_RUN, MAX_STEPS, NUM_RUNS (calculated as `$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN ))`), and NUM_MINUTES. These variables are part of the test infrastructure's standard interface and should not be flagged as unused even if not directly referenced within the individual script, as they are consumed by external launch tooling or common.env.
Applied to files:
tests/functional/L1_Functional_Tests_GPU.shtests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/**/*.sh : Driver shell scripts should match the YAML base name with .sh extension and invoke training entrypoint with uv run
Applied to files:
tests/functional/L1_Functional_Tests_GPU.shtests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.shtests/functional/grpo_automodel_lora.sh
📚 Learning: 2025-10-12T14:46:55.513Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1324
File: tests/test_suites/llm/distillation-qwen3-32b-to-1.7b-base-1n8g-megatron-tp2pp2cp2-pack.sh:16-30
Timestamp: 2025-10-12T14:46:55.513Z
Learning: In the NVIDIA-NeMo/RL repository, test scripts under tests/ follow a consistent pattern: use `cd $PROJECT_ROOT` without quotes or error handling, and pass arguments with `$@` unquoted. Maintain this consistency when adding new test scripts.
Applied to files:
tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.shtests/functional/grpo_automodel_lora.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/**/*.sh : When adding support for a new model, create a corresponding driver shell script under tests/test_suites/ in the matching domain
Applied to files:
tests/functional/grpo_automodel_lora.sh
📚 Learning: 2025-09-19T07:28:29.887Z
Learnt from: shuo-nvidia
Repo: NVIDIA-NeMo/RL PR: 1006
File: tests/test_suites/llm/distillation-qwen3-32b-to-4b-base-2n8g-fsdp2tp2-long.v1.sh:1-4
Timestamp: 2025-09-19T07:28:29.887Z
Learning: The NVIDIA-NeMo/RL project prefers to maintain consistent formatting across test scripts rather than applying individual bash hardening improvements like `set -euo pipefail` or proper quoting for sourcing files.
Applied to files:
tests/functional/grpo_automodel_lora.sh
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to tests/test_suites/nightly.txt : When adding a nightly test for a new model, append the driver script path (relative to tests/test_suites/) to tests/test_suites/nightly.txt
Applied to files:
tests/test_suites/nightly.txt
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to examples/configs/recipes/llm/*.yaml : Recipe YAML files should follow the naming pattern: <algo>-<model>-<nodes>n<gpus>g-<strategy-and-params>[-modifiers][-long][.vN].yaml for LLM recipes
Applied to files:
examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to examples/configs/recipes/**/*.yaml : When adding support for a new model, create a recipe YAML under examples/configs/recipes/ in the appropriate domain subdirectory (llm, vlm, etc.)
Applied to files:
examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
📚 Learning: 2025-11-24T17:24:41.976Z
Learnt from: CR
Repo: NVIDIA-NeMo/RL PR: 0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-11-24T17:24:41.976Z
Learning: Applies to examples/configs/recipes/vlm/*.yaml : Recipe YAML files should follow the naming pattern: vlm_<algo>-<model>-<nodes>n<gpus>g-<strategy>[-modifiers][.vN].yaml for VLM recipes
Applied to files:
examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
📚 Learning: 2025-09-18T14:20:36.297Z
Learnt from: zpqiu
Repo: NVIDIA-NeMo/RL PR: 1006
File: examples/configs/recipes/llm/distillation-qwen3-32b-to-8b-base-2n8g-fsdp2tp2.v1.yaml:113-120
Timestamp: 2025-09-18T14:20:36.297Z
Learning: In distillation workflows, the teacher policy does not perform generation - it only does inference/logprob computation on sequences generated by the student policy. Therefore, teacher generation configuration mismatches (like vLLM tensor parallelism settings) and colocation concerns are not relevant.
Applied to files:
nemo_rl/algorithms/grpo.pytests/unit/models/generation/test_vllm_generation.py
🧬 Code graph analysis (5)
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
nemo_rl/utils/packed_tensor.py (1)
packed_broadcast_producer(39-95)
nemo_rl/models/generation/interfaces.py (3)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_via_ipc_zmq(167-296)nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_via_ipc_zmq(752-783)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_via_ipc_zmq(770-793)
nemo_rl/models/generation/vllm/vllm_backend.py (4)
nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_via_ipc_zmq(752-783)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_via_ipc_zmq(770-793)nemo_rl/models/generation/interfaces.py (1)
update_weights_via_ipc_zmq(248-252)nemo_rl/models/generation/lora.py (2)
LoRARequestWithCfgAndWeights(23-25)get_vllm_lora_metadata(124-132)
nemo_rl/algorithms/grpo.py (7)
nemo_rl/models/generation/vllm/vllm_backend.py (1)
update_weights_via_ipc_zmq(167-296)nemo_rl/models/generation/vllm/vllm_worker.py (1)
update_weights_via_ipc_zmq(752-783)nemo_rl/models/generation/vllm/vllm_generation.py (1)
update_weights_via_ipc_zmq(770-793)nemo_rl/models/policy/workers/dtensor_policy_worker.py (1)
offload_before_refit(1807-1814)nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (1)
offload_before_refit(1893-1900)nemo_rl/models/policy/interfaces.py (1)
offload_before_refit(168-169)nemo_rl/utils/timer.py (1)
time(110-123)
nemo_rl/models/policy/lm_policy.py (1)
nemo_rl/distributed/worker_groups.py (1)
run_all_workers_single_data(755-799)
🪛 GitHub Actions: CICD NeMo RL
nemo_rl/models/generation/lora.py
[error] 1-1: File nemo_rl/models/generation/lora.py has zero errors but is not in pyrefly.toml in the 'project-includes' list. Please add it to this whitelist.
🪛 Ruff (0.14.10)
nemo_rl/models/generation/lora.py
99-102: Avoid specifying long messages outside the exception class
(TRY003)
105-105: Use raise without specifying exception name
Remove exception name
(TRY201)
108-111: Avoid specifying long messages outside the exception class
(TRY003)
120-120: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
nemo_rl/models/policy/workers/dtensor_policy_worker.py
1687-1687: Avoid equality comparisons to False; use not refit_lora_weights: for false checks
Replace with not refit_lora_weights
(E712)
nemo_rl/models/generation/vllm/vllm_backend.py
169-169: Do not use mutable data structures for argument defaults
Replace with None; initialize within function
(B006)
269-271: Abstract raise to an inner function
(TRY301)
269-271: Avoid specifying long messages outside the exception class
(TRY003)
nemo_rl/models/generation/vllm/vllm_worker.py
829-829: Local variable a_shapes is assigned to but never used
Remove assignment to unused variable a_shapes
(F841)
830-830: Local variable b_shapes is assigned to but never used
Remove assignment to unused variable b_shapes
(F841)
nemo_rl/models/policy/workers/megatron_policy_worker.py
2192-2192: Unused method argument: refit_base_model_weights
(ARG002)
2193-2193: Unused method argument: refit_lora_weights
(ARG002)
🪛 Shellcheck (0.11.0)
tests/test_suites/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.sh
[warning] 6-6: NUM_NODES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 9-9: NUM_RUNS appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 10-10: NUM_MINUTES appears unused. Verify use (or export if used externally).
(SC2034)
[warning] 16-16: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
[error] 28-28: Double quote array expansions to avoid re-splitting elements.
(SC2068)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Docs_Tests
🔇 Additional comments (34)
tests/test_suites/nightly.txt (1)
62-64: LGTM! Nightly test entry added correctly.The new LoRA GRPO test entry follows the established pattern and is correctly placed in the GRPO section with proper formatting.
tests/functional/L1_Functional_Tests_GPU.sh (1)
30-30: LGTM! Functional test invocation follows established patterns.The new LoRA GRPO test invocation correctly uses
uv run --no-syncas per coding guidelines and is properly sequenced with other functional tests.nemo_rl/models/policy/workers/megatron_policy_worker.py (1)
2188-2194: Parameters added for interface consistency (currently unused).The
refit_base_model_weightsandrefit_lora_weightsparameters are added to match the updated interface but aren't utilized in the Megatron implementation yet. This is acceptable for maintaining consistency across policy worker implementations, though the parameters could be documented or removed if not planned for near-term use.nemo_rl/models/generation/vllm/vllm_generation.py (1)
770-794: Forwarding refit flags through worker_group is good; double-check default semantics at call sites.Defaults preserve prior behavior (
refit_base_model_weights=True,refit_lora_weights=False), but LoRA training paths likely need the opposite. Ensure upstream refit logic always passes the intended combination (especially when LoRA is enabled).tests/functional/grpo_automodel_lora.sh (1)
24-45: LGTM!The test script correctly uses
uv runfor execution, includes appropriate LoRA configuration parameters, and follows the project's established testing patterns with metrics verification.nemo_rl/models/policy/lm_policy.py (3)
89-91: LGTM!Clean initialization of
lora_enabledattribute with sensible default, ensuring the attribute is always defined regardless of which backend path is taken.
115-129: LGTM!The LoRA configuration extraction and V1 worker assertion are well-implemented. The assertion provides a clear error message for unsupported configurations.
765-779: LGTM!The extended method signature with
refit_base_model_weightsandrefit_lora_weightsflags properly propagates to the worker group call, enabling selective weight updates.nemo_rl/models/generation/lora.py (2)
23-25: LGTM!Clean extension of
LoRARequestwith optional config and weights fields for in-memory LoRA loading.
115-120: LGTM!The monkey-patching approach using
setattris appropriate here for patching vLLM's internal LoRA manager. While static analysis flags this pattern, it's the correct approach for runtime patching of third-party library internals.nemo_rl/models/generation/vllm/vllm_worker.py (5)
142-142: LGTM!Proper extraction of LoRA configuration from vLLM config during worker initialization.
400-409: LGTM!Clean integration of LoRA patching and vLLM kwargs configuration. The
max_loras=1constraint is appropriately documented.
582-593: LGTM!The LoRA request construction for generation is properly gated by
self.lora_enabledand uses the metadata helper correctly.
752-768: LGTM!The updated signature properly includes the refit flags and passes
self.lora_cfgalong with the flags to the collective RPC call.
845-855: LGTM!Simple and useful introspection helper for retrieving the model state dict via collective RPC.
nemo_rl/models/generation/vllm/vllm_backend.py (3)
128-154: LGTM!The
map_param_namefunction correctly resolves packed LoRA modules to their canonical names and properly handles the base_layer path transformation for LoRA-enabled models.
156-164: LGTM!Clean helper function that applies the name mapping transformation to weight tuples.
229-271: LGTM on the logic flow.The branching logic for
refit_base_model_weightsvsrefit_lora_weightsis correct:
- Base model weights get name mapping when LoRA is enabled before loading
- LoRA weights are loaded via the LoRARequest mechanism with proper remove/add cycle
- The ValueError for both flags being False is appropriate validation
nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py (3)
1706-1716: LGTM!Clean helper functions for identifying LoRA vs base model weights by suffix patterns. The logic is straightforward and matches the expected LoRA weight naming conventions.
1743-1796: LGTM!The
stream_weights_via_ipc_zmqmethod correctly implements selective weight streaming based on the refit flags, with proper DTensor-to-local conversion and dtype casting.
1798-1844: LGTM!The
broadcast_weights_for_collectivemethod mirrors the streaming logic with consistent filtering behavior. The_filtered_state_dict_iteratorhelper cleanly encapsulates the filtering logic.tests/unit/models/generation/test_vllm_generation.py (9)
38-38: LGTM: LoRAConfig import added to support LoRA test scenarios.The import enables LoRA configuration for the test module.
73-73: LGTM: kv_cache_dtype configuration added to vLLM config.This allows tests to specify KV cache data type, supporting FP8 and other precision scenarios.
109-109: LGTM: _v2 flag added to dtensor_config.This flag enables DTensor v2 backend, which is required for LoRA support as noted in the PR objectives.
132-143: LGTM: LoRA test configuration properly defined.The configuration provides sensible defaults for LoRA testing with
enabled: Falseby default, allowing tests to explicitly enable it when needed.
898-906: LGTM: Test parametrization extended with LoRA scenarios.The parametrization properly covers:
- Existing non-LoRA tests remain unchanged
- New LoRA tests with async/sync and cpu_offload variations
- FP8 tests remain without LoRA (correct, as FP8+LoRA is incompatible)
927-940: LGTM: LoRA configuration properly wired in colocated test.The test correctly:
- Configures vLLM with LoRA settings
- Enables DTensor v2 when LoRA is enabled
- Propagates LoRA config to both vLLM and DTensor paths
1005-1007: LGTM: Proper assertion prevents incompatible FP8+LoRA configuration.This matches the guard at line 517-519 in grpo.py and ensures tests catch this incompatibility.
1014-1017: LGTM: DTensor v2 requirement for LoRA properly documented.The comment and configuration correctly enforce that LoRA requires DTensor v2 backend.
705-759: No action needed—the refit logic already correctly handles all specified edge cases.The implementation of
refit_policy_generationinnemo_rl/algorithms/grpo.py(lines 1006-1017) explicitly manages each scenario you mentioned:
- Base weights only:
refit_base_model_weights=True, refit_lora_weights=False- Both base and LoRA weights: Both parameters set to
True, executed sequentially- Sequential refitting: Base weights are always refitted first (line 1008), followed by LoRA weights (line 1011)
The function includes an assertion preventing both parameters from being
False, and the test at lines 753-758 correctly passesrefit_base_model_weights=Truewithrefit_lora_weights=enable_lora, which aligns with the intended behavior.nemo_rl/algorithms/grpo.py (4)
517-519: LGTM: Proper guard prevents FP8+LoRA incompatibility.This assertion correctly prevents an unsupported configuration and provides a clear error message.
1190-1200: LGTM: Consistent refit flag usage in training loop.The refit calls in the main training loop correctly use the runtime flags for selective weight updates.
1432-1441: LGTM: Consistent refit flag usage in validation.The refit calls before validation correctly use the runtime flags, maintaining consistency with the training loop.
463-474: No issues found. The DTensor LoRA configuration uses the sameLoRAConfigTypedDict schema that vLLM expects, making the direct assignment type-safe and compatible. The code already includes appropriate validation through assertions that check for colocated inference and prevent async rollouts with DTensor LoRA.Likely an incorrect or invalid review comment.
| # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add file to pyrefly.toml whitelist.
The pipeline failure indicates this new file needs to be added to the pyrefly.toml project-includes list.
🧰 Tools
🪛 GitHub Actions: CICD NeMo RL
[error] 1-1: File nemo_rl/models/generation/lora.py has zero errors but is not in pyrefly.toml in the 'project-includes' list. Please add it to this whitelist.
🤖 Prompt for AI Agents
In @nemo_rl/models/generation/lora.py around lines 1 - 13, The new module
nemo_rl/models/generation/lora.py is missing from the pyrefly.toml whitelist and
must be added to the project-includes so CI stops failing; open pyrefly.toml and
add the relative path "nemo_rl/models/generation/lora.py" (or the appropriate
glob such as "nemo_rl/models/generation/*.py") to the project-includes array,
save the file, and re-run the pipeline to verify the whitelist change fixes the
failure.
| #!/bin/bash | ||
|
|
||
| # clean up checkpoint directory on exit | ||
| trap "rm -rf /tmp/lora_sft_checkpoints" EXIT |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trap cleanup path mismatch.
The trap references /tmp/lora_sft_checkpoints but the actual checkpoint directory configured on line 38 is /tmp/lora_grpo_checkpoints. This means the cleanup won't remove the correct directory on exit.
🔧 Proposed fix
-trap "rm -rf /tmp/lora_sft_checkpoints" EXIT
+trap "rm -rf /tmp/lora_grpo_checkpoints" EXIT📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| trap "rm -rf /tmp/lora_sft_checkpoints" EXIT | |
| trap "rm -rf /tmp/lora_grpo_checkpoints" EXIT |
🤖 Prompt for AI Agents
In @tests/functional/grpo_automodel_lora.sh at line 4, The trap cleanup uses rm
-rf /tmp/lora_sft_checkpoints but the script actually configures the checkpoint
directory as /tmp/lora_grpo_checkpoints (the configured checkpoint variable on
line ~38); fix by making the trap target match the configured checkpoint
directory—i.e., update the trap command string "rm -rf
/tmp/lora_sft_checkpoints" to "rm -rf /tmp/lora_grpo_checkpoints" (or
alternatively change the configured checkpoint path to
/tmp/lora_sft_checkpoints) so the exit cleanup removes the correct directory.
59f553e to
48ac5b8
Compare
ℹ️ File Consistency CheckCheck based on commit: 48ac5b8 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @RayenTian a lot for supporting this! reviewed partly and left some comments.
48ac5b8 to
6d75d86
Compare
ℹ️ File Consistency CheckCheck based on commit: 6d75d86 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: ab6b375 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
ℹ️ File Consistency CheckCheck based on commit: ec85ee9 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
yuki-97
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
review finished, thanks again for supporting this!
| activation_checkpointing: true | ||
| lora_cfg: | ||
| enabled: True | ||
| dim: 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious why dim and alpha seems larger than that in sft? is it b/c there's some setting we are comparing with use this setting?
examples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yaml
Outdated
Show resolved
Hide resolved
| lora_cfg = policy_config.get("dtensor_cfg", {}).get("lora_cfg", {}) | ||
| if lora_cfg.get("enabled", False): | ||
| # Override the vLLM lora config with the DTensor lora config | ||
| generation_config["vllm_cfg"]["lora_cfg"] = lora_cfg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should be okay for this PR, but do you have any ideas how to unify it when supporting grpo lora in mcore?
also cc @vadam5
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
…de' argument Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
Signed-off-by: ruit <ruit@nvidia.com>
…ary parameters Signed-off-by: ruit <ruit@nvidia.com>
…ntegration Signed-off-by: ruit <ruit@nvidia.com>
ec85ee9 to
65a8e24
Compare
ℹ️ File Consistency CheckCheck based on commit: 65a8e24 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
…ases Signed-off-by: ruit <ruit@nvidia.com>
ℹ️ File Consistency CheckCheck based on commit: 7e37787 (PR #1748 from ✅ DTensor Policy Worker Synchronization CheckBoth DTensor policy worker files were modified in this PR:
Please ensure that the changes are consistent between both files where applicable. This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning. |
What does this PR do ?
Add LoRA Support for GRPO with Synchronous Colocated Inference
Core Features
refit_base_model_weightsandrefit_lora_weightsparameters in therefit_policy_generation()functionNew Components
nemo_rl/models/generation/lora.py: New module containing:Testing
tests/functional/grpo_automodel_lora.shto validate end-to-end LoRA GRPO workflowexamples/configs/recipes/llm/grpo-qwen3-8B-base-1n8g-fsdp2-lora.yamlIssues
[1/3] of #1597
Subsequent PRs
Usage
Result
Models to Test
Co-located + Sync
Qwen/Qwen3-0.6B
Llama-3.2-3B-Instruct
Llama-3.1-8B
Qwen2.5-7B
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
New Examples
Tests
✏️ Tip: You can customize this high-level summary in your review settings.