1414from __future__ import annotations
1515
1616import base64
17- import logging
1817import struct
1918from enum import IntEnum
2019from typing import Any , Dict , List , Optional , Tuple
2120
2221import zstandard as zstd
2322
24- logger = logging .getLogger (__name__ )
25-
2623MAGIC = b"R3V1"
2724HEADER_FORMAT = "<4sBBBBIIIIQ"
2825HEADER_SIZE = struct .calcsize (HEADER_FORMAT ) # 32 bytes
@@ -130,27 +127,40 @@ def decompress_and_parse_r3(
130127 selector_byte_length = header ["selector_byte_length" ]
131128 matrix_byte_length = header ["matrix_byte_length" ]
132129
130+ metadata : Dict [str , Any ] = {
131+ "routing_dtype" : _ROUTING_DTYPE_NAMES .get (routing_dtype , str (routing_dtype )),
132+ "selector_mode" : _SELECTOR_MODE_NAMES .get (selector_mode , str (selector_mode )),
133+ "total_token_count" : total_token_count ,
134+ "replayed_token_count" : replayed_token_count ,
135+ "replay_start_token" : replay_start_token ,
136+ }
137+
138+ if replayed_token_count == 0 :
139+ return [None ] * total_token_count , metadata
140+
133141 # Per-token matrix byte size is implicit in the payload: all replayed
134142 # tokens share the same matrix length, so we can recover it from the
135143 # matrix section total length divided by the replayed-token count.
136- if replayed_token_count > 0 :
137- if matrix_byte_length % replayed_token_count != 0 :
138- raise ValueError (
139- f"matrix_byte_length ({ matrix_byte_length } ) is not a multiple of "
140- f"replayed_token_count ({ replayed_token_count } ); cannot split "
141- "into per-token matrices"
142- )
143- matrix_elem_size = matrix_byte_length // replayed_token_count
144- else :
145- matrix_elem_size = 0
144+ if matrix_byte_length % replayed_token_count != 0 :
145+ raise ValueError (
146+ f"matrix_byte_length ({ matrix_byte_length } ) is not a multiple of "
147+ f"replayed_token_count ({ replayed_token_count } ); cannot split "
148+ "into per-token matrices"
149+ )
150+ matrix_elem_size = matrix_byte_length // replayed_token_count
146151
147152 body = raw [HEADER_SIZE :]
153+ expected_body_length = selector_byte_length + matrix_byte_length
154+ if len (body ) < expected_body_length :
155+ raise ValueError (
156+ f"Payload body too short for selector and matrix sections: "
157+ f"{ len (body )} < { expected_body_length } "
158+ )
159+
148160 selector_bytes = body [:selector_byte_length ]
149161 matrix_bytes = body [selector_byte_length : selector_byte_length + matrix_byte_length ]
150162
151- if matrix_elem_size == 0 :
152- replayed_positions : List [int ] = []
153- elif selector_mode == _SelectorMode .ALL :
163+ if selector_mode == _SelectorMode .ALL :
154164 replayed_positions = list (range (total_token_count ))
155165 elif selector_mode == _SelectorMode .SUFFIX :
156166 replayed_positions = list (
@@ -161,26 +171,17 @@ def decompress_and_parse_r3(
161171 else :
162172 raise ValueError (f"Unknown selector_mode: { selector_mode } " )
163173
174+ if len (replayed_positions ) != replayed_token_count :
175+ raise ValueError (
176+ f"Selector produced { len (replayed_positions )} replayed positions, "
177+ f"but header replayed_token_count is { replayed_token_count } "
178+ )
179+
164180 # Split matrix bytes into per-token chunks and base64-encode each one
165181 matrices : List [Optional [str ]] = [None ] * total_token_count
166182 for idx , pos in enumerate (replayed_positions ):
167183 start = idx * matrix_elem_size
168184 end = start + matrix_elem_size
169- if end > len (matrix_bytes ):
170- logger .warning (
171- "R3 matrix data truncated at token %d (position %d): "
172- "expected %d bytes but only %d remaining" ,
173- idx , pos , matrix_elem_size , len (matrix_bytes ) - start ,
174- )
175- break
176185 matrices [pos ] = base64 .b64encode (matrix_bytes [start :end ]).decode ("ascii" )
177186
178- metadata : Dict [str , Any ] = {
179- "routing_dtype" : _ROUTING_DTYPE_NAMES .get (routing_dtype , str (routing_dtype )),
180- "selector_mode" : _SELECTOR_MODE_NAMES .get (selector_mode , str (selector_mode )),
181- "total_token_count" : total_token_count ,
182- "replayed_token_count" : replayed_token_count ,
183- "replay_start_token" : replay_start_token ,
184- }
185-
186187 return matrices , metadata
0 commit comments