Skip to content

Commit d263897

Browse files
authored
Arm backend: Fix two control flow bugs (#16274)
Arm backend: Regularize submodules before processing them. Add a function _regularize_submodule that contains special handling needed for submodules. This mainly solves two problems: - Buffers in submodules are (currently) handled differently depending on whether they are from tracing or added in passes. The old solution tried avoiding having to add a new meta field, but was brittle. The new solution simply marks all all placeholders before passes. - The pass pipeline assumes the dim_order of inputs and outputs to match the actual dim_order (tosa_dim_order). We need to ensure this. ------------------------ Arm backend: Fix while quantization The output of the while loop body can either re-enter the body, or exit the while loop. Therefore, A and B in the diagrambelow need to share the same quantization parameters. A -> while ( RESCALE -> ... -> RESCALE -> ) -> B Earlier tests happened to get equal qparams on the input and output, but this is not the general case. Signed-off-by: Erik Lundell <erik.lundell@arm.com>
1 parent 84b4a33 commit d263897

File tree

5 files changed

+96
-56
lines changed

5 files changed

+96
-56
lines changed

backends/arm/process_node.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
1616
from executorch.backends.arm.tosa.specification import TosaSpecification
1717
from executorch.backends.arm.tosa.utils import tosa_shape
18-
from executorch.exir.graph_module import get_cond_while_submodules
1918
from torch._export.utils import (
2019
get_buffer,
2120
get_lifted_tensor_constant,
@@ -183,9 +182,10 @@ def process_inputs_to_lifted_tensor_constants(
183182
) from e
184183
tensor = get_lifted_tensor_constant(edge_program, node)
185184
tensor_data = tensor.detach().numpy() # type: ignore[union-attr]
185+
tensor_values = np.transpose(tensor_data, tosa_arg.dim_order)
186186

187187
tosa_graph.addConst(
188-
tensor_data.shape, tosa_arg.dtype, tensor_data, name=tosa_arg.name
188+
tensor_values.shape, tosa_arg.dtype, tensor_values, name=tosa_arg.name
189189
)
190190

191191

@@ -195,46 +195,7 @@ def _is_submodule_input(
195195
"""Determines whether 'node' is an input to a submodule of 'containing_graph_module'."""
196196
if node.op != "placeholder":
197197
return False
198-
199-
for _, _, submodule_node in get_cond_while_submodules(containing_graph_module):
200-
args = cast(list[torch.fx.Node], submodule_node.args[-1])
201-
for arg in args:
202-
if isinstance(arg.target, str):
203-
# If argument is a buffer or similar, we can match exactly.
204-
if arg.target == node.name:
205-
return True
206-
# If argument target has a name, the submodule input is operator name + number to avoid duplication.
207-
# For example: cond input namespace::my_op -> submodule input my_op_1
208-
if (name_fn := (getattr(arg.target, "name", None))) is not None:
209-
op_name = name_fn().split(":")[-1]
210-
if op_name in node.name:
211-
return True
212-
return False
213-
214-
215-
def _submodule_has_user_input(
216-
containing_graph_module: torch.fx.GraphModule, edge_program: ExportedProgram
217-
):
218-
# If argument is a user input, there is no such guarantee. We need to to a heuristic match.
219-
for _, _, control_flow_node in get_cond_while_submodules(containing_graph_module):
220-
match control_flow_node.target:
221-
case torch.ops.higher_order.cond:
222-
args = control_flow_node.args[-1]
223-
case torch.ops.higher_order.while_loop:
224-
args = cast(list, control_flow_node.args[-2]) + cast(
225-
list, control_flow_node.args[-1]
226-
)
227-
case _:
228-
raise RuntimeError(
229-
f"Unexpected control flow target: {control_flow_node.target}"
230-
)
231-
args = cast(list[torch.fx.Node], args)
232-
for arg in args:
233-
if (
234-
isinstance(arg.target, str)
235-
and arg.target in edge_program.graph_signature.user_inputs
236-
):
237-
return True
198+
return node.meta.get("is_input", False)
238199

239200

240201
def process_placeholder(
@@ -268,11 +229,6 @@ def process_placeholder(
268229
raise NotImplementedError(
269230
"Placeholder is of type 'lifted custom object' which is not supported."
270231
)
271-
elif containing_graph_module and _submodule_has_user_input(
272-
containing_graph_module, edge_program
273-
):
274-
# If we are in a submodule and it has user input, process as regular input.
275-
process_inputs(node, tosa_graph, tosa_spec)
276232
else:
277233
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")
278234

backends/arm/quantizer/quantization_annotator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,7 @@ def any_or_hardtanh_min_zero(n: Node):
712712
):
713713
submodule_args_pos = -1 if node.target == torch.ops.higher_order.cond else -2
714714
submodule_args = node.args[submodule_args_pos]
715+
output_qspec = output_act_qspec
715716
if len(submodule_args) > 0: # type: ignore[arg-type]
716717
# The way the TOSA backend handles quantized inputs, arrays of input tensors (such as the input to a
717718
# conditional graph) need shared quantization.
@@ -727,7 +728,14 @@ def any_or_hardtanh_min_zero(n: Node):
727728
],
728729
)
729730
]
730-
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
731+
if node.target == torch.ops.higher_order.while_loop:
732+
# The output of the while loop body can either re-enter the body, or exit the while loop.
733+
# Therefore, A and B in the diagram below need to share the same quantization parameters.
734+
# A -> while ( RESCALE -> ... RESCALE -> ) -> B
735+
output_qspec = shared_qspec
736+
737+
quant_properties.quant_output = _QuantProperty(0, output_qspec)
738+
731739
else:
732740
return None
733741

backends/arm/test/ops/test_cond.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def false_branch(arg: torch.Tensor) -> torch.Tensor:
4747
class CondOneArgBufferOneOutput(torch.nn.Module):
4848
def __init__(self, *args: common.Any, **kwargs: common.Any) -> None:
4949
super().__init__(*args, **kwargs)
50-
self.buffer = torch.rand(2, 3)
50+
self.buffer = torch.rand(1, 1, 2, 2)
5151

5252
def forward(self, x: torch.Tensor) -> torch.Tensor:
5353
def true_branch(arg: torch.Tensor, buffer: torch.Tensor) -> torch.Tensor:
@@ -159,7 +159,7 @@ def _single_input_case(
159159
module_factory: Callable[[], torch.nn.Module]
160160
) -> Callable[[], tuple[torch.nn.Module, input_t1]]:
161161
def _create() -> tuple[torch.nn.Module, input_t1]:
162-
return module_factory(), (torch.randn(2, 3),)
162+
return module_factory(), (torch.randn(1, 1, 2, 2),)
163163

164164
return _create
165165

@@ -168,7 +168,7 @@ def _dual_input_case(
168168
module_factory: Callable[[], torch.nn.Module]
169169
) -> Callable[[], tuple[torch.nn.Module, input_t2]]:
170170
def _create() -> tuple[torch.nn.Module, input_t2]:
171-
return module_factory(), (torch.randn(2, 3), torch.randn(2, 3))
171+
return module_factory(), (torch.randn(2, 3, 4, 6), torch.randn(2, 3, 4, 6))
172172

173173
return _create
174174

@@ -223,6 +223,7 @@ def test_cond_tosa_FP(case: Callable[[], tuple[torch.nn.Module, tuple]]):
223223
pipeline = TosaPipelineFP[tuple](
224224
module, example_inputs, aten_op, tosa_extensions=["cf"]
225225
)
226+
226227
# Make sure no cond ops are left after partitioning.
227228
pipeline.add_stage_after(
228229
"to_edge_transform_and_lower",

backends/arm/test/ops/test_while.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,27 @@ def body_fn(
7171
return result # type: ignore
7272

7373

74+
class DecreasingOutput(torch.nn.Module):
75+
def __init__(self) -> None:
76+
super().__init__()
77+
78+
def forward(self, value: torch.Tensor) -> torch.Tensor:
79+
def cond_fn(value: torch.Tensor) -> torch.Tensor:
80+
total = value.sum()
81+
return torch.gt(total, torch.full((1,), 60.0)).squeeze()
82+
83+
def body_fn(value: torch.Tensor) -> Tuple[torch.Tensor]:
84+
return (torch.div(value, torch.full((1,), 2.0)),)
85+
86+
result = torch.ops.higher_order.while_loop(
87+
cond_fn,
88+
body_fn,
89+
(value,),
90+
(),
91+
)
92+
return result[0] # type: ignore
93+
94+
7495
class WhileAdditionalArg(torch.nn.Module):
7596
def __init__(self) -> None:
7697
super().__init__()
@@ -121,7 +142,7 @@ def _single_input_case(
121142
module_factory: Callable[[], torch.nn.Module],
122143
) -> Callable[[], Tuple[torch.nn.Module, input_single]]:
123144
def _create() -> Tuple[torch.nn.Module, input_single]:
124-
return module_factory(), (torch.ones(2, 3),)
145+
return module_factory(), (torch.ones(2, 3, 4, 6),)
125146

126147
return _create
127148

@@ -138,6 +159,7 @@ def _create() -> Tuple[torch.nn.Module, input_double]:
138159
test_cases: dict[str, Callable[[], Tuple[torch.nn.Module, Tuple]]] = {
139160
"two_in_two_out": _dual_input_case(WhileTwoInputsTwoOutputs),
140161
"one_in_one_buffer_two_out": _single_input_case(WhileOneInputOneBufferTwoOutputs),
162+
"decreasing_output": _single_input_case(DecreasingOutput),
141163
"additional_arg": _single_input_case(WhileAdditionalArg),
142164
"two_in_one_captured_out": _single_input_case(WhileSingleCapturedOutput),
143165
}

backends/arm/tosa/backend.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from itertools import count
2121
from typing import cast, Dict, final, List
2222

23+
import torch
24+
2325
import tosa_serializer as ts
2426
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
2527
from executorch.backends.arm.common.debug import debug_fail, debug_tosa_dump
@@ -33,6 +35,7 @@
3335
from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META
3436
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3537
from executorch.exir.backend.compile_spec_schema import CompileSpec
38+
from executorch.exir.dim_order_utils import get_memory_format
3639
from executorch.exir.graph_module import get_cond_while_submodules
3740
from torch.export.exported_program import ExportedProgram
3841
from torch.fx import Graph, GraphModule, Node
@@ -112,6 +115,15 @@ def _sort_key(t: Node) -> int:
112115
return graph_module
113116

114117

118+
def _get_matching_fake_tensor(node: Node):
119+
"""Return a fake tensor with the same properties as node,
120+
but with .dim_order() == node.meta["tosa_dim_order"]
121+
"""
122+
fake_tensor = node.meta["val"]
123+
desired_dim_order = node.meta["tosa_dim_order"]
124+
return fake_tensor.to(memory_format=get_memory_format(list(desired_dim_order)))
125+
126+
115127
def arm_get_first_delegation_tag(graph_module) -> str:
116128
"""Return the first delegation tag discovered in the FX graph.
117129
@@ -253,6 +265,47 @@ def _preprocess( # noqa: C901
253265

254266
return PreprocessResult(processed_bytes=binary)
255267

268+
@staticmethod
269+
def _regularize_submodule(submodule: GraphModule, submodule_node: Node):
270+
"""To make a submodule fit into the normal flow of a graph_module, we need to do some regularizations.
271+
272+
- Buffers created before passes are treated as input to the submodule. Buffers created during passes
273+
are treated as "normal" buffers, i.e. gathered from the state_dict.
274+
To make it easy to tell them apart, mark all placeholders with "is_input = True" before running passes.
275+
- Make sure output node args[0] is always iterable.
276+
- Match the dim_order() of the input tensors with the dim orders of the submodule_node inputs.
277+
- Match the dim_order() of the out tensors with the dim orders of the submodule_node outputs.
278+
"""
279+
submodule_inputs: list[Node] = []
280+
for node in submodule.graph.nodes:
281+
if node.op == "placeholder":
282+
node.meta["is_input"] = True
283+
submodule_inputs.append(node)
284+
match submodule_node.target:
285+
case torch.ops.higher_order.cond:
286+
args = cast(list[Node], submodule_node.args[-1])
287+
case torch.ops.higher_order.while_loop:
288+
args = cast(list[Node], submodule_node.args[-2]) + cast(
289+
list, submodule_node.args[-1]
290+
)
291+
case _:
292+
raise RuntimeError(
293+
f"Unexpected control flow target: {submodule_node.target}"
294+
)
295+
296+
for submodule_input, submodule_arg in zip(submodule_inputs, args, strict=True):
297+
submodule_input.meta["val"] = _get_matching_fake_tensor(submodule_arg)
298+
299+
output_node = submodule.graph.output_node()
300+
if isinstance(output_node.args[0], Node):
301+
output_node.update_arg(0, [output_node.args[0]])
302+
output_args = cast(list[Node], output_node.args[0])
303+
304+
# Not all outputs might be used, causing len(users) < len(outputs)
305+
# Therefore, strict != True in the zip
306+
for submodule_output, submodule_user in zip(output_args, submodule_node.users):
307+
submodule_output.meta["val"] = _get_matching_fake_tensor(submodule_user)
308+
256309
@staticmethod
257310
def _preprocess_module( # noqa: C901
258311
graph_module: GraphModule,
@@ -278,9 +331,6 @@ def _preprocess_module( # noqa: C901
278331
279332
"""
280333
tosa_spec = compile_spec.tosa_spec
281-
output_node = graph_module.graph.output_node()
282-
if isinstance(output_node.args[0], Node):
283-
output_node.update_arg(0, [output_node.args[0]])
284334
node_to_id_map = _annotate_external_ids(graph_module.graph)
285335
artifact_path = compile_spec.get_intermediate_path()
286336
output_order_workaround = compile_spec.get_output_order_workaround()
@@ -351,7 +401,10 @@ def _preprocess_module( # noqa: C901
351401
raise
352402

353403
# Recursively preprocess controlflow submodules.
354-
for name, submodule, _ in get_cond_while_submodules(graph_module):
404+
for name, submodule, control_flow_node in get_cond_while_submodules(
405+
graph_module
406+
):
407+
TOSABackend._regularize_submodule(submodule, control_flow_node)
355408
TOSABackend._preprocess_module(
356409
submodule,
357410
edge_program,

0 commit comments

Comments
 (0)