diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 0f32fbb52df..443a06e6b53 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -16,6 +16,7 @@ ) from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -187,12 +188,75 @@ def memory_format_differs(shape, spatial_rank): channel_dim = shape[channel_idx] return channel_dim > 1 and any(dim > 1 for dim in spatial_dims) + @staticmethod + def _is_nhwc_safe_reshape( + input_shape, output_shape, cl_order: tuple[int, ...] + ) -> bool: + """Return ``True`` when a 4-D+ reshape can operate directly on NHWC + data. + + A reshape is NHWC-safe when its shape_indices are monotonic, both the + batch dimension (index 0) and the channel dimension (last index) are + preserved alone in their output groups, and every merged group contains + only dims that are contiguous in the NHWC physical layout. + + """ + rank_in = len(input_shape) + rank_out = len(output_shape) + if rank_in < 4 or rank_out < 4: + return False + + indices = ToTosaMemoryFormatPass._get_shape_indices( + list(input_shape), list(output_shape) + ) + if indices is None or not ToTosaMemoryFormatPass._is_monotonic(indices): + return False + + # The channel dim (last axis in NHWC) and batch dim (index 0) + # must each appear alone — merging either with spatial dims + # would reorder data or change element pairing semantics. + channel_idx = rank_in - 1 + batch_idx = 0 + for group in indices: + if channel_idx in group and len(group) != 1: + return False + if batch_idx in group and len(group) != 1: + return False + + batch_found = any(batch_idx in g for g in indices) + channel_found = any(channel_idx in g for g in indices) + if not (batch_found and channel_found): + return False + + # Merged dims must be contiguous in the NHWC physical layout. + # The TOSA RESHAPE operates on row-major data in NHWC order, + # so only dims adjacent in that order can be validly merged. + nhwc_pos = [0] * rank_in + for pos, dim in enumerate(cl_order): + nhwc_pos[dim] = pos + for group in indices: + if len(group) <= 1: + continue + positions = sorted(nhwc_pos[d] for d in group) + for i in range(1, len(positions)): + if positions[i] != positions[i - 1] + 1: + return False + + return True + @staticmethod def is_channel_reshape( input_shape, output_shape, input_spatial_rank, output_spatial_rank ): """Check whether a reshape touches the logical channel or consolidated - batch dimensions, which would invalidate dim-order annotations. + batch dimensions in a way that would invalidate dim-order annotations. + + Returns ``False`` (no transposes needed) when either: + - The reshape does not change the channel or batch dimensions at all, OR + - The reshape is NHWC-safe: monotonic shape_indices with both batch + (index 0) and channel (last index) preserved alone in their output + groups, meaning the view_copy can operate directly on NHWC data. + """ valid_ranks = {4, 5, 6} @@ -220,7 +284,27 @@ def get_batch_prod_dim(shape, spatial_rank): N_old = get_batch_prod_dim(input_shape, input_spatial_rank) N_new = get_batch_prod_dim(output_shape, output_spatial_rank) - return (N_old != N_new) or (C_old != C_new) + if (N_old == N_new) and (C_old == C_new): + return False + + # The reshape touches batch/channel dims — check whether it is + # NHWC-safe (can operate directly on NHWC data without transposes). + # This optimisation is only valid when both tensors use the same + # channels-last permutation; when the spatial rank changes relative + # to the tensor rank the NHWC axis mapping differs and the reshape + # would scramble data. + in_cl = ToTosaMemoryFormatPass._channels_last_order( + len(input_shape), input_spatial_rank + ) + out_cl = ToTosaMemoryFormatPass._channels_last_order( + len(output_shape), output_spatial_rank + ) + if in_cl == out_cl and ToTosaMemoryFormatPass._is_nhwc_safe_reshape( + input_shape, output_shape, in_cl + ): + return False + + return True @staticmethod def insert_input_transpose(node, input_node, graph_module): @@ -271,7 +355,7 @@ def insert_output_transpose(node, graph_module): # Guard: mem_format must be a true permutation for the current rank assert sorted(mem_format) == list( range(rank) - ), f"bad perm {mem_format} for rank {rank} in insert_input_transpose" + ), f"bad perm {mem_format} for rank {rank} in insert_output_transpose" with graph_module.graph.inserting_after(node): permute_node = create_node( @@ -296,6 +380,65 @@ def insert_output_transpose(node, graph_module): for user in users: user.replace_input_with(node, permute_node) + @staticmethod + def _get_shape_indices( + src_shape: list[int], tgt_shape: list[int] + ) -> list[list[int]] | None: + """Greedy dimension matching for reshape operations. + + For each target dimension, greedily consumes contiguous source + dimensions whose product equals the target size. Size-1 target + dimensions that do not correspond to any source dimension produce + empty index lists (inserted dims). + + Returns ``None`` when no valid mapping exists. + + """ + src_idx = 0 + result: list[list[int]] = [] + + for tgt_dim in tgt_shape: + if tgt_dim <= 0: + return None + + indices: list[int] = [] + remaining = tgt_dim + + while src_idx < len(src_shape): + if src_shape[src_idx] == 0: + return None + if remaining % src_shape[src_idx] != 0: + break + indices.append(src_idx) + remaining //= src_shape[src_idx] + src_idx += 1 + if remaining == 1: + break + + if remaining != 1: + return None + + result.append(indices) + + if src_idx != len(src_shape): + return None + + return result + + @staticmethod + def _is_monotonic(indices: list[list[int]]) -> bool: + """Return ``True`` when all non-empty index groups are strictly ordered + — i.e. each group's indices follow the previous group's. + """ + last_max = -1 + for group in indices: + if not group: + continue + if group[0] <= last_max: + return False + last_max = group[-1] + return True + @staticmethod def _insert_view_transpose( input_shape, output_shape, node, input_node, graph_module @@ -329,6 +472,110 @@ def _insert_view_transpose( ) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr): ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module) + @staticmethod + def _is_input_channels_last(input_node: torch.fx.Node, cl_order: list[int]) -> bool: + """Return True if *input_node* is already in channels-last order. + + Only when the input is in NHWC does a cl_order/cl_inv permute duplicate + the tosa_dim_order annotation. When the input is in NCHW (e.g. from a + placeholder or non-spatial op) the permute is the model's intended + computation and must be kept. + + """ + input_dim_order = input_node.meta.get("tosa_dim_order") + if input_dim_order is None: + return True + return list(input_dim_order) == cl_order + + @staticmethod + def _is_semantic_permute(input_node: torch.fx.Node) -> bool: + """Return True if the permute's input traces back to a shape- + manipulation op through transpose/permute nodes. + + Walk upstream through tosa.TRANSPOSE and aten.permute_copy nodes + (chained permutes arise from decomposition passes, e.g. unfold -> + as_strided + movedim -> permute_copy). If a shape-manipulation op is + found, the permute is semantic, not a format conversion. + + """ + upstream: torch.fx.Node | object = input_node + while isinstance(upstream, torch.fx.Node) and upstream.target in ( + exir_ops.backend.tosa.TRANSPOSE.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.permute.default, + ): + upstream = upstream.args[0] + return isinstance(upstream, torch.fx.Node) and upstream.target in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.reshape.default, + exir_ops.edge.aten.as_strided.default, + exir_ops.edge.aten.as_strided_copy.default, + ) + + def _try_replace_redundant_permute( + self, node: torch.fx.Node, graph_module: torch.fx.GraphModule + ) -> bool: + """Remove a permute_copy if it duplicates tosa_dim_order. + + When a permute_copy's permutation matches the channels-last order + (or its inverse) AND the input is already in NHWC dim_order, the + permute does the same NCHW<>NHWC conversion that tosa_dim_order + already handles — keeping both would double-convert. Remove the + permute by wiring its users directly to its input. + + Returns ``True`` if the node was removed. + + """ + if node.target not in ( + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.permute.default, + ): + return False + + perm_arg = node.args[1] + assert isinstance(perm_arg, (list, tuple)) + perm = list(perm_arg) + rank = len(perm) + sr = node.meta.get("tosa_spatial_rank", 0) + + if rank < 3 or sr < 1: + return False + + cl_order = list(self._channels_last_order(rank, sr)) + cl_inv = list(self._channels_last_inverse_order(rank, sr)) + if perm != cl_order and perm != cl_inv: + return False + + input_node = node.args[0] + if not isinstance(input_node, torch.fx.Node): + return False + + if not self._is_input_channels_last(input_node, cl_order): + return False + + if self._is_semantic_permute(input_node): + return False + + output_shape = list(node.meta["val"].shape) + with graph_module.graph.inserting_before(node): + const_shape_node = graph_module.graph.call_function( + exir_ops.backend.tosa.CONST_SHAPE.default, + (output_shape,), + ) + const_shape_node.meta["val"] = output_shape + const_shape_node.meta["tosa_dim_order"] = node.meta.get( + "tosa_dim_order", tuple(range(rank)) + ) + const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE + view_node = graph_module.graph.call_function( + exir_ops.edge.aten.view_copy.default, + (input_node, const_shape_node), + ) + view_node.meta = dict(node.meta) + node.replace_all_uses_with(view_node) + graph_module.graph.erase_node(node) + return True + def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): """Transposes are needed for operators transforming the input to a different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC- @@ -345,12 +592,15 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): - 1D/2D tensors """ - for node in graph_module.graph.nodes: + for node in list(graph_module.graph.nodes): if node.op != "call_function": continue + if self._try_replace_redundant_permute(node, graph_module): + continue + # Transpose views - elif node.target == exir_ops.edge.aten.view_copy.default: + if node.target == exir_ops.edge.aten.view_copy.default: input_node = node.args[0] input_shape = input_node.meta["val"].shape output_shape = node.meta["val"].shape diff --git a/backends/arm/operators/op_tosa_conv3d.py b/backends/arm/operators/op_tosa_conv3d.py index c033314f9a7..d0ad9bf977a 100644 --- a/backends/arm/operators/op_tosa_conv3d.py +++ b/backends/arm/operators/op_tosa_conv3d.py @@ -15,7 +15,7 @@ class Conv3dVisitor(Conv2dVisitor): target = "tosa.CONV3D.default" def _get_tosa_op(self): - import serializer.tosa_serializer as ts # type: ignore + import tosa_serializer as ts # type: ignore return ts.Op.CONV3D diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..49dafc94a65 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -177,11 +177,77 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) +class NHWCSafeSpatialMerge(torch.nn.Module): + """Test-module with a 4D->4D reshape that merges NCHW dims 1 and 2. + + For models with view_copy shapes [1,2,14,72]->[1,28,1,72] where C=2 + sits at NCHW position 1. Dims 1 and 2 map to NHWC positions 3 and 1 + (not contiguous), so the reshape is NOT NHWC-safe and transposes are + inserted around the view_copy. + + Setup: conv2d (forces NHWC, C=2) -> view_copy -> add (keeps in NHWC). + + """ + + ops_before_pass: Dict[str, int] = {} + # 2 I/O transposes for conv + 2 for view_copy (NHWC-unsafe merge) + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=2, out_channels=2, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # forces NHWC path; output [1, 2, 14, 72] + x = x.view(1, 28, 1, 72) # spatial merge: H*W=2*14->28, last dim 72 preserved + return x + x # keep result 4-D in NHWC + + def get_inputs(self) -> input_t: + return (torch.randn(1, 2, 14, 72),) + + +class NHWCUnsafeChannelChange(torch.nn.Module): + """Test-module with a 4D->4D reshape that is NOT NHWC-safe because the + target shape cannot be produced by a monotonic merge of NHWC input dims. + + The pass MUST still insert transposes around the view_copy. + + """ + + ops_before_pass: Dict[str, int] = {} + # conv I/O transposes (2) + view_copy transposes (2) = 4 + ops_after_pass: Dict[str, int] = { + "executorch_exir_dialects_backend__ops_tosa_TRANSPOSE_default": 4 + } + ops_not_after_pass: List[str] = [] + + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=72, out_channels=72, kernel_size=1, bias=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) # output [1, 72, 2, 14] + x = x.view(1, 14, 2, 72) # not NHWC-safe (channels shuffled) + return x + x + + def get_inputs(self) -> input_t: + return (torch.randn(1, 72, 2, 14),) + + modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), "serial_clusters": SerialClusters(), "reshapes": Reshapes(), + "nhwc_safe_spatial_merge": NHWCSafeSpatialMerge(), + "nhwc_unsafe_channel_change": NHWCUnsafeChannelChange(), } @@ -209,3 +275,79 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() + + +# --- Direct unit tests for NHWC-safe reshape detection in is_channel_reshape --- + + +def test_get_shape_indices_spatial_merge(): + """[1,2,14,72] -> [1,28,1,72]: merge H*W, insert size-1 dim, preserve C.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([1, 2, 14, 72], [1, 28, 1, 72]) + assert indices == [[0], [1, 2], [], [3]] + + +def test_get_shape_indices_identity(): + """Same shape => each dim maps to itself.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [2, 3, 4]) + assert indices == [[0], [1], [2]] + + +def test_get_shape_indices_full_merge(): + """[2, 3, 4] -> [24]: merge all dims into one.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 4], [24]) + assert indices == [[0, 1, 2]] + + +def test_get_shape_indices_incompatible(): + """Sizes that don't divide => None.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([2, 3, 5], [6, 4]) + assert indices is None + + +def test_get_shape_indices_size_one_insert(): + """[6, 4] -> [6, 1, 4]: inserted size-1 dim in the middle.""" + indices = ToTosaMemoryFormatPass._get_shape_indices([6, 4], [6, 1, 4]) + assert indices is not None + assert indices == [[0], [], [1]] + + +def test_is_monotonic_true(): + assert ToTosaMemoryFormatPass._is_monotonic([[0], [1, 2], [], [3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[0], [], [1], [2, 3]]) + assert ToTosaMemoryFormatPass._is_monotonic([[], [0, 1, 2]]) + + +def test_is_monotonic_false(): + assert not ToTosaMemoryFormatPass._is_monotonic([[1], [0]]) + assert not ToTosaMemoryFormatPass._is_monotonic([[0, 2], [1]]) + + +def test_channel_reshape_nhwc_unsafe_merge(): + """[1,2,14,72] -> [1,28,1,72] merges NCHW dims 1 and 2. + + Dims 1,2 map to NHWC positions 3,1 (not contiguous), so reshape is NOT NHWC- + safe. is_channel_reshape should return True. + + """ + assert ToTosaMemoryFormatPass.is_channel_reshape( + [1, 2, 14, 72], [1, 28, 1, 72], input_spatial_rank=2, output_spatial_rank=2 + ) + + +def test_channel_reshape_non_4d(): + """Reshapes below rank 4 always return False from is_channel_reshape.""" + assert not ToTosaMemoryFormatPass.is_channel_reshape( + [6, 4], [24], input_spatial_rank=0, output_spatial_rank=0 + ) + + +def test_channel_reshape_batch_merge(): + """Reshapes merging batch with spatial dims are NOT NHWC-safe.""" + # [1,2,5,10] -> [2,1,5,10]: merges N(=1) with H(=2) — not safe + assert ToTosaMemoryFormatPass.is_channel_reshape( + [1, 2, 5, 10], [2, 1, 5, 10], input_spatial_rank=2, output_spatial_rank=2 + ) + # [5,10,25,20] -> [1250,20,1,1]: merges N+H+W — not safe (Linear decomp) + assert ToTosaMemoryFormatPass.is_channel_reshape( + [5, 10, 25, 20], [1250, 20, 1, 1], input_spatial_rank=2, output_spatial_rank=2 + ) diff --git a/backends/arm/test/targets.bzl b/backends/arm/test/targets.bzl index 85367b10ae8..0ae5a3cc817 100644 --- a/backends/arm/test/targets.bzl +++ b/backends/arm/test/targets.bzl @@ -19,18 +19,21 @@ def define_arm_tests(): "ops/test_avg_pool2d.py", "ops/test_cat.py", "ops/test_conv2d.py", + "ops/test_conv3d.py", + "ops/test_cos.py", "ops/test_linear.py", + "ops/test_max_pool.py", "ops/test_max_pool1d.py", "ops/test_mul.py", "ops/test_permute.py", "ops/test_rsqrt.py", - "ops/test_slice.py", "ops/test_sigmoid.py", + "ops/test_slice.py", "ops/test_sub.py", "ops/test_tanh.py", - "ops/test_view.py", - "ops/test_cos.py", "ops/test_to_copy.py", + "ops/test_unsqueeze.py", + "ops/test_view.py", ] # Quantization