Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion meta/docker/Dockerfile.x86_64-cuda
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ RUN apt update && apt upgrade -y
WORKDIR /scratchpad

COPY . .
COPY --from=ghcr.io/xiaozheyao/sp-builder:v0.1.6-x86 /wheels /wheels
COPY --from=ghcr.io/xiaozheyao/sp-builder:v0.1.6-x86_64 /wheels /wheels

RUN pip install --no-cache-dir /wheels/flashinfer_python-0.2.3-cp38-abi3-linux_x86_64.whl && \
pip install --no-cache-dir /wheels/triteia-0.1.0-cp310-cp310-linux_x86_64.whl
Expand Down
2 changes: 1 addition & 1 deletion meta/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ nvidia-cuda-nvrtc-cu12
cuda-python
setproctitle
soundfile
triton==3.0.0
partial_json_parser
httpx
84 changes: 52 additions & 32 deletions scratchpad/constrained/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,66 @@ class BaseGrammarObject:
pass


INVALID_GRAMMAR_OBJ: BaseGrammarObject = BaseGrammarObject()


class BaseGrammarBackend:
def __init__(self):
self.executor = ThreadPoolExecutor()
self.cache = {}
self.cache_lock = Lock()

def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
with self.cache_lock:
if key in self.cache:
cache_hit = True
entry = self.cache[key]
else:
cache_hit = False
entry = CacheEntry(None, Event())
self.cache[key] = entry

if cache_hit:
entry.event.wait()
self.cache: Dict[Tuple[str, str], CacheEntry] = {}

def _not_supported(self, key_type: str, key_string: str) -> None:
logger.warning(f"Skip unsupported {key_type=}, {key_string=}")

def dispatch_fallback(
self, key_type: str, key_string: str
) -> Optional[BaseGrammarObject]:
"""
This function should not be reached in any case.
"""
raise ValueError(f"Invalid key_type: {key_type}={key_string}")

def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("json", key_string)

def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("regex", key_string)

def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("ebnf", key_string)

def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
return self._not_supported("structural_tag", key_string)

def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
key_type, key_string = key
if key_type == "json":
return self.dispatch_json(key_string)
elif key_type == "regex":
return self.dispatch_regex(key_string)
elif key_type == "ebnf":
return self.dispatch_ebnf(key_string)
elif key_type == "structural_tag":
return self.dispatch_structural_tag(key_string)
elif key_type == "structural_pattern":
return self.dispatch_structural_pattern(key_string)
else:
entry.value = self.init_value_impl(key)
entry.event.set()
return entry.value.copy() if entry.value else None

def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
raise NotImplementedError()
return self.dispatch_fallback(key_type, key_string)

def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
with self.cache_lock:
entry = self.cache.get(key)
if not entry or not entry.event.is_set():
return None
val = self.cache[key].value
return val.copy() if val else None
def get_cached_or_future_value(
self, key: Tuple[str, str]
) -> Optional[BaseGrammarObject]:
value = self.cache.get(key)
if value:
return value.copy(), True
value = self.executor.submit(self._init_value_dispatch, key)
return value, False

def get_future_value(self, key: Tuple[str, str]) -> Future:
return self.executor.submit(self.init_value, key)
def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject):
self.cache[key] = value

def reset(self):
with self.cache_lock:
self.cache.clear()
self.cache.clear()


def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
Expand Down
141 changes: 141 additions & 0 deletions scratchpad/constrained/triton_ops/bitmask_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Adapt from
# https://github.com/mlc-ai/xgrammar/blob/v0.1.17/python/xgrammar/kernels/apply_token_bitmask_inplace_triton.py

from typing import List, Optional, Union

import torch
import triton
import triton.language as tl

from scratchpad.utils import get_device_core_count


