Skip to content

Commit 84b4a33

Browse files
Arm backend: Add support for masked_fill_.Scalar (#16272)
### Summary Add support for inplace masked_fill in DecomposeMaskedFillPass. As a consequence, we can now partition the full Conformer model for TOSA INT and U85. Signed-off-by: Oscar Andersson <oscar.andersson@arm.com>
1 parent a3a0190 commit 84b4a33

File tree

3 files changed

+13
-41
lines changed

3 files changed

+13
-41
lines changed

backends/arm/_passes/decompose_masked_fill_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
edge_ops = (exir_ops.edge.aten.masked_fill.Scalar,)
20-
aten_ops = (torch.ops.aten.masked_fill.Scalar,)
20+
aten_ops = (torch.ops.aten.masked_fill.Scalar, torch.ops.aten.masked_fill_.Scalar)
2121

2222

2323
def _get_decomposition(op) -> tuple:
@@ -26,7 +26,7 @@ def _get_decomposition(op) -> tuple:
2626
exir_ops.edge.aten.where.self,
2727
exir_ops.edge.aten.full_like.default,
2828
)
29-
if op in aten_ops:
29+
elif op in aten_ops:
3030
return (
3131
torch.ops.aten.where.self,
3232
torch.ops.aten.full_like.default,
@@ -44,7 +44,7 @@ class DecomposeMaskedFillPass(ArmPass):
4444
_passes_required_after: Set[Type[ExportPass]] = {ConvertFullLikeToFullPass}
4545

4646
def call_operator(self, op, args, kwargs, meta, updated=False):
47-
if op not in (edge_ops + aten_ops):
47+
if op not in (*aten_ops, *edge_ops):
4848
return super().call_operator(op, args, kwargs, meta, updated)
4949

5050
x, mask, scalar = args

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ class TestCLIPTextModelWithProjection:
4545
"executorch_exir_dialects_edge__ops_aten_index_select_default": 1,
4646
"executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor": 1,
4747
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
48-
"executorch_exir_dialects_edge__ops_aten_where_self": 1,
4948
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
50-
"torch.ops.aten.scalar_tensor.default": 1,
5149
"torch.ops.higher_order.executorch_call_delegate": 2,
5250
}
5351

backends/arm/test/models/test_conformer.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,8 @@ def test_conformer_tosa_INT():
7272
aten_op=[],
7373
exir_op=[],
7474
use_to_edge_transform_and_lower=True,
75-
)
76-
pipeline.pop_stage("check_count.exir")
77-
pipeline.change_args(
78-
"run_method_and_compare_outputs",
79-
get_test_inputs(
80-
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
81-
),
82-
rtol=TestConformer.rtol,
8375
atol=TestConformer.atol,
76+
rtol=TestConformer.rtol,
8477
)
8578
pipeline.run()
8679

@@ -93,38 +86,26 @@ def test_conformer_u55_INT():
9386
pipeline = EthosU55PipelineINT[input_t](
9487
TestConformer.conformer,
9588
TestConformer.model_example_inputs,
96-
aten_ops=TestConformer.aten_ops,
89+
aten_ops=[],
9790
exir_ops=[],
9891
use_to_edge_transform_and_lower=True,
92+
atol=TestConformer.atol,
93+
rtol=TestConformer.rtol,
9994
)
100-
pipeline.change_args(
101-
"run_method_and_compare_outputs",
102-
get_test_inputs(
103-
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
104-
),
105-
rtol=1.0,
106-
atol=5.0,
107-
)
95+
pipeline.pop_stage("check_count.exir")
10896
pipeline.run()
10997

11098

11199
@common.XfailIfNoCorstone320
112-
@pytest.mark.xfail(reason="All IO needs to have the same data type (MLETORCH-635)")
113100
def test_conformer_u85_INT():
114101
pipeline = EthosU85PipelineINT[input_t](
115102
TestConformer.conformer,
116103
TestConformer.model_example_inputs,
117-
aten_ops=TestConformer.aten_ops,
104+
aten_ops=[],
118105
exir_ops=[],
119106
use_to_edge_transform_and_lower=True,
120-
)
121-
pipeline.change_args(
122-
"run_method_and_compare_outputs",
123-
get_test_inputs(
124-
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
125-
),
126-
rtol=1.0,
127-
atol=5.0,
107+
atol=TestConformer.atol,
108+
rtol=TestConformer.rtol,
128109
)
129110
pipeline.run()
130111

@@ -137,16 +118,9 @@ def test_conformer_vgf_quant():
137118
aten_op=[],
138119
exir_op=[],
139120
use_to_edge_transform_and_lower=True,
140-
quantize=True,
141-
)
142-
pipeline.pop_stage("check_count.exir")
143-
pipeline.change_args(
144-
"run_method_and_compare_outputs",
145-
get_test_inputs(
146-
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
147-
),
148-
rtol=TestConformer.rtol,
149121
atol=TestConformer.atol,
122+
rtol=TestConformer.rtol,
123+
quantize=True,
150124
)
151125
pipeline.run()
152126

0 commit comments

Comments
 (0)