Skip to content

Commit ca2a616

Browse files
authored
Qualcomm AI Engine Direct - Adding QNN backend support for isInf core ATen op (#18249)
1 parent b17937b commit ca2a616

7 files changed

Lines changed: 118 additions & 0 deletions

File tree

backends/qualcomm/builders/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ Please help update following table if you are contributing new operators:
439439
| GroupNorm | ✓ |
440440
| HardSwish | ✓ |
441441
| InstanceNorm | ✓ |
442+
| IsInf | ✓ |
442443
| L2Norm | ✗ |
443444
| LayerNorm | ✓ |
444445
| LogSoftmax | ✓ |

backends/qualcomm/builders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
op_index_put,
5656
op_index_select,
5757
op_instance_norm,
58+
op_is_inf,
5859
op_layer_norm,
5960
op_le,
6061
op_linear,
@@ -164,6 +165,7 @@
164165
op_index_put,
165166
op_index_select,
166167
op_instance_norm,
168+
op_is_inf,
167169
op_layer_norm,
168170
op_le,
169171
op_linear,

backends/qualcomm/builders/node_visitor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
}
6060
QNN_TENSOR_TYPE_MAP = {
6161
torch.bool: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
62+
torch.float16: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_16,
6263
torch.float32: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
6364
# Note that there is no float64 tensor data type in Qnn.
6465
torch.float64: PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import warnings
7+
from typing import Dict
8+
9+
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
10+
11+
import torch
12+
from executorch.backends.qualcomm.utils.constants import QCOM_DATA
13+
14+
from .node_visitor import NodeVisitor
15+
from .node_visitor_manager import register_node_visitor
16+
from .qnn_constants import OpIsInf, QNN_OP_PACKAGE_NAME_QTI_AISW
17+
18+
19+
@register_node_visitor
20+
class IsInf(NodeVisitor):
21+
target = ["aten.isinf.default"]
22+
23+
def __init__(self, *args) -> None:
24+
super().__init__(*args)
25+
26+
def define_node(
27+
self,
28+
node: torch.fx.Node,
29+
nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper],
30+
) -> PyQnnManager.PyQnnOpWrapper:
31+
input_node = self.get_node(node.args[0])
32+
input_tensor = self.get_tensor(input_node, node)
33+
34+
if input_tensor.dtype != torch.float16:
35+
warnings.warn(
36+
"[QNN Delegate Op Builder]: QNN IsInf only supports FP16 inputs.",
37+
stacklevel=1,
38+
)
39+
return None
40+
41+
input_tensor_wrapper = self.define_tensor(
42+
input_node,
43+
node,
44+
self.get_tensor(input_node, node),
45+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
46+
nodes_to_wrappers,
47+
)
48+
49+
input_tensors = [input_tensor_wrapper]
50+
51+
out_tensor = self.get_tensor(node, node)
52+
output_tensor_wrapper = self.define_tensor(
53+
node,
54+
node,
55+
out_tensor,
56+
PyQnnManager.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
57+
nodes_to_wrappers,
58+
)
59+
output_tensors = [output_tensor_wrapper]
60+
61+
isinf_op = PyQnnManager.PyQnnOpWrapper(
62+
node.name,
63+
QNN_OP_PACKAGE_NAME_QTI_AISW,
64+
OpIsInf.op_name,
65+
)
66+
isinf_op.AddInputTensors(input_tensors)
67+
isinf_op.AddOutputTensors(output_tensors)
68+
69+
isinf_op.AddScalarParam(
70+
OpIsInf.param_detect_negative,
71+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
72+
{QCOM_DATA: True},
73+
)
74+
isinf_op.AddScalarParam(
75+
OpIsInf.param_detect_positive,
76+
PyQnnManager.Qnn_DataType_t.QNN_DATATYPE_BOOL_8,
77+
{QCOM_DATA: True},
78+
)
79+
80+
return isinf_op

backends/qualcomm/builders/qnn_constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,13 @@ class OpInstanceNorm:
389389
param_region = "region"
390390

391391

392+
@dataclass(init=False, frozen=True)
393+
class OpIsInf:
394+
op_name: str = "IsInf"
395+
param_detect_negative = "detect_negative"
396+
param_detect_positive = "detect_positive"
397+
398+
392399
@dataclass(init=False, frozen=True)
393400
class OpLayerNorm:
394401
op_name: str = "LayerNorm"

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,14 @@ def forward(self, x):
13101310
return self.instance_norm(x)
13111311

13121312

1313+
class IsInf(torch.nn.Module):
1314+
def __init__(self):
1315+
super().__init__()
1316+
1317+
def forward(self, x):
1318+
return torch.isinf(x)
1319+
1320+
13131321
class LargeTensorLinear(torch.nn.Module):
13141322
def __init__(self):
13151323
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,25 @@ def test_qnn_backend_instance_norm_2d(self):
12511251
with self.subTest(i=i):
12521252
self.lower_module_and_test_output(module, sample_input)
12531253

1254+
def test_qnn_backend_is_inf(self):
1255+
module = IsInf() # noqa: F405
1256+
sample_input = (
1257+
torch.tensor(
1258+
[
1259+
1.1,
1260+
float("inf"),
1261+
-float("inf"),
1262+
0.0,
1263+
float("nan"),
1264+
0.6,
1265+
float("nan"),
1266+
-5.0,
1267+
],
1268+
dtype=torch.float16,
1269+
),
1270+
)
1271+
self.lower_module_and_test_output(module, sample_input)
1272+
12541273
def test_qnn_backend_interpolate_bicubic(self):
12551274
modules = [
12561275
ResizeBicubic([2, 2], None, False), # noqa: F405

0 commit comments

Comments
 (0)