Skip to content

Commit c493e2d

Browse files
jgibson2Copilot
andauthored
Add ability to specify CoreML pipeline passes (#16118)
### Summary Adds the ability to specify a set of CoreML passes as a CompileSpec, allowing additional customization of the model compilation. ### Test plan Converted a model and made sure it worked with a custom pipeline. Also ensured via print statements that the passes were translated correctly. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 0686e6a commit c493e2d

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

backends/apple/coreml/compiler/coreml_preprocess.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class COMPILE_SPEC_KEYS(Enum):
4343
MODEL_COMPUTE_PRECISION = "model_compute_precision"
4444
OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
4545
ENUMERATED_SHAPES = "enumerated_shapes"
46+
PASS_PIPELINE = "pass_pipeline"
4647

4748

4849
class MODEL_PATHS(Enum):
@@ -220,6 +221,33 @@ def op_linear_quantizer_config_from_compile_specs(
220221

221222
return None
222223

224+
@staticmethod
225+
def generate_pass_pipeline_compile_spec(pass_names: List[str]) -> CompileSpec:
226+
"""
227+
Creates a compile spec representing the pass pipeline to be used by the CoreML backend
228+
:param pass_names: the list of pass names
229+
"""
230+
str_representation = json.dumps(pass_names)
231+
byte_representation = str_representation.encode("utf-8")
232+
return CompileSpec(COMPILE_SPEC_KEYS.PASS_PIPELINE.value, byte_representation)
233+
234+
@staticmethod
235+
def pass_pipeline_from_compile_specs(
236+
compile_specs: List[CompileSpec],
237+
) -> ct.PassPipeline:
238+
"""
239+
Creates a PassPipeline from the list of compile specs, or returns the default if none are provided.
240+
"""
241+
for compile_spec in compile_specs:
242+
if compile_spec.key == COMPILE_SPEC_KEYS.PASS_PIPELINE.value:
243+
pass_names_str = compile_spec.value.decode("utf-8")
244+
pass_names = json.loads(pass_names_str)
245+
return ct.PassPipeline(
246+
pass_names, pipeline_name="executorch_user_pipeline"
247+
)
248+
249+
return ct.PassPipeline.DEFAULT
250+
223251
@staticmethod
224252
def generate_enumerated_shapes_compile_spec(
225253
ep: ExportedProgram,
@@ -275,6 +303,7 @@ def generate_compile_specs(
275303
compute_precision: ct.precision = ct.precision.FLOAT16,
276304
model_type: MODEL_TYPE = MODEL_TYPE.MODEL,
277305
op_linear_quantizer_config: Optional[Dict] = None,
306+
pass_names: Optional[List[str]] = None,
278307
) -> List[CompileSpec]:
279308
"""
280309
Returns the list of compile specs that's used by CoreMLBackend to lower the module.
@@ -298,6 +327,10 @@ def generate_compile_specs(
298327
op_linear_quantizer_config
299328
)
300329
)
330+
if pass_names is not None:
331+
compile_specs.append(
332+
CoreMLBackend.generate_pass_pipeline_compile_spec(pass_names)
333+
)
301334

302335
return compile_specs
303336

@@ -503,6 +536,9 @@ def preprocess(
503536
enumerated_shapes = CoreMLBackend.enumerated_shapes_from_compile_specs(
504537
compile_specs
505538
)
539+
pass_pipeline: ct.PassPipeline = CoreMLBackend.pass_pipeline_from_compile_specs(
540+
compile_specs
541+
)
506542

507543
# If using enumerated shapes, we need to pass the inputs to CoreML's convert() function
508544
# explicitly
@@ -530,7 +566,7 @@ def preprocess(
530566
model=edge_program,
531567
source="pytorch",
532568
convert_to="mlprogram",
533-
pass_pipeline=ct.PassPipeline.DEFAULT,
569+
pass_pipeline=pass_pipeline,
534570
skip_model_load=skip_model_load,
535571
compute_precision=model_compute_precision,
536572
minimum_deployment_target=minimum_deployment_target,

0 commit comments

Comments
 (0)