Skip to content

Commit 7c27a3e

Browse files
fix linear bias decomposition invokation
1 parent e5122d0 commit 7c27a3e

2 files changed

Lines changed: 6 additions & 6 deletions

File tree

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,9 @@ def _create_metal_partitioners(programs):
513513
# print(f"Running decompositions for {key}")
514514
# print(ep.graph_module)
515515
if key != "preprocessor":
516-
updated_programs[key] = ep.run_decompositions(
517-
{torch.ops.aten.linear.default: _linear_bias_decomposition}
518-
)
516+
decomp_table = torch.export.default_decompositions()
517+
decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition
518+
updated_programs[key] = ep.run_decompositions(decomp_table)
519519
else:
520520
updated_programs[key] = ep
521521

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -395,9 +395,9 @@ def lower_to_executorch(programs, metadata, backend="xnnpack"):
395395
# Run decompositions for Metal backend
396396
updated_programs = {}
397397
for key, ep in programs.items():
398-
updated_programs[key] = ep.run_decompositions(
399-
{torch.ops.aten.linear.default: _linear_bias_decomposition}
400-
)
398+
decomp_table = torch.export.default_decompositions()
399+
decomp_table[torch.ops.aten.linear.default] = _linear_bias_decomposition
400+
updated_programs[key] = ep.run_decompositions(decomp_table)
401401
programs = updated_programs
402402

403403
partitioner = {}

0 commit comments

Comments
 (0)