|
13 | 13 | import logging |
14 | 14 | import operator |
15 | 15 | from dataclasses import dataclass, replace |
16 | | -from typing import Callable, cast, Iterable, List, Optional, Sequence |
| 16 | +from typing import Any, Callable, cast, Iterable, List, NamedTuple, Optional, Sequence |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.fx |
20 | 20 | from executorch.backends.arm.common.debug import get_node_debug_info |
21 | 21 | from executorch.backends.arm.common.type import ensure_type |
22 | 22 | from executorch.backends.arm.quantizer import QuantizationConfig |
23 | | -from torch._subclasses import FakeTensor |
24 | 23 |
|
| 24 | +from torch._subclasses import FakeTensor |
25 | 25 | from torch.fx import Node |
26 | 26 | from torchao.quantization.pt2e import ( |
27 | 27 | FakeQuantize, |
28 | 28 | FusedMovingAvgObsFakeQuantize, |
29 | 29 | MovingAveragePerChannelMinMaxObserver, |
30 | 30 | PartialWrapper, |
31 | 31 | ) |
| 32 | + |
32 | 33 | from torchao.quantization.pt2e.quantizer import ( |
33 | 34 | annotate_input_qspec_map, |
34 | 35 | annotate_output_qspec, |
| 36 | + FixedQParamsQuantizationSpec, |
35 | 37 | QuantizationSpec, |
36 | 38 | QuantizationSpecBase, |
37 | 39 | SharedQuantizationSpec, |
@@ -78,6 +80,11 @@ def __init__(self): |
78 | 80 | self.quant_output: Optional[_QuantProperty] = None |
79 | 81 |
|
80 | 82 |
|
| 83 | +class _QParams(NamedTuple): |
| 84 | + scale: float |
| 85 | + zero_point: int |
| 86 | + |
| 87 | + |
81 | 88 | def _as_list(x): |
82 | 89 | """Return ``x`` wrapped as a list if needed. |
83 | 90 |
|
@@ -443,6 +450,29 @@ def _match_pattern( |
443 | 450 | torch.ops.aten.conv3d.padding, |
444 | 451 | } |
445 | 452 |
|
| 453 | +# For these ops, we use fixed qspecs, meaning that quantization params for |
| 454 | +# these are statically defined. This is to prevent issues with out-of-range |
| 455 | +# values when using dynamic quantization. |
| 456 | +# |
| 457 | +# Dict of operator to a dict of num_bits to qparams for that operator. |
| 458 | +_fixed_input_qspec_ops: dict[Any, dict[int, _QParams]] = { |
| 459 | + # acos has a valid range of [-1, 1] |
| 460 | + torch.ops.aten.acos.default: { |
| 461 | + 8: _QParams((1.0 - (-1.0)) / (1 << 8), 0), |
| 462 | + 16: _QParams((1.0 - (-1.0)) / (1 << 16), 0), |
| 463 | + }, |
| 464 | + # asin has a valid range of [-1, 1] |
| 465 | + torch.ops.aten.asin.default: { |
| 466 | + 8: _QParams((1.0 - (-1.0)) / (1 << 8), 0), |
| 467 | + 16: _QParams((1.0 - (-1.0)) / (1 << 16), 0), |
| 468 | + }, |
| 469 | + # atanh has a valid range of (-1, 1) (excluding -1 and 1). |
| 470 | + torch.ops.aten.atanh.default: { |
| 471 | + 8: _QParams((0.999 - (-0.999)) / (1 << 8), 0), |
| 472 | + 16: _QParams((0.99999 - (-0.99999)) / (1 << 16), 0), |
| 473 | + }, |
| 474 | +} |
| 475 | + |
446 | 476 | _one_to_one = { |
447 | 477 | torch.ops.aten.abs.default, |
448 | 478 | torch.ops.aten.ceil.default, |
@@ -474,11 +504,8 @@ def _match_pattern( |
474 | 504 | torch.ops.aten.log1p.default, |
475 | 505 | torch.ops.aten.acosh.default, |
476 | 506 | torch.ops.aten.sign.default, |
477 | | - torch.ops.aten.asin.default, |
478 | | - torch.ops.aten.atanh.default, |
479 | 507 | torch.ops.aten.asinh.default, |
480 | 508 | torch.ops.aten.cosh.default, |
481 | | - torch.ops.aten.acos.default, |
482 | 509 | torch.ops.aten.cumsum.default, |
483 | 510 | torch.ops.aten.tan.default, |
484 | 511 | } |
@@ -784,6 +811,25 @@ def any_or_hardtanh_min_zero(n: Node): |
784 | 811 | elif node.target in _one_to_one: |
785 | 812 | quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] |
786 | 813 | quant_properties.quant_output = _QuantProperty(0, output_act_qspec) |
| 814 | + elif node.target in _fixed_input_qspec_ops: |
| 815 | + num_bits = torch.iinfo(input_act_qspec.dtype).bits |
| 816 | + qparams = _fixed_input_qspec_ops[node.target][num_bits] |
| 817 | + |
| 818 | + quant_properties.quant_inputs = [ |
| 819 | + _QuantProperty( |
| 820 | + 0, |
| 821 | + FixedQParamsQuantizationSpec( |
| 822 | + dtype=input_act_qspec.dtype, |
| 823 | + scale=qparams.scale, |
| 824 | + zero_point=qparams.zero_point, |
| 825 | + quant_min=input_act_qspec.quant_min, |
| 826 | + quant_max=input_act_qspec.quant_max, |
| 827 | + qscheme=input_act_qspec.qscheme, |
| 828 | + is_dynamic=input_act_qspec.is_dynamic, |
| 829 | + ), |
| 830 | + ) |
| 831 | + ] |
| 832 | + quant_properties.quant_output = _QuantProperty(0, output_act_qspec) |
787 | 833 | elif node.target in _one_to_one_shared_input_qspec: |
788 | 834 | input_node = ensure_type(Node, node.args[0]) |
789 | 835 | quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)] |
|
0 commit comments