Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions miles/utils/chat_template_utils/tito_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,63 @@ def merge_tokens(
return prefix + incremental


# ---------------------------------------------------------------------------
# K2V3 family implementation
# ---------------------------------------------------------------------------


class K2V3TITOTokenizer(TITOTokenizer):
"""K2V3 family.

The chat template emits ``<|im_end|>\\n`` after every message (jinja
block whitespace between ``{{- '<|im_end|>' }}`` and the next block
is preserved by default ``trim_blocks``), but the model
autoregressively stops at ``<|im_end|>`` without generating the
trailing ``\\n``. ``merge_tokens`` inserts the missing newline so the
pretokenized buffer matches the canonical template output.

Empirical sanity check::

apply_chat_template([user, assistant, user], tokenize=False)
→ '...hello<|im_end|>\\n<|im_start|>user\\n...'
^^
"""

_default_assistant_start_str: str = "<|im_start|>assistant"

def __init__(
self,
tokenizer: Any,
chat_template_kwargs: dict[str, Any] | None = None,
assistant_start_str: str | None = None,
allowed_append_roles: list[str] | None = None,
):
super().__init__(
tokenizer,
chat_template_kwargs,
assistant_start_str or self._default_assistant_start_str,
allowed_append_roles=allowed_append_roles,
)
nl_ids = tokenizer.encode("\n", add_special_tokens=False)
assert len(nl_ids) == 1, f"Expected single newline token, got {nl_ids}"
self._newline_id: int = nl_ids[0]
self._im_end_id: int = tokenizer.convert_tokens_to_ids("<|im_end|>")
self.trailing_token_ids = frozenset({self._newline_id})

def merge_tokens(
self,
old_messages: list[dict[str, Any]],
new_messages: list[dict[str, Any]],
pretokenized_token_ids: list[int],
tools: list[dict[str, Any]] | None = None,
) -> list[int]:
incremental = self.tokenize_additional_non_assistant(old_messages, new_messages, tools)
prefix = list(pretokenized_token_ids)
if prefix and prefix[-1] == self._im_end_id:
prefix.append(self._newline_id)
return prefix + incremental


# ---------------------------------------------------------------------------
# Enum + Registry + Factory
# ---------------------------------------------------------------------------
Expand All @@ -348,12 +405,14 @@ class TITOTokenizerType(str, Enum):
DEFAULT = "default"
QWEN3 = "qwen3"
GLM47 = "glm47"
K2V3 = "k2v3"


_TOKENIZER_REGISTRY: dict[TITOTokenizerType, type[TITOTokenizer]] = {
TITOTokenizerType.DEFAULT: TITOTokenizer,
TITOTokenizerType.QWEN3: Qwen3TITOTokenizer,
TITOTokenizerType.GLM47: GLM47TITOTokenizer,
TITOTokenizerType.K2V3: K2V3TITOTokenizer,
}


Expand Down
Loading
Loading