Skip to content

Commit f1e393d

Browse files
early return
1 parent 4e60711 commit f1e393d

1 file changed

Lines changed: 32 additions & 31 deletions

File tree

eval_protocol/adapters/r3_deserializer.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,12 @@
1414
from __future__ import annotations
1515

1616
import base64
17-
import logging
1817
import struct
1918
from enum import IntEnum
2019
from typing import Any, Dict, List, Optional, Tuple
2120

2221
import zstandard as zstd
2322

24-
logger = logging.getLogger(__name__)
25-
2623
MAGIC = b"R3V1"
2724
HEADER_FORMAT = "<4sBBBBIIIIQ"
2825
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) # 32 bytes
@@ -130,27 +127,40 @@ def decompress_and_parse_r3(
130127
selector_byte_length = header["selector_byte_length"]
131128
matrix_byte_length = header["matrix_byte_length"]
132129

130+
metadata: Dict[str, Any] = {
131+
"routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)),
132+
"selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)),
133+
"total_token_count": total_token_count,
134+
"replayed_token_count": replayed_token_count,
135+
"replay_start_token": replay_start_token,
136+
}
137+
138+
if replayed_token_count == 0:
139+
return [None] * total_token_count, metadata
140+
133141
# Per-token matrix byte size is implicit in the payload: all replayed
134142
# tokens share the same matrix length, so we can recover it from the
135143
# matrix section total length divided by the replayed-token count.
136-
if replayed_token_count > 0:
137-
if matrix_byte_length % replayed_token_count != 0:
138-
raise ValueError(
139-
f"matrix_byte_length ({matrix_byte_length}) is not a multiple of "
140-
f"replayed_token_count ({replayed_token_count}); cannot split "
141-
"into per-token matrices"
142-
)
143-
matrix_elem_size = matrix_byte_length // replayed_token_count
144-
else:
145-
matrix_elem_size = 0
144+
if matrix_byte_length % replayed_token_count != 0:
145+
raise ValueError(
146+
f"matrix_byte_length ({matrix_byte_length}) is not a multiple of "
147+
f"replayed_token_count ({replayed_token_count}); cannot split "
148+
"into per-token matrices"
149+
)
150+
matrix_elem_size = matrix_byte_length // replayed_token_count
146151

147152
body = raw[HEADER_SIZE:]
153+
expected_body_length = selector_byte_length + matrix_byte_length
154+
if len(body) < expected_body_length:
155+
raise ValueError(
156+
f"Payload body too short for selector and matrix sections: "
157+
f"{len(body)} < {expected_body_length}"
158+
)
159+
148160
selector_bytes = body[:selector_byte_length]
149161
matrix_bytes = body[selector_byte_length : selector_byte_length + matrix_byte_length]
150162

151-
if matrix_elem_size == 0:
152-
replayed_positions: List[int] = []
153-
elif selector_mode == _SelectorMode.ALL:
163+
if selector_mode == _SelectorMode.ALL:
154164
replayed_positions = list(range(total_token_count))
155165
elif selector_mode == _SelectorMode.SUFFIX:
156166
replayed_positions = list(
@@ -161,26 +171,17 @@ def decompress_and_parse_r3(
161171
else:
162172
raise ValueError(f"Unknown selector_mode: {selector_mode}")
163173

174+
if len(replayed_positions) != replayed_token_count:
175+
raise ValueError(
176+
f"Selector produced {len(replayed_positions)} replayed positions, "
177+
f"but header replayed_token_count is {replayed_token_count}"
178+
)
179+
164180
# Split matrix bytes into per-token chunks and base64-encode each one
165181
matrices: List[Optional[str]] = [None] * total_token_count
166182
for idx, pos in enumerate(replayed_positions):
167183
start = idx * matrix_elem_size
168184
end = start + matrix_elem_size
169-
if end > len(matrix_bytes):
170-
logger.warning(
171-
"R3 matrix data truncated at token %d (position %d): "
172-
"expected %d bytes but only %d remaining",
173-
idx, pos, matrix_elem_size, len(matrix_bytes) - start,
174-
)
175-
break
176185
matrices[pos] = base64.b64encode(matrix_bytes[start:end]).decode("ascii")
177186

178-
metadata: Dict[str, Any] = {
179-
"routing_dtype": _ROUTING_DTYPE_NAMES.get(routing_dtype, str(routing_dtype)),
180-
"selector_mode": _SELECTOR_MODE_NAMES.get(selector_mode, str(selector_mode)),
181-
"total_token_count": total_token_count,
182-
"replayed_token_count": replayed_token_count,
183-
"replay_start_token": replay_start_token,
184-
}
185-
186187
return matrices, metadata

0 commit comments

Comments
 (0)