From af07b09912713da4447d84d4c790ab67a5ab6b5e Mon Sep 17 00:00:00 2001 From: George Gekov Date: Mon, 8 Jun 2026 16:35:22 +0100 Subject: [PATCH] Arm backend: Remove FuseConsecutiveRescalesPass The FuseConsecutiveRescalesPass was destinated to speed-up execution on Ethos-U55 by identifying consequtive rescales that could be optimised and eliminate them. However, since then the compiler has added support for advanced add/sub for the Ethos-U55. We no longer need to remove rescales for the U55, hence removing this pass. Signed-off-by: George Gekov Change-Id: I6e15636f2ea46390cbebc8bb7e2ee2beb3f81512 --- backends/arm/_passes/__init__.py | 1 - backends/arm/_passes/arm_pass_manager.py | 2 - .../_passes/fuse_consecutive_rescales_pass.py | 175 ------------- .../test/models/test_residual_conv_block.py | 8 +- .../test/passes/test_rescale_optimization.py | 244 ------------------ 5 files changed, 2 insertions(+), 428 deletions(-) delete mode 100644 backends/arm/_passes/fuse_consecutive_rescales_pass.py delete mode 100644 backends/arm/test/passes/test_rescale_optimization.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 20ead36627c..42eb6616dea 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -107,7 +107,6 @@ ) from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa from .fuse_consecutive_concat_shapes import FuseConsecutiveConcatShapesPass # noqa -from .fuse_consecutive_rescales_pass import FuseConsecutiveRescalesPass # noqa from .fuse_constant_ops_pass import ( # noqa ComputeConstantOpsAOTPass, FuseConstantArgsPass, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 748c369482f..42fcbb3f151 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -104,7 +104,6 @@ FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, FuseConsecutiveConcatShapesPass, - FuseConsecutiveRescalesPass, FuseConstantArgsPass, FuseDuplicateUsersPass, FuseEqualPlaceholdersPass, @@ -417,7 +416,6 @@ def _tosa_pipeline( # Ticket: MLETORCH-1539 DecomposeLinearPass(), InsertRescaleInt32Pass(), - FuseConsecutiveRescalesPass(), InsertControlFlowRescalesPass(), DecomposeQuantNodesPass(), ] diff --git a/backends/arm/_passes/fuse_consecutive_rescales_pass.py b/backends/arm/_passes/fuse_consecutive_rescales_pass.py deleted file mode 100644 index 1deb35d7aee..00000000000 --- a/backends/arm/_passes/fuse_consecutive_rescales_pass.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import cast, Set, Type - -import torch -from executorch.backends.arm._passes.arm_pass import ArmPass -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx import GraphModule, Node - -logger: logging.Logger = logging.getLogger(__name__) - -# TOSA RESCALE argument positions: -# args[0] = input tensor (Node) -# args[1] = output dtype (e.g., torch.int8, torch.int32) -# args[2] = scale list (List[float]; per-tensor when len == 1) -# args[3] = input zero point (int) -# args[4] = output zero point (int) -_ARG_INPUT = 0 -_ARG_OUTPUT_DTYPE = 1 -_ARG_SCALE = 2 -_ARG_INPUT_ZP = 3 -_ARG_OUTPUT_ZP = 4 - - -class FuseConsecutiveRescalesPass(ArmPass): - """Fuse consecutive RESCALE(INT32->INT8/INT16) -> RESCALE(INT8/INT16->INT32) - pairs. - - InsertRescaleInt32Pass wraps each quantized arithmetic and comparison - operator (add, sub, mul, abs, eq, ge, gt, le, lt, max, min, sum) with - input rescales (INT8/INT16->INT32) and an output rescale - (INT32->INT8/INT16). When two such ops are chained (e.g., add1 -> add2), - the output rescale of add1 feeds directly into an input rescale of add2, - creating a redundant INT32->INT8/INT16->INT32 round-trip that loses - precision. - - This pass detects such pairs and handles two cases: - - - **Identity** (composed scale ~1.0, matching zero points): Removes both - RESCALEs and directly wires R1's input to R2's users. This eliminates - the entire round-trip. Bypassing the intermediate INT8/INT16 clamp can - in theory cause up to ~120 INT8 steps of output difference when all - inputs are near the clamp boundary; in practice, observed differences - are 0-1 steps for typical distributions. Tests use qtol=1. - - - **Non-identity**: Leaves the pair unchanged. The Vela NPU compiler - cannot correctly process INT32->INT32 RESCALE (produces all-zero NPU - outputs), so non-identity pairs retain their INT8/INT16 intermediate. - - Handles multi-user R1 nodes: when R1 feeds both RESCALE and - non-RESCALE users, each R1->R2 RESCALE pair is fused individually - while preserving R1 for its non-RESCALE users. - - """ - - _passes_required_after: Set[Type[ExportPass]] = set() - - def call(self, graph_module: GraphModule) -> PassResult: - graph = graph_module.graph - modified = False - rescale_before = sum(1 for n in graph.nodes if _is_rescale(n)) - identity_pairs_fused = 0 - - for node in list(graph.nodes): - node = cast(Node, node) - if not _is_fuseable_r1(node): - continue - - r1_input = node.args[_ARG_INPUT] - r1_input_zp = node.args[_ARG_INPUT_ZP] - r1_scale = float(node.args[_ARG_SCALE][0]) # type: ignore[arg-type] - - node_fused = False - for user in list(node.users): - if _try_fuse_identity_pair(node, user, r1_input, r1_input_zp, r1_scale): - node_fused = True - identity_pairs_fused += 1 - - if node_fused: - modified = True - - if modified: - graph.eliminate_dead_code() - rescale_after = sum(1 for n in graph.nodes if _is_rescale(n)) - removed = rescale_before - rescale_after - logger.info( - "FuseConsecutiveRescalesPass: removed %d identity pairs " - "(%d RESCALEs: %d -> %d)", - identity_pairs_fused, - removed, - rescale_before, - rescale_after, - ) - graph_module.recompile() - graph.lint() - # Note: we deliberately skip super().call() — retracing is - # unnecessary since this pass only rewires edges and removes - # nodes without introducing new operations. - - return PassResult(graph_module, modified) - - -def _is_rescale(node: Node) -> bool: - return ( - node.op == "call_function" - and node.target == exir_ops.backend.tosa.RESCALE.default - ) - - -def _is_fuseable_r1(node: Node) -> bool: - """Check if node is an R1 candidate. - - R1 is RESCALE(INT32 -> INT8/INT16) with per-tensor scale. - - """ - if not _is_rescale(node): - return False - if node.args[_ARG_OUTPUT_DTYPE] not in (torch.int8, torch.int16): - return False - if len(node.args[_ARG_SCALE]) != 1: # type: ignore[arg-type] - return False - r1_input = node.args[_ARG_INPUT] - if not isinstance(r1_input, Node) or "val" not in r1_input.meta: - return False - if r1_input.meta["val"].dtype != torch.int32: - return False - return True - - -def _try_fuse_identity_pair( - r1: Node, - r2: Node, - r1_input: Node, - r1_input_zp: int, - r1_scale: float, -) -> bool: - """Try to fuse an R1->R2 identity pair. - - Returns True if fused. - - """ - if not _is_rescale(r2): - return False - if r2.args[_ARG_OUTPUT_DTYPE] != torch.int32: - return False - if r1.args[_ARG_OUTPUT_ZP] != r2.args[_ARG_INPUT_ZP]: - return False - if len(r2.args[_ARG_SCALE]) != 1: # type: ignore[arg-type] - return False - - r2_scale = float(r2.args[_ARG_SCALE][0]) # type: ignore[arg-type, index] - composed_scale = r1_scale * r2_scale - r2_output_zp = r2.args[_ARG_OUTPUT_ZP] - - if abs(composed_scale - 1.0) < 1e-6 and r1_input_zp == r2_output_zp: - # Identity case: remove both RESCALEs and directly wire - # R1's input (INT32) to R2's users. The composed scale - # is ~1.0 so the round-trip is a no-op modulo the INT8 - # clamp. Bypassing the clamp can in theory cause up to - # ~120 INT8 steps of difference near clamp boundaries; - # observed differences are 0-1 steps. Tests use qtol=1. - r2.replace_all_uses_with(r1_input) - return True - - # Non-identity: leave the pair unchanged. Creating a - # single INT32->INT32 RESCALE with the composed scale would - # be semantically correct (and the TOSA ref model handles - # it), but the Vela NPU compiler produces all-zero outputs - # for INT32->INT32 RESCALE operations. - return False diff --git a/backends/arm/test/models/test_residual_conv_block.py b/backends/arm/test/models/test_residual_conv_block.py index b77ea8a9ccc..f2003bf37b9 100644 --- a/backends/arm/test/models/test_residual_conv_block.py +++ b/backends/arm/test/models/test_residual_conv_block.py @@ -5,8 +5,7 @@ """Residual conv block model test for ARM TOSA backend. Tests a minimal residual architecture with conv->batchnorm->relu->add blocks and -permute operations, representative of quantized signal processing models where -FuseConsecutiveRescalesPass eliminates redundant RESCALE pairs. +permute operations. """ @@ -28,10 +27,7 @@ class ResidualConvBlock(torch.nn.Module): Architecture: conv->bn->relu->add (residual) -> permute -> conv->bn->relu->add. When quantized, each residual add is - wrapped with INT32 RESCALEs by InsertRescaleInt32Pass. Stacked - blocks create consecutive RESCALE pairs (INT32->INT8->INT32) - between adjacent adds that FuseConsecutiveRescalesPass - eliminates. + wrapped with INT32 RESCALEs by InsertRescaleInt32Pass. """ diff --git a/backends/arm/test/passes/test_rescale_optimization.py b/backends/arm/test/passes/test_rescale_optimization.py deleted file mode 100644 index fa77c0cae64..00000000000 --- a/backends/arm/test/passes/test_rescale_optimization.py +++ /dev/null @@ -1,244 +0,0 @@ -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import pytest -import torch -from executorch.backends.arm._passes import FuseConsecutiveRescalesPass -from executorch.backends.arm.test import common -from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineINT, - TosaPipelineINT, -) -from executorch.backends.arm.tosa.specification import ( - TosaLoweringContext, - TosaSpecification, -) -from executorch.exir.dialects._ops import ops as exir_ops - -RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default -TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT") - - -class AddChain(torch.nn.Module): - """Two cascaded adds: (x + y) + z.""" - - input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] - - def forward(self, x, y, z): - return (x + y) + z - - @staticmethod - def get_test_inputs(): - return ( - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - ) - - -class BranchingAdd(torch.nn.Module): - """(x + y) feeds two downstream adds.""" - - input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - - def forward(self, x, y, z, w): - a = x + y - b = a + z - c = a + w - return b + c - - @staticmethod - def get_test_inputs(): - return ( - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - ) - - -class LSTMGatePattern(torch.nn.Module): - """Mimics the LSTM cell-state update f * c_prev + i * g.""" - - input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] - - def forward(self, forget_gate_in, cell_prev, input_gate_in, candidate): - f = torch.sigmoid(forget_gate_in) - i = torch.sigmoid(input_gate_in) - g = torch.tanh(candidate) - return f * cell_prev + i * g - - @staticmethod - def get_test_inputs(): - return ( - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - torch.randn(1, 3, 8, 8), - ) - - -def _is_rescale(node) -> bool: - return node.op == "call_function" and node.target == RESCALE_TARGET - - -def _count_rescales(graph_module: torch.fx.GraphModule) -> int: - return sum(1 for node in graph_module.graph.nodes if _is_rescale(node)) - - -def _run_fuse_pass(graph_module: torch.fx.GraphModule): - with TosaLoweringContext(TOSA_INT_SPEC): - return FuseConsecutiveRescalesPass().call(graph_module) - - -def _make_int32_placeholder(graph: torch.fx.Graph, name: str = "x") -> torch.fx.Node: - x = graph.placeholder(name) - x.meta["val"] = torch.ones((1,), dtype=torch.int32) - return x - - -def _make_rescale( - graph: torch.fx.Graph, - input_node: torch.fx.Node, - output_dtype: torch.dtype, - scale: float, - input_zp: int = 0, - output_zp: int = 0, -) -> torch.fx.Node: - return graph.create_node( - "call_function", - RESCALE_TARGET, - args=(input_node, output_dtype, [scale], input_zp, output_zp), - ) - - -@pytest.mark.parametrize( - "r1_scale,r2_scale", - [(0.5, 2.0), (0.25, 4.0), (0.125, 8.0)], - ids=["0.5x2.0", "0.25x4.0", "0.125x8.0"], -) -def test_fuse_consecutive_rescales_tosa_INT_identity_pair_removed( - r1_scale: float, r2_scale: float -) -> None: - """Checks that an identity RESCALE pair is removed.""" - graph = torch.fx.Graph() - x = _make_int32_placeholder(graph) - r1 = _make_rescale(graph, x, torch.int8, r1_scale) - r2 = _make_rescale(graph, r1, torch.int32, r2_scale) - graph.output(r2) - graph_module = torch.fx.GraphModule({}, graph) - rescale_count_before = _count_rescales(graph_module) - - result = _run_fuse_pass(graph_module) - - assert rescale_count_before == 2 - assert result.modified - assert _count_rescales(result.graph_module) == 0 - - -def test_fuse_consecutive_rescales_tosa_INT_non_identity_pair_preserved() -> None: - """Checks that a non-identity RESCALE pair is left unchanged.""" - graph = torch.fx.Graph() - x = _make_int32_placeholder(graph) - r1 = _make_rescale(graph, x, torch.int8, 0.5) - r2 = _make_rescale(graph, r1, torch.int32, 3.0) - graph.output(r2) - graph_module = torch.fx.GraphModule({}, graph) - rescale_count_before = _count_rescales(graph_module) - - result = _run_fuse_pass(graph_module) - - assert not result.modified - assert _count_rescales(result.graph_module) == rescale_count_before - - -def test_fuse_consecutive_rescales_tosa_INT_zero_point_mismatch_preserved() -> None: - """Checks that mismatched zero points prevent fusion.""" - graph = torch.fx.Graph() - x = _make_int32_placeholder(graph) - r1 = _make_rescale(graph, x, torch.int8, 0.5, input_zp=0, output_zp=3) - r2 = _make_rescale(graph, r1, torch.int32, 2.0, input_zp=4, output_zp=0) - graph.output(r2) - graph_module = torch.fx.GraphModule({}, graph) - rescale_count_before = _count_rescales(graph_module) - - result = _run_fuse_pass(graph_module) - - assert not result.modified - assert _count_rescales(result.graph_module) == rescale_count_before - - -def test_fuse_consecutive_rescales_tosa_INT_shared_producer_all_rescale_users_removed() -> ( - None -): - """Checks that a shared producer is removed when all users are fuseable.""" - graph = torch.fx.Graph() - x = _make_int32_placeholder(graph) - r1 = _make_rescale(graph, x, torch.int8, 0.5) - r2_left = _make_rescale(graph, r1, torch.int32, 2.0) - r2_right = _make_rescale(graph, r1, torch.int32, 2.0) - graph.output((r2_left, r2_right)) - graph_module = torch.fx.GraphModule({}, graph) - rescale_count_before = _count_rescales(graph_module) - - result = _run_fuse_pass(graph_module) - - assert rescale_count_before == 3 - assert result.modified - assert _count_rescales(result.graph_module) == 0 - - -def test_fuse_consecutive_rescales_tosa_INT_shared_producer_non_rescale_user_preserved() -> ( - None -): - """Checks that a shared producer is kept for remaining non-RESCALE users.""" - graph = torch.fx.Graph() - x = _make_int32_placeholder(graph) - r1 = _make_rescale(graph, x, torch.int8, 0.5) - r2 = _make_rescale(graph, r1, torch.int32, 2.0) - graph.output((r2, r1)) - graph_module = torch.fx.GraphModule({}, graph) - rescale_count_before = _count_rescales(graph_module) - - result = _run_fuse_pass(graph_module) - - assert rescale_count_before == 2 - assert result.modified - assert _count_rescales(result.graph_module) == 1 - remaining_rescales = [ - node for node in result.graph_module.graph.nodes if _is_rescale(node) - ] - assert remaining_rescales == [r1] - - -def test_fuse_consecutive_rescales_tosa_INT_lstm_gate_pattern_pipeline() -> None: - """Checks the LSTM-shaped regression path in the TOSA INT pipeline.""" - model = LSTMGatePattern() - pipeline = TosaPipelineINT[LSTMGatePattern.input_t]( - model, - model.get_test_inputs(), - aten_op=[], - exir_op=[], - use_to_edge_transform_and_lower=True, - frobenius_threshold=None, - cosine_threshold=None, - ) - pipeline.run() - - -@common.XfailIfNoCorstone300 -def test_fuse_consecutive_rescales_u55_INT_lstm_gate_pattern_pipeline() -> None: - """Checks the LSTM-shaped regression path in the U55 INT pipeline.""" - model = LSTMGatePattern() - pipeline = EthosU55PipelineINT[LSTMGatePattern.input_t]( - model, - model.get_test_inputs(), - aten_ops=[], - exir_ops=[], - use_to_edge_transform_and_lower=True, - ) - pipeline.run()