1010# JIT compiler flows.
1111#
1212
13+ import json
1314from abc import ABC , abstractmethod
1415from dataclasses import dataclass , field
1516from enum import Enum
1617
18+ from executorch .backends .arm .common .pipeline_config import ArmPassPipelineConfig
1719from executorch .backends .arm .tosa import TosaSpecification
1820
1921from executorch .exir .backend .compile_spec_schema import CompileSpec
@@ -36,6 +38,7 @@ class DebugMode(Enum):
3638 _DEBUG_ARTIFACT_KEY = "debug_artifact_path"
3739 _DEBUG_MODE_KEY = "dump_debug_info"
3840 _OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
41+ _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config"
3942
4043 def _set_compile_specs (
4144 self ,
@@ -44,13 +47,15 @@ def _set_compile_specs(
4447 path_for_intermediates : str | None = None ,
4548 tosa_debug_mode : DebugMode | None = None ,
4649 output_order_workaround : bool = True ,
50+ pipeline_config : ArmPassPipelineConfig | None = None ,
4751 ):
4852 """Set all values of dataclass directly."""
4953 self .tosa_spec = tosa_spec
5054 self .compiler_flags = compiler_flags
5155 self .path_for_intermediates = path_for_intermediates
5256 self .tosa_debug_mode = tosa_debug_mode
5357 self .output_order_workaround = output_order_workaround
58+ self ._pipeline_config = pipeline_config
5459
5560 @classmethod
5661 def from_list (cls , compile_specs : list [CompileSpec ]): # noqa: C901
@@ -60,6 +65,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
6065 path_for_intermediates : str | None = None
6166 tosa_debug_mode : ArmCompileSpec .DebugMode | None = None
6267 output_order_workaround : bool = True
68+ pipeline_config : ArmPassPipelineConfig | None = None
6369 unknown_specs : dict [str , str ] = {}
6470 for spec in compile_specs :
6571 key = spec .key
@@ -98,6 +104,12 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
98104 tosa_debug_mode = ArmCompileSpec .DebugMode [val ]
99105 elif key == ArmCompileSpec ._OUTPUT_REORDER_KEY :
100106 output_order_workaround = val # type: ignore[assignment]
107+ elif key == ArmCompileSpec ._TRANSFORM_PIPELINE_CONFIG_KEY :
108+ if pipeline_config is not None :
109+ raise ValueError (
110+ "More than one transform pipeline entry in compile spec."
111+ )
112+ pipeline_config = ArmPassPipelineConfig .from_dict (json .loads (val ))
101113 else :
102114 unknown_specs [key ] = val
103115
@@ -120,6 +132,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
120132 path_for_intermediates = path_for_intermediates ,
121133 tosa_debug_mode = tosa_debug_mode ,
122134 output_order_workaround = output_order_workaround ,
135+ pipeline_config = pipeline_config ,
123136 )
124137 cls .from_list_hook (compile_spec , unknown_specs )
125138 compile_spec .validate ()
@@ -189,8 +202,33 @@ def to_list(self):
189202 )
190203 )
191204
205+ if self ._pipeline_config is not None and not self ._pipeline_config .is_default ():
206+ compile_spec .append (
207+ CompileSpec (
208+ ArmCompileSpec ._TRANSFORM_PIPELINE_CONFIG_KEY ,
209+ self ._pipeline_config .serialize (),
210+ )
211+ )
192212 return compile_spec
193213
214+ def get_pass_pipeline_config (self ) -> ArmPassPipelineConfig :
215+ """
216+ Returns configuration that controls how the Arm pass pipeline should behave.
217+ Subclasses may override to tweak defaults for specific targets.
218+ """
219+ if self ._pipeline_config is None :
220+ self ._pipeline_config = self ._create_default_pipeline_config ()
221+ return self ._pipeline_config
222+
223+ def set_pass_pipeline_config (self , config : ArmPassPipelineConfig ) -> None :
224+ self ._pipeline_config = config
225+
226+ def _create_default_pipeline_config (self ) -> ArmPassPipelineConfig :
227+ config = ArmPassPipelineConfig ()
228+ if self .tosa_spec .is_U55_subset :
229+ config .disable_masked_softmax ()
230+ return config
231+
194232 def get_intermediate_path (self ) -> str | None :
195233 """
196234 Gets the path used for dumping intermediate results such as tosa and pte.
0 commit comments