@@ -2314,6 +2314,35 @@ def get_model_config(self) -> dict:
23142314 """
23152315 return self .model .config .__dict__
23162316
2317+ def get_seq_len_and_handle_specialized_prefill_model (self , prefill_seq_len : Optional [int ] = None ) -> int :
2318+ num_q_blocks = os .environ .get ("NUM_Q_BLOCKS" , None )
2319+ if num_q_blocks is None :
2320+ block_size = 128
2321+ if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128 :
2322+ raise ValueError (
2323+ f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={ block_size } . "
2324+ f"Or set `NUM_BLOCKS` ENV variable"
2325+ f"Received: prefill_seq_len={ prefill_seq_len } "
2326+ )
2327+
2328+ num_q_blocks = prefill_seq_len // block_size
2329+ logger .warning (
2330+ f"Setting NUM_BLOCKS={ num_q_blocks } used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override"
2331+ )
2332+ os .environ ["NUM_Q_BLOCKS" ] = num_q_blocks
2333+
2334+ num_ffn_blocks = os .environ .get ("NUM_FFN_BLOCKS" , None )
2335+ min_seq_len = int (max (num_q_blocks , num_ffn_blocks )) if num_ffn_blocks else num_q_blocks
2336+
2337+ self .prefill (True )
2338+ self .hash_params ["prefill_only" ] = True
2339+ self .hash_params ["num_blocks" ] = os .environ ["NUM_BLOCKS" ]
2340+ return (
2341+ min_seq_len
2342+ if min_seq_len > constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN
2343+ else constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN
2344+ )
2345+
23172346 def export (
23182347 self ,
23192348 export_dir : Optional [str ] = None ,
@@ -2345,25 +2374,11 @@ def export(
23452374 fbs : int = constants .ONNX_EXPORT_EXAMPLE_FBS
23462375 if prefill_only :
23472376 assert not self .continuous_batching , "prefill_only=True is not supported with continuous_batching=True"
2348-
2349- if self .model .config .model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH :
2350- block_size = os .environ .get ("BLOCK_SIZE" , None )
2351- if block_size is None :
2352- block_size = 128
2353- logger .warning (
2354- "Setting BLOCK_SIZE=128 for prefill_only model, please set ENV variable `BLOCK_SIZE` to override"
2355- )
2356- if prefill_seq_len is None or prefill_seq_len % block_size != 0 :
2357- raise ValueError (
2358- f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={ block_size } . "
2359- f"Received: prefill_seq_len={ prefill_seq_len } "
2360- )
2361- os .environ ["NUM_BLOCKS" ] = str (prefill_seq_len // block_size )
2362-
2363- self .prefill (True )
2364- self .hash_params ["prefill_only" ] = True
2365- self .hash_params ["num_blocks" ] = os .environ ["NUM_BLOCKS" ]
2366- seq_len = prefill_seq_len // block_size if (prefill_seq_len // block_size ) > seq_len else seq_len
2377+ seq_len = (
2378+ self .get_seq_len_and_handle_specialized_prefill_model (prefill_seq_len )
2379+ if self .model .config .model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
2380+ else seq_len
2381+ )
23672382 else :
23682383 self .prefill (False )
23692384 self .hash_params .pop ("prefill_only" , None )
0 commit comments