diff --git a/olive/cli/model_package.py b/olive/cli/model_package.py index 4481df1b9..5513c2d71 100644 --- a/olive/cli/model_package.py +++ b/olive/cli/model_package.py @@ -2,12 +2,52 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +"""``olive generate-model-package`` CLI command. + +Assemble one or more Olive output directories into a proposal-shaped ORT +model package. + +Each ``--source`` directory is one Olive output (an ``ONNXModel`` or a +``CompositeModel`` with ONNX components). Single-source packages are +allowed: a single variant under one component is a normal, valid package. + +Output layout (per the ORT model-package proposal):: + + / + ├── manifest.json + ├── configs/ + │ └── # tokenizer, genai_config, ... + └── / + ├── metadata.json + ├── shared_weights/ + │ └── / # opt-in cross-variant dedup + └── / + ├── variant.json + ├── model.onnx + └── ... + +Notes: +- ``shared_weights`` is opt-in per blob. A blob whose SHA-256 appears in only + one variant stays inline next to its ONNX file in the variant directory, + keeping the single-variant case loadable by stock ORT. +- Cross-variant dedup moves a duplicated blob to + ``/shared_weights//`` and records the mapping + in the per-file ``shared_files`` map of the variant's ``variant.json``. + Loading such a variant requires a model-package-aware consumer. +- ``genai_config.json`` is copied verbatim into ``/configs/``; + per-variant overlays are ORT-GenAI's responsibility, not Olive's. + +""" + +import hashlib import json import logging +import re import shutil from argparse import ArgumentParser +from dataclasses import dataclass, field from pathlib import Path -from typing import Optional +from typing import Any, Optional from olive.cli.base import ( BaseOliveCLICommand, @@ -18,18 +58,34 @@ logger = logging.getLogger(__name__) -# Model file suffixes that belong in the models/ directory, not configs/ +# Files inside an Olive output dir that always belong next to the ONNX model +# rather than under /configs/. _MODEL_SUFFIXES = {".onnx", ".bin", ".data", ".xml"} +# Schema version emitted in manifest.json. Keep in sync with the proposal. +_MANIFEST_SCHEMA_VERSION = 1 + +# Hash chunk size for SHA-256 over external-data blobs. +_HASH_CHUNK = 1024 * 1024 + +# Disallow path separators / traversal in component and variant names so a +# producer can't write files outside the package directory. +_NAME_RE = re.compile(r"^[A-Za-z0-9._-][A-Za-z0-9._\- ]*$") + + +# --------------------------------------------------------------------------- +# CLI command +# --------------------------------------------------------------------------- + class ModelPackageCommand(BaseOliveCLICommand): - """Merge multiple Olive output directories into a model package with manifest.""" + """Merge one or more Olive output directories into a model package.""" @staticmethod def register_subcommand(parser: ArgumentParser): sub_parser = parser.add_parser( "generate-model-package", - help="Merge multiple model outputs into a model package with manifest", + help="Merge one or more Olive output directories into a model package", ) sub_parser.add_argument( @@ -38,7 +94,10 @@ def register_subcommand(parser: ArgumentParser): type=str, action="append", required=True, - help="Source Olive output directory. Can be specified multiple times.", + help=( + "Source Olive output directory. Repeat to add multiple variants. " + "A single source is allowed (single-variant package)." + ), ) sub_parser.add_argument( @@ -46,21 +105,21 @@ def register_subcommand(parser: ArgumentParser): "--output_path", type=str, required=True, - help="Output directory for the merged model package.", + help="Output directory for the model package. Must be empty or non-existent.", ) sub_parser.add_argument( "--model_name", type=str, default=None, - help="Model name for the manifest. If not set, derived from the output directory name.", + help="Optional model name recorded under manifest.producer.", ) sub_parser.add_argument( "--model_version", type=str, default="1.0", - help="Model version string for the manifest. Default: 1.0", + help="Optional model version recorded under manifest.producer. Default: 1.0", ) add_logging_options(sub_parser) @@ -71,141 +130,136 @@ def register_subcommand(parser: ArgumentParser): def run(self): sources = self._parse_sources() output_dir = Path(self.args.output_path) - output_dir.mkdir(parents=True, exist_ok=True) - - model_name = self.args.model_name or output_dir.name - model_version = self.args.model_version - # Read model configs from each source targets = [] for target_name, source_path in sources: model_config = self._read_model_config(source_path) targets.append((target_name, source_path, model_config)) - is_composite = targets[0][2].get("type") == "CompositeModel" + types = {targets[i][2].get("type") for i in range(len(targets))} + if types - {"ONNXModel", "CompositeModel"}: + unsupported = sorted(types - {"ONNXModel", "CompositeModel"}) + raise ValueError( + f"Unsupported source model type(s) {unsupported!r}. " + "generate-model-package supports ONNXModel and CompositeModel only." + ) + if len(types) > 1: + raise ValueError( + f"Sources mix model types {sorted(types)!r}. All sources must share the same type " + "(all ONNXModel or all CompositeModel)." + ) + is_composite = next(iter(types)) == "CompositeModel" + if is_composite: - self._package_composite(targets, output_dir, model_name, model_version) + variants = self._build_composite_variants(targets) else: - self._package_single(targets, output_dir, model_name, model_version) + variants = self._build_single_variants(targets) + + config_files = self._collect_config_files(targets) + + task = self._extract_task(targets) + producer_info: dict[str, str] = {"tool": "olive-ai"} + try: + from olive import __version__ as _olive_version + + producer_info["tool_version"] = _olive_version + except Exception: + logger.debug("Could not read olive.__version__", exc_info=True) + producer_info["model_name"] = self.args.model_name or output_dir.name + producer_info["model_version"] = self.args.model_version + if task: + producer_info["task"] = task + + write_model_package( + output_dir=output_dir, + variants=variants, + config_files=config_files, + producer_info=producer_info, + ) logger.info("Model package generated at %s", output_dir) print(f"Model package generated at {output_dir}") # ------------------------------------------------------------------ - # Single-component packaging + # VariantSpec construction # ------------------------------------------------------------------ - def _package_single( - self, - targets: list[tuple[str, Path, dict]], - output_dir: Path, - model_name: str, - model_version: str, - ) -> None: - """Package non-composite models (single ONNX per target).""" - config_file_names = self._copy_config_files(targets, output_dir) + def _build_single_variants(self, targets: list[tuple[str, Path, dict]]) -> list["VariantSpec"]: task = self._extract_task(targets) component_name = _task_to_component_name(task) - - component_dir = output_dir / "models" / component_name - component_dir.mkdir(parents=True, exist_ok=True) - - model_variants = {} - for target_name, _source_path, model_config in targets: + variants: list[VariantSpec] = [] + for target_name, _src, model_config in targets: attrs = _get_model_attributes(model_config) - model_path = Path(model_config["config"]["model_path"]) - - target_dir = component_dir / target_name - _copy_model_files_single(model_path, target_dir) - - constraints = _build_constraints(attrs, model_path) - model_variants[target_name] = {"file": model_path.name, "constraints": constraints} - - _remove_config_files(component_dir, config_file_names) - - metadata = {"name": component_name, "model_variants": model_variants} - _write_json(component_dir / "metadata.json", metadata) - - manifest = { - "name": model_name, - "model_version": model_version, - "task": task, - "component_models": [component_name], - } - _write_json(output_dir / "manifest.json", manifest) - - # ------------------------------------------------------------------ - # Composite-model packaging - # ------------------------------------------------------------------ + onnx_path = _resolve_onnx_path(model_config) + ep, device, compatibility = _ep_device_compatibility(attrs, onnx_path) + variants.append( + VariantSpec( + component_name=component_name, + variant_name=target_name, + onnx_files=[onnx_path], + ep=ep, + device=device, + compatibility=compatibility, + inference_settings=model_config.get("config", {}).get("inference_settings") or {}, + ) + ) + return variants - def _package_composite( - self, - targets: list[tuple[str, Path, dict]], - output_dir: Path, - model_name: str, - model_version: str, - ) -> None: - """Package composite models with per-component directory layout.""" - config_file_names = self._copy_config_files(targets, output_dir) - - # Collect component info: component_data[comp_name][target_name] = (comp_config, target_attrs) + def _build_composite_variants(self, targets: list[tuple[str, Path, dict]]) -> list["VariantSpec"]: from collections import OrderedDict - component_data: dict[str, dict] = OrderedDict() + # Track per-component variants in source insertion order. + component_variants: dict[str, list[VariantSpec]] = OrderedDict() - for target_name, _source_path, model_config in targets: + for target_name, _src, model_config in targets: target_attrs = _get_model_attributes(model_config) + target_inference = model_config.get("config", {}).get("inference_settings") or {} components = model_config["config"].get("model_components", []) component_names = model_config["config"].get("component_names", []) - for comp_config, comp_name in zip(components, component_names): - if comp_name not in component_data: - component_data[comp_name] = OrderedDict() - component_data[comp_name][target_name] = (comp_config, target_attrs) - - models_dir = output_dir / "models" - comp_names_list = list(component_data.keys()) - - for comp_name in comp_names_list: - comp_dir = models_dir / comp_name - comp_dir.mkdir(parents=True, exist_ok=True) - - model_variants = {} - for target_name, (comp_config, target_attrs) in component_data[comp_name].items(): - comp_model_path = Path(comp_config["config"]["model_path"]) - target_dir = comp_dir / target_name - _copy_component_files(comp_model_path, target_dir) - - constraints = _build_constraints(target_attrs, comp_model_path) - model_variants[target_name] = {"file": comp_model_path.name, "constraints": constraints} + if not components: + raise ValueError(f"Composite source {target_name!r} declares no model_components.") - _remove_config_files(comp_dir, config_file_names) - - metadata = {"name": comp_name, "model_variants": model_variants} - _write_json(comp_dir / "metadata.json", metadata) + for comp_config, comp_name in zip(components, component_names): + # Component-level inference_settings overrides target-level if present. + comp_inference = comp_config.get("config", {}).get("inference_settings") or target_inference + # Component-level model_attributes overlay target-level. + comp_attrs = dict(target_attrs) + comp_attrs.update(_get_model_attributes(comp_config)) + + onnx_path = _resolve_onnx_path(comp_config) + ep, device, compatibility = _ep_device_compatibility(comp_attrs, onnx_path) + + spec = VariantSpec( + component_name=comp_name, + variant_name=target_name, + onnx_files=[onnx_path], + ep=ep, + device=device, + compatibility=compatibility, + inference_settings=comp_inference, + ) + component_variants.setdefault(comp_name, []).append(spec) - task = self._extract_task(targets) - manifest = { - "name": model_name, - "model_version": model_version, - "task": task, - "component_models": comp_names_list, - } - _write_json(output_dir / "manifest.json", manifest) + flat: list[VariantSpec] = [] + for comp_specs in component_variants.values(): + flat.extend(comp_specs) + return flat # ------------------------------------------------------------------ # Config file handling # ------------------------------------------------------------------ @staticmethod - def _copy_config_files( - targets: list[tuple[str, Path, dict]], - output_dir: Path, - ) -> set[str]: - """Copy non-model config files (genai_config, tokenizer, etc.) to configs/.""" + def _collect_config_files(targets: list[tuple[str, Path, dict]]) -> dict[str, Path]: + """Pick consumer-shared config files (genai_config, tokenizer, ...). + + Source-of-truth order: + 1. ``model_attributes.additional_files`` of any source that has it. + 2. Otherwise, the first source's non-model files. + """ config_entries: dict[str, Path] = {} - # Collect from the first target's additional_files or source directory for _target_name, _source_path, model_config in targets: attrs = _get_model_attributes(model_config) for fp in attrs.get("additional_files", []): @@ -215,7 +269,6 @@ def _copy_config_files( if config_entries: break - # Fall back to scanning the source directory for non-model files if not config_entries: for _target_name, source_path, _model_config in targets: for f in sorted(source_path.iterdir()): @@ -226,69 +279,47 @@ def _copy_config_files( if config_entries: break - if not config_entries: - return set() - - configs_dir = output_dir / "configs" - configs_dir.mkdir(parents=True, exist_ok=True) - - for name, src_path in config_entries.items(): - dest = configs_dir / name - if src_path.is_dir(): - if not dest.exists(): - shutil.copytree(str(src_path), str(dest)) - else: - shutil.copy2(str(src_path), str(dest)) - logger.info("Copied %s to %s", name, configs_dir) - - return set(config_entries.keys()) + return config_entries # ------------------------------------------------------------------ - # Source validation and reading + # Source validation / reading # ------------------------------------------------------------------ def _parse_sources(self) -> list[tuple[str, Path]]: - sources = [] + sources: list[tuple[str, Path]] = [] + seen_names: set[str] = set() for source in self.args.source: path = Path(source) if not path.is_dir(): raise ValueError(f"Source path does not exist or is not a directory: {path}") - if not (path / "model_config.json").exists(): raise ValueError( f"No model_config.json found in {path}. " "Source must be an Olive output directory with model_config.json." ) - - sources.append((path.name, path)) - - if len(sources) < 2: - raise ValueError("At least two --source directories are required to merge.") - + name = path.name + if name in seen_names: + raise ValueError( + f"Two sources share the directory name {name!r}. Variant names are derived from " + "the source directory name; please rename so each source is unique." + ) + seen_names.add(name) + sources.append((name, path)) + if not sources: + raise ValueError("At least one --source directory is required.") return sources @staticmethod def _read_model_config(source_path: Path) -> dict: - config_path = source_path / "model_config.json" - with open(config_path) as f: + with (source_path / "model_config.json").open() as f: return json.load(f) - @staticmethod - def _extract_accelerator_info(target_models: list[dict]) -> tuple[str, str]: - for model_config in target_models: - attrs = model_config.get("config", {}).get("model_attributes") or {} - ep = attrs.get("ep", "CPUExecutionProvider") - device = attrs.get("device", "cpu") - return ep, device.lower() - return "CPUExecutionProvider", "cpu" - # ------------------------------------------------------------------ # Task extraction # ------------------------------------------------------------------ @staticmethod def _extract_task(targets: list[tuple[str, Path, dict]]) -> str: - """Extract the HuggingFace pipeline task for the model.""" model_name_or_path = "" for _target_name, _source_path, model_config in targets: attrs = _get_model_attributes(model_config) @@ -310,40 +341,490 @@ def _extract_task(targets: list[tuple[str, Path, dict]]) -> str: return "" -# ------------------------------------------------------------------ -# Module-level helpers -# ------------------------------------------------------------------ +# --------------------------------------------------------------------------- +# Writer (CLI-private; kept here because only this command produces packages) +# --------------------------------------------------------------------------- + + +@dataclass +class VariantSpec: + """One variant of one component, ready to be packaged.""" + + component_name: str + variant_name: str + onnx_files: list[Path] + ep: str + device: Optional[str] = None + compatibility: list[str] = field(default_factory=list) + inference_settings: dict[str, Any] = field(default_factory=dict) + consumer_metadata: Optional[dict[str, Any]] = None + + +def write_model_package( + output_dir: Path, + variants: list[VariantSpec], + config_files: Optional[dict[str, Path]] = None, + producer_info: Optional[dict[str, Any]] = None, +) -> None: + """Materialize a model package on disk. + + :param output_dir: Target directory. Must be empty (or non-existent) so a + partial overwrite cannot mix the new layout with stale files from a + previous run. + :param variants: Ordered list of variants. Component insertion order is + the order each component first appears in this list. + :param config_files: Map from filename (basename) to source path; copied + into ``/configs/``. Same-named files contributed by + different sources should be byte-identical; the first wins on + conflict and a warning is logged. + :param producer_info: Olive-specific provenance recorded under + ``manifest.producer``. Schema-tolerated extra field (the proposal + defines only ``schema_version``, ``components``, and + ``merge_provenance``; producers may add namespaced extras). + """ + if not variants: + raise ValueError("write_model_package requires at least one variant.") + + output_dir = Path(output_dir) + _ensure_empty_output_dir(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Group by component while preserving insertion order. + components: dict[str, list[VariantSpec]] = {} + for v in variants: + _validate_name(v.component_name, "component") + _validate_name(v.variant_name, "variant") + components.setdefault(v.component_name, []).append(v) + + # Per component, fail fast on duplicate variant names. The caller is + # expected to disambiguate (e.g. with a rank suffix) before calling. + for comp_name, comp_variants in components.items(): + seen: set[str] = set() + for v in comp_variants: + if v.variant_name in seen: + raise ValueError( + f"Duplicate variant name '{v.variant_name}' under component " + f"'{comp_name}'. Variant names must be unique per component." + ) + seen.add(v.variant_name) + + for comp_name, comp_variants in components.items(): + _write_component(output_dir, comp_name, comp_variants) + + if config_files: + _copy_config_files(output_dir, config_files) + + _write_manifest(output_dir, list(components.keys()), producer_info) + + +def _write_component(output_dir: Path, component_name: str, comp_variants: list[VariantSpec]) -> None: + component_dir = output_dir / component_name + component_dir.mkdir(parents=True, exist_ok=True) + + # First pass: copy each variant's ONNX file(s) and discover external-data + # references. We hash blobs as we copy so multi-variant packages don't + # re-read the data later. + blob_index: dict[str, dict[str, Any]] = {} + variant_files: dict[str, list[tuple[str, list[tuple[str, str]]]]] = {} + + for v in comp_variants: + if not v.onnx_files: + raise ValueError(f"Variant '{v.variant_name}' under component '{component_name}' has no ONNX files.") + + variant_dir = component_dir / v.variant_name + variant_dir.mkdir(parents=True, exist_ok=True) + files_for_variant: list[tuple[str, list[tuple[str, str]]]] = [] + + for onnx_src in v.onnx_files: + onnx_src_path = Path(onnx_src) + if not onnx_src_path.is_file(): + raise FileNotFoundError(f"ONNX file not found: {onnx_src_path}") + + onnx_dst = variant_dir / onnx_src_path.name + shutil.copy2(str(onnx_src_path), str(onnx_dst)) + + ext_refs = _discover_external_data(onnx_src_path) + external_root = onnx_src_path.parent.resolve() + blob_records: list[tuple[str, str]] = [] + for graph_location in ext_refs: + blob_src = (onnx_src_path.parent / graph_location).resolve() + if not blob_src.is_relative_to(external_root): + logger.warning( + "External-data file referenced by %s resolves outside its source directory " + "(symlink escape?); skipping: %s", + onnx_src_path, + blob_src, + ) + continue + if not blob_src.is_file(): + logger.warning( + "External-data file referenced by %s but missing: %s", + onnx_src_path, + blob_src, + ) + continue + + blob_dst = variant_dir / graph_location + blob_dst.parent.mkdir(parents=True, exist_ok=True) + if not blob_dst.exists(): + shutil.copy2(str(blob_src), str(blob_dst)) + + sha = _sha256_file(blob_dst) + blob_records.append((graph_location, sha)) + + entry = blob_index.setdefault( + sha, {"first_path": blob_dst, "occurrences": 0, "basename": Path(graph_location).name} + ) + entry["occurrences"] += 1 + + files_for_variant.append((onnx_dst.name, blob_records)) + + variant_files[v.variant_name] = files_for_variant + + # Second pass: dedup any blob that appears in 2+ variants of this + # component into /shared_weights//. Single- + # occurrence blobs stay inline so single-variant packages remain + # loadable without the package API. + shared_weights_dir = component_dir / "shared_weights" + shared_blob_paths: dict[str, Path] = {} + for sha, entry in blob_index.items(): + if entry["occurrences"] < 2: + continue + sha_dir = shared_weights_dir / sha + sha_dir.mkdir(parents=True, exist_ok=True) + target = sha_dir / entry["basename"] + if not target.exists(): + shutil.copy2(str(entry["first_path"]), str(target)) + shared_blob_paths[sha] = target + + # Third pass: for each variant, remove deduped blobs from the variant + # directory and emit variant.json with the right shared_files map per + # files[i]. Then emit metadata.json for the component. + for v in comp_variants: + variant_dir = component_dir / v.variant_name + files_payload: list[dict[str, Any]] = [] + for onnx_filename, blob_records in variant_files[v.variant_name]: + shared_files: dict[str, str] = {} + for graph_location, sha in blob_records: + if sha in shared_blob_paths: + inline = variant_dir / graph_location + if inline.exists(): + inline.unlink() + # Clean up any now-empty parent directories created for + # nested graph_location paths, but stop at variant_dir. + parent = inline.parent + while parent != variant_dir and parent.is_dir() and not any(parent.iterdir()): + parent.rmdir() + parent = parent.parent + shared_files[graph_location] = sha + + file_entry: dict[str, Any] = {"filename": onnx_filename} + so = (v.inference_settings or {}).get("session_options") or {} + po = _provider_options_for_ep(v.inference_settings or {}, v.ep) + if so: + file_entry["session_options"] = so + if po: + file_entry["provider_options"] = po + if shared_files: + file_entry["shared_files"] = shared_files + files_payload.append(file_entry) + + variant_payload: dict[str, Any] = {"files": files_payload} + if v.consumer_metadata is not None: + variant_payload["consumer_metadata"] = v.consumer_metadata + _write_json(variant_dir / "variant.json", variant_payload) + + _write_metadata(component_dir, comp_variants) + + +def _write_metadata(component_dir: Path, comp_variants: list[VariantSpec]) -> None: + variants_payload: dict[str, Any] = {} + for v in comp_variants: + ep_entry: dict[str, Any] = {"ep": v.ep} + if v.device: + ep_entry["device"] = v.device + if v.compatibility: + ep_entry["compatibility"] = list(v.compatibility) + variants_payload[v.variant_name] = {"ep_compatibility": [ep_entry]} + _write_json(component_dir / "metadata.json", {"variants": variants_payload}) + + +def _write_manifest( + output_dir: Path, + components: list[str], + producer_info: Optional[dict[str, Any]], +) -> None: + manifest: dict[str, Any] = { + "schema_version": _MANIFEST_SCHEMA_VERSION, + "components": components, + } + if producer_info: + # Olive-specific provenance under a namespaced key so future schema + # evolution can't collide with it. + manifest["producer"] = producer_info + _write_json(output_dir / "manifest.json", manifest) + + +# --------------------------------------------------------------------------- +# configs/ handling +# --------------------------------------------------------------------------- + + +def _copy_config_files(output_dir: Path, config_files: dict[str, Path]) -> None: + configs_dir = output_dir / "configs" + configs_dir.mkdir(parents=True, exist_ok=True) + configs_root = configs_dir.resolve() + for name, src in config_files.items(): + if "/" in name or "\\" in name or name in ("", ".", ".."): + logger.warning("Skipping config file with unsafe name %r.", name) + continue + src_path = Path(src) + dest = configs_dir / name + # Belt-and-suspenders: even with the name check above, refuse a dest + # that doesn't land directly under configs/. + if dest.resolve().parent != configs_root: + logger.warning("Skipping config file %r: resolved path escapes configs/.", name) + continue + if dest.exists(): + if not _paths_equal(src_path, dest): + logger.warning( + "configs/%s already present and differs from %s; keeping the existing copy. " + "Per-variant config differences belong in variant.json's consumer_metadata, " + "which is consumer-defined and out of Olive's scope.", + name, + src_path, + ) + continue + if src_path.is_dir(): + shutil.copytree(str(src_path), str(dest)) + elif src_path.is_file(): + shutil.copy2(str(src_path), str(dest)) + else: + logger.warning("Config source %s does not exist; skipping.", src_path) + + +def _paths_equal(a: Path, b: Path) -> bool: + """Return True if a and b have identical content (file or directory).""" + if a.is_file() and b.is_file(): + if a.stat().st_size != b.stat().st_size: + return False + return _sha256_file(a) == _sha256_file(b) + if a.is_dir() and b.is_dir(): + a_entries = sorted(p.name for p in a.iterdir()) + b_entries = sorted(p.name for p in b.iterdir()) + if a_entries != b_entries: + return False + return all(_paths_equal(a / name, b / name) for name in a_entries) + return False + + +# --------------------------------------------------------------------------- +# ONNX external-data discovery +# --------------------------------------------------------------------------- + + +def _discover_external_data(onnx_path: Path) -> list[str]: + """Return the relative ``location`` strings of every external-data blob. + + Locations are validated as safe relative paths (no absolute paths, no + upward traversal). Unsafe references are dropped with a warning rather + than failing — better to package a slightly broken model than to refuse + progress on something the user can fix downstream. + """ + try: + import onnx + except ImportError: + logger.warning("onnx package not available; external-data discovery skipped.") + return [] + + try: + model = onnx.load(str(onnx_path), load_external_data=False) + except Exception: + logger.debug("Failed to parse %s; skipping external-data discovery.", onnx_path, exc_info=True) + return [] + + locations: list[str] = [] + seen: set[str] = set() + for init in model.graph.initializer: + if init.data_location != onnx.TensorProto.EXTERNAL: + continue + for entry in init.external_data: + if entry.key != "location": + continue + location = entry.value + if not _is_safe_relative_location(location): + logger.warning( + "Skipping unsafe external-data location %r in %s.", + location, + onnx_path, + ) + continue + if location not in seen: + locations.append(location) + seen.add(location) + return locations + + +def _is_safe_relative_location(location: str) -> bool: + if not location: + return False + p = Path(location) + if p.is_absolute(): + return False + parts = p.parts + if any(part in ("..", "") for part in parts): + return False + # Reject Windows-drive style paths that slip through is_absolute on POSIX. + return not (len(location) >= 2 and location[1] == ":") + + +# --------------------------------------------------------------------------- +# Helpers (module-level so tests can exercise them directly) +# --------------------------------------------------------------------------- + + +def _provider_options_for_ep(inference_settings: dict[str, Any], ep: str) -> dict[str, Any]: + """Return the provider_options dict that matches ``ep`` by name. + + Olive's inference_settings has ``execution_provider`` (list of EP names) + and ``provider_options`` (parallel list). Match by EP name; do not rely on + positional indexing. + """ + eps = inference_settings.get("execution_provider") or [] + pos = inference_settings.get("provider_options") or [] + for name, opts in zip(eps, pos): + if name == ep: + return opts or {} + return {} + + +def _sha256_file(path: Path) -> str: + h = hashlib.sha256() + with path.open("rb") as fh: + while True: + chunk = fh.read(_HASH_CHUNK) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + + +def _validate_name(name: str, kind: str) -> None: + if not name or not _NAME_RE.match(name): + raise ValueError( + f"Invalid {kind} name {name!r}: must be non-empty and contain only " + "alphanumerics, dot, underscore, hyphen, and space." + ) + if name in (".", "..") or "/" in name or "\\" in name: + raise ValueError(f"Invalid {kind} name {name!r}: path separators and traversal are not allowed.") + + +def _ensure_empty_output_dir(output_dir: Path) -> None: + if output_dir.exists(): + if not output_dir.is_dir(): + raise ValueError(f"Output path {output_dir} exists and is not a directory.") + if any(output_dir.iterdir()): + raise ValueError( + f"Output directory {output_dir} is not empty. Refusing to mix stale files with a new " + "package; please point at an empty (or non-existent) directory." + ) + + +def _write_json(path: Path, data: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as fh: + json.dump(data, fh, indent=2) + fh.write("\n") + logger.info("Wrote %s", path) + + +def parse_compatibility_strings(raw: Optional[str]) -> list[str]: + """Split Olive's ``ep_compatibility_info.`` ONNX metadata string. + + Producers store comma-delimited lists today (e.g. ``"sm_80,sm_86,sm_90"``); + the proposal expects a JSON list of opaque strings. Splitting here keeps + consumers from having to know Olive's convention. + """ + if not raw: + return [] + return [tok.strip() for tok in raw.split(",") if tok.strip()] + + +def disambiguate_variant_names(candidates: list[tuple[str, str]]) -> list[str]: + """Return per-candidate variant names with rank suffixes on collision. + + ``candidates`` is a list of ``(component_name, base_variant_name)`` + tuples; the function returns a parallel list of disambiguated variant + names (suffixing ``_rank{N}`` deterministically when two candidates land + on the same ``(component, base_variant)``). + """ + counts: dict[tuple[str, str], int] = {} + for key in candidates: + counts[key] = counts.get(key, 0) + 1 + + used: dict[tuple[str, str], int] = {} + result: list[str] = [] + for comp, base in candidates: + if counts[(comp, base)] == 1: + result.append(base) + continue + used[(comp, base)] = used.get((comp, base), 0) + 1 + result.append(f"{base}_rank{used[(comp, base)]}") + return result + + +# --------------------------------------------------------------------------- +# Olive model-config helpers +# --------------------------------------------------------------------------- def _get_model_attributes(model_config: dict) -> dict: return model_config.get("config", {}).get("model_attributes") or {} -def _write_json(path: Path, data: dict) -> None: - with open(path, "w") as f: - json.dump(data, f, indent=2) - logger.info("Generated %s", path) +def _resolve_onnx_path(model_config: dict) -> Path: + """Resolve the ONNX file path from an Olive model config. + The config's ``model_path`` may be either: + - the ONNX file itself (a ``LocalFile`` resource), + - a directory containing the ONNX file (a ``LocalFolder`` resource), + in which case ``onnx_file_name`` (or a single ``.onnx`` in the dir) + identifies the actual file. + """ + cfg = model_config.get("config", {}) or {} + raw = cfg.get("model_path") + if not raw: + raise ValueError("Model config has no model_path.") + p = Path(raw) + if p.is_file(): + return p + if p.is_dir(): + onnx_name = cfg.get("onnx_file_name") + if onnx_name: + candidate = p / onnx_name + if candidate.is_file(): + return candidate + onnx_files = list(p.glob("*.onnx")) + if len(onnx_files) == 1: + return onnx_files[0] + raise ValueError( + f"Cannot resolve a unique ONNX file under {p}; " + "set onnx_file_name in the model config or pass the file path directly." + ) + raise FileNotFoundError(f"model_path does not exist: {p}") -def _build_constraints(attrs: dict, model_path: Path) -> dict: - """Build variant constraints from model attributes and ONNX metadata.""" - constraints = {} - ep = attrs.get("ep") - if ep: - constraints["ep"] = ep - device = attrs.get("device") - if device: - constraints["device"] = device - ep_compat = _extract_ep_compatibility_from_onnx(model_path, ep or "") - constraints["ep_compatibility_info"] = ep_compat or "" - return constraints + +def _ep_device_compatibility(attrs: dict, onnx_path: Path) -> tuple[str, Optional[str], list[str]]: + """Extract (ep, device, compatibility[]) for one variant from Olive metadata.""" + ep = attrs.get("ep") or "CPUExecutionProvider" + device = attrs.get("device") or None + compatibility = parse_compatibility_strings(_extract_ep_compatibility_from_onnx(onnx_path, ep)) + return ep, device, compatibility def _extract_ep_compatibility_from_onnx(model_path: Path, ep: str = "") -> Optional[str]: - """Extract ep_compatibility_info from ONNX model custom metadata.""" + """Read ``ep_compatibility_info.`` from the ONNX model's metadata_props.""" if not model_path.is_file(): return None - try: import onnx @@ -365,74 +846,7 @@ def _extract_ep_compatibility_from_onnx(model_path: Path, ep: str = "") -> Optio return None -def _copy_model_files_single(model_path: Path, dest_dir: Path) -> None: - """Copy model files for a single ONNX model into dest_dir.""" - if dest_dir.exists(): - return - - src_dir = model_path.parent if model_path.is_file() else model_path - if src_dir.is_dir(): - shutil.copytree(str(src_dir), str(dest_dir)) - else: - dest_dir.mkdir(parents=True, exist_ok=True) - shutil.copy2(str(model_path), str(dest_dir)) - - -def _copy_component_files(model_path: Path, dest_dir: Path) -> None: - """Copy files for a single ONNX component to dest_dir. - - Copies the .onnx file and its associated context binary (.bin) files - and external data files. - """ - if dest_dir.exists(): - return - - dest_dir.mkdir(parents=True, exist_ok=True) - src_dir = model_path.parent - - # Copy the ONNX file itself - shutil.copy2(str(model_path), str(dest_dir / model_path.name)) - - # Find associated files - associated_files: set[str] = set() - try: - from olive.passes.onnx.common import get_context_bin_file_names - - associated_files.update(get_context_bin_file_names(str(model_path))) - except Exception: - logger.debug("Could not read context binary file names from %s", model_path, exc_info=True) - - try: - import onnx - - onnx_model = onnx.load(str(model_path), load_external_data=False) - for init in onnx_model.graph.initializer: - if init.data_location == onnx.TensorProto.EXTERNAL: - for entry in init.external_data: - if entry.key == "location": - associated_files.add(entry.value) - except Exception: - logger.debug("Could not read ONNX external data from %s", model_path, exc_info=True) - - for file_name in associated_files: - src = src_dir / file_name - if src.is_file(): - shutil.copy2(str(src), str(dest_dir / file_name)) - - -def _remove_config_files(component_dir: Path, config_file_names: set[str]) -> None: - """Remove config files from variant subdirectories (they belong in configs/).""" - for name in config_file_names: - for p in component_dir.rglob(name): - if p.is_dir(): - shutil.rmtree(str(p)) - else: - p.unlink() - logger.debug("Removed duplicate config entry %s from variant directory", p) - - def _task_to_component_name(task: str) -> str: - """Map a task string to a component name for single-component models.""" task_component_map = { "text_generation": "decoder", "text2text_generation": "encoder_decoder", diff --git a/test/cli/test_model_package.py b/test/cli/test_model_package.py index 458337e50..3f2549634 100644 --- a/test/cli/test_model_package.py +++ b/test/cli/test_model_package.py @@ -3,24 +3,104 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- # pylint: disable=protected-access +"""Tests for ``olive generate-model-package``. + +Covers both the CLI argument-parsing / source-validation surface and the +underlying writer (``write_model_package`` and helpers); they live in the +same module (``olive.cli.model_package``). +""" + import json from argparse import ArgumentParser +from pathlib import Path +import onnx import pytest +from onnx import TensorProto, helper + +from olive.cli.model_package import ( + ModelPackageCommand, + VariantSpec, + disambiguate_variant_names, + parse_compatibility_strings, + write_model_package, +) + +# --------------------------------------------------------------------------- +# ONNX fixture helpers +# --------------------------------------------------------------------------- + -from olive.cli.model_package import ModelPackageCommand +def _make_onnx_inline(onnx_path: Path, metadata_props: dict[str, str] | None = None) -> Path: + """Write a minimal ONNX file with no external data.""" + onnx_path.parent.mkdir(parents=True, exist_ok=True) + init = helper.make_tensor("weight", TensorProto.FLOAT, [1], [1.0]) + output = helper.make_tensor_value_info("y", TensorProto.FLOAT, [None]) + node = helper.make_node("Identity", inputs=["weight"], outputs=["y"]) + graph = helper.make_graph([node], "test", inputs=[], outputs=[output], initializer=[init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + if metadata_props: + for k, v in metadata_props.items(): + entry = model.metadata_props.add() + entry.key = k + entry.value = v + onnx.save(model, str(onnx_path)) + return onnx_path -def _create_source_dir(tmp_path, name, model_attributes): - """Create a fake Olive output directory with model_config.json and a dummy .onnx file.""" +def _make_onnx_with_external( + onnx_path: Path, + blob_relpath: str, + blob_bytes: bytes, + metadata_props: dict[str, str] | None = None, +) -> Path: + """Write a minimal ONNX file whose only initializer points at an external-data blob.""" + onnx_path.parent.mkdir(parents=True, exist_ok=True) + blob_path = onnx_path.parent / blob_relpath + blob_path.parent.mkdir(parents=True, exist_ok=True) + blob_path.write_bytes(blob_bytes) + + init = TensorProto() + init.name = "weight" + init.data_type = TensorProto.FLOAT + init.dims.extend([max(1, len(blob_bytes) // 4)]) + init.data_location = TensorProto.EXTERNAL + for k, v in (("location", blob_relpath), ("offset", "0"), ("length", str(len(blob_bytes)))): + entry = init.external_data.add() + entry.key = k + entry.value = v + + output = helper.make_tensor_value_info("y", TensorProto.FLOAT, [None]) + node = helper.make_node("Identity", inputs=["weight"], outputs=["y"]) + graph = helper.make_graph([node], "test", inputs=[], outputs=[output], initializer=[init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + if metadata_props: + for k, v in metadata_props.items(): + entry = model.metadata_props.add() + entry.key = k + entry.value = v + onnx.save(model, str(onnx_path)) + return onnx_path + + +def _create_source_dir( + tmp_path: Path, + name: str, + model_attributes: dict, + *, + onnx_metadata: dict[str, str] | None = None, + inference_settings: dict | None = None, +) -> Path: + """Create a fake Olive output directory with model_config.json and a real ONNX file.""" source_dir = tmp_path / name source_dir.mkdir(parents=True) - model_config = { - "type": "ONNXModel", - "config": {"model_path": str(source_dir / "model.onnx"), "model_attributes": model_attributes}, - } + onnx_path = source_dir / "model.onnx" + _make_onnx_inline(onnx_path, metadata_props=onnx_metadata) + cfg: dict = {"model_path": str(onnx_path), "model_attributes": model_attributes} + if inference_settings is not None: + cfg["inference_settings"] = inference_settings + model_config = {"type": "ONNXModel", "config": cfg} (source_dir / "model_config.json").write_text(json.dumps(model_config)) - (source_dir / "model.onnx").write_text("dummy") return source_dir @@ -33,20 +113,21 @@ def _make_command(args_list): return parsed_args.func(parser, parsed_args, unknown) -class TestSourceValidation: - """Tests for _parse_sources validation logic.""" +# --------------------------------------------------------------------------- +# CLI: source validation +# --------------------------------------------------------------------------- - def test_rejects_single_source(self, tmp_path): - # setup + +class TestSourceValidation: + def test_accepts_single_source(self, tmp_path): src = _create_source_dir(tmp_path, "soc_60", {"ep": "QNNExecutionProvider"}) cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(tmp_path / "out")]) - # execute + assert - with pytest.raises(ValueError, match="At least two"): - cmd._parse_sources() + sources = cmd._parse_sources() + + assert sources == [("soc_60", src)] def test_rejects_missing_model_config(self, tmp_path): - # setup no_config = tmp_path / "no_config" no_config.mkdir() valid = _create_source_dir(tmp_path, "valid", {"ep": "QNNExecutionProvider"}) @@ -54,45 +135,50 @@ def test_rejects_missing_model_config(self, tmp_path): ["generate-model-package", "-s", str(no_config), "-s", str(valid), "-o", str(tmp_path / "out")] ) - # execute + assert with pytest.raises(ValueError, match=r"model_config\.json"): cmd._parse_sources() def test_rejects_nonexistent_path(self, tmp_path): - # setup valid = _create_source_dir(tmp_path, "valid", {"ep": "QNNExecutionProvider"}) cmd = _make_command( ["generate-model-package", "-s", "/nonexistent/path", "-s", str(valid), "-o", str(tmp_path / "out")] ) - # execute + assert with pytest.raises(ValueError, match="does not exist"): cmd._parse_sources() + def test_rejects_duplicate_source_basenames(self, tmp_path): + # Two source dirs share basename "soc_60" — variant names would collide. + src_a = _create_source_dir(tmp_path / "a", "soc_60", {"ep": "QNNExecutionProvider"}) + src_b = _create_source_dir(tmp_path / "b", "soc_60", {"ep": "QNNExecutionProvider"}) + cmd = _make_command(["generate-model-package", "-s", str(src_a), "-s", str(src_b), "-o", str(tmp_path / "out")]) + + with pytest.raises(ValueError, match="share the directory name"): + cmd._parse_sources() + def test_parses_two_valid_sources(self, tmp_path): - # setup src1 = _create_source_dir(tmp_path, "soc_60", {"ep": "QNNExecutionProvider"}) src2 = _create_source_dir(tmp_path, "soc_73", {"ep": "QNNExecutionProvider"}) cmd = _make_command(["generate-model-package", "-s", str(src1), "-s", str(src2), "-o", str(tmp_path / "out")]) - # execute sources = cmd._parse_sources() - # assert assert len(sources) == 2 assert sources[0] == ("soc_60", src1) assert sources[1] == ("soc_73", src2) -class TestGeneratePackageSingle: - """Tests for single-component model package generation.""" +# --------------------------------------------------------------------------- +# CLI: end-to-end (single component, multi-variant) +# --------------------------------------------------------------------------- + - def test_generates_manifest_and_metadata(self, tmp_path): - """Package output should have manifest.json and metadata.json.""" +class TestGeneratePackageMultiVariant: + def test_writes_proposal_layout(self, tmp_path): # setup src1 = _create_source_dir(tmp_path, "soc_60", {"ep": "QNNExecutionProvider", "device": "NPU"}) src2 = _create_source_dir(tmp_path, "soc_73", {"ep": "QNNExecutionProvider", "device": "NPU"}) - out_dir = tmp_path / "out" + out = tmp_path / "out" cmd = _make_command( [ "generate-model-package", @@ -101,7 +187,7 @@ def test_generates_manifest_and_metadata(self, tmp_path): "-s", str(src2), "-o", - str(out_dir), + str(out), "--model_name", "test_model", "--model_version", @@ -112,36 +198,627 @@ def test_generates_manifest_and_metadata(self, tmp_path): # execute cmd.run() - # assert: manifest - manifest_path = out_dir / "manifest.json" - assert manifest_path.exists() - manifest = json.loads(manifest_path.read_text()) - assert manifest["name"] == "test_model" - assert manifest["model_version"] == "2.0" - assert "component_models" in manifest - - # assert: metadata in component dir - component_name = manifest["component_models"][0] - metadata_path = out_dir / "models" / component_name / "metadata.json" - assert metadata_path.exists() - metadata = json.loads(metadata_path.read_text()) - assert "soc_60" in metadata["model_variants"] - assert "soc_73" in metadata["model_variants"] - - # assert: constraints - for variant in metadata["model_variants"].values(): - assert variant["constraints"]["ep"] == "QNNExecutionProvider" - assert variant["constraints"]["device"] == "NPU" - - -class TestAcceleratorInfo: - """Test accelerator info extraction.""" - - def test_defaults_accelerator_when_no_attributes(self): - """Falls back to CPUExecutionProvider/cpu when model_attributes is empty.""" - # setup + execute - ep, device = ModelPackageCommand._extract_accelerator_info([{"type": "ONNXModel", "config": {}}]) - - # assert - assert ep == "CPUExecutionProvider" - assert device == "cpu" + # assert: top-level layout (no models/ wrapper) + assert (out / "manifest.json").is_file() + assert not (out / "models").exists() + + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["schema_version"] == 1 + assert manifest["components"] == ["model"] + assert manifest["producer"]["model_name"] == "test_model" + assert manifest["producer"]["model_version"] == "2.0" + + # metadata uses ep_compatibility[] + metadata = json.loads((out / "model" / "metadata.json").read_text()) + assert set(metadata["variants"]) == {"soc_60", "soc_73"} + for variant_payload in metadata["variants"].values(): + ep_compat = variant_payload["ep_compatibility"] + assert ep_compat == [{"ep": "QNNExecutionProvider", "device": "NPU"}] + + # variant.json contains files[] with filename + for v in ("soc_60", "soc_73"): + variant_json = json.loads((out / "model" / v / "variant.json").read_text()) + assert variant_json["files"][0]["filename"] == "model.onnx" + assert (out / "model" / v / "model.onnx").is_file() + + +class TestGeneratePackageSingleSource: + def test_single_source_is_valid_package(self, tmp_path): + src = _create_source_dir(tmp_path, "cpu_x64", {"ep": "CPUExecutionProvider"}) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(out)]) + + cmd.run() + + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["components"] == ["model"] + metadata = json.loads((out / "model" / "metadata.json").read_text()) + assert "cpu_x64" in metadata["variants"] + assert metadata["variants"]["cpu_x64"]["ep_compatibility"] == [{"ep": "CPUExecutionProvider"}] + # No shared_weights because nothing to dedup. + assert not (out / "model" / "shared_weights").exists() + + +# --------------------------------------------------------------------------- +# Writer: layout + manifest + metadata + variant.json +# --------------------------------------------------------------------------- + + +class TestWriteModelPackageLayout: + def test_writes_proposal_shape_for_single_variant(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + device="cpu", + ) + ], + producer_info={"tool": "olive-ai", "model_name": "demo"}, + ) + + assert (out / "manifest.json").is_file() + assert (out / "decoder" / "metadata.json").is_file() + assert (out / "decoder" / "cpu" / "variant.json").is_file() + assert (out / "decoder" / "cpu" / "model.onnx").is_file() + assert not (out / "models").exists() + + def test_manifest_uses_proposal_schema(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + producer_info={"tool": "olive-ai", "tool_version": "1.2.3", "model_name": "demo"}, + ) + + manifest = json.loads((out / "manifest.json").read_text()) + assert manifest["schema_version"] == 1 + assert manifest["components"] == ["decoder"] + assert manifest["producer"] == { + "tool": "olive-ai", + "tool_version": "1.2.3", + "model_name": "demo", + } + # No legacy fields + assert "name" not in manifest + assert "component_models" not in manifest + assert "model_version" not in manifest + + def test_metadata_uses_ep_compatibility_array(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="qnn-npu", + onnx_files=[onnx_path], + ep="QNNExecutionProvider", + device="NPU", + compatibility=["soc_60", "soc_69"], + ) + ], + ) + + metadata = json.loads((out / "decoder" / "metadata.json").read_text()) + ep_compat = metadata["variants"]["qnn-npu"]["ep_compatibility"] + assert ep_compat == [{"ep": "QNNExecutionProvider", "device": "NPU", "compatibility": ["soc_60", "soc_69"]}] + assert "model_variants" not in metadata + + def test_metadata_omits_optional_fields_when_unset(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + metadata = json.loads((out / "decoder" / "metadata.json").read_text()) + ep_compat = metadata["variants"]["cpu"]["ep_compatibility"][0] + assert ep_compat == {"ep": "CPUExecutionProvider"} + + def test_variant_json_carries_session_and_provider_options(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + inference = { + "session_options": {"graph_optimization_level": 3}, + "execution_provider": ["CPUExecutionProvider"], + "provider_options": [{"intra_op_num_threads": 4}], + } + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + inference_settings=inference, + ) + ], + ) + + variant = json.loads((out / "decoder" / "cpu" / "variant.json").read_text()) + assert variant["files"] == [ + { + "filename": "model.onnx", + "session_options": {"graph_optimization_level": 3}, + "provider_options": {"intra_op_num_threads": 4}, + } + ] + + def test_provider_options_match_ep_by_name(self, tmp_path): + """When inference_settings has multiple EPs, pick the one whose name matches VariantSpec.ep.""" + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + inference = { + "session_options": {}, + "execution_provider": ["CPUExecutionProvider", "QNNExecutionProvider"], + "provider_options": [{"cpu_only": "1"}, {"backend_path": "QnnHtp.so"}], + } + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="qnn", + onnx_files=[onnx_path], + ep="QNNExecutionProvider", + inference_settings=inference, + ) + ], + ) + + variant = json.loads((out / "decoder" / "qnn" / "variant.json").read_text()) + assert variant["files"][0].get("provider_options") == {"backend_path": "QnnHtp.so"} + assert "session_options" not in variant["files"][0] + + +# --------------------------------------------------------------------------- +# Writer: shared_weights / external-data dedup +# --------------------------------------------------------------------------- + + +class TestSharedWeightsDedup: + def test_dedups_identical_external_data_across_variants(self, tmp_path): + blob = b"\x00\x01\x02\x03" * 64 + a = _make_onnx_with_external(tmp_path / "a" / "model.onnx", "model.onnx.data", blob) + b = _make_onnx_with_external(tmp_path / "b" / "model.onnx", "model.onnx.data", blob) + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="v1", + onnx_files=[a], + ep="CPUExecutionProvider", + ), + VariantSpec( + component_name="decoder", + variant_name="v2", + onnx_files=[b], + ep="CPUExecutionProvider", + ), + ], + ) + + shared_root = out / "decoder" / "shared_weights" + assert shared_root.is_dir() + sha_dirs = list(shared_root.iterdir()) + assert len(sha_dirs) == 1 + sha = sha_dirs[0].name + assert (shared_root / sha / "model.onnx.data").is_file() + assert not (out / "decoder" / "v1" / "model.onnx.data").exists() + assert not (out / "decoder" / "v2" / "model.onnx.data").exists() + + for v in ("v1", "v2"): + variant = json.loads((out / "decoder" / v / "variant.json").read_text()) + entry = variant["files"][0] + assert entry["filename"] == "model.onnx" + assert entry["shared_files"] == {"model.onnx.data": sha} + + def test_keeps_external_data_inline_when_unique(self, tmp_path): + a = _make_onnx_with_external(tmp_path / "a" / "model.onnx", "model.onnx.data", b"a-bytes" * 32) + b = _make_onnx_with_external(tmp_path / "b" / "model.onnx", "model.onnx.data", b"b-bytes" * 32) + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="v1", + onnx_files=[a], + ep="CPUExecutionProvider", + ), + VariantSpec( + component_name="decoder", + variant_name="v2", + onnx_files=[b], + ep="CPUExecutionProvider", + ), + ], + ) + + assert not (out / "decoder" / "shared_weights").exists() + assert (out / "decoder" / "v1" / "model.onnx.data").is_file() + assert (out / "decoder" / "v2" / "model.onnx.data").is_file() + + for v in ("v1", "v2"): + variant = json.loads((out / "decoder" / v / "variant.json").read_text()) + assert "shared_files" not in variant["files"][0] + + def test_single_variant_keeps_blob_inline(self, tmp_path): + onnx_path = _make_onnx_with_external(tmp_path / "src" / "model.onnx", "model.onnx.data", b"x" * 128) + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + assert (out / "decoder" / "cpu" / "model.onnx.data").is_file() + assert not (out / "decoder" / "shared_weights").exists() + variant = json.loads((out / "decoder" / "cpu" / "variant.json").read_text()) + assert "shared_files" not in variant["files"][0] + + +# --------------------------------------------------------------------------- +# Writer: configs/ + safety +# --------------------------------------------------------------------------- + + +class TestConfigsAndSafety: + def test_copies_config_files_into_configs_dir(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + cfg_a = tmp_path / "configs_src" / "tokenizer.json" + cfg_a.parent.mkdir(parents=True) + cfg_a.write_text("{}") + cfg_b = tmp_path / "configs_src" / "genai_config.json" + cfg_b.write_text("{}") + out = tmp_path / "package" + + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + config_files={"tokenizer.json": cfg_a, "genai_config.json": cfg_b}, + ) + + assert (out / "configs" / "tokenizer.json").is_file() + assert (out / "configs" / "genai_config.json").is_file() + + def test_rejects_non_empty_output_dir(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + out.mkdir() + (out / "stale.txt").write_text("stale") + + with pytest.raises(ValueError, match="not empty"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + def test_rejects_invalid_component_name(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + with pytest.raises(ValueError, match="component name"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="../escape", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + def test_rejects_invalid_variant_name(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + with pytest.raises(ValueError, match="variant name"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="bad/name", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + ) + + def test_rejects_duplicate_variant_names_per_component(self, tmp_path): + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + out = tmp_path / "package" + + with pytest.raises(ValueError, match="Duplicate variant name"): + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ), + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ), + ], + ) + + def test_rejects_empty_variants(self, tmp_path): + with pytest.raises(ValueError, match="at least one variant"): + write_model_package(output_dir=tmp_path / "package", variants=[]) + + def test_skips_config_file_with_unsafe_key(self, tmp_path): + # setup: a real source plus a config_files map with a path-escaping key. + onnx_path = _make_onnx_inline(tmp_path / "src" / "model.onnx") + bad = tmp_path / "configs_src" / "evil.txt" + bad.parent.mkdir(parents=True) + bad.write_text("oops") + out = tmp_path / "package" + + # execute + write_model_package( + output_dir=out, + variants=[ + VariantSpec( + component_name="decoder", + variant_name="cpu", + onnx_files=[onnx_path], + ep="CPUExecutionProvider", + ) + ], + config_files={"../escape.txt": bad, "subdir/nested.txt": bad, "ok.txt": bad}, + ) + + # assert: unsafe keys are dropped, safe key copied + assert not (out.parent / "escape.txt").exists() + assert not (out / "configs" / "subdir").exists() + assert not (out / "configs" / "..").is_dir() or not (out / ".." / "escape.txt").exists() + assert (out / "configs" / "ok.txt").exists() + # configs/ should contain only the one safe entry + assert sorted(p.name for p in (out / "configs").iterdir()) == ["ok.txt"] + + +# --------------------------------------------------------------------------- +# CLI: mixed source types +# --------------------------------------------------------------------------- + + +class TestMixedSourceTypes: + def test_rejects_mixed_onnx_and_composite(self, tmp_path): + # setup: one ONNXModel source, one CompositeModel source + onnx_src = _create_source_dir(tmp_path, "onnx_src", {"ep": "CPUExecutionProvider"}) + comp_src = tmp_path / "comp_src" + comp_src.mkdir() + comp_onnx = _make_onnx_inline(comp_src / "comp.onnx") + (comp_src / "model_config.json").write_text( + json.dumps( + { + "type": "CompositeModel", + "config": { + "model_components": [{"type": "ONNXModel", "config": {"model_path": str(comp_onnx)}}], + "component_names": ["decoder"], + }, + } + ) + ) + cmd = _make_command( + ["generate-model-package", "-s", str(onnx_src), "-s", str(comp_src), "-o", str(tmp_path / "out")] + ) + + # execute + assert + with pytest.raises(ValueError, match="mix model types"): + cmd.run() + + +# --------------------------------------------------------------------------- +# Helper functions +# --------------------------------------------------------------------------- + + +class TestParseCompatibilityStrings: + def test_splits_comma_delimited_string(self): + assert parse_compatibility_strings("sm_80,sm_86,sm_90") == ["sm_80", "sm_86", "sm_90"] + + def test_strips_whitespace_and_drops_empty(self): + assert parse_compatibility_strings(" sm_80 , , sm_86 ") == ["sm_80", "sm_86"] + + def test_returns_empty_for_none_or_empty(self): + assert parse_compatibility_strings(None) == [] + assert parse_compatibility_strings("") == [] + + +class TestDisambiguateVariantNames: + def test_passes_unique_names_through(self): + assert disambiguate_variant_names([("c", "a"), ("c", "b")]) == ["a", "b"] + + def test_appends_rank_suffix_on_collision(self): + out = disambiguate_variant_names([("c", "a"), ("c", "a"), ("c", "a")]) + assert out == ["a_rank1", "a_rank2", "a_rank3"] + + def test_isolates_collisions_per_component(self): + out = disambiguate_variant_names([("c1", "a"), ("c2", "a")]) + assert out == ["a", "a"] + + +# --------------------------------------------------------------------------- +# CLI: comma-delimited compatibility from ONNX metadata +# --------------------------------------------------------------------------- + + +class TestCompatibilityFromOnnxMetadata: + def test_splits_comma_delimited_metadata(self, tmp_path): + # setup: source with QNNExecutionProvider compat info in ONNX metadata_props + src = _create_source_dir( + tmp_path, + "soc_60", + {"ep": "QNNExecutionProvider", "device": "NPU"}, + onnx_metadata={"ep_compatibility_info.QNNExecutionProvider": "soc_60,soc_69,soc_73"}, + ) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(out)]) + + # execute + cmd.run() + + # assert: compatibility array reflects the comma-split list + metadata = json.loads((out / "model" / "metadata.json").read_text()) + ep_compat = metadata["variants"]["soc_60"]["ep_compatibility"][0] + assert ep_compat["ep"] == "QNNExecutionProvider" + assert ep_compat["compatibility"] == ["soc_60", "soc_69", "soc_73"] + + +# --------------------------------------------------------------------------- +# CLI: composite (per-component inference_settings precedence) +# --------------------------------------------------------------------------- + + +def _create_composite_source( + tmp_path: Path, + name: str, + components: list[dict], + component_names: list[str], + *, + target_inference: dict | None = None, + target_attrs: dict | None = None, +) -> Path: + """Create an Olive-style composite source dir.""" + source_dir = tmp_path / name + source_dir.mkdir(parents=True) + cfg = {"model_components": components, "component_names": component_names} + if target_inference is not None: + cfg["inference_settings"] = target_inference + if target_attrs is not None: + cfg["model_attributes"] = target_attrs + (source_dir / "model_config.json").write_text(json.dumps({"type": "CompositeModel", "config": cfg})) + return source_dir + + +class TestCompositeBuild: + def test_per_component_inference_settings_wins(self, tmp_path): + # setup: component-level inference_settings should override target-level + comp_a_onnx = _make_onnx_inline(tmp_path / "comp_a" / "model.onnx") + comp_b_onnx = _make_onnx_inline(tmp_path / "comp_b" / "model.onnx") + + target_inference = { + "session_options": {"graph_optimization_level": 1}, + "execution_provider": ["CPUExecutionProvider"], + "provider_options": [{}], + } + comp_b_inference = { + "session_options": {"graph_optimization_level": 99}, + "execution_provider": ["CPUExecutionProvider"], + "provider_options": [{}], + } + components = [ + {"type": "ONNXModel", "config": {"model_path": str(comp_a_onnx)}}, + { + "type": "ONNXModel", + "config": {"model_path": str(comp_b_onnx), "inference_settings": comp_b_inference}, + }, + ] + src = _create_composite_source( + tmp_path, + "soc_60", + components, + ["encoder", "decoder"], + target_inference=target_inference, + target_attrs={"ep": "CPUExecutionProvider"}, + ) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(src), "-o", str(out)]) + + # execute + cmd.run() + + # assert: encoder uses target-level, decoder uses component-level + encoder_v = json.loads((out / "encoder" / "soc_60" / "variant.json").read_text()) + assert encoder_v["files"][0]["session_options"] == {"graph_optimization_level": 1} + + decoder_v = json.loads((out / "decoder" / "soc_60" / "variant.json").read_text()) + assert decoder_v["files"][0]["session_options"] == {"graph_optimization_level": 99} + + +# --------------------------------------------------------------------------- +# CLI: unsupported model type +# --------------------------------------------------------------------------- + + +class TestUnsupportedModelType: + def test_rejects_pytorch_model(self, tmp_path): + # setup: a source whose model_config declares an unsupported type + source_dir = tmp_path / "pytorch_src" + source_dir.mkdir() + (source_dir / "model_config.json").write_text( + json.dumps({"type": "PyTorchModel", "config": {"model_path": "pt"}}) + ) + out = tmp_path / "out" + cmd = _make_command(["generate-model-package", "-s", str(source_dir), "-o", str(out)]) + + # execute + assert + with pytest.raises(ValueError, match="Unsupported source model type"): + cmd.run()