From 8ad2b553f6fb94c676237616eada9c78767c353e Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Mon, 4 May 2026 14:46:21 -0700 Subject: [PATCH 1/8] initial commit --- eval_protocol/adapters/fireworks_tracing.py | 22 +- eval_protocol/adapters/r3_deserializer.py | 189 +++++++ .../pytest/remote_rollout_processor.py | 7 +- eval_protocol/pytest/tracing_utils.py | 17 +- .../types/remote_rollout_processor.py | 1 + pyproject.toml | 1 + tests/adapters/test_r3_deserializer.py | 479 ++++++++++++++++++ uv.lock | 10 +- 8 files changed, 714 insertions(+), 12 deletions(-) create mode 100644 eval_protocol/adapters/r3_deserializer.py create mode 100644 tests/adapters/test_r3_deserializer.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 2d8316d2..45fc2697 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -100,13 +100,29 @@ def convert_trace_dict_to_evaluation_row( ): break # Break early if we've found all the metadata we need + # Extract router replay payloads when present + payloads = trace.get("payloads") + if isinstance(payloads, dict): + router_replay = payloads.get("router_replay") + if isinstance(router_replay, dict) and router_replay.get("data"): + try: + from .r3_deserializer import decompress_and_parse_r3 + + matrices, r3_meta = decompress_and_parse_r3(router_replay["data"]) + if execution_metadata.extra is None: + execution_metadata.extra = {} + execution_metadata.extra["routing_matrices"] = matrices + execution_metadata.extra["routing_metadata"] = r3_meta + except Exception as e: + logger.warning("Failed to decompress R3 payload for trace %s: %s", trace.get("id"), e) + return EvaluationRow( messages=messages, tools=tools, input_metadata=InputMetadata( row_id=row_id, session_data={ - "langfuse_trace_id": trace.get("id"), # Store the trace ID here + "langfuse_trace_id": trace.get("id"), }, ), execution_metadata=execution_metadata, @@ -426,6 +442,7 @@ def get_evaluation_rows( max_retries: int = 3, span_name: Optional[str] = None, converter: Optional[TraceDictConverter] = None, + include_payloads: bool = False, ) -> List[EvaluationRow]: """Pull traces from Langfuse via proxy and convert to EvaluationRow format. @@ -449,6 +466,8 @@ def get_evaluation_rows( max_retries: Max retry attempts used by proxy (default: 3) converter: Optional custom converter implementing TraceDictConverter protocol. If provided, this will be used instead of the default conversion logic. + include_payloads: If True, request payload data (e.g., router replay) + from the gateway and decompress it into the returned EvaluationRows. Returns: List[EvaluationRow]: Converted evaluation rows @@ -479,6 +498,7 @@ def get_evaluation_rows( "to_timestamp": to_timestamp.isoformat() if to_timestamp else None, "sleep_between_gets": sleep_between_gets, "max_retries": max_retries, + "include_payloads": include_payloads if include_payloads else None, } # Remove None values diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py new file mode 100644 index 00000000..1b70d9f6 --- /dev/null +++ b/eval_protocol/adapters/r3_deserializer.py @@ -0,0 +1,189 @@ +"""R3/v1 binary deserializer for router-replay payloads. + +Implements the inverse of the packed binary format produced by the tracing +gateway's ``r3_serializer.serialize_r3``. See that module for the full +header specification. + +The main entry point is :func:`decompress_and_parse_r3`, which accepts the +base64-encoded compressed blob returned by the gateway's +``/v1/traces/pointwise?include_payloads=true`` endpoint and produces +per-token routing matrices in the same ``List[Optional[str]]`` format used +by the direct inference path (``DeploymentSampler.sample_with_tokens()``). +""" + +from __future__ import annotations + +import base64 +import logging +import math +import struct +from enum import IntEnum +from typing import Any, Dict, List, Optional, Tuple + +import zstandard as zstd + +logger = logging.getLogger(__name__) + +MAGIC = b"R3V1" +HEADER_FORMAT = "<4sBBBBIIHHIIQ" +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 36 bytes + + +class _SelectorMode(IntEnum): + ALL = 0 + SUFFIX = 1 + BITMAP = 2 + + +class _RoutingDtype(IntEnum): + UINT8 = 1 + UINT16 = 2 + + @property + def byte_width(self) -> int: + return self.value + + +_SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode} +_ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype} + + +def _parse_header(raw: bytes) -> Dict[str, Any]: + if len(raw) < HEADER_SIZE: + raise ValueError( + f"Payload too short for r3/v1 header: {len(raw)} < {HEADER_SIZE}" + ) + + ( + magic, + version, + selector_mode, + routing_dtype, + flags, + total_token_count, + replayed_token_count, + num_moe_layers, + top_k, + replay_start_token, + selector_byte_length, + matrix_byte_length, + ) = struct.unpack(HEADER_FORMAT, raw[:HEADER_SIZE]) + + if magic != MAGIC: + raise ValueError(f"Bad R3 magic: {magic!r}") + if version != 1: + raise ValueError(f"Unsupported R3 header version: {version}") + + return { + "selector_mode": selector_mode, + "routing_dtype": routing_dtype, + "flags": flags, + "total_token_count": total_token_count, + "replayed_token_count": replayed_token_count, + "num_moe_layers": num_moe_layers, + "top_k": top_k, + "replay_start_token": replay_start_token, + "selector_byte_length": selector_byte_length, + "matrix_byte_length": matrix_byte_length, + } + + +def _read_bitmap_positions( + selector_bytes: bytes, total_token_count: int +) -> List[int]: + """Return sorted token indices where the bitmap bit is set.""" + positions: List[int] = [] + for i in range(total_token_count): + byte_idx = i >> 3 + bit_idx = i & 7 + if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1: + positions.append(i) + return positions + + +def decompress_and_parse_r3( + data_b64: str, +) -> Tuple[List[Optional[str]], Dict[str, Any]]: + """Decompress and unpack an R3/v1 payload into per-token routing matrices. + + Args: + data_b64: Base64-encoded zstd-compressed R3 binary blob, as returned + by the tracing gateway in ``payloads.router_replay.data``. + + Returns: + A tuple of ``(routing_matrices, metadata)`` where: + + - ``routing_matrices`` is a ``List[Optional[str]]`` of length + ``total_token_count``. Each present position contains a + base64-encoded routing matrix (matching the format returned by + the direct inference path); absent positions are ``None``. + - ``metadata`` is a dict with keys ``num_moe_layers``, ``top_k``, + ``routing_dtype``, ``selector_mode``, ``total_token_count``, + ``replayed_token_count``, ``replay_start_token``. + """ + compressed = base64.b64decode(data_b64) + + decompressor = zstd.ZstdDecompressor() + raw = decompressor.decompress(compressed, max_output_size=len(compressed) * 20) + + header = _parse_header(raw) + + selector_mode = header["selector_mode"] + routing_dtype = header["routing_dtype"] + total_token_count = header["total_token_count"] + replayed_token_count = header["replayed_token_count"] + num_moe_layers = header["num_moe_layers"] + top_k = header["top_k"] + replay_start_token = header["replay_start_token"] + selector_byte_length = header["selector_byte_length"] + matrix_byte_length = header["matrix_byte_length"] + + dtype_byte_width = _RoutingDtype(routing_dtype).byte_width + matrix_elem_size = num_moe_layers * top_k * dtype_byte_width + + body = raw[HEADER_SIZE:] + selector_bytes = body[:selector_byte_length] + matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length] + + if matrix_elem_size == 0: + replayed_positions: List[int] = [] + elif selector_mode == _SelectorMode.ALL: + replayed_positions = list(range(total_token_count)) + elif selector_mode == _SelectorMode.SUFFIX: + replayed_positions = list( + range(replay_start_token, replay_start_token + replayed_token_count) + ) + elif selector_mode == _SelectorMode.BITMAP: + replayed_positions = _read_bitmap_positions(selector_bytes, total_token_count) + else: + raise ValueError(f"Unknown selector_mode: {selector_mode}") + + # Split matrix bytes into per-token chunks and base64-encode each one + matrices: List[Optional[str]] = [None] * total_token_count + for idx, pos in enumerate(replayed_positions): + start = idx * matrix_elem_size + end = start + matrix_elem_size + if end > len(matrix_bytes): + logger.warning( + "R3 matrix data truncated at token %d (position %d): " + "expected %d bytes but only %d remaining", + idx, pos, matrix_elem_size, len(matrix_bytes) - start, + ) + break + matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") + + metadata: Dict[str, Any] = { + "num_moe_layers": num_moe_layers, + "top_k": top_k, + "routing_dtype": _ROUTING_DTYPE_NAMES.get( + _RoutingDtype(routing_dtype), str(routing_dtype) + ), + "selector_mode": _SELECTOR_MODE_NAMES.get( + _SelectorMode(selector_mode), str(selector_mode) + ), + "total_token_count": total_token_count, + "replayed_token_count": replayed_token_count, + "replay_start_token": replay_start_token, + } + + return matrices, metadata diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 66e888ce..05e49b38 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -35,11 +35,13 @@ def __init__( model_base_url: str = "https://tracing.fireworks.ai", poll_interval: float = 1.0, timeout_seconds: float = 120.0, + include_payloads: bool = False, ): # Prefer constructor-provided configuration. These can be overridden via # config.kwargs at call time for backward compatibility. self._remote_base_url = remote_base_url self._model_base_url = model_base_url + self._include_payloads = include_payloads if os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL"): self._remote_base_url = os.getenv("EP_REMOTE_ROLLOUT_PROCESSOR_BASE_URL") _ep_model_base_url = os.getenv("EP_MODEL_BASE_URL") @@ -194,7 +196,10 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time def _update_with_trace() -> None: - return update_row_with_remote_trace(row, default_fireworks_output_data_loader, model_base_url) + return update_row_with_remote_trace( + row, default_fireworks_output_data_loader, model_base_url, + include_payloads=self._include_payloads, + ) await asyncio.to_thread(_update_with_trace) # Update row with remote trace in-place return row diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 7d6b1714..9ac8d501 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -22,7 +22,11 @@ def fetch_traces() -> List[EvaluationRow]: # Use EP_REMOTE_API_KEY for fetching remote traces, falling back to FIREWORKS_API_KEY api_key = os.environ.get("EP_REMOTE_API_KEY") or os.environ.get("FIREWORKS_API_KEY") adapter = FireworksTracingAdapter(base_url=base_url, api_key=api_key) - return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5) + return adapter.get_evaluation_rows( + tags=[f"rollout_id:{config.rollout_id}"], + max_retries=5, + include_payloads=config.include_payloads, + ) return DynamicDataLoader(generators=[fetch_traces], preprocess_fn=filter_longest_conversation) @@ -148,13 +152,20 @@ def build_init_request( def update_row_with_remote_trace( - row: EvaluationRow, output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader], model_base_url: str + row: EvaluationRow, + output_data_loader: Callable[[DataLoaderConfig], DynamicDataLoader], + model_base_url: str, + include_payloads: bool = False, ) -> None: """Update row with remote trace data using output_data_loader (shared logic).""" if not row.execution_metadata.rollout_id: return None - loader_config = DataLoaderConfig(rollout_id=row.execution_metadata.rollout_id, model_base_url=model_base_url) + loader_config = DataLoaderConfig( + rollout_id=row.execution_metadata.rollout_id, + model_base_url=model_base_url, + include_payloads=include_payloads, + ) data_loader = output_data_loader(loader_config) results = data_loader.load() output_rows: List[EvaluationRow] = [r for result in results for r in result.rows] diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index 03104c4a..3d86cbec 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -39,6 +39,7 @@ class DataLoaderConfig(BaseModel): rollout_id: str model_base_url: Optional[str] = None + include_payloads: bool = False class InitRequest(BaseModel): diff --git a/pyproject.toml b/pyproject.toml index e3d162a7..55dd3289 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "deepdiff>=6.0.0", "websockets>=15.0.1", "fastapi>=0.116.1", + "zstandard>=0.19.0", ] [project.urls] diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py new file mode 100644 index 00000000..26970466 --- /dev/null +++ b/tests/adapters/test_r3_deserializer.py @@ -0,0 +1,479 @@ +"""Tests for R3/v1 binary deserializer.""" + +from __future__ import annotations + +import base64 +import math +import struct +from typing import List, Optional + +import pytest +import zstandard as zstd + +from eval_protocol.adapters.r3_deserializer import ( + HEADER_FORMAT, + HEADER_SIZE, + MAGIC, + _SelectorMode, + _RoutingDtype, + _parse_header, + _read_bitmap_positions, + decompress_and_parse_r3, +) + + +def _make_raw_r3( + *, + selector_mode: int = _SelectorMode.ALL, + routing_dtype: int = _RoutingDtype.UINT8, + total_token_count: int = 4, + replayed_token_count: int = 4, + num_moe_layers: int = 2, + top_k: int = 2, + replay_start_token: int = 0, + selector_bytes: bytes = b"", + matrix_data: Optional[bytes] = None, +) -> bytes: + """Build a raw (uncompressed) R3/v1 payload for testing.""" + dtype_byte_width = _RoutingDtype(routing_dtype).byte_width + matrix_elem_size = num_moe_layers * top_k * dtype_byte_width + + if matrix_data is None: + matrix_data = bytes(range(matrix_elem_size)) * replayed_token_count + + header = struct.pack( + HEADER_FORMAT, + MAGIC, + 1, # version + selector_mode, + routing_dtype, + 0x01, # flags: little-endian + total_token_count, + replayed_token_count, + num_moe_layers, + top_k, + replay_start_token, + len(selector_bytes), + len(matrix_data), + ) + return header + selector_bytes + matrix_data + + +def _compress_and_b64(raw: bytes) -> str: + compressor = zstd.ZstdCompressor() + compressed = compressor.compress(raw) + return base64.b64encode(compressed).decode("ascii") + + +class TestParseHeader: + def test_valid_header(self): + raw = _make_raw_r3(total_token_count=10, replayed_token_count=5) + hdr = _parse_header(raw) + assert hdr["total_token_count"] == 10 + assert hdr["replayed_token_count"] == 5 + assert hdr["selector_mode"] == _SelectorMode.ALL + assert hdr["routing_dtype"] == _RoutingDtype.UINT8 + + def test_bad_magic(self): + raw = b"XXXX" + b"\x00" * (HEADER_SIZE - 4) + with pytest.raises(ValueError, match="Bad R3 magic"): + _parse_header(raw) + + def test_too_short(self): + with pytest.raises(ValueError, match="too short"): + _parse_header(b"\x00" * 10) + + def test_unsupported_version(self): + raw = struct.pack( + HEADER_FORMAT, + MAGIC, 99, 0, 1, 0, 4, 4, 2, 2, 0, 0, 16, + ) + with pytest.raises(ValueError, match="Unsupported R3 header version"): + _parse_header(raw) + + +class TestReadBitmapPositions: + def test_all_set(self): + bitmap = bytes([0xFF]) + positions = _read_bitmap_positions(bitmap, 8) + assert positions == list(range(8)) + + def test_none_set(self): + bitmap = bytes([0x00]) + positions = _read_bitmap_positions(bitmap, 8) + assert positions == [] + + def test_sparse(self): + # Bit 0 and bit 2 set => positions [0, 2] + bitmap = bytes([0b00000101]) + positions = _read_bitmap_positions(bitmap, 8) + assert positions == [0, 2] + + def test_multi_byte(self): + # 16 tokens: first byte has bits 0,7 set; second byte has bit 1 (token 9) set + bitmap = bytes([0b10000001, 0b00000010]) + positions = _read_bitmap_positions(bitmap, 16) + assert positions == [0, 7, 9] + + +class TestDecompressAndParseR3: + def test_all_mode_uint8(self): + num_moe_layers = 2 + top_k = 2 + total_tokens = 4 + matrix_elem_size = num_moe_layers * top_k # 4 bytes per token + + matrices_raw = [] + for i in range(total_tokens): + matrices_raw.append(bytes([i * 10 + j for j in range(matrix_elem_size)])) + matrix_data = b"".join(matrices_raw) + + raw = _make_raw_r3( + total_token_count=total_tokens, + replayed_token_count=total_tokens, + num_moe_layers=num_moe_layers, + top_k=top_k, + matrix_data=matrix_data, + ) + blob = _compress_and_b64(raw) + + matrices, metadata = decompress_and_parse_r3(blob) + + assert len(matrices) == total_tokens + assert metadata["num_moe_layers"] == num_moe_layers + assert metadata["top_k"] == top_k + assert metadata["routing_dtype"] == "uint8" + assert metadata["selector_mode"] == "all" + assert metadata["total_token_count"] == total_tokens + assert metadata["replayed_token_count"] == total_tokens + + for i in range(total_tokens): + assert matrices[i] is not None + decoded = base64.b64decode(matrices[i]) + assert decoded == matrices_raw[i] + + def test_suffix_mode(self): + num_moe_layers = 2 + top_k = 2 + total_tokens = 8 + replayed = 3 + start_token = 5 + matrix_elem_size = num_moe_layers * top_k + + matrices_raw = [] + for i in range(replayed): + matrices_raw.append(bytes([(start_token + i) * 10 + j for j in range(matrix_elem_size)])) + matrix_data = b"".join(matrices_raw) + + raw = _make_raw_r3( + selector_mode=_SelectorMode.SUFFIX, + total_token_count=total_tokens, + replayed_token_count=replayed, + num_moe_layers=num_moe_layers, + top_k=top_k, + replay_start_token=start_token, + matrix_data=matrix_data, + ) + blob = _compress_and_b64(raw) + + matrices, metadata = decompress_and_parse_r3(blob) + + assert len(matrices) == total_tokens + assert metadata["selector_mode"] == "suffix" + assert metadata["replay_start_token"] == start_token + + # Positions before start_token should be None + for i in range(start_token): + assert matrices[i] is None + + # Positions from start_token to start_token+replayed should have data + for i in range(replayed): + pos = start_token + i + assert matrices[pos] is not None + decoded = base64.b64decode(matrices[pos]) + assert decoded == matrices_raw[i] + + def test_bitmap_mode(self): + num_moe_layers = 2 + top_k = 2 + total_tokens = 8 + matrix_elem_size = num_moe_layers * top_k + + # Replay tokens at positions 1, 3, 6 + replayed_positions = [1, 3, 6] + replayed = len(replayed_positions) + + # Build bitmap + bitmap = bytearray(math.ceil(total_tokens / 8)) + for pos in replayed_positions: + bitmap[pos >> 3] |= 1 << (pos & 7) + selector_bytes = bytes(bitmap) + + matrices_raw = [] + for idx, pos in enumerate(replayed_positions): + matrices_raw.append(bytes([pos * 10 + j for j in range(matrix_elem_size)])) + matrix_data = b"".join(matrices_raw) + + raw = _make_raw_r3( + selector_mode=_SelectorMode.BITMAP, + total_token_count=total_tokens, + replayed_token_count=replayed, + num_moe_layers=num_moe_layers, + top_k=top_k, + selector_bytes=selector_bytes, + matrix_data=matrix_data, + ) + blob = _compress_and_b64(raw) + + matrices, metadata = decompress_and_parse_r3(blob) + + assert len(matrices) == total_tokens + assert metadata["selector_mode"] == "bitmap" + assert metadata["replayed_token_count"] == replayed + + for i in range(total_tokens): + if i in replayed_positions: + assert matrices[i] is not None + idx = replayed_positions.index(i) + decoded = base64.b64decode(matrices[i]) + assert decoded == matrices_raw[idx] + else: + assert matrices[i] is None + + def test_uint16_dtype(self): + num_moe_layers = 2 + top_k = 2 + total_tokens = 2 + matrix_elem_size = num_moe_layers * top_k * 2 # 2 bytes per element for uint16 + + matrices_raw = [] + for i in range(total_tokens): + matrices_raw.append(bytes([i * 10 + j for j in range(matrix_elem_size)])) + matrix_data = b"".join(matrices_raw) + + raw = _make_raw_r3( + routing_dtype=_RoutingDtype.UINT16, + total_token_count=total_tokens, + replayed_token_count=total_tokens, + num_moe_layers=num_moe_layers, + top_k=top_k, + matrix_data=matrix_data, + ) + blob = _compress_and_b64(raw) + + matrices, metadata = decompress_and_parse_r3(blob) + + assert metadata["routing_dtype"] == "uint16" + assert len(matrices) == total_tokens + for i in range(total_tokens): + decoded = base64.b64decode(matrices[i]) + assert decoded == matrices_raw[i] + + def test_zero_replayed_tokens(self): + raw = _make_raw_r3( + total_token_count=10, + replayed_token_count=0, + matrix_data=b"", + ) + blob = _compress_and_b64(raw) + + matrices, metadata = decompress_and_parse_r3(blob) + + assert len(matrices) == 10 + assert all(m is None for m in matrices) + assert metadata["replayed_token_count"] == 0 + + +class TestRoundTrip: + """Round-trip test using the gateway's serializer and EP's deserializer.""" + + def test_round_trip_with_serializer(self): + """Verify that data serialized by the gateway's r3_serializer can be + deserialized by EP's r3_deserializer and produce the original per-token + matrices.""" + import sys + import os + + # Add the tracing gateway code to the path so we can import the serializer + serializer_dir = os.path.join( + os.path.dirname(__file__), "..", "..", "..", "mono", "eval-py" + ) + serializer_dir = os.path.normpath(serializer_dir) + + if not os.path.isdir(serializer_dir): + pytest.skip(f"Serializer source not available at {serializer_dir}") + + sys.path.insert(0, serializer_dir) + try: + from litellm_proxy_config.proxy_core.r3_serializer import ( + serialize_r3, + compress_and_chunk, + ) + from litellm_proxy_config.proxy_core.models import RouterReplayData + except ImportError: + pytest.skip("r3_serializer or models not importable") + finally: + sys.path.pop(0) + + num_moe_layers = 4 + top_k = 8 + total_tokens = 16 + + # Build per-token matrices as Optional[bytes], like the gateway produces + original_matrices: List[Optional[bytes]] = [] + original_b64: List[Optional[str]] = [] + matrix_elem_size = num_moe_layers * top_k # uint8 + for i in range(total_tokens): + if i < 4: + # Prompt tokens: no routing data + original_matrices.append(None) + original_b64.append(None) + else: + mat = bytes([(i * 7 + j) % 256 for j in range(matrix_elem_size)]) + original_matrices.append(mat) + original_b64.append(base64.b64encode(mat).decode("ascii")) + + data = RouterReplayData( + routing_matrices=original_matrices, + total_token_count=total_tokens, + num_moe_layers=num_moe_layers, + top_k=top_k, + routing_dtype="uint8", + ) + + raw_payload = serialize_r3(data) + chunks = compress_and_chunk(raw_payload, chunk_size=1024 * 1024) + assembled = b"".join(chunks) + blob_b64 = base64.b64encode(assembled).decode("ascii") + + # Now deserialize with EP + matrices, metadata = decompress_and_parse_r3(blob_b64) + + assert len(matrices) == total_tokens + assert metadata["num_moe_layers"] == num_moe_layers + assert metadata["top_k"] == top_k + + for i in range(total_tokens): + if original_b64[i] is None: + assert matrices[i] is None, f"Token {i} should be None" + else: + assert matrices[i] is not None, f"Token {i} should have data" + assert matrices[i] == original_b64[i], f"Token {i} data mismatch" + + +class TestConvertTraceDictWithPayloads: + """Test that convert_trace_dict_to_evaluation_row extracts R3 payloads.""" + + def test_trace_with_router_replay_payload(self): + from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row + + num_moe_layers = 2 + top_k = 2 + total_tokens = 4 + matrix_elem_size = num_moe_layers * top_k + + matrices_raw = [] + for i in range(total_tokens): + matrices_raw.append(bytes([i * 10 + j for j in range(matrix_elem_size)])) + matrix_data = b"".join(matrices_raw) + + raw = _make_raw_r3( + total_token_count=total_tokens, + replayed_token_count=total_tokens, + num_moe_layers=num_moe_layers, + top_k=top_k, + matrix_data=matrix_data, + ) + blob = _compress_and_b64(raw) + + trace = { + "id": "test-trace-123", + "input": { + "messages": [ + {"role": "user", "content": "hello"}, + ] + }, + "output": { + "choices": [ + {"message": {"role": "assistant", "content": "hi"}} + ] + }, + "tags": ["rollout_id:r1", "run_id:run1"], + "payloads": { + "router_replay": { + "manifest": { + "PayloadVersion": "r3/v1", + "Compression": "zstd", + }, + "data": blob, + } + }, + } + + row = convert_trace_dict_to_evaluation_row(trace) + assert row is not None + assert row.execution_metadata.extra is not None + assert "routing_matrices" in row.execution_metadata.extra + assert "routing_metadata" in row.execution_metadata.extra + + rm = row.execution_metadata.extra["routing_matrices"] + assert len(rm) == total_tokens + for i in range(total_tokens): + assert rm[i] is not None + decoded = base64.b64decode(rm[i]) + assert decoded == matrices_raw[i] + + meta = row.execution_metadata.extra["routing_metadata"] + assert meta["num_moe_layers"] == num_moe_layers + assert meta["top_k"] == top_k + + def test_trace_without_payloads(self): + from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row + + trace = { + "id": "test-trace-no-payload", + "input": { + "messages": [ + {"role": "user", "content": "hello"}, + ] + }, + "output": { + "choices": [ + {"message": {"role": "assistant", "content": "hi"}} + ] + }, + "tags": [], + } + + row = convert_trace_dict_to_evaluation_row(trace) + assert row is not None + assert row.execution_metadata.extra is None + + def test_trace_with_empty_payload_data(self): + from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row + + trace = { + "id": "test-trace-empty-payload", + "input": { + "messages": [ + {"role": "user", "content": "hello"}, + ] + }, + "output": { + "choices": [ + {"message": {"role": "assistant", "content": "hi"}} + ] + }, + "tags": [], + "payloads": { + "router_replay": { + "manifest": {}, + "data": "", + } + }, + } + + row = convert_trace_dict_to_evaluation_row(trace) + assert row is not None + # Empty data string should be skipped (no crash) + assert row.execution_metadata.extra is None diff --git a/uv.lock b/uv.lock index ae420524..048760dc 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -1185,6 +1185,7 @@ dependencies = [ { name = "toml" }, { name = "uvicorn" }, { name = "websockets" }, + { name = "zstandard" }, ] [package.optional-dependencies] @@ -1382,6 +1383,7 @@ requires-dist = [ { name = "versioneer", marker = "extra == 'dev'", specifier = ">=0.20" }, { name = "websockets", specifier = ">=15.0.1" }, { name = "werkzeug", marker = "extra == 'dev'", specifier = ">=2.0.0" }, + { name = "zstandard", specifier = ">=0.19.0" }, ] provides-extras = ["dev", "trl", "openevals", "box2d", "langfuse", "huggingface", "langsmith", "bigquery", "svgbench", "pydantic", "supabase", "chinook", "langchain", "braintrust", "openenv", "dspy", "klavis", "langgraph", "langgraph-tools", "proxy"] @@ -1917,7 +1919,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/6a/33d1702184d94106d3cdd7bfb788e19723206fce152e303473ca3b946c7b/greenlet-3.3.0-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:6f8496d434d5cb2dce025773ba5597f71f5410ae499d5dd9533e0653258cdb3d", size = 273658, upload-time = "2025-12-04T14:23:37.494Z" }, { url = "https://files.pythonhosted.org/packages/d6/b7/2b5805bbf1907c26e434f4e448cd8b696a0b71725204fa21a211ff0c04a7/greenlet-3.3.0-cp310-cp310-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b96dc7eef78fd404e022e165ec55327f935b9b52ff355b067eb4a0267fc1cffb", size = 574810, upload-time = "2025-12-04T14:50:04.154Z" }, { url = "https://files.pythonhosted.org/packages/94/38/343242ec12eddf3d8458c73f555c084359883d4ddc674240d9e61ec51fd6/greenlet-3.3.0-cp310-cp310-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:73631cd5cccbcfe63e3f9492aaa664d278fda0ce5c3d43aeda8e77317e38efbd", size = 586248, upload-time = "2025-12-04T14:57:39.35Z" }, - { url = "https://files.pythonhosted.org/packages/f0/d0/0ae86792fb212e4384041e0ef8e7bc66f59a54912ce407d26a966ed2914d/greenlet-3.3.0-cp310-cp310-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b299a0cb979f5d7197442dccc3aee67fce53500cd88951b7e6c35575701c980b", size = 597403, upload-time = "2025-12-04T15:07:10.831Z" }, { url = "https://files.pythonhosted.org/packages/b6/a8/15d0aa26c0036a15d2659175af00954aaaa5d0d66ba538345bd88013b4d7/greenlet-3.3.0-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7dee147740789a4632cace364816046e43310b59ff8fb79833ab043aefa72fd5", size = 586910, upload-time = "2025-12-04T14:25:59.705Z" }, { url = "https://files.pythonhosted.org/packages/e1/9b/68d5e3b7ccaba3907e5532cf8b9bf16f9ef5056a008f195a367db0ff32db/greenlet-3.3.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:39b28e339fc3c348427560494e28d8a6f3561c8d2bcf7d706e1c624ed8d822b9", size = 1547206, upload-time = "2025-12-04T15:04:21.027Z" }, { url = "https://files.pythonhosted.org/packages/66/bd/e3086ccedc61e49f91e2cfb5ffad9d8d62e5dc85e512a6200f096875b60c/greenlet-3.3.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:b3c374782c2935cc63b2a27ba8708471de4ad1abaa862ffdb1ef45a643ddbb7d", size = 1613359, upload-time = "2025-12-04T14:27:26.548Z" }, @@ -1925,7 +1926,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/cb/48e964c452ca2b92175a9b2dca037a553036cb053ba69e284650ce755f13/greenlet-3.3.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:e29f3018580e8412d6aaf5641bb7745d38c85228dacf51a73bd4e26ddf2a6a8e", size = 274908, upload-time = "2025-12-04T14:23:26.435Z" }, { url = "https://files.pythonhosted.org/packages/28/da/38d7bff4d0277b594ec557f479d65272a893f1f2a716cad91efeb8680953/greenlet-3.3.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a687205fb22794e838f947e2194c0566d3812966b41c78709554aa883183fb62", size = 577113, upload-time = "2025-12-04T14:50:05.493Z" }, { url = "https://files.pythonhosted.org/packages/3c/f2/89c5eb0faddc3ff014f1c04467d67dee0d1d334ab81fadbf3744847f8a8a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4243050a88ba61842186cb9e63c7dfa677ec146160b0efd73b855a3d9c7fcf32", size = 590338, upload-time = "2025-12-04T14:57:41.136Z" }, - { url = "https://files.pythonhosted.org/packages/80/d7/db0a5085035d05134f8c089643da2b44cc9b80647c39e93129c5ef170d8f/greenlet-3.3.0-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:670d0f94cd302d81796e37299bcd04b95d62403883b24225c6b5271466612f45", size = 601098, upload-time = "2025-12-04T15:07:11.898Z" }, { url = "https://files.pythonhosted.org/packages/dc/a6/e959a127b630a58e23529972dbc868c107f9d583b5a9f878fb858c46bc1a/greenlet-3.3.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6cb3a8ec3db4a3b0eb8a3c25436c2d49e3505821802074969db017b87bc6a948", size = 590206, upload-time = "2025-12-04T14:26:01.254Z" }, { url = "https://files.pythonhosted.org/packages/48/60/29035719feb91798693023608447283b266b12efc576ed013dd9442364bb/greenlet-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2de5a0b09eab81fc6a382791b995b1ccf2b172a9fec934747a7a23d2ff291794", size = 1550668, upload-time = "2025-12-04T15:04:22.439Z" }, { url = "https://files.pythonhosted.org/packages/0a/5f/783a23754b691bfa86bd72c3033aa107490deac9b2ef190837b860996c9f/greenlet-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4449a736606bd30f27f8e1ff4678ee193bc47f6ca810d705981cfffd6ce0d8c5", size = 1615483, upload-time = "2025-12-04T14:27:28.083Z" }, @@ -1933,7 +1933,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/0a/a3871375c7b9727edaeeea994bfff7c63ff7804c9829c19309ba2e058807/greenlet-3.3.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:b01548f6e0b9e9784a2c99c5651e5dc89ffcbe870bc5fb2e5ef864e9cc6b5dcb", size = 276379, upload-time = "2025-12-04T14:23:30.498Z" }, { url = "https://files.pythonhosted.org/packages/43/ab/7ebfe34dce8b87be0d11dae91acbf76f7b8246bf9d6b319c741f99fa59c6/greenlet-3.3.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:349345b770dc88f81506c6861d22a6ccd422207829d2c854ae2af8025af303e3", size = 597294, upload-time = "2025-12-04T14:50:06.847Z" }, { url = "https://files.pythonhosted.org/packages/a4/39/f1c8da50024feecd0793dbd5e08f526809b8ab5609224a2da40aad3a7641/greenlet-3.3.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e8e18ed6995e9e2c0b4ed264d2cf89260ab3ac7e13555b8032b25a74c6d18655", size = 607742, upload-time = "2025-12-04T14:57:42.349Z" }, - { url = "https://files.pythonhosted.org/packages/77/cb/43692bcd5f7a0da6ec0ec6d58ee7cddb606d055ce94a62ac9b1aa481e969/greenlet-3.3.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c024b1e5696626890038e34f76140ed1daf858e37496d33f2af57f06189e70d7", size = 622297, upload-time = "2025-12-04T15:07:13.552Z" }, { url = "https://files.pythonhosted.org/packages/75/b0/6bde0b1011a60782108c01de5913c588cf51a839174538d266de15e4bf4d/greenlet-3.3.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:047ab3df20ede6a57c35c14bf5200fcf04039d50f908270d3f9a7a82064f543b", size = 609885, upload-time = "2025-12-04T14:26:02.368Z" }, { url = "https://files.pythonhosted.org/packages/49/0e/49b46ac39f931f59f987b7cd9f34bfec8ef81d2a1e6e00682f55be5de9f4/greenlet-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d9ad37fc657b1102ec880e637cccf20191581f75c64087a549e66c57e1ceb53", size = 1567424, upload-time = "2025-12-04T15:04:23.757Z" }, { url = "https://files.pythonhosted.org/packages/05/f5/49a9ac2dff7f10091935def9165c90236d8f175afb27cbed38fb1d61ab6b/greenlet-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83cd0e36932e0e7f36a64b732a6f60c2fc2df28c351bae79fbaf4f8092fe7614", size = 1636017, upload-time = "2025-12-04T14:27:29.688Z" }, @@ -1941,7 +1940,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/2f/28592176381b9ab2cafa12829ba7b472d177f3acc35d8fbcf3673d966fff/greenlet-3.3.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:a1e41a81c7e2825822f4e068c48cb2196002362619e2d70b148f20a831c00739", size = 275140, upload-time = "2025-12-04T14:23:01.282Z" }, { url = "https://files.pythonhosted.org/packages/2c/80/fbe937bf81e9fca98c981fe499e59a3f45df2a04da0baa5c2be0dca0d329/greenlet-3.3.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f515a47d02da4d30caaa85b69474cec77b7929b2e936ff7fb853d42f4bf8808", size = 599219, upload-time = "2025-12-04T14:50:08.309Z" }, { url = "https://files.pythonhosted.org/packages/c2/ff/7c985128f0514271b8268476af89aee6866df5eec04ac17dcfbc676213df/greenlet-3.3.0-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7d2d9fd66bfadf230b385fdc90426fcd6eb64db54b40c495b72ac0feb5766c54", size = 610211, upload-time = "2025-12-04T14:57:43.968Z" }, - { url = "https://files.pythonhosted.org/packages/79/07/c47a82d881319ec18a4510bb30463ed6891f2ad2c1901ed5ec23d3de351f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:30a6e28487a790417d036088b3bcb3f3ac7d8babaa7d0139edbaddebf3af9492", size = 624311, upload-time = "2025-12-04T15:07:14.697Z" }, { url = "https://files.pythonhosted.org/packages/fd/8e/424b8c6e78bd9837d14ff7df01a9829fc883ba2ab4ea787d4f848435f23f/greenlet-3.3.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:087ea5e004437321508a8d6f20efc4cfec5e3c30118e1417ea96ed1d93950527", size = 612833, upload-time = "2025-12-04T14:26:03.669Z" }, { url = "https://files.pythonhosted.org/packages/b5/ba/56699ff9b7c76ca12f1cdc27a886d0f81f2189c3455ff9f65246780f713d/greenlet-3.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ab97cf74045343f6c60a39913fa59710e4bd26a536ce7ab2397adf8b27e67c39", size = 1567256, upload-time = "2025-12-04T15:04:25.276Z" }, { url = "https://files.pythonhosted.org/packages/1e/37/f31136132967982d698c71a281a8901daf1a8fbab935dce7c0cf15f942cc/greenlet-3.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5375d2e23184629112ca1ea89a53389dddbffcf417dad40125713d88eb5f96e8", size = 1636483, upload-time = "2025-12-04T14:27:30.804Z" }, @@ -1949,7 +1947,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/7c/f0a6d0ede2c7bf092d00bc83ad5bafb7e6ec9b4aab2fbdfa6f134dc73327/greenlet-3.3.0-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:60c2ef0f578afb3c8d92ea07ad327f9a062547137afe91f38408f08aacab667f", size = 275671, upload-time = "2025-12-04T14:23:05.267Z" }, { url = "https://files.pythonhosted.org/packages/44/06/dac639ae1a50f5969d82d2e3dd9767d30d6dbdbab0e1a54010c8fe90263c/greenlet-3.3.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a5d554d0712ba1de0a6c94c640f7aeba3f85b3a6e1f2899c11c2c0428da9365", size = 646360, upload-time = "2025-12-04T14:50:10.026Z" }, { url = "https://files.pythonhosted.org/packages/e0/94/0fb76fe6c5369fba9bf98529ada6f4c3a1adf19e406a47332245ef0eb357/greenlet-3.3.0-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3a898b1e9c5f7307ebbde4102908e6cbfcb9ea16284a3abe15cab996bee8b9b3", size = 658160, upload-time = "2025-12-04T14:57:45.41Z" }, - { url = "https://files.pythonhosted.org/packages/93/79/d2c70cae6e823fac36c3bbc9077962105052b7ef81db2f01ec3b9bf17e2b/greenlet-3.3.0-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:dcd2bdbd444ff340e8d6bdf54d2f206ccddbb3ccfdcd3c25bf4afaa7b8f0cf45", size = 671388, upload-time = "2025-12-04T15:07:15.789Z" }, { url = "https://files.pythonhosted.org/packages/b8/14/bab308fc2c1b5228c3224ec2bf928ce2e4d21d8046c161e44a2012b5203e/greenlet-3.3.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5773edda4dc00e173820722711d043799d3adb4f01731f40619e07ea2750b955", size = 660166, upload-time = "2025-12-04T14:26:05.099Z" }, { url = "https://files.pythonhosted.org/packages/4b/d2/91465d39164eaa0085177f61983d80ffe746c5a1860f009811d498e7259c/greenlet-3.3.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:ac0549373982b36d5fd5d30beb8a7a33ee541ff98d2b502714a09f1169f31b55", size = 1615193, upload-time = "2025-12-04T15:04:27.041Z" }, { url = "https://files.pythonhosted.org/packages/42/1b/83d110a37044b92423084d52d5d5a3b3a73cafb51b547e6d7366ff62eff1/greenlet-3.3.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:d198d2d977460358c3b3a4dc844f875d1adb33817f0613f663a656f463764ccc", size = 1683653, upload-time = "2025-12-04T14:27:32.366Z" }, @@ -1957,7 +1954,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/66/bd6317bc5932accf351fc19f177ffba53712a202f9df10587da8df257c7e/greenlet-3.3.0-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:d6ed6f85fae6cdfdb9ce04c9bf7a08d666cfcfb914e7d006f44f840b46741931", size = 282638, upload-time = "2025-12-04T14:25:20.941Z" }, { url = "https://files.pythonhosted.org/packages/30/cf/cc81cb030b40e738d6e69502ccbd0dd1bced0588e958f9e757945de24404/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9125050fcf24554e69c4cacb086b87b3b55dc395a8b3ebe6487b045b2614388", size = 651145, upload-time = "2025-12-04T14:50:11.039Z" }, { url = "https://files.pythonhosted.org/packages/9c/ea/1020037b5ecfe95ca7df8d8549959baceb8186031da83d5ecceff8b08cd2/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:87e63ccfa13c0a0f6234ed0add552af24cc67dd886731f2261e46e241608bee3", size = 654236, upload-time = "2025-12-04T14:57:47.007Z" }, - { url = "https://files.pythonhosted.org/packages/69/cc/1e4bae2e45ca2fa55299f4e85854606a78ecc37fead20d69322f96000504/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2662433acbca297c9153a4023fe2161c8dcfdcc91f10433171cf7e7d94ba2221", size = 662506, upload-time = "2025-12-04T15:07:16.906Z" }, { url = "https://files.pythonhosted.org/packages/57/b9/f8025d71a6085c441a7eaff0fd928bbb275a6633773667023d19179fe815/greenlet-3.3.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3c6e9b9c1527a78520357de498b0e709fb9e2f49c3a513afd5a249007261911b", size = 653783, upload-time = "2025-12-04T14:26:06.225Z" }, { url = "https://files.pythonhosted.org/packages/f6/c7/876a8c7a7485d5d6b5c6821201d542ef28be645aa024cfe1145b35c120c1/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:286d093f95ec98fdd92fcb955003b8a3d054b4e2cab3e2707a5039e7b50520fd", size = 1614857, upload-time = "2025-12-04T15:04:28.484Z" }, { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" }, From 6d10467366f8972d1c86f14f6fb842f911f0e728 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Mon, 4 May 2026 19:06:01 -0700 Subject: [PATCH 2/8] refactor: drop num_moe_layers/top_k from r3/v1 deserializer Mirrors the gateway-side r3_serializer change: the per-token matrix shape (num_moe_layers, top_k) is no longer required and is no longer written into the r3/v1 binary header. Per-token matrix byte size is recovered as matrix_byte_length / replayed_token_count. - HEADER_FORMAT: "<4sBBBBIIHHIIQ" (36 bytes) -> "<4sBBBBIIIIQ" (32 bytes). - Drop num_moe_layers/top_k from _parse_header() and the metadata dict returned by decompress_and_parse_r3(). - Compute matrix_elem_size from matrix_byte_length / replayed_token_count with a divisibility check that surfaces malformed payloads early. - Update unit tests to use matrix_elem_size as the parameter and drop assertions on the removed header fields; round-trip test no longer passes num_moe_layers/top_k to RouterReplayData. Co-authored-by: Cursor --- eval_protocol/adapters/r3_deserializer.py | 34 ++++++------- tests/adapters/test_r3_deserializer.py | 58 +++++++---------------- 2 files changed, 35 insertions(+), 57 deletions(-) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 1b70d9f6..2709c234 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -15,7 +15,6 @@ import base64 import logging -import math import struct from enum import IntEnum from typing import Any, Dict, List, Optional, Tuple @@ -25,8 +24,8 @@ logger = logging.getLogger(__name__) MAGIC = b"R3V1" -HEADER_FORMAT = "<4sBBBBIIHHIIQ" -HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 36 bytes +HEADER_FORMAT = "<4sBBBBIIIIQ" +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes class _SelectorMode(IntEnum): @@ -62,8 +61,6 @@ def _parse_header(raw: bytes) -> Dict[str, Any]: flags, total_token_count, replayed_token_count, - num_moe_layers, - top_k, replay_start_token, selector_byte_length, matrix_byte_length, @@ -80,8 +77,6 @@ def _parse_header(raw: bytes) -> Dict[str, Any]: "flags": flags, "total_token_count": total_token_count, "replayed_token_count": replayed_token_count, - "num_moe_layers": num_moe_layers, - "top_k": top_k, "replay_start_token": replay_start_token, "selector_byte_length": selector_byte_length, "matrix_byte_length": matrix_byte_length, @@ -117,9 +112,9 @@ def decompress_and_parse_r3( ``total_token_count``. Each present position contains a base64-encoded routing matrix (matching the format returned by the direct inference path); absent positions are ``None``. - - ``metadata`` is a dict with keys ``num_moe_layers``, ``top_k``, - ``routing_dtype``, ``selector_mode``, ``total_token_count``, - ``replayed_token_count``, ``replay_start_token``. + - ``metadata`` is a dict with keys ``routing_dtype``, + ``selector_mode``, ``total_token_count``, ``replayed_token_count``, + ``replay_start_token``. """ compressed = base64.b64decode(data_b64) @@ -132,14 +127,23 @@ def decompress_and_parse_r3( routing_dtype = header["routing_dtype"] total_token_count = header["total_token_count"] replayed_token_count = header["replayed_token_count"] - num_moe_layers = header["num_moe_layers"] - top_k = header["top_k"] replay_start_token = header["replay_start_token"] selector_byte_length = header["selector_byte_length"] matrix_byte_length = header["matrix_byte_length"] - dtype_byte_width = _RoutingDtype(routing_dtype).byte_width - matrix_elem_size = num_moe_layers * top_k * dtype_byte_width + # Per-token matrix byte size is implicit in the payload: all replayed + # tokens share the same matrix length, so we can recover it from the + # matrix section total length divided by the replayed-token count. + if replayed_token_count > 0: + if matrix_byte_length % replayed_token_count != 0: + raise ValueError( + f"matrix_byte_length ({matrix_byte_length}) is not a multiple of " + f"replayed_token_count ({replayed_token_count}); cannot split " + "into per-token matrices" + ) + matrix_elem_size = matrix_byte_length // replayed_token_count + else: + matrix_elem_size = 0 body = raw[HEADER_SIZE:] selector_bytes = body[:selector_byte_length] @@ -173,8 +177,6 @@ def decompress_and_parse_r3( matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") metadata: Dict[str, Any] = { - "num_moe_layers": num_moe_layers, - "top_k": top_k, "routing_dtype": _ROUTING_DTYPE_NAMES.get( _RoutingDtype(routing_dtype), str(routing_dtype) ), diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py index 26970466..4cb1502b 100644 --- a/tests/adapters/test_r3_deserializer.py +++ b/tests/adapters/test_r3_deserializer.py @@ -28,17 +28,20 @@ def _make_raw_r3( routing_dtype: int = _RoutingDtype.UINT8, total_token_count: int = 4, replayed_token_count: int = 4, - num_moe_layers: int = 2, - top_k: int = 2, + matrix_elem_size: Optional[int] = None, replay_start_token: int = 0, selector_bytes: bytes = b"", matrix_data: Optional[bytes] = None, ) -> bytes: - """Build a raw (uncompressed) R3/v1 payload for testing.""" - dtype_byte_width = _RoutingDtype(routing_dtype).byte_width - matrix_elem_size = num_moe_layers * top_k * dtype_byte_width + """Build a raw (uncompressed) R3/v1 payload for testing. + ``matrix_elem_size`` is the per-token matrix byte length; when not given + and no explicit ``matrix_data`` is supplied, defaults to 4 bytes/token + (a minimal placeholder for tests that don't care about shape). + """ if matrix_data is None: + if matrix_elem_size is None: + matrix_elem_size = 4 matrix_data = bytes(range(matrix_elem_size)) * replayed_token_count header = struct.pack( @@ -50,8 +53,6 @@ def _make_raw_r3( 0x01, # flags: little-endian total_token_count, replayed_token_count, - num_moe_layers, - top_k, replay_start_token, len(selector_bytes), len(matrix_data), @@ -86,7 +87,7 @@ def test_too_short(self): def test_unsupported_version(self): raw = struct.pack( HEADER_FORMAT, - MAGIC, 99, 0, 1, 0, 4, 4, 2, 2, 0, 0, 16, + MAGIC, 99, 0, 1, 0, 4, 4, 0, 0, 16, ) with pytest.raises(ValueError, match="Unsupported R3 header version"): _parse_header(raw) @@ -118,10 +119,8 @@ def test_multi_byte(self): class TestDecompressAndParseR3: def test_all_mode_uint8(self): - num_moe_layers = 2 - top_k = 2 + matrix_elem_size = 4 # e.g. 2 MoE layers * 2 top-k * 1 byte (uint8) total_tokens = 4 - matrix_elem_size = num_moe_layers * top_k # 4 bytes per token matrices_raw = [] for i in range(total_tokens): @@ -131,8 +130,6 @@ def test_all_mode_uint8(self): raw = _make_raw_r3( total_token_count=total_tokens, replayed_token_count=total_tokens, - num_moe_layers=num_moe_layers, - top_k=top_k, matrix_data=matrix_data, ) blob = _compress_and_b64(raw) @@ -140,8 +137,6 @@ def test_all_mode_uint8(self): matrices, metadata = decompress_and_parse_r3(blob) assert len(matrices) == total_tokens - assert metadata["num_moe_layers"] == num_moe_layers - assert metadata["top_k"] == top_k assert metadata["routing_dtype"] == "uint8" assert metadata["selector_mode"] == "all" assert metadata["total_token_count"] == total_tokens @@ -153,12 +148,10 @@ def test_all_mode_uint8(self): assert decoded == matrices_raw[i] def test_suffix_mode(self): - num_moe_layers = 2 - top_k = 2 + matrix_elem_size = 4 total_tokens = 8 replayed = 3 start_token = 5 - matrix_elem_size = num_moe_layers * top_k matrices_raw = [] for i in range(replayed): @@ -169,8 +162,6 @@ def test_suffix_mode(self): selector_mode=_SelectorMode.SUFFIX, total_token_count=total_tokens, replayed_token_count=replayed, - num_moe_layers=num_moe_layers, - top_k=top_k, replay_start_token=start_token, matrix_data=matrix_data, ) @@ -194,10 +185,8 @@ def test_suffix_mode(self): assert decoded == matrices_raw[i] def test_bitmap_mode(self): - num_moe_layers = 2 - top_k = 2 + matrix_elem_size = 4 total_tokens = 8 - matrix_elem_size = num_moe_layers * top_k # Replay tokens at positions 1, 3, 6 replayed_positions = [1, 3, 6] @@ -218,8 +207,6 @@ def test_bitmap_mode(self): selector_mode=_SelectorMode.BITMAP, total_token_count=total_tokens, replayed_token_count=replayed, - num_moe_layers=num_moe_layers, - top_k=top_k, selector_bytes=selector_bytes, matrix_data=matrix_data, ) @@ -241,10 +228,8 @@ def test_bitmap_mode(self): assert matrices[i] is None def test_uint16_dtype(self): - num_moe_layers = 2 - top_k = 2 + matrix_elem_size = 8 # e.g. 2 MoE layers * 2 top-k * 2 bytes (uint16) total_tokens = 2 - matrix_elem_size = num_moe_layers * top_k * 2 # 2 bytes per element for uint16 matrices_raw = [] for i in range(total_tokens): @@ -255,8 +240,6 @@ def test_uint16_dtype(self): routing_dtype=_RoutingDtype.UINT16, total_token_count=total_tokens, replayed_token_count=total_tokens, - num_moe_layers=num_moe_layers, - top_k=top_k, matrix_data=matrix_data, ) blob = _compress_and_b64(raw) @@ -336,8 +319,6 @@ def test_round_trip_with_serializer(self): data = RouterReplayData( routing_matrices=original_matrices, total_token_count=total_tokens, - num_moe_layers=num_moe_layers, - top_k=top_k, routing_dtype="uint8", ) @@ -350,8 +331,7 @@ def test_round_trip_with_serializer(self): matrices, metadata = decompress_and_parse_r3(blob_b64) assert len(matrices) == total_tokens - assert metadata["num_moe_layers"] == num_moe_layers - assert metadata["top_k"] == top_k + assert metadata["total_token_count"] == total_tokens for i in range(total_tokens): if original_b64[i] is None: @@ -367,10 +347,8 @@ class TestConvertTraceDictWithPayloads: def test_trace_with_router_replay_payload(self): from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row - num_moe_layers = 2 - top_k = 2 + matrix_elem_size = 4 total_tokens = 4 - matrix_elem_size = num_moe_layers * top_k matrices_raw = [] for i in range(total_tokens): @@ -380,8 +358,6 @@ def test_trace_with_router_replay_payload(self): raw = _make_raw_r3( total_token_count=total_tokens, replayed_token_count=total_tokens, - num_moe_layers=num_moe_layers, - top_k=top_k, matrix_data=matrix_data, ) blob = _compress_and_b64(raw) @@ -424,8 +400,8 @@ def test_trace_with_router_replay_payload(self): assert decoded == matrices_raw[i] meta = row.execution_metadata.extra["routing_metadata"] - assert meta["num_moe_layers"] == num_moe_layers - assert meta["top_k"] == top_k + assert meta["routing_dtype"] == "uint8" + assert meta["total_token_count"] == total_tokens def test_trace_without_payloads(self): from eval_protocol.adapters.fireworks_tracing import convert_trace_dict_to_evaluation_row From aa0762ac3e8a9134bc70ad526bccb12b44d44b8c Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 5 May 2026 16:05:57 -0700 Subject: [PATCH 3/8] test --- eval_protocol/pytest/tracing_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 9ac8d501..1bbb4824 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -133,7 +133,7 @@ def build_init_request( # Build final model base URL with tracing metadata final_model_base_url = model_base_url - if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost")): + if model_base_url and ("tracing.fireworks.ai" in model_base_url or model_base_url.startswith("http://localhost") or "litellm-gateway" in model_base_url): final_model_base_url = build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url) # Extract API key from environment or completion_params From 3393c04753e5e01023b2044babd698ee0ba1b927 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 5 May 2026 18:36:45 -0700 Subject: [PATCH 4/8] fix: drop arbitrary 20x cap on r3/v1 decompressed size ZstdCompressor.compress() (used by the gateway-side r3_serializer) embeds the uncompressed size in the frame header, so passing max_output_size=len(compressed)*20 was both unnecessary and incorrect: highly compressible router-replay payloads (e.g. tokens routing to a small subset of experts) routinely exceed a 20:1 ratio, and would have failed deserialization with ZstdError. Removing the cap lets the library auto-allocate from the embedded content size. Verified locally: a 64 KiB zero-filled matrix payload compresses to ~35 bytes (>1800x ratio) and now deserializes cleanly. Adds a regression test covering the high-compression case. Co-authored-by: Cursor --- eval_protocol/adapters/r3_deserializer.py | 4 +++- tests/adapters/test_r3_deserializer.py | 24 +++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 2709c234..6925ff30 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -118,8 +118,10 @@ def decompress_and_parse_r3( """ compressed = base64.b64decode(data_b64) + # ZstdCompressor.compress() embeds the uncompressed size in the frame + # header by default, so the library can auto-allocate the output buffer. decompressor = zstd.ZstdDecompressor() - raw = decompressor.decompress(compressed, max_output_size=len(compressed) * 20) + raw = decompressor.decompress(compressed) header = _parse_header(raw) diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py index 4cb1502b..348cb0be 100644 --- a/tests/adapters/test_r3_deserializer.py +++ b/tests/adapters/test_r3_deserializer.py @@ -266,6 +266,30 @@ def test_zero_replayed_tokens(self): assert all(m is None for m in matrices) assert metadata["replayed_token_count"] == 0 + def test_high_compression_ratio_payload(self): + """Highly compressible payloads (e.g. tokens routing to the same + experts) can compress much better than 20:1; the deserializer must + not impose an arbitrary cap on the decompressed size.""" + # 64 KiB of zeros compresses to ~35 bytes (>1000x ratio). + total_tokens = 1024 + matrix_elem_size = 64 # bytes/token + matrix_data = b"\x00" * (total_tokens * matrix_elem_size) + + raw = _make_raw_r3( + total_token_count=total_tokens, + replayed_token_count=total_tokens, + matrix_data=matrix_data, + ) + blob = _compress_and_b64(raw) + # Sanity: compression really is >> 20x for this case. + assert len(base64.b64decode(blob)) * 20 < len(raw) + + matrices, metadata = decompress_and_parse_r3(blob) + assert len(matrices) == total_tokens + assert metadata["replayed_token_count"] == total_tokens + for m in matrices: + assert base64.b64decode(m) == b"\x00" * matrix_elem_size + class TestRoundTrip: """Round-trip test using the gateway's serializer and EP's deserializer.""" From 5b09aeca4d96e1743eab00450074dc853092ae16 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 5 May 2026 18:42:51 -0700 Subject: [PATCH 5/8] fix: do not construct IntEnum for unknown dtype/selector_mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _RoutingDtype(int) and _SelectorMode(int) raise ValueError for any value not in the enum, so the .get() fallback was unreachable: a future routing_dtype=3 in the header would crash metadata construction before str(int) could run. Look up names by raw int instead — IntEnum keys hash-equal their int values, so known modes resolve to their lowercase name and unknown ones fall back to str(int) without ever constructing the enum. Adds a regression test exercising routing_dtype=99. Co-authored-by: Cursor --- eval_protocol/adapters/r3_deserializer.py | 8 ++------ tests/adapters/test_r3_deserializer.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 6925ff30..9f589712 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -179,12 +179,8 @@ def decompress_and_parse_r3( matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") metadata: Dict[str, Any] = { - "routing_dtype": _ROUTING_DTYPE_NAMES.get( - _RoutingDtype(routing_dtype), str(routing_dtype) - ), - "selector_mode": _SELECTOR_MODE_NAMES.get( - _SelectorMode(selector_mode), str(selector_mode) - ), + "routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)), + "selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)), "total_token_count": total_token_count, "replayed_token_count": replayed_token_count, "replay_start_token": replay_start_token, diff --git a/tests/adapters/test_r3_deserializer.py b/tests/adapters/test_r3_deserializer.py index 348cb0be..31f058f6 100644 --- a/tests/adapters/test_r3_deserializer.py +++ b/tests/adapters/test_r3_deserializer.py @@ -266,6 +266,20 @@ def test_zero_replayed_tokens(self): assert all(m is None for m in matrices) assert metadata["replayed_token_count"] == 0 + def test_unknown_routing_dtype_falls_back_to_str(self): + """Unknown routing_dtype ints (e.g. a future dtype=3) must not crash + metadata construction; the dtype is surfaced as its string repr.""" + raw = _make_raw_r3( + routing_dtype=99, # not in _RoutingDtype + total_token_count=2, + replayed_token_count=2, + matrix_data=b"\x00" * 8, + ) + blob = _compress_and_b64(raw) + + _, metadata = decompress_and_parse_r3(blob) + assert metadata["routing_dtype"] == "99" + def test_high_compression_ratio_payload(self): """Highly compressible payloads (e.g. tokens routing to the same experts) can compress much better than 20:1; the deserializer must From 6639c01cb8e3a55ec6c0c8774b20600c721ccaa5 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Tue, 5 May 2026 18:47:35 -0700 Subject: [PATCH 6/8] chore: drop unused _RoutingDtype.byte_width decompress_and_parse_r3 now derives matrix_elem_size from matrix_byte_length / replayed_token_count, so the dtype's per-element byte width is no longer referenced anywhere. Removing dead code. Co-authored-by: Cursor --- eval_protocol/adapters/r3_deserializer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 9f589712..05a51106 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -38,10 +38,6 @@ class _RoutingDtype(IntEnum): UINT8 = 1 UINT16 = 2 - @property - def byte_width(self) -> int: - return self.value - _SELECTOR_MODE_NAMES = {v: v.name.lower() for v in _SelectorMode} _ROUTING_DTYPE_NAMES = {v: v.name.lower() for v in _RoutingDtype} From 4e60711f303b219a5074f9b7ad38ef6ec8f667ad Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Fri, 8 May 2026 14:34:04 -0700 Subject: [PATCH 7/8] simplify 1 aspect --- eval_protocol/adapters/r3_deserializer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 05a51106..3297d3ed 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -26,6 +26,7 @@ MAGIC = b"R3V1" HEADER_FORMAT = "<4sBBBBIIIIQ" HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes +BITS_PER_BYTE = 8 class _SelectorMode(IntEnum): @@ -85,8 +86,8 @@ def _read_bitmap_positions( """Return sorted token indices where the bitmap bit is set.""" positions: List[int] = [] for i in range(total_token_count): - byte_idx = i >> 3 - bit_idx = i & 7 + byte_idx = i // BITS_PER_BYTE + bit_idx = i % BITS_PER_BYTE if byte_idx < len(selector_bytes) and (selector_bytes[byte_idx] >> bit_idx) & 1: positions.append(i) return positions From f1e393df155985c596fe9da2298f3ab9ec925fa1 Mon Sep 17 00:00:00 2001 From: Sandeep Singh Date: Fri, 8 May 2026 14:45:42 -0700 Subject: [PATCH 8/8] early return --- eval_protocol/adapters/r3_deserializer.py | 63 ++++++++++++----------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/eval_protocol/adapters/r3_deserializer.py b/eval_protocol/adapters/r3_deserializer.py index 3297d3ed..1e3c1a8c 100644 --- a/eval_protocol/adapters/r3_deserializer.py +++ b/eval_protocol/adapters/r3_deserializer.py @@ -14,15 +14,12 @@ from __future__ import annotations import base64 -import logging import struct from enum import IntEnum from typing import Any, Dict, List, Optional, Tuple import zstandard as zstd -logger = logging.getLogger(__name__) - MAGIC = b"R3V1" HEADER_FORMAT = "<4sBBBBIIIIQ" HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes @@ -130,27 +127,40 @@ def decompress_and_parse_r3( selector_byte_length = header["selector_byte_length"] matrix_byte_length = header["matrix_byte_length"] + metadata: Dict[str, Any] = { + "routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)), + "selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)), + "total_token_count": total_token_count, + "replayed_token_count": replayed_token_count, + "replay_start_token": replay_start_token, + } + + if replayed_token_count == 0: + return [None] * total_token_count, metadata + # Per-token matrix byte size is implicit in the payload: all replayed # tokens share the same matrix length, so we can recover it from the # matrix section total length divided by the replayed-token count. - if replayed_token_count > 0: - if matrix_byte_length % replayed_token_count != 0: - raise ValueError( - f"matrix_byte_length ({matrix_byte_length}) is not a multiple of " - f"replayed_token_count ({replayed_token_count}); cannot split " - "into per-token matrices" - ) - matrix_elem_size = matrix_byte_length // replayed_token_count - else: - matrix_elem_size = 0 + if matrix_byte_length % replayed_token_count != 0: + raise ValueError( + f"matrix_byte_length ({matrix_byte_length}) is not a multiple of " + f"replayed_token_count ({replayed_token_count}); cannot split " + "into per-token matrices" + ) + matrix_elem_size = matrix_byte_length // replayed_token_count body = raw[HEADER_SIZE:] + expected_body_length = selector_byte_length + matrix_byte_length + if len(body) < expected_body_length: + raise ValueError( + f"Payload body too short for selector and matrix sections: " + f"{len(body)} < {expected_body_length}" + ) + selector_bytes = body[:selector_byte_length] matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length] - if matrix_elem_size == 0: - replayed_positions: List[int] = [] - elif selector_mode == _SelectorMode.ALL: + if selector_mode == _SelectorMode.ALL: replayed_positions = list(range(total_token_count)) elif selector_mode == _SelectorMode.SUFFIX: replayed_positions = list( @@ -161,26 +171,17 @@ def decompress_and_parse_r3( else: raise ValueError(f"Unknown selector_mode: {selector_mode}") + if len(replayed_positions) != replayed_token_count: + raise ValueError( + f"Selector produced {len(replayed_positions)} replayed positions, " + f"but header replayed_token_count is {replayed_token_count}" + ) + # Split matrix bytes into per-token chunks and base64-encode each one matrices: List[Optional[str]] = [None] * total_token_count for idx, pos in enumerate(replayed_positions): start = idx * matrix_elem_size end = start + matrix_elem_size - if end > len(matrix_bytes): - logger.warning( - "R3 matrix data truncated at token %d (position %d): " - "expected %d bytes but only %d remaining", - idx, pos, matrix_elem_size, len(matrix_bytes) - start, - ) - break matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii") - metadata: Dict[str, Any] = { - "routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)), - "selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)), - "total_token_count": total_token_count, - "replayed_token_count": replayed_token_count, - "replay_start_token": replay_start_token, - } - return matrices, metadata