Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
483 changes: 483 additions & 0 deletions docs/contributing/issue-108-session-log.md

Large diffs are not rendered by default.

35 changes: 27 additions & 8 deletions docs/operators/fused-logp.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,36 @@ logp_op = kernel_registry.get_op("logp")
output = logp_op(logits, token_ids)
```

The PyTorch native reference also exposes the Issue #108 interface:

```python
from rl_engine.kernels.ops.pytorch.loss.logp import NativeLogpOp

logp_ref = NativeLogpOp()
output = logp_ref.forward(logits, token_ids)
reference = logp_ref.forward_fp32(logits, token_ids)
```

`apply(...)` and `apply_fp32(...)` remain available as backward-compatible aliases.

## Backends

| Backend | Wrapper | Native symbol | Notes |
| --- | --- | --- | --- |
| CUDA SM90 | `FusedLogpSM90Op` | `_C.fused_logp_sm90` | TMA-oriented path for Hopper-class GPUs. |
| CUDA generic | `FusedLogpGenericOp` | `_C.fused_logp` | Generic compiled extension fallback. |
| PyTorch native | `NativeOp` | None | Baseline fallback path. |
| PyTorch native | `NativeLogpOp` | None | PyTorch baseline/reference path. |

## Tensor Contract

| Argument | Shape | Dtype | Requirements |
| --- | --- | --- | --- |
| `logits` | `[N, V]` | `bfloat16` for SM90 path | Contiguous, on the target device. |
| `token_ids` / `labels` | `[N]` | Converted to `int32` | Same logical device as `logits`. |
| Output | `[N]` | Backend-defined tensor dtype | One selected log probability per row. |
| `logits` | `[..., V]` | Floating point | Contiguous for fused CUDA paths; arbitrary leading dimensions. |
| `token_ids` / `labels` | `[...]` | Integer | Must match `logits.shape[:-1]`. |
| Output | `[...]` | See below | One selected log probability per row. |

For `NativeLogpOp`, `forward(...)` returns the input dtype and `forward_fp32(...)`
returns `torch.float32`.

## Reference Semantics

Expand All @@ -39,16 +54,20 @@ ref = torch.gather(ref, dim=-1, index=token_ids.unsqueeze(-1).long()).squeeze(-1
## Tests

```bash
python tests/test_op_accuracy.py
python -m pytest tests/test_logp.py -q
python -m pytest tests/test_op_accuracy.py -q
```

The current accuracy test compares the dispatched operator with a PyTorch reference and
uses a dtype-dependent threshold.
`tests/test_logp.py` covers the PyTorch reference contract, dtype behavior,
backward-compatible aliases, batch invariance, and registry dispatch. The existing
operator accuracy tests continue to validate native/CUDA fused API compatibility.

## Implementation Files

- `rl_engine/kernels/registry.py`
- `rl_engine/kernels/ops/cuda.py`
- `rl_engine/kernels/ops/pytorch/loss/logp.py`
- `rl_engine/kernels/ops/cuda/loss/logp.py`
- `csrc/ops.cpp`
- `csrc/fused_logp_kernel.cu`
- `csrc/cuda/fused_logp_sm90.cu`
- `tests/test_logp.py`
10 changes: 10 additions & 0 deletions rl_engine/kernels/gtest/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors

from .op_checks import CandidateSpec, OperatorCase, run_operator_suite

__all__ = [
"CandidateSpec",
"OperatorCase",
"run_operator_suite",
]
Loading