diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 20ead36627c..ea4d49a79bb 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -43,6 +43,9 @@ from .decompose_cumsum_pass import DecomposeCumsumPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_div_tensor_mode import DecomposeDivTensorModePass # noqa +from .decompose_dynamic_adaptive_avg_pool2d_pass import ( # noqa + DecomposeDynamicAdaptiveAvgPool2dPass, +) from .decompose_dynamic_full_pass import DecomposeDynamicFullPass # noqa from .decompose_einsum_pass import DecomposeEinsumPass # noqa from .decompose_elu_pass import ConvertEluFamilyToEluPass, DecomposeEluPass # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 748c369482f..485e01278d9 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -49,6 +49,7 @@ DecomposeCumsumPass, DecomposeDivPass, DecomposeDivTensorModePass, + DecomposeDynamicAdaptiveAvgPool2dPass, DecomposeDynamicFullPass, DecomposeEinsumPass, DecomposeEluPass, @@ -463,6 +464,7 @@ def _tosa_pipeline( AccumulateIndexPutPass(), DecomposeIndexTensorToGatherPass(), DecomposeAdaptiveAvgPool2dPass(), + DecomposeDynamicAdaptiveAvgPool2dPass(), DecomposeAvgPool2dPass(), Conv1dUnsqueezePass(), ] diff --git a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py index 58fcf69cd8f..07fd5c9e358 100644 --- a/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py +++ b/backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py @@ -12,7 +12,11 @@ from executorch.backends.arm._passes.decompose_avg_pool2d_pass import ( DecomposeAvgPool2dPass, ) - +from executorch.backends.arm.constants import NHWC_INVERSE_ORDER, NHWC_ORDER +from executorch.backends.arm.tosa.specification import ( + get_context_shape_env, + get_context_spec, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, NodeMetadata @@ -37,12 +41,13 @@ def _get_decomposition(op) -> tuple: class DecomposeAdaptiveAvgPool2dPass(ArmOpTargetedPass): - """Decomposes AdaptiveAvgPool2d into AvgPool2d operations. + """Decompose static-shape topology-changing AdaptiveAvgPool2d cases. - An input tensor of shape (N, C, H, W) is transformed into an output tensor - of shape (N, C, output_size_h, output_size_w). + Static input/output shapes use the existing slice + avg_pool2d + cat + lowering, with a fast path for directly representable uniform regions. - The output is of size output_size_h x output_size_w for any input. + Dynamic cases are left untouched for dedicated dynamic rewrite/decompose + passes later in the TOSA pipeline. """ @@ -50,40 +55,288 @@ class DecomposeAdaptiveAvgPool2dPass(ArmOpTargetedPass): target_ops = edge_ops + aten_ops check_allowed_to_transform = True - def call_operator(self, op, args, kwargs, meta, updated=False): - if op not in self.target_ops or not self.allowed_to_transform(meta): - return super().call_operator(op, args, kwargs, meta, updated) + @staticmethod + def _is_static_dim(dim) -> bool: + return not isinstance(dim, torch.SymInt) - avg_pool2d_op, slice_op, cat_op = _get_decomposition(op) + @classmethod + def _is_static_shape(cls, *dims) -> bool: + return all(cls._is_static_dim(dim) for dim in dims) - x = args[0] + @staticmethod + def _has_dynamic_spatial_shape(x) -> bool: _, _, input_size_h, input_size_w = x.data.shape + return isinstance(input_size_h, torch.SymInt) or isinstance( + input_size_w, torch.SymInt + ) + + def _call_const_shape(self, value: int, meta: NodeMetadata): + return super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + ([value],), + {}, + meta, + True, + ) + + def _get_dim_shape(self, x, axis: int, meta: NodeMetadata): + dim = x.data.shape[axis] + if isinstance(dim, torch.SymInt): + return super().call_shape_operator( + exir_ops.backend.tosa.DIM.default, + (x,), + {"axis": axis}, + meta, + True, + ) + return self._call_const_shape(dim, meta) + + def _shape_mul_const(self, value, factor: int, meta: NodeMetadata): + return super().call_shape_operator( + exir_ops.backend.tosa.MUL_SHAPE.default, + (value, self._call_const_shape(factor, meta)), + {}, + meta, + True, + ) + + def _shape_add_const(self, value, addend: int, meta: NodeMetadata): + return super().call_shape_operator( + exir_ops.backend.tosa.ADD_SHAPE.default, + (value, self._call_const_shape(addend, meta)), + {}, + meta, + True, + ) + + def _shape_floor_div_const(self, value, divisor: int, meta: NodeMetadata): + return super().call_shape_operator( + exir_ops.backend.tosa.DIV_FLOOR_SHAPE.default, + (value, self._call_const_shape(divisor, meta)), + {}, + meta, + True, + ) + + def _shape_sub(self, lhs, rhs, meta: NodeMetadata): + return super().call_shape_operator( + exir_ops.backend.tosa.SUB_SHAPE.default, + (lhs, rhs), + {}, + meta, + True, + ) + + def _shape_concat(self, parts: list, meta: NodeMetadata): + return super().call_shape_operator( + exir_ops.backend.tosa.CONCAT_SHAPE.default, + (parts,), + {}, + meta, + True, + ) + + def _is_directly_representable(self, input_size, output_size) -> bool: + if isinstance(output_size, torch.SymInt): + return False + if self._is_static_dim(input_size): + return input_size % output_size in (0, 1) + + try: + remainder_range = get_context_shape_env().bound_sympy( + (input_size % output_size).node.expr + ) + except Exception: + return False + return remainder_range.is_singleton() and remainder_range.upper in (0, 1) + + def _is_dynamic_direct_case(self, x, output_size_h, output_size_w) -> bool: + _, _, input_size_h, input_size_w = x.data.shape + if not self._has_dynamic_spatial_shape(x): + return False + return self._is_directly_representable( + input_size_h, output_size_h + ) and self._is_directly_representable(input_size_w, output_size_w) + + @staticmethod + def _static_bin_bounds( + input_size: int, output_size: int, out_idx: int + ) -> tuple[int, int]: + start = floor(out_idx * input_size / output_size) + end = ceil((out_idx + 1) * input_size / output_size) + return start, end - (output_size_h, output_size_w) = args[1] + def _symbolic_bin_bounds(self, input_size, output_size: int, out_idx: int, meta): + start_num = self._shape_mul_const(input_size, out_idx, meta) + start = self._shape_floor_div_const(start_num, output_size, meta) - # Vela currently only allows a stride in the interval of [1,3] for AvgPool2d. - # To accommodate this, the AvgPool2d op is applied to pooling regions and the results are concatenated. + end_num = self._shape_mul_const(input_size, out_idx + 1, meta) + end_num = self._shape_add_const(end_num, output_size - 1, meta) + end = self._shape_floor_div_const(end_num, output_size, meta) + + size = self._shape_sub(end, start, meta) + return start, end, size + + def _emit_tosa_slice(self, x, start_h, size_h, start_w, size_w, meta): + start = self._shape_concat( + [ + self._call_const_shape(0, meta), + self._call_const_shape(0, meta), + start_h, + start_w, + ], + meta, + ) + size = self._shape_concat( + [ + self._get_dim_shape(x, 0, meta), + self._get_dim_shape(x, 1, meta), + size_h, + size_w, + ], + meta, + ) + return super().call_operator( + exir_ops.backend.tosa.SLICE.default, + (x, start, size), + {}, + meta, + True, + ) + + def _emit_adaptive_pool(self, x_slice, size_h, size_w, meta): + in_qparams = meta.data.get("input_qparams", {}) + in_zp_val = in_qparams[0].get_zp_per_tensor() if 0 in in_qparams else 0 + input_zp = self.call_scalar(in_zp_val, meta) + + out_qparams = meta.data.get("output_qparams", {}) + out_zp_val = out_qparams[0].get_zp_per_tensor() if 0 in out_qparams else 0 + output_zp = self.call_scalar(out_zp_val, meta) + + acc_type = ( + torch.int32 + if x_slice.data.dtype in (torch.int8, torch.int16) + else torch.float32 + ) + stride = [1, 1] + pad = [0, 0, 0, 0] + x_slice_nhwc = super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (x_slice, list(NHWC_ORDER)), + {}, + meta, + True, + ) + pad = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + (pad,), + {}, + meta, + ) + kernel = [size_h, size_w] + if all(isinstance(k, int) for k in kernel): + kernel = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + (kernel,), + {}, + meta, + ) + else: + kernel = self._shape_concat( + [ + self._get_dim_shape(x_slice_nhwc, 1, meta), + self._get_dim_shape(x_slice_nhwc, 2, meta), + ], + meta, + ) + if all(isinstance(s, int) for s in stride): + stride = super().call_shape_operator( + exir_ops.backend.tosa.CONST_SHAPE.default, + (stride,), + {}, + meta, + ) + pooled_nhwc = super().call_operator( + exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default, + (x_slice_nhwc, input_zp, output_zp, kernel, stride, pad, acc_type), + {}, + meta, + True, + ) + return super().call_operator( + exir_ops.edge.aten.permute_copy.default, + (pooled_nhwc, list(NHWC_INVERSE_ORDER)), + {}, + meta, + True, + ) + + @staticmethod + def _supports_dynamic_tosa_adaptive() -> bool: + try: + tosa_spec = get_context_spec() + except Exception: + return False + return ( + tosa_spec.version.major == 1 + and tosa_spec.version.minor >= 1 + and tosa_spec.support_extension("shape") + ) + + def _decompose_static( + self, + avg_pool2d_op, + slice_op, + cat_op, + x, + output_size_h, + output_size_w, + kwargs, + meta, + ): + _, _, input_size_h, input_size_w = x.data.shape + + stride_h = floor(input_size_h / output_size_h) + stride_w = floor(input_size_w / output_size_w) + if ( + self._is_directly_representable(input_size_h, output_size_h) + and self._is_directly_representable(input_size_w, output_size_w) + and stride_h in (1, 2, 3) + and stride_w in (1, 2, 3) + ): + kernel_h = stride_h + (input_size_h % output_size_h) + kernel_w = stride_w + (input_size_w % output_size_w) + return super().call_operator( + avg_pool2d_op, + (x, (kernel_h, kernel_w), (stride_h, stride_w), (0, 0)), + kwargs, + meta, + True, + ) - # Slices and concats does not require quantization parameters metadata_dict = dict(meta.data) metadata_dict["input_qparams"] = {} metadata_dict["output_qparams"] = {} meta_with_no_qparams = NodeMetadata(metadata_dict) + res = [] for out_i in range(output_size_h): row = [] for out_j in range(output_size_w): - # Calculate pooling regions - start_h = floor(out_i * input_size_h / output_size_h) - end_h = ceil((out_i + 1) * input_size_h / output_size_h) - start_w = floor(out_j * input_size_w / output_size_w) - end_w = ceil((out_j + 1) * input_size_w / output_size_w) + start_h, end_h = self._static_bin_bounds( + input_size_h, output_size_h, out_i + ) + start_w, end_w = self._static_bin_bounds( + input_size_w, output_size_w, out_j + ) - # Slice along H x_h = super().call_operator( - slice_op, (x, 2, start_h, end_h), kwargs, meta_with_no_qparams, True + slice_op, + (x, 2, start_h, end_h), + kwargs, + meta_with_no_qparams, + True, ) - # Slice along W x_hw = super().call_operator( slice_op, (x_h, 3, start_w, end_w), @@ -92,28 +345,118 @@ def call_operator(self, op, args, kwargs, meta, updated=False): True, ) - # Apply avg pooling with kernel size equal to the pooling region kernel_h = end_h - start_h kernel_w = end_w - start_w - pool_args = (x_hw, (kernel_h, kernel_w), (1, 1), (0, 0)) pooled = super().call_operator( - avg_pool2d_op, pool_args, kwargs, meta, True + avg_pool2d_op, + (x_hw, (kernel_h, kernel_w), (1, 1), (0, 0)), + kwargs, + meta, + True, ) row.append(pooled) - # Concatenate row results along width (dim=3) if more than one. - if len(row) > 1: - row_tensor = super().call_operator( - cat_op, (row, 3), kwargs, meta_with_no_qparams, True + + row_tensor = ( + super().call_operator( + cat_op, + (row, 3), + kwargs, + meta_with_no_qparams, + True, ) - else: - row_tensor = row[0] + if len(row) > 1 + else row[0] + ) res.append(row_tensor) - # Concatenate all rows along height (dim=2) if more than one. - if len(res) > 1: - out = super().call_operator( - cat_op, (res, 2), kwargs, meta_with_no_qparams, True + return ( + super().call_operator( + cat_op, + (res, 2), + kwargs, + meta_with_no_qparams, + True, ) - else: - out = res[0] - return out + if len(res) > 1 + else res[0] + ) + + def _decompose_dynamic_static_output( + self, x, cat_op, output_size_h: int, output_size_w: int, kwargs, meta + ): + metadata_dict = dict(meta.data) + metadata_dict["input_qparams"] = {} + metadata_dict["output_qparams"] = {} + meta_with_no_qparams = NodeMetadata(metadata_dict) + + input_h_shape = self._get_dim_shape(x, 2, meta_with_no_qparams) + input_w_shape = self._get_dim_shape(x, 3, meta_with_no_qparams) + + res = [] + for out_i in range(output_size_h): + row = [] + start_h, _end_h, size_h = self._symbolic_bin_bounds( + input_h_shape, output_size_h, out_i, meta_with_no_qparams + ) + for out_j in range(output_size_w): + start_w, _end_w, size_w = self._symbolic_bin_bounds( + input_w_shape, output_size_w, out_j, meta_with_no_qparams + ) + x_slice = self._emit_tosa_slice( + x, start_h, size_h, start_w, size_w, meta_with_no_qparams + ) + pooled = self._emit_adaptive_pool(x_slice, size_h, size_w, meta) + row.append(pooled) + + row_tensor = ( + super().call_operator( + cat_op, + (row, 3), + kwargs, + meta_with_no_qparams, + True, + ) + if len(row) > 1 + else row[0] + ) + res.append(row_tensor) + + return ( + super().call_operator( + cat_op, + (res, 2), + kwargs, + meta_with_no_qparams, + True, + ) + if len(res) > 1 + else res[0] + ) + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta): + return super().call_operator(op, args, kwargs, meta, updated) + + avg_pool2d_op, slice_op, cat_op = _get_decomposition(op) + x = args[0] + output_size_h, output_size_w = args[1] + + if isinstance(output_size_h, torch.SymInt) or isinstance( + output_size_w, torch.SymInt + ): + return super().call_operator(op, args, kwargs, meta, updated) + + _, _, input_size_h, input_size_w = x.data.shape + if not self._is_static_shape(input_size_h, input_size_w): + return super().call_operator(op, args, kwargs, meta, updated) + + return self._decompose_static( + avg_pool2d_op, + slice_op, + cat_op, + x, + output_size_h, + output_size_w, + kwargs, + meta, + ) diff --git a/backends/arm/_passes/decompose_dynamic_adaptive_avg_pool2d_pass.py b/backends/arm/_passes/decompose_dynamic_adaptive_avg_pool2d_pass.py new file mode 100644 index 00000000000..0bb7ec7c41a --- /dev/null +++ b/backends/arm/_passes/decompose_dynamic_adaptive_avg_pool2d_pass.py @@ -0,0 +1,57 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import ( + _get_decomposition, + aten_ops, + DecomposeAdaptiveAvgPool2dPass, + edge_ops, +) +from executorch.backends.arm._passes.rewrite_adaptive_avg_pool2d import ( + RewriteAdaptiveAvgPool2dPass, +) + + +class DecomposeDynamicAdaptiveAvgPool2dPass(DecomposeAdaptiveAvgPool2dPass): + """Decompose symbolic irregular AdaptiveAvgPool2d to TOSA shape ops. + + Directly representable dynamic cases are left to + ``RewriteAdaptiveAvgPool2dPass``. Static cases stay in + ``DecomposeAdaptiveAvgPool2dPass``. + + """ + + _passes_required_after = {RewriteAdaptiveAvgPool2dPass} + + def call_operator(self, op, args, kwargs, meta, updated=False): + if op not in (edge_ops + aten_ops) or not self.allowed_to_transform(meta): + return super().call_operator(op, args, kwargs, meta, updated) + + x = args[0] + output_size_h, output_size_w = args[1] + if isinstance(output_size_h, torch.SymInt) or isinstance( + output_size_w, torch.SymInt + ): + return super().call_operator(op, args, kwargs, meta, updated) + + if not self._has_dynamic_spatial_shape(x): + return super().call_operator(op, args, kwargs, meta, updated) + + if self._is_dynamic_direct_case(x, output_size_h, output_size_w): + return super().call_operator(op, args, kwargs, meta, updated) + + if not self._supports_dynamic_tosa_adaptive(): + return super().call_operator(op, args, kwargs, meta, updated) + + _, _, input_size_h, input_size_w = x.data.shape + if self._is_static_shape(input_size_h, input_size_w): + return super().call_operator(op, args, kwargs, meta, updated) + + _, _, cat_op = _get_decomposition(op) + return self._decompose_dynamic_static_output( + x, cat_op, output_size_h, output_size_w, kwargs, meta + ) diff --git a/backends/arm/_passes/insert_dynamic_padding.py b/backends/arm/_passes/insert_dynamic_padding.py index 22de1262e83..bfc0382e4ad 100644 --- a/backends/arm/_passes/insert_dynamic_padding.py +++ b/backends/arm/_passes/insert_dynamic_padding.py @@ -31,6 +31,7 @@ class InsertDynamicPaddingPass(ArmOpTargetedPass): exir_ops.backend.tosa.CONV2D.default, exir_ops.backend.tosa.DEPTHWISE_CONV2D.default, exir_ops.backend.tosa.MAX_POOL2D.default, + exir_ops.backend.tosa.AVG_POOL2D.default, ) def _is_dynamic_padding( @@ -48,6 +49,8 @@ def call_operator(self, op, args, kwargs, meta, updated=False) -> ProxyValue: return super().call_operator(op, args, kwargs, meta, updated) if op == exir_ops.backend.tosa.MAX_POOL2D.default: padding_index = 3 + elif op == exir_ops.backend.tosa.AVG_POOL2D.default: + padding_index = 5 else: padding_index = 4 padding = args[padding_index] diff --git a/backends/arm/test/passes/test_decompose_adaptive_avg_pool2d_pass.py b/backends/arm/test/passes/test_decompose_adaptive_avg_pool2d_pass.py new file mode 100644 index 00000000000..abb54b4f4de --- /dev/null +++ b/backends/arm/test/passes/test_decompose_adaptive_avg_pool2d_pass.py @@ -0,0 +1,191 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes.decompose_adaptive_avg_pool2d_pass import ( + DecomposeAdaptiveAvgPool2dPass, +) +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir import to_edge +from torch.export import export + +input_t = Tuple[torch.Tensor] + + +class AdaptiveAvgPoolUniform(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 8, 8),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.adaptive_avg_pool2d(x, (4, 4)) + + +class AdaptiveAvgPoolIrregular(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 7, 7),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.adaptive_avg_pool2d(x, (4, 4)) + + +class AdaptiveAvgPoolLargeStride(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 32, 32),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.adaptive_avg_pool2d(x, (4, 4)) + + +class AdaptiveAvgPoolAsymmetric(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 9, 13),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.adaptive_avg_pool2d(x, (2, 3)) + + +class AdaptiveAvgPoolKeepWidth(torch.nn.Module): + def get_inputs(self) -> input_t: + return (torch.rand(1, 3, 10, 16),) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.adaptive_avg_pool2d(x, (2, None)) + + +def _run_static_decomposition(module: torch.nn.Module, inputs: input_t): + ep = export(module, inputs) + edge_model = to_edge(ep) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+FP")): + edge_model = edge_model.transform([DecomposeAdaptiveAvgPool2dPass()]) + return edge_model.exported_program().graph_module + + +def test_decompose_adaptive_avg_pool2d_uniform_regions_rewrite_to_avg_pool2d(): + module = AdaptiveAvgPoolUniform() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 1, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default", + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", + "executorch_exir_dialects_edge__ops_aten_cat_default", + "executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_ADAPTIVE_default", + ], + pass_list=[DecomposeAdaptiveAvgPool2dPass], + ) + pipeline.run() + + +def test_decompose_adaptive_avg_pool2d_no_target_irregular_regions(): + module = AdaptiveAvgPoolIrregular() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 16, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 32, + "executorch_exir_dialects_edge__ops_aten_cat_default": 5, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default", + ], + pass_list=[DecomposeAdaptiveAvgPool2dPass], + ) + pipeline.run() + + +def test_decompose_adaptive_avg_pool2d_no_target_large_stride_still_decomposes(): + module = AdaptiveAvgPoolLargeStride() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 16, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 32, + "executorch_exir_dialects_edge__ops_aten_cat_default": 5, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default", + ], + pass_list=[DecomposeAdaptiveAvgPool2dPass], + ) + pipeline.run() + + +def test_decompose_adaptive_avg_pool2d_asymmetric_regions_compare_numerically(): + module = AdaptiveAvgPoolAsymmetric() + inputs = ( + torch.arange(1, 1 + 1 * 3 * 9 * 13, dtype=torch.float32).reshape(1, 3, 9, 13), + ) + transformed = _run_static_decomposition(module, inputs) + + reference = module(*inputs) + result = transformed(*inputs) + if isinstance(result, tuple): + result = result[0] + + assert torch.allclose(result, reference) + + +def test_decompose_adaptive_avg_pool2d_asymmetric_regions_decompose(): + module = AdaptiveAvgPoolAsymmetric() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 6, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 12, + "executorch_exir_dialects_edge__ops_aten_cat_default": 3, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default", + "executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_ADAPTIVE_default", + ], + pass_list=[DecomposeAdaptiveAvgPool2dPass], + ) + pipeline.run() + + +def test_decompose_adaptive_avg_pool2d_keep_width_decompose(): + module = AdaptiveAvgPoolKeepWidth() + pipeline = PassPipeline[input_t]( + module, + module.get_inputs(), + ops_before_pass={ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default": 1, + }, + ops_after_pass={ + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default": 32, + "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 64, + "executorch_exir_dialects_edge__ops_aten_cat_default": 3, + }, + ops_not_after_pass=[ + "executorch_exir_dialects_edge__ops_aten__adaptive_avg_pool2d_default", + "executorch_exir_dialects_backend__ops_tosa_AVG_POOL2D_ADAPTIVE_default", + ], + pass_list=[DecomposeAdaptiveAvgPool2dPass], + ) + pipeline.run() diff --git a/backends/arm/test/passes/test_decompose_dynamic_adaptive_avg_pool2d_pass.py b/backends/arm/test/passes/test_decompose_dynamic_adaptive_avg_pool2d_pass.py new file mode 100644 index 00000000000..0e6a1e81b78 --- /dev/null +++ b/backends/arm/test/passes/test_decompose_dynamic_adaptive_avg_pool2d_pass.py @@ -0,0 +1,93 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from executorch.backends.arm._passes.decompose_dynamic_adaptive_avg_pool2d_pass import ( + DecomposeDynamicAdaptiveAvgPool2dPass, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir import to_edge +from executorch.exir.dialects._ops import ops as exir_ops +from torch._export.utils import _get_shape_env_from_gm +from torch.export import Dim, export + +input_t = Tuple[torch.Tensor] + + +class AdaptiveAvgPoolDynamic(torch.nn.Module): + def __init__(self, output_size: tuple[int | None, int | None] = (4, 4)): + super().__init__() + self.output_size = output_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size) + + +def _run_dynamic_decomposition(dynamic_shapes, output_size=(4, 4)): + module = AdaptiveAvgPoolDynamic(output_size) + example_inputs = (torch.rand(1, 3, 8, 8),) + ep = export(module, example_inputs, dynamic_shapes=dynamic_shapes) + edge_model = to_edge(ep) + shape_env = _get_shape_env_from_gm(edge_model.exported_program().graph_module) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env=shape_env + ): + edge_model = edge_model.transform([DecomposeDynamicAdaptiveAvgPool2dPass()]) + return list(edge_model.exported_program().graph.nodes) + + +def test_decompose_dynamic_adaptive_avg_pool2d_irregular_uses_tosa_adaptive(): + nodes = _run_dynamic_decomposition( + { + "x": { + 2: Dim("height", min=4, max=10), + 3: Dim("width", min=4, max=10), + } + } + ) + + assert not any( + n.target == exir_ops.edge.aten._adaptive_avg_pool2d.default for n in nodes + ) + assert ( + sum( + n.target == exir_ops.backend.tosa.AVG_POOL2D_ADAPTIVE.default for n in nodes + ) + == 16 + ) + assert sum(n.target == exir_ops.backend.tosa.SLICE.default for n in nodes) == 16 + assert sum(n.target == exir_ops.edge.aten.permute_copy.default for n in nodes) == 32 + assert any(n.target == exir_ops.backend.tosa.DIM.default for n in nodes) + assert any(n.target == exir_ops.backend.tosa.DIV_FLOOR_SHAPE.default for n in nodes) + assert any(n.target == exir_ops.backend.tosa.SUB_SHAPE.default for n in nodes) + assert any(n.target == exir_ops.backend.tosa.CONCAT_SHAPE.default for n in nodes) + + +def test_rewrite_adaptive_avg_pool2d_does_not_require_dynamic_decompose_pass(): + from executorch.backends.arm._passes.rewrite_adaptive_avg_pool2d import ( + RewriteAdaptiveAvgPool2dPass, + ) + + assert ( + DecomposeDynamicAdaptiveAvgPool2dPass + not in RewriteAdaptiveAvgPool2dPass._passes_required_after + ) + + +def test_decompose_dynamic_adaptive_avg_pool2d_requires_rewrite_adaptive_avg_pool2d(): + from executorch.backends.arm._passes.rewrite_adaptive_avg_pool2d import ( + RewriteAdaptiveAvgPool2dPass, + ) + + assert ( + RewriteAdaptiveAvgPool2dPass + in DecomposeDynamicAdaptiveAvgPool2dPass._passes_required_after + )