Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 96 additions & 15 deletions rl_engine/executors/deepspeed_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,7 @@
TrainingStageResult,
objective_reference_logps,
)
from rl_engine.testing import (
compute_policy_ratio,
compute_reference_kl,
masked_mean,
selected_logprobs_reference,
)
from rl_engine.testing import compute_policy_ratio, compute_reference_kl, masked_mean

_TDestination = TypeVar("_TDestination", bound=dict[str, Any])

Expand Down Expand Up @@ -84,9 +79,9 @@ def __init__(
deepspeed = _load_deepspeed()
self._deepspeed = deepspeed
torch.manual_seed(self.config.seed)
self.model = torch.nn.Sequential(
torch.nn.Embedding(self.config.vocab_size, self.config.hidden_dim),
torch.nn.Linear(self.config.hidden_dim, self.config.vocab_size),
self.model = _EmbeddingLMHeadModel(
self.config.vocab_size,
self.config.hidden_dim,
).to(device=self.device)
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr)

Expand All @@ -101,17 +96,27 @@ def __init__(
engine_device = getattr(self.engine, "device", None)
if engine_device is not None:
self.device = torch.device(engine_device)
self._linear_logp_op = _linear_logp_op_for_device(self.device)
self.model.linear_logp_op = self._linear_logp_op

def train(self, rollout: RolloutStageResult) -> TrainingStageResult:
started_at = time.perf_counter()
batch, payload_metrics = self._batch_from_rollout_or_synthetic(rollout)

logits = _extract_logits(self.engine(batch.token_ids.long()))
current_logps = selected_logprobs_reference(
logits,
token_ids = _safe_token_ids(
batch.token_ids,
mask=batch.completion_mask,
output_dtype=torch.float32,
batch.completion_mask,
vocab_size=self.config.vocab_size,
)

current_logps = _extract_logps(
self.engine(
token_ids,
target_ids=token_ids,
)
).to(dtype=torch.float32)
current_logps = current_logps.masked_fill(
~batch.completion_mask.to(device=current_logps.device, dtype=torch.bool),
0.0,
)
old_logps = current_logps.detach() - 0.01
ref_logps = objective_reference_logps(current_logps, batch)
Expand Down Expand Up @@ -147,6 +152,8 @@ def train(self, rollout: RolloutStageResult) -> TrainingStageResult:
"training_device": str(self.device),
"deepspeed_engine": type(self.engine).__name__,
"deepspeed_zero_stage": self.config.zero_stage,
"lm_head_projection_path": "linear_logp",
"lm_head_projection_backend": type(self._linear_logp_op).__name__,
"active_advantage_mean_global": (
float(active_advantages.mean().detach().cpu().item())
if active_advantages.numel()
Expand Down Expand Up @@ -273,6 +280,67 @@ def _resolved_deepspeed_config(self) -> dict[str, Any]:
return _deep_merge(base, dict(self.config.deepspeed_config))


class _EmbeddingLMHeadModel(torch.nn.Sequential):
"""Tiny policy model with an explicit deterministic LM-head logprob path."""

def __init__(self, vocab_size: int, hidden_dim: int) -> None:
super().__init__(
torch.nn.Embedding(vocab_size, hidden_dim),
torch.nn.Linear(hidden_dim, vocab_size),
)
self.linear_logp_op: Optional[Any] = None

@property
def embedding(self) -> torch.nn.Embedding:
return self[0]

@property
def lm_head(self) -> torch.nn.Linear:
return self[1]

def forward(
self,
input_ids: torch.Tensor,
*,
target_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden = self.embedding(input_ids.long())
if target_ids is None:
return self.lm_head(hidden)
if self.linear_logp_op is None:
raise ValueError("target_ids scoring requires a linear_logp_op")
return self.linear_logp_op(hidden, self.lm_head.weight, target_ids, self.lm_head.bias)
Comment on lines +301 to +312

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify whether this PR has a ZeRO-3 train-path test and where the LM-head
# parameters are accessed directly.
rg -n -C3 'zero_stage\s*=\s*3|"zero_optimization": \{"stage": 3\}' tests rl_engine
rg -n -C3 'lm_head\.(weight|bias)|linear_logp_op' rl_engine/executors/deepspeed_trainer.py

Repository: RL-Align/RL-Kernel

Length of output: 3317


Direct parameter access in ZeRO-3 path bypasses module hooks

In rl_engine/executors/deepspeed_trainer.py (lines 301–312), the forward method accesses self.lm_head.weight and self.lm_head.bias directly when target_ids is provided. Under DeepSpeed ZeRO-3, model parameters are partitioned across ranks and must be gathered via the module’s forward method to trigger the necessary parameter-gather hooks. Direct attribute access bypasses these hooks, which can lead to runtime failures or silent incorrect results in distributed training.

Although tests exist for zero_stage=3, they do not explicitly verify the code path where target_ids is passed and linear_logp_op is invoked. To ensure ZeRO-3 compatibility, encapsulate the custom scoring logic inside a dedicated nn.Module (e.g., _LinearLogpHead) and route both standard and custom paths through its forward method. This guarantees parameter-gather hooks are invoked regardless of the execution path.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@rl_engine/executors/deepspeed_trainer.py` around lines 301 - 312, The
`forward` method in `deepspeed_trainer.py` is reading `self.lm_head.weight` and
`self.lm_head.bias` directly on the `target_ids` path, which bypasses ZeRO-3
parameter-gather hooks. Move the custom scoring logic into a dedicated
`nn.Module` such as `_LinearLogpHead` and call it through its `forward` method
from `forward`, so both the normal logits path and the `target_ids` scoring path
go through module execution and trigger DeepSpeed hooks correctly.



def _linear_logp_op_for_device(device: torch.device) -> Any:
if device.type == "cpu":
from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp

return NativeLinearLogpOp()
from rl_engine.kernels.registry import kernel_registry

return kernel_registry.get_op("linear_logp")


def _safe_token_ids(
token_ids: torch.Tensor,
mask: torch.Tensor,
*,
vocab_size: int,
) -> torch.Tensor:
active = mask.to(device=token_ids.device, dtype=torch.bool)
safe_ids = token_ids.long().masked_fill(~active, 0)
if bool(((safe_ids < 0) | (safe_ids >= vocab_size)).any()):
active_ids = safe_ids[active]
t_min = int(active_ids.min()) if active_ids.numel() else 0
t_max = int(active_ids.max()) if active_ids.numel() else 0
raise ValueError(
f"active token_ids out of range: expected [0, {vocab_size - 1}], "
f"got [{t_min}, {t_max}]"
)
return safe_ids


def _load_deepspeed() -> Any:
_configure_cuda_home_from_python_packages()
try:
Expand Down Expand Up @@ -374,6 +442,19 @@ def _extract_logits(model_output: Any) -> torch.Tensor:
raise TypeError(f"DeepSpeed model output does not expose logits: {type(model_output)!r}")


def _extract_logps(model_output: Any) -> torch.Tensor:
if isinstance(model_output, torch.Tensor):
return model_output
if isinstance(model_output, Mapping) and "logps" in model_output:
return model_output["logps"]
logps = getattr(model_output, "logps", None)
if logps is not None:
return logps
if isinstance(model_output, (tuple, list)) and model_output:
return _extract_logps(model_output[0])
raise TypeError(f"DeepSpeed model output does not expose logps: {type(model_output)!r}")


class _StateDictModule(torch.nn.Module):
def __init__(self, state_dict: Mapping[str, torch.Tensor]):
super().__init__()
Expand Down
2 changes: 2 additions & 0 deletions rl_engine/kernels/ops/rocm/attention/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2026 RL-Kernel Contributors

import os
from typing import Callable

import torch

Expand All @@ -22,6 +23,7 @@ class RocmFlashAttentionOp:
"""

def __init__(self):
self.op: Callable[..., torch.Tensor]
if torch.version.hip is None:
raise RuntimeError("RocmFlashAttentionOp requires a ROCm PyTorch build.")

Expand Down
92 changes: 92 additions & 0 deletions tests/test_deepspeed_training_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,41 @@ def _rollout(iteration=2, weight_version=9):
)


def _ragged_rollout(iteration=2, weight_version=9):
return RolloutStageResult(
iteration=iteration,
weight_version=weight_version,
payload={
"normalized_outputs": [
[{"token_ids": [3, 4, 5], "text": "abc"}],
[{"token_ids": [6, 7], "text": "de"}],
]
},
started_at=time.perf_counter(),
finished_at=time.perf_counter(),
)


class SpyLinearLogpOp:
def __init__(self):
self.calls = []

def __call__(self, hidden, lm_head_weight, target_ids, bias=None):
self.calls.append(
{
"hidden_shape": tuple(hidden.shape),
"weight_shape": tuple(lm_head_weight.shape),
"target_shape": tuple(target_ids.shape),
"target_ids": target_ids.detach().cpu().clone(),
"has_bias": bias is not None,
}
)
logits = torch.nn.functional.linear(hidden, lm_head_weight, bias)
log_probs = torch.log_softmax(logits.float(), dim=-1)
selected = log_probs.gather(-1, target_ids.long().unsqueeze(-1)).squeeze(-1)
return selected


def test_importing_module_does_not_import_deepspeed(monkeypatch):
monkeypatch.delitem(sys.modules, "deepspeed", raising=False)

Expand Down Expand Up @@ -228,6 +263,63 @@ def test_deepspeed_training_worker_uses_engine_backward_and_step(monkeypatch):
assert result.metrics["active_advantage_std_global"] >= 0.0


def test_deepspeed_training_worker_routes_lm_head_through_linear_logp(monkeypatch):
_install_fake_deepspeed(monkeypatch)
from rl_engine.executors import deepspeed_trainer
from rl_engine.executors.deepspeed_trainer import (
DeepSpeedTrainingConfig,
DeepSpeedTrainingWorker,
)

spy_op = SpyLinearLogpOp()
monkeypatch.setattr(deepspeed_trainer, "_linear_logp_op_for_device", lambda device: spy_op)
worker = DeepSpeedTrainingWorker(
DeepSpeedTrainingConfig(
num_prompts=1,
samples_per_prompt=2,
prompt_len=2,
completion_len=3,
vocab_size=16,
hidden_dim=8,
valid_density=1.0,
seed=5,
)
)

result = worker.train(_ragged_rollout())

assert len(spy_op.calls) == 1
assert {key: value for key, value in spy_op.calls[0].items() if key != "target_ids"} == {
"hidden_shape": (2, 3, 8),
"weight_shape": (16, 8),
"target_shape": (2, 3),
"has_bias": True,
}
assert torch.equal(spy_op.calls[0]["target_ids"], torch.tensor([[3, 4, 5], [6, 7, 0]]))
assert result.metrics["lm_head_projection_path"] == "linear_logp"
assert result.metrics["lm_head_projection_backend"] == "SpyLinearLogpOp"
assert math.isfinite(result.metrics["loss"])


def test_deepspeed_training_safe_token_ids_allow_masked_ignore_index():
from rl_engine.executors.deepspeed_trainer import _safe_token_ids

token_ids = torch.tensor([[1, -100, 3]])
mask = torch.tensor([[True, False, True]])

assert torch.equal(_safe_token_ids(token_ids, mask, vocab_size=8), torch.tensor([[1, 0, 3]]))


def test_deepspeed_training_safe_token_ids_reject_active_out_of_range():
from rl_engine.executors.deepspeed_trainer import _safe_token_ids

token_ids = torch.tensor([[1, -100, 9]])
mask = torch.tensor([[True, False, True]])

with pytest.raises(ValueError, match="active token_ids out of range"):
_safe_token_ids(token_ids, mask, vocab_size=8)


def test_deepspeed_training_worker_synthetic_fallback(monkeypatch):
_install_fake_deepspeed(monkeypatch)
from rl_engine.executors.deepspeed_trainer import (
Expand Down
Loading
Loading