Skip to content

Commit 864b4d0

Browse files
committed
implement propagate_error
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 1c85a66 commit 864b4d0

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,19 @@ class DatasetArguments(CustomDatasetArguments):
223223
quantization_aware_calibration: bool = field(
224224
default=True,
225225
metadata={
226-
"help": "Whether to enable quantization-aware calibration in the pipeline. "
227-
"When True, quantization is applied during forward pass in calibration. "
228-
"When False, quantization is disabled during forward pass in calibration. "
229-
"Default is set to True."
226+
"help": "Only relevant for the sequential pipeline. "
227+
"If True, quantization is applied during forward pass in calibration. "
228+
"If False, quantization is disabled during forward pass in calibration. "
229+
"Default is True."
230+
},
231+
)
232+
propagate_error: bool = field(
233+
default=True,
234+
metadata={
235+
"help": "Only relevant for the sequential pipeline. If True, use quantized "
236+
"layer outputs as the inputs to the next sequential layer. If False, use "
237+
"unquantized layer outputs as the inputs to the next sequential layer. "
238+
"Default is True"
230239
},
231240
)
232241

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ def __call__(
9797
# do a preliminary pass to trigger modifier hooks
9898
for batch_idx in tqdm(range(len(dataloader)), desc=calib_desc):
9999
inputs = activations.fetch(batch_idx, subgraph.input_names)
100-
subgraph.forward(model, **inputs)
100+
outputs = subgraph.forward(model, **inputs)
101+
102+
if not dataset_args.propagate_error:
103+
activations.update(batch_idx, outputs)
104+
activations.delete(batch_idx, subgraph.consumed_names)
101105

102106
LifecycleCallbacks.sequential_epoch_end(subgraph)
103107

@@ -106,10 +110,10 @@ def __call__(
106110
with HooksMixin.disable_hooks():
107111
for batch_idx in tqdm(range(len(dataloader)), desc=prop_desc):
108112
inputs = activations.fetch(batch_idx, subgraph.input_names)
109-
output = subgraph.forward(model, **inputs)
113+
outputs = subgraph.forward(model, **inputs)
110114

111-
if subgraph_index < num_subgraphs - 1:
112-
activations.update(batch_idx, output)
115+
if dataset_args.propagate_error:
116+
activations.update(batch_idx, outputs)
113117
activations.delete(batch_idx, subgraph.consumed_names)
114118

115119
# redundant, finish any remaining compression

0 commit comments

Comments
 (0)