From 45519e55d45cf061878373e42793e9dfd8078682 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Mon, 22 Jun 2026 10:28:16 +0800 Subject: [PATCH 1/4] feat(ws1): add NativeEmbeddingOp pure-PyTorch reference WS1 ground-truth token-embedding op for issue #108 (Qwen3-8B input embedding table, vocab=151936 x hidden=4096, tie_word_embeddings=false): - NativeEmbeddingOp: out = weight[token_ids], a lossless row gather exposing the forward / forward_fp32 dual-path contract (fp32 ground truth + dtype-behavior path); pure function, no in-place mutation. - register PYTORCH_NATIVE_EMBEDDING in OpBackend and the cuda/rocm/cpu priority maps. - tests/test_embedding.py: bitwise correctness vs direct indexing, dtype paths, non-int64 id tolerance, Axis-A batch invariance (slice + padding), purity, sparse gradient flow to weight, registry dispatch, and a GPU-only real-shape smoke test (vocab=151936, boundary ids). - docs/operators/embedding.md + nav/index wiring. --- docs/.nav.yml | 1 + docs/operators/README.md | 1 + docs/operators/embedding.md | 97 ++++++++++ .../kernels/ops/pytorch/linear/__init__.py | 2 + .../kernels/ops/pytorch/linear/embedding.py | 48 +++++ rl_engine/kernels/registry.py | 6 + tests/test_embedding.py | 179 ++++++++++++++++++ 7 files changed, 334 insertions(+) create mode 100644 docs/operators/embedding.md create mode 100644 rl_engine/kernels/ops/pytorch/linear/__init__.py create mode 100644 rl_engine/kernels/ops/pytorch/linear/embedding.py create mode 100644 tests/test_embedding.py diff --git a/docs/.nav.yml b/docs/.nav.yml index 6ba2e50..af9e75f 100644 --- a/docs/.nav.yml +++ b/docs/.nav.yml @@ -14,6 +14,7 @@ nav: - operators/grpo-loss.md - operators/ratio-kl.md - operators/sampling.md + - operators/embedding.md - Developer Guide: - contributing/README.md - General: diff --git a/docs/operators/README.md b/docs/operators/README.md index 4bb7e9e..495970f 100644 --- a/docs/operators/README.md +++ b/docs/operators/README.md @@ -22,4 +22,5 @@ Every operator page should include: - [GRPO Loss](grpo-loss.md) - [Policy Ratio + KL Penalty](ratio-kl.md) - [Sampling](sampling.md) +- [Token Embedding](embedding.md) - [Operator Doc Template](../contributing/operator-doc-template.md) diff --git a/docs/operators/embedding.md b/docs/operators/embedding.md new file mode 100644 index 0000000..0afbc64 --- /dev/null +++ b/docs/operators/embedding.md @@ -0,0 +1,97 @@ +# Token Embedding + +The embedding operator maps integer token ids to their hidden-state rows — the first +layer of the Qwen3/Llama stack. It is a **WS1 ground-truth reference** (issue #108): +a pure-PyTorch definition of the "correct answer" that downstream fused CUDA/Triton +kernels are validated against. + +- **Embedding** (`NativeEmbeddingOp`): `out = weight[token_ids]` — a plain row gather. + +For Qwen3-8B the table is the input embedding `[vocab=151936, hidden=4096]` and is +**independent** from the lm_head weight (`tie_word_embeddings=false`) — the two weights +are not shared. + +## Entry Point +```python +from rl_engine.kernels.registry import kernel_registry + +embedding = kernel_registry.get_op("embedding") + +h = embedding(token_ids, weight) # [B, S], [vocab, hidden] -> [B, S, hidden] +``` + +The op exposes the WS1 dual-path contract: + +- `forward(...)` — gathers in fp32, casts back to the weight dtype (Axis-B accuracy + candidate / dtype-behavior path). +- `forward_fp32(...)` — gathers and returns fp32 (the ground-truth golden path). + +## Backends + +| Backend | Wrapper | Native symbol | Status | +| --- | --- | --- | --- | +| PyTorch fallback | `NativeEmbeddingOp` | None | fp32 ground-truth reference; CPU and any GPU. | +| CUDA / ROCm / Triton | — | — | Planned: downstream fused kernels validate against this reference. | + +## Tensor Contract + +| Argument | Shape | Dtype | Requirements | +| --- | --- | --- | --- | +| `token_ids` | `[B, S]` (any shape) | integer | Index dtype; cast to int64 internally. Values in `[0, vocab)`. | +| `weight` | `[vocab, hidden]` | float (fp16/bf16/fp32) | Embedding table (Qwen3-8B `[151936, 4096]`). | +| output | `token_ids.shape + (hidden,)` | `forward`: weight dtype · `forward_fp32`: float32 | Gathered rows. | + +Output dtype follows `weight` (the float operand); `token_ids` stay integer. Pure +function — no randomness, no in-place mutation, device/dtype follow the inputs. + +## Dispatch Behavior + +`kernel_registry.get_op("embedding")` resolves through the `OpBackend` priority map. On +`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op +(`PYTORCH_NATIVE_EMBEDDING`), so every device dispatches to the fp32 reference. When fused +kernels land, they are prepended to the priority list and the native op becomes the fallback. + +## Accuracy + +Reference semantics (`forward_fp32`): + +```python +out = F.embedding(token_ids.long(), weight.float()) +``` + +- **Ground truth**: `forward_fp32` gathers in and returns fp32. +- **Dtype path**: `forward` runs the same gather, then casts back to the weight dtype; + it is bitwise-equal to `forward_fp32(...).to(dtype)`. +- **Lossless gather — no accuracy drift**: a row gather performs no reduction and no + floating-point accumulation, so the result is **bit-exact** at every dtype. There is no + Axis-B tolerance to calibrate; the gathered rows equal direct indexing exactly. +- **Axis A — batch invariance**: each token's row is independent, so the output is + bitwise-identical regardless of batch size or padding (`torch.equal`, `atol=0`). + +## Performance Notes + +Reference operator — no fused kernel or benchmark yet. Downstream fused kernels carry their +own benchmarks and are measured against this reference for correctness. + +## Tests + +```bash +python -m pytest tests/test_embedding.py -v +``` + +Covers: correctness vs direct indexing (bitwise), dtype paths, non-int64 id tolerance, +Axis-A batch invariance (slice + padding), input purity, gradient flow to `weight` +(including sparse-grad: unused rows stay zero), registry dispatch, and a GPU-only smoke +test at the real Qwen3-8B dims (`vocab=151936, hidden=4096`, boundary ids `0` and +`vocab-1`) that skips when CUDA or GPU memory is unavailable. + +## Implementation Files + +- `rl_engine/kernels/ops/pytorch/linear/embedding.py` +- `rl_engine/kernels/registry.py` +- `tests/test_embedding.py` + +## Known Limitations + +- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). +- Out-of-range token ids are not validated; callers must keep ids in `[0, vocab)`. diff --git a/rl_engine/kernels/ops/pytorch/linear/__init__.py b/rl_engine/kernels/ops/pytorch/linear/__init__.py new file mode 100644 index 0000000..86cf4c9 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/linear/__init__.py @@ -0,0 +1,2 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors diff --git a/rl_engine/kernels/ops/pytorch/linear/embedding.py b/rl_engine/kernels/ops/pytorch/linear/embedding.py new file mode 100644 index 0000000..a5e22e2 --- /dev/null +++ b/rl_engine/kernels/ops/pytorch/linear/embedding.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors + +from __future__ import annotations + +import torch +import torch.nn.functional as F + + +class NativeEmbeddingOp: + """ + Pure PyTorch native token-embedding reference. + out = weight[token_ids] (a plain row gather, no accumulation) + + Maps integer token ids to their hidden-state rows. For Qwen3-8B the + weight is the input embedding table ``[vocab=151936, hidden=4096]`` and + is *independent* from the lm_head weight (``tie_word_embeddings=false``). + """ + + def __init__(self) -> None: + pass + + def __call__(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + return self.forward(token_ids, weight) + + def forward(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """ + Canonical entry: gather in fp32, cast the result back to weight.dtype. + This is the dtype-behavior path used as the Axis-B accuracy candidate. + """ + return self._embedding(token_ids, weight, output_dtype=weight.dtype) + + def forward_fp32(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: + """Ground-truth: gather in fp32 and force fp32 output.""" + return self._embedding(token_ids, weight, output_dtype=torch.float32) + + # ------------------------------------------------------------------ # + # Helpers + # ------------------------------------------------------------------ # + @staticmethod + def _embedding( + token_ids: torch.Tensor, + weight: torch.Tensor, + *, + output_dtype: torch.dtype, + ) -> torch.Tensor: + out = F.embedding(token_ids.long(), weight.float()) + return out.to(output_dtype) diff --git a/rl_engine/kernels/registry.py b/rl_engine/kernels/registry.py index 7aae08f..c2c7fd4 100644 --- a/rl_engine/kernels/registry.py +++ b/rl_engine/kernels/registry.py @@ -45,6 +45,9 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta): TRITON_GENERIC = "rl_engine.kernels.ops.triton.generic.TritonOp" PYTORCH_NATIVE = "rl_engine.kernels.ops.pytorch.loss.logp.NativeLogpOp" + # WS1 pure-PyTorch ground-truth embedding ops + PYTORCH_NATIVE_EMBEDDING = "rl_engine.kernels.ops.pytorch.linear.embedding.NativeEmbeddingOp" + class KernelRegistry: """ @@ -79,6 +82,7 @@ def __init__(self): "attn": [OpBackend.FLASH_ATTN, OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "embedding": [OpBackend.PYTORCH_NATIVE_EMBEDDING], # Default dispatch logic for new operators }, "rocm": { @@ -86,12 +90,14 @@ def __init__(self): "attn": [OpBackend.TRITON_GENERIC, OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS], "ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL], + "embedding": [OpBackend.PYTORCH_NATIVE_EMBEDDING], }, "cpu": { "logp": [OpBackend.PYTORCH_NATIVE], "attn": [OpBackend.PYTORCH_NATIVE], "grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS], "ratio_kl": [OpBackend.PYTORCH_RATIO_KL], + "embedding": [OpBackend.PYTORCH_NATIVE_EMBEDDING], }, } logger.info(f"KernelRegistry initialized for {device_ctx.device_type}") diff --git a/tests/test_embedding.py b/tests/test_embedding.py new file mode 100644 index 0000000..dac1288 --- /dev/null +++ b/tests/test_embedding.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 RL-Kernel Contributors +"""Tests for NativeEmbeddingOp (ISSUE #108 WS1 ground-truth baseline). + +Embedding is a pure row gather (weight[token_ids]) with no floating-point +accumulation, so unlike silu/swiglu there is no reduction-order drift: the +fp32 path and the dtype path differ only by a single cast. Correctness is +therefore asserted bitwise (torch.equal), not allclose. +""" + +import pytest +import torch + +from rl_engine.kernels.ops.pytorch.linear.embedding import NativeEmbeddingOp +from rl_engine.kernels.registry import kernel_registry + +# Qwen3-8B architecture (synthetic tensors, no weight download). Most tests use +# a shrunk vocab/hidden -- the logic is identical and the full-size table is +# pointless to materialize for every case. The real Qwen3-8B dims are exercised +# separately by the GPU smoke test below. +_VOCAB = 128 # shrunk; real value: _QWEN3_VOCAB +_HIDDEN = 64 # shrunk; real value: _QWEN3_HIDDEN + +# Real Qwen3-8B input-embedding table dims: 151936 x 4096 ~ 2.49 GB in fp32. +_QWEN3_VOCAB = 151936 +_QWEN3_HIDDEN = 4096 + + +# Shared helpers -- fixed-seed Generator for determinism / reproducibility. +def _rand_weight(vocab=_VOCAB, hidden=_HIDDEN, *, seed, dtype=torch.float32): + gen = torch.Generator().manual_seed(seed) + return torch.randn(vocab, hidden, generator=gen, dtype=dtype) + + +def _rand_ids(shape, *, seed, vocab=_VOCAB): + # token ids are indices: int64, values in [0, vocab). + gen = torch.Generator().manual_seed(seed) + return torch.randint(0, vocab, shape, generator=gen, dtype=torch.int64) + + +# Correctness: embedding == indexing weight by token_ids. All three dtypes +# tested. Output dtype follows *weight* (the float operand), never token_ids +# (which stay int64). The gather is lossless, so the fp32 reference equals +# weight.float()[token_ids] exactly. +@pytest.mark.parametrize("dtype", (torch.float32, torch.bfloat16, torch.float16)) +def test_native_embedding_matches_gather_reference(dtype: torch.dtype): + token_ids = _rand_ids((2, 5), seed=1) + weight = _rand_weight(seed=1, dtype=dtype) + + fp32_reference = weight.float()[token_ids] + result = NativeEmbeddingOp().forward(token_ids, weight) + + # forward: output dtype follows weight; lossless gather -> bitwise equal + # to the reference cast back down. + assert result.dtype == dtype + assert torch.equal(result, fp32_reference.to(dtype)) + # forward_fp32: forced fp32 output == ground truth. + assert torch.equal(NativeEmbeddingOp().forward_fp32(token_ids, weight), fp32_reference) + + +# Output shape must be token_ids.shape + (hidden,). +def test_native_embedding_output_shape(): + token_ids = _rand_ids((3, 7), seed=2) + weight = _rand_weight(seed=2) + out = NativeEmbeddingOp().forward(token_ids, weight) + assert out.shape == (3, 7, _HIDDEN) + + +# Non-int64 ids (e.g. int32) must be tolerated: the op casts via .long(). +def test_native_embedding_accepts_non_int64_ids(): + token_ids = _rand_ids((2, 4), seed=3).to(torch.int32) + weight = _rand_weight(seed=3) + out = NativeEmbeddingOp().forward(token_ids, weight) + assert torch.equal(out, weight.float()[token_ids.long()].to(weight.dtype)) + + +# Axis A -- batch invariance, bitwise (the WS1 "aligned" property). A token's +# embedding must not depend on how many other tokens share the batch. Compute +# on the full input once, then slice -- never compute a slice on its own (that +# would let the golden source carry its own batch dependence). Trivially true +# for a gather, but asserted explicitly to guard the contract. +def test_embedding_batch_invariance_slice(): + op = NativeEmbeddingOp() + token_ids = _rand_ids((8, 32), seed=4) + weight = _rand_weight(seed=4) + full = op.forward_fp32(token_ids, weight) # compute on full batch... + assert torch.equal(op.forward_fp32(token_ids[:1], weight), full[:1]) # ...then slice + assert torch.equal(op.forward_fp32(token_ids[3:5], weight), full[3:5]) + + +def test_embedding_batch_invariance_with_padding(): + """Padding extra positions must not perturb the real ones (bitwise). + + Mimics a variable-length batch: real token ids followed by padding ids + along the seq dim; the real prefix must match the no-padding result. + """ + op = NativeEmbeddingOp() + weight = _rand_weight(seed=5) + real = _rand_ids((4, 10), seed=5) + pad = _rand_ids((4, 6), seed=99) # 6 extra padding positions + padded = torch.cat([real, pad], dim=1) # concat along seq + assert torch.equal(op.forward_fp32(padded, weight)[:, :10], op.forward_fp32(real, weight)) + + +# Purity -- neither token_ids nor weight may be mutated in place. +def test_embedding_inputs_not_mutated(): + op = NativeEmbeddingOp() + token_ids = _rand_ids((2, 8), seed=6) + weight = _rand_weight(seed=6) + ids_c, w_c = token_ids.clone(), weight.clone() + op.forward(token_ids, weight) + op.forward_fp32(token_ids, weight) + assert torch.equal(token_ids, ids_c) and torch.equal(weight, w_c) + + +# Gradient flows (fp32 autograd = backward golden source). Gradient is only +# defined for weight (token_ids are integer indices). The weight gradient is +# sparse: only the rows that were indexed are non-zero; unused rows stay 0. +def test_embedding_gradient_flows_to_weight(): + op = NativeEmbeddingOp() + token_ids = _rand_ids((2, 4), seed=7, vocab=10) # small vocab -> some unused rows + weight = _rand_weight(vocab=10, seed=7).requires_grad_(True) + op.forward_fp32(token_ids, weight).sum().backward() + + assert torch.isfinite(weight.grad).all() + used = torch.unique(token_ids).tolist() + unused = torch.tensor([i for i in range(10) if i not in used]) + if len(unused): + assert torch.equal(weight.grad[unused], torch.zeros_like(weight.grad[unused])) + + +# Registry dispatch -- "embedding" resolves to NativeEmbeddingOp (matches the +# PYTORCH_NATIVE_EMBEDDING entry + the per-platform priority-map additions). +def test_registry_dispatches_native_embedding_op(): + assert isinstance(kernel_registry.get_op("embedding"), NativeEmbeddingOp) + + +# --------------------------------------------------------------------------- # +# Qwen3-8B real-shape smoke test +# --------------------------------------------------------------------------- # +# Exercises the actual embedding table dims (vocab=151936, hidden=4096). The +# fp32 weight is ~2.5 GB, so this is GPU-only and skips when CUDA is absent or +# there is not enough free memory. The shrunk-dim tests above already cover the +# logic; this one validates the real index range (incl. boundary ids 0 and +# vocab-1) and the real hidden width, with a small (batch, seq) load point. +def _enough_gpu_memory(num_bytes: int) -> bool: + if not torch.cuda.is_available(): + return False + free, _ = torch.cuda.mem_get_info() + return free > int(num_bytes * 1.5) # headroom for the gathered output + + +@pytest.mark.skipif( + not _enough_gpu_memory(_QWEN3_VOCAB * _QWEN3_HIDDEN * 4), + reason="needs a CUDA GPU with room for the ~2.5 GB fp32 Qwen3-8B embedding table", +) +def test_native_embedding_qwen3_8b_real_shape(): + device = torch.device("cuda") + op = NativeEmbeddingOp() + + # SMALL load point (batch=2, seq=16) at the real model dims. + gen = torch.Generator(device=device).manual_seed(0) + token_ids = torch.randint( + 0, _QWEN3_VOCAB, (2, 16), generator=gen, dtype=torch.int64, device=device + ) + # Pin boundary ids so the full vocab range is actually indexed. + token_ids[0, 0] = 0 + token_ids[0, 1] = _QWEN3_VOCAB - 1 + weight = torch.randn( + _QWEN3_VOCAB, _QWEN3_HIDDEN, generator=gen, dtype=torch.float32, device=device + ) + + out = op.forward_fp32(token_ids, weight) + assert out.shape == (2, 16, _QWEN3_HIDDEN) + assert out.dtype == torch.float32 + # Lossless gather: bitwise equal to direct indexing. + assert torch.equal(out, weight[token_ids]) + # Axis A: compute on full batch, then slice (no per-slice recompute). + assert torch.equal(op.forward_fp32(token_ids[:1], weight), out[:1]) From 2de30a703746adf5322616adc96f35e99336b25c Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Tue, 23 Jun 2026 00:25:04 +0800 Subject: [PATCH 2/4] fix(ws1): gather embedding in native dtype to avoid full-table fp32 copy Gathering with weight.float() upcast the entire vocab table to fp32 before the lookup, allocating a multi-GB fp32 copy of the Qwen3-8B embedding table just for a tiny row gather and risking OOM on the default fallback path. A row gather is lossless (pure indexing, no arithmetic), so gather in the weight's native dtype and upcast only the gathered rows -- bitwise-identical to the previous path. All 11 tests in tests/test_embedding.py still pass. --- docs/operators/embedding.md | 14 +++++++++----- rl_engine/kernels/ops/pytorch/linear/embedding.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/docs/operators/embedding.md b/docs/operators/embedding.md index 0afbc64..6efb113 100644 --- a/docs/operators/embedding.md +++ b/docs/operators/embedding.md @@ -22,9 +22,10 @@ h = embedding(token_ids, weight) # [B, S], [vocab, hidden] -> [B, S, hidden] The op exposes the WS1 dual-path contract: -- `forward(...)` — gathers in fp32, casts back to the weight dtype (Axis-B accuracy - candidate / dtype-behavior path). -- `forward_fp32(...)` — gathers and returns fp32 (the ground-truth golden path). +- `forward(...)` — gathers in the weight's native dtype, casts the gathered rows back to + the weight dtype (Axis-B accuracy candidate / dtype-behavior path). +- `forward_fp32(...)` — native-dtype gather, then upcasts the result to fp32 (the + ground-truth golden path). ## Backends @@ -56,10 +57,13 @@ kernels land, they are prepended to the priority list and the native op becomes Reference semantics (`forward_fp32`): ```python -out = F.embedding(token_ids.long(), weight.float()) +out = F.embedding(token_ids.long(), weight).to(torch.float32) ``` -- **Ground truth**: `forward_fp32` gathers in and returns fp32. +- **Ground truth**: `forward_fp32` gathers in the native dtype, then upcasts to fp32. + Because a gather is a lossless row copy, this is bitwise-identical to upcasting the + whole table first — but it never allocates a multi-GB fp32 copy of the full vocab + table for a tiny lookup; only the gathered rows are upcast. - **Dtype path**: `forward` runs the same gather, then casts back to the weight dtype; it is bitwise-equal to `forward_fp32(...).to(dtype)`. - **Lossless gather — no accuracy drift**: a row gather performs no reduction and no diff --git a/rl_engine/kernels/ops/pytorch/linear/embedding.py b/rl_engine/kernels/ops/pytorch/linear/embedding.py index a5e22e2..862b0d3 100644 --- a/rl_engine/kernels/ops/pytorch/linear/embedding.py +++ b/rl_engine/kernels/ops/pytorch/linear/embedding.py @@ -25,13 +25,14 @@ def __call__(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tenso def forward(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """ - Canonical entry: gather in fp32, cast the result back to weight.dtype. + Canonical entry: gather in the weight's native dtype, then cast the + gathered rows to weight.dtype (a no-op here, kept for symmetry). This is the dtype-behavior path used as the Axis-B accuracy candidate. """ return self._embedding(token_ids, weight, output_dtype=weight.dtype) def forward_fp32(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - """Ground-truth: gather in fp32 and force fp32 output.""" + """Ground-truth: native-dtype gather, then upcast the result to fp32.""" return self._embedding(token_ids, weight, output_dtype=torch.float32) # ------------------------------------------------------------------ # @@ -44,5 +45,10 @@ def _embedding( *, output_dtype: torch.dtype, ) -> torch.Tensor: - out = F.embedding(token_ids.long(), weight.float()) + # Embedding is a lossless row gather (pure indexing, no arithmetic), so + # gathering in the weight's native dtype and upcasting the small result + # is bitwise-identical to upcasting the whole table first -- but it never + # allocates a multi-GB fp32 copy of the full vocab table just for a tiny + # lookup. Only the gathered rows are upcast. + out = F.embedding(token_ids.long(), weight) return out.to(output_dtype) From fab4da6eb0a9e896f125f6981ddd0ad5f6cc2514 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sat, 27 Jun 2026 16:58:06 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix(ws1):=20address=20PR=20#169=20review=20?= =?UTF-8?q?=E2=80=94=20embedding=20backward=20determinism?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backward (∂L/∂weight) is a scatter-add; repeated token ids accumulate into the same grad row, and on CUDA that uses atomic adds (nondeterministic order), so the weight gradient is not bitwise reproducible when ids collide. Since forward_fp32 is the backward golden source: - docs: document the limitation under Known Limitations (embedding.md) - test: enable torch.use_deterministic_algorithms(True) in the gradient test and assert grad is bitwise identical across two independent backward passes --- docs/operators/embedding.md | 9 +++++++++ tests/test_embedding.py | 35 +++++++++++++++++++++++++++++------ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/docs/operators/embedding.md b/docs/operators/embedding.md index 6efb113..58d503b 100644 --- a/docs/operators/embedding.md +++ b/docs/operators/embedding.md @@ -99,3 +99,12 @@ test at the real Qwen3-8B dims (`vocab=151936, hidden=4096`, boundary ids `0` an - PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work). - Out-of-range token ids are not validated; callers must keep ids in `[0, vocab)`. +- **GPU backward is bitwise-reproducible only under deterministic algorithms.** The + forward is a lossless gather (always reproducible), but `∂L/∂weight` is a scatter-add: + every repeated token id (padding, common tokens) accumulates into the same `weight.grad` + row. On CUDA that accumulation uses atomic adds, whose ordering is nondeterministic, so + the weight gradient is not bit-exact across runs when ids collide. PyTorch documents + `embedding` backward as a nondeterministic CUDA op for this reason. Since `forward_fp32` + is the backward golden source, callers that need a reproducible GPU gradient must enable + `torch.use_deterministic_algorithms(True)` (the gradient test does this). CPU backward is + always deterministic. diff --git a/tests/test_embedding.py b/tests/test_embedding.py index dac1288..b8069f0 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -116,17 +116,40 @@ def test_embedding_inputs_not_mutated(): # Gradient flows (fp32 autograd = backward golden source). Gradient is only # defined for weight (token_ids are integer indices). The weight gradient is # sparse: only the rows that were indexed are non-zero; unused rows stay 0. +# +# Reproducibility caveat: embedding backward is a scatter-add -- repeated token +# ids accumulate into the same weight.grad row. On CUDA that uses atomic adds +# (nondeterministic ordering), so the weight gradient is bit-exact across runs +# only under torch.use_deterministic_algorithms(True). Since forward_fp32 is the +# backward golden source, we enable it here and assert the gradient is bitwise +# reproducible across two independent backward passes. CPU is always deterministic. def test_embedding_gradient_flows_to_weight(): op = NativeEmbeddingOp() - token_ids = _rand_ids((2, 4), seed=7, vocab=10) # small vocab -> some unused rows - weight = _rand_weight(vocab=10, seed=7).requires_grad_(True) - op.forward_fp32(token_ids, weight).sum().backward() - - assert torch.isfinite(weight.grad).all() + # Small vocab -> some rows unused, and repeated ids -> the scatter-add path + # (multiple contributions into one row) is actually exercised. + token_ids = _rand_ids((2, 4), seed=7, vocab=10) + + prev = torch.are_deterministic_algorithms_enabled() + torch.use_deterministic_algorithms(True) + try: + weight_a = _rand_weight(vocab=10, seed=7).requires_grad_(True) + op.forward_fp32(token_ids, weight_a).sum().backward() + grad_a = weight_a.grad + + weight_b = _rand_weight(vocab=10, seed=7).requires_grad_(True) + op.forward_fp32(token_ids, weight_b).sum().backward() + grad_b = weight_b.grad + finally: + torch.use_deterministic_algorithms(prev) + + assert torch.isfinite(grad_a).all() + # Backward golden source must be bitwise reproducible (Axis-A for gradients). + assert torch.equal(grad_a, grad_b) + # Sparse gradient: rows never indexed stay exactly zero. used = torch.unique(token_ids).tolist() unused = torch.tensor([i for i in range(10) if i not in used]) if len(unused): - assert torch.equal(weight.grad[unused], torch.zeros_like(weight.grad[unused])) + assert torch.equal(grad_a[unused], torch.zeros_like(grad_a[unused])) # Registry dispatch -- "embedding" resolves to NativeEmbeddingOp (matches the From d8b3d39eb57915aa0151c72b34a53645f45ed954 Mon Sep 17 00:00:00 2001 From: maxiaosong1124 Date: Sun, 28 Jun 2026 15:45:25 +0800 Subject: [PATCH 4/4] refactor(ws1): NativeEmbeddingOp inherits nn.Module per PR #169 review --- rl_engine/kernels/ops/pytorch/linear/embedding.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/rl_engine/kernels/ops/pytorch/linear/embedding.py b/rl_engine/kernels/ops/pytorch/linear/embedding.py index 862b0d3..cf01c30 100644 --- a/rl_engine/kernels/ops/pytorch/linear/embedding.py +++ b/rl_engine/kernels/ops/pytorch/linear/embedding.py @@ -7,7 +7,7 @@ import torch.nn.functional as F -class NativeEmbeddingOp: +class NativeEmbeddingOp(torch.nn.Module): """ Pure PyTorch native token-embedding reference. out = weight[token_ids] (a plain row gather, no accumulation) @@ -18,10 +18,7 @@ class NativeEmbeddingOp: """ def __init__(self) -> None: - pass - - def __call__(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - return self.forward(token_ids, weight) + super().__init__() def forward(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: """