@triton.jit
def apply_token_bitmask_inplace_kernel(
logits_ptr,
bitmask_ptr,
indices_ptr,
num_rows,
vocab_size,
logits_strides,
bitmask_strides,
NUM_SMS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Apply a bitmask to logits in-place using Triton. The bitmask is a 01 bitwise compressed tensor,
where 0 means the token is masked and 1 means the token is not masked. After applying the bitmask,
the masked logits will be set to -inf.

Parameters
----------
logits_ptr : tl.tensor
Pointer to the logits tensor to apply the bitmask to.

bitmask_ptr : tl.tensor
Pointer to the bitmask tensor to apply.

indices_ptr : Optional[tl.tensor]
Optional pointer to indices tensor specifying which rows to apply the mask to.

num_rows : int
Number of rows to process. If indices_ptr is provided, this is the number of unique indices.

vocab_size : int
Size of the vocabulary dimension. If the logits does not have a vocab padding, this is the
same as the logits's second dimension. Otherwise, this is the actual size of the vocabulary.

logits_strides : int
Stride between rows in the logits tensor.

bitmask_strides : int
Stride between rows in the bitmask tensor.

NUM_SMS : int
Number of streaming multiprocessors to use.

BLOCK_SIZE : int
Size of processing blocks.
"""

pid = tl.program_id(0)
num_blocks = tl.cdiv(vocab_size, BLOCK_SIZE)
for work_id in tl.range(pid, num_rows * num_blocks, NUM_SMS):
row_id = work_id // num_blocks
block_offset = (work_id % num_blocks) * BLOCK_SIZE
batch_id = row_id if indices_ptr is None else tl.load(indices_ptr + row_id)
offsets = block_offset + tl.arange(0, BLOCK_SIZE)
bitmask_offsets = block_offset // 32 + tl.arange(0, BLOCK_SIZE // 32)
vocab_mask = offsets < vocab_size
packed_bitmask_mask = bitmask_offsets < bitmask_strides
packed_bitmask = tl.load(
bitmask_ptr + batch_id * bitmask_strides + bitmask_offsets,
packed_bitmask_mask,
)
bitmask = ((packed_bitmask[:, None] >> (tl.arange(0, 32)[None, :])) & 1) == 0
bitmask = bitmask.reshape(BLOCK_SIZE)

tl.store(
logits_ptr + batch_id * logits_strides + offsets,
-float("inf"),
vocab_mask & bitmask,
)


def apply_token_bitmask_inplace_triton(
logits: torch.Tensor,
bitmask: torch.Tensor,
indices: Optional[Union[List[int], torch.Tensor]] = None,
):
NUM_SMS = get_device_core_count()
BLOCK_SIZE = 4096
BITS_PER_BLOCK = 32

# Check input dtype
assert bitmask.dtype == torch.int32, "bitmask must be of type int32"

# Check input tensor shapes.
logits_shape = logits.shape
bitmask_shape = bitmask.shape
if logits.ndim == 1:
logits_shape = (1, logits_shape[0])
if bitmask.ndim == 1:
bitmask_shape = (1, bitmask_shape[0])

required_bitmask_width = (logits_shape[1] + BITS_PER_BLOCK - 1) // BITS_PER_BLOCK
assert required_bitmask_width >= bitmask_shape[1], (
f"Bitmask width too large: allow at most {required_bitmask_width} int32s for "
f"logits' width {logits_shape[1]}, but got {bitmask_shape[1]}"
)

vocab_size = min(logits_shape[1], bitmask_shape[1] * BITS_PER_BLOCK)

num_rows = None
if isinstance(indices, list) or isinstance(indices, torch.Tensor):
indices = torch.tensor(indices, dtype=torch.int32, device=logits.device)
num_rows = indices.shape[0]
else:
assert (
logits_shape[0] == bitmask_shape[0]
), f"batch size mismatch: logits {logits_shape[0]} vs bitmask {bitmask_shape[0]}"
num_rows = logits_shape[0]

if NUM_SMS > 0:
grid = (NUM_SMS,)
else:
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
grid = (num_rows * num_blocks,)
NUM_SMS = triton.next_power_of_2(grid[0])

apply_token_bitmask_inplace_kernel[grid](
logits,
bitmask,
indices,
num_rows,
vocab_size,
logits_shape[1],
bitmask_shape[1],
NUM_SMS,
BLOCK_SIZE,
num_warps=BLOCK_SIZE // 32 // (16 // logits.element_size()),
num_stages=3,
)
Loading
Loading