-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathquantization.html
More file actions
531 lines (480 loc) · 28 KB
/
quantization.html
File metadata and controls
531 lines (480 loc) · 28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Quantization Explained with PyTorch</title>
<style>
body {
font-family: sans-serif;
max-width: 900px;
margin: 0 auto;
padding: 2rem;
line-height: 1.6;
color: #333;
}
h1, h2, h3, h4 {
color: #2c3e50; /* Darker shade for better contrast */
margin-top: 1.5em;
margin-bottom: 0.5em;
}
h1 { font-size: 2.5em; border-bottom: 2px solid #3498db; padding-bottom: 0.3em;}
h2 { font-size: 2em; border-bottom: 1px solid #bdc3c7; padding-bottom: 0.2em;}
h3 { font-size: 1.5em; }
h4 { font-size: 1.2em; color: #555;}
nav { margin-bottom: 30px; padding: 10px; background: #ecf0f1; border: 1px solid #bdc3c7; border-radius: 4px;}
nav ul { list-style: none; padding: 0; }
nav li { display: inline-block; margin-right: 15px; }
nav a { text-decoration: none; color: #3498db; font-weight: bold;}
nav a:hover { text-decoration: underline; color: #2980b9;}
pre {
background: #f8f9f9; /* Lighter background for code blocks */
padding: 1rem;
overflow-x: auto;
border: 1px solid #e1e4e8; /* Softer border */
border-left: 4px solid #3498db; /* Accent border */
border-radius: 4px;
font-size: 0.9em;
}
code {
font-family: 'SFMono-Regular', Consolas, 'Liberation Mono', Menlo, Courier, monospace;
}
/* For inline code */
p > code, li > code, table td > code {
background: #e8eaed;
padding: 0.2em 0.4em;
border-radius: 3px;
font-size: 0.85em;
}
pre code { /* Reset for code inside pre, already handled by pre styling */
background: none;
padding: 0;
font-size: 1em; /* Ensure pre's font size is inherited */
}
ul {
padding-left: 20px;
}
li {
margin-bottom: 0.5em;
}
strong {
color: #2980b9;
}
</style>
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/styles/monokai.min.css">
<script src="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/11.9.0/highlight.min.js"></script>
<script>hljs.highlightAll();</script>
</head>
<body>
<nav>
<ul>
<li><a href="index.html">Home</a></li>
<li><a href="https://github.com/raminaeye/ML-Concepts/blob/main/notebooks/quantization/qunatization.ipynb">Quantization-Notebook</a></li>
</ul>
</nav>
<h1>Quantization</h1>
<p>Smaller model means less memory bandwidth. Faster inference because fewer bits and fewer multiplies. Lower power because integer ops are cheaper than float on most hardware. You get the idea.</p>
<p>There are a couple types of quantization: <strong>Post-Training Quantization (PTQ)</strong> and <strong>Quantization-Aware Training (QAT)</strong>. PTQ comes in two flavors: <strong>Dynamic</strong> and <strong>Static</strong>. Let’s break them down.</p>
<section>
<h2>Dynamic Quantization</h2>
<p>This is the lazy version of quantization—and that’s not a bad thing. Super simple. You quantize the weights after training. Activations stay in float32 during inference but get quantized on the fly. No calibration. No fuss.</p>
<pre><code class="language-python">
import torch.quantization
model = YourModel()
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
</code></pre>
<p>In this setup:</p>
<ul>
<li>Weights → int8</li>
<li>Activations → float32</li>
<li>Quantization of activations happens at runtime, per batch</li>
<li>Dequantization happens immediately after matmul</li>
</ul>
<p>So why use it? It’s stupid easy and gives you a fast win on CPUs—especially for LSTMs and Linear-heavy models.</p>
<p>But why doesn’t it work well for CNNs? Because Conv2d is not matmul. Conv layers have gnarly memory access patterns and need pre-quantized activations to be efficient. Dynamic quantization doesn’t do that. You’d be quantizing every sliding window patch per inference. Not fun. Not fast.</p>
<p>Also, dynamic quantization doesn’t touch things like BatchNorm or ReLU. It just swaps in dynamic versions of Linear and LSTM layers. So, no fusion here. No point fusing layers if you're not touching them.</p>
</section>
<section>
<h2>Static Quantization</h2>
<p>Now we’re doing it properly. Both weights and activations go to int8. But to quantize activations, you first need to know their range. That’s where <strong>calibration</strong> comes in. You run some input data through the model (no labels needed) to collect min and max values of activations. PyTorch uses that to decide how to quantize.</p>
<p>This works great for CNNs and edge deployment. You get full int8 inference.</p>
<p>The workflow looks like this:</p>
<pre><code class="language-python">
class QuantCNN(nn.Module):
def fuse_model(self):
# Example: torch.quantization.fuse_modules(self.conv_relu_group, ['0', '1', '2'], inplace=True) # if conv, bn, relu are grouped
# Or fuse specific layers:
for module_name, module in self.named_children():
if "conv_bn_relu" in module_name: # A common pattern
torch.quantization.fuse_modules(module, ['conv', 'bn', 'relu'], inplace=True)
pass # Actual fusion logic here
model = QuantCNN() # Instantiate your model
model.eval()
model.fuse_model() # Apply fusion if defined in your model
# Set qconfig
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # or 'qnnpack' for ARM
# Insert observers
torch.quantization.prepare(model, inplace=True)
# Calibration step
with torch.no_grad():
for images in calibration_loader: # calibration_loader should yield batches of input data
model(images)
# Convert to int8
torch.quantization.convert(model, inplace=True)
</code></pre>
<p>Let’s talk about what’s going on.</p>
<h3>Fusion</h3>
<p>Before quantizing, you fuse Conv + BatchNorm + ReLU into a single op. Why? It reduces memory overhead, speeds things up, and minimizes numerical errors.</p>
<h3>QConfig</h3>
<p>QConfig tells PyTorch how to quantize: symmetric vs asymmetric, per-tensor vs per-channel. Example:</p>
<pre><code class="language-python">
from torch.quantization import default_per_channel_weight_observer, MinMaxObserver
import torch # Ensure torch is imported for dtypes
my_qconfig = torch.quantization.QConfig(
activation=MinMaxObserver.with_args(dtype=torch.quint8),
weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8) # Specify dtype for weights
)
model.qconfig = my_qconfig # Then assign it to your model
</code></pre>
<h3>Observers</h3>
<p>Observers are how PyTorch measures activation ranges. Here's a basic one (conceptual, PyTorch's <code>ObserverBase</code> is more complex):</p>
<pre><code class="language-python">
from torch.ao.quantization.observer import ObserverBase
class MinMaxObserverExample(ObserverBase): # Conceptual
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.min_val = float('inf')
self.max_val = float('-inf')
def forward(self, x):
self.min_val = min(self.min_val, x.min().item())
self.max_val = max(self.max_val, x.max().item())
return x
</code></pre>
<p>During calibration, these observers record activation stats. Then PyTorch figures out the scale and zero-point needed to map float values to int8.</p>
<p>But MinMax isn’t always enough.</p>
</section>
<section>
<h2>Histogram Observers</h2>
<p>MinMax gets clobbered by outliers. Histogram observers are smarter. They build a histogram of the activations, simulate multiple clipping thresholds, and pick the one that gives the smallest quantization error.</p>
<pre><code class="language-python">
import torch.quantization as tq
qconfig_hist = tq.QConfig(
activation=tq.observer.HistogramObserver.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_affine
),
weight=tq.default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)
model.qconfig = qconfig_hist
</code></pre>
<p>They look at the L2 error between the original and the quantized-dequantized tensor and find the best range for preserving accuracy.</p>
<p>This is where quantization actually starts acting intelligent. Especially important when your activations have long tails, skew, or multimodal distributions.</p>
</section>
<section>
<h2>Quick Refresher: Scale and Zero-Point</h2>
<p>When you quantize a float:</p>
<pre><code>Q = round(X / scale) + zero_point</code></pre>
<p>And to dequantize:</p>
<pre><code>X = scale * (Q - zero_point)</code></pre>
<h3>Symmetric</h3>
<ul>
<li><code>zero_point = 0</code></li>
<li><code>scale = max(abs(min_val), abs(max_val)) / 127</code> (for int8, typically range [-127, 127] or [-128, 127])</li>
</ul>
<p>Used for weights. Assumes distribution is centered around 0 (which is usually true for weights).</p>
<h3>Asymmetric</h3>
<ul>
<li><code>zero_point ≠ 0</code></li>
<li><code>scale = (x_max - x_min) / 255</code> (for quint8, typically range [0, 255])</li>
<li><code>zero_point = round(qmin - x_min / scale)</code> (adjusted to be within [qmin, qmax])</li>
</ul>
<p>Used for activations. Better when data is skewed (e.g., ReLU output).</p>
</section>
<section>
<h2>Calibration Matters</h2>
<p>Bad calibration = bad quantization.</p>
<p>If your calibration dataset is too small or unrepresentative, the model sees weird ranges and either clips too much or wastes resolution. The more realistic your calibration inputs, the better your scale and zero-point.</p>
<p>Also, remember: PTQ is a one-shot deal. Once you calibrate and convert, you don’t get to fine-tune anymore. It’s set in stone.</p>
</section>
<section>
<h2>Be Smart About Quant Boundaries</h2>
<p>Quantizing just part of a model? You’d better manage the float-int transitions carefully. These quant-dequant boundaries are expensive. You’re moving data from int8 to float and back. It kills performance and breaks layer fusion.</p>
<p>Once you quantize, stay quantized as long as possible. You can safely quantize ReLU, pooling, and many elementwise ops.</p>
</section>
<section>
<h2>Visualize with FX</h2>
<p>Use FX to trace and manipulate the model graphically.</p>
<pre><code class="language-python">
from torch.fx import symbolic_trace
graph = symbolic_trace(quantized_model)
print(graph.graph) # For a textual representation
graph.print_tabular() # For a more readable table
</code></pre>
<p>FX makes it easier to:</p>
<ul>
<li>Insert observers</li>
<li>Fuse layers</li>
<li>Quantize cleanly</li>
<li>Inspect the graph</li>
</ul>
<p>It’s the modern way to do quantization in PyTorch. Forget the old Eager mode unless you like pain.</p>
<pre><code class="language-python">
import torch # Ensure torch is imported for torch.quantization
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
</code></pre>
<p>FX takes care of applying qconfig recursively, inserting only the observers you need, and fusing layers correctly.</p>
</section>
<section>
<h2>PTQ Example: CNN with Manual Observers</h2>
<p>Let’s do an example with a CNN model to wrap up PTQ. We manually add <code>PerChannelMinMaxObserver</code> for weights and <code>MinMaxObserver</code> for activation.</p>
<pre><code class="language-python">
class ObserverCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.pool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(32, 10)
# Manual observers
self.obs_act1 = MinMaxObserver(dtype=torch.quint8)
self.obs_act2 = MinMaxObserver(dtype=torch.quint8)
self.obs_fc_in = MinMaxObserver(dtype=torch.quint8)
self.obs_weight1 = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8)
self.obs_weight2 = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8)
self.obs_weight_fc = PerChannelMinMaxObserver(ch_axis=0, dtype=torch.qint8) # For Linear weights
def forward(self, x):
# Observe conv1 weight
self.obs_weight1(self.conv1.weight)
x = self.conv1(x) # Apply conv1
x = self.relu1(x) # Apply relu1
self.obs_act1(x) # Observe output of relu1
# Observe conv2 weight
self.obs_weight2(self.conv2.weight)
x = self.conv2(x) # Apply conv2
x = self.relu2(x) # Apply relu2
self.obs_act2(x) # Observe output of relu2
x = self.pool(x)
x = x.view(x.size(0), -1)
self.obs_weight_fc(self.fc.weight) # Observe fc weight
self.obs_fc_in(x) # Observe input to fc layer
x = self.fc(x) # Apply fc
return x
model_ptq_example = ObserverCNN().eval()
# Fake calibration pass
with torch.no_grad():
for _ in range(10): # simulate 10 batches
dummy_input = torch.randn(8, 3, 32, 32)
_ = model_ptq_example(dummy_input)
print("Conv1 activation range:", model_ptq_example.obs_act1.min_val.item(), model_ptq_example.obs_act1.max_val.item())
print("Conv2 activation range:", model_ptq_example.obs_act2.min_val.item(), model_ptq_example.obs_act2.max_val.item())
print("FC input range:", model_ptq_example.obs_fc_in.min_val.item(), model_ptq_example.obs_fc_in.max_val.item())
print("Conv1 weight qparams:", model_ptq_example.obs_weight1.calculate_qparams())
print("Conv2 weight qparams:", model_ptq_example.obs_weight2.calculate_qparams())
print("FC weight qparams:", model_ptq_example.obs_weight_fc.calculate_qparams())
</code></pre>
<p><code>calculate_qparams()</code> is a method used by observers in PyTorch to compute scale and zero-point. That's the observer's job to begin with.</p>
<p>General quantization formula: <code>quantized = round(clamp(x / scale + zero_point, qmin, qmax))</code></p>
<h4>For Symmetric Quantization:</h4>
<ul>
<li><code>scale = max(abs(min_val), abs(max_val)) / 127</code> (or 128)</li>
<li><code>zero_point = 0</code></li>
</ul>
<h4>For Asymmetric Quantization:</h4>
<ul>
<li><code>scale = (max_val - min_val) / (qmax - qmin)</code></li>
<li><code>zero_point = round(qmin - min_val / scale)</code></li>
</ul>
<p>Coming up next: Quantization Aware Training, where we teach the model to embrace the pain and learn to live with it.</p>
</section>
<section>
<h2>Quantization-Aware Training (QAT)</h2>
<p>If PTQ drops too much accuracy, QAT is how you can recover it. QAT requires retraining; it’s slower but more accurate. It learns to suppress outliers and it works well with tuning. You start with inserting observers (which are actually <code>FakeQuantize</code> modules in QAT) just like PTQ. These <code>FakeQuantize</code> modules are differentiable. The model learns to adapt to quantization noise.</p>
<pre><code class="language-python">
from torch.ao.quantization import get_default_qat_qconfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx
qat_qconfig = get_default_qat_qconfig('fbgemm')
qconfig_dict_qat = {"": qat_qconfig}
model_for_qat = ObserverCNN() # Assuming ObserverCNN or similar model
model_for_qat.train() # Set model to training mode for QAT
# Note: prepare_qat_fx typically requires the model to be in training mode.
model_prepared_qat = prepare_qat_fx(model_for_qat, qconfig_dict_qat)
</code></pre>
<p>Then you train the <code>model_prepared_qat</code> and convert it to int8 using <code>convert_fx(model_prepared_qat.eval())</code>. Inside the <code>FakeQuantize</code> module, it takes <code>x</code>, computes scale and zero-point, and simulates quantization like:</p>
<pre><code>
q = round(clamp(x / scale + zero_point))
x_hat = (q - zero_point) * scale
</code></pre>
<p>QAT can learn activation clipping, push small weights out of dead zones, and learn quantization-friendly distributions.</p>
<p>QAT simulates quantization in the forward pass. Rounding and clamping are, however, both non-differentiable, so you can't directly backpropagate through them. QAT uses a <strong>Straight-Through Estimator (STE)</strong>, which basically means this operation has no gradient for the non-differentiable part. It pretends the output equals the input during backpropagation for that specific part.</p>
<pre><code class="language-python">
class RoundSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output # straight-through (acts like identity for the rounding part)
</code></pre>
<p>STE allows the model to learn parameters and "survive" quantization during training.</p>
<p>Now let’s go back to <code>MovingAverageMinMaxObserver</code> (often used within <code>FakeQuantize</code> for QAT). It uses an exponential moving average of min/max values observed during training. <code>FakeQuantize</code> modules are inserted during <code>prepare_qat_fx()</code> and are active during the training phase. After training, during <code>convert_fx()</code>, they’re replaced with real quantization operations.</p>
<p>The moving average update rules are typically:</p>
<pre><code>
min_val = (1 - α) * min_val + α * current_batch_min_val
max_val = (1 - α) * max_val + α * current_batch_max_val
</code></pre>
<p>Where α (alpha) is a momentum term.</p>
</section>
<section>
<h2>QAT Training & FakeQuantize Details</h2>
<p>Training a QAT model often requires lower learning rates and more epochs to converge. It also needs careful initialization, for example, starting from a pretrained FP32 model. Similar to static quantization, layers such as BatchNorm (BN) are fused with preceding convolution or linear layers during the conversion step (<code>convert_fx</code>).</p>
<p>After inserting <code>FakeQuantize</code> modules and monitoring ranges, it helps to fine-tune the model to the "frozen" quantization noise. This means that after an initial "warm-up" phase where quantization parameters (scale and zero-point) are being learned, these parameters are then fixed, and the model continues training to adapt to these specific, now constant, quantization effects.</p>
<h3>A note on observers and FakeQuantize modules</h3>
<p>Inside the <code>FakeQuantize</code> modules, there is an observer instance (or logic) and a quantize-dequantize function. You can inspect them. The <code>activation_post_process</code> attribute on a layer (e.g., <code>model.layer_name.activation_post_process</code>) *is* the <code>FakeQuantize</code> module itself after <code>prepare_qat_fx</code>.</p>
<pre><code class="language-python">
if hasattr(model_prepared_qat, 'conv1') and hasattr(model_prepared_qat.conv1, 'activation_post_process'):
fq_module = model_prepared_qat.conv1.activation_post_process # This IS the FakeQuantize module
current_scale, current_zero_point = fq_module.calculate_qparams()
print(f"Scale: {current_scale}, Zero-point: {current_zero_point}")
if hasattr(fq_module, 'min_val') and hasattr(fq_module, 'max_val'):
print(f"Observed min by FakeQuant: {fq_module.min_val.item()}, Observed max: {fq_module.max_val.item()}")
else:
print("conv1 or activation_post_process not found in conceptual model_prepared_qat.")
</code></pre>
<p>During the forward pass, observers (within <code>FakeQuantize</code>) are enabled to calculate the scale and zero-point. The fake quantizer then applies the quantize-dequantize simulation and uses STE during backpropagation. During QAT training (warm-up phase), observers are active, updating the range statistics, and the model learns under these evolving quantization conditions. Once the warm-up is done, the scale and zero-point are often "frozen." Training continues with this consistent quantization noise to improve stability.</p>
</section>
<section>
<h2>Quantization Bugs</h2>
<p>Common issues include:</p>
<ul>
<li><strong>Accuracy drops. A lot.</strong>
<ul>
<li>Bad scale/zero-point, poor calibration.</li>
<li>Layers not fused properly or unnecessary quant-dequant boundaries.</li>
<li>ReLU ranges are too wide skewing observers.</li>
<li>Using per-tensor quant when you should use per-channel.</li>
</ul>
</li>
<li><strong>Layers behave abnormally.</strong>
<ul>
<li>Observer didn’t see representative data.</li>
<li>Quant-dequant boundary in the wrong place.</li>
<li>Rounding and clamping causing clipping.</li>
</ul>
</li>
<li><strong>FakeQuant not simulating, model not converging (QAT).</strong>
<ul>
<li>Observers are not updating.</li>
<li>Observers are frozen too soon.</li>
<li>FakeQuant clipping too aggressively.</li>
</ul>
</li>
<li><strong><code>convert_fx()</code> fails - this happens a lot!</strong>
<ul>
<li>Model wasn’t fused correctly.</li>
<li>QConfig was missing for a submodule.</li>
<li>FX graph couldn’t trace.</li>
</ul>
</li>
</ul>
</section>
<section>
<h2>How to Trace the Bugs?</h2>
<h3>Print quantization parameters</h3>
<pre><code class="language-python">
for name, module in model_prepared_qat.named_modules():
if isinstance(module, torch.ao.quantization.FakeQuantize):
# For FakeQuantize modules, scale and zero_point are attributes
print(f"{name} -- scale: {module.scale.item() if torch.is_tensor(module.scale) else module.scale}, "
f"zero_point: {module.zero_point.item() if torch.is_tensor(module.zero_point) else module.zero_point}")
elif hasattr(module, 'weight_fake_quant'): # For layers with separate weight fake_quant
print(f"{name}.weight_fake_quant -- scale: {module.weight_fake_quant.scale.item()}, "
f"zero_point: {module.weight_fake_quant.zero_point.item()}")
</code></pre>
<p>If you see scale is zero or NaN, then the observers aren’t seeing anything. If the scale is too large or too small, the range is skewed due to outliers.</p>
<h3>Visualize activation ranges</h3>
<p>It’s important to see what range of values a layer is producing. Make sure the quantization observers capture the range properly and avoid outliers. If the range is off you’ll get clipping, dead activation, and poor resolution due to wide range. Hook into a layer to see float ranges and compare it with what you see from the observer reports.</p>
<pre><code class="language-python">
activation_histograms = {}
def capture_activations_hook(name):
def hook(module, input_val, output_val):
data_to_log = output_val
if isinstance(output_val, tuple): data_to_log = output_val[0] # Handle tuple outputs
if torch.is_tensor(data_to_log):
activation_histograms[name] = data_to_log.detach().cpu().flatten().numpy()
return hook
model_to_debug = ObserverCNN().eval() # Or your specific model, ensure ObserverCNN is defined
hook_conv1 = model_to_debug.conv1.register_forward_hook(capture_activations_hook("conv1_output"))
hook_relu1 = model_to_debug.relu1.register_forward_hook(capture_activations_hook("relu1_output"))
with torch.no_grad():
_ = model_to_debug(torch.randn(1, 3, 32, 32)) # Adjust input shape as needed
for name, data in activation_histograms.items():
plt.figure()
plt.hist(data, bins=100)
plt.title(f"Activation Histogram: {name}")
plt.xlabel("Value"); plt.ylabel("Frequency"); plt.grid(True)
plt.show()
hook_conv1.remove() # Clean up hooks
hook_relu1.remove()
</code></pre>
<p>Compare float vs quantized outputs. If a single layer is the issue, that narrows down the layer-wise outputs.</p>
<h3>Disable quantization from layers one by one</h3>
<pre><code class="language-python">
if hasattr(model_prepared_qat, 'conv1') and \
hasattr(model_prepared_qat.conv1, 'activation_post_process') and \
hasattr(model_prepared_qat.conv1.activation_post_process, 'disable_fake_quant'):
model_prepared_qat.conv1.activation_post_process.disable_fake_quant()
# Re-evaluate and then re-enable with .enable_fake_quant() or re-prepare the model.
</code></pre>
<h3>Run FX <code>print_tabular()</code></h3>
<p>This shows you which quantization observers are attached and where scale and zero points are applied. Look for missing <code>activation_post_process</code> or quant/dequant pairs due to bad layer fusion.</p>
<pre><code class="language-python">
if hasattr(model_prepared_qat, 'graph'):
model_prepared_qat.graph.print_tabular()
else:
print("Model does not seem to be an FX graph module or 'graph' attribute is missing.")
</code></pre>
<h3>Inspect <code>activation_post_process</code></h3>
<p><code>activation_post_process</code> is the name of a hook module (<code>FakeQuantize</code> in QAT) that PyTorch uses to handle quantization of activation. It’s inserted when you call <code>prepare_qat_fx()</code>.</p>
<pre><code class="language-python">
# Example:
self.conv = nn.Conv2d(...)
self.conv.activation_post_process = FakeQuantize(...)
if hasattr(model_prepared_qat, 'conv1') and hasattr(model_prepared_qat.conv1, 'activation_post_process'):
fq_module_inspect = model_prepared_qat.conv1.activation_post_process
print(fq_module_inspect.scale, fq_module_inspect.zero_point)
if hasattr(fq_module_inspect, 'min_val'): # Check if min_val/max_val are directly on FakeQuant
print("Observed min/max by FakeQuant:", fq_module_inspect.min_val.item(), fq_module_inspect.max_val.item())
</code></pre>
<p>If the values are outside min/max, they’ll be clipped during quantization.</p>
<h3>Use <code>torch.ao.quantization.quant_debug</code></h3>
<p>This helps you record and compare float vs quantized activation, layer by layer differences, and histograms of differences. (API might vary across PyTorch versions).</p>
<pre><code class="language-python">
debug_model = add_debug_observers_qat(model_prepared_qat, qconfig_dict_qat)
print(debug_model.activation_debug_stats) # Or similar attribute, API might change
</code></pre>
<h3>Look for outliers and dead activations</h3>
<pre><code class="language-python">
tensor = torch.randn(1000) # Example tensor
plt.hist(tensor.detach().cpu().flatten().numpy(), bins=100)
plt.title("Tensor Value Histogram")
plt.show()
</code></pre>
</section>
<section>
<h2>Wrapping Up</h2>
<ul>
<li>Use dynamic quantization for fast wins on CPU, LSTMs, and fully connected models.</li>
<li>Use static quantization for real performance gains, especially for CNNs and edge deployment.</li>
<li>Fuse before quantizing.</li>
<li>Choose the right observer.</li>
<li>Calibration must reflect real runtime data.</li>
<li>Avoid unnecessary float-int boundaries.</li>
<li>FX is your friend.</li>
<li>Don’t use <code>MovingAverageMinMaxObserver</code> directly in PTQ (it's primarily for QAT's <code>FakeQuantize</code> modules); PTQ typically uses observers like <code>MinMaxObserver</code> or <code>HistogramObserver</code> for one-shot calibration.</li>
</ul>
</section>
</body>
</html>