Skip to content

Commit eb92cec

Browse files
authored
Qualcomm/op reciprocal (#18220)
1 parent 202c6af commit eb92cec

9 files changed

Lines changed: 163 additions & 1 deletion

File tree

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .decompose_linalg_vector_norm import DecomposeLinalgVectorNorm
2525
from .decompose_maxpool3d import DecomposeMaxPool3d
2626
from .decompose_minmaxdim import DecomposeMinMaxDim
27+
from .decompose_reciprocal import DecomposeReciprocal
2728
from .decompose_roll import DecomposeRoll
2829
from .decompose_silu import DecomposeSilu
2930
from .decompose_threshold import DecomposeThreshold
@@ -73,6 +74,7 @@
7374
DecomposeLinalgVectorNorm,
7475
DecomposeMaxPool3d,
7576
DecomposeMinMaxDim,
77+
DecomposeReciprocal,
7678
DecomposeRoll,
7779
DecomposeSilu,
7880
DecomposeThreshold,
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
10+
from .utils import copy_meta
11+
12+
13+
class DecomposeReciprocal(ExportPass):
14+
def __init__(self):
15+
super(DecomposeReciprocal, self).__init__()
16+
17+
def call(self, graph_module: torch.fx.GraphModule):
18+
graph = graph_module.graph
19+
for node in graph.nodes:
20+
if node.target in {
21+
torch.ops.aten.reciprocal.default,
22+
}:
23+
reciprocal_node = node
24+
reciprocal_node_input = node.args[0]
25+
with graph_module.graph.inserting_after(reciprocal_node_input):
26+
# Create division node
27+
div_node = graph.call_function(
28+
torch.ops.aten.div.Tensor,
29+
(1, reciprocal_node_input),
30+
)
31+
div_node.meta = copy_meta(reciprocal_node.meta)
32+
33+
# Replace all uses of reciprocal with division
34+
for user in reciprocal_node.users.copy():
35+
user.replace_input_with(reciprocal_node, div_node)
36+
37+
graph.eliminate_dead_code()
38+
graph_module.recompile()
39+
return PassResult(graph_module, True)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
DecomposeLinalgVectorNorm,
3030
DecomposeMaxPool3d,
3131
DecomposeMinMaxDim,
32+
DecomposeReciprocal,
3233
DecomposeRoll,
3334
DecomposeSilu,
3435
DecomposeThreshold,
@@ -220,6 +221,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
220221
self.add_pass(DecomposeEinsum())
221222
self.add_pass(DecomposeExpM1())
222223
self.add_pass(DecomposeGlu())
224+
# HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal
225+
# Decompose Reciprocal into Div for these 2 backend
226+
# TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager)
227+
self.add_pass(DecomposeReciprocal())
223228
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
224229
self.add_pass(ReplaceInfValues())
225230
self.add_pass(LiftConstantScalarOperands())
@@ -243,6 +248,10 @@ def transform_for_export_pipeline(
243248
# This pass is needed before to_edge pipeline to avoid mixed type for div operator with RemoveMixedTypeOperators pass.
244249
self.add_pass(DecomposeFloorDivide())
245250
self.add_pass(DecomposeWrapWithAutocast())
251+
# HTP and GPU doesn't support ElementWiseUnary with operation=reciprocal
252+
# Decompose Reciprocal into Div for these 2 backend
253+
# TODO: Skip this pass for CPU backend (Dependency: Backend-aware passes manager)
254+
self.add_pass(DecomposeReciprocal())
246255
# this pass will rewrite state_dict, it needs to be accomplished before
247256
# to_edge_transform_and_lower
248257
self.add_pass(CanonicalizeConv(exported_program))

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ Please help update following table if you are contributing new operators:
422422
| ElementWiseSquaredDifference | ✗ |
423423
| ElementWiseSquareRoot | ✓ |
424424
| ElementWiseSubtract | ✓ |
425-
| ElementWiseUnary | ✗ |
425+
| ElementWiseUnary | ✓ |
426426
| ElementWiseXor | ✓ |
427427
| Elu | ✓ |
428428
| ExpandDims | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
op_to,
105105
op_topk,
106106
op_transpose,
107+
op_unary,
107108
op_unbind,
108109
op_unsqueeze,
109110
op_upsample_bilinear2d,
@@ -185,6 +186,7 @@
185186
op_pow,
186187
op_prelu,
187188
op_quantize,
189+
op_unary,
188190
op_relu,
189191
op_repeat,
190192
op_reshape,
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from typing import Dict
7+
8+
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnWrapper
9+
10+
import numpy as np
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
16+
from .qnn_constants import OpElementWiseUnary, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class Unary(NodeVisitor):
21+
target = ["aten.reciprocal.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
30+
) -> PyQnnWrapper.PyQnnOpWrapper:
31+
input_node = self.get_node(node.args[0])
32+
input_tensor = self.get_tensor(input_node, node)
33+
reciprocal_inp_tensor_wrapper = self.define_tensor(
34+
input_node,
35+
node,
36+
input_tensor,
37+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
38+
nodes_to_wrappers,
39+
)
40+
reciprocal_input_tensors = [reciprocal_inp_tensor_wrapper]
41+
42+
output_tensor = self.get_tensor(node, node)
43+
output_tensor_wrapper = self.define_tensor(
44+
node,
45+
node,
46+
output_tensor,
47+
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
48+
nodes_to_wrappers,
49+
)
50+
reciprocal_output_tensors = [output_tensor_wrapper]
51+
52+
reciprocal_op = PyQnnWrapper.PyQnnOpWrapper(
53+
node.name,
54+
QNN_OP_PACKAGE_NAME_QTI_AISW,
55+
OpElementWiseUnary.op_name,
56+
)
57+
reciprocal_op.AddInputTensors(reciprocal_input_tensors)
58+
reciprocal_op.AddOutputTensors(reciprocal_output_tensors)
59+
60+
reciprocal_op.AddScalarParam(
61+
OpElementWiseUnary.param_operation,
62+
PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
63+
{QCOM_DATA: np.uint32(OpElementWiseUnary.Operation.RECIPROCAL)},
64+
)
65+
66+
return reciprocal_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,31 @@ class OpElementWiseSubtract:
280280
op_name = "ElementWiseSubtract"
281281

282282

283+
@dataclass(init=False, frozen=True)
284+
class OpElementWiseUnary:
285+
op_name: str = "ElementWiseUnary"
286+
param_operation: str = "operation"
287+
288+
@unique
289+
class Operation(IntEnum):
290+
ABS = 0
291+
ASIN = 1
292+
ATAN = 2
293+
CEIL = 3
294+
COS = 4
295+
EXP = 5
296+
FLOOR = 6
297+
LOG = 7
298+
NEG = 8
299+
NOT = 9
300+
RECIPROCAL = 10
301+
ROUND = 11
302+
RSQRT = 12
303+
SIGN = 13
304+
SIN = 14
305+
SQRT = 15
306+
307+
283308
@dataclass(init=False, frozen=True)
284309
class OpElementWiseXor:
285310
op_name: str = "ElementWiseXor"

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,14 @@ def forward(self, x):
17651765
return self.prelu(x)
17661766

17671767

1768+
class Reciprocal(torch.nn.Module):
1769+
def __init__(self):
1770+
super().__init__()
1771+
1772+
def forward(self, x):
1773+
return torch.reciprocal(x)
1774+
1775+
17681776
class Relu(torch.nn.Module):
17691777
def __init__(self):
17701778
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,11 @@ def test_qnn_backend_prelu(self):
16061606
index += 1
16071607
self.lower_module_and_test_output(module, sample_input)
16081608

1609+
def test_qnn_backend_reciprocal(self):
1610+
module = Reciprocal() # noqa: F405
1611+
sample_input = (torch.randn([2, 2, 2, 2]),)
1612+
self.lower_module_and_test_output(module, sample_input)
1613+
16091614
def test_qnn_backend_relu(self):
16101615
module = Relu() # noqa: F405
16111616
sample_input = (torch.randn([2, 5, 1, 3]),)
@@ -3931,6 +3936,12 @@ def test_qnn_backend_prelu(self):
39313936
module = self.get_qdq_module(module, sample_input)
39323937
self.lower_module_and_test_output(module, sample_input)
39333938

3939+
def test_qnn_backend_reciprocal(self):
3940+
module = Reciprocal() # noqa: F405
3941+
sample_input = (torch.randn([2, 5, 1, 3]),)
3942+
module = self.get_qdq_module(module, sample_input)
3943+
self.lower_module_and_test_output(module, sample_input)
3944+
39343945
def test_qnn_backend_relu(self):
39353946
module = Relu() # noqa: F405
39363947
sample_input = (torch.randn([2, 5, 1, 3]),)

0 commit comments

Comments
 (0)