Skip to content

Commit ffc4415

Browse files
committed
optimization
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 864b4d0 commit ffc4415

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,26 +95,27 @@ def __call__(
9595
# reduce memory movement by keeping modules onloaded
9696
with disable_offloading():
9797
# do a preliminary pass to trigger modifier hooks
98-
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
99-
inputs = activations.fetch(batch_idx, subgraph.input_names)
98+
for b_idx in tqdm(range(len(dataloader)), desc=calib_desc):
99+
inputs = activations.fetch(b_idx, subgraph.input_names)
100100
outputs = subgraph.forward(model, **inputs)
101101

102102
if not dataset_args.propagate_error:
103-
activations.update(batch_idx, outputs)
104-
activations.delete(batch_idx, subgraph.consumed_names)
103+
activations.update(b_idx, outputs)
104+
activations.delete(b_idx, subgraph.consumed_names)
105105

106106
LifecycleCallbacks.sequential_epoch_end(subgraph)
107107

108-
# this pass does not trigger modifier hooks
109-
# and is only used for capturing outputs of newly compressed modules
110-
with HooksMixin.disable_hooks():
111-
for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc):
112-
inputs = activations.fetch(batch_idx, subgraph.input_names)
113-
outputs = subgraph.forward(model, **inputs)
114-
115-
if dataset_args.propagate_error:
116-
activations.update(batch_idx, outputs)
117-
activations.delete(batch_idx, subgraph.consumed_names)
108+
if not dataset_args.propagate_error:
109+
# this pass does not trigger modifier hooks
110+
# and is only used for capturing outputs of compressed modules
111+
with HooksMixin.disable_hooks():
112+
for b_idx in tqdm(range(len(dataloader)), desc=prop_desc):
113+
inputs = activations.fetch(b_idx, subgraph.input_names)
114+
outputs = subgraph.forward(model, **inputs)
115+
116+
if dataset_args.propagate_error:
117+
activations.update(b_idx, outputs)
118+
activations.delete(b_idx, subgraph.consumed_names)
118119

119120
# redundant, finish any remaining compression
120121
LifecycleCallbacks.calibration_epoch_end()

0 commit comments

Comments
 (0)