From e9b1f53631209e86aa59199abbda8ab53f1fb01d Mon Sep 17 00:00:00 2001 From: Ethan Ng Date: Mon, 8 Jun 2026 21:17:29 -0700 Subject: [PATCH] Walk transparent ops when extracting input quant params Summary: `extract_input_quant_params_from_graph` (used by `QuantizedInputWrapper` to dequantize pre-quantized inputs) only matched a `quantize_per_tensor` placed directly on the input placeholder. Inputs that pass through a shape-only op first -- an nhwc `permute`, or a patchify `reshape` -- defeated it, so callers had to pre-compute the quant params themselves. Walk from each input placeholder through a chain of transparent ops (`view`/`reshape`/`permute`/`transpose`/`slice`/`to`/...) to the first `quantize_per_tensor`, so `QuantizedInputWrapper(module, input_names)` resolves those inputs automatically. Backward compatible (a quantize directly on the placeholder is just a zero-hop walk) and subsumes the narrower `extract_quant_params_through_permute`. Differential Revision: D107922730 --- backends/cadence/aot/compiler_funcs.py | 49 ++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 7 deletions(-) diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index e8c0f2a602b..71b80e5b0c9 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -22,6 +22,28 @@ logger: logging.Logger = logging.getLogger(__name__) QuantArgs = tuple[float, int, int, int, torch.dtype] +TRANSPARENT_OPS: frozenset[torch._ops.OpOverloadPacket] = frozenset( + { + torch.ops.aten.view, + torch.ops.aten.view_copy, + torch.ops.aten._unsafe_view, + torch.ops.aten.reshape, + torch.ops.aten.permute, + torch.ops.aten.permute_copy, + torch.ops.aten.transpose, + torch.ops.aten.transpose_copy, + torch.ops.aten.squeeze, + torch.ops.aten.squeeze_copy, + torch.ops.aten.unsqueeze, + torch.ops.aten.unsqueeze_copy, + torch.ops.aten.slice, + torch.ops.aten.slice_copy, + torch.ops.aten.contiguous, + torch.ops.aten.clone, + torch.ops.aten.to, + torch.ops.aten._to_copy, + } +) @torch.no_grad() @@ -251,17 +273,24 @@ def extract_input_quant_params_from_graph( if not input_names: return quant_args + placeholders = {n.name: n for n in module.graph.nodes if n.op == "placeholder"} + for idx, name in enumerate(input_names): - for node in module.graph.nodes: - if node.op != "call_function": + placeholder = placeholders.get(name) + if placeholder is None: + continue + seen: set[torch.fx.Node] = set() + to_visit: list[torch.fx.Node] = list(placeholder.users) + while to_visit: + node = to_visit.pop() + if node in seen or node.op != "call_function": continue - + seen.add(node) + target_str = str(node.target) if ( - node.args - and isinstance(node.args[0], torch.fx.Node) - and node.args[0].name == name + "quantize_per_tensor" in target_str + and "dequantize" not in target_str and not node.name.startswith("_assert_tensor_metadata") - and "quantize_per_tensor" in str(node.target) ): args = node.args[1:] if len(args) >= 5: @@ -274,6 +303,12 @@ def extract_input_quant_params_from_graph( ) found_names.add(name) break + target = node.target + if ( + isinstance(target, torch._ops.OpOverload) + and target.overloadpacket in TRANSPARENT_OPS + ): + to_visit.extend(node.users) missing_names = set(input_names) - found_names if missing_names: