diff --git a/rl_engine/executors/deepspeed_trainer.py b/rl_engine/executors/deepspeed_trainer.py index 4a7acfb..21c7705 100644 --- a/rl_engine/executors/deepspeed_trainer.py +++ b/rl_engine/executors/deepspeed_trainer.py @@ -7,6 +7,7 @@ import os import sysconfig import time +from contextlib import nullcontext from dataclasses import dataclass, field from pathlib import Path from typing import Any, Mapping, Optional, TypeVar, overload @@ -26,12 +27,9 @@ TrainingStageResult, objective_reference_logps, ) -from rl_engine.testing import ( - compute_policy_ratio, - compute_reference_kl, - masked_mean, - selected_logprobs_reference, -) +from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp +from rl_engine.kernels.registry import kernel_registry +from rl_engine.testing import compute_policy_ratio, compute_reference_kl, masked_mean _TDestination = TypeVar("_TDestination", bound=dict[str, Any]) @@ -53,6 +51,25 @@ def __post_init__(self) -> None: raise ValueError("zero_stage must be >= 0") +class _EmbeddingLMHeadModel(torch.nn.Module): + def __init__( + self, + vocab_size: int, + hidden_dim: int, + *, + bias: bool = True, + tie_weights: bool = False, + ) -> None: + super().__init__() + self.embedding = torch.nn.Embedding(vocab_size, hidden_dim) + self.lm_head = torch.nn.Linear(hidden_dim, vocab_size, bias=bias) + if tie_weights: + self.lm_head.weight = self.embedding.weight + + def forward(self, token_ids: torch.Tensor) -> torch.Tensor: + return self.embedding(token_ids.long()) + + class DeepSpeedTrainingWorker(RolloutBatchMixin): """ Training worker implementation backed by a real DeepSpeed engine contract. @@ -84,43 +101,39 @@ def __init__( deepspeed = _load_deepspeed() self._deepspeed = deepspeed torch.manual_seed(self.config.seed) - self.model = torch.nn.Sequential( - torch.nn.Embedding(self.config.vocab_size, self.config.hidden_dim), - torch.nn.Linear(self.config.hidden_dim, self.config.vocab_size), + self.model = _EmbeddingLMHeadModel( + self.config.vocab_size, + self.config.hidden_dim, ).to(device=self.device) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config.lr) + self._deepspeed_config = self._resolved_deepspeed_config() + self._deepspeed_zero_stage = _resolved_zero_stage( + self._deepspeed_config, + fallback=self.config.zero_stage, + ) init_result = deepspeed.initialize( model=self.model, model_parameters=self.model.parameters(), optimizer=self.optimizer, - config=self._resolved_deepspeed_config(), + config=self._deepspeed_config, **dict(self.config.initialize_kwargs), ) self.engine = _first_initialize_result(init_result) engine_device = getattr(self.engine, "device", None) if engine_device is not None: self.device = torch.device(engine_device) + self._linear_logp = _linear_logp_op_for_device(self.device) def train(self, rollout: RolloutStageResult) -> TrainingStageResult: started_at = time.perf_counter() batch, payload_metrics = self._batch_from_rollout_or_synthetic(rollout) - - logits = _extract_logits(self.engine(batch.token_ids.long())) - current_logps = selected_logprobs_reference( - logits, + training_model = _unwrap_training_model(self.engine, self.model) + training_embedding = _embedding_layer(training_model) + _validate_model_input_token_ids( batch.token_ids, - mask=batch.completion_mask, - output_dtype=torch.float32, + vocab_size=training_embedding.num_embeddings, ) - old_logps = current_logps.detach() - 0.01 - ref_logps = objective_reference_logps(current_logps, batch) - ratio = compute_policy_ratio(current_logps, old_logps, batch.completion_mask) - unclipped = ratio * batch.advantages.float() - clipped = torch.clamp(ratio, 0.8, 1.2) * batch.advantages.float() - policy_loss = -torch.minimum(unclipped, clipped) - kl = compute_reference_kl(current_logps, ref_logps, batch.completion_mask) - loss = masked_mean(policy_loss + 0.01 * kl, batch.completion_mask) if hasattr(self.engine, "zero_grad"): try: @@ -129,7 +142,30 @@ def train(self, rollout: RolloutStageResult) -> TrainingStageResult: self.engine.zero_grad() elif hasattr(self.optimizer, "zero_grad"): self.optimizer.zero_grad(set_to_none=True) - self.engine.backward(loss) + + with _linear_logp_parameter_context( + self._deepspeed, + training_model, + zero_stage=self._deepspeed_zero_stage, + world_size=self._engine_world_size(), + ): + current_logps = _extract_logps( + self.engine(batch.token_ids.long()), + training_model, + batch.token_ids, + batch.completion_mask, + self._linear_logp, + output_dtype=torch.float32, + ) + old_logps = current_logps.detach() - 0.01 + ref_logps = objective_reference_logps(current_logps, batch) + ratio = compute_policy_ratio(current_logps, old_logps, batch.completion_mask) + unclipped = ratio * batch.advantages.float() + clipped = torch.clamp(ratio, 0.8, 1.2) * batch.advantages.float() + policy_loss = -torch.minimum(unclipped, clipped) + kl = compute_reference_kl(current_logps, ref_logps, batch.completion_mask) + loss = masked_mean(policy_loss + 0.01 * kl, batch.completion_mask) + self.engine.backward(loss) self.engine.step() finished_at = time.perf_counter() @@ -146,7 +182,9 @@ def train(self, rollout: RolloutStageResult) -> TrainingStageResult: "training_backend": "deepspeed", "training_device": str(self.device), "deepspeed_engine": type(self.engine).__name__, - "deepspeed_zero_stage": self.config.zero_stage, + "deepspeed_zero_stage": self._deepspeed_zero_stage, + "current_logp_path": "linear_logp", + "current_logp_backend": type(self._linear_logp).__name__, "active_advantage_mean_global": ( float(active_advantages.mean().detach().cpu().item()) if active_advantages.numel() @@ -180,14 +218,14 @@ def publish_weights( manifest_metadata = dict(metadata or {}) layout = { "kind": "full-state", - "zero_stage": self.config.zero_stage, + "zero_stage": self._deepspeed_zero_stage, "world_size": self._engine_world_size(), "rank": self._engine_rank(), } layout.update(dict(manifest_metadata.get("layout", {}))) manifest_metadata["layout"] = layout publish_model: torch.nn.Module = self.model - if self.config.zero_stage >= 3: + if self._deepspeed_zero_stage >= 3: publish_model, export_metadata = self._export_zero3_full_state_model() manifest_metadata["deepspeed_zero3_full_state_export"] = export_metadata return self.weight_bridge.publish( @@ -361,6 +399,15 @@ def _first_initialize_result(init_result: Any) -> Any: return init_result +def _resolved_zero_stage(config: Mapping[str, Any], *, fallback: int) -> int: + zero_config = config.get("zero_optimization") + if isinstance(zero_config, Mapping): + return int(zero_config.get("stage", fallback)) + if zero_config is False: + return 0 + return int(fallback) + + def _extract_logits(model_output: Any) -> torch.Tensor: if isinstance(model_output, torch.Tensor): return model_output @@ -374,6 +421,226 @@ def _extract_logits(model_output: Any) -> torch.Tensor: raise TypeError(f"DeepSpeed model output does not expose logits: {type(model_output)!r}") +def _extract_hidden_states( + model_output: Any, + *, + expected_hidden_dim: Optional[int] = None, +) -> torch.Tensor: + hidden = _coerce_hidden_tensor(model_output, expected_hidden_dim=expected_hidden_dim) + if hidden is None: + raise TypeError( + f"DeepSpeed model output does not expose a hidden-state tensor: {type(model_output)!r}" + ) + return hidden + + +def _linear_logp_op_for_device(device: torch.device | str) -> Any: + resolved = torch.device(device) + if resolved.type == "cpu": + return NativeLinearLogpOp() + return kernel_registry.get_op("linear_logp") + + +def _unwrap_training_model(engine: Any, fallback_model: torch.nn.Module) -> torch.nn.Module: + model = getattr(engine, "module", None) + if isinstance(model, torch.nn.Module): + return model + return fallback_model + + +def _embedding_layer(model: torch.nn.Module) -> torch.nn.Embedding: + embedding = getattr(model, "embedding", None) + if not isinstance(embedding, torch.nn.Embedding): + raise TypeError( + "DeepSpeed training model must expose an embedding torch.nn.Embedding for " + "model-input validation" + ) + return embedding + + +def _coerce_hidden_tensor( + candidate: Any, + *, + expected_hidden_dim: Optional[int] = None, +) -> Optional[torch.Tensor]: + if isinstance(candidate, torch.Tensor): + return candidate if _looks_like_hidden_tensor(candidate, expected_hidden_dim) else None + if isinstance(candidate, Mapping): + for key in ("last_hidden_state", "hidden"): + value = candidate.get(key) + hidden = _coerce_hidden_tensor(value, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + hidden_states = candidate.get("hidden_states") + hidden = _last_hidden_state_tensor( + hidden_states, + expected_hidden_dim=expected_hidden_dim, + ) + if hidden is not None: + return hidden + return None + for attr in ("last_hidden_state", "hidden"): + if hasattr(candidate, attr): + hidden = _coerce_hidden_tensor( + getattr(candidate, attr), + expected_hidden_dim=expected_hidden_dim, + ) + if hidden is not None: + return hidden + if hasattr(candidate, "hidden_states"): + hidden = _last_hidden_state_tensor( + getattr(candidate, "hidden_states"), + expected_hidden_dim=expected_hidden_dim, + ) + if hidden is not None: + return hidden + if isinstance(candidate, (tuple, list)): + for item in candidate: + if _has_hidden_state_metadata(item): + hidden = _coerce_hidden_tensor(item, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + tensor_candidates = [ + item + for item in candidate + if isinstance(item, torch.Tensor) + and _looks_like_hidden_tensor(item, expected_hidden_dim) + ] + if len(tensor_candidates) == 1: + return tensor_candidates[0] + if tensor_candidates: + max_ndim = max(tensor.ndim for tensor in tensor_candidates) + deepest = [tensor for tensor in tensor_candidates if tensor.ndim == max_ndim] + if len(deepest) == 1: + return deepest[0] + for item in candidate: + if isinstance(item, torch.Tensor): + continue + hidden = _coerce_hidden_tensor(item, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + return None + + +def _has_hidden_state_metadata(candidate: Any) -> bool: + if isinstance(candidate, Mapping): + return any(key in candidate for key in ("last_hidden_state", "hidden", "hidden_states")) + return any( + hasattr(candidate, attr) for attr in ("last_hidden_state", "hidden", "hidden_states") + ) + + +def _looks_like_hidden_tensor( + tensor: torch.Tensor, + expected_hidden_dim: Optional[int], +) -> bool: + if tensor.ndim < 2: + return False + if expected_hidden_dim is not None and int(tensor.size(-1)) != int(expected_hidden_dim): + return False + return True + + +def _last_hidden_state_tensor( + candidate: Any, + *, + expected_hidden_dim: Optional[int] = None, +) -> Optional[torch.Tensor]: + if isinstance(candidate, torch.Tensor): + return candidate if _looks_like_hidden_tensor(candidate, expected_hidden_dim) else None + if isinstance(candidate, (tuple, list)): + for item in reversed(candidate): + hidden = _coerce_hidden_tensor(item, expected_hidden_dim=expected_hidden_dim) + if hidden is not None: + return hidden + return None + + +def _safe_token_ids(token_ids: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: + safe_token_ids = token_ids.long() + if mask is None: + return safe_token_ids + active_mask = mask.to(device=safe_token_ids.device, dtype=torch.bool) + if active_mask.shape != safe_token_ids.shape: + raise ValueError( + f"mask shape {tuple(active_mask.shape)} must match token_ids shape " + f"{tuple(safe_token_ids.shape)}" + ) + return safe_token_ids.masked_fill(~active_mask, 0) + + +def _validate_model_input_token_ids(token_ids: torch.Tensor, *, vocab_size: int) -> None: + invalid = (token_ids < 0) | (token_ids >= int(vocab_size)) + if bool(invalid.any().item()): + t_min = int(token_ids.min().item()) + t_max = int(token_ids.max().item()) + raise ValueError( + f"model input token_ids must be in [0, {int(vocab_size) - 1}], got " + f"[{t_min}, {t_max}]. Keep ignore-index / padding sentinels out of the model " + "input path and apply masking only at the logprob/loss stage." + ) + + +def _linear_logp_parameter_context( + deepspeed_runtime: Any, + model: torch.nn.Module, + *, + zero_stage: int, + world_size: int, +) -> Any: + if int(zero_stage) < 3 or int(world_size) <= 1: + return nullcontext() + + lm_head = getattr(model, "lm_head", None) + if not isinstance(lm_head, torch.nn.Linear): + raise TypeError( + "DeepSpeed training model must expose an lm_head torch.nn.Linear for ZeRO-3 " + "linear_logp gathering" + ) + + gathered_parameters = getattr( + getattr(deepspeed_runtime, "zero", None), + "GatheredParameters", + None, + ) + if not callable(gathered_parameters): + raise WeightBridgeUnavailableError( + "DeepSpeed ZeRO-3 linear_logp training requires deepspeed.zero.GatheredParameters " + "or an equivalent full-parameter gather API." + ) + + parameters = [lm_head.weight] + if lm_head.bias is not None: + parameters.append(lm_head.bias) + return gathered_parameters(parameters, modifier_rank=None) + + +def _extract_logps( + model_output: Any, + model: torch.nn.Module, + token_ids: torch.Tensor, + completion_mask: Optional[torch.Tensor], + linear_logp_op: Any, + *, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + lm_head = getattr(model, "lm_head", None) + if not isinstance(lm_head, torch.nn.Linear): + raise TypeError( + "DeepSpeed training model must expose an lm_head torch.nn.Linear for linear_logp" + ) + + hidden = _extract_hidden_states( + model_output, + expected_hidden_dim=int(lm_head.in_features), + ) + targets = _safe_token_ids(token_ids.to(device=hidden.device), completion_mask) + logps = linear_logp_op(hidden, lm_head.weight, targets, lm_head.bias) + if completion_mask is not None: + logps = logps.masked_fill(~completion_mask.to(device=logps.device, dtype=torch.bool), 0.0) + return logps.to(dtype=output_dtype) + + class _StateDictModule(torch.nn.Module): def __init__(self, state_dict: Mapping[str, torch.Tensor]): super().__init__() diff --git a/tests/test_deepspeed_training_worker.py b/tests/test_deepspeed_training_worker.py index e093444..35caefa 100644 --- a/tests/test_deepspeed_training_worker.py +++ b/tests/test_deepspeed_training_worker.py @@ -8,12 +8,15 @@ import os import sys import time +from dataclasses import replace import pytest import torch from rl_engine.executors.bridge import LocalTensorCopyBridge, WeightBridgeUnavailableError from rl_engine.executors.training_contract import RolloutStageResult +from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp +from rl_engine.testing import make_synthetic_rl_kernel_batch, selected_logprobs_reference class FakeDeepSpeedEngine: @@ -56,16 +59,25 @@ def initialize(self, **kwargs): class FakeGatheredParameters: calls = 0 + active = 0 + max_active = 0 + modifier_ranks = [] + parameter_counts = [] def __init__(self, parameters, modifier_rank=0): self.parameters = list(parameters) self.modifier_rank = modifier_rank + type(self).modifier_ranks.append(modifier_rank) + type(self).parameter_counts.append(len(self.parameters)) def __enter__(self): type(self).calls += 1 + type(self).active += 1 + type(self).max_active = max(type(self).max_active, type(self).active) return self.parameters def __exit__(self, exc_type, exc, traceback): + type(self).active -= 1 return False @@ -78,6 +90,10 @@ def _install_fake_deepspeed(monkeypatch): def _install_fake_deepspeed_with_gather(monkeypatch): fake = FakeDeepSpeedModule() FakeGatheredParameters.calls = 0 + FakeGatheredParameters.active = 0 + FakeGatheredParameters.max_active = 0 + FakeGatheredParameters.modifier_ranks = [] + FakeGatheredParameters.parameter_counts = [] fake.zero = type("FakeZeroNamespace", (), {"GatheredParameters": FakeGatheredParameters})() monkeypatch.setitem(sys.modules, "deepspeed", fake) return fake @@ -98,6 +114,24 @@ def _rollout(iteration=2, weight_version=9): ) +class SpyLinearLogpOp: + def __init__(self): + self.calls = [] + self._delegate = NativeLinearLogpOp() + + def __call__(self, hidden, lm_head_weight, target_ids, bias=None, **kwargs): + self.calls.append( + { + "hidden": hidden.detach().clone(), + "lm_head_weight": lm_head_weight.detach().clone(), + "target_ids": target_ids.detach().clone(), + "bias": None if bias is None else bias.detach().clone(), + "kwargs": dict(kwargs), + } + ) + return self._delegate(hidden, lm_head_weight, target_ids, bias, **kwargs) + + def test_importing_module_does_not_import_deepspeed(monkeypatch): monkeypatch.delitem(sys.modules, "deepspeed", raising=False) @@ -218,9 +252,12 @@ def test_deepspeed_training_worker_uses_engine_backward_and_step(monkeypatch): assert result.consumed_weight_version == 9 assert result.published_weight_version == 10 assert result.metrics["training_backend"] == "deepspeed" + assert result.metrics["deepspeed_zero_stage"] == 1 assert result.metrics["training_data_source"] == "rollout_payload" assert result.metrics["rollout_sequences"] == 2 assert result.metrics["rollout_tokens"] == 6 + assert result.metrics["current_logp_path"] == "linear_logp" + assert result.metrics["current_logp_backend"] == "NativeLinearLogpOp" assert math.isfinite(result.metrics["loss"]) assert "advantage_mean" not in result.metrics assert "advantage_std" not in result.metrics @@ -228,6 +265,309 @@ def test_deepspeed_training_worker_uses_engine_backward_and_step(monkeypatch): assert result.metrics["active_advantage_std_global"] >= 0.0 +def test_extract_logps_matches_masked_reference_with_ignore_index(): + from rl_engine.executors.deepspeed_trainer import _EmbeddingLMHeadModel, _extract_logps + + torch.manual_seed(2026) + model = _EmbeddingLMHeadModel(vocab_size=13, hidden_dim=7) + input_ids = torch.tensor([[4, 3, 2], [1, 0, 5]], dtype=torch.long) + token_ids = torch.tensor([[6, -100, 2], [-100, 1, 4]], dtype=torch.long) + mask = token_ids.ne(-100) + + hidden = model(input_ids) + actual = _extract_logps( + hidden, + model, + token_ids, + mask, + NativeLinearLogpOp(), + output_dtype=torch.float32, + ) + logits = torch.nn.functional.linear( + hidden.float(), + model.lm_head.weight.float(), + model.lm_head.bias.float(), + ) + expected = selected_logprobs_reference(logits, token_ids, mask=mask) + + assert torch.allclose(actual, expected, atol=1e-5) + assert actual[~mask].eq(0.0).all() + + +def test_extract_logps_uses_hidden_dim_to_disambiguate_tuple_logits(): + from rl_engine.executors.deepspeed_trainer import _EmbeddingLMHeadModel, _extract_logps + + torch.manual_seed(2027) + model = _EmbeddingLMHeadModel(vocab_size=13, hidden_dim=5) + input_ids = torch.tensor([[4, 3, 2]], dtype=torch.long) + token_ids = torch.tensor([[6, 1, 4]], dtype=torch.long) + mask = torch.ones_like(token_ids, dtype=torch.bool) + + hidden = model(input_ids) + logits = torch.randn(1, 3, model.lm_head.out_features) + actual = _extract_logps( + (torch.tensor(1.0), logits, hidden), + model, + token_ids, + mask, + NativeLinearLogpOp(), + output_dtype=torch.float32, + ) + expected_logits = torch.nn.functional.linear( + hidden.float(), + model.lm_head.weight.float(), + model.lm_head.bias.float(), + ) + expected = selected_logprobs_reference(expected_logits, token_ids, mask=mask) + + assert torch.allclose(actual, expected, atol=1e-5) + + +def test_extract_hidden_states_prefers_last_hidden_state_over_hidden_state_stack(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + expected = torch.randn(2, 3, 5) + output = { + "hidden_states": ( + torch.randn(2, 3, 5), + torch.randn(2, 3, 5), + ), + "last_hidden_state": expected, + } + + actual = _extract_hidden_states(output) + + assert actual is expected + + +def test_extract_hidden_states_uses_last_tensor_from_hidden_state_stack(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + layers = ( + torch.randn(2, 3, 5), + torch.randn(2, 3, 5), + torch.randn(2, 3, 5), + ) + + actual = _extract_hidden_states({"hidden_states": layers}) + + assert actual is layers[-1] + + +def test_extract_hidden_states_prefers_structured_hidden_over_tuple_logits(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + logits = torch.randn(2, 3, 11) + expected = torch.randn(2, 3, 5) + output = (torch.tensor(1.0), logits, {"last_hidden_state": expected}) + + actual = _extract_hidden_states(output) + + assert actual is expected + + +def test_extract_hidden_states_rejects_ambiguous_multi_tensor_tuple(): + from rl_engine.executors.deepspeed_trainer import _extract_hidden_states + + with pytest.raises(TypeError, match="hidden-state tensor"): + _extract_hidden_states((torch.randn(2, 3, 11), torch.randn(2, 3, 5))) + + +def test_deepspeed_training_worker_routes_linear_logp_and_zeroes_masked_targets(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + spy = SpyLinearLogpOp() + monkeypatch.setattr(deepspeed_trainer, "_linear_logp_op_for_device", lambda device: spy) + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + num_prompts=1, + samples_per_prompt=2, + prompt_len=1, + completion_len=4, + vocab_size=23, + hidden_dim=8, + seed=31, + ) + ) + batch = make_synthetic_rl_kernel_batch( + num_prompts=1, + samples_per_prompt=2, + prompt_len=1, + completion_len=4, + vocab_size=23, + valid_density=1.0, + device="cpu", + seed=32, + ) + completion_mask = torch.tensor( + [[True, False, True, False], [False, True, True, False]], + dtype=torch.bool, + ) + patched_batch = replace( + batch, + completion_mask=completion_mask, + valid_indices=completion_mask.reshape(-1).nonzero(as_tuple=False).squeeze(-1), + metadata={ + **batch.metadata, + "valid_density": float(completion_mask.float().mean().item()), + "valid_tokens": int(completion_mask.sum().item()), + }, + ) + monkeypatch.setattr( + worker, + "_batch_from_rollout_or_synthetic", + lambda rollout: ( + patched_batch, + { + "training_data_source": "patched_fixture", + "rollout_sequences": patched_batch.batch_size, + "rollout_tokens": int(completion_mask.sum().item()), + }, + ), + ) + + result = worker.train(_rollout()) + + assert len(spy.calls) == 1 + recorded_targets = spy.calls[0]["target_ids"] + assert torch.equal(recorded_targets[completion_mask], patched_batch.token_ids[completion_mask]) + assert torch.equal( + recorded_targets[~completion_mask], + torch.zeros_like(recorded_targets[~completion_mask]), + ) + assert result.metrics["training_data_source"] == "patched_fixture" + assert result.metrics["current_logp_path"] == "linear_logp" + assert result.metrics["current_logp_backend"] == "SpyLinearLogpOp" + assert math.isfinite(result.metrics["loss"]) + + +def test_deepspeed_training_worker_rejects_ignore_index_in_model_inputs(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + num_prompts=1, + samples_per_prompt=1, + prompt_len=1, + completion_len=3, + vocab_size=17, + hidden_dim=8, + seed=33, + ) + ) + batch = make_synthetic_rl_kernel_batch( + num_prompts=1, + samples_per_prompt=1, + prompt_len=1, + completion_len=3, + vocab_size=17, + valid_density=1.0, + device="cpu", + seed=34, + ) + broken_batch = replace( + batch, + token_ids=batch.token_ids.clone(), + ) + broken_batch.token_ids[0, 1] = -100 + monkeypatch.setattr( + worker, + "_batch_from_rollout_or_synthetic", + lambda rollout: (broken_batch, {"training_data_source": "patched_fixture"}), + ) + + with pytest.raises(ValueError, match="ignore-index"): + worker.train(_rollout()) + + +def test_deepspeed_zero3_training_gathers_lm_head_parameters_during_backward(monkeypatch): + _install_fake_deepspeed_with_gather(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + vocab_size=19, + hidden_dim=8, + zero_stage=3, + seed=35, + ) + ) + worker.engine.world_size = 2 + + active_during_backward = {"value": False} + original_backward = worker.engine.backward + + def wrapped_backward(loss): + active_during_backward["value"] = FakeGatheredParameters.active > 0 + return original_backward(loss) + + worker.engine.backward = wrapped_backward + + result = worker.train(_rollout()) + + assert result.metrics["current_logp_path"] == "linear_logp" + assert FakeGatheredParameters.calls == 1 + assert FakeGatheredParameters.parameter_counts == [2] + assert FakeGatheredParameters.modifier_ranks == [None] + assert FakeGatheredParameters.max_active == 1 + assert active_during_backward["value"] is True + assert FakeGatheredParameters.active == 0 + + +def test_deepspeed_zero3_training_without_gather_api_is_blocked(monkeypatch): + _install_fake_deepspeed(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + vocab_size=19, + hidden_dim=8, + zero_stage=3, + seed=36, + ) + ) + worker.engine.world_size = 2 + + with pytest.raises(WeightBridgeUnavailableError, match="linear_logp training requires"): + worker.train(_rollout()) + + +def test_deepspeed_config_zero3_override_controls_training_and_publish(monkeypatch): + fake = _install_fake_deepspeed_with_gather(monkeypatch) + from rl_engine.executors import deepspeed_trainer + + bridge = LocalTensorCopyBridge(source_worker="training", source_rank=0) + worker = deepspeed_trainer.DeepSpeedTrainingWorker( + deepspeed_trainer.DeepSpeedTrainingConfig( + vocab_size=19, + hidden_dim=8, + zero_stage=0, + deepspeed_config={"zero_optimization": {"stage": 3}}, + seed=37, + ), + weight_bridge=bridge, + ) + worker.engine.world_size = 2 + + result = worker.train(_rollout()) + manifest = worker.publish_weights(weight_version=41) + + assert fake.initialize_calls[0]["config"]["zero_optimization"]["stage"] == 3 + assert result.metrics["deepspeed_zero_stage"] == 3 + assert manifest.metadata["layout"]["zero_stage"] == 3 + assert manifest.metadata["deepspeed_zero3_full_state_export"]["method"] == ( + "deepspeed.zero.GatheredParameters" + ) + assert FakeGatheredParameters.calls == 2 + assert FakeGatheredParameters.parameter_counts == [2, 3] + assert FakeGatheredParameters.modifier_ranks == [None, 0] + + bridge.release(manifest.update_id) + + def test_deepspeed_training_worker_synthetic_fallback(monkeypatch): _install_fake_deepspeed(monkeypatch) from rl_engine.executors.deepspeed_trainer import ( diff --git a/tests/test_linear_logp.py b/tests/test_linear_logp.py index a6ce1b3..1f4c8da 100644 --- a/tests/test_linear_logp.py +++ b/tests/test_linear_logp.py @@ -4,7 +4,12 @@ import pytest import torch -from rl_engine.kernels.ops.pytorch.loss.linear_logp import NativeLinearLogpOp +from rl_engine.executors.deepspeed_trainer import _EmbeddingLMHeadModel, _safe_token_ids +from rl_engine.kernels.ops.pytorch.loss.linear_logp import ( + NativeLinearLogpOp, + chunked_linear_logp_backward, +) +from rl_engine.testing import selected_logprobs_reference try: import triton # noqa: F401 @@ -83,6 +88,48 @@ def _manual_reference(hidden, weight, target, bias): return sel.reshape(target.shape) +def _layout_inputs(base_hidden, base_target, base_mask, order, lead_shape): + order_t = torch.tensor(order, dtype=torch.long) + hidden = base_hidden.index_select(0, order_t).reshape(*lead_shape, base_hidden.size(-1)) + target = base_target.index_select(0, order_t).reshape(*lead_shape) + mask = base_mask.index_select(0, order_t).reshape(*lead_shape) + masked_target = target.masked_fill(~mask, -100) + return hidden, masked_target, mask + + +def _recover_canonical_rows(layout_values, order): + flat = layout_values.reshape( + layout_values.shape[0] * layout_values.shape[1], *layout_values.shape[2:] + ) + recovered = torch.empty_like(flat) + recovered[torch.tensor(order, dtype=torch.long)] = flat + return recovered + + +def _run_chunked_backward(hidden, weight, target, bias, grad_out, *, chunk_elems): + return chunked_linear_logp_backward( + grad_out, + hidden.reshape(-1, hidden.size(-1)).contiguous(), + weight, + target.reshape(-1).contiguous(), + hidden.reshape(-1, hidden.size(-1)).contiguous() if bias is None else bias, + has_bias=bias is not None, + lead_shape=target.shape, + hidden_dtype=hidden.dtype, + weight_dtype=weight.dtype, + bias_dtype=None if bias is None else bias.dtype, + chunk_elems=chunk_elems, + ) + + +def _run_autograd_linear_logp(hidden, weight, target, bias, grad_out): + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + b = bias.detach().clone().requires_grad_(True) if bias is not None else None + NativeLinearLogpOp()(h, w, target, b).backward(grad_out) + return h.grad, w.grad, (None if b is None else b.grad) + + def test_native_matches_manual_reference(): native = NativeLinearLogpOp() hidden, weight, target, bias = _inputs(0, device="cpu") @@ -92,6 +139,146 @@ def test_native_matches_manual_reference(): assert torch.allclose(out, ref, atol=1e-5) +def test_linear_logp_handoff_matches_masked_reference_across_layouts(): + torch.manual_seed(2026) + op = NativeLinearLogpOp() + base_hidden = torch.randn(6, 5) + weight = torch.randn(17, 5) + bias = torch.randn(17) + base_target = torch.tensor([3, 7, 1, 9, 4, 6], dtype=torch.long) + base_mask = torch.tensor([True, False, True, True, False, True], dtype=torch.bool) + layouts = [ + ((2, 3), [0, 1, 2, 3, 4, 5]), + ((3, 2), [5, 1, 3, 0, 4, 2]), + ((1, 6), [2, 4, 1, 5, 0, 3]), + ] + + canonical = None + for lead_shape, order in layouts: + hidden, target, mask = _layout_inputs( + base_hidden, base_target, base_mask, order, lead_shape + ) + actual = op(hidden, weight, _safe_token_ids(target, mask), bias).masked_fill(~mask, 0.0) + logits = torch.nn.functional.linear(hidden.float(), weight.float(), bias.float()) + expected = selected_logprobs_reference(logits, target, mask=mask) + recovered = _recover_canonical_rows(actual.unsqueeze(-1), order).squeeze(-1) + + assert torch.allclose(actual, expected, atol=1e-5) + if canonical is None: + canonical = recovered + else: + assert torch.allclose(recovered, canonical, atol=1e-6) + + +@pytest.mark.parametrize("use_bias", [True, False]) +def test_chunked_linear_logp_backward_matches_autograd_and_layout_invariance(use_bias): + torch.manual_seed(2027) + weight = torch.randn(19, 7) + bias = torch.randn(19) if use_bias else None + base_hidden = torch.randn(6, 7) + base_target = torch.tensor([1, 7, 3, 5, 0, 9], dtype=torch.long) + base_mask = torch.tensor([True, False, True, True, False, True], dtype=torch.bool) + base_grad = torch.tensor([0.5, 0.0, -1.25, 0.75, 0.0, 1.5], dtype=torch.float32) + layouts = [ + ((2, 3), [0, 1, 2, 3, 4, 5]), + ((3, 2), [5, 2, 1, 0, 4, 3]), + ] + + canonical_hidden_grad = None + canonical_weight_grad = None + canonical_bias_grad = None + chunk_elems = weight.size(0) * 2 + + for lead_shape, order in layouts: + hidden, target, mask = _layout_inputs( + base_hidden, base_target, base_mask, order, lead_shape + ) + safe_target = _safe_token_ids(target, mask) + grad_out = base_grad[torch.tensor(order, dtype=torch.long)].reshape(lead_shape) + grad_out = grad_out.masked_fill(~mask, 0.0) + + grad_hidden, grad_weight, grad_bias = _run_chunked_backward( + hidden, + weight, + safe_target, + bias, + grad_out, + chunk_elems=chunk_elems, + ) + ref_hidden, ref_weight, ref_bias = _run_autograd_linear_logp( + hidden, + weight, + safe_target, + bias, + grad_out, + ) + recovered_hidden = _recover_canonical_rows(grad_hidden, order) + + assert torch.allclose(grad_hidden, ref_hidden, atol=1e-5) + assert torch.allclose(grad_weight, ref_weight, atol=1e-5) + if use_bias: + assert torch.allclose(grad_bias, ref_bias, atol=1e-5) + + if canonical_hidden_grad is None: + canonical_hidden_grad = recovered_hidden + canonical_weight_grad = grad_weight + canonical_bias_grad = grad_bias + else: + assert torch.allclose(recovered_hidden, canonical_hidden_grad, atol=1e-6) + assert torch.allclose(grad_weight, canonical_weight_grad, atol=1e-6) + if use_bias: + assert torch.allclose(grad_bias, canonical_bias_grad, atol=1e-6) + + +def test_tied_embedding_lm_head_shared_gradient_is_layout_invariant(): + torch.manual_seed(2028) + model = _EmbeddingLMHeadModel(vocab_size=13, hidden_dim=6, bias=False, tie_weights=True) + op = NativeLinearLogpOp() + base_input_ids = torch.tensor([2, 5, 1, 5, 2, 3], dtype=torch.long) + base_target = torch.tensor([4, 1, 0, 2, 6, 3], dtype=torch.long) + base_mask = torch.tensor([True, False, True, True, False, True], dtype=torch.bool) + base_upstream = torch.tensor([0.75, 0.0, -1.25, 0.5, 0.0, 1.0], dtype=torch.float32) + layouts = [ + ((2, 3), [0, 1, 2, 3, 4, 5]), + ((3, 2), [5, 2, 1, 0, 4, 3]), + ] + + assert model.lm_head.weight is model.embedding.weight + canonical_logps = None + canonical_grad = None + + for lead_shape, order in layouts: + order_t = torch.tensor(order, dtype=torch.long) + input_ids = base_input_ids.index_select(0, order_t).reshape(lead_shape) + target = base_target.index_select(0, order_t).reshape(lead_shape) + mask = base_mask.index_select(0, order_t).reshape(lead_shape) + masked_target = target.masked_fill(~mask, -100) + upstream = ( + base_upstream.index_select(0, order_t).reshape(lead_shape).masked_fill(~mask, 0.0) + ) + + model.zero_grad(set_to_none=True) + hidden = model(input_ids) + logps = op( + hidden, model.lm_head.weight, _safe_token_ids(masked_target, mask), model.lm_head.bias + ) + logps = logps.masked_fill(~mask, 0.0) + logits = torch.nn.functional.linear(hidden.float(), model.lm_head.weight.float(), None) + expected = selected_logprobs_reference(logits, masked_target, mask=mask) + (logps * upstream).sum().backward() + + recovered_logps = _recover_canonical_rows(logps.unsqueeze(-1), order).squeeze(-1) + shared_grad = model.embedding.weight.grad.detach().clone() + + assert torch.allclose(logps, expected, atol=1e-5) + if canonical_logps is None: + canonical_logps = recovered_logps + canonical_grad = shared_grad + else: + assert torch.allclose(recovered_logps, canonical_logps, atol=1e-6) + assert torch.allclose(shared_grad, canonical_grad, atol=1e-6) + + def test_native_rejects_shape_mismatch(): native = NativeLinearLogpOp() hidden, weight, _, bias = _inputs(0, device="cpu")