Skip to content
1 change: 1 addition & 0 deletions docs/.nav.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ nav:
- operators/grpo-loss.md
- operators/ratio-kl.md
- operators/sampling.md
- operators/embedding.md
- Developer Guide:
- contributing/README.md
- General:
Expand Down
1 change: 1 addition & 0 deletions docs/operators/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ Every operator page should include:
- [GRPO Loss](grpo-loss.md)
- [Policy Ratio + KL Penalty](ratio-kl.md)
- [Sampling](sampling.md)
- [Token Embedding](embedding.md)
- [Operator Doc Template](../contributing/operator-doc-template.md)
110 changes: 110 additions & 0 deletions docs/operators/embedding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Token Embedding

The embedding operator maps integer token ids to their hidden-state rows — the first
layer of the Qwen3/Llama stack. It is a **WS1 ground-truth reference** (issue #108):
a pure-PyTorch definition of the "correct answer" that downstream fused CUDA/Triton
kernels are validated against.

- **Embedding** (`NativeEmbeddingOp`): `out = weight[token_ids]` — a plain row gather.

For Qwen3-8B the table is the input embedding `[vocab=151936, hidden=4096]` and is
**independent** from the lm_head weight (`tie_word_embeddings=false`) — the two weights
are not shared.

## Entry Point
```python
from rl_engine.kernels.registry import kernel_registry

embedding = kernel_registry.get_op("embedding")

h = embedding(token_ids, weight) # [B, S], [vocab, hidden] -> [B, S, hidden]
```

The op exposes the WS1 dual-path contract:

- `forward(...)` — gathers in the weight's native dtype, casts the gathered rows back to
the weight dtype (Axis-B accuracy candidate / dtype-behavior path).
- `forward_fp32(...)` — native-dtype gather, then upcasts the result to fp32 (the
ground-truth golden path).

## Backends

| Backend | Wrapper | Native symbol | Status |
| --- | --- | --- | --- |
| PyTorch fallback | `NativeEmbeddingOp` | None | fp32 ground-truth reference; CPU and any GPU. |
| CUDA / ROCm / Triton | — | — | Planned: downstream fused kernels validate against this reference. |

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `token_ids` | `[B, S]` (any shape) | integer | Index dtype; cast to int64 internally. Values in `[0, vocab)`. |
| `weight` | `[vocab, hidden]` | float (fp16/bf16/fp32) | Embedding table (Qwen3-8B `[151936, 4096]`). |
| output | `token_ids.shape + (hidden,)` | `forward`: weight dtype · `forward_fp32`: float32 | Gathered rows. |

Output dtype follows `weight` (the float operand); `token_ids` stay integer. Pure
function — no randomness, no in-place mutation, device/dtype follow the inputs.

## Dispatch Behavior

`kernel_registry.get_op("embedding")` resolves through the `OpBackend` priority map. On
`cuda` / `rocm` / `cpu` the only registered backend today is the PyTorch native op
(`PYTORCH_NATIVE_EMBEDDING`), so every device dispatches to the fp32 reference. When fused
kernels land, they are prepended to the priority list and the native op becomes the fallback.

## Accuracy

Reference semantics (`forward_fp32`):

```python
out = F.embedding(token_ids.long(), weight).to(torch.float32)
```

- **Ground truth**: `forward_fp32` gathers in the native dtype, then upcasts to fp32.
Because a gather is a lossless row copy, this is bitwise-identical to upcasting the
whole table first — but it never allocates a multi-GB fp32 copy of the full vocab
table for a tiny lookup; only the gathered rows are upcast.
- **Dtype path**: `forward` runs the same gather, then casts back to the weight dtype;
it is bitwise-equal to `forward_fp32(...).to(dtype)`.
- **Lossless gather — no accuracy drift**: a row gather performs no reduction and no
floating-point accumulation, so the result is **bit-exact** at every dtype. There is no
Axis-B tolerance to calibrate; the gathered rows equal direct indexing exactly.
- **Axis A — batch invariance**: each token's row is independent, so the output is
bitwise-identical regardless of batch size or padding (`torch.equal`, `atol=0`).

## Performance Notes

Reference operator — no fused kernel or benchmark yet. Downstream fused kernels carry their
own benchmarks and are measured against this reference for correctness.

## Tests

```bash
python -m pytest tests/test_embedding.py -v
```

Covers: correctness vs direct indexing (bitwise), dtype paths, non-int64 id tolerance,
Axis-A batch invariance (slice + padding), input purity, gradient flow to `weight`
(including sparse-grad: unused rows stay zero), registry dispatch, and a GPU-only smoke
test at the real Qwen3-8B dims (`vocab=151936, hidden=4096`, boundary ids `0` and
`vocab-1`) that skips when CUDA or GPU memory is unavailable.

## Implementation Files

- `rl_engine/kernels/ops/pytorch/linear/embedding.py`
- `rl_engine/kernels/registry.py`
- `tests/test_embedding.py`

## Known Limitations

- PyTorch fallback only; no fused CUDA/Triton backend yet (downstream work).
- Out-of-range token ids are not validated; callers must keep ids in `[0, vocab)`.
- **GPU backward is bitwise-reproducible only under deterministic algorithms.** The
forward is a lossless gather (always reproducible), but `∂L/∂weight` is a scatter-add:
every repeated token id (padding, common tokens) accumulates into the same `weight.grad`
row. On CUDA that accumulation uses atomic adds, whose ordering is nondeterministic, so
the weight gradient is not bit-exact across runs when ids collide. PyTorch documents
`embedding` backward as a nondeterministic CUDA op for this reason. Since `forward_fp32`
is the backward golden source, callers that need a reproducible GPU gradient must enable
`torch.use_deterministic_algorithms(True)` (the gradient test does this). CPU backward is
always deterministic.
2 changes: 2 additions & 0 deletions rl_engine/kernels/ops/pytorch/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors
51 changes: 51 additions & 0 deletions rl_engine/kernels/ops/pytorch/linear/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from __future__ import annotations

import torch
import torch.nn.functional as F


class NativeEmbeddingOp(torch.nn.Module):
"""
Pure PyTorch native token-embedding reference.
out = weight[token_ids] (a plain row gather, no accumulation)

Maps integer token ids to their hidden-state rows. For Qwen3-8B the
weight is the input embedding table ``[vocab=151936, hidden=4096]`` and
is *independent* from the lm_head weight (``tie_word_embeddings=false``).
"""

def __init__(self) -> None:
Comment thread
Flink-ddd marked this conversation as resolved.
super().__init__()

def forward(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""
Canonical entry: gather in the weight's native dtype, then cast the
gathered rows to weight.dtype (a no-op here, kept for symmetry).
This is the dtype-behavior path used as the Axis-B accuracy candidate.
"""
return self._embedding(token_ids, weight, output_dtype=weight.dtype)

def forward_fp32(self, token_ids: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""Ground-truth: native-dtype gather, then upcast the result to fp32."""
return self._embedding(token_ids, weight, output_dtype=torch.float32)

# ------------------------------------------------------------------ #
# Helpers
# ------------------------------------------------------------------ #
@staticmethod
def _embedding(
token_ids: torch.Tensor,
weight: torch.Tensor,
*,
output_dtype: torch.dtype,
) -> torch.Tensor:
# Embedding is a lossless row gather (pure indexing, no arithmetic), so
# gathering in the weight's native dtype and upcasting the small result
# is bitwise-identical to upcasting the whole table first -- but it never
# allocates a multi-GB fp32 copy of the full vocab table just for a tiny
# lookup. Only the gathered rows are upcast.
out = F.embedding(token_ids.long(), weight)
return out.to(output_dtype)
6 changes: 6 additions & 0 deletions rl_engine/kernels/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class OpBackend(Enum, metaclass=_KernelEnumMeta):
PYTORCH_NATIVE_SILU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSiLUOp"
PYTORCH_NATIVE_SWIGLU = "rl_engine.kernels.ops.pytorch.activation.swiglu.NativeSwiGLUOp"

# WS1 pure-PyTorch ground-truth embedding ops
PYTORCH_NATIVE_EMBEDDING = "rl_engine.kernels.ops.pytorch.linear.embedding.NativeEmbeddingOp"


class KernelRegistry:
"""
Expand Down Expand Up @@ -91,6 +94,7 @@ def __init__(self):
"grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS],
"linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP],
"ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL],
"embedding": [OpBackend.PYTORCH_NATIVE_EMBEDDING],
"silu": [OpBackend.PYTORCH_NATIVE_SILU],
"swiglu": [OpBackend.PYTORCH_NATIVE_SWIGLU],
# Default dispatch logic for new operators
Expand All @@ -105,6 +109,7 @@ def __init__(self):
"grpo_loss": [OpBackend.TRITON_GRPO_LOSS, OpBackend.PYTORCH_GRPO_LOSS],
"linear_logp": [OpBackend.TRITON_LINEAR_LOGP, OpBackend.PYTORCH_LINEAR_LOGP],
"ratio_kl": [OpBackend.TRITON_RATIO_KL, OpBackend.PYTORCH_RATIO_KL],
"embedding": [OpBackend.PYTORCH_NATIVE_EMBEDDING],
"silu": [OpBackend.PYTORCH_NATIVE_SILU],
"swiglu": [OpBackend.PYTORCH_NATIVE_SWIGLU],
},
Expand All @@ -114,6 +119,7 @@ def __init__(self):
"grpo_loss": [OpBackend.PYTORCH_GRPO_LOSS],
"linear_logp": [OpBackend.PYTORCH_LINEAR_LOGP],
"ratio_kl": [OpBackend.PYTORCH_RATIO_KL],
"embedding": [OpBackend.PYTORCH_NATIVE_EMBEDDING],
"silu": [OpBackend.PYTORCH_NATIVE_SILU],
"swiglu": [OpBackend.PYTORCH_NATIVE_SWIGLU],
},
Expand Down
Loading
Loading