@@ -32,6 +32,7 @@ pub struct LazyGraphExecutor {
3232 pass_index : u64 ,
3333 inplace_support : bool ,
3434 caching_enabled : bool ,
35+ shared_object_allocation_enabled : bool ,
3536}
3637
3738fn panic_cycle ( id : TensorId ) {
@@ -96,14 +97,19 @@ fn compute_post_order_from_nodes(roots: Vec<&Tensor>) -> PostOrderData {
9697}
9798
9899impl LazyGraphExecutor {
99- pub fn new ( inplace_support : bool , caching_enabled : bool ) -> Self {
100+ pub fn new (
101+ inplace_support : bool ,
102+ caching_enabled : bool ,
103+ shared_object_allocation_enabled : bool ,
104+ ) -> Self {
100105 Self {
101106 tensors : Arc :: new ( RwLock :: new ( BTreeMap :: default ( ) ) ) ,
102107 cache : HashMap :: default ( ) ,
103108 pass_index : Default :: default ( ) ,
104109 inplace_support,
105110 step_log_config : None ,
106111 caching_enabled,
112+ shared_object_allocation_enabled,
107113 }
108114 }
109115
@@ -413,15 +419,25 @@ impl LazyGraphExecutor {
413419 }
414420 }
415421
416- let mut cached_exec = if use_cache {
422+ let ( mut cached_exec, do_shared_realloc , is_shared_realloc ) = if use_cache {
417423 self . cache
418424 . remove ( & hash)
419- . and_then ( |cached_exec| Arc :: try_unwrap ( cached_exec. executable ) . ok ( ) )
425+ . map ( |cached_exec| {
426+ if cached_exec. is_shared_realloc {
427+ // Cache hit, no need to realloc, shared realloc
428+ ( Arc :: try_unwrap ( cached_exec. executable ) . ok ( ) , false , true )
429+ } else {
430+ // Cache hit, not shared realloc and needs shared realloc, not yet shared
431+ // realloc
432+ ( None , true , false )
433+ }
434+ } )
435+ // Cache miss, no need to realloc, can't be shared realloc
436+ . unwrap_or ( ( None , false , false ) )
420437 } else {
421- None
438+ // Not using cache, no need to realloc, we don't allow shared realloc
439+ ( None , false , false )
422440 } ;
423- let do_shared_realloc = false ;
424- let is_shared_realloc = false ;
425441
426442 let mut compiled_ops = Vec :: with_capacity ( post_order. len ( ) ) ;
427443
@@ -681,7 +697,8 @@ impl LazyGraphExecutor {
681697 hash,
682698 CachedExecutable {
683699 executable : Arc :: new ( executable) ,
684- is_shared_realloc : do_shared_realloc,
700+ // If we already did a shared realloc, we don't need to do it again
701+ is_shared_realloc : is_shared_realloc || do_shared_realloc,
685702 } ,
686703 ) ;
687704 }
@@ -719,6 +736,14 @@ impl LazyGraphExecutor {
719736 self . caching_enabled
720737 }
721738
739+ pub fn set_shared_object_allocation_enabled ( & mut self , enabled : bool ) {
740+ self . shared_object_allocation_enabled = enabled;
741+ }
742+
743+ pub fn shared_object_allocation_enabled ( & self ) -> bool {
744+ self . shared_object_allocation_enabled
745+ }
746+
722747 pub fn set_inplace_support ( & mut self , enabled : bool ) {
723748 self . inplace_support = enabled;
724749 }
@@ -730,7 +755,7 @@ impl LazyGraphExecutor {
730755
731756impl Default for LazyGraphExecutor {
732757 fn default ( ) -> Self {
733- Self :: new ( false , false )
758+ Self :: new ( false , false , false )
734759 }
735760}
736761
0 commit comments