Skip to content

Commit 0872ba7

Browse files
fix: do not construct IntEnum for unknown dtype/selector_mode
_RoutingDtype(int) and _SelectorMode(int) raise ValueError for any value not in the enum, so the .get() fallback was unreachable: a future routing_dtype=3 in the header would crash metadata construction before str(int) could run. Look up names by raw int instead — IntEnum keys hash-equal their int values, so known modes resolve to their lowercase name and unknown ones fall back to str(int) without ever constructing the enum. Adds a regression test exercising routing_dtype=99. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 1faadd5 commit 0872ba7

2 files changed

Lines changed: 16 additions & 6 deletions

File tree

eval_protocol/adapters/r3_deserializer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,8 @@ def decompress_and_parse_r3(
179179
matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii")
180180

181181
metadata: Dict[str, Any] = {
182-
"routing_dtype": _ROUTING_DTYPE_NAMES.get(
183-
_RoutingDtype(routing_dtype), str(routing_dtype)
184-
),
185-
"selector_mode": _SELECTOR_MODE_NAMES.get(
186-
_SelectorMode(selector_mode), str(selector_mode)
187-
),
182+
"routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)),
183+
"selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)),
188184
"total_token_count": total_token_count,
189185
"replayed_token_count": replayed_token_count,
190186
"replay_start_token": replay_start_token,

tests/adapters/test_r3_deserializer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,20 @@ def test_zero_replayed_tokens(self):
266266
assert all(m is None for m in matrices)
267267
assert metadata["replayed_token_count"] == 0
268268

269+
def test_unknown_routing_dtype_falls_back_to_str(self):
270+
"""Unknown routing_dtype ints (e.g. a future dtype=3) must not crash
271+
metadata construction; the dtype is surfaced as its string repr."""
272+
raw = _make_raw_r3(
273+
routing_dtype=99, # not in _RoutingDtype
274+
total_token_count=2,
275+
replayed_token_count=2,
276+
matrix_data=b"\x00" * 8,
277+
)
278+
blob = _compress_and_b64(raw)
279+
280+
_, metadata = decompress_and_parse_r3(blob)
281+
assert metadata["routing_dtype"] == "99"
282+
269283
def test_high_compression_ratio_payload(self):
270284
"""Highly compressible payloads (e.g. tokens routing to the same
271285
experts) can compress much better than 20:1; the deserializer must

0 commit comments

Comments
 (0)