From f917d026cf591287abe71074a55b18203c2573c7 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 11 May 2026 11:33:16 +0000 Subject: [PATCH 1/3] feat: support invalid_token_ids in sampling params - Add InvalidTokenIds ctypes struct and shm field on SamplingParams, populated from request `logit_bias` keys. - Plumb invalid_token_ids through py SamplingParams and InferSamplingParams, including vocab_size validation. - Add apply_invalid_token_ids Triton kernel that masks given token ids to -inf, applied during sampling between penalty application and softmax. - Move apply_penalty.py and apply_penalty_gpu_cache.py into a new triton_kernel/post_process/ subdirectory and add the new kernel there. - Add unit test for the new kernel. --- .../triton_kernel/post_process/__init__.py | 0 .../post_process/apply_invalid_token.py | 36 +++++++++++++ .../{ => post_process}/apply_penalty.py | 0 .../apply_penalty_gpu_cache.py | 0 .../server/core/objs/py_sampling_params.py | 4 ++ lightllm/server/core/objs/sampling_params.py | 28 +++++++++++ .../server/router/model_infer/infer_batch.py | 8 +++ .../mode_backend/generic_post_process.py | 38 +++++++++++++- .../triton_kernel/test_apply_invalid_token.py | 50 +++++++++++++++++++ 9 files changed, 162 insertions(+), 2 deletions(-) create mode 100644 lightllm/common/basemodel/triton_kernel/post_process/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py rename lightllm/common/basemodel/triton_kernel/{ => post_process}/apply_penalty.py (100%) rename lightllm/common/basemodel/triton_kernel/{ => post_process}/apply_penalty_gpu_cache.py (100%) create mode 100644 unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py diff --git a/lightllm/common/basemodel/triton_kernel/post_process/__init__.py b/lightllm/common/basemodel/triton_kernel/post_process/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py new file mode 100644 index 0000000000..353affd8ed --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/post_process/apply_invalid_token.py @@ -0,0 +1,36 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_apply_invalid_token( + Logits, + invalid_token_ids, + cu_invalid_token_num, + stride_logit_b, +): + cur_batch = tl.program_id(0) + start_index = tl.load(cu_invalid_token_num + cur_batch) + end_index = tl.load(cu_invalid_token_num + cur_batch + 1) + for i in range(start_index, end_index): + cur_invalid_token_id = tl.load(invalid_token_ids + i) + cur_logit_ptr = Logits + cur_batch * stride_logit_b + cur_invalid_token_id + tl.store(cur_logit_ptr, float("-inf")) + return + + +def apply_invalid_token_ids( + Logits: torch.Tensor, + invalid_token_ids: torch.Tensor, + cu_invalid_token_num: torch.Tensor, +): + batch_size = Logits.shape[0] + grid = (batch_size,) + _fwd_kernel_apply_invalid_token[grid]( + Logits=Logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + stride_logit_b=Logits.stride(0), + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty.py diff --git a/lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py b/lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py similarity index 100% rename from lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py rename to lightllm/common/basemodel/triton_kernel/post_process/apply_penalty_gpu_cache.py diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index bf7051abdc..5d3a511d21 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -54,6 +54,8 @@ def __init__( # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. allowed_token_ids: Optional[List[int]] = None, + # if provided, the invalid token ids will be ignored during generation + invalid_token_ids: Optional[List[int]] = None, # p d mode used params group_request_id: Optional[int] = None, # move kv to deocde node, only used in pd mode @@ -89,6 +91,7 @@ def __init__( self.guided_grammar = guided_grammar self.guided_json = guided_json self.allowed_token_ids = allowed_token_ids + self.invalid_token_ids = invalid_token_ids self.group_request_id = group_request_id self.move_kv_to_decode_node = move_kv_to_decode_node self.suggested_dp_index = suggested_dp_index @@ -269,6 +272,7 @@ def to_dict(self): ret["guided_grammar"] = self.guided_grammar ret["guided_json"] = self.guided_json ret["allowed_token_ids"] = self.allowed_token_ids + ret["invalid_token_ids"] = self.invalid_token_ids ret["move_kv_to_decode_node"] = self.move_kv_to_decode_node ret["seed"] = self.seed return ret diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 707f8f31b5..c94f3c6957 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -17,6 +17,7 @@ REGULAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_REGULAR_CONSTRAINT_MAX_LENGTH", 2048)) GRAMMAR_CONSTRAINT_MAX_LENGTH = int(os.getenv("LIGHTLLM_GRAMMAR_CONSTRAINT_MAX_LENGTH", 2048)) JSON_SCHEMA_MAX_LENGTH = int(os.getenv("LIGHTLLM_JSON_SCHEMA_MAX_LENGTH", 2048)) +INVALID_TOKEN_IDS_MAX_LENGTH = int(os.getenv("LIGHTLLM_INVALID_TOKEN_IDS_MAX_LENGTH", 10)) class StopSequence(ctypes.Structure): @@ -205,6 +206,25 @@ def to_list(self): return list(self.ids[: self.size]) +class InvalidTokenIds(ctypes.Structure): + _pack_ = 4 + _fields_ = [ + ("ids", ctypes.c_int * INVALID_TOKEN_IDS_MAX_LENGTH), + ("size", ctypes.c_int), + ] + + def initialize(self, ids: List[int]): + self.size = len(ids) + assert ( + self.size <= INVALID_TOKEN_IDS_MAX_LENGTH + ), f"Too many invalid token IDs {self.size} > {INVALID_TOKEN_IDS_MAX_LENGTH}." + self.ids[: self.size] = ids[:] + return + + def to_list(self): + return list(self.ids[: self.size]) + + class ExponentialDecayLengthPenalty(ctypes.Structure): _pack_ = 4 _fields_ = [ @@ -304,6 +324,8 @@ class SamplingParams(ctypes.Structure): # processor which only retains scores for the given token ids. Defaults to None. # allowed_token_ids only can be used in "--output_constraint_mode outlines" started server. ("allowed_token_ids", AllowedTokenIds), + # if provided, the invalid token ids will be ignored during generation + ("invalid_token_ids", InvalidTokenIds), ("stop_sequences", StopSequenceGroups), ("exponential_decay_length_penalty", ExponentialDecayLengthPenalty), ("group_request_id", ctypes.c_int64), # p d mode used params @@ -394,6 +416,11 @@ def init(self, tokenizer, **kwargs): self.allowed_token_ids = AllowedTokenIds() self.allowed_token_ids.initialize(allowed_token_ids) + # Initialize invalid_token_ids + invalid_token_ids = map(int, kwargs.get("logit_bias", {}).keys()) + self.invalid_token_ids = InvalidTokenIds() + self.invalid_token_ids.initialize(list[int](invalid_token_ids)) + if self.do_sample is False: self.temperature = 1.0 self.top_p = 1.0 @@ -493,6 +520,7 @@ def to_dict(self): "guided_grammar": self.guided_grammar.to_str(), "guided_json": self.guided_json.to_str(), "allowed_token_ids": self.allowed_token_ids.to_list(), + "invalid_token_ids": self.invalid_token_ids.to_list(), "group_request_id": self.group_request_id, "move_kv_to_decode_node": self.move_kv_to_decode_node.to_dict(), "skip_special_tokens": self.skip_special_tokens, diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 3a27c082de..7c19b5748e 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -444,6 +444,9 @@ def __init__( if len(self.allowed_token_ids) == 0: self.allowed_token_ids = None + # if provided, invalid_token_ids are masked to -inf during sampling (see generic_post_process.sample) + self.invalid_token_ids = self.shm_param.invalid_token_ids.to_list() + # p d mode use params if self.shm_param.move_kv_to_decode_node.exists: self.move_kv_to_decode_node = self.shm_param.move_kv_to_decode_node.to_dict() @@ -456,6 +459,11 @@ def __init__( logger.error("allowed_token_ids contain tokenid >= vobsize, we remove these token ids") self.allowed_token_ids = [e for e in self.allowed_token_ids if e < vocab_size] + if len(self.invalid_token_ids) > 0: + if not all(e < vocab_size for e in self.invalid_token_ids): + logger.error("invalid_token_ids contain tokenid >= vobsize, we remove these token ids") + self.invalid_token_ids = [e for e in self.invalid_token_ids if e < vocab_size] + # nixl decode node information if self.shm_param.nixl_params.data_len > 0: self.nixl_decode_node: NIXLDecodeNodeInfo = pickle.loads(self.shm_param.nixl_params.get()) diff --git a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py index f3ad03662e..41e89da9ab 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_post_process.py @@ -1,7 +1,8 @@ import torch from typing import List, Tuple -from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty -from lightllm.common.basemodel.triton_kernel.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty import apply_penalty +from lightllm.common.basemodel.triton_kernel.post_process.apply_penalty_gpu_cache import apply_penalty_gpu_cache +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import apply_invalid_token_ids from lightllm.server.router.model_infer.infer_batch import InferReq, g_infer_context from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.utils.envs_utils import get_env_start_args @@ -15,7 +16,10 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): b_top_ks, b_length_penalty_param, b_mask_eos_reqs, + invalid_token_ids, + cu_invalid_token_num, is_all_greedy, + has_invalid_token_ids, skip_top_k, skip_top_p, exist_req_use_random_seed, @@ -63,6 +67,14 @@ def sample(logits: torch.Tensor, reqs: List[InferReq], eos_id: List[int] = [2]): eos_ids=eos_ids, sampling_params_manager=sampling_params_manager, ) + + if has_invalid_token_ids: + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + logits.div_(b_temperatures.view((-1, 1))) probs = torch.softmax(logits, dim=-1) @@ -152,6 +164,12 @@ def _get_post_sample_tensors(reqs: List[InferReq]): skip_top_p = True exist_req_use_random_seed = False + # invalid token ids + invalid_token_ids: List[int] = [] + has_invalid_token_ids = False + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for i, req_obj in enumerate(reqs): sample_param = req_obj.sampling_param shm_param = sample_param.shm_param @@ -173,6 +191,11 @@ def _get_post_sample_tensors(reqs: List[InferReq]): if req_obj.generator is not None: exist_req_use_random_seed = True req_idxes.append(req_obj.req_idx) + invalid_token_num_start += len(req_obj.sampling_param.invalid_token_ids) + cu_invalid_token_num.append(invalid_token_num_start) + if len(req_obj.sampling_param.invalid_token_ids) > 0: + has_invalid_token_ids = True + invalid_token_ids.extend(req_obj.sampling_param.invalid_token_ids) req_idxes_cpu = g_pin_mem_manager.gen_from_list(key="req_idxes", data=req_idxes, dtype=torch.int32) temperatures_cpu = g_pin_mem_manager.gen_from_list(key="temperatures", data=temperatures, dtype=torch.float32) @@ -183,6 +206,14 @@ def _get_post_sample_tensors(reqs: List[InferReq]): ) mask_eos_reqs_cpu = g_pin_mem_manager.gen_from_list(key="mask_eos_reqs", data=mask_eos_reqs, dtype=torch.bool) + if has_invalid_token_ids: + invalid_token_ids_cpu = g_pin_mem_manager.gen_from_list( + key="invalid_token_ids", data=invalid_token_ids, dtype=torch.int32 + ) + cu_invalid_token_num_cpu = g_pin_mem_manager.gen_from_list( + key="cu_invalid_token_num", data=cu_invalid_token_num, dtype=torch.int32 + ) + return ( req_idxes_cpu.cuda(non_blocking=True), temperatures_cpu.cuda(non_blocking=True), @@ -190,7 +221,10 @@ def _get_post_sample_tensors(reqs: List[InferReq]): top_ks_cpu.cuda(non_blocking=True), length_penalty_param_cpu.cuda(non_blocking=True), mask_eos_reqs_cpu.cuda(non_blocking=True), + invalid_token_ids_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, + cu_invalid_token_num_cpu.cuda(non_blocking=True) if has_invalid_token_ids else None, is_all_greedy, + has_invalid_token_ids, skip_top_k, skip_top_p, exist_req_use_random_seed, diff --git a/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py new file mode 100644 index 0000000000..3b2f159f62 --- /dev/null +++ b/unit_tests/common/basemodel/triton_kernel/test_apply_invalid_token.py @@ -0,0 +1,50 @@ +import pytest +import torch + +from lightllm.common.basemodel.triton_kernel.post_process.apply_invalid_token import ( + apply_invalid_token_ids, +) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) +def test_apply_invalid_token_ids(dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Triton kernels.") + + batch_size = 4 + vocab_size = 32 + logits = torch.randn((batch_size, vocab_size), device="cuda", dtype=dtype) + expected = logits.clone() + + invalid_token_ids_per_batch = [ + [1, 3, 5], + [], + [0, 2, 31], + [7], + ] + + flat_ids = [] + cu_invalid_token_num = [0] + invalid_token_num_start = 0 + for ids in invalid_token_ids_per_batch: + flat_ids.extend(ids) + invalid_token_num_start += len(ids) + cu_invalid_token_num.append(invalid_token_num_start) + + invalid_token_ids = torch.tensor(flat_ids, device="cuda", dtype=torch.int32) + cu_invalid_token_num = torch.tensor(cu_invalid_token_num, device="cuda", dtype=torch.int32) + + for batch_idx, ids in enumerate(invalid_token_ids_per_batch): + if ids: + expected[batch_idx, ids] = float("-inf") + + apply_invalid_token_ids( + Logits=logits, + invalid_token_ids=invalid_token_ids, + cu_invalid_token_num=cu_invalid_token_num, + ) + assert torch.equal(logits, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) From ceebefdd68b3797231f0989d8e3db3096bec8946 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 11 May 2026 11:48:11 +0000 Subject: [PATCH 2/3] test: add invalid_token_ids API smoke test Hits /generate twice with a logit_bias map covering common English tokens (via the Qwen3.5 tokenizer) and asserts none of the blocked ids appear in the biased output, while the baseline produces them. --- test/test_api/test_invalid_token_ids.py | 129 ++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 test/test_api/test_invalid_token_ids.py diff --git a/test/test_api/test_invalid_token_ids.py b/test/test_api/test_invalid_token_ids.py new file mode 100644 index 0000000000..e85b786bdf --- /dev/null +++ b/test/test_api/test_invalid_token_ids.py @@ -0,0 +1,129 @@ +""" +Smoke test for the invalid_token_ids feature (logit_bias path). + +Hits the lightllm-native /generate endpoint, which forwards `logit_bias` keys +into the SamplingParams `invalid_token_ids` field. The kernel masks those +ids to -inf, so they must never appear in the output. + +Run: + python test/test_api/test_invalid_token_ids.py + +Assumes the server is up on http://localhost:8000 and the model tokenizer +is Qwen3.5 (matches the launch command in the PR description). +""" + +import json +import sys +from typing import Dict, List, Tuple + +import requests +from transformers import AutoTokenizer + + +URL = "http://localhost:8000/generate" +HEADERS = {"Content-Type": "application/json"} +MODEL_DIR = "/nvme/models/Qwen3.5-35B-A3B" + +# Stay under INVALID_TOKEN_IDS_MAX_LENGTH (default 10). +BLOCK_WORDS = ["the", " the", "The", " is", " a", " of", " and"] + + +def _post_generate(prompt: str, parameters: dict, timeout: int = 120) -> dict: + payload = {"inputs": prompt, "parameters": parameters} + resp = requests.post(URL, headers=HEADERS, data=json.dumps(payload), timeout=timeout) + if resp.status_code != 200: + raise RuntimeError(f"{resp.status_code} {resp.text}") + return resp.json() + + +def _generated_text(resp: dict) -> str: + text = resp["generated_text"] + return text[0] if isinstance(text, list) else text + + +def _token_ids_from_details(resp: dict) -> List[int]: + tokens = resp.get("tokens", []) + if tokens and isinstance(tokens[0], list): + tokens = tokens[0] + out: List[int] = [] + for tok in tokens: + tid = tok.get("id") + if tid is not None: + out.append(int(tid)) + return out + + +def _build_block_map(tokenizer) -> Tuple[Dict[int, float], Dict[int, str]]: + """Map token id -> bias (-100 = block) and id -> source word.""" + bias: Dict[int, float] = {} + source: Dict[int, str] = {} + for w in BLOCK_WORDS: + ids = tokenizer.encode(w, add_special_tokens=False) + for tid in ids: + bias.setdefault(tid, -100.0) + source.setdefault(tid, w) + return bias, source + + +def test_invalid_token_ids(): + print("[1/3] Loading tokenizer...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True) + + bias_map, source_map = _build_block_map(tokenizer) + blocked_ids = sorted(bias_map.keys()) + print(f" Blocking {len(blocked_ids)} token ids: {blocked_ids}") + for tid in blocked_ids: + print(f" {tid:6d} <- {source_map[tid]!r}") + + prompt = ( + "Write three short English sentences about San Francisco. " + "Mention the bay, the bridge and the weather." + ) + base_params = { + "do_sample": False, + "temperature": 1.0, + "max_new_tokens": 80, + "return_details": True, + } + + print("[2/3] Baseline request (no logit_bias)...", flush=True) + base_resp = _post_generate(prompt, dict(base_params)) + base_text = _generated_text(base_resp) + base_ids = _token_ids_from_details(base_resp) + print(f" text: {base_text!r}") + base_hits = [tid for tid in base_ids if tid in bias_map] + print(f" blocked-tokens that appeared in baseline: {len(base_hits)} ({base_hits[:10]})") + + print("[3/3] logit_bias request...", flush=True) + bias_params = dict(base_params) + bias_params["logit_bias"] = {str(k): v for k, v in bias_map.items()} + biased_resp = _post_generate(prompt, bias_params) + biased_text = _generated_text(biased_resp) + biased_ids = _token_ids_from_details(biased_resp) + print(f" text: {biased_text!r}") + biased_hits = [(tid, source_map[tid]) for tid in biased_ids if tid in bias_map] + print(f" blocked-tokens that appeared with bias: {len(biased_hits)} ({biased_hits[:10]})") + + failures = [] + if biased_hits: + failures.append(f"Blocked token ids leaked into biased output: {biased_hits}") + + # Sanity check: the baseline should have produced at least one of the blocked tokens. + # If it did not, the test is uninformative (but still passes the strict check above). + if not base_hits: + print(" WARNING: baseline did not produce any of the target tokens; " + "the assertion below is trivially satisfied.") + + if biased_text == base_text: + failures.append("Biased output is identical to baseline; bias may not be applied.") + + if failures: + for f in failures: + print(f"FAIL: {f}", file=sys.stderr) + sys.exit(1) + + print("PASS: invalid_token_ids correctly suppressed blocked tokens.") + + +if __name__ == "__main__": + test_invalid_token_ids() From 3e9fbf308655d257a35f9c1870ff65e11b8f92f5 Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 11 May 2026 12:25:14 +0000 Subject: [PATCH 3/3] reformat --- test/test_api/test_invalid_token_ids.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_api/test_invalid_token_ids.py b/test/test_api/test_invalid_token_ids.py index e85b786bdf..82923f4613 100644 --- a/test/test_api/test_invalid_token_ids.py +++ b/test/test_api/test_invalid_token_ids.py @@ -12,6 +12,7 @@ is Qwen3.5 (matches the launch command in the PR description). """ + import json import sys from typing import Dict, List, Tuple @@ -75,10 +76,7 @@ def test_invalid_token_ids(): for tid in blocked_ids: print(f" {tid:6d} <- {source_map[tid]!r}") - prompt = ( - "Write three short English sentences about San Francisco. " - "Mention the bay, the bridge and the weather." - ) + prompt = "Write three short English sentences about San Francisco. " "Mention the bay, the bridge and the weather." base_params = { "do_sample": False, "temperature": 1.0, @@ -111,8 +109,10 @@ def test_invalid_token_ids(): # Sanity check: the baseline should have produced at least one of the blocked tokens. # If it did not, the test is uninformative (but still passes the strict check above). if not base_hits: - print(" WARNING: baseline did not produce any of the target tokens; " - "the assertion below is trivially satisfied.") + print( + " WARNING: baseline did not produce any of the target tokens; " + "the assertion below is trivially satisfied." + ) if biased_text == base_text: failures.append("Biased output is identical to baseline; bias may not be applied.")