@@ -95,26 +95,27 @@ def __call__(
9595 # reduce memory movement by keeping modules onloaded
9696 with disable_offloading ():
9797 # do a preliminary pass to trigger modifier hooks
98- for batch_idx in tqdm (range (len (dataloader )), desc = calib_desc ):
99- inputs = activations .fetch (batch_idx , subgraph .input_names )
98+ for b_idx in tqdm (range (len (dataloader )), desc = calib_desc ):
99+ inputs = activations .fetch (b_idx , subgraph .input_names )
100100 outputs = subgraph .forward (model , ** inputs )
101101
102102 if not dataset_args .propagate_error :
103- activations .update (batch_idx , outputs )
104- activations .delete (batch_idx , subgraph .consumed_names )
103+ activations .update (b_idx , outputs )
104+ activations .delete (b_idx , subgraph .consumed_names )
105105
106106 LifecycleCallbacks .sequential_epoch_end (subgraph )
107107
108- # this pass does not trigger modifier hooks
109- # and is only used for capturing outputs of newly compressed modules
110- with HooksMixin .disable_hooks ():
111- for batch_idx in tqdm (range (len (dataloader )), desc = prop_desc ):
112- inputs = activations .fetch (batch_idx , subgraph .input_names )
113- outputs = subgraph .forward (model , ** inputs )
114-
115- if dataset_args .propagate_error :
116- activations .update (batch_idx , outputs )
117- activations .delete (batch_idx , subgraph .consumed_names )
108+ if not dataset_args .propagate_error :
109+ # this pass does not trigger modifier hooks
110+ # and is only used for capturing outputs of compressed modules
111+ with HooksMixin .disable_hooks ():
112+ for b_idx in tqdm (range (len (dataloader )), desc = prop_desc ):
113+ inputs = activations .fetch (b_idx , subgraph .input_names )
114+ outputs = subgraph .forward (model , ** inputs )
115+
116+ if dataset_args .propagate_error :
117+ activations .update (b_idx , outputs )
118+ activations .delete (b_idx , subgraph .consumed_names )
118119
119120 # redundant, finish any remaining compression
120121 LifecycleCallbacks .calibration_epoch_end ()
0 commit comments