@@ -13,21 +13,22 @@ pub mod onnx {
1313 use dr_transform:: converter:: { BatchPredictionRequestToTorchTensorConverter , Converter } ;
1414 use itertools:: Itertools ;
1515 use log:: { debug, info} ;
16- use ort:: environment:: Environment ;
17- use ort:: session:: Session ;
18- use ort:: tensor:: InputTensor ;
19- use ort:: { ExecutionProvider , GraphOptimizationLevel , SessionBuilder } ;
16+ use dr_transform:: ort:: environment:: Environment ;
17+ use dr_transform:: ort:: session:: Session ;
18+ use dr_transform:: ort:: tensor:: InputTensor ;
19+ use dr_transform:: ort:: { ExecutionProvider , GraphOptimizationLevel , SessionBuilder } ;
20+ use dr_transform:: ort:: LoggingLevel ;
2021 use serde_json:: Value ;
2122 use std:: fmt:: { Debug , Display } ;
2223 use std:: sync:: Arc ;
2324 use std:: { fmt, fs} ;
2425 use tokio:: time:: Instant ;
25-
2626 lazy_static ! {
2727 pub static ref ENVIRONMENT : Arc <Environment > = Arc :: new(
2828 Environment :: builder( )
2929 . with_name( "onnx home" )
30- . with_log_level( ort:: LoggingLevel :: Error )
30+ . with_log_level( LoggingLevel :: Error )
31+ . with_global_thread_pool( ARGS . onnx_global_thread_pool_options. clone( ) )
3132 . build( )
3233 . unwrap( )
3334 ) ;
@@ -101,23 +102,30 @@ pub mod onnx {
101102 let meta_info = format ! ( "{}/{}/{}" , ARGS . model_dir[ idx] , version, META_INFO ) ;
102103 let mut builder = SessionBuilder :: new ( & ENVIRONMENT ) ?
103104 . with_optimization_level ( GraphOptimizationLevel :: Level3 ) ?
104- . with_parallel_execution ( ARGS . onnx_use_parallel_mode == "true" ) ?
105- . with_inter_threads (
106- utils:: get_config_or (
107- model_config,
108- "inter_op_parallelism" ,
109- & ARGS . inter_op_parallelism [ idx] ,
110- )
111- . parse ( ) ?,
112- ) ?
113- . with_intra_threads (
114- utils:: get_config_or (
115- model_config,
116- "intra_op_parallelism" ,
117- & ARGS . intra_op_parallelism [ idx] ,
118- )
119- . parse ( ) ?,
120- ) ?
105+ . with_parallel_execution ( ARGS . onnx_use_parallel_mode == "true" ) ?;
106+ if ARGS . onnx_global_thread_pool_options . is_empty ( ) {
107+ builder = builder
108+ . with_inter_threads (
109+ utils:: get_config_or (
110+ model_config,
111+ "inter_op_parallelism" ,
112+ & ARGS . inter_op_parallelism [ idx] ,
113+ )
114+ . parse ( ) ?,
115+ ) ?
116+ . with_intra_threads (
117+ utils:: get_config_or (
118+ model_config,
119+ "intra_op_parallelism" ,
120+ & ARGS . intra_op_parallelism [ idx] ,
121+ )
122+ . parse ( ) ?,
123+ ) ?;
124+ }
125+ else {
126+ builder = builder. with_disable_per_session_threads ( ) ?;
127+ }
128+ builder = builder
121129 . with_memory_pattern ( ARGS . onnx_use_memory_pattern == "true" ) ?
122130 . with_execution_providers ( & OnnxModel :: ep_choices ( ) ) ?;
123131 match & ARGS . profiling {
0 commit comments