From a54dbb6994e4f70dbb1c678b97777062f107f914 Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 16 Mar 2026 01:51:03 -0500 Subject: [PATCH 01/19] [feat] add vitis_generate_model_sd.py. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py new file mode 100644 index 0000000000..c47ab7ec0a --- /dev/null +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -0,0 +1,169 @@ +# +# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# + +"""Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). +Accepts ONNX input only; run OnnxConversion (e.g. from PyTorchModel + olive user_script) first, +then this pass runs generate_sd_model for preprocess + partition. +""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path +from typing import Optional + +from olive.model import ONNXModelHandler +from olive.passes import Pass +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) + + +def _get_sd_registry(): + """Import registry from npu_model_gen to keep model_type choices in sync.""" + from model_generate import _SD_CONFIG_REGISTRY + return _SD_CONFIG_REGISTRY + + +def _build_fixed_shapes(dim_param: Optional[list], dim_value: Optional[list]) -> Optional[list[str]]: + """Build --fixed-shapes style list (e.g. ['batch=1', 'height=64']) from dim_param and dim_value.""" + if not dim_param or not dim_value: + return None + if len(dim_param) != len(dim_value): + raise ValueError("dim_param and dim_value must have the same length.") + return [f"{p}={v}" for p, v in zip(dim_param, dim_value)] + + +class VitisGenerateModelSD(Pass): + """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. + Use OnnxConversion (PyTorchModel + olive user_script) upstream to produce ONNX. + Optional dim_param / dim_value override the default fixed shapes used in preprocess (like DynamicToFixedShape). + """ + + @classmethod + def _default_config(cls, accelerator_spec): + registry = _get_sd_registry() + return { + "model_type": PassConfigParam( + type_=str, + required=True, + description="SD submodel type, must be a key from SD config registry (e.g. sd_unet, sd_vae_decoder, sd_vae_encoder).", + ), + "fixed_shapes_dim_param": PassConfigParam( + type_=list, + default_value=None, + required=False, + description=( + "Symbolic dimension names for fixed shapes (e.g. ['batch','channels','height','width']). " + ), + ), + "fixed_shapes_dim_value": PassConfigParam( + type_=list, + default_value=None, + required=False, + description=( + "Defines the values for dimensions listed in fixed_shapes_dim_param (e.g., [1, 4, 64, 64]). " + "Use 'x' to preserve a dynamic dimension (e.g., [1, 4, 'x', 'x']). " + "The length must match fixed_shapes_dim_param if specified." + ), + ), + } + + @staticmethod + def _validate_model_type(model_type: str) -> None: + registry = _get_sd_registry() + if model_type not in registry: + raise ValueError( + f"model_type must be one of {list(registry.keys())}, got {model_type!r}" + ) + + def _run_for_config( + self, + model: ONNXModelHandler, + config: BasePassConfig, + output_model_path: str, + ) -> ONNXModelHandler: + if not isinstance(model, ONNXModelHandler): + raise TypeError( + "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " + f"Got {type(model).__name__}" + ) + model_type = config.model_type + self._validate_model_type(model_type) + + output_dir = Path(output_model_path) + if output_dir.suffix == ".onnx": + output_dir = output_dir.parent + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info( + "[VitisGenerateModelSD] output_dir=%s, model_type=%s", + output_dir, + model_type, + ) + + onnx_input_path = self._resolve_onnx_input_path(model) + logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) + + fixed_shapes = _build_fixed_shapes( + getattr(config, "fixed_shapes_dim_param", None), getattr(config, "fixed_shapes_dim_value", None) + ) + if fixed_shapes: + logger.info( + "[VitisGenerateModelSD] Overriding fixed shapes: %s", + fixed_shapes, + ) + + from model_generate import generate_sd_model + + generate_sd_model( + input_model=str(onnx_input_path), + output_dir=str(output_dir), + model_type=model_type, + fixed_shapes=fixed_shapes, + ) + + self._ensure_model_onnx(output_dir) + + return ONNXModelHandler( + model_path=str(output_dir), + onnx_file_name="model.onnx", + ) + + def _resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: + p = Path(model.model_path) + if p.is_file(): + return p + if p.is_dir(): + name = getattr(model, "onnx_file_name", None) + if name: + f = p / name + if f.exists(): + return f + onnx_files = list(p.glob("*.onnx")) + if onnx_files: + return onnx_files[0] + raise FileNotFoundError(f"No .onnx file found under {p}") + raise FileNotFoundError(f"Model path does not exist: {p}") + + def _ensure_model_onnx(self, output_dir: Path) -> None: + """Copy actual generate_sd_model output to output_dir/model.onnx if needed.""" + model_onnx = output_dir / "model.onnx" + if model_onnx.exists(): + return + optimized = output_dir / "optimized.onnx" + dd_replaced = output_dir / "dd" / "replaced.onnx" + if dd_replaced.exists(): + shutil.copy2(dd_replaced, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from dd/replaced.onnx") + elif optimized.exists(): + shutil.copy2(optimized, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx") + else: + logger.warning( + "[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under %s", + output_dir, + ) From 0a98c0dc12daf868f4029c531d483631d8b5adaa Mon Sep 17 00:00:00 2001 From: liujij Date: Tue, 17 Mar 2026 03:08:34 -0500 Subject: [PATCH 02/19] up olive_config.json. --- olive/olive_config.json | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/olive/olive_config.json b/olive/olive_config.json index 2748c39101..dcdf4af072 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -626,6 +626,15 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ], "run_on_target": true + }, + "VitisGenerateModelSD": { + "module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD", + "supported_providers": [ "CPUExecutionProvider" ], + "supported_accelerators": [ "cpu" ], + "supported_precisions": [ "bf16", "bfp16" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "run_on_target": true } }, "extra_dependencies": { From 6f675d4f82d1796c1abf22cf6bbbeb4bc8c2fe38 Mon Sep 17 00:00:00 2001 From: liujij Date: Tue, 24 Mar 2026 04:54:05 -0500 Subject: [PATCH 03/19] update codes. --- olive/olive_config.json | 2 +- .../onnx/vitis_ai/vitis_generate_model_sd.py | 46 +++++-------------- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/olive/olive_config.json b/olive/olive_config.json index dcdf4af072..4168e0b2d3 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -631,7 +631,7 @@ "module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD", "supported_providers": [ "CPUExecutionProvider" ], "supported_accelerators": [ "cpu" ], - "supported_precisions": [ "bf16", "bfp16" ], + "supported_precisions": [ "int8" ], "supported_algorithms": [ ], "supported_quantization_encodings": [ ], "run_on_target": true diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index c47ab7ec0a..e5bd88cd11 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -27,20 +27,10 @@ def _get_sd_registry(): from model_generate import _SD_CONFIG_REGISTRY return _SD_CONFIG_REGISTRY - -def _build_fixed_shapes(dim_param: Optional[list], dim_value: Optional[list]) -> Optional[list[str]]: - """Build --fixed-shapes style list (e.g. ['batch=1', 'height=64']) from dim_param and dim_value.""" - if not dim_param or not dim_value: - return None - if len(dim_param) != len(dim_value): - raise ValueError("dim_param and dim_value must have the same length.") - return [f"{p}={v}" for p, v in zip(dim_param, dim_value)] - - class VitisGenerateModelSD(Pass): """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. Use OnnxConversion (PyTorchModel + olive user_script) upstream to produce ONNX. - Optional dim_param / dim_value override the default fixed shapes used in preprocess (like DynamicToFixedShape). + Optional resolutions override the default fixed shapes used in preprocess. Default is [512x512]. """ @classmethod @@ -50,25 +40,13 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description="SD submodel type, must be a key from SD config registry (e.g. sd_unet, sd_vae_decoder, sd_vae_encoder).", + description=f"SD submodel type, must be a key from SD config registry (e.g. {list(registry.keys())}).", ), - "fixed_shapes_dim_param": PassConfigParam( - type_=list, - default_value=None, + "resolutions": PassConfigParam( + type_=list[str], + default_value=["512x512"], required=False, - description=( - "Symbolic dimension names for fixed shapes (e.g. ['batch','channels','height','width']). " - ), - ), - "fixed_shapes_dim_value": PassConfigParam( - type_=list, - default_value=None, - required=False, - description=( - "Defines the values for dimensions listed in fixed_shapes_dim_param (e.g., [1, 4, 64, 64]). " - "Use 'x' to preserve a dynamic dimension (e.g., [1, 4, 'x', 'x']). " - "The length must match fixed_shapes_dim_param if specified." - ), + description="List of resolutions (e.g. ['512x512', '1024x1024']) Default is [512x512].", ), } @@ -108,13 +86,11 @@ def _run_for_config( onnx_input_path = self._resolve_onnx_input_path(model) logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) - fixed_shapes = _build_fixed_shapes( - getattr(config, "fixed_shapes_dim_param", None), getattr(config, "fixed_shapes_dim_value", None) - ) - if fixed_shapes: + resolutions = getattr(config, "resolutions", None) + if resolutions: logger.info( - "[VitisGenerateModelSD] Overriding fixed shapes: %s", - fixed_shapes, + "[VitisGenerateModelSD] Using resolutions: %s", + resolutions, ) from model_generate import generate_sd_model @@ -123,7 +99,7 @@ def _run_for_config( input_model=str(onnx_input_path), output_dir=str(output_dir), model_type=model_type, - fixed_shapes=fixed_shapes, + resolutions=resolutions, ) self._ensure_model_onnx(output_dir) From 342c43da2c8cc9ec2afd44c1bbf0649bdfb26dc8 Mon Sep 17 00:00:00 2001 From: liujij Date: Wed, 25 Mar 2026 04:24:05 -0500 Subject: [PATCH 04/19] up. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index e5bd88cd11..1d17806a4b 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -24,8 +24,8 @@ def _get_sd_registry(): """Import registry from npu_model_gen to keep model_type choices in sync.""" - from model_generate import _SD_CONFIG_REGISTRY - return _SD_CONFIG_REGISTRY + import model_generate + return model_generate.SUPPORTED_SD_MODEL_TYPES class VitisGenerateModelSD(Pass): """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. @@ -40,7 +40,7 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be a key from SD config registry (e.g. {list(registry.keys())}).", + description=f"SD submodel type, must be one of {', '.join(registry)}.", ), "resolutions": PassConfigParam( type_=list[str], From f8b9e49d076487171185a5831a3d2fb2f77d8fdf Mon Sep 17 00:00:00 2001 From: liujij Date: Wed, 25 Mar 2026 04:35:44 -0500 Subject: [PATCH 05/19] lint. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 1d17806a4b..22112e321a 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -4,6 +4,7 @@ # """Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). + Accepts ONNX input only; run OnnxConversion (e.g. from PyTorchModel + olive user_script) first, then this pass runs generate_sd_model for preprocess + partition. """ @@ -13,7 +14,6 @@ import logging import shutil from pathlib import Path -from typing import Optional from olive.model import ONNXModelHandler from olive.passes import Pass @@ -25,12 +25,15 @@ def _get_sd_registry(): """Import registry from npu_model_gen to keep model_type choices in sync.""" import model_generate + return model_generate.SUPPORTED_SD_MODEL_TYPES + class VitisGenerateModelSD(Pass): """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. - Use OnnxConversion (PyTorchModel + olive user_script) upstream to produce ONNX. - Optional resolutions override the default fixed shapes used in preprocess. Default is [512x512]. + + Use OnnxConversion to produce ONNX input model. + Optional resolutions to generate NPU-ready models. Default is [512x512]. """ @classmethod @@ -54,9 +57,7 @@ def _default_config(cls, accelerator_spec): def _validate_model_type(model_type: str) -> None: registry = _get_sd_registry() if model_type not in registry: - raise ValueError( - f"model_type must be one of {list(registry.keys())}, got {model_type!r}" - ) + raise ValueError(f"model_type must be one of {list(registry.keys())}, got {model_type!r}") def _run_for_config( self, @@ -66,8 +67,7 @@ def _run_for_config( ) -> ONNXModelHandler: if not isinstance(model, ONNXModelHandler): raise TypeError( - "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " - f"Got {type(model).__name__}" + f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" ) model_type = config.model_type self._validate_model_type(model_type) From caa483cae0edc6f8a4f44370779493d758179810 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 03:12:06 -0500 Subject: [PATCH 06/19] up. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 22112e321a..9fe864171e 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -1,12 +1,12 @@ # -# Copyright (C) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT # """Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). -Accepts ONNX input only; run OnnxConversion (e.g. from PyTorchModel + olive user_script) first, -then this pass runs generate_sd_model for preprocess + partition. +Accepts ONNX input only; run OnnxConversion to produce ONNX input model first, +then this pass runs generate_sd_model to generate NPU-ready models. """ from __future__ import annotations @@ -57,7 +57,7 @@ def _default_config(cls, accelerator_spec): def _validate_model_type(model_type: str) -> None: registry = _get_sd_registry() if model_type not in registry: - raise ValueError(f"model_type must be one of {list(registry.keys())}, got {model_type!r}") + raise ValueError(f"model_type must be one of {', '.join(registry)}, got {model_type!r}") def _run_for_config( self, @@ -67,7 +67,8 @@ def _run_for_config( ) -> ONNXModelHandler: if not isinstance(model, ONNXModelHandler): raise TypeError( - f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" + "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " + f"Got {type(model).__name__}" ) model_type = config.model_type self._validate_model_type(model_type) From e3a73fb8f8cffd6f5755f254529657324bf3d306 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 03:17:15 -0500 Subject: [PATCH 07/19] lint. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 9fe864171e..ca9a3216f8 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -38,12 +38,11 @@ class VitisGenerateModelSD(Pass): @classmethod def _default_config(cls, accelerator_spec): - registry = _get_sd_registry() return { "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be one of {', '.join(registry)}.", + description=f"SD submodel type, must be one of {', '.join(_get_sd_registry())}.", ), "resolutions": PassConfigParam( type_=list[str], @@ -67,8 +66,7 @@ def _run_for_config( ) -> ONNXModelHandler: if not isinstance(model, ONNXModelHandler): raise TypeError( - "VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). " - f"Got {type(model).__name__}" + f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" ) model_type = config.model_type self._validate_model_type(model_type) From e0966e556b331f4d058282e6e9febee9e7fb9c75 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 22:52:52 -0500 Subject: [PATCH 08/19] ruff. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index ca9a3216f8..ab1a7517a6 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: MIT # -"""Olive Pass for Vitis NPU Stable Diffusion submodel generation (UNet / VAE decoder). +"""Olive Pass for Vitis NPU Stable Diffusion submodel generation. Accepts ONNX input only; run OnnxConversion to produce ONNX input model first, then this pass runs generate_sd_model to generate NPU-ready models. @@ -15,6 +15,8 @@ import shutil from pathlib import Path +from model_generate import SUPPORTED_SD_MODEL_TYPES, generate_sd_model + from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam @@ -22,15 +24,8 @@ logger = logging.getLogger(__name__) -def _get_sd_registry(): - """Import registry from npu_model_gen to keep model_type choices in sync.""" - import model_generate - - return model_generate.SUPPORTED_SD_MODEL_TYPES - - class VitisGenerateModelSD(Pass): - """Generate Vitis NPU-ready SD submodel (unet or vae_decoder) from ONNX input. + """Generate Vitis NPU-ready SD submodel from ONNX input. Use OnnxConversion to produce ONNX input model. Optional resolutions to generate NPU-ready models. Default is [512x512]. @@ -42,7 +37,7 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be one of {', '.join(_get_sd_registry())}.", + description=f"SD submodel type, must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}.", ), "resolutions": PassConfigParam( type_=list[str], @@ -54,9 +49,8 @@ def _default_config(cls, accelerator_spec): @staticmethod def _validate_model_type(model_type: str) -> None: - registry = _get_sd_registry() - if model_type not in registry: - raise ValueError(f"model_type must be one of {', '.join(registry)}, got {model_type!r}") + if model_type not in SUPPORTED_SD_MODEL_TYPES: + raise ValueError(f"model_type must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}, got {model_type!r}") def _run_for_config( self, @@ -92,8 +86,6 @@ def _run_for_config( resolutions, ) - from model_generate import generate_sd_model - generate_sd_model( input_model=str(onnx_input_path), output_dir=str(output_dir), From c168335a3e83b9a2bbe737fbda6d8ef90713d091 Mon Sep 17 00:00:00 2001 From: liujij Date: Thu, 26 Mar 2026 23:11:23 -0500 Subject: [PATCH 09/19] ruff. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index ab1a7517a6..4b4773fb4a 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -1,7 +1,7 @@ -# +# ------------------------------------------------------------------------- # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -# +# ------------------------------------------------------------------------- """Olive Pass for Vitis NPU Stable Diffusion submodel generation. From 70cf7d2a168647ec199111c337622908b6946674 Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 30 Mar 2026 04:28:05 -0500 Subject: [PATCH 10/19] update olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py with new model_generate. --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 4b4773fb4a..9fc5059aec 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -15,13 +15,15 @@ import shutil from pathlib import Path -from model_generate import SUPPORTED_SD_MODEL_TYPES, generate_sd_model +from model_generate import generate_model +from model_generate.recipes import get_supported_sd_model_types from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) +SUPPORTED_SD_MODEL_TYPES = get_supported_sd_model_types() class VitisGenerateModelSD(Pass): @@ -86,11 +88,11 @@ def _run_for_config( resolutions, ) - generate_sd_model( + generate_model( + mode="sd", input_model=str(onnx_input_path), output_dir=str(output_dir), - model_type=model_type, - resolutions=resolutions, + extra_options={"model_type": model_type, "resolutions": ",".join(resolutions)}, ) self._ensure_model_onnx(output_dir) From ee3d37c8c490e4a927638c308368604321eee629 Mon Sep 17 00:00:00 2001 From: liujij <134048367+liujij@users.noreply.github.com> Date: Thu, 7 May 2026 14:28:33 +0800 Subject: [PATCH 11/19] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 9fc5059aec..70ee060466 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -82,17 +82,19 @@ def _run_for_config( logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) resolutions = getattr(config, "resolutions", None) + extra_options = {"model_type": model_type} if resolutions: logger.info( "[VitisGenerateModelSD] Using resolutions: %s", resolutions, ) + extra_options["resolutions"] = ",".join(resolutions) generate_model( mode="sd", input_model=str(onnx_input_path), output_dir=str(output_dir), - extra_options={"model_type": model_type, "resolutions": ",".join(resolutions)}, + extra_options=extra_options, ) self._ensure_model_onnx(output_dir) From 87036780a05a045fc37f3d22b0c6c0fce4607832 Mon Sep 17 00:00:00 2001 From: Ji Date: Thu, 7 May 2026 17:39:53 +0800 Subject: [PATCH 12/19] Resolve review feedback. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 55 ++++++++++++++----- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 70ee060466..40c9f28783 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -15,15 +15,11 @@ import shutil from pathlib import Path -from model_generate import generate_model -from model_generate.recipes import get_supported_sd_model_types - from olive.model import ONNXModelHandler from olive.passes import Pass from olive.passes.pass_config import BasePassConfig, PassConfigParam logger = logging.getLogger(__name__) -SUPPORTED_SD_MODEL_TYPES = get_supported_sd_model_types() class VitisGenerateModelSD(Pass): @@ -39,7 +35,7 @@ def _default_config(cls, accelerator_spec): "model_type": PassConfigParam( type_=str, required=True, - description=f"SD submodel type, must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}.", + description="SD submodel type.", ), "resolutions": PassConfigParam( type_=list[str], @@ -50,9 +46,22 @@ def _default_config(cls, accelerator_spec): } @staticmethod - def _validate_model_type(model_type: str) -> None: - if model_type not in SUPPORTED_SD_MODEL_TYPES: - raise ValueError(f"model_type must be one of {', '.join(SUPPORTED_SD_MODEL_TYPES)}, got {model_type!r}") + def _get_supported_sd_model_types(): + try: + from model_generate.recipes import get_supported_sd_model_types + except ImportError as e: + raise ImportError( + "model_generate is required for VitisGenerateModelSD. Please install the model_generate package." + ) from e + + return get_supported_sd_model_types() + + @classmethod + def _validate_model_type(cls, model_type: str) -> None: + supported_types = cls._get_supported_sd_model_types() + + if model_type not in supported_types: + raise ValueError(f"model_type must be one of {', '.join(supported_types)}, got {model_type!r}") def _run_for_config( self, @@ -60,6 +69,13 @@ def _run_for_config( config: BasePassConfig, output_model_path: str, ) -> ONNXModelHandler: + try: + from model_generate import generate_model + except ImportError as e: + raise ImportError( + "model_generate is required for VitisGenerateModelSD. Please install the model_generate package." + ) from e + if not isinstance(model, ONNXModelHandler): raise TypeError( f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" @@ -114,10 +130,22 @@ def _resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: f = p / name if f.exists(): return f - onnx_files = list(p.glob("*.onnx")) - if onnx_files: + raise FileNotFoundError(f"Specified onnx_file_name does not exist under {p}: {name}") + + default_model_path = p / "model.onnx" + if default_model_path.exists(): + return default_model_path + + onnx_files = sorted(path for path in p.glob("*.onnx") if path.is_file()) + if len(onnx_files) == 1: return onnx_files[0] - raise FileNotFoundError(f"No .onnx file found under {p}") + if len(onnx_files) > 1: + candidates = ", ".join(path.name for path in onnx_files) + raise ValueError( + f"Multiple .onnx model files found under {p}: {candidates}. Please specify one using the onnx_file_name argument." + ) + else: + raise FileNotFoundError(f"No .onnx file found under {p}") raise FileNotFoundError(f"Model path does not exist: {p}") def _ensure_model_onnx(self, output_dir: Path) -> None: @@ -134,7 +162,6 @@ def _ensure_model_onnx(self, output_dir: Path) -> None: shutil.copy2(optimized, model_onnx) logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx") else: - logger.warning( - "[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under %s", - output_dir, + raise FileNotFoundError( + f"[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under {output_dir}. Please check the output directory.", ) From fb0fba596b9089381b56de3f27b9966f57e36347 Mon Sep 17 00:00:00 2001 From: Ji Date: Thu, 7 May 2026 18:09:45 +0800 Subject: [PATCH 13/19] add unit test for vitis_generate_model_sd.py. --- .../vitis_ai/test_vitis_generate_model_sd.py | 252 ++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 test/passes/vitis_ai/test_vitis_generate_model_sd.py diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py new file mode 100644 index 0000000000..58130f9cfb --- /dev/null +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -0,0 +1,252 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import builtins +import sys +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from olive.model import ONNXModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.vitis_ai.vitis_generate_model_sd import VitisGenerateModelSD +from test.utils import ONNX_MODEL_PATH, get_onnx_model + +_PATCH_GEN = "model_generate.generate_model" +_PATCH_SUPPORTED = "model_generate.recipes.get_supported_sd_model_types" + + +def _make_pass(**kwargs): + cfg = {"model_type": "unet", "resolutions": [], **kwargs} + return create_pass_from_dict(VitisGenerateModelSD, cfg, disable_search=True) + + +def _generate_writes_placeholder(**kwargs): + """Mock generate_model: leave output Olive's _ensure_model_onnx can satisfy.""" + out = Path(kwargs["output_dir"]) + out.mkdir(parents=True, exist_ok=True) + (out / "optimized.onnx").write_bytes(b"placeholder") + + +def test_get_supported_sd_model_types_wraps_import_error(): + saved_mods = {k: sys.modules.pop(k) for k in list(sys.modules) if k == "model_generate" or k.startswith("model_generate.")} + real_import = builtins.__import__ + + def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "model_generate.recipes" and fromlist: + raise ImportError("simulated missing recipes") + return real_import(name, globals, locals, fromlist, level) + + try: + with patch.object(builtins, "__import__", guarded_import): + with pytest.raises(ImportError, match="model_generate is required for VitisGenerateModelSD"): + VitisGenerateModelSD._get_supported_sd_model_types() + finally: + sys.modules.update(saved_mods) + + +@pytest.mark.parametrize( + ("model_type", "supported"), + [ + ("bad", ["unet"]), + ("", ["unet", "vae"]), + ], +) +def test_run_invalid_model_type_raises(model_type, supported, tmp_path): + gen = MagicMock() + with patch(_PATCH_SUPPORTED, return_value=supported), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": model_type, "resolutions": []}, + disable_search=True, + ) + with pytest.raises(ValueError, match="model_type must be one of"): + p.run(get_onnx_model(), str(tmp_path / "out")) + + +def test_run_includes_resolutions_in_extra_options(tmp_path): + gen = MagicMock(side_effect=_generate_writes_placeholder) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + { + "model_type": "unet", + "resolutions": ["512x512", "768x768"], + }, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "sd_out")) + + gen.assert_called_once() + kwargs = gen.call_args.kwargs + assert kwargs["mode"] == "sd" + assert kwargs["extra_options"]["model_type"] == "unet" + assert kwargs["extra_options"]["resolutions"] == "512x512,768x768" + + +def test_run_default_resolutions_passed_when_using_defaults(tmp_path): + gen = MagicMock(side_effect=_generate_writes_placeholder) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet"}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert gen.call_args.kwargs["extra_options"].get("resolutions") == "512x512" + + +def test_run_omits_resolutions_when_empty_list(tmp_path): + gen = MagicMock(side_effect=_generate_writes_placeholder) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert "resolutions" not in gen.call_args.kwargs["extra_options"] + + +def test_ensure_model_onnx_copies_optimized(tmp_path): + def write_optimized(**kwargs): + out = Path(kwargs["output_dir"]) + (out / "optimized.onnx").write_text("from_optimized", encoding="utf-8") + + gen = MagicMock(side_effect=write_optimized) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "from_optimized" + + +def test_ensure_model_onnx_prefers_dd_replaced_over_optimized(tmp_path): + def write_both(**kwargs): + out = Path(kwargs["output_dir"]) + (out / "optimized.onnx").write_text("from_optimized", encoding="utf-8") + dd = out / "dd" + dd.mkdir(parents=True) + (dd / "replaced.onnx").write_text("from_dd", encoding="utf-8") + + gen = MagicMock(side_effect=write_both) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "from_dd" + + +def test_ensure_model_onnx_skips_copy_when_model_onnx_exists(tmp_path): + def write_only_original(**kwargs): + out = Path(kwargs["output_dir"]) + (out / "model.onnx").write_text("original", encoding="utf-8") + (out / "optimized.onnx").write_text("optimized", encoding="utf-8") + + gen = MagicMock(side_effect=write_only_original) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "original" + + +def test_ensure_model_onnx_raises_when_no_candidate_files(tmp_path): + gen = MagicMock() + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + with pytest.raises(FileNotFoundError, match="No optimized.onnx or dd/replaced.onnx"): + p.run(get_onnx_model(), str(tmp_path / "out")) + + +def test_resolve_onnx_input_path_single_file(): + p = _make_pass() + h = ONNXModelHandler(model_path=str(ONNX_MODEL_PATH)) + assert p._resolve_onnx_input_path(h) == Path(ONNX_MODEL_PATH) + + +def test_resolve_onnx_input_path_dir_with_model_onnx(tmp_path): + (tmp_path / "model.onnx").write_bytes(b"x") + p = _make_pass() + h = ONNXModelHandler(model_path=str(tmp_path)) + assert p._resolve_onnx_input_path(h) == tmp_path / "model.onnx" + + +def test_resolve_onnx_input_path_dir_with_onnx_file_name(tmp_path): + (tmp_path / "custom.onnx").write_bytes(b"x") + p = _make_pass() + h = ONNXModelHandler(model_path=str(tmp_path), onnx_file_name="custom.onnx") + assert p._resolve_onnx_input_path(h) == tmp_path / "custom.onnx" + + +def test_resolve_onnx_input_path_dir_onnx_file_name_missing_raises(tmp_path): + p = _make_pass() + h = SimpleNamespace(model_path=str(tmp_path), onnx_file_name="missing.onnx") + with pytest.raises(FileNotFoundError, match="Specified onnx_file_name"): + p._resolve_onnx_input_path(h) + + +def test_resolve_onnx_input_path_dir_single_unnamed_onnx(tmp_path): + (tmp_path / "only.onnx").write_bytes(b"x") + p = _make_pass() + h = ONNXModelHandler(model_path=str(tmp_path)) + assert p._resolve_onnx_input_path(h) == tmp_path / "only.onnx" + + +def test_resolve_onnx_input_path_dir_multiple_onnx_raises(tmp_path): + (tmp_path / "a.onnx").write_bytes(b"x") + (tmp_path / "b.onnx").write_bytes(b"y") + p = _make_pass() + h = SimpleNamespace(model_path=str(tmp_path)) + with pytest.raises(ValueError, match="Multiple .onnx model files found"): + p._resolve_onnx_input_path(h) + + +def test_resolve_onnx_input_path_dir_no_onnx_raises(tmp_path): + p = _make_pass() + h = SimpleNamespace(model_path=str(tmp_path)) + with pytest.raises(FileNotFoundError, match="No .onnx file found"): + p._resolve_onnx_input_path(h) + + +def test_resolve_onnx_input_path_missing_path_raises(tmp_path): + p = _make_pass() + missing = tmp_path / "nope" + h = SimpleNamespace(model_path=str(missing)) + with pytest.raises(FileNotFoundError, match="Model path does not exist"): + p._resolve_onnx_input_path(h) + + +def test_run_requires_onnx_model_handler(tmp_path): + gen = MagicMock() + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + bad = MagicMock() + with pytest.raises(TypeError, match="ONNXModelHandler"): + p.run(bad, str(tmp_path / "out")) From ee6964256ab94cb9e40d676e1d6942f2057fc731 Mon Sep 17 00:00:00 2001 From: liujij Date: Fri, 8 May 2026 22:37:45 -0500 Subject: [PATCH 14/19] lint. --- test/passes/vitis_ai/test_vitis_generate_model_sd.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py index 58130f9cfb..bd68511afd 100644 --- a/test/passes/vitis_ai/test_vitis_generate_model_sd.py +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -33,7 +33,9 @@ def _generate_writes_placeholder(**kwargs): def test_get_supported_sd_model_types_wraps_import_error(): - saved_mods = {k: sys.modules.pop(k) for k in list(sys.modules) if k == "model_generate" or k.startswith("model_generate.")} + saved_mods = { + k: sys.modules.pop(k) for k in list(sys.modules) if k == "model_generate" or k.startswith("model_generate.") + } real_import = builtins.__import__ def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): From 2f3ed1fd96a670c6670edeb60fc6487af8f4586a Mon Sep 17 00:00:00 2001 From: liujij Date: Fri, 8 May 2026 23:12:58 -0500 Subject: [PATCH 15/19] fix format issue. --- .../vitis_ai/test_vitis_generate_model_sd.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py index bd68511afd..9d12175006 100644 --- a/test/passes/vitis_ai/test_vitis_generate_model_sd.py +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +# pylint: disable=protected-access import builtins import sys @@ -38,15 +39,17 @@ def test_get_supported_sd_model_types_wraps_import_error(): } real_import = builtins.__import__ - def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): + def guarded_import(name, glbs=None, locs=None, fromlist=(), level=0): if name == "model_generate.recipes" and fromlist: raise ImportError("simulated missing recipes") - return real_import(name, globals, locals, fromlist, level) + return real_import(name, glbs, locs, fromlist, level) try: - with patch.object(builtins, "__import__", guarded_import): - with pytest.raises(ImportError, match="model_generate is required for VitisGenerateModelSD"): - VitisGenerateModelSD._get_supported_sd_model_types() + with ( + patch.object(builtins, "__import__", guarded_import), + pytest.raises(ImportError, match="model_generate is required for VitisGenerateModelSD"), + ): + VitisGenerateModelSD._get_supported_sd_model_types() finally: sys.modules.update(saved_mods) @@ -179,7 +182,7 @@ def test_ensure_model_onnx_raises_when_no_candidate_files(tmp_path): {"model_type": "unet", "resolutions": []}, disable_search=True, ) - with pytest.raises(FileNotFoundError, match="No optimized.onnx or dd/replaced.onnx"): + with pytest.raises(FileNotFoundError, match=r"No optimized\.onnx or dd/replaced\.onnx"): p.run(get_onnx_model(), str(tmp_path / "out")) @@ -222,14 +225,14 @@ def test_resolve_onnx_input_path_dir_multiple_onnx_raises(tmp_path): (tmp_path / "b.onnx").write_bytes(b"y") p = _make_pass() h = SimpleNamespace(model_path=str(tmp_path)) - with pytest.raises(ValueError, match="Multiple .onnx model files found"): + with pytest.raises(ValueError, match=r"Multiple \.onnx model files found"): p._resolve_onnx_input_path(h) def test_resolve_onnx_input_path_dir_no_onnx_raises(tmp_path): p = _make_pass() h = SimpleNamespace(model_path=str(tmp_path)) - with pytest.raises(FileNotFoundError, match="No .onnx file found"): + with pytest.raises(FileNotFoundError, match=r"No \.onnx file found"): p._resolve_onnx_input_path(h) From 05b5c5471814f259913da084a0f1aca614b11bff Mon Sep 17 00:00:00 2001 From: liujij Date: Fri, 8 May 2026 23:28:07 -0500 Subject: [PATCH 16/19] fix issue from github-advanced-security. --- .../onnx/vitis_ai/vitis_generate_model_sd.py | 8 +++--- .../vitis_ai/test_vitis_generate_model_sd.py | 25 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py index 40c9f28783..2cf4e602e5 100644 --- a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -46,7 +46,7 @@ def _default_config(cls, accelerator_spec): } @staticmethod - def _get_supported_sd_model_types(): + def get_supported_sd_model_types(): try: from model_generate.recipes import get_supported_sd_model_types except ImportError as e: @@ -58,7 +58,7 @@ def _get_supported_sd_model_types(): @classmethod def _validate_model_type(cls, model_type: str) -> None: - supported_types = cls._get_supported_sd_model_types() + supported_types = cls.get_supported_sd_model_types() if model_type not in supported_types: raise ValueError(f"model_type must be one of {', '.join(supported_types)}, got {model_type!r}") @@ -94,7 +94,7 @@ def _run_for_config( model_type, ) - onnx_input_path = self._resolve_onnx_input_path(model) + onnx_input_path = self.resolve_onnx_input_path(model) logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) resolutions = getattr(config, "resolutions", None) @@ -120,7 +120,7 @@ def _run_for_config( onnx_file_name="model.onnx", ) - def _resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: + def resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: p = Path(model.model_path) if p.is_file(): return p diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py index 9d12175006..b9f274ba1d 100644 --- a/test/passes/vitis_ai/test_vitis_generate_model_sd.py +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -1,8 +1,7 @@ # ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# pylint: disable=protected-access +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# ------------------------------------------------------------------------- import builtins import sys @@ -49,7 +48,7 @@ def guarded_import(name, glbs=None, locs=None, fromlist=(), level=0): patch.object(builtins, "__import__", guarded_import), pytest.raises(ImportError, match="model_generate is required for VitisGenerateModelSD"), ): - VitisGenerateModelSD._get_supported_sd_model_types() + VitisGenerateModelSD.get_supported_sd_model_types() finally: sys.modules.update(saved_mods) @@ -189,35 +188,35 @@ def test_ensure_model_onnx_raises_when_no_candidate_files(tmp_path): def test_resolve_onnx_input_path_single_file(): p = _make_pass() h = ONNXModelHandler(model_path=str(ONNX_MODEL_PATH)) - assert p._resolve_onnx_input_path(h) == Path(ONNX_MODEL_PATH) + assert p.resolve_onnx_input_path(h) == Path(ONNX_MODEL_PATH) def test_resolve_onnx_input_path_dir_with_model_onnx(tmp_path): (tmp_path / "model.onnx").write_bytes(b"x") p = _make_pass() h = ONNXModelHandler(model_path=str(tmp_path)) - assert p._resolve_onnx_input_path(h) == tmp_path / "model.onnx" + assert p.resolve_onnx_input_path(h) == tmp_path / "model.onnx" def test_resolve_onnx_input_path_dir_with_onnx_file_name(tmp_path): (tmp_path / "custom.onnx").write_bytes(b"x") p = _make_pass() h = ONNXModelHandler(model_path=str(tmp_path), onnx_file_name="custom.onnx") - assert p._resolve_onnx_input_path(h) == tmp_path / "custom.onnx" + assert p.resolve_onnx_input_path(h) == tmp_path / "custom.onnx" def test_resolve_onnx_input_path_dir_onnx_file_name_missing_raises(tmp_path): p = _make_pass() h = SimpleNamespace(model_path=str(tmp_path), onnx_file_name="missing.onnx") with pytest.raises(FileNotFoundError, match="Specified onnx_file_name"): - p._resolve_onnx_input_path(h) + p.resolve_onnx_input_path(h) def test_resolve_onnx_input_path_dir_single_unnamed_onnx(tmp_path): (tmp_path / "only.onnx").write_bytes(b"x") p = _make_pass() h = ONNXModelHandler(model_path=str(tmp_path)) - assert p._resolve_onnx_input_path(h) == tmp_path / "only.onnx" + assert p.resolve_onnx_input_path(h) == tmp_path / "only.onnx" def test_resolve_onnx_input_path_dir_multiple_onnx_raises(tmp_path): @@ -226,14 +225,14 @@ def test_resolve_onnx_input_path_dir_multiple_onnx_raises(tmp_path): p = _make_pass() h = SimpleNamespace(model_path=str(tmp_path)) with pytest.raises(ValueError, match=r"Multiple \.onnx model files found"): - p._resolve_onnx_input_path(h) + p.resolve_onnx_input_path(h) def test_resolve_onnx_input_path_dir_no_onnx_raises(tmp_path): p = _make_pass() h = SimpleNamespace(model_path=str(tmp_path)) with pytest.raises(FileNotFoundError, match=r"No \.onnx file found"): - p._resolve_onnx_input_path(h) + p.resolve_onnx_input_path(h) def test_resolve_onnx_input_path_missing_path_raises(tmp_path): @@ -241,7 +240,7 @@ def test_resolve_onnx_input_path_missing_path_raises(tmp_path): missing = tmp_path / "nope" h = SimpleNamespace(model_path=str(missing)) with pytest.raises(FileNotFoundError, match="Model path does not exist"): - p._resolve_onnx_input_path(h) + p.resolve_onnx_input_path(h) def test_run_requires_onnx_model_handler(tmp_path): From f1c6ec80e48616002d37f5551adf0c68d1e2468d Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 11 May 2026 04:17:46 -0500 Subject: [PATCH 17/19] ruff. --- test/passes/vitis_ai/test_vitis_generate_model_sd.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py index b9f274ba1d..94b706bf38 100644 --- a/test/passes/vitis_ai/test_vitis_generate_model_sd.py +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -1,7 +1,7 @@ -# ------------------------------------------------------------------------- -# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. -# SPDX-License-Identifier: MIT -# ------------------------------------------------------------------------- +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- import builtins import sys From 0cf11d31206b709f24ec4f1f99aab78d5c3dc1c8 Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 11 May 2026 04:19:22 -0500 Subject: [PATCH 18/19] ruff. --- test/passes/vitis_ai/test_vitis_generate_model_sd.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py index 94b706bf38..2fc840fb6d 100644 --- a/test/passes/vitis_ai/test_vitis_generate_model_sd.py +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -1,4 +1,3 @@ -# -------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- From 5487b09763d6f5383c07132c923d5aca692e1a9b Mon Sep 17 00:00:00 2001 From: liujij Date: Mon, 11 May 2026 04:41:23 -0500 Subject: [PATCH 19/19] lint. --- test/passes/vitis_ai/test_vitis_generate_model_sd.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py index 2fc840fb6d..b9f274ba1d 100644 --- a/test/passes/vitis_ai/test_vitis_generate_model_sd.py +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -1,6 +1,7 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- +# ------------------------------------------------------------------------- +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# ------------------------------------------------------------------------- import builtins import sys