Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 208 additions & 3 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from executorch.backends.arm.constants import NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER
from executorch.backends.arm.tosa.dialect.shape import is_shape_op_node
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -271,7 +272,7 @@ def insert_output_transpose(node, graph_module):
# Guard: mem_format must be a true permutation for the current rank
assert sorted(mem_format) == list(
range(rank)
), f"bad perm {mem_format} for rank {rank} in insert_input_transpose"
), f"bad perm {mem_format} for rank {rank} in insert_output_transpose"

with graph_module.graph.inserting_after(node):
permute_node = create_node(
Expand All @@ -296,6 +297,122 @@ def insert_output_transpose(node, graph_module):
for user in users:
user.replace_input_with(node, permute_node)

@staticmethod
def _get_shape_indices(
src_shape: list[int], tgt_shape: list[int]
) -> list[list[int]] | None:
"""Greedy dimension matching for reshape operations.
For each target dimension, greedily consumes contiguous source
dimensions whose product equals the target size. Size-1 target
dimensions that do not correspond to any source dimension produce
empty index lists (inserted dims).
Returns ``None`` when no valid mapping exists.
"""
src_idx = 0
result: list[list[int]] = []

for tgt_dim in tgt_shape:
if tgt_dim <= 0:
return None

indices: list[int] = []
remaining = tgt_dim

while src_idx < len(src_shape):
if src_shape[src_idx] == 0:
return None
if remaining % src_shape[src_idx] != 0:
break
indices.append(src_idx)
remaining //= src_shape[src_idx]
src_idx += 1
if remaining == 1:
break

if remaining != 1:
return None

result.append(indices)

if src_idx != len(src_shape):
return None

return result

@staticmethod
def _is_monotonic(indices: list[list[int]]) -> bool:
"""Return ``True`` when all non-empty index groups are strictly ordered
— i.e. each group's indices follow the previous group's.
"""
last_max = -1
for group in indices:
if not group:
continue
if group[0] <= last_max:
return False
last_max = group[-1]
return True

@staticmethod
def _is_nhwc_safe_reshape(
input_shape, output_shape, input_sr, output_sr # noqa: ARG004
) -> bool:
"""Detect whether a 4-D+ reshape can operate directly on NHWC data.
By the time ``ToTosaMemoryFormatPass`` runs, 4-D tensor shapes in
``meta["val"]`` are already in NHWC physical order (the channel
dimension sits at position ``rank - spatial_rank - 1``, not at
position 1 as in NCHW). We therefore check the shape indices on
the **raw** input/output shapes — no extra permutation is needed.
Returns ``True`` when:
1. The reshape has monotonic shape_indices (each output dim maps
to a contiguous, in-order group of input dims), AND
2. The channel dimension is preserved alone (not merged with
spatial dims).
"""
rank_in = len(input_shape)
rank_out = len(output_shape)
if rank_in < 4 or rank_out < 4:
return False

indices = ToTosaMemoryFormatPass._get_shape_indices(
list(input_shape), list(output_shape)
)
if indices is None:
return False

if not ToTosaMemoryFormatPass._is_monotonic(indices):
return False

# In the TOSA pipeline the physical memory order is NHWC.
# The channel dimension in NHWC is always the **last** axis
# (position ``rank - 1``). It must appear *alone* in its
# output group — if it is merged with spatial dims the reshape
# would reorder channel data and the optimisation is invalid.
#
# Similarly, the batch dimension (index 0) must appear alone.
# Merging batch with spatial dims changes the element pairing
# semantics for downstream ops (e.g., conv2d after
# DecomposeLinearPass, or permutes converted to view_copy by
# ConvertPermuteSingletonToViewPass).
channel_idx = rank_in - 1
batch_idx = 0
for group in indices:
if channel_idx in group and len(group) != 1:
return False
if batch_idx in group and len(group) != 1:
return False

# Verify both batch and channel dims were actually consumed.
batch_found = any(batch_idx in g for g in indices)
channel_found = any(channel_idx in g for g in indices)
return batch_found and channel_found

@staticmethod
def _insert_view_transpose(
input_shape, output_shape, node, input_node, graph_module
Expand All @@ -317,6 +434,14 @@ def _insert_view_transpose(
output_sr,
)

# When the NHWC-space reshape has monotonic shape_indices the
# view_copy can operate directly on NHWC data — no transposes
# are needed.
if channel_reshape and ToTosaMemoryFormatPass._is_nhwc_safe_reshape(
input_shape, output_shape, input_sr, output_sr
):
return

if (
channel_reshape or nhwc_to_nchw
) and ToTosaMemoryFormatPass.memory_format_differs(input_shape, input_sr):
Expand All @@ -329,6 +454,83 @@ def _insert_view_transpose(
) and ToTosaMemoryFormatPass.memory_format_differs(output_shape, output_sr):
ToTosaMemoryFormatPass.insert_output_transpose(node, graph_module)

def _try_replace_redundant_permute(
self, node: torch.fx.Node, graph_module: torch.fx.GraphModule
) -> bool:
"""Remove a permute_copy if it duplicates tosa_dim_order.
When a permute_copy's permutation matches the channels-last order
(or its inverse) AND the input is already in NHWC dim_order, the
permute does the same NCHW<>NHWC conversion that tosa_dim_order
already handles — keeping both would double-convert. Remove the
permute by wiring its users directly to its input.
Returns ``True`` if the node was removed.
"""
if node.target not in (
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.permute.default,
):
return False

perm_arg = node.args[1]
assert isinstance(perm_arg, (list, tuple))
perm = list(perm_arg)
rank = len(perm)
sr = node.meta.get("tosa_spatial_rank", 0)

if rank < 3 or sr < 1:
return False

cl_order = list(self._channels_last_order(rank, sr))
cl_inv = list(self._channels_last_inverse_order(rank, sr))
if perm != cl_order and perm != cl_inv:
return False

# Only replace when the permute is genuinely redundant with the
# tosa_dim_order annotation. When the input is already in
# channels-last order (NHWC), any permute matching cl_order or
# cl_inv is a format conversion that tosa_dim_order already
# handles — keeping both would double-convert.
#
# When the input is in NCHW (e.g., from a placeholder or a
# non-spatial op), the permute is the model's intended
# computation and must NOT be replaced.
input_node = node.args[0]
if not isinstance(input_node, torch.fx.Node):
return False
input_dim_order = input_node.meta.get("tosa_dim_order")
if input_dim_order is not None:
if list(input_dim_order) != cl_order:
return False

# The permute is redundant — tosa_dim_order already handles
# the format conversion. Replace with a view_copy (identity
# reshape to the permuted shape) so consumers still see the
# correct shape. The view_copy must NOT be further processed
# by _insert_view_transpose (it's a no-op reshape, not a
# channel-crossing reshape).
output_shape = list(node.meta["val"].shape)
with graph_module.graph.inserting_before(node):
const_shape_node = graph_module.graph.call_function(
exir_ops.backend.tosa.CONST_SHAPE.default,
(output_shape,),
)
const_shape_node.meta["val"] = output_shape
const_shape_node.meta["tosa_dim_order"] = node.meta.get(
"tosa_dim_order", tuple(range(rank))
)
const_shape_node.meta[TosaSpecialDtype.meta_key()] = TosaSpecialDtype.SHAPE
view_node = graph_module.graph.call_function(
exir_ops.edge.aten.view_copy.default,
(input_node, const_shape_node),
)
view_node.meta = dict(node.meta)
node.replace_all_uses_with(view_node)
graph_module.graph.erase_node(node)
return True

def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
"""Transposes are needed for operators transforming the input to a
different rank, as 4D and 5D-tensors are assumed to be in (N)NHWC-
Expand All @@ -345,12 +547,15 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
- 1D/2D tensors
"""
for node in graph_module.graph.nodes:
for node in list(graph_module.graph.nodes):
if node.op != "call_function":
continue

if self._try_replace_redundant_permute(node, graph_module):
continue

# Transpose views
elif node.target == exir_ops.edge.aten.view_copy.default:
if node.target == exir_ops.edge.aten.view_copy.default:
input_node = node.args[0]
input_shape = input_node.meta["val"].shape
output_shape = node.meta["val"].shape
Expand Down
Loading
Loading