Skip to content

Commit fb90480

Browse files
Martin Lindströmmartinlsm
authored andcommitted
Arm backend: Set fixed qparams for acos, asin and atanh
acos, asin and atanh have shown to be problematic to quantize properly due to their limited input range between [-1, 1] (inclusively for asin and acos and exclusively for atanh). Before this patch, the approach for quantizing these ops was to use the quantization spec from the quantization config (typically a HistogramObserver if using deafult symmetric quantization config). This caused problems when calibrating the model with inputs close to -1 or 1, because they could land outside the valid range of the operator. When this happened, the resulting TABLE op set the output of these outliers to zero, which is not ideal. To mitigate this problem, use fixed quantization params for these ops by statically defining them in quantization_annotator.py. With this solution, we potentially lose a bit of numerical precision because the ops won't be affected by quantization calibration at all, but the resulting TABLE ops won't set any zeros since there is no input that can be outside the [-1, 1] interval anymore. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Change-Id: Id55156be5ca7fcfbf9f9c3f8ae88fb075509ce0c
1 parent e5be6d5 commit fb90480

4 files changed

Lines changed: 56 additions & 12 deletions

File tree

backends/arm/quantizer/quantization_annotator.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,27 @@
1313
import logging
1414
import operator
1515
from dataclasses import dataclass, replace
16-
from typing import Callable, cast, Iterable, List, Optional, Sequence
16+
from typing import Any, Callable, cast, Iterable, List, NamedTuple, Optional, Sequence
1717

1818
import torch
1919
import torch.fx
2020
from executorch.backends.arm.common.debug import get_node_debug_info
2121
from executorch.backends.arm.common.type import ensure_type
2222
from executorch.backends.arm.quantizer import QuantizationConfig
23-
from torch._subclasses import FakeTensor
2423

