diff --git a/rl_engine/executors/deepspeed_trainer.py b/rl_engine/executors/deepspeed_trainer.py index 4a7acfb..40fdfa5 100644 --- a/rl_engine/executors/deepspeed_trainer.py +++ b/rl_engine/executors/deepspeed_trainer.py @@ -26,12 +26,7 @@ TrainingStageResult, objective_reference_logps, ) -from rl_engine.testing import ( - compute_policy_ratio, - compute_reference_kl, - masked_mean, - selected_logprobs_reference, -) +from rl_engine.testing import compute_policy_ratio, compute_reference_kl, masked_mean _TDestination = TypeVar("_TDestination", bound=dict[str, Any]) @@ -84,9 +79,9 @@ def __init__( deepspeed = _load_deepspeed() self._deepspeed = deepspeed torch.manual_seed(self.config.seed) - self.model = torch.nn.Sequential( - torch.nn.Embedding(self.config.vocab_size, self.config.hidden_dim), - torch.nn.Linear(self.config.hidden_dim, self.config.vocab_size), + self.model = _EmbeddingLMHeadModel( + self.config.vocab_size, + self.config.hidden_dim, ).to(device=self.device) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr) @@ -101,17 +96,27 @@ def __init__( engine_device = getattr(self.engine, "device", None) if engine_device is not None: self.device = torch.device(engine_device) + self._linear_logp_op = _linear_logp_op_for_device(self.device) + self.model.linear_logp_op = self._linear_logp_op def train(self, rollout: RolloutStageResult) -> TrainingStageResult: started_at = time.perf_counter() batch, payload_metrics = self._batch_from_rollout_or_synthetic(rollout) - - logits = _extract_logits(self.engine(batch.token_ids.long())) - current_logps = selected_logprobs_reference( - logits, + token_ids = _safe_token_ids( batch.token_ids, - mask=batch.completion_mask, - output_dtype=torch.float32, + batch.completion_mask, + vocab_size=self.config.vocab_size, + ) + + current_logps = _extract_logps( + self.engine( + token_ids, + target_ids=token_ids, + ) + ).to(dtype=torch.float32) + current_logps = current_logps.masked_fill( + ~batch.completion_mask.to(device=current_logps.device, dtype=torch.bool), + 0.0, ) old_logps = current_logps.detach() - 0.01 ref_logps = objective_reference_logps(current_logps, batch) @@ -147,6 +152,8 @@ def train(self, rollout: RolloutStageResult) -> TrainingStageResult: "training_device": str(self.device), "deepspeed_engine": type(self.engine).__name__, "deepspeed_zero_stage": self.config.zero_stage, + "lm_head_projection_path": "linear_logp", + "lm_head_projection_backend": type(self._linear_logp_op).__name__, "active_advantage_mean_global": ( float(active_advantages.mean().detach().cpu().item()) if active_advantages.numel() @@ -273,6 +280,67 @@ def _resolved_deepspeed_config(self) -> dict[str, Any]: return _deep_merge(base, dict(self.config.deepspeed_config)) +class _EmbeddingLMHeadModel(torch.nn.Sequential): + """Tiny policy model with an explicit deterministic LM-head logprob path.""" + + def __init__(self, vocab_size: int, hidden_dim: int) -> None: + super().__init__( + torch.nn.Embedding(vocab_size, hidden_dim), + torch.nn.Linear(hidden_dim, vocab_size), + ) + self.linear_logp_op: Optional[Any] = None + + @property + def embedding(self) -> torch.nn.Embedding: + return self[0] + + @property + def lm_head(self) -> torch.nn.Linear: + return self[1] + + def forward( + self, + input_ids: torch.Tensor, + *, + target_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden = self.embedding(input_ids.long()) + if target_ids is None: + return self.lm_head(hidden) + if self.linear_logp_op is None: + raise ValueError("target_ids scoring requires a linear_logp_op") + return self.linear_logp_op(hidden, self.lm_head.weight, target_ids, self.lm_head.bias) + + +def _linear_logp_op_for_device(device: torch.device) -> Any: + if device.type == "cpu": + from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp + + return NativeLinearLogpOp() + from rl_engine.kernels.registry import kernel_registry + + return kernel_registry.get_op("linear_logp") + + +def _safe_token_ids( + token_ids: torch.Tensor, + mask: torch.Tensor, + *, + vocab_size: int, +) -> torch.Tensor: + active = mask.to(device=token_ids.device, dtype=torch.bool) + safe_ids = token_ids.long().masked_fill(~active, 0) + if bool(((safe_ids < 0) | (safe_ids >= vocab_size)).any()): + active_ids = safe_ids[active] + t_min = int(active_ids.min()) if active_ids.numel() else 0 + t_max = int(active_ids.max()) if active_ids.numel() else 0 + raise ValueError( + f"active token_ids out of range: expected [0, {vocab_size - 1}], " + f"got [{t_min}, {t_max}]" + ) + return safe_ids + + def _load_deepspeed() -> Any: _configure_cuda_home_from_python_packages() try: @@ -374,6 +442,19 @@ def _extract_logits(model_output: Any) -> torch.Tensor: raise TypeError(f"DeepSpeed model output does not expose logits: {type(model_output)!r}") +def _extract_logps(model_output: Any) -> torch.Tensor: + if isinstance(model_output, torch.Tensor): + return model_output + if isinstance(model_output, Mapping) and "logps" in model_output: + return model_output["logps"] + logps = getattr(model_output, "logps", None) + if logps is not None: + return logps + if isinstance(model_output, (tuple, list)) and model_output: + return _extract_logps(model_output[0]) + raise TypeError(f"DeepSpeed model output does not expose logps: {type(model_output)!r}") + + class _StateDictModule(torch.nn.Module): def __init__(self, state_dict: Mapping[str, torch.Tensor]): super().__init__() diff --git a/rl_engine/kernels/ops/rocm/attention/flash_attn.py b/rl_engine/kernels/ops/rocm/attention/flash_attn.py index a9781cf..252dc79 100644 --- a/rl_engine/kernels/ops/rocm/attention/flash_attn.py +++ b/rl_engine/kernels/ops/rocm/attention/flash_attn.py @@ -2,6 +2,7 @@ # Copyright (c) 2026 RL-Kernel Contributors import os +from typing import Callable import torch @@ -22,6 +23,7 @@ class RocmFlashAttentionOp: """ def __init__(self): + self.op: Callable[..., torch.Tensor] if torch.version.hip is None: raise RuntimeError("RocmFlashAttentionOp requires a ROCm PyTorch build.") diff --git a/tests/test_deepspeed_training_worker.py b/tests/test_deepspeed_training_worker.py index e093444..007cbc2 100644 --- a/tests/test_deepspeed_training_worker.py +++ b/tests/test_deepspeed_training_worker.py @@ -98,6 +98,41 @@ def _rollout(iteration=2, weight_version=9): ) +def _ragged_rollout(iteration=2, weight_version=9): + return RolloutStageResult( + iteration=iteration, + weight_version=weight_version, + payload={ + "normalized_outputs": [ + [{"token_ids": [3, 4, 5], "text": "abc"}], + [{"token_ids": [6, 7], "text": "de"}], + ] + }, + started_at=time.perf_counter(), + finished_at=time.perf_counter(), + ) + + +class SpyLinearLogpOp: + def __init__(self): + self.calls = [] + + def __call__(self, hidden, lm_head_weight, target_ids, bias=None): + self.calls.append( + { + "hidden_shape": tuple(hidden.shape), + "weight_shape": tuple(lm_head_weight.shape), + "target_shape": tuple(target_ids.shape), + "target_ids": target_ids.detach().cpu().clone(), + "has_bias": bias is not None, + } + ) + logits = torch.nn.functional.linear(hidden, lm_head_weight, bias) + log_probs = torch.log_softmax(logits.float(), dim=-1) + selected = log_probs.gather(-1, target_ids.long().unsqueeze(-1)).squeeze(-1) + return selected + + def test_importing_module_does_not_import_deepspeed(monkeypatch): monkeypatch.delitem(sys.modules, "deepspeed", raising=False) @@ -228,6 +263,63 @@ def test_deepspeed_training_worker_uses_engine_backward_and_step(monkeypatch): assert result.metrics["active_advantage_std_global"] >= 0.0 +def test_deepspeed_training_worker_routes_lm_head_through_linear_logp(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + from rl_engine.executors.deepspeed_trainer import ( + DeepSpeedTrainingConfig, + DeepSpeedTrainingWorker, + ) + + spy_op = SpyLinearLogpOp() + monkeypatch.setattr(deepspeed_trainer, "_linear_logp_op_for_device", lambda device: spy_op) + worker = DeepSpeedTrainingWorker( + DeepSpeedTrainingConfig( + num_prompts=1, + samples_per_prompt=2, + prompt_len=2, + completion_len=3, + vocab_size=16, + hidden_dim=8, + valid_density=1.0, + seed=5, + ) + ) + + result = worker.train(_ragged_rollout()) + + assert len(spy_op.calls) == 1 + assert {key: value for key, value in spy_op.calls[0].items() if key != "target_ids"} == { + "hidden_shape": (2, 3, 8), + "weight_shape": (16, 8), + "target_shape": (2, 3), + "has_bias": True, + } + assert torch.equal(spy_op.calls[0]["target_ids"], torch.tensor([[3, 4, 5], [6, 7, 0]])) + assert result.metrics["lm_head_projection_path"] == "linear_logp" + assert result.metrics["lm_head_projection_backend"] == "SpyLinearLogpOp" + assert math.isfinite(result.metrics["loss"]) + + +def test_deepspeed_training_safe_token_ids_allow_masked_ignore_index(): + from rl_engine.executors.deepspeed_trainer import _safe_token_ids + + token_ids = torch.tensor([[1, -100, 3]]) + mask = torch.tensor([[True, False, True]]) + + assert torch.equal(_safe_token_ids(token_ids, mask, vocab_size=8), torch.tensor([[1, 0, 3]])) + + +def test_deepspeed_training_safe_token_ids_reject_active_out_of_range(): + from rl_engine.executors.deepspeed_trainer import _safe_token_ids + + token_ids = torch.tensor([[1, -100, 9]]) + mask = torch.tensor([[True, False, True]]) + + with pytest.raises(ValueError, match="active token_ids out of range"): + _safe_token_ids(token_ids, mask, vocab_size=8) + + def test_deepspeed_training_worker_synthetic_fallback(monkeypatch): _install_fake_deepspeed(monkeypatch) from rl_engine.executors.deepspeed_trainer import ( diff --git a/tests/test_embedding_lookup_invariance.py b/tests/test_embedding_lookup_invariance.py new file mode 100644 index 0000000..b45c50b --- /dev/null +++ b/tests/test_embedding_lookup_invariance.py @@ -0,0 +1,187 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import pytest +import torch + +from rl_engine.testing import SyntheticRLKernelBatch, make_synthetic_rl_kernel_batch + +VOCAB_SIZE = 32_768 +HIDDEN_DIM = 256 +PROMPT_PROBE_POS = 1 +COMPLETION_PROBE_POS = 5 +PROMPT_PROBE_TOKEN = 12_345 +COMPLETION_PROBE_TOKEN = 23_456 + +BATCH_LAYOUTS = ( + dict( + num_prompts=1, + samples_per_prompt=2, + prompt_len=4, + completion_len=6, + vocab_size=VOCAB_SIZE, + valid_density=1.0, + seed=11, + ), + dict( + num_prompts=2, + samples_per_prompt=3, + prompt_len=4, + completion_len=8, + vocab_size=VOCAB_SIZE, + valid_density=0.5, + seed=12, + ), + dict( + num_prompts=3, + samples_per_prompt=4, + prompt_len=4, + completion_len=10, + vocab_size=VOCAB_SIZE, + valid_density=0.75, + seed=13, + ), +) + +CUDA_CASE = pytest.param( + "cuda", + torch.bfloat16, + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available"), +) + + +def _make_embedding(*, device: str, dtype: torch.dtype, seed: int) -> torch.nn.Embedding: + embedding = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_DIM).to(device=device, dtype=dtype) + generator = torch.Generator(device=torch.device(device)) + generator.manual_seed(seed) + weights = torch.randn( + VOCAB_SIZE, + HIDDEN_DIM, + device=device, + dtype=torch.float32, + generator=generator, + ).to(dtype=dtype) + with torch.no_grad(): + embedding.weight.copy_(weights) + return embedding + + +def _stamp_probe_tokens(batch: SyntheticRLKernelBatch) -> SyntheticRLKernelBatch: + completion_offset = COMPLETION_PROBE_POS - batch.prompt_len + if completion_offset < 0 or completion_offset >= batch.completion_len: + raise ValueError("completion probe position must fall inside completion tokens") + + batch.input_ids[:, PROMPT_PROBE_POS] = PROMPT_PROBE_TOKEN + batch.input_ids[:, COMPLETION_PROBE_POS] = COMPLETION_PROBE_TOKEN + batch.token_ids[:, completion_offset] = COMPLETION_PROBE_TOKEN + batch.completion_mask[:, completion_offset] = True + batch.attention_mask[:, COMPLETION_PROBE_POS] = True + return batch + + +def _make_layout( + layout: dict[str, int | float], *, device: str, dtype: torch.dtype +) -> SyntheticRLKernelBatch: + batch = make_synthetic_rl_kernel_batch(device=device, dtype=dtype, **layout) + return _stamp_probe_tokens(batch) + + +def _permute_rows(batch: SyntheticRLKernelBatch, perm: torch.Tensor) -> SyntheticRLKernelBatch: + completion_mask = batch.completion_mask.index_select(0, perm) + return SyntheticRLKernelBatch( + input_ids=batch.input_ids.index_select(0, perm), + attention_mask=batch.attention_mask.index_select(0, perm), + prompt_mask=batch.prompt_mask.index_select(0, perm), + completion_mask=completion_mask, + token_ids=batch.token_ids.index_select(0, perm), + rewards=batch.rewards.index_select(0, perm), + advantages=batch.advantages.index_select(0, perm), + old_logps=batch.old_logps.index_select(0, perm), + ref_logps=batch.ref_logps.index_select(0, perm), + valid_indices=completion_mask.reshape(-1).nonzero(as_tuple=False).squeeze(-1), + metadata=dict(batch.metadata), + ) + + +def _assert_probe_vectors( + output: torch.Tensor, + *, + batch_size: int, + prompt_reference: torch.Tensor, + completion_reference: torch.Tensor, +) -> None: + assert torch.equal( + output[:, PROMPT_PROBE_POS, :], + prompt_reference.expand(batch_size, -1), + ) + assert torch.equal( + output[:, COMPLETION_PROBE_POS, :], + completion_reference.expand(batch_size, -1), + ) + + +@pytest.mark.parametrize("device,dtype", [("cpu", torch.float32), CUDA_CASE]) +def test_embedding_lookup_is_bitwise_identical_across_batch_layouts( + device: str, dtype: torch.dtype +) -> None: + embedding = _make_embedding(device=device, dtype=dtype, seed=2026) + prompt_reference = embedding.weight[PROMPT_PROBE_TOKEN].detach() + completion_reference = embedding.weight[COMPLETION_PROBE_TOKEN].detach() + + for layout in BATCH_LAYOUTS: + batch = _make_layout(layout, device=device, dtype=dtype) + output = embedding(batch.input_ids) + _assert_probe_vectors( + output, + batch_size=batch.batch_size, + prompt_reference=prompt_reference, + completion_reference=completion_reference, + ) + + +@pytest.mark.parametrize("device,dtype", [("cpu", torch.float32), CUDA_CASE]) +def test_embedding_lookup_is_row_order_invariant_under_permutation( + device: str, dtype: torch.dtype +) -> None: + embedding = _make_embedding(device=device, dtype=dtype, seed=2026) + batch = _make_layout(BATCH_LAYOUTS[2], device=device, dtype=dtype) + + perm = torch.arange(batch.batch_size - 1, -1, -1, device=torch.device(device)) + original = embedding(batch.input_ids) + permuted_batch = _permute_rows(batch, perm) + permuted = embedding(permuted_batch.input_ids) + + assert torch.equal(permuted, original.index_select(0, perm)) + + +@pytest.mark.parametrize("device,dtype", [("cpu", torch.float32), CUDA_CASE]) +def test_embedding_lookup_is_unaffected_by_padding_tail_mutations( + device: str, dtype: torch.dtype +) -> None: + embedding = _make_embedding(device=device, dtype=dtype, seed=2026) + batch = _make_layout(BATCH_LAYOUTS[1], device=device, dtype=dtype) + inactive = ~batch.attention_mask + assert bool(inactive.any()) + + mutated_input_ids = batch.input_ids.clone() + generator = torch.Generator(device=torch.device(device)) + generator.manual_seed(404) + random_tokens = torch.randint( + 0, + VOCAB_SIZE, + mutated_input_ids.shape, + device=device, + generator=generator, + dtype=torch.long, + ) + mutated_input_ids[inactive] = random_tokens[inactive] + + baseline = embedding(batch.input_ids) + candidate = embedding(mutated_input_ids) + + assert torch.equal( + batch.input_ids[batch.attention_mask], mutated_input_ids[batch.attention_mask] + ) + assert torch.equal(candidate[batch.attention_mask], baseline[batch.attention_mask])