diff --git a/olive/olive_config.json b/olive/olive_config.json index 50e1f36d6c..d20f8f2d6d 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -673,6 +673,15 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ], "run_on_target": true + }, + "VitisGenerateModelSD": { + "module_path": "olive.passes.onnx.vitis_ai.vitis_generate_model_sd.VitisGenerateModelSD", + "supported_providers": [ "CPUExecutionProvider" ], + "supported_accelerators": [ "cpu" ], + "supported_precisions": [ "int8" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "run_on_target": true } }, "extra_dependencies": { diff --git a/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py new file mode 100644 index 0000000000..2cf4e602e5 --- /dev/null +++ b/olive/passes/onnx/vitis_ai/vitis_generate_model_sd.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------------------- +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# ------------------------------------------------------------------------- + +"""Olive Pass for Vitis NPU Stable Diffusion submodel generation. + +Accepts ONNX input only; run OnnxConversion to produce ONNX input model first, +then this pass runs generate_sd_model to generate NPU-ready models. +""" + +from __future__ import annotations + +import logging +import shutil +from pathlib import Path + +from olive.model import ONNXModelHandler +from olive.passes import Pass +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) + + +class VitisGenerateModelSD(Pass): + """Generate Vitis NPU-ready SD submodel from ONNX input. + + Use OnnxConversion to produce ONNX input model. + Optional resolutions to generate NPU-ready models. Default is [512x512]. + """ + + @classmethod + def _default_config(cls, accelerator_spec): + return { + "model_type": PassConfigParam( + type_=str, + required=True, + description="SD submodel type.", + ), + "resolutions": PassConfigParam( + type_=list[str], + default_value=["512x512"], + required=False, + description="List of resolutions (e.g. ['512x512', '1024x1024']) Default is [512x512].", + ), + } + + @staticmethod + def get_supported_sd_model_types(): + try: + from model_generate.recipes import get_supported_sd_model_types + except ImportError as e: + raise ImportError( + "model_generate is required for VitisGenerateModelSD. Please install the model_generate package." + ) from e + + return get_supported_sd_model_types() + + @classmethod + def _validate_model_type(cls, model_type: str) -> None: + supported_types = cls.get_supported_sd_model_types() + + if model_type not in supported_types: + raise ValueError(f"model_type must be one of {', '.join(supported_types)}, got {model_type!r}") + + def _run_for_config( + self, + model: ONNXModelHandler, + config: BasePassConfig, + output_model_path: str, + ) -> ONNXModelHandler: + try: + from model_generate import generate_model + except ImportError as e: + raise ImportError( + "model_generate is required for VitisGenerateModelSD. Please install the model_generate package." + ) from e + + if not isinstance(model, ONNXModelHandler): + raise TypeError( + f"VitisGenerateModelSD requires ONNXModelHandler (run OnnxConversion first). Got {type(model).__name__}" + ) + model_type = config.model_type + self._validate_model_type(model_type) + + output_dir = Path(output_model_path) + if output_dir.suffix == ".onnx": + output_dir = output_dir.parent + output_dir.mkdir(parents=True, exist_ok=True) + + logger.info( + "[VitisGenerateModelSD] output_dir=%s, model_type=%s", + output_dir, + model_type, + ) + + onnx_input_path = self.resolve_onnx_input_path(model) + logger.info("[VitisGenerateModelSD] ONNX input path: %s", onnx_input_path) + + resolutions = getattr(config, "resolutions", None) + extra_options = {"model_type": model_type} + if resolutions: + logger.info( + "[VitisGenerateModelSD] Using resolutions: %s", + resolutions, + ) + extra_options["resolutions"] = ",".join(resolutions) + + generate_model( + mode="sd", + input_model=str(onnx_input_path), + output_dir=str(output_dir), + extra_options=extra_options, + ) + + self._ensure_model_onnx(output_dir) + + return ONNXModelHandler( + model_path=str(output_dir), + onnx_file_name="model.onnx", + ) + + def resolve_onnx_input_path(self, model: ONNXModelHandler) -> Path: + p = Path(model.model_path) + if p.is_file(): + return p + if p.is_dir(): + name = getattr(model, "onnx_file_name", None) + if name: + f = p / name + if f.exists(): + return f + raise FileNotFoundError(f"Specified onnx_file_name does not exist under {p}: {name}") + + default_model_path = p / "model.onnx" + if default_model_path.exists(): + return default_model_path + + onnx_files = sorted(path for path in p.glob("*.onnx") if path.is_file()) + if len(onnx_files) == 1: + return onnx_files[0] + if len(onnx_files) > 1: + candidates = ", ".join(path.name for path in onnx_files) + raise ValueError( + f"Multiple .onnx model files found under {p}: {candidates}. Please specify one using the onnx_file_name argument." + ) + else: + raise FileNotFoundError(f"No .onnx file found under {p}") + raise FileNotFoundError(f"Model path does not exist: {p}") + + def _ensure_model_onnx(self, output_dir: Path) -> None: + """Copy actual generate_sd_model output to output_dir/model.onnx if needed.""" + model_onnx = output_dir / "model.onnx" + if model_onnx.exists(): + return + optimized = output_dir / "optimized.onnx" + dd_replaced = output_dir / "dd" / "replaced.onnx" + if dd_replaced.exists(): + shutil.copy2(dd_replaced, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from dd/replaced.onnx") + elif optimized.exists(): + shutil.copy2(optimized, model_onnx) + logger.info("[VitisGenerateModelSD] Wrote model.onnx from optimized.onnx") + else: + raise FileNotFoundError( + f"[VitisGenerateModelSD] No optimized.onnx or dd/replaced.onnx found under {output_dir}. Please check the output directory.", + ) diff --git a/test/passes/vitis_ai/test_vitis_generate_model_sd.py b/test/passes/vitis_ai/test_vitis_generate_model_sd.py new file mode 100644 index 0000000000..b9f274ba1d --- /dev/null +++ b/test/passes/vitis_ai/test_vitis_generate_model_sd.py @@ -0,0 +1,256 @@ +# ------------------------------------------------------------------------- +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# ------------------------------------------------------------------------- + +import builtins +import sys +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from olive.model import ONNXModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.vitis_ai.vitis_generate_model_sd import VitisGenerateModelSD +from test.utils import ONNX_MODEL_PATH, get_onnx_model + +_PATCH_GEN = "model_generate.generate_model" +_PATCH_SUPPORTED = "model_generate.recipes.get_supported_sd_model_types" + + +def _make_pass(**kwargs): + cfg = {"model_type": "unet", "resolutions": [], **kwargs} + return create_pass_from_dict(VitisGenerateModelSD, cfg, disable_search=True) + + +def _generate_writes_placeholder(**kwargs): + """Mock generate_model: leave output Olive's _ensure_model_onnx can satisfy.""" + out = Path(kwargs["output_dir"]) + out.mkdir(parents=True, exist_ok=True) + (out / "optimized.onnx").write_bytes(b"placeholder") + + +def test_get_supported_sd_model_types_wraps_import_error(): + saved_mods = { + k: sys.modules.pop(k) for k in list(sys.modules) if k == "model_generate" or k.startswith("model_generate.") + } + real_import = builtins.__import__ + + def guarded_import(name, glbs=None, locs=None, fromlist=(), level=0): + if name == "model_generate.recipes" and fromlist: + raise ImportError("simulated missing recipes") + return real_import(name, glbs, locs, fromlist, level) + + try: + with ( + patch.object(builtins, "__import__", guarded_import), + pytest.raises(ImportError, match="model_generate is required for VitisGenerateModelSD"), + ): + VitisGenerateModelSD.get_supported_sd_model_types() + finally: + sys.modules.update(saved_mods) + + +@pytest.mark.parametrize( + ("model_type", "supported"), + [ + ("bad", ["unet"]), + ("", ["unet", "vae"]), + ], +) +def test_run_invalid_model_type_raises(model_type, supported, tmp_path): + gen = MagicMock() + with patch(_PATCH_SUPPORTED, return_value=supported), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": model_type, "resolutions": []}, + disable_search=True, + ) + with pytest.raises(ValueError, match="model_type must be one of"): + p.run(get_onnx_model(), str(tmp_path / "out")) + + +def test_run_includes_resolutions_in_extra_options(tmp_path): + gen = MagicMock(side_effect=_generate_writes_placeholder) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + { + "model_type": "unet", + "resolutions": ["512x512", "768x768"], + }, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "sd_out")) + + gen.assert_called_once() + kwargs = gen.call_args.kwargs + assert kwargs["mode"] == "sd" + assert kwargs["extra_options"]["model_type"] == "unet" + assert kwargs["extra_options"]["resolutions"] == "512x512,768x768" + + +def test_run_default_resolutions_passed_when_using_defaults(tmp_path): + gen = MagicMock(side_effect=_generate_writes_placeholder) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet"}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert gen.call_args.kwargs["extra_options"].get("resolutions") == "512x512" + + +def test_run_omits_resolutions_when_empty_list(tmp_path): + gen = MagicMock(side_effect=_generate_writes_placeholder) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert "resolutions" not in gen.call_args.kwargs["extra_options"] + + +def test_ensure_model_onnx_copies_optimized(tmp_path): + def write_optimized(**kwargs): + out = Path(kwargs["output_dir"]) + (out / "optimized.onnx").write_text("from_optimized", encoding="utf-8") + + gen = MagicMock(side_effect=write_optimized) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "from_optimized" + + +def test_ensure_model_onnx_prefers_dd_replaced_over_optimized(tmp_path): + def write_both(**kwargs): + out = Path(kwargs["output_dir"]) + (out / "optimized.onnx").write_text("from_optimized", encoding="utf-8") + dd = out / "dd" + dd.mkdir(parents=True) + (dd / "replaced.onnx").write_text("from_dd", encoding="utf-8") + + gen = MagicMock(side_effect=write_both) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "from_dd" + + +def test_ensure_model_onnx_skips_copy_when_model_onnx_exists(tmp_path): + def write_only_original(**kwargs): + out = Path(kwargs["output_dir"]) + (out / "model.onnx").write_text("original", encoding="utf-8") + (out / "optimized.onnx").write_text("optimized", encoding="utf-8") + + gen = MagicMock(side_effect=write_only_original) + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + p.run(get_onnx_model(), str(tmp_path / "out")) + + assert (tmp_path / "out" / "model.onnx").read_text(encoding="utf-8") == "original" + + +def test_ensure_model_onnx_raises_when_no_candidate_files(tmp_path): + gen = MagicMock() + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + with pytest.raises(FileNotFoundError, match=r"No optimized\.onnx or dd/replaced\.onnx"): + p.run(get_onnx_model(), str(tmp_path / "out")) + + +def test_resolve_onnx_input_path_single_file(): + p = _make_pass() + h = ONNXModelHandler(model_path=str(ONNX_MODEL_PATH)) + assert p.resolve_onnx_input_path(h) == Path(ONNX_MODEL_PATH) + + +def test_resolve_onnx_input_path_dir_with_model_onnx(tmp_path): + (tmp_path / "model.onnx").write_bytes(b"x") + p = _make_pass() + h = ONNXModelHandler(model_path=str(tmp_path)) + assert p.resolve_onnx_input_path(h) == tmp_path / "model.onnx" + + +def test_resolve_onnx_input_path_dir_with_onnx_file_name(tmp_path): + (tmp_path / "custom.onnx").write_bytes(b"x") + p = _make_pass() + h = ONNXModelHandler(model_path=str(tmp_path), onnx_file_name="custom.onnx") + assert p.resolve_onnx_input_path(h) == tmp_path / "custom.onnx" + + +def test_resolve_onnx_input_path_dir_onnx_file_name_missing_raises(tmp_path): + p = _make_pass() + h = SimpleNamespace(model_path=str(tmp_path), onnx_file_name="missing.onnx") + with pytest.raises(FileNotFoundError, match="Specified onnx_file_name"): + p.resolve_onnx_input_path(h) + + +def test_resolve_onnx_input_path_dir_single_unnamed_onnx(tmp_path): + (tmp_path / "only.onnx").write_bytes(b"x") + p = _make_pass() + h = ONNXModelHandler(model_path=str(tmp_path)) + assert p.resolve_onnx_input_path(h) == tmp_path / "only.onnx" + + +def test_resolve_onnx_input_path_dir_multiple_onnx_raises(tmp_path): + (tmp_path / "a.onnx").write_bytes(b"x") + (tmp_path / "b.onnx").write_bytes(b"y") + p = _make_pass() + h = SimpleNamespace(model_path=str(tmp_path)) + with pytest.raises(ValueError, match=r"Multiple \.onnx model files found"): + p.resolve_onnx_input_path(h) + + +def test_resolve_onnx_input_path_dir_no_onnx_raises(tmp_path): + p = _make_pass() + h = SimpleNamespace(model_path=str(tmp_path)) + with pytest.raises(FileNotFoundError, match=r"No \.onnx file found"): + p.resolve_onnx_input_path(h) + + +def test_resolve_onnx_input_path_missing_path_raises(tmp_path): + p = _make_pass() + missing = tmp_path / "nope" + h = SimpleNamespace(model_path=str(missing)) + with pytest.raises(FileNotFoundError, match="Model path does not exist"): + p.resolve_onnx_input_path(h) + + +def test_run_requires_onnx_model_handler(tmp_path): + gen = MagicMock() + with patch(_PATCH_SUPPORTED, return_value=["unet"]), patch(_PATCH_GEN, gen): + p = create_pass_from_dict( + VitisGenerateModelSD, + {"model_type": "unet", "resolutions": []}, + disable_search=True, + ) + bad = MagicMock() + with pytest.raises(TypeError, match="ONNXModelHandler"): + p.run(bad, str(tmp_path / "out"))