24+
from torch._subclasses import FakeTensor
2525
from torch.fx import Node
2626
from torchao.quantization.pt2e import (
2727
FakeQuantize,
2828
FusedMovingAvgObsFakeQuantize,
2929
MovingAveragePerChannelMinMaxObserver,
3030
PartialWrapper,
3131
)
32+
3233
from torchao.quantization.pt2e.quantizer import (
3334
annotate_input_qspec_map,
3435
annotate_output_qspec,
36+
FixedQParamsQuantizationSpec,
3537
QuantizationSpec,
3638
QuantizationSpecBase,
3739
SharedQuantizationSpec,
@@ -78,6 +80,11 @@ def __init__(self):
7880
self.quant_output: Optional[_QuantProperty] = None
7981

8082

83+
class _QParams(NamedTuple):
84+
scale: float
85+
zero_point: int
86+
87+
8188
def _as_list(x):
8289
"""Return ``x`` wrapped as a list if needed.
8390
@@ -443,6 +450,29 @@ def _match_pattern(
443450
torch.ops.aten.conv3d.padding,
444451
}
445452

453+
# For these ops, we use fixed qspecs, meaning that quantization params for
454+
# these are statically defined. This is to prevent issues with out-of-range
455+
# values when using dynamic quantization.
456+
#
457+
# Dict of operator to a dict of num_bits to qparams for that operator.
458+
_fixed_input_qspec_ops: dict[Any, dict[int, _QParams]] = {
459+
# acos has a valid range of [-1, 1]
460+
torch.ops.aten.acos.default: {
461+
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
462+
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
463+
},
464+
# asin has a valid range of [-1, 1]
465+
torch.ops.aten.asin.default: {
466+
8: _QParams((1.0 - (-1.0)) / (1 << 8), 0),
467+
16: _QParams((1.0 - (-1.0)) / (1 << 16), 0),
468+
},
469+
# atanh has a valid range of (-1, 1) (excluding -1 and 1).
470+
torch.ops.aten.atanh.default: {
471+
8: _QParams((0.999 - (-0.999)) / (1 << 8), 0),
472+
16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0),
473+
},
474+
}
475+
446476
_one_to_one = {
447477
torch.ops.aten.abs.default,
448478
torch.ops.aten.ceil.default,
@@ -474,11 +504,8 @@ def _match_pattern(
474504
torch.ops.aten.log1p.default,
475505
torch.ops.aten.acosh.default,
476506
torch.ops.aten.sign.default,
477-
torch.ops.aten.asin.default,
478-
torch.ops.aten.atanh.default,
479507
torch.ops.aten.asinh.default,
480508
torch.ops.aten.cosh.default,
481-
torch.ops.aten.acos.default,
482509
torch.ops.aten.cumsum.default,
483510
torch.ops.aten.tan.default,
484511
}
@@ -784,6 +811,25 @@ def any_or_hardtanh_min_zero(n: Node):
784811
elif node.target in _one_to_one:
785812
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
786813
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
814+
elif node.target in _fixed_input_qspec_ops:
815+
num_bits = torch.iinfo(input_act_qspec.dtype).bits
816+
qparams = _fixed_input_qspec_ops[node.target][num_bits]
817+
818+
quant_properties.quant_inputs = [
819+
_QuantProperty(
820+
0,
821+
FixedQParamsQuantizationSpec(
822+
dtype=input_act_qspec.dtype,
823+
scale=qparams.scale,
824+
zero_point=qparams.zero_point,
825+
quant_min=input_act_qspec.quant_min,
826+
quant_max=input_act_qspec.quant_max,
827+
qscheme=input_act_qspec.qscheme,
828+
is_dynamic=input_act_qspec.is_dynamic,
829+
),
830+
)
831+
]
832+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
787833
elif node.target in _one_to_one_shared_input_qspec:
788834
input_node = ensure_type(Node, node.args[0])
789835
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]

backends/arm/test/ops/test_acos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_acos_tosa_INT(test_data: Tuple):
6565
(test_data(),),
6666
aten_op=aten_op,
6767
exir_op=exir_op,
68-
frobenius_threshold=0.5, # MLETORCH-1709
6968
)
7069
pipeline.run()
7170

backends/arm/test/ops/test_asin.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def test_asin_tosa_INT(test_data: Tuple):
5555
(test_data(),),
5656
aten_op=[],
5757
exir_op=[],
58-
frobenius_threshold=0.6, # MLETORCH-1709
59-
cosine_threshold=0.8, # MLETORCH-1709
6058
)
6159
pipeline.run()
6260

backends/arm/test/ops/test_atanh.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,10 @@
2626
test_data_suite = {
2727
"zeros": torch.zeros(1, 10, 10, 10),
2828
"zeros_alt_shape": torch.zeros(1, 10, 3, 5),
29-
"ones": torch.ones(10, 10, 10),
3029
"rand": torch.rand(10, 10) - 0.5,
3130
"rand_alt_shape": torch.rand(1, 10, 3, 5) - 0.5,
3231
"ramp": torch.arange(-1, 1, 0.2),
33-
"near_bounds": torch.tensor([-0.999999, -0.999, -0.9, 0.9, 0.999, 0.999999]),
32+
"near_bounds": torch.tensor([-0.99, -0.9, 0.9, 0.99]),
3433
"on_bounds": torch.tensor([-1.0, 1.0]),
3534
}
3635

@@ -58,9 +57,11 @@ def test_atanh_tosa_INT(test_data: Tuple):
5857
(test_data,),
5958
aten_op=aten_op,
6059
exir_op=exir_op,
61-
frobenius_threshold=None, # MLETORCH-1709
62-
cosine_threshold=0.7,
6360
)
61+
if torch.any(test_data >= 1) or torch.any(test_data <= -1):
62+
# The quantized model will saturate to max/min values while the
63+
# original model will return inf/-inf, so comparison wont be valid here.
64+
pipeline.pop_stage("run_method_and_compare_outputs.original_model")
6465
pipeline.run()
6566

6667

0 commit comments

Comments
 (0)