From 2f47e102e8e80a6d143022975ceb011a15deca1d Mon Sep 17 00:00:00 2001 From: Zhenting Wang Date: Thu, 7 May 2026 08:50:54 +0000 Subject: [PATCH] Add K2V3TITOTokenizer for K2V3 and the TITO test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The K2V3 chat template emits `<|im_end|>\n` after every message (jinja block whitespace between `{{- '<|im_end|>' }}` and the next block is preserved by default `trim_blocks`). The model autoregressively stops at `<|im_end|>` without producing the trailing `\n`. Without a per-model fix, the rollout buffer ends with `<|im_end|>` while the canonical chat-template render has `<|im_end|>\n` — diverging by exactly one `\n` token, which trips `update_pretokenized_state`'s prefix check. `K2V3TITOTokenizer.merge_tokens` mirrors `Qwen3TITOTokenizer`'s strategy: insert `\n` when `prefix[-1] == im_end_id`. Standalone subclass (rather than alias of Qwen3) so future K2V3-specific divergences have a clean hook. `tests/fast/utils/chat_template_utils/test_tito_k2v3.py` — K2V3-focused contract test suite (54 parametrized cases) verifying the boundary fix is alive and effective, that the K2V3 chat template round-trips cleanly through the real SGLang parsers, and that the full rollout flow stays consistent end-to-end. Coverage: - 8 trajectory shapes (single_tool / multi_turn / multi_tool_single_turn / multi_tool_multi_turn, each with native or synthesized thinking variant). Each runs a 2-phase check: finalized buffer vs canonical, plus a synthetic env follow-up that forces the boundary fix path even on single-turn shapes (defeating trim_trailing_ids that would otherwise hide missing-fix bugs). - 8 trajectories x 4 env append shapes = 32 cross-cases driving realistic <|im_end|>-end buffers through prepare_pretokenized -> merge_tokens with various env-delta patterns (single tool / user / system / alternating mixed). Each case also asserts the env content markers appear in the incremental tokens in order. - 8 trajectories run through real SGLang ReasoningParser (deepseek-r1) + FunctionCallParser (hermes), verifying chat template <-> parser structural round-trip on each shape's first assistant message. - 4 boss-level integration flows: drive every assistant turn of a multi-turn trajectory through real parsers (so session.messages accumulates parser-derived parsed_msg across turns), then append a complex env follow-up. Catches integration regressions that only surface in the full flow. - Sanity: production prefix-check defense fires on intentional violation; K2V3 enum value is correctly wired in _TOKENIZER_REGISTRY. Reverse-validated: with the `\n` boundary insertion commented out, all 44 trajectory + append + boss cases fail; the 8 parser cases (which don't drive merge_tokens) and the 2 fix-independent sanity cases still pass. Confirms the suite is genuinely exercising the boundary fix. --- .../chat_template_utils/tito_tokenizer.py | 59 + .../chat_template_utils/test_tito_k2v3.py | 1200 +++++++++++++++++ 2 files changed, 1259 insertions(+) create mode 100644 tests/fast/utils/chat_template_utils/test_tito_k2v3.py diff --git a/miles/utils/chat_template_utils/tito_tokenizer.py b/miles/utils/chat_template_utils/tito_tokenizer.py index 48a564314a..837be090e3 100644 --- a/miles/utils/chat_template_utils/tito_tokenizer.py +++ b/miles/utils/chat_template_utils/tito_tokenizer.py @@ -339,6 +339,63 @@ def merge_tokens( return prefix + incremental +# --------------------------------------------------------------------------- +# K2V3 family implementation +# --------------------------------------------------------------------------- + + +class K2V3TITOTokenizer(TITOTokenizer): + """K2V3 family. + + The chat template emits ``<|im_end|>\\n`` after every message (jinja + block whitespace between ``{{- '<|im_end|>' }}`` and the next block + is preserved by default ``trim_blocks``), but the model + autoregressively stops at ``<|im_end|>`` without generating the + trailing ``\\n``. ``merge_tokens`` inserts the missing newline so the + pretokenized buffer matches the canonical template output. + + Empirical sanity check:: + + apply_chat_template([user, assistant, user], tokenize=False) + → '...hello<|im_end|>\\n<|im_start|>user\\n...' + ^^ + """ + + _default_assistant_start_str: str = "<|im_start|>assistant" + + def __init__( + self, + tokenizer: Any, + chat_template_kwargs: dict[str, Any] | None = None, + assistant_start_str: str | None = None, + allowed_append_roles: list[str] | None = None, + ): + super().__init__( + tokenizer, + chat_template_kwargs, + assistant_start_str or self._default_assistant_start_str, + allowed_append_roles=allowed_append_roles, + ) + nl_ids = tokenizer.encode("\n", add_special_tokens=False) + assert len(nl_ids) == 1, f"Expected single newline token, got {nl_ids}" + self._newline_id: int = nl_ids[0] + self._im_end_id: int = tokenizer.convert_tokens_to_ids("<|im_end|>") + self.trailing_token_ids = frozenset({self._newline_id}) + + def merge_tokens( + self, + old_messages: list[dict[str, Any]], + new_messages: list[dict[str, Any]], + pretokenized_token_ids: list[int], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + incremental = self.tokenize_additional_non_assistant(old_messages, new_messages, tools) + prefix = list(pretokenized_token_ids) + if prefix and prefix[-1] == self._im_end_id: + prefix.append(self._newline_id) + return prefix + incremental + + # --------------------------------------------------------------------------- # Enum + Registry + Factory # --------------------------------------------------------------------------- @@ -348,12 +405,14 @@ class TITOTokenizerType(str, Enum): DEFAULT = "default" QWEN3 = "qwen3" GLM47 = "glm47" + K2V3 = "k2v3" _TOKENIZER_REGISTRY: dict[TITOTokenizerType, type[TITOTokenizer]] = { TITOTokenizerType.DEFAULT: TITOTokenizer, TITOTokenizerType.QWEN3: Qwen3TITOTokenizer, TITOTokenizerType.GLM47: GLM47TITOTokenizer, + TITOTokenizerType.K2V3: K2V3TITOTokenizer, } diff --git a/tests/fast/utils/chat_template_utils/test_tito_k2v3.py b/tests/fast/utils/chat_template_utils/test_tito_k2v3.py new file mode 100644 index 0000000000..b607f3df6f --- /dev/null +++ b/tests/fast/utils/chat_template_utils/test_tito_k2v3.py @@ -0,0 +1,1200 @@ +"""TITO contract tests for the K2V3 family. + +Combines two test patterns from elsewhere in this repo: + + * From ``test_tito_tokenizer_model_matrix.py`` — + breadth: realistic conversation shapes (single tool, multi-turn, + parallel tools, with/without thinking) drawn from + ``miles.utils.test_utils.mock_trajectories``. + + * From the rollout-side principle test (originally drafted for the + agentic-rl integration) — + depth: the rollout buffer used as ``pretokenized`` is built + through ``LinearTrajectory.update_pretokenized_state`` with + completion_token_ids that mirror what SGLang's + ``output_token_logprobs`` carries on a real autoregressive emit + (i.e. token sequence ending at ``<|im_end|>`` WITHOUT the trailing + ``\\n`` that the chat template's jinja whitespace adds). + +Why this matters for K2V3 specifically: + +The K2V3 chat template emits ``<|im_end|>\\n`` after every message +(verified empirically; the ``\\n`` comes from jinja block whitespace +between ``{{- '<|im_end|>' }}`` and the next template block). The model +autoregressively stops at ``<|im_end|>`` without producing the trailing +``\\n``. ``K2V3TITOTokenizer.merge_tokens`` inserts the missing newline +to keep the buffer aligned with the canonical chat-template render. + +The existing ``test_tito_tokenizer_model_matrix.py`` does NOT exercise +this fix because its ``pretokenized`` is computed with +``apply_chat_template(old_messages, add_generation_prompt=False)`` — +that render already includes the ``\\n``, so ``prefix[-1]`` is ``\\n`` and +the boundary fix never fires (and the test passes whether the fix exists +or not). This file closes that gap by routing through +``update_pretokenized_state``, which produces the realistic +``prefix[-1] == <|im_end|>`` state that requires the fix. + +Skips at module level if the K2V3 checkpoint is not on this host. +""" + +from __future__ import annotations + +import os +from copy import deepcopy +from dataclasses import dataclass + +import pytest +from transformers import AutoTokenizer + +from miles.rollout.session.linear_trajectory import LinearTrajectory +from miles.rollout.session.session_errors import TokenizationError +from miles.utils.chat_template_utils import ( + MismatchType, + apply_chat_template, + try_get_fixed_chat_template, +) +from miles.utils.chat_template_utils.tito_tokenizer import ( + TITOTokenizerType, + get_tito_tokenizer, +) +from miles.utils.processing_utils import load_tokenizer +from miles.utils.test_utils.mock_trajectories import ( + LongChainThinkingTrajectory, + LongChainTrajectory, + MultiToolSingleTurnTrajectory, + MultiTurnThinkingTrajectory, + MultiTurnTrajectory, + SingleToolThinkingTrajectory, + SingleToolTrajectory, +) + + +# --------------------------------------------------------------------------- +# Path + fixtures +# --------------------------------------------------------------------------- + +K2V3_MODEL_PATH = os.environ.get( + "TITO_TEST_MODEL_PATH_K2V3", + "/mnt/weka/shrd/k2m/suqi.sun/bbq_image/bbq-8b-mid3-final", +) +_ALLOWED_APPEND_ROLES = ["tool", "user", "system"] + +# K2V3 chat template's generation prompt depends on reasoning_effort +# (high → , medium → , low → ). Production +# runs with high effort; pinning here so test is deterministic regardless +# of any future template-default change. Override via env if needed. +_K2V3_REASONING_EFFORT = os.environ.get("TITO_TEST_REASONING_EFFORT_K2V3", "high") +_K2V3_CHAT_TEMPLATE_KWARGS = {"reasoning_effort": _K2V3_REASONING_EFFORT} + +# Per-K2V3 SGLang parser names. Defaults match the K2V3 production +# config: +# SGLANG_TOOL_PARSER=hermes +# SGLANG_REASONING_PARSER=deepseek-r1 +# Both rely on `...` (deepseek-r1) and the hermes +# `\n{json}\n` shape that K2V3's chat template emits. +# +# Older SGLang builds may register `hermes` under a different name (e.g. +# the qwen25 detector handles the same shape). Override via env in those +# environments — e.g. ``TITO_TEST_TOOL_PARSER_K2V3=qwen25``. If the +# configured parser is not registered in this SGLang build, the parser +# round-trip test skips with an explicit reason rather than silently +# turning green. +_K2V3_TOOL_PARSER = os.environ.get("TITO_TEST_TOOL_PARSER_K2V3", "hermes") +_K2V3_REASONING_PARSER = os.environ.get("TITO_TEST_REASONING_PARSER_K2V3", "deepseek-r1") + + +@pytest.fixture(scope="module") +def tokenizer() -> AutoTokenizer: + if not os.path.isdir(K2V3_MODEL_PATH): + pytest.skip(f"K2V3 checkpoint not present on this host: {K2V3_MODEL_PATH}") + return load_tokenizer( + K2V3_MODEL_PATH, + chat_template_path=try_get_fixed_chat_template(K2V3_MODEL_PATH), + trust_remote_code=True, + ) + + +@pytest.fixture +def tito_tok(tokenizer): + return get_tito_tokenizer( + tokenizer, + tokenizer_type=TITOTokenizerType.K2V3, + allowed_append_roles=_ALLOWED_APPEND_ROLES, + chat_template_kwargs=_K2V3_CHAT_TEMPLATE_KWARGS, + ) + + +# --------------------------------------------------------------------------- +# Trajectories — realistic conversation shapes from mock_trajectories +# --------------------------------------------------------------------------- + +def _with_synthetic_thinking( + trajectory_cls: type, + reasoning: str = "Let me work through this step by step.", +) -> type: + """Synthesize a thinking variant by injecting ``reasoning_content`` on + each assistant message of the trajectory. + + Used to build coverage shapes that ``mock_trajectories`` doesn't ship + a native thinking variant for (e.g. multi-tool single-turn with + thinking — production exercises this combination but no native + fixture exists). + """ + new_messages = deepcopy(trajectory_cls.MESSAGES) + for m in new_messages: + if m.get("role") == "assistant": + m["reasoning_content"] = reasoning + + class _Synthesized: + TOOLS = deepcopy(getattr(trajectory_cls, "TOOLS", None)) + MESSAGES = new_messages + + _Synthesized.__name__ = trajectory_cls.__name__ + "_WithSyntheticThinking" + return _Synthesized + + +# Native + synthetic-thinking-injected trajectories. Each entry exercises a +# distinct rollout shape; the thinking variants additionally trigger the +# K2V3 chat template's reasoning-block path (<|im_start|>assistant\n\n +# ... \ncontent<|im_end|>). +CONVERSATIONS: list[tuple[str, type]] = [ + # Single assistant turn — single tool call. + ("single_tool", SingleToolTrajectory), + ("single_tool_thinking", SingleToolThinkingTrajectory), + # Multiple assistant turns — single tool call per turn. + ("multi_turn", MultiTurnTrajectory), + ("multi_turn_thinking", MultiTurnThinkingTrajectory), + # Single assistant turn — multiple parallel tool calls. + ("multi_tool_single_turn", MultiToolSingleTurnTrajectory), + # No native thinking variant exists for parallel-tools-single-turn; + # synthesize by injecting reasoning_content into the assistant turn. + ("multi_tool_single_turn_thinking", + _with_synthetic_thinking(MultiToolSingleTurnTrajectory)), + # Multiple assistant turns AND tool calls (chain shape). + ("multi_tool_multi_turn", LongChainTrajectory), + ("multi_tool_multi_turn_thinking", LongChainThinkingTrajectory), +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _render_text( + messages: list[dict], + tokenizer: AutoTokenizer, + tools: list[dict] | None, + *, + add_generation_prompt: bool, +) -> str: + """``apply_chat_template(...) → str`` with K2V3 chat_template_kwargs auto-applied.""" + return apply_chat_template( + messages, + tokenizer=tokenizer, + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=False, + **_K2V3_CHAT_TEMPLATE_KWARGS, + ) + + +def _render_ids( + messages: list[dict], + tokenizer: AutoTokenizer, + tools: list[dict] | None, + *, + add_generation_prompt: bool, +) -> list[int]: + """``apply_chat_template(...) → list[int]`` with K2V3 chat_template_kwargs auto-applied.""" + return list( + apply_chat_template( + messages, + tokenizer=tokenizer, + tools=tools, + add_generation_prompt=add_generation_prompt, + tokenize=True, + **_K2V3_CHAT_TEMPLATE_KWARGS, + ) + ) + + +def _first_diff(a, b) -> str: + for i in range(min(len(a), len(b))): + if a[i] != b[i]: + return f"position {i}: a[{i}]={a[i]} b[{i}]={b[i]}" + return f"length differs (len(a)={len(a)} len(b)={len(b)})" + + +def _assistant_indices(messages: list[dict]) -> list[int]: + return [i for i, m in enumerate(messages) if m["role"] == "assistant"] + + +def _realistic_emit_ids( + request_messages: list[dict], + assistant_message: dict, + tools: list[dict] | None, + tokenizer: AutoTokenizer, +) -> list[int]: + """Synthesize completion_token_ids that mirror SGLang's autoregressive emit. + + The model emits starting from inside the assistant generation prompt + and stops at ``<|im_end|>`` (no trailing ``\\n``). We compute this by + diffing two chat-template renders: + + full = render(request + [assistant], add_generation_prompt=False) + prompt = render(request, add_generation_prompt=True) + emit_text = full[len(prompt):] # what model would emit + emit_text = emit_text.rstrip("\\n") # strip jinja's trailing \\n + assert emit_text.endswith("<|im_end|>") + emit_ids = tokenizer.encode(emit_text) + """ + full_text = _render_text( + request_messages + [assistant_message], tokenizer, tools, + add_generation_prompt=False, + ) + prompt_text = _render_text( + request_messages, tokenizer, tools, + add_generation_prompt=True, + ) + assert full_text.startswith(prompt_text), ( + "chat template not append-only: prompt-only render is not a prefix " + "of full render. TITO's premise breaks here." + ) + emit_text = full_text[len(prompt_text):] + # Strip the trailing newline(s) the jinja whitespace adds after + # `<|im_end|>`. The model autoregressively stops at the stop token + # without producing them. + emit_text_stop = emit_text.rstrip("\n") + assert emit_text_stop.endswith("<|im_end|>"), ( + f"unexpected emit_text shape (does not end with <|im_end|>): " + f"{emit_text_stop!r}" + ) + return list(tokenizer.encode(emit_text_stop, add_special_tokens=False)) + + +def _drive_session_through_trajectory( + session: LinearTrajectory, + tito_tok, + messages: list[dict], + tools: list[dict] | None, +) -> None: + """Drive ``session`` turn-by-turn using the trajectory's messages. + + For each assistant message in the trajectory, builds the realistic + emit_ids and calls ``update_pretokenized_state`` exactly as production + does. After this call, ``session.token_ids`` reflects what the rollout + buffer would hold mid-conversation. + """ + for asst_idx in _assistant_indices(messages): + request_messages = messages[:asst_idx] + assistant_message = messages[asst_idx] + + pre = session.prepare_pretokenized(request_messages, tools, tito_tokenizer=tito_tok) + if pre is None: + prompt_ids = _render_ids( + request_messages, tito_tok.tokenizer, tools, + add_generation_prompt=True, + ) + else: + prompt_ids = list(pre["input_ids"]) + + emit_ids = _realistic_emit_ids( + request_messages, assistant_message, tools, tito_tok.tokenizer + ) + + session.update_pretokenized_state( + request_messages=request_messages, + assistant_message=assistant_message, + prompt_token_ids=prompt_ids, + completion_token_ids=emit_ids, + max_trim_tokens=tito_tok.max_trim_tokens, + ) + + +# =========================================================================== +# Tests +# =========================================================================== + + +@pytest.mark.parametrize( + "name, trajectory_cls", + CONVERSATIONS, + ids=lambda x: x if isinstance(x, str) else None, +) +def test_buffer_matches_canonical_under_realistic_rollout(name, trajectory_cls, tito_tok): + """Drive the trajectory through ``update_pretokenized_state`` with + ``<|im_end|>``-terminated completion_ids (the realistic autoregressive-stop + shape). Then compare ``session.token_ids`` to the chat-template canonical + via the per-model comparator. + + The boundary fix in ``K2V3TITOTokenizer.merge_tokens`` (inserting ``\\n`` + when ``prefix[-1] == <|im_end|>``) IS exercised here, because the + ``pretokenized`` argument the trajectory builds via + ``update_pretokenized_state`` actually ends at ``<|im_end|>`` between + turns. + + Allows ``ASSISTANT_TEXT`` mismatches (BPE-merge noise from autoregressive + emission, classified non-severe by the comparator itself). Fails on any + ``SPECIAL_TOKEN_COUNT`` / ``SPECIAL_TOKEN_TYPE`` / ``NON_ASSISTANT_TEXT`` + mismatch — those indicate the per-model boundary fix is wrong, not BPE + noise. + """ + messages = deepcopy(trajectory_cls.MESSAGES) + tools = deepcopy(getattr(trajectory_cls, "TOOLS", None)) + + session = LinearTrajectory() + _drive_session_through_trajectory(session, tito_tok, messages, tools) + + comparator = tito_tok.create_comparator() + + # Phase 1 — finalized buffer vs canonical (covers structural drift in the + # whole trajectory, but the comparator's ``trim_trailing_ids`` hides + # end-of-sequence ``<|im_end|>`` vs ``<|im_end|>\\n`` differences if the + # trajectory has only ONE assistant turn). + expected_final = _render_ids( + session.messages, tito_tok.tokenizer, tools, + add_generation_prompt=False, + ) + actual_final = list(session.token_ids) + severe_final = [ + m for m in comparator.compare_sequences(expected_final, actual_final) + if m.type != MismatchType.ASSISTANT_TEXT + ] + if severe_final: + details = "\n".join( + f" {m.type.value} at segment {m.segment_index}: " + f"expected={m.expected_text!r} actual={m.actual_text!r}" + + (f" — {m.detail}" if m.detail else "") + for m in severe_final[:5] + ) + pytest.fail( + f"K2V3 [{name}] phase-1 (finalized buffer) canonical mismatch.\n" + f" first_diff: {_first_diff(expected_final, actual_final)}\n{details}" + ) + + # Phase 2 — force the boundary fix path even for single-assistant-turn + # trajectories: simulate a NEXT-turn env append by calling + # ``prepare_pretokenized`` with one extra ``tool`` message. This triggers + # ``tito_tok.merge_tokens(...)`` against a buffer whose last token is + # ``<|im_end|>`` (the model's autoregressive stop), which is the + # production state the boundary fix exists for. The follow-up moves the + # ``<|im_end|>`` from end-of-sequence to mid-sequence, defeating + # ``trim_trailing_ids`` and surfacing missing-fix bugs that phase 1 + # would hide. + follow_up = {"role": "tool", "content": "[test] synthetic follow-up env"} + extended_messages = list(session.messages) + [follow_up] + pre = session.prepare_pretokenized(extended_messages, tools, tito_tokenizer=tito_tok) + assert pre is not None, ( + f"K2V3 [{name}] phase-2 setup error: prepare_pretokenized returned " + f"None even though session has {len(session.messages)} stored messages" + ) + merged = list(pre["input_ids"]) + expected_next = _render_ids( + extended_messages, tito_tok.tokenizer, tools, + add_generation_prompt=True, + ) + severe_next = [ + m for m in comparator.compare_sequences(expected_next, merged) + if m.type != MismatchType.ASSISTANT_TEXT + ] + if severe_next: + details = "\n".join( + f" {m.type.value} at segment {m.segment_index}: " + f"expected={m.expected_text!r} actual={m.actual_text!r}" + + (f" — {m.detail}" if m.detail else "") + for m in severe_next[:5] + ) + pytest.fail( + f"K2V3 [{name}] phase-2 (next-turn merged input_ids) canonical " + f"mismatch — the per-model boundary fix is likely broken.\n" + f" first_diff: {_first_diff(expected_next, merged)}\n{details}" + ) + + +# =========================================================================== +# Append-case test — mirrors the breadth of test_tito_tokenizer_model_matrix.py +# but routes through ``update_pretokenized_state`` so the buffer used for +# ``merge_tokens`` has the realistic ``<|im_end|>``-end shape (defeats the +# comparator's ``trim_trailing_ids`` shielding that hides missing-fix bugs in +# the model_matrix variant). +# =========================================================================== + + +@dataclass(frozen=True) +class _EnvAppendShape: + """Generic env append shape — the messages to be appended after the + session has been driven through some trajectory.""" + name: str + appended_messages: list[dict] + required_contents: tuple[str, ...] + + +# Generic append shapes. Each gets cross-producted with every trajectory in +# CONVERSATIONS, so we exercise merge_tokens against many distinct buffer +# end-states (single tool, parallel tools, multi-turn with thinking, etc.) +# combined with each env shape (single tool / single user / single system / +# alternating). Strings inside ``required_contents`` are unique markers so +# the in-order check pinpoints exactly which env content the incremental +# tokens dropped if the test fails. +_ENV_APPEND_SHAPES: list[_EnvAppendShape] = [ + _EnvAppendShape( + name="env_tool", + appended_messages=[ + {"role": "tool", "tool_call_id": "call_test_xyz", + "content": "_marker_tool_xyz_42_"}, + ], + required_contents=("_marker_tool_xyz_42_",), + ), + _EnvAppendShape( + name="env_user", + appended_messages=[ + {"role": "user", "content": "_marker_user_abc_99_"}, + ], + required_contents=("_marker_user_abc_99_",), + ), + _EnvAppendShape( + name="env_system", + appended_messages=[ + {"role": "system", "content": "_marker_system_def_77_"}, + ], + required_contents=("_marker_system_def_77_",), + ), + _EnvAppendShape( + name="env_alternating_user_tool", + appended_messages=[ + {"role": "tool", "tool_call_id": "call_alt_1", + "content": "_marker_alt_tool1_aaa_"}, + {"role": "user", "content": "_marker_alt_user1_bbb_"}, + {"role": "tool", "tool_call_id": "call_alt_2", + "content": "_marker_alt_tool2_ccc_"}, + {"role": "user", "content": "_marker_alt_user2_ddd_"}, + ], + required_contents=( + "_marker_alt_tool1_aaa_", + "_marker_alt_user1_bbb_", + "_marker_alt_tool2_ccc_", + "_marker_alt_user2_ddd_", + ), + ), +] + + +@pytest.mark.parametrize( + "traj_name, traj_cls", CONVERSATIONS, + ids=lambda x: x if isinstance(x, str) else None, +) +@pytest.mark.parametrize( + "env_shape", _ENV_APPEND_SHAPES, ids=lambda s: s.name, +) +def test_append_via_realistic_buffer(traj_name, traj_cls, env_shape, tito_tok): + """Cross-product: each trajectory shape × each env append shape. + + Drive the trajectory through ``update_pretokenized_state`` so the stored + buffer ends at ``<|im_end|>`` (the autoregressive-stop shape). Then + call ``prepare_pretokenized`` with the env append messages — + triggering ``merge_tokens`` on a realistic buffer. + + The cross-product matters because ``merge_tokens``'s correctness depends + on BOTH the buffer's end-state shape (single-tool ending vs + parallel-tools ending vs thinking ending) AND the env shape + (tool / user / system / mixed). 8 trajectories × 4 env shapes = 32 + distinct ``merge_tokens`` invocation contexts. + + Verifies: + 1. ``merged input_ids`` match ``apply_chat_template`` canonical + (modulo BPE-noise ``ASSISTANT_TEXT`` mismatches). + 2. Each ``required_content`` marker appears in the incremental + tokens IN ORDER — catches "merge_tokens dropped an env message + or scrambled order". + """ + messages = deepcopy(traj_cls.MESSAGES) + tools = deepcopy(getattr(traj_cls, "TOOLS", None)) + + session = LinearTrajectory() + _drive_session_through_trajectory(session, tito_tok, messages, tools) + + pretokenized_buffer = list(session.token_ids) + assert ( + pretokenized_buffer + and pretokenized_buffer[-1] == tito_tok._im_end_id + ), ( + f"K2V3 [{traj_name} + {env_shape.name}] setup error: pretokenized " + f"buffer should end at <|im_end|> after drive, got last token " + f"{pretokenized_buffer[-1] if pretokenized_buffer else 'EMPTY'}" + ) + + extended = list(session.messages) + list(env_shape.appended_messages) + pre = session.prepare_pretokenized(extended, tools, tito_tokenizer=tito_tok) + assert pre is not None, ( + f"K2V3 [{traj_name} + {env_shape.name}] setup error: " + f"prepare_pretokenized returned None despite stored token_ids of " + f"length {len(pretokenized_buffer)}" + ) + merged = list(pre["input_ids"]) + + expected = _render_ids( + extended, tito_tok.tokenizer, tools, + add_generation_prompt=True, + ) + + comparator = tito_tok.create_comparator() + severe = [ + m for m in comparator.compare_sequences(expected, merged) + if m.type != MismatchType.ASSISTANT_TEXT + ] + if severe: + details = "\n".join( + f" {m.type.value} at segment {m.segment_index}: " + f"expected={m.expected_text!r} actual={m.actual_text!r}" + + (f" — {m.detail}" if m.detail else "") + for m in severe[:5] + ) + pytest.fail( + f"K2V3 [{traj_name} + {env_shape.name}] merged-vs-canonical " + f"mismatch under realistic buffer.\n" + f" first_diff: {_first_diff(expected, merged)}\n{details}" + ) + + # required-contents-in-order check on the incremental segment. + incremental_text = tito_tok.tokenizer.decode( + merged[len(pretokenized_buffer):], skip_special_tokens=False + ) + cursor = 0 + for content in env_shape.required_contents: + found = incremental_text.find(content, cursor) + assert found >= 0, ( + f"K2V3 [{traj_name} + {env_shape.name}] required_content " + f"{content!r} missing from incremental tokens (or out of order). " + f"incremental_text={incremental_text!r}" + ) + cursor = found + len(content) + + +# =========================================================================== +# Real-SGLang-parser round-trip +# +# Production data flow (server-side parsing): +# +# model raw text +# → ReasoningParser → reasoning_content + remaining text +# → FunctionCallParser → content + tool_calls +# → structured assistant_message stored in session.messages +# → next turn's chat_template re-renders that structured message +# back into text — which feeds the canonical compare +# +# If parser output drifts from what chat_template would re-emit (whitespace +# stripping, reasoning-block boundaries, tool_call argument formatting), +# the structured message in history no longer round-trips, and either: +# (a) the chat_template renders it differently from the original raw +# emit → buffer-vs-canonical mismatch on subsequent turns, or +# (b) the chat_template raises (e.g. "tool_call.arguments must be a +# dict, not a string" — K2V3's chat template enforces this). +# =========================================================================== + + +# (Parser config is declared at the top of the file alongside K2V3_MODEL_PATH.) + +_TEST_TOOL_DICT = { + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two integers and return the product.", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + }, + }, +} + + +def _load_sglang_parsers(): + """Return (FunctionCallParser_cls, ReasoningParser_cls) — either may be + None if SGLang is missing the corresponding module. Caller decides + whether to skip.""" + fcp_cls = None + try: + from sglang.srt.function_call.function_call_parser import FunctionCallParser + fcp_cls = FunctionCallParser + except ImportError: + pass + rp_cls = None + try: + from sglang.srt.parser.reasoning_parser import ReasoningParser + rp_cls = ReasoningParser + except ImportError: + try: + from sglang.srt.reasoning_parser import ReasoningParser # older SGLang layout + rp_cls = ReasoningParser + except ImportError: + pass + return fcp_cls, rp_cls + + +def _try_json_decode_tool_args(tool_calls: list[dict]) -> list[dict]: + """K2V3's chat template requires ``tool_call.arguments`` to be a dict. + Hermes parser returns it as a JSON string. Decode for template + compatibility — this mirrors what production agent loops do.""" + import json + out = [] + for tc in tool_calls: + fn = tc.get("function", {}) + args = fn.get("arguments") + if isinstance(args, str): + try: + fn = {**fn, "arguments": json.loads(args)} + except Exception: + pass + out.append({**tc, "function": fn}) + return out + + +@pytest.mark.parametrize( + "traj_name, traj_cls", CONVERSATIONS, + ids=lambda x: x if isinstance(x, str) else None, +) +def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cls, tito_tok): + """Verify K2V3's chat template round-trips cleanly through the real + SGLang ReasoningParser + FunctionCallParser, parametrized over every + trajectory shape in ``CONVERSATIONS``. + + Per trajectory: take its first assistant message as the synthetic + ``truth_msg``, render via chat_template to get the raw model emit, + run real parsers on the raw emit to produce ``parsed_msg``, drive a + single-turn session with ``parsed_msg``, then verify + ``session.token_ids`` matches the canonical render of + ``[, parsed_msg]`` at the structural level. + + Each trajectory's first assistant exercises a distinct parser shape: + plain content / content + tool_calls / + reasoning / + parallel + tool_calls / etc. So the parser test naturally cross-covers the + same shapes the trajectory test does. + + Why ``ASSISTANT_TEXT`` mismatches are still excluded as non-severe + (consistent with the trajectory tests): empirically the + ``deepseek-r1`` reasoning parser does not ``rstrip`` the extracted + reasoning content, so a re-render via chat_template inserts an extra + ``\\n`` before ```` — purely whitespace inside the assistant + content segment. Production tolerates this (rollout buffer keeps the + raw emit; trainer trains on the raw emit; ``compute_session_mismatch`` + classifies it as ``ASSISTANT_TEXT`` and the strict CI check excludes + it). The structural special-token / non-assistant-text round-trip is + what matters and is what this test enforces. + + Skips if SGLang parser modules or specific parser names are unavailable + in this environment. + """ + FCP, RP = _load_sglang_parsers() + if FCP is None: + pytest.skip("sglang.srt.function_call.function_call_parser not importable") + + tokenizer = tito_tok.tokenizer + messages = deepcopy(traj_cls.MESSAGES) + tools = deepcopy(getattr(traj_cls, "TOOLS", None)) + + # Pick the first assistant message — that's our parser-test ``truth_msg``. + # The messages preceding it (system + user typically) are kept as the + # request prefix so the chat template renders in correct context. + first_asst_idx = next(i for i, m in enumerate(messages) if m["role"] == "assistant") + request_messages = messages[:first_asst_idx] + truth_msg = messages[first_asst_idx] + has_reasoning = bool(truth_msg.get("reasoning_content")) + + # 1) Render truth_msg via chat_template — that is the raw emit shape. + full_text = _render_text( + request_messages + [truth_msg], tokenizer, tools, + add_generation_prompt=False, + ) + prompt_text = _render_text( + request_messages, tokenizer, tools, + add_generation_prompt=True, + ) + assert full_text.startswith(prompt_text), ( + f"K2V3 [{traj_name}] chat template not append-only: prompt-only " + f"render is not a prefix of full render." + ) + raw_assistant_emit = full_text[len(prompt_text):].rstrip("\n") + assert raw_assistant_emit.endswith("<|im_end|>"), ( + f"K2V3 [{traj_name}] unexpected raw_assistant_emit shape: " + f"{raw_assistant_emit!r}" + ) + + # 2) Run real ReasoningParser on the raw emit (only if the trajectory's + # truth_msg actually has reasoning_content — otherwise there's no + # ... to extract). + text_after_reasoning = raw_assistant_emit + parsed_reasoning = "" + if _K2V3_REASONING_PARSER and has_reasoning: + if RP is None: + pytest.skip("sglang reasoning parser not importable") + try: + rp = RP(model_type=_K2V3_REASONING_PARSER) + except Exception as e: + pytest.skip( + f"reasoning parser {_K2V3_REASONING_PARSER!r} unsupported " + f"by this SGLang build: {e}" + ) + r_out, n_out = rp.parse_non_stream(raw_assistant_emit) + parsed_reasoning = r_out or "" + text_after_reasoning = n_out if n_out is not None else "" + + # 3) Run real FunctionCallParser on the post-reasoning text. + try: + from sglang.srt.entrypoints.openai.protocol import Tool as SGLangTool + except ImportError as e: + pytest.skip(f"sglang.srt.entrypoints.openai.protocol.Tool not importable: {e}") + sglang_tools = [SGLangTool(**t) for t in (tools or [])] + try: + fcp = FCP(tools=sglang_tools, tool_call_parser=_K2V3_TOOL_PARSER) + except Exception as e: + pytest.skip( + f"tool parser {_K2V3_TOOL_PARSER!r} unsupported by this SGLang " + f"build: {e}" + ) + normal_text, tool_call_items = fcp.parse_non_stream(text_after_reasoning) + parsed_content = normal_text if normal_text is not None else "" + parsed_tool_calls = [ + { + "id": f"call_{i}", + "type": "function", + "function": {"name": item.name, "arguments": item.parameters}, + } + for i, item in enumerate(tool_call_items) + ] + # Hermes returns arguments as a JSON string; K2V3 chat template requires + # a dict. Decoding here mirrors what a production agent loop does + # before storing the assistant message. + parsed_tool_calls = _try_json_decode_tool_args(parsed_tool_calls) + + parsed_msg: dict = { + "role": "assistant", + "content": parsed_content, + "tool_calls": parsed_tool_calls, + } + if has_reasoning: + parsed_msg["reasoning_content"] = parsed_reasoning + + # 4) Drive session with parser-derived assistant_message. + # ``raw_assistant_emit`` already ends with ``<|im_end|>`` (the model's + # autoregressive stop), so the tokenized form is the complete emit. + # Do NOT append ``tokenizer.eos_token_id`` — for K2V3 that is + # ``<|endoftext|>``, which the model never emits at turn boundary + # and would create a spurious extra special-token mismatch. + emit_ids = list(tokenizer.encode(raw_assistant_emit, add_special_tokens=False)) + prompt_ids = _render_ids( + request_messages, tokenizer, tools, add_generation_prompt=True, + ) + session = LinearTrajectory() + session.update_pretokenized_state( + request_messages=list(request_messages), + assistant_message=parsed_msg, + prompt_token_ids=prompt_ids, + completion_token_ids=emit_ids, + max_trim_tokens=tito_tok.max_trim_tokens, + ) + + # 5) Compare ``session.token_ids`` (rollout buffer with raw emit tokens) + # against ``apply_chat_template(session.messages)`` canonical (which + # re-renders parsed_msg back to text). Severe types only. + expected = _render_ids( + session.messages, tokenizer, tools, add_generation_prompt=False, + ) + actual = list(session.token_ids) + comparator = tito_tok.create_comparator() + mismatches = comparator.compare_sequences(expected, actual) + severe = [m for m in mismatches if m.type != MismatchType.ASSISTANT_TEXT] + if severe: + details = "\n".join( + f" {m.type.value} at segment {m.segment_index}: " + f"expected={m.expected_text!r} actual={m.actual_text!r}" + + (f" — {m.detail}" if m.detail else "") + for m in severe[:8] + ) + pytest.fail( + f"K2V3 [{traj_name}] chat-template ↔ SGLang parser structural " + f"round-trip mismatch (tool_parser={_K2V3_TOOL_PARSER!r}, " + f"reasoning_parser={_K2V3_REASONING_PARSER!r}). " + f"Severe types only — ASSISTANT_TEXT-only mismatches are " + f"tolerated (whitespace inside assistant content; production " + f"already classifies these as non-severe).\n" + f"{details}\n" + f"({len(severe)} severe mismatch(es) total; " + f"showing first {min(8, len(severe))}.)" + ) + + +# =========================================================================== +# End-to-end "boss" smoke tests +# +# These chain everything together: real SGLang parsers running on each +# assistant turn, parser-derived ``parsed_msg`` accumulating in +# ``session.messages`` across multiple turns, and a complex env follow-up +# at the end that triggers ``prepare_pretokenized → merge_tokens`` against +# a session whose history has been touched by the parser. +# +# The focused tests above cover each invariant in isolation. These boss +# tests exist to catch integration regressions that only surface in the +# full flow — specifically: +# +# - parser-derived ``parsed_msg`` (with whatever whitespace shifts the +# parser introduces) being stored in ``session.messages``, +# - the next-turn ``prepare_pretokenized`` then walking +# ``assert_messages_append_only_with_allowed_role`` against that +# parser-derived history, +# - and the final env follow-up driving ``merge_tokens`` over a buffer +# that has accumulated multi-turn parser-derived content. +# +# Each flow uses a different "most complex" combination: multi-turn + +# thinking + parallel tools + various env follow-up shapes. +# =========================================================================== + + +@dataclass(frozen=True) +class _BossFlow: + name: str + trajectory_cls: type + final_env: list[dict] + + +# Build the synthesized thinking variant of the parallel-tools trajectory +# at module load (so it's a stable type referenced in _BOSS_FLOWS). +_MultiToolSingleTurnThinking = _with_synthetic_thinking(MultiToolSingleTurnTrajectory) + + +_BOSS_FLOWS: list[_BossFlow] = [ + _BossFlow( + name="multi_turn_thinking + tool_followup", + trajectory_cls=MultiTurnThinkingTrajectory, + final_env=[ + {"role": "tool", "tool_call_id": "boss_call_1", + "content": "_boss_tool_followup_xyz_42_"}, + ], + ), + _BossFlow( + name="multi_tool_multi_turn_thinking + alternating_user_tool_followup", + trajectory_cls=LongChainThinkingTrajectory, + final_env=[ + {"role": "tool", "tool_call_id": "boss_call_2a", + "content": "_boss_alt_tool1_aaa_"}, + {"role": "user", "content": "_boss_alt_user1_bbb_"}, + {"role": "tool", "tool_call_id": "boss_call_2b", + "content": "_boss_alt_tool2_ccc_"}, + {"role": "user", "content": "_boss_alt_user2_ddd_"}, + ], + ), + _BossFlow( + name="multi_tool_single_turn_thinking + system_inject", + trajectory_cls=_MultiToolSingleTurnThinking, + final_env=[ + {"role": "system", + "content": "_boss_system_inject_def_77_"}, + ], + ), + _BossFlow( + name="multi_tool_multi_turn_thinking + complex_env_chain", + trajectory_cls=LongChainThinkingTrajectory, + final_env=[ + {"role": "tool", "tool_call_id": "boss_call_4a", + "content": "_boss_chain_tool1_AAA_"}, + {"role": "user", "content": "_boss_chain_user1_BBB_"}, + {"role": "tool", "tool_call_id": "boss_call_4b", + "content": "_boss_chain_tool2_CCC_"}, + {"role": "system", "content": "_boss_chain_system_DDD_"}, + {"role": "tool", "tool_call_id": "boss_call_4c", + "content": "_boss_chain_tool3_EEE_"}, + ], + ), +] + + +def _run_parsers_on_emit( + raw_emit: str, + tools: list[dict] | None, + *, + fcp_cls, + rp_cls, + has_reasoning: bool, +) -> tuple[str, list[dict], str]: + """Invoke real SGLang parsers on a raw assistant emit. Returns + (parsed_content, parsed_tool_calls, parsed_reasoning).""" + text_after_reasoning = raw_emit + parsed_reasoning = "" + if has_reasoning and _K2V3_REASONING_PARSER: + if rp_cls is None: + pytest.skip("sglang reasoning parser not importable") + try: + rp = rp_cls(model_type=_K2V3_REASONING_PARSER) + except Exception as e: + pytest.skip( + f"reasoning parser {_K2V3_REASONING_PARSER!r} unsupported " + f"by this SGLang build: {e}" + ) + r_out, n_out = rp.parse_non_stream(raw_emit) + parsed_reasoning = r_out or "" + text_after_reasoning = n_out if n_out is not None else "" + + try: + from sglang.srt.entrypoints.openai.protocol import Tool as SGLangTool + except ImportError as e: + pytest.skip(f"sglang.srt.entrypoints.openai.protocol.Tool not importable: {e}") + sglang_tools = [SGLangTool(**t) for t in (tools or [])] + try: + fcp = fcp_cls(tools=sglang_tools, tool_call_parser=_K2V3_TOOL_PARSER) + except Exception as e: + pytest.skip( + f"tool parser {_K2V3_TOOL_PARSER!r} unsupported by this SGLang " + f"build: {e}" + ) + normal_text, tool_call_items = fcp.parse_non_stream(text_after_reasoning) + parsed_content = normal_text if normal_text is not None else "" + parsed_tool_calls = [ + { + "id": f"call_{i}", + "type": "function", + "function": {"name": item.name, "arguments": item.parameters}, + } + for i, item in enumerate(tool_call_items) + ] + parsed_tool_calls = _try_json_decode_tool_args(parsed_tool_calls) + return parsed_content, parsed_tool_calls, parsed_reasoning + + +def _drive_one_assistant_turn_through_real_parsers( + session: LinearTrajectory, + tito_tok, + *, + fcp_cls, + rp_cls, + request_messages: list[dict], + truth_assistant_msg: dict, + tools: list[dict] | None, +) -> dict: + """Render ``truth_assistant_msg`` to raw_emit, parse it with real + SGLang parsers, build ``parsed_msg`` from parser output, drive the + session with ``parsed_msg`` (NOT ``truth_assistant_msg`` — production + stores parser output in messages history). Returns ``parsed_msg``. + """ + tokenizer = tito_tok.tokenizer + + full_text = _render_text( + request_messages + [truth_assistant_msg], tokenizer, tools, + add_generation_prompt=False, + ) + prompt_text = _render_text( + request_messages, tokenizer, tools, + add_generation_prompt=True, + ) + assert full_text.startswith(prompt_text), ( + f"chat template not append-only between " + f"render(request_messages) and render(request_messages + [truth_msg])" + ) + raw_emit = full_text[len(prompt_text):].rstrip("\n") + assert raw_emit.endswith("<|im_end|>"), ( + f"unexpected raw_emit shape: {raw_emit!r}" + ) + + has_reasoning = bool(truth_assistant_msg.get("reasoning_content")) + parsed_content, parsed_tool_calls, parsed_reasoning = _run_parsers_on_emit( + raw_emit, tools, fcp_cls=fcp_cls, rp_cls=rp_cls, has_reasoning=has_reasoning, + ) + + parsed_msg: dict = { + "role": "assistant", + "content": parsed_content, + "tool_calls": parsed_tool_calls, + } + if has_reasoning: + parsed_msg["reasoning_content"] = parsed_reasoning + + pre = session.prepare_pretokenized(request_messages, tools, tito_tokenizer=tito_tok) + if pre is None: + prompt_ids = _render_ids( + request_messages, tokenizer, tools, add_generation_prompt=True, + ) + else: + prompt_ids = list(pre["input_ids"]) + + emit_ids = list(tokenizer.encode(raw_emit, add_special_tokens=False)) + + session.update_pretokenized_state( + request_messages=list(request_messages), + assistant_message=parsed_msg, + prompt_token_ids=prompt_ids, + completion_token_ids=emit_ids, + max_trim_tokens=tito_tok.max_trim_tokens, + ) + return parsed_msg + + +@pytest.mark.parametrize("flow", _BOSS_FLOWS, ids=lambda f: f.name) +def test_end_to_end_realistic_rollout_with_real_parsers(flow: _BossFlow, tito_tok): + """Boss-level integration: drive every assistant turn of a multi-turn + trajectory through real SGLang parsers (so ``session.messages`` + accumulates parser-derived ``parsed_msg`` across turns, not the + original truth messages), then append a complex env follow-up that + triggers ``prepare_pretokenized → merge_tokens`` over the + parser-tainted history. Verify final state structural consistency. + + The 4 flows pick deliberately complex shapes: + - multi_turn_thinking + tool follow-up + - multi_tool_multi_turn_thinking + alternating user/tool chain + - multi_tool_single_turn_thinking (parallel tools + reasoning) + + system injection + - multi_tool_multi_turn_thinking + tool/user/system/tool chain + + Skips if SGLang parsers are not available (matches the per-shape + parser test's skip behavior). + """ + FCP, RP = _load_sglang_parsers() + if FCP is None: + pytest.skip("sglang.srt.function_call.function_call_parser not importable") + + messages = deepcopy(flow.trajectory_cls.MESSAGES) + tools = deepcopy(getattr(flow.trajectory_cls, "TOOLS", None)) + asst_indices = _assistant_indices(messages) + assert asst_indices, f"boss flow {flow.name} has no assistant turns" + + session = LinearTrajectory() + + # Track running messages — these become the request_messages prefix + # for each subsequent turn, with each prior turn's truth_assistant + # replaced by its parser-derived parsed_msg. + running_messages: list[dict] = [] + + for k, asst_idx in enumerate(asst_indices): + if k == 0: + # Pre-first-assistant: typically [system, user] + request_messages = list(messages[:asst_idx]) + else: + # Add env messages from the trajectory between previous + # assistant and this one (tool results, user follow-ups, etc.) + prev_asst_idx = asst_indices[k - 1] + env_between = list(messages[prev_asst_idx + 1 : asst_idx]) + request_messages = list(running_messages) + env_between + + truth_msg = messages[asst_idx] + parsed_msg = _drive_one_assistant_turn_through_real_parsers( + session, tito_tok, + fcp_cls=FCP, rp_cls=RP, + request_messages=request_messages, + truth_assistant_msg=truth_msg, + tools=tools, + ) + running_messages = list(request_messages) + [parsed_msg] + + # Final env follow-up — triggers prepare_pretokenized → merge_tokens + # over a session.messages that has been fully populated by parser- + # derived parsed_msg's. + extended = list(session.messages) + list(flow.final_env) + pre = session.prepare_pretokenized(extended, tools, tito_tokenizer=tito_tok) + assert pre is not None, ( + f"K2V3 [boss/{flow.name}] setup error: prepare_pretokenized " + f"returned None even though session has " + f"{len(session.messages)} stored messages" + ) + merged = list(pre["input_ids"]) + + expected = _render_ids( + extended, tito_tok.tokenizer, tools, add_generation_prompt=True, + ) + + comparator = tito_tok.create_comparator() + severe = [ + m for m in comparator.compare_sequences(expected, merged) + if m.type != MismatchType.ASSISTANT_TEXT + ] + if severe: + details = "\n".join( + f" {m.type.value} at segment {m.segment_index}: " + f"expected={m.expected_text!r} actual={m.actual_text!r}" + + (f" — {m.detail}" if m.detail else "") + for m in severe[:8] + ) + pytest.fail( + f"K2V3 [boss/{flow.name}] integration mismatch: " + f"merged input_ids vs canonical render diverge after multi-turn " + f"parser-driven flow.\n" + f" first_diff: {_first_diff(expected, merged)}\n{details}\n" + f"({len(severe)} severe mismatch(es) total; " + f"showing first {min(8, len(severe))}.)" + ) + + # Required-content marker check on the incremental segment — ensures + # the final env chain's content (which includes user/tool/system + # markers) actually flows into the incremental tokens in order. + pretokenized_buffer = list(session.token_ids) + incremental_text = tito_tok.tokenizer.decode( + merged[len(pretokenized_buffer):], skip_special_tokens=False + ) + cursor = 0 + for env_msg in flow.final_env: + marker = env_msg.get("content", "") + if not marker: + continue + found = incremental_text.find(marker, cursor) + assert found >= 0, ( + f"K2V3 [boss/{flow.name}] env marker {marker!r} missing " + f"from incremental tokens (or out of order). " + f"incremental_text={incremental_text!r}" + ) + cursor = found + len(marker) + + +def test_production_prefix_check_raises_on_intentional_violation(tito_tok): + """Validate that production's ``update_pretokenized_state`` prefix check + fires when fed prompt_token_ids that do not extend the stored prefix. + + If a refactor disables this check, this test fails — protecting the + runtime defense that catches the same class of bugs in real rollouts. + """ + session = LinearTrajectory() + user_q = {"role": "user", "content": "Test."} + asst1 = {"role": "assistant", "content": "ok"} + + # Seed: drive a single normal turn so the session has stored token_ids. + prompt_ids = _render_ids( + [user_q], tito_tok.tokenizer, tools=None, add_generation_prompt=True, + ) + eos = getattr(tito_tok.tokenizer, "eos_token_id", None) + completion_ids = list(tito_tok.tokenizer.encode("ok", add_special_tokens=False)) + if eos is not None and (not completion_ids or completion_ids[-1] != int(eos)): + completion_ids.append(int(eos)) + session.update_pretokenized_state( + request_messages=[user_q], + assistant_message=asst1, + prompt_token_ids=prompt_ids, + completion_token_ids=completion_ids, + max_trim_tokens=tito_tok.max_trim_tokens, + ) + + # Now feed bogus prompt_ids — completely different from what's stored. + bogus_prompt = [99999] * (len(session.token_ids) + 5) + bogus_completion = [12345] + asst2 = {"role": "assistant", "content": "next"} + tool_msg = {"role": "tool", "content": "irrelevant"} + + with pytest.raises(TokenizationError, match=r"pretokenized prefix mismatch"): + session.update_pretokenized_state( + request_messages=[user_q, asst1, tool_msg], + assistant_message=asst2, + prompt_token_ids=bogus_prompt, + completion_token_ids=bogus_completion, + max_trim_tokens=0, + ) + + +def test_k2v3_subclass_is_wired(tito_tok): + """Sanity: ``get_tito_tokenizer(..., TITOTokenizerType.K2V3)`` returns + the K2V3 subclass — not silently falling back to the base + ``TITOTokenizer``. Catches a future regression where the registry entry + is removed or pointed elsewhere.""" + from miles.utils.chat_template_utils.tito_tokenizer import K2V3TITOTokenizer + + assert isinstance(tito_tok, K2V3TITOTokenizer), ( + f"expected K2V3TITOTokenizer, got {type(tito_tok).__name__}. " + f"_TOKENIZER_REGISTRY[TITOTokenizerType.K2V3] may be misregistered." + )