Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 110 additions & 14 deletions atom/model_ops/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import torch
import triton
import triton.language as tl
from atom.utils import envs
from atom.utils.forward_context import SpecDecodeMetadata
from torch import nn

ATOM_ENABLE_RELAXED_MTP = envs.ATOM_ENABLE_RELAXED_MTP
if ATOM_ENABLE_RELAXED_MTP:
RELAXED_TOP_N = 10
RELAXED_DELTA = 0.6
else:
RELAXED_TOP_N = 1
RELAXED_DELTA = 0.0


class RejectionSampler(nn.Module):
def forward(
Expand Down Expand Up @@ -79,18 +88,42 @@ def rejection_sample(
)
num_bonus_tokens = torch.empty(batch_size, dtype=torch.int32, device=device)

# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
num_bonus_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_spec_steps,
num_warps=1,
)
if RELAXED_TOP_N <= 1:
# Strict greedy path: draft must exactly match target argmax
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
num_bonus_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_spec_steps,
num_warps=1,
)
else:
# Relaxed acceptance path: accept if draft is among top-N
# candidates with prob >= (top1_prob - delta)
probs = target_probs.softmax(dim=-1, dtype=torch.float32)
topn_probs, topn_ids = torch.topk(probs, RELAXED_TOP_N, dim=-1)

top1_probs = topn_probs[:, 0:1]
valid_mask = topn_probs >= (top1_probs - RELAXED_DELTA)
topn_ids[~valid_mask] = -1
topn_ids = topn_ids.to(torch.int32).contiguous()

rejection_relaxed_sample_kernel[(batch_size,)](
output_token_ids,
num_bonus_tokens,
cu_num_draft_tokens,
draft_token_ids,
topn_ids,
bonus_token_ids,
num_spec_steps,
RELAXED_TOP_N,
num_warps=1,
)

return output_token_ids, num_bonus_tokens


Expand Down Expand Up @@ -125,7 +158,7 @@ def rejection_greedy_sample_kernel(
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
target_argmax_id = tl.cast(target_argmax_id, tl.int32)
if draft_token_id != target_argmax_id:
# Reject.
# rejected = False
rejected = True
num_bonus_token += 1
tl.store(
Expand All @@ -136,7 +169,70 @@ def rejection_greedy_sample_kernel(
if rejected:
bonus_token_id = INVALID_TOKEN
else:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
num_bonus_token += 1
tl.store(
output_token_ids_ptr + req_idx * (num_spec_steps + 1) + num_draft_tokens,
bonus_token_id,
)
tl.store(num_bonus_tokens_ptr + req_idx, num_bonus_token)


@triton.jit(do_not_specialize=["num_spec_steps", "top_n"])
def rejection_relaxed_sample_kernel(
output_token_ids_ptr, # [batch_size, num_spec_steps + 1]
num_bonus_tokens_ptr,
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
topn_ids_ptr, # [num_tokens, top_n] — candidate token ids, -1 = invalid
bonus_token_ids_ptr, # [batch_size]
num_spec_steps,
top_n,
):
req_idx = tl.program_id(0)

if req_idx == 0:
start_idx = 0
else:
start_idx = tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx

rejected = False
num_bonus_token = -1
INVALID_TOKEN: tl.constexpr = -1

for pos in range(num_draft_tokens):
if rejected:
output_id = INVALID_TOKEN
else:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)

base_offset = (start_idx + pos) * top_n
top1_id = tl.load(topn_ids_ptr + base_offset)

found = False
for k in range(top_n):
candidate_id = tl.load(topn_ids_ptr + base_offset + k)
if candidate_id == draft_token_id:
found = True

if found:
output_id = draft_token_id
else:
output_id = top1_id
rejected = True

num_bonus_token += 1

tl.store(
output_token_ids_ptr + req_idx * (num_spec_steps + 1) + pos,
output_id,
)

if rejected:
bonus_token_id = INVALID_TOKEN
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
num_bonus_token += 1
tl.store(
Expand Down
3 changes: 3 additions & 0 deletions atom/utils/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@
"ATOM_DUAL_STREAM_MOE_TOKEN_THRESHOLD": lambda: int(
os.getenv("ATOM_DUAL_STREAM_MOE_TOKEN_THRESHOLD", "1024")
),
# --- MTP (relaxed mtp for quantized mtp) ---
"ATOM_ENABLE_RELAXED_MTP": lambda: os.getenv("ATOM_ENABLE_RELAXED_MTP", "0").lower()
== "1",
}


Expand Down
8 changes: 8 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"ATOM_DISABLE_VLLM_PLUGIN",
"ATOM_DISABLE_VLLM_PLUGIN_ATTENTION",
"ATOM_USE_CUSTOM_ALL_GATHER",
"ATOM_ENABLE_RELAXED_MTP",
]


Expand Down Expand Up @@ -84,6 +85,9 @@ def test_disable_vllm_plugin_default(self):
def test_disable_vllm_plugin_attention_default(self):
assert _get_envs().ATOM_DISABLE_VLLM_PLUGIN_ATTENTION is False

def test_atom_enable_relaxed_mtp_default(self):
assert _get_envs().ATOM_ENABLE_RELAXED_MTP is False

def test_unknown_attr_raises(self):
with pytest.raises(AttributeError):
_ = _get_envs().ATOM_NONEXISTENT_VAR
Expand Down Expand Up @@ -132,6 +136,10 @@ def test_disable_vllm_plugin_attention_enabled(self, monkeypatch):
monkeypatch.setenv("ATOM_DISABLE_VLLM_PLUGIN_ATTENTION", "1")
assert _get_envs().ATOM_DISABLE_VLLM_PLUGIN_ATTENTION is True

def test_atom_enable_relaxed_mtp_enabled(self, monkeypatch):
monkeypatch.setenv("ATOM_ENABLE_RELAXED_MTP", "1")
assert _get_envs().ATOM_ENABLE_RELAXED_MTP is True


class TestIsSet:
"""Test the is_set() helper function."""
Expand Down
Loading