@@ -295,6 +295,20 @@ def _export(
295295 self .onnx_path = onnx_path
296296 return onnx_path
297297
298+ def get_onnx_path (self , prefill_only : Optional [bool ] = False ,
299+ specializations : Optional [List [Dict [str , int ]]] = None ,
300+ offload_pt_weights : Optional [bool ] = True ):
301+ kwargs = {"offload_pt_weights" : offload_pt_weights }
302+ if prefill_only :
303+ if self .prefill_onnx_path is None :
304+ kwargs .update ({"prefill_only" : prefill_only , "prefill_seq_len" : specializations [0 ].get ("seq_len" )})
305+ self .export (** kwargs )
306+ return self .prefill_onnx_path
307+ else :
308+ if self .onnx_path is None :
309+ self .export (** kwargs )
310+ return self .onnx_path
311+
298312 @dump_qconfig
299313 def _compile (
300314 self ,
@@ -335,17 +349,7 @@ def _compile(
335349
336350 For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
337351 """
338- kwargs = {"offload_pt_weights" : offload_pt_weights }
339- if onnx_path is None and prefill_only :
340- kwargs .update ({"prefill_only" : prefill_only , "prefill_seq_len" : specializations [0 ].get ("seq_len" )})
341- self .export (** kwargs )
342- onnx_path = Path (self .prefill_onnx_path )
343- elif onnx_path is None :
344- self .export (** kwargs )
345- onnx_path = Path (self .onnx_path )
346- else :
347- onnx_path = Path (onnx_path )
348-
352+ onnx_path = Path (onnx_path if onnx_path else self .get_onnx_path (prefill_only , specializations , offload_pt_weights ))
349353 compile_dir = Path (compile_dir or onnx_path .parent )
350354 qpc_path = compile_dir / "qpc"
351355 if not onnx_path .is_file ():
0 commit comments