diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index e8c0f2a602b..71b80e5b0c9 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -22,6 +22,28 @@ logger: logging.Logger = logging.getLogger(__name__) QuantArgs = tuple[float, int, int, int, torch.dtype] +TRANSPARENT_OPS: frozenset[torch._ops.OpOverloadPacket] = frozenset( + { + torch.ops.aten.view, + torch.ops.aten.view_copy, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, + torch.ops.aten.permute, + torch.ops.aten.permute_copy, + torch.ops.aten.transpose, + torch.ops.aten.transpose_copy, + torch.ops.aten.squeeze, + torch.ops.aten.squeeze_copy, + torch.ops.aten.unsqueeze, + torch.ops.aten.unsqueeze_copy, + torch.ops.aten.slice, + torch.ops.aten.slice_copy, + torch.ops.aten.contiguous, + torch.ops.aten.clone, + torch.ops.aten.to, + torch.ops.aten._to_copy, + } +) @torch.no_grad() @@ -251,17 +273,24 @@ def extract_input_quant_params_from_graph( if not input_names: return quant_args + placeholders = {n.name: n for n in module.graph.nodes if n.op == "placeholder"} + for idx, name in enumerate(input_names): - for node in module.graph.nodes: - if node.op != "call_function": + placeholder = placeholders.get(name) + if placeholder is None: + continue + seen: set[torch.fx.Node] = set() + to_visit: list[torch.fx.Node] = list(placeholder.users) + while to_visit: + node = to_visit.pop() + if node in seen or node.op != "call_function": continue - + seen.add(node) + target_str = str(node.target) if ( - node.args - and isinstance(node.args[0], torch.fx.Node) - and node.args[0].name == name + "quantize_per_tensor" in target_str + and "dequantize" not in target_str and not node.name.startswith("_assert_tensor_metadata") - and "quantize_per_tensor" in str(node.target) ): args = node.args[1:] if len(args) >= 5: @@ -274,6 +303,12 @@ def extract_input_quant_params_from_graph( ) found_names.add(name) break + target = node.target + if ( + isinstance(target, torch._ops.OpOverload) + and target.overloadpacket in TRANSPARENT_OPS + ): + to_visit.extend(node.users) missing_names = set(input_names) - found_names if missing_names: