diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 6894af518..856123be2 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -121,6 +121,43 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: default=None, help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.", ) + + genai = p.add_argument_group( + "genai bundle", + "Options for producing an onnxruntime-genai inference bundle.", + ) + genai.add_argument( + "--genai-bundle", + type=Path, + default=None, + metavar="DIR", + help=( + "If set, assemble a complete onnxruntime-genai bundle in DIR: " + "ctx.onnx (prefill), iter.onnx (decode), genai_config.json, and " + "tokenizer files. Provide --embeddings and --lm-head to include " + "the placeholder models required for end-to-end inference." + ), + ) + genai.add_argument( + "--embeddings", + type=Path, + default=None, + metavar="ONNX", + help=( + "Path to the embeddings ONNX to copy into the genai bundle as " + "embeddings.onnx. Required for end-to-end genai inference." + ), + ) + genai.add_argument( + "--lm-head", + type=Path, + default=None, + metavar="ONNX", + help=( + "Path to the lm_head ONNX to copy into the genai bundle as " + "lm_head.onnx. Required for end-to-end genai inference." + ), + ) return p.parse_args(argv) @@ -164,6 +201,39 @@ def main(argv: list[str] | None = None) -> int: copy_onnx_model(src, dst) print(f" -> copied to {dst}") + # ----------------------------------------------------------------------- + # Optional: assemble an onnxruntime-genai bundle. + # ----------------------------------------------------------------------- + if args.genai_bundle is not None: + from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle + + prefill_path = Path(model.sub_models["decoder_prefill"].onnx_path) + decode_path = Path(model.sub_models["decoder_gen"].onnx_path) + + print(f"\n=== assembling genai bundle -> {args.genai_bundle} ===") + config_path = write_genai_bundle( + args.genai_bundle, + context_onnx=prefill_path, + iterator_onnx=decode_path, + model_id=args.model_id, + max_cache_len=args.max_cache_len, + prefill_seq_len=args.prefill_seq_len, + embeddings_src=args.embeddings, + lm_head_src=args.lm_head, + ep="qnn" if args.device == "npu" else args.device, + ) + print(f" genai_config.json -> {config_path}") + if args.embeddings is None: + print( + " WARNING: --embeddings not provided; " + "add embeddings.onnx to the bundle before inference." + ) + if args.lm_head is None: + print( + " WARNING: --lm-head not provided; " + "add lm_head.onnx to the bundle before inference." + ) + return 0 diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py new file mode 100644 index 000000000..47aa6683e --- /dev/null +++ b/scripts/infer_genai.py @@ -0,0 +1,151 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""onnxruntime-genai inference for a genai bundle (decoder-pipeline). + +Loads the genai bundle produced by ``export_qwen3_transformer_only.py +--genai-bundle `` and runs greedy text generation using +:class:`~winml.modelkit.session.GenaiSession`. + +The bundle directory must contain ``genai_config.json`` and the four ONNX +graphs it references (``embeddings.onnx``, ``ctx.onnx``, ``iter.onnx``, +``lm_head.onnx``) plus HF tokenizer files. + +Usage:: + + # CPU sanity check (works anywhere onnxruntime-genai is installed) + uv run python scripts/infer_genai.py --prompt "Hello, who are you?" --chat + + # Qualcomm NPU (registers the QNN EP via the Windows ML EP catalog) + uv run python scripts/infer_genai.py \\ + --prompt "Explain what a transformer is." \\ + --ep qnn --chat + + # Point at a non-default bundle + uv run python scripts/infer_genai.py \\ + --model-dir out/my_bundle --prompt "Hi" --ep cpu + + # Pre-compile QNN stages to EPContext on first run; reuse cache on subsequent runs. + # Eliminates per-run JIT overhead (~60-90 s saved on Snapdragon X Elite). + uv run python scripts/infer_genai.py \\ + --prompt "Hello" --ep mixed --compile + +Dependencies (install in a fresh venv):: + + pip install onnxruntime-genai-winml + pip install "windowsml[with-ort]" # registers QNN EP; also provides onnxruntime +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +from winml.modelkit.session import GenaiSession, GenerationConfig + + +# Default bundle directory: /out/qwen3_bundle +_REPO_ROOT = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" + +_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] + + +def _wrap_chat_template(prompt: str) -> str: + """Wrap *prompt* in the ChatML chat template.""" + return GenaiSession.apply_chatml_template(prompt) + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse CLI arguments.""" + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument( + "--prompt", + default="Give me a short introduction to large language models.", + help="Input prompt (default: %(default)s).", + ) + p.add_argument( + "--model-dir", + type=Path, + default=DEFAULT_MODEL_DIR, + metavar="DIR", + help=( + "Path to the genai bundle directory containing genai_config.json " + "and the ONNX / tokenizer files (default: %(default)s)." + ), + ) + p.add_argument( + "--ep", + choices=_SUPPORTED_EPS, + default="mixed", + help="Execution provider: 'mixed' uses genai_config.json as-is (default); " + "'cpu' forces all stages to CPU; 'qnn'/'dml' for full NPU/GPU.", + ) + p.add_argument( + "--max-new", + type=int, + default=128, + help="Maximum number of new tokens to generate (default: %(default)s).", + ) + p.add_argument( + "--chat", + action="store_true", + help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", + ) + p.add_argument( + "--compile", + action="store_true", + help=( + "Pre-compile QNN pipeline stages to EPContext ONNX before loading. " + "On first use this triggers ort.ModelCompiler per stage (~60-90 s for iter). " + "Compiled artifacts are cached in bundle_dir/_compiled/; " + "subsequent runs reuse the cache and skip JIT. " + "Has no effect when --ep cpu." + ), + ) + p.add_argument( + "--verbose", + action="store_true", + help="Enable onnxruntime-genai native model I/O logging.", + ) + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + """Load the genai bundle and run generation.""" + args = parse_args(argv) + + text = _wrap_chat_template(args.prompt) if args.chat else args.prompt + gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) + + try: + session = GenaiSession( + args.model_dir, ep=args.ep, verbose=args.verbose, compile=args.compile + ) + except FileNotFoundError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + + print(f"[load] ep={args.ep} bundle={args.model_dir}") + with session: + print(f"[ctx] context_length={session.context_length}") + print("[gen] ", end="", flush=True) + t0 = time.monotonic() + n = 0 + for token_str in session.generate_streaming(text, gen_cfg): + print(token_str, end="", flush=True) + n += 1 + + dt = time.monotonic() - t0 + print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 332fb9234..dbabe2d60 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -3,4 +3,30 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Qwen3 transformer-only export support (modeling, export ops, IO configs).""" +"""Qwen3 transformer-only export + genai bundle support. + +Modules: + qwen_transformer_only — OnnxConfig, build config, composite model class. + qwen3_modeling — winml-owned Qwen3 module definitions (forward bindings). + qwen3_export_ops — custom ONNX symbolic ops (LpNorm, GQA, 1x1 Conv). + genai — genai_config.json generator + bundle assembler. +""" + +from .genai import ( + DecoderIOMapping, + PipelineStage, + build_decoder_pipeline_stages, + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, +) + + +__all__ = [ + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "build_qwen3_transformer_only_stages", + "write_genai_bundle", +] diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py new file mode 100644 index 000000000..9e65908f5 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -0,0 +1,53 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Qwen3 genai bundle support — thin shim over :mod:`winml.modelkit.utils.genai`. + +All generic logic (``PipelineStage``, ``DecoderIOMapping``, ``build_genai_config``, +``build_decoder_pipeline_stages``, ``write_genai_bundle``) lives in +:mod:`winml.modelkit.utils.genai` so it can be reused by other model families. + +This module re-exports that API unchanged and adds +``build_qwen3_transformer_only_stages`` as a backward-compatible alias for +``build_decoder_pipeline_stages``. New code should prefer the generic names. +""" + +from __future__ import annotations + +from winml.modelkit.utils.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + _detect_format_patterns, + build_decoder_pipeline_stages, + build_genai_config, + qnn_stage_session_options, + write_genai_bundle, +) + + +# Backward-compatible alias: existing callers that import +# ``build_qwen3_transformer_only_stages`` continue to work unchanged. +build_qwen3_transformer_only_stages = build_decoder_pipeline_stages + +# Keep the internal helper accessible for tests that import it directly. +_qnn_stage_session_options = qnn_stage_session_options + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", + "_detect_format_patterns", + "build_decoder_pipeline_stages", + "build_genai_config", + "build_qwen3_transformer_only_stages", + "qnn_stage_session_options", + "write_genai_bundle", +] diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 9a2c0b34c..f2dad581a 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any +from .calibration import get_quant_finalizer from .config import QuantizeResult, WinMLQuantizationConfig @@ -29,18 +30,17 @@ ] -# Names below are loaded lazily via ``__getattr__`` to avoid pulling in -# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports -# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports -# without triggering the heavy imports at runtime. +# ``quantize_onnx`` is loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization at import time. The TYPE_CHECKING re-import gives +# static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports. +# ``get_quant_finalizer`` is imported directly above — its module chain +# (calibration/__init__ -> registry) is lightweight and safe at import time. if TYPE_CHECKING: - from .calibration import get_quant_finalizer from .quantizer import quantize_onnx _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), - "get_quant_finalizer": (".calibration", "get_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index 39b9543c5..213199811 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -38,3 +38,4 @@ def finalize( model_id: str | None = None, ) -> WinMLQuantizationConfig: """Return ``quant`` populated with the graph-derived quant settings.""" + ... diff --git a/src/winml/modelkit/session/__init__.py b/src/winml/modelkit/session/__init__.py index 5148da0b3..d11673961 100644 --- a/src/winml/modelkit/session/__init__.py +++ b/src/winml/modelkit/session/__init__.py @@ -5,6 +5,13 @@ """WinMLSession - ONNX Runtime session manager with WinML EP integration.""" from .ep_registry import WinMLEPRegistry +from .genai_session import ( + GenaiLoadError, + GenaiNotInstalledError, + GenaiSession, + GenaiSessionError, + GenerationConfig, +) from .monitor.ep_monitor import EPMonitor, NullEPMonitor from .monitor.hw_monitor import HWMonitor from .monitor.openvino_monitor import OpenVinoMonitor @@ -17,6 +24,11 @@ __all__ = [ "EPMonitor", + "GenaiLoadError", + "GenaiNotInstalledError", + "GenaiSession", + "GenaiSessionError", + "GenerationConfig", "HWMonitor", "InferenceError", "NullEPMonitor", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py new file mode 100644 index 000000000..312a7358c --- /dev/null +++ b/src/winml/modelkit/session/genai_session.py @@ -0,0 +1,680 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""GenaiSession — onnxruntime-genai session for multi-model decoder pipelines. + +Manages ``og.Model`` + ``og.Generator`` lifecycle for autoregressive text +generation. Reuses :class:`WinMLEPRegistry` for EP discovery and registration +so EPs are downloaded / registered at most once per process. + +Unlike :class:`WinMLSession` (which wraps ``ort.InferenceSession`` for +single-shot inference), ``GenaiSession`` drives a streaming token-by-token +generation loop. The two classes are peers — neither inherits from the other. + +Bundle directory layout expected by ``onnxruntime-genai``:: + + / + genai_config.json ← required; controls pipeline & search + ctx.onnx ← prefill transformer graph + iter.onnx ← decode transformer graph + embeddings.onnx ← embedding lookup + lm_head.onnx ← logit projection + tokenizer.json ← HF tokenizer files + tokenizer_config.json + ... + +Usage:: + + # Context manager (recommended — auto-loads and unloads) + with GenaiSession("out/qwen3_bundle", ep="qnn") as session: + for token_str in session.generate_streaming("Hello, who are you?"): + print(token_str, end="", flush=True) + + # Manual lifecycle + session = GenaiSession("out/qwen3_bundle", ep="cpu") + session.load() + result = session.generate("What is a transformer?") + session.unload() + +Dependencies:: + + pip install onnxruntime-genai-winml + pip install "windowsml[with-ort]" # registers QNN EP; also provides ORT +""" + +from __future__ import annotations + +import json +import logging +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from .ep_registry import WinMLEPRegistry + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Module-level compilation worker (must be at module scope for multiprocessing +# spawn on Windows, which serialises the target via pickle). +# --------------------------------------------------------------------------- + + +def _qnn_compile_worker(src: str, dst: str, qnn_options: dict) -> None: + """Compile *src* ONNX to an EPContext ONNX at *dst* using QNN HTP. + + Executed in a subprocess by :meth:`GenaiSession._compile_stage`. + """ + import onnxruntime as ort + + from winml.modelkit.session.ep_registry import WinMLEPRegistry + from winml.modelkit.winml import add_ep_for_device + + registry = WinMLEPRegistry.get_instance() + registry.register_execution_providers() + so = ort.SessionOptions() + so.add_session_config_entry("ep.context_enable", "1") + so.add_session_config_entry("ep.context_file_path", dst) + add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options) + mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) + mc.compile_to_file(dst) + + +# --------------------------------------------------------------------------- +# Valid EP short names. +# "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU, +# ctx/iter on the target accelerator). +# EP routing is driven entirely by per-stage session_options in the bundle's +# genai_config.json — GenaiSession never calls clear_providers/append_provider. +# --------------------------------------------------------------------------- +_VALID_EPS: frozenset[str] = frozenset({"cpu", "mixed", "qnn", "dml"}) +# EPs that require WinML EP discovery + registration before og.Model() init. +_NEEDS_WINML_EPS: frozenset[str] = frozenset({"mixed", "qnn", "dml"}) + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass +class GenerationConfig: + """Search / sampling parameters for a single generation call. + + All parameters are forwarded to ``og.GeneratorParams.set_search_options``. + ``max_length`` is **not** configurable here — it is set to the bundle's + ``context_length`` (read from ``genai_config.json``) because the static KV + cache size is baked into the ONNX graphs at export time. + + Attributes: + max_new_tokens: Soft cap on the number of new tokens to generate. + Generation stops when the model signals EOS, when the KV buffer is + exhausted (``context_length``), or when this limit is reached, + whichever comes first. + do_sample: Enable sampling (``True``) vs greedy (``False``). + temperature: Sampling temperature. Ignored when ``do_sample=False``. + top_p: Nucleus sampling probability mass. Ignored when + ``do_sample=False``. + top_k: Top-K sampling. ``0`` disables the filter. Ignored when + ``do_sample=False``. + repetition_penalty: Multiplicative penalty for repeated tokens + (``1.0`` = no penalty). + """ + + max_new_tokens: int = 128 + do_sample: bool = False + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + + +class GenaiSessionError(Exception): + """Base exception for GenaiSession.""" + + +class GenaiNotInstalledError(GenaiSessionError): + """``onnxruntime-genai`` (or ``onnxruntime-genai-winml``) is not installed.""" + + +class GenaiLoadError(GenaiSessionError): + """The bundle could not be loaded (bad config, EP unavailable, etc.).""" + + +# --------------------------------------------------------------------------- +# Session +# --------------------------------------------------------------------------- + + +class GenaiSession: + """ORT GenAI session for multi-model decoder-pipeline inference. + + Wraps ``og.Model`` + ``og.Generator`` to provide a clean generation API. + + The session is **stateless across calls**: each :meth:`generate_streaming` + call creates a fresh ``og.Generator`` so KV state does not persist between + prompts. Thread-safety within a single session is not guaranteed. + + Args: + bundle_dir: Path to the genai bundle directory. Must contain + ``genai_config.json`` and the ONNX files it references. + ep: Execution provider short name (``"cpu"``, ``"qnn"``, ``"dml"``). + Non-CPU EPs trigger WinML EP discovery and registration. + context_length: Override for the static KV cache length. When + ``None`` (default), read from ``genai_config.json``. + Must match the ``--max-cache-len`` used during the winml-cli build. + verbose: Enable ``onnxruntime-genai`` native model I/O logging. + compile: Pre-compile QNN pipeline stages to EPContext ONNX on first + run (inside ``bundle_dir/_compiled/``). Subsequent calls reuse + the cached EPContext files, eliminating per-run JIT overhead. + Only stages that can be compiled without hanging are attempted; + stages that fail compilation fall back to the original ONNX. + Has no effect when ``ep="cpu"``. + """ + + # Sub-directory within the bundle that holds pre-compiled EPContext ONNX files. + _COMPILED_SUBDIR: str = "_compiled" + + def __init__( + self, + bundle_dir: str | Path, + ep: str = "cpu", + *, + context_length: int | None = None, + verbose: bool = False, + compile: bool = False, + ) -> None: + self._bundle_dir = Path(bundle_dir) + self._ep = ep.lower() + self._context_length_override = context_length + self._verbose = verbose + self._compile = compile + + # Resolved at load() time. + self._context_length: int | None = None + + # og.* handles — None until load() is called. + self._model: object | None = None + self._tokenizer: object | None = None + + if not self._bundle_dir.exists(): + raise FileNotFoundError(f"Bundle directory not found: {self._bundle_dir}") + config_path = self._bundle_dir / "genai_config.json" + if not config_path.exists(): + raise FileNotFoundError( + f"genai_config.json not found in {self._bundle_dir}. " + "Run export_qwen3_transformer_only.py --genai-bundle first." + ) + if self._ep not in _VALID_EPS: + raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_VALID_EPS)}") + + logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def load(self) -> None: + """Load ``og.Model`` and tokenizer from the bundle directory. + + Idempotent: calling ``load()`` on an already-loaded session is a no-op. + + Raises: + GenaiNotInstalledError: ``onnxruntime_genai`` is not installed. + GenaiLoadError: The model could not be loaded (EP error, bad config, + missing ONNX files, …). + """ + if self._model is not None: + return + + og = self._import_og() + + # Register WinML EPs to ORT GenAI when the bundle may use a hardware EP. + if self._ep in _NEEDS_WINML_EPS: + self._register_eps(og) + + if self._verbose: + og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + + # Determine which bundle directory og.Config should load from. + load_dir = self._bundle_dir + if self._compile and self._ep in _NEEDS_WINML_EPS: + load_dir = self._prepare_compiled_bundle() + + try: + config = og.Config(str(load_dir)) + # EP routing is driven entirely by genai_config.json (per-stage + # session_options). Do NOT call clear_providers/append_provider — + # those only touch the top-level provider and cannot override + # per-stage session_options already embedded in the pipeline config. + self._model = og.Model(config) + self._tokenizer = og.Tokenizer(self._model) + except Exception as exc: + self._model = None + self._tokenizer = None + raise GenaiLoadError(f"Failed to load genai bundle from {load_dir}: {exc}") from exc + + self._context_length = self._context_length_override or self._read_context_length() + logger.info( + "GenaiSession loaded: ep=%s context_length=%d", + self._ep, + self._context_length, + ) + + def unload(self) -> None: + """Release ``og.Model`` and tokenizer handles. + + Safe to call on an unloaded session. + """ + self._model = None + self._tokenizer = None + self._context_length = None + logger.info("GenaiSession unloaded: bundle=%s", self._bundle_dir) + + def __enter__(self) -> GenaiSession: + self.load() + return self + + def __exit__(self, *_: object) -> None: + self.unload() + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + + def generate( + self, + prompt: str | list[int], + config: GenerationConfig | None = None, + ) -> str: + """Generate text and return the full response as a single string. + + This is a convenience wrapper around :meth:`generate_streaming`. + + Args: + prompt: Input text (auto-encoded) or a pre-encoded token-ID list. + config: Generation parameters. Uses :class:`GenerationConfig` + defaults when ``None``. + + Returns: + The generated text (not including the prompt). + """ + return "".join(self.generate_streaming(prompt, config)) + + def generate_streaming( + self, + prompt: str | list[int], + config: GenerationConfig | None = None, + ) -> Iterator[str]: + """Generate text token-by-token, yielding decoded token strings. + + The method auto-loads the session on the first call (lazy-load + equivalent of :meth:`load`). + + Each yield is the decoded string for a single new token. Callers + typically ``print(token, end="", flush=True)`` to stream output. + + Args: + prompt: Input text (auto-encoded via the bundle tokenizer) or a + pre-encoded token-ID list. Pass a pre-formatted string when + chat templates or special tokens are needed — the session is + not aware of any particular model's template format. + config: Generation parameters. Uses :class:`GenerationConfig` + defaults when ``None``. + + Yields: + Decoded string for each newly generated token. + """ + self._ensure_loaded() + og = self._import_og() + cfg = config or GenerationConfig() + + tokens = ( + self._tokenizer.encode(prompt) # type: ignore[union-attr] + if isinstance(prompt, str) + else prompt + ) + + params = og.GeneratorParams(self._model) + params.set_search_options( + max_length=self._context_length, + do_sample=cfg.do_sample, + temperature=cfg.temperature, + top_p=cfg.top_p, + top_k=cfg.top_k, + repetition_penalty=cfg.repetition_penalty, + ) + + generator = og.Generator(self._model, params) + generator.append_tokens(tokens) + + stream = self._tokenizer.create_stream() # type: ignore[union-attr] + n = 0 + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + yield stream.decode(new_token) + n += 1 + if n >= cfg.max_new_tokens: + break + # ``generator`` (og.Generator) holds the KV cache buffer; releasing the + # reference here (end of scope) frees it before the caller processes the + # last yielded token, which is earlier than waiting for GC. + + # ------------------------------------------------------------------ + # Chat-template helpers + # ------------------------------------------------------------------ + + @staticmethod + def apply_chatml_template( + prompt: str, + system: str | None = None, + ) -> str: + r"""Wrap *prompt* in the ChatML format used by Qwen2/3, Yi, Mistral, etc. + + Produces:: + + <|im_start|>system + {system}<|im_end|> + <|im_start|>user + {prompt}<|im_end|> + <|im_start|>assistant + + The trailing ``<|im_start|>assistant\\n`` primes the model to respond + as the assistant role with no leading newline in its output. + + Args: + prompt: The user turn text. + system: Optional system prompt. When ``None`` no system turn is + prepended. + + Returns: + Formatted string ready to pass to :meth:`generate` / + :meth:`generate_streaming`. + """ + parts: list[str] = [] + if system is not None: + parts.append(f"<|im_start|>system\n{system}<|im_end|>\n") + parts.append(f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n") + return "".join(parts) + + # ------------------------------------------------------------------ + # Tokenizer helpers + # ------------------------------------------------------------------ + + def encode(self, text: str) -> list[int]: + """Encode *text* to a list of token IDs using the bundle tokenizer.""" + self._ensure_loaded() + return self._tokenizer.encode(text).tolist() # type: ignore[union-attr] + + def decode(self, tokens: list[int]) -> str: + """Decode a list of token IDs to a string.""" + self._ensure_loaded() + return self._tokenizer.decode(tokens) # type: ignore[union-attr] + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_loaded(self) -> bool: + """``True`` if the model is loaded and ready for generation.""" + return self._model is not None + + @property + def bundle_dir(self) -> Path: + """Path to the genai bundle directory.""" + return self._bundle_dir + + @property + def ep(self) -> str: + """Execution provider short name (as passed to ``__init__``).""" + return self._ep + + @property + def context_length(self) -> int | None: + """Static KV cache length, populated after :meth:`load`.""" + return self._context_length + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _ensure_loaded(self) -> None: + if self._model is None: + self.load() + + def _prepare_compiled_bundle(self) -> Path: + """Create (or reuse) a *compiled* bundle directory. + + Reads ``genai_config.json``, finds QNN-accelerated stages (those with + ``QNNExecutionProvider`` in their ``session_options``), and tries to + compile their ONNX to EPContext format using ``ort.ModelCompiler``. + + The compiled bundle is stored under ``bundle_dir/_compiled/``. On + every call the helper checks whether the cached EPContext file is + newer than the source ONNX; if so, it skips recompilation. + + Returns: + Path to the compiled bundle directory (may equal ``bundle_dir`` + if no compilable stages were found, or if all compilations failed). + """ + compiled_dir = self._bundle_dir / self._COMPILED_SUBDIR + config_src = self._bundle_dir / "genai_config.json" + cfg = json.loads(config_src.read_text(encoding="utf-8")) + + # Collect pipeline stages that use a QNN EP ("qnn" key in provider_options). + # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] + # provider_options format: [{"qnn": {...}}] + pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) + # [(stage_key, onnx_filename, qnn_opts), ...] + qnn_stages: list[tuple[str, str, dict]] = [] + for stage_entry in pipeline_list: + if not isinstance(stage_entry, dict): + continue + for stage_key, stage_cfg in stage_entry.items(): + if not isinstance(stage_cfg, dict): + continue + so = stage_cfg.get("session_options", {}) + providers = so.get("provider_options", []) + for p in providers: + if isinstance(p, dict) and "qnn" in p: + onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) + break + + if not qnn_stages: + logger.info("No QNN stages found in genai_config.json; skipping compilation") + return self._bundle_dir + + compiled_dir.mkdir(exist_ok=True) + modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) + any_compiled = False + + for stage_key, onnx_filename, qnn_opts in qnn_stages: + src_onnx = self._bundle_dir / onnx_filename + ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" + + # Skip recompilation if cache is up-to-date. + if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime: + logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name) + # Use just the filename — genai_config.json lives in compiled_dir, + # so ort-genai resolves filenames relative to compiled_dir. + self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) + any_compiled = True + continue + + # Attempt compilation. + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) + if success: + self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) + any_compiled = True + else: + logger.warning( + "Stage %r: compilation failed; using original ONNX (JIT fallback)", stage_key + ) + + if not any_compiled: + return self._bundle_dir + + # Write the modified genai_config into the compiled sub-directory. + # ONNX filenames are relative to compiled_dir; ort-genai resolves them + # from the directory it loads og.Config from. + compiled_config = compiled_dir / "genai_config.json" + compiled_config.write_text( + json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" + ) + self._mirror_non_onnx_files(compiled_dir) + + logger.info("Compiled bundle prepared at %s", compiled_dir) + return compiled_dir + + @staticmethod + def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None: + """Rewrite a pipeline stage's ``filename`` to an absolute path.""" + pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) + for stage_entry in pipeline_list: + if isinstance(stage_entry, dict) and stage_key in stage_entry: + stage_cfg = stage_entry[stage_key] + if isinstance(stage_cfg, dict): + stage_cfg["filename"] = abs_path + return + + def _compile_stage( + self, + src_onnx: Path, + ctx_out: Path, + stage_key: str, + qnn_opts: dict | None = None, + ) -> bool: + """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. + + Runs in a subprocess so that a ModelCompiler failure does not block + the caller. The QNN options from ``genai_config.json`` are forwarded + unchanged to the compilation session, so each stage is compiled at + exactly the optimization level configured in the bundle. + + Args: + src_onnx: Source ONNX file path. + ctx_out: Destination EPContext ONNX path. + stage_key: Human-readable label for logging. + qnn_opts: QNN provider options from genai_config (e.g. backend_path, + htp_performance_mode, htp_graph_finalization_optimization_mode, + soc_model). + + Returns: + ``True`` if compilation succeeded; ``False`` on timeout or error. + """ + import multiprocessing + + compile_qnn_opts = dict(qnn_opts or {}) + compile_timeout_s = 300 # 5 minutes; ctx compiles in ~73s, iter in ~67s + + logger.info( + "Compiling stage %r: %s → %s (qnn_opts=%s)", + stage_key, + src_onnx.name, + ctx_out.name, + compile_qnn_opts, + ) + + ctx = multiprocessing.get_context("spawn") + proc = ctx.Process( + target=_qnn_compile_worker, args=(str(src_onnx), str(ctx_out), compile_qnn_opts) + ) + proc.start() + proc.join(timeout=compile_timeout_s) + + if proc.is_alive(): + logger.error( + "Stage %r compilation timed out after %ds — killing subprocess.", + stage_key, + compile_timeout_s, + ) + proc.kill() + proc.join() + ctx_out.unlink(missing_ok=True) + return False + + if proc.exitcode != 0: + logger.warning("Stage %r compilation failed (exit %d)", stage_key, proc.exitcode) + ctx_out.unlink(missing_ok=True) + return False + + logger.info("Stage %r compiled successfully → %s", stage_key, ctx_out) + return True + + def _mirror_non_onnx_files(self, compiled_dir: Path) -> None: + """Create symlinks (or copies on Windows) for every non-ONNX file. + + Files are linked/copied into *compiled_dir* so that ``og.Config`` + finds tokenizer files, specials maps, etc. Existing files are left + untouched. + """ + for src in self._bundle_dir.iterdir(): + if src.name == self._COMPILED_SUBDIR: + continue + dst = compiled_dir / src.name + if dst.exists(): + continue + if src.is_file(): + try: + dst.symlink_to(src.resolve()) + except (OSError, NotImplementedError): + # Symlinks may require elevated privileges on Windows; fall back to copy. + shutil.copy2(src, dst) + + @staticmethod + def _import_og() -> object: + """Import and return the ``onnxruntime_genai`` module. + + Raises: + GenaiNotInstalledError: Package not found. + """ + try: + import onnxruntime_genai as og + + return og + except ImportError as exc: + raise GenaiNotInstalledError( + "onnxruntime_genai is not installed. " + "Install it with: pip install onnxruntime-genai-winml" + ) from exc + + def _register_eps(self, og: object) -> None: + """Register WinML EPs with ORT GenAI (idempotent, best-effort).""" + try: + registry = WinMLEPRegistry.get_instance() + if registry.winml_available: + result = registry.register_execution_providers(ort_genai=True) + registered = result.get("onnxruntime_genai", []) + logger.info("WinML EPs registered for ORT GenAI: %s", registered) + except Exception as exc: + logger.warning("WinML EP registration skipped: %s", exc) + + def _read_context_length(self) -> int: + """Read ``model.context_length`` from ``genai_config.json``.""" + cfg = json.loads((self._bundle_dir / "genai_config.json").read_text(encoding="utf-8")) + return int(cfg["model"]["context_length"]) + + +__all__ = [ + "GenaiLoadError", + "GenaiNotInstalledError", + "GenaiSession", + "GenaiSessionError", + "GenerationConfig", +] diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py new file mode 100644 index 000000000..33645cdaf --- /dev/null +++ b/src/winml/modelkit/utils/genai.py @@ -0,0 +1,663 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Generic onnxruntime-genai bundle utilities for decoder-pipeline models. + +The bundle is a directory that ``onnxruntime-genai`` can load directly via +``og.Config(str(bundle_dir))``. It contains: + + genai_config.json — pipeline config consumed by onnxruntime-genai + ctx.onnx — prefill/context ONNX + iter.onnx — iteration/decode ONNX + embeddings.onnx — embedding-lookup ONNX + lm_head.onnx — lm_head ONNX + tokenizer.json — HF tokenizer files (downloaded from the model repo) + tokenizer_config.json + vocab.json / merges.txt / generation_config.json + +The pipeline follows the standard 4-stage decoder layout: + + input_ids → [embeddings] → input_hidden_states + → [context | iterator] → output_hidden_states + present KVs + → [lm_head] → logits + +The context stage runs on the prompt (prefill); the iterator stage runs on each +subsequent decode step. Both share the same KV cache buffer via genai's +``past_present_share_buffer`` mode. + +Public API:: + + from winml.modelkit.utils.genai import ( + build_genai_config, + build_decoder_pipeline_stages, + write_genai_bundle, + DecoderIOMapping, + PipelineStage, + qnn_stage_session_options, + ) + + # Build stages by introspecting the ONNX I/O (no hardcoded tensor names) + stages, decoder_io = build_decoder_pipeline_stages( + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers, ep="qnn" + ) + cfg = build_genai_config( + hf_config, max_cache_len=256, prefill_seq_len=64, + pipeline=stages, decoder_io=decoder_io, + ) + + # Or one-shot bundle assembly + write_genai_bundle( + Path("out/bundle"), + context_onnx=ctx_path, + iterator_onnx=iter_path, + model_id="Qwen/Qwen3-0.6B", + max_cache_len=256, + prefill_seq_len=64, + embeddings_src=emb_path, # None = skip (add later) + lm_head_src=lmh_path, # None = skip (add later) + ep="qnn", + ) +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + +# Default filenames inside the bundle directory. +DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" +DEFAULT_CONTEXT_FILENAME = "ctx.onnx" +DEFAULT_ITERATOR_FILENAME = "iter.onnx" +DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" + +# Regex for detecting indexed tensor names such as ``past_keys_3``. +_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") + + +# --------------------------------------------------------------------------- +# Pipeline data structures +# --------------------------------------------------------------------------- + + +@dataclass +class PipelineStage: + """One stage in an onnxruntime-genai multi-model pipeline. + + Attributes: + name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. + filename: ONNX filename inside the bundle directory. + run_on_prompt: Whether genai runs this stage during the prefill pass. + run_on_token_gen: Whether genai runs this stage during decode steps. + inputs: Actual ONNX input tensor names (not format strings). + outputs: Actual ONNX output tensor names (not format strings). + is_lm_head: Set ``True`` for the final language-model head stage. + """ + + name: str + filename: str + run_on_prompt: bool + run_on_token_gen: bool + inputs: list[str] + outputs: list[str] + is_lm_head: bool = False + session_options: dict | None = None + """Per-stage ORT session options (e.g. provider_options for QNN). + + When set, emitted verbatim as the ``session_options`` key in the + ``genai_config.json`` pipeline stage. Leave ``None`` (default) for + stages that should run on the default (CPU) provider. + """ + + def to_dict(self) -> dict: + """Serialize to the dict format expected by ``genai_config.json``.""" + d: dict = { + "filename": self.filename, + "inputs": list(self.inputs), + "outputs": list(self.outputs), + "run_on_prompt": self.run_on_prompt, + "run_on_token_gen": self.run_on_token_gen, + } + if self.session_options: + d["session_options"] = self.session_options + if self.is_lm_head: + d["is_lm_head"] = True + return d + + +@dataclass +class DecoderIOMapping: + """Maps genai's abstract I/O concepts to ONNX tensor name format strings. + + The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is + the convention genai uses to expand per-layer KV cache tensor names + (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). + + Defaults match the Qwen3 transformer-only export naming; override any field + when building bundles for models with different tensor names. + """ + + input_ids: str = "input_ids" + past_sequence_length: str = "past_seq_len" + total_sequence_length: str = "total_seq_len" + past_key_names: str = "past_keys_%d" + past_value_names: str = "past_values_%d" + logits: str = "logits" + present_key_names: str = "present_keys_%d" + present_value_names: str = "present_values_%d" + + def inputs_dict(self) -> dict: + """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" + return { + "input_ids": self.input_ids, + "past_sequence_length": self.past_sequence_length, + "total_sequence_length": self.total_sequence_length, + "past_key_names": self.past_key_names, + "past_value_names": self.past_value_names, + } + + def outputs_dict(self) -> dict: + """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" + return { + "logits": self.logits, + "present_key_names": self.present_key_names, + "present_value_names": self.present_value_names, + } + + +# --------------------------------------------------------------------------- +# Generic config builder +# --------------------------------------------------------------------------- + + +def build_genai_config( + hf_config: Any, + *, + max_cache_len: int, + prefill_seq_len: int | None = None, + pipeline: list[PipelineStage], + decoder_io: DecoderIOMapping | None = None, +) -> dict: + """Build a ``genai_config.json`` dict for any decoder-pipeline model. + + This function is architecture-agnostic: the caller supplies the pipeline + stages and the I/O name mapping so no tensor names are hardcoded here. + + Args: + hf_config: A ``transformers.PretrainedConfig``. Reads: + ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, + ``num_key_value_heads``, ``head_dim`` (optional, falls back to + ``hidden_size // num_attention_heads``), ``bos_token_id``, + ``eos_token_id``, ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length → ``context_length`` and + ``search.max_length``. + prefill_seq_len: When given, emits a ``sliding_window`` section with + ``window_size=prefill_seq_len``. Pass ``None`` to omit. + pipeline: Ordered list of :class:`PipelineStage` describing each + model in the genai pipeline. + decoder_io: Format-string mapping from genai's abstract I/O names to + actual ONNX tensor names. Defaults to + :class:`DecoderIOMapping` (the standard names). + + Returns: + A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. + """ + if decoder_io is None: + decoder_io = DecoderIOMapping() + + num_layers: int = hf_config.num_hidden_layers + head_size: int = getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ) + + eos_token_id = hf_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + + pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id + + decoder_section: dict = { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + } + + if prefill_seq_len is not None: + decoder_section["sliding_window"] = { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + } + + decoder_section["inputs"] = decoder_io.inputs_dict() + decoder_section["outputs"] = decoder_io.outputs_dict() + decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] + + return { + "model": { + "type": "decoder-pipeline", + "bos_token_id": hf_config.bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "vocab_size": hf_config.vocab_size, + "context_length": max_cache_len, + "decoder": decoder_section, + }, + "search": { + "max_length": max_cache_len, + "min_length": 0, + "do_sample": False, + "past_present_share_buffer": True, + }, + } + + +# --------------------------------------------------------------------------- +# ONNX introspection helpers +# --------------------------------------------------------------------------- + + +def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: + """Return ``(input_names, output_names)`` from an ONNX model graph header. + + External data is intentionally not loaded — only the graph topology is read, + so this is fast even for large quantized models. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "The 'onnx' package is required for ONNX introspection. " + "Install it with: pip install onnx" + ) from exc + model = onnx.load(str(onnx_path), load_external_data=False) + return ( + [inp.name for inp in model.graph.input], + [out.name for out in model.graph.output], + ) + + +def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: + """Detect ``prefix%d`` patterns from a list of indexed tensor names. + + Scans *names* for entries matching ```` where exactly + *num_layers* consecutive zero-based indices are present. + + Returns: + ``{prefix: "prefix%d"}`` for each qualifying group, in the order the + prefixes first appear in *names*. Only groups covering the full + ``[0, num_layers)`` index range are returned. + + Examples:: + + >>> _detect_format_patterns( + ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], + ... num_layers=2, + ... ) + {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + """ + groups: dict[str, list[int]] = {} + for name in names: + m = _KV_INDEXED_RE.match(name) + if m: + prefix, idx = m.group(1), int(m.group(2)) + groups.setdefault(prefix, []).append(idx) + + return { + prefix: f"{prefix}%d" + for prefix, indices in groups.items() + if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) + } + + +def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: + """Sort *patterns* keys by when ``0`` first appears in *names*.""" + + def _key(prefix: str) -> int: + try: + return names.index(f"{prefix}0") + except ValueError: + return len(names) + + return sorted(patterns.keys(), key=_key) + + +# --------------------------------------------------------------------------- +# Per-EP stage session_options helpers +# --------------------------------------------------------------------------- + + +def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + +# --------------------------------------------------------------------------- +# Generic decoder-pipeline stage factory +# --------------------------------------------------------------------------- + + +def build_decoder_pipeline_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build pipeline stages by introspecting the built ONNX models. + + Reads actual tensor names from *context_onnx* and *iterator_onnx* so the + generated ``genai_config.json`` can never drift out of sync with the real + model I/O — no tensor names are hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer stages. ``"qnn"`` injects + QNN HTP ``session_options`` into the ``context`` and ``iterator`` + stages so they run on the NPU while ``embeddings`` and ``lm_head`` + continue on CPU. ``"cpu"`` (default) omits ``session_options`` + from all stages. + soc_model: Snapdragon SoC model number forwarded to the QNN backend + when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. + + Returns: + ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and + the :class:`DecoderIOMapping` derived from the introspected tensor names. + """ + ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) + iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) + + # Detect per-layer KV format-string patterns in the context model. + input_patterns = _detect_format_patterns(ctx_inputs, num_layers) + output_patterns = _detect_format_patterns(ctx_outputs, num_layers) + + in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) + out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) + + past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" + past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" + pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" + pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" + + # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. + non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] + seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] + hidden_state_in = next( + (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" + ) + past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") + total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") + + # Non-indexed output: hidden-state output of the transformer stack. + hidden_state_out = next( + (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" + ) + + decoder_io = DecoderIOMapping( + past_sequence_length=past_seq_name, + total_sequence_length=total_seq_name, + past_key_names=past_key_fmt, + past_value_names=past_val_fmt, + present_key_names=pres_key_fmt, + present_value_names=pres_val_fmt, + ) + + # Per-stage session_options: NPU stages get QNN config; CPU and others get None. + ctx_session_opts: dict | None = None + iter_session_opts: dict | None = None + if ep == "qnn": + ctx_session_opts = qnn_stage_session_options( + "onnxruntime-genai.context", soc_model=soc_model + ) + iter_session_opts = qnn_stage_session_options( + "onnxruntime-genai.iterator", soc_model=soc_model + ) + + stages: list[PipelineStage] = [ + PipelineStage( + name="embeddings", + filename=embeddings_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[decoder_io.input_ids], + outputs=[hidden_state_in], + ), + PipelineStage( + name="context", + filename=context_filename, + run_on_prompt=True, + run_on_token_gen=False, + inputs=ctx_inputs, + outputs=ctx_outputs, + session_options=ctx_session_opts, + ), + PipelineStage( + name="iterator", + filename=iterator_filename, + run_on_prompt=False, + run_on_token_gen=True, + inputs=iter_inputs, + outputs=iter_outputs, + session_options=iter_session_opts, + ), + PipelineStage( + name="lm_head", + filename=lm_head_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[hidden_state_out], + outputs=[decoder_io.logits], + is_lm_head=True, + ), + ] + return stages, decoder_io + + +# --------------------------------------------------------------------------- +# Bundle assembler +# --------------------------------------------------------------------------- + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> Path: + """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. + + Copies the winml-built transformer ONNX files, optional embedding / + lm_head models, HF tokenizer files, and writes ``genai_config.json``. + Tensor names in the config are derived by introspecting the built ONNX + files rather than being hardcoded, so this works for any model that + follows the 4-stage decoder-pipeline layout. + + Args: + output_dir: Destination directory (created if absent). + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + model_id: HuggingFace model ID or local path for config + tokenizer. + max_cache_len: Static KV cache length (= ``context_length`` in genai). + prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). + embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. + lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer (context/iterator) stages. + ``"qnn"`` injects QNN HTP ``session_options`` so those stages run + on the NPU while embeddings and lm_head run on CPU. + ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. + + Returns: + Path to the written ``genai_config.json``. + """ + from transformers import AutoConfig, AutoTokenizer + + from winml.modelkit.onnx import copy_onnx_model + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + context_onnx = Path(context_onnx) + iterator_onnx = Path(iterator_onnx) + + # 1. Copy winml-built transformer ONNX files. + logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) + copy_onnx_model(context_onnx, output_dir / context_filename) + + logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) + copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + + # 2. Copy placeholder models (embeddings + lm_head). + if embeddings_src is not None: + logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) + copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) + else: + logger.warning( + "embeddings_src not provided — '%s' is missing from bundle.", + embeddings_filename, + ) + + if lm_head_src is not None: + logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) + copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) + else: + logger.warning( + "lm_head_src not provided — '%s' is missing from bundle.", + lm_head_filename, + ) + + # 3. Save tokenizer files from the HF snapshot. + logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(str(output_dir)) + + # 4. Build pipeline stages by introspecting the source ONNX files. + hf_config = AutoConfig.from_pretrained(model_id) + stages, decoder_io = build_decoder_pipeline_stages( + context_onnx, + iterator_onnx, + num_layers=hf_config.num_hidden_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + ep=ep, + soc_model=soc_model, + ) + + # 5. Write genai_config.json. + config = build_genai_config( + hf_config, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + pipeline=stages, + decoder_io=decoder_io, + ) + config_path = output_dir / "genai_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + logger.info("Wrote genai_config.json -> %s", config_path) + + _log_bundle_summary(output_dir, config_path) + return config_path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: + """Print a human-readable summary of the assembled bundle.""" + files = sorted(bundle_dir.iterdir()) + lines = [f"\n=== genai bundle: {bundle_dir} ==="] + for f in files: + size_kb = f.stat().st_size / 1024 + tag = "" + if f.name == "genai_config.json": + tag = " <- pipeline config" + elif f.name.endswith(".onnx"): + tag = " <- ONNX graph" + elif f.name.endswith(".data"): + tag = " <- ONNX external weights" + lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") + lines.append(f"\nConfig written to: {config_path}") + logger.info("\n".join(lines)) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "qnn_stage_session_options", + "write_genai_bundle", +] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py new file mode 100644 index 000000000..e359f02e2 --- /dev/null +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -0,0 +1,515 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the Qwen3 genai config builder.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +from winml.modelkit.models.hf.qwen3.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + _detect_format_patterns, + build_genai_config, + build_qwen3_transformer_only_stages, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_config( + *, + num_hidden_layers: int = 28, + hidden_size: int = 1024, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + bos_token_id: int = 151643, + eos_token_id: int = 151645, + pad_token_id: int = 151643, + vocab_size: int = 151936, +) -> SimpleNamespace: + """Return a minimal stand-in for a HF PretrainedConfig.""" + return SimpleNamespace( + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + vocab_size=vocab_size, + ) + + +def _make_pipeline( + num_layers: int = 28, + *, + emb_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + ctx_filename: str = DEFAULT_CONTEXT_FILENAME, + iter_filename: str = DEFAULT_ITERATOR_FILENAME, + lmh_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> list[PipelineStage]: + """Build a standard 4-stage pipeline for use in unit tests.""" + ctx_inputs = [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(num_layers)], + *[f"past_values_{i}" for i in range(num_layers)], + ] + ctx_outputs = [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(num_layers)], + *[f"present_values_{i}" for i in range(num_layers)], + ] + return [ + PipelineStage( + "embeddings", emb_filename, True, True, ["input_ids"], ["input_hidden_states"] + ), + PipelineStage("context", ctx_filename, True, False, ctx_inputs, ctx_outputs), + PipelineStage("iterator", iter_filename, False, True, ctx_inputs, ctx_outputs), + PipelineStage( + "lm_head", + lmh_filename, + True, + True, + ["output_hidden_states"], + ["logits"], + is_lm_head=True, + ), + ] + + +# --------------------------------------------------------------------------- +# Tests: build_genai_config +# --------------------------------------------------------------------------- + + +class TestBuildGenaiConfig: + def setup_method(self) -> None: + self.cfg = _mock_config() + self.result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + ) + + def test_top_level_model_type(self) -> None: + assert self.result["model"]["type"] == "decoder-pipeline" + + def test_token_ids(self) -> None: + m = self.result["model"] + assert m["bos_token_id"] == 151643 + assert m["eos_token_id"] == 151645 + assert m["pad_token_id"] == 151643 + assert m["vocab_size"] == 151936 + + def test_context_length_equals_max_cache_len(self) -> None: + assert self.result["model"]["context_length"] == 256 + + def test_search_max_length_equals_context_length(self) -> None: + assert self.result["search"]["max_length"] == self.result["model"]["context_length"] + + def test_search_past_present_share_buffer(self) -> None: + assert self.result["search"]["past_present_share_buffer"] is True + + def test_decoder_architecture_params(self) -> None: + dec = self.result["model"]["decoder"] + assert dec["hidden_size"] == 1024 + assert dec["num_attention_heads"] == 16 + assert dec["num_key_value_heads"] == 8 + assert dec["num_hidden_layers"] == 28 + assert dec["head_size"] == 128 + + def test_sliding_window_present_when_prefill_seq_len_given(self) -> None: + sw = self.result["model"]["decoder"]["sliding_window"] + assert sw["window_size"] == 64 + assert sw["slide_inputs"] is True + assert sw["slide_key_value_cache"] is False + + def test_sliding_window_absent_when_prefill_seq_len_none(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=None, + pipeline=_make_pipeline(), + ) + assert "sliding_window" not in result["model"]["decoder"] + + def test_decoder_io_tensor_names(self) -> None: + inputs = self.result["model"]["decoder"]["inputs"] + assert inputs["past_sequence_length"] == "past_seq_len" + assert inputs["total_sequence_length"] == "total_seq_len" + assert inputs["past_key_names"] == "past_keys_%d" + assert inputs["past_value_names"] == "past_values_%d" + outputs = self.result["model"]["decoder"]["outputs"] + assert outputs["logits"] == "logits" + assert outputs["present_key_names"] == "present_keys_%d" + assert outputs["present_value_names"] == "present_values_%d" + + def test_custom_decoder_io_mapping(self) -> None: + custom_io = DecoderIOMapping( + past_key_names="k_%d", + past_value_names="v_%d", + present_key_names="pk_%d", + present_value_names="pv_%d", + ) + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + decoder_io=custom_io, + ) + dec_inputs = result["model"]["decoder"]["inputs"] + assert dec_inputs["past_key_names"] == "k_%d" + assert dec_inputs["past_value_names"] == "v_%d" + dec_outputs = result["model"]["decoder"]["outputs"] + assert dec_outputs["present_key_names"] == "pk_%d" + assert dec_outputs["present_value_names"] == "pv_%d" + + def test_pipeline_has_four_stages(self) -> None: + pipeline = self.result["model"]["decoder"]["pipeline"] + assert len(pipeline) == 4 + stage_names = [next(iter(s.keys())) for s in pipeline] + assert stage_names == ["embeddings", "context", "iterator", "lm_head"] + + def test_embeddings_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][0]["embeddings"] + assert stage["filename"] == DEFAULT_EMBEDDINGS_FILENAME + assert stage["inputs"] == ["input_ids"] + assert stage["outputs"] == ["input_hidden_states"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][1]["context"] + assert stage["filename"] == DEFAULT_CONTEXT_FILENAME + assert "input_hidden_states" in stage["inputs"] + assert "past_seq_len" in stage["inputs"] + assert "total_seq_len" in stage["inputs"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is False + + def test_iterator_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert stage["filename"] == DEFAULT_ITERATOR_FILENAME + assert stage["run_on_prompt"] is False + assert stage["run_on_token_gen"] is True + + def test_lm_head_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][3]["lm_head"] + assert stage["filename"] == DEFAULT_LM_HEAD_FILENAME + assert stage["inputs"] == ["output_hidden_states"] + assert stage["outputs"] == ["logits"] + assert stage["is_lm_head"] is True + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_kv_inputs_count(self) -> None: + """context.inputs must include all 28 past_keys + 28 past_values.""" + inputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + past_values = [x for x in inputs if x.startswith("past_values_")] + assert len(past_keys) == 28 + assert len(past_values) == 28 + assert set(past_keys) == {f"past_keys_{i}" for i in range(28)} + assert set(past_values) == {f"past_values_{i}" for i in range(28)} + + def test_context_outputs_kv_count(self) -> None: + outputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["outputs"] + present_keys = [x for x in outputs if x.startswith("present_keys_")] + present_values = [x for x in outputs if x.startswith("present_values_")] + assert len(present_keys) == 28 + assert len(present_values) == 28 + + def test_context_and_iterator_have_same_io(self) -> None: + ctx = self.result["model"]["decoder"]["pipeline"][1]["context"] + itr = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert ctx["inputs"] == itr["inputs"] + assert ctx["outputs"] == itr["outputs"] + + def test_custom_filenames(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=512, + prefill_seq_len=128, + pipeline=_make_pipeline( + emb_filename="emb.onnx", + ctx_filename="prefill.onnx", + iter_filename="decode.onnx", + lmh_filename="head.onnx", + ), + ) + pipeline = result["model"]["decoder"]["pipeline"] + assert pipeline[0]["embeddings"]["filename"] == "emb.onnx" + assert pipeline[1]["context"]["filename"] == "prefill.onnx" + assert pipeline[2]["iterator"]["filename"] == "decode.onnx" + assert pipeline[3]["lm_head"]["filename"] == "head.onnx" + + def test_eos_token_id_list_unpacked(self) -> None: + cfg = _mock_config(eos_token_id=[151645, 151643]) + result = build_genai_config( + cfg, max_cache_len=256, prefill_seq_len=64, pipeline=_make_pipeline() + ) + assert result["model"]["eos_token_id"] == 151645 + + def test_head_size_derived_when_head_dim_missing(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + # no head_dim attribute + bos_token_id=0, + eos_token_id=1, + pad_token_id=0, + vocab_size=32000, + ) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) + # head_size = hidden_size // num_attention_heads = 512 // 8 = 64 + assert result["model"]["decoder"]["head_size"] == 64 + + def test_pad_token_id_falls_back_to_bos(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=64, + bos_token_id=0, + eos_token_id=1, + pad_token_id=None, + vocab_size=32000, + ) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) + assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id + + def test_different_layer_count(self) -> None: + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(4) + ) + inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + assert len(past_keys) == 4 + assert {f"past_keys_{i}" for i in range(4)} == set(past_keys) + + +# --------------------------------------------------------------------------- +# Tests: _detect_format_patterns +# --------------------------------------------------------------------------- + + +class TestDetectFormatPatterns: + def test_detects_two_kv_groups(self) -> None: + names = [ + "input_hidden_states", + "past_seq_len", + "past_keys_0", + "past_keys_1", + "past_keys_2", + "past_values_0", + "past_values_1", + "past_values_2", + ] + result = _detect_format_patterns(names, num_layers=3) + assert result == {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + + def test_ignores_incomplete_index_range(self) -> None: + # Missing index 1 — should not be detected + names = ["prefix_0", "prefix_2"] + result = _detect_format_patterns(names, num_layers=3) + assert "prefix_" not in result + + def test_ignores_wrong_num_layers(self) -> None: + # 3 entries but num_layers=5 + names = ["kv_0", "kv_1", "kv_2"] + result = _detect_format_patterns(names, num_layers=5) + assert len(result) == 0 + + def test_empty_input(self) -> None: + assert _detect_format_patterns([], num_layers=4) == {} + + def test_non_indexed_names_ignored(self) -> None: + names = ["input_hidden_states", "past_seq_len", "total_seq_len"] + result = _detect_format_patterns(names, num_layers=3) + assert result == {} + + def test_single_layer_model(self) -> None: + names = ["keys_0", "vals_0"] + result = _detect_format_patterns(names, num_layers=1) + assert result == {"keys_": "keys_%d", "vals_": "vals_%d"} + + +# --------------------------------------------------------------------------- +# Tests: build_qwen3_transformer_only_stages +# --------------------------------------------------------------------------- + + +class TestBuildQwen3TransformerOnlyStages: + """Uses mocked onnx.load so no real ONNX files are required.""" + + def _ctx_inputs(self, n: int = 4) -> list[str]: + return [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(n)], + *[f"past_values_{i}" for i in range(n)], + ] + + def _ctx_outputs(self, n: int = 4) -> list[str]: + return [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(n)], + *[f"present_values_{i}" for i in range(n)], + ] + + def _patch_onnx(self, n: int = 4): + ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + iter_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + return patch( + "winml.modelkit.utils.genai._introspect_onnx_io", + side_effect=[ctx_io, iter_io], + ) + + def test_returns_four_stages(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + assert len(stages) == 4 + assert [s.name for s in stages] == ["embeddings", "context", "iterator", "lm_head"] + + def test_detected_kv_format_patterns(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_key_names == "past_keys_%d" + assert decoder_io.past_value_names == "past_values_%d" + assert decoder_io.present_key_names == "present_keys_%d" + assert decoder_io.present_value_names == "present_values_%d" + + def test_detected_seq_len_names(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_sequence_length == "past_seq_len" + assert decoder_io.total_sequence_length == "total_seq_len" + + def test_context_stage_inputs_from_onnx(self) -> None: + with self._patch_onnx(n=4): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + ctx_stage = next(s for s in stages if s.name == "context") + assert "input_hidden_states" in ctx_stage.inputs + assert "past_keys_0" in ctx_stage.inputs + assert "past_values_3" in ctx_stage.inputs + + def test_custom_filenames(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", + "iter.onnx", + num_layers=4, + context_filename="prefill.onnx", + iterator_filename="decode.onnx", + embeddings_filename="emb.onnx", + lm_head_filename="head.onnx", + ) + names = {s.name: s.filename for s in stages} + assert names["context"] == "prefill.onnx" + assert names["iterator"] == "decode.onnx" + assert names["embeddings"] == "emb.onnx" + assert names["lm_head"] == "head.onnx" + + def test_roundtrip_with_build_genai_config(self) -> None: + """build_qwen3_transformer_only_stages output feeds build_genai_config cleanly.""" + with self._patch_onnx(n=4): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config( + cfg, + max_cache_len=128, + prefill_seq_len=32, + pipeline=stages, + decoder_io=decoder_io, + ) + assert result["model"]["type"] == "decoder-pipeline" + assert len(result["model"]["decoder"]["pipeline"]) == 4 + + def test_cpu_ep_no_session_options(self) -> None: + """Default cpu ep: context/iterator stages have no session_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="cpu" + ) + ctx = next(s for s in stages if s.name == "context") + itr = next(s for s in stages if s.name == "iterator") + assert ctx.session_options is None + assert itr.session_options is None + + def test_qnn_ep_injects_session_options(self) -> None: + """ep='qnn': context/iterator get QNN session_options; emb/lm_head do not.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + stage_map = {s.name: s for s in stages} + assert stage_map["embeddings"].session_options is None + assert stage_map["lm_head"].session_options is None + ctx_opts = stage_map["context"].session_options + itr_opts = stage_map["iterator"].session_options + assert ctx_opts is not None + assert itr_opts is not None + assert ctx_opts["provider_options"][0]["qnn"]["backend_path"] == "QnnHtp.dll" + assert itr_opts["log_id"] == "onnxruntime-genai.iterator" + + def test_qnn_session_options_in_serialized_config(self) -> None: + """QNN session_options appear in genai_config.json pipeline output.""" + with self._patch_onnx(): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + cfg = build_genai_config( + _mock_config(num_hidden_layers=4), + max_cache_len=256, + prefill_seq_len=64, + pipeline=stages, + decoder_io=decoder_io, + ) + pipeline = cfg["model"]["decoder"]["pipeline"] + ctx_dict = next(s for s in pipeline if "context" in s)["context"] + itr_dict = next(s for s in pipeline if "iterator" in s)["iterator"] + emb_dict = next(s for s in pipeline if "embeddings" in s)["embeddings"] + assert "session_options" in ctx_dict + assert "session_options" in itr_dict + assert "session_options" not in emb_dict + + def test_custom_soc_model(self) -> None: + """soc_model parameter propagates to QNN provider_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn", soc_model="73" + ) + ctx = next(s for s in stages if s.name == "context") + assert ctx.session_options["provider_options"][0]["qnn"]["soc_model"] == "73" diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py index ad53ef352..6881f0f72 100644 --- a/tests/unit/quant/calibration/test_qwen3_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -16,8 +16,9 @@ from types import SimpleNamespace import numpy as np -import onnx import torch +from onnx import TensorProto, helper +from onnx import save as onnx_save from winml.modelkit.quant.calibration.qwen3_transformer_only import ( Qwen3DecodeTrajectoryCalibReader, @@ -47,27 +48,27 @@ def _fake_config() -> SimpleNamespace: def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: """Write a minimal graph carrying the inputs the readers introspect.""" inputs = [ - onnx.helper.make_tensor_value_info( - "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + helper.make_tensor_value_info( + "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ), - onnx.helper.make_tensor_value_info( - "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + helper.make_tensor_value_info( + "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] ), ] - out = onnx.helper.make_tensor_value_info( - "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + out = helper.make_tensor_value_info( + "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ) - gqa = onnx.helper.make_node( + gqa = helper.make_node( "GroupQueryAttention", ["input_hidden_states"], ["attn_out"], name="gqa_layer_0", domain="com.microsoft", ) - identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) - graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) - model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) - onnx.save(model, str(path)) + identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + onnx_save(model, str(path)) def test_graph_shapes_and_gqa_nodes(tmp_path): diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py new file mode 100644 index 000000000..4dcf7ea1c --- /dev/null +++ b/tests/unit/session/test_genai_session.py @@ -0,0 +1,380 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for GenaiSession. + +All tests that touch load() / generate*() mock onnxruntime_genai so no +real model files or GPU/NPU hardware is required. +""" + +from __future__ import annotations + +import json +import sys +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest + +from winml.modelkit.session.genai_session import ( + GenaiLoadError, + GenaiNotInstalledError, + GenaiSession, + GenaiSessionError, + GenerationConfig, +) + + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bundle_dir(tmp_path: Path) -> Path: + """Create a minimal genai bundle directory with genai_config.json.""" + cfg = { + "model": { + "type": "decoder-pipeline", + "context_length": 256, + "decoder": {}, + }, + "search": {"max_length": 256}, + } + (tmp_path / "genai_config.json").write_text(json.dumps(cfg), encoding="utf-8") + return tmp_path + + +@pytest.fixture +def mock_og() -> MagicMock: + """Return a fully mocked onnxruntime_genai module.""" + og = MagicMock(name="onnxruntime_genai") + og.Config.return_value = MagicMock() + og.Model.return_value = MagicMock() + og.Tokenizer.return_value = MagicMock() + og.GeneratorParams.return_value = MagicMock() + + # Generator that yields two tokens then is_done() + gen = MagicMock() + gen.is_done.side_effect = [False, False, True] + gen.get_next_tokens.side_effect = [ + MagicMock(__getitem__=lambda s, i: 10), + MagicMock(__getitem__=lambda s, i: 20), + ] + og.Generator.return_value = gen + + # TokenizerStream decodes tokens to text + stream = MagicMock() + stream.decode.side_effect = ["Hello", " world"] + og.Tokenizer.return_value.create_stream.return_value = stream + + return og + + +def _patch_og(mock: MagicMock): + """Context manager: inject mock_og as onnxruntime_genai in sys.modules.""" + return patch.dict(sys.modules, {"onnxruntime_genai": mock}) + + +# --------------------------------------------------------------------------- +# Tests: GenaiSession.__init__ +# --------------------------------------------------------------------------- + + +class TestGenaiSessionInit: + def test_missing_bundle_dir_raises(self, tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError, match="Bundle directory not found"): + GenaiSession(tmp_path / "nonexistent") + + def test_missing_config_raises(self, tmp_path: Path) -> None: + # Dir exists but no genai_config.json + with pytest.raises(FileNotFoundError, match=r"genai_config\.json not found"): + GenaiSession(tmp_path) + + def test_unknown_ep_raises(self, bundle_dir: Path) -> None: + with pytest.raises(ValueError, match="Unknown EP"): + GenaiSession(bundle_dir, ep="tensorrt") + + def test_default_ep_is_cpu(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert session.ep == "cpu" + + def test_not_loaded_after_init(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert not session.is_loaded + assert session.context_length is None + + def test_bundle_dir_property(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert session.bundle_dir == bundle_dir + + def test_supported_eps(self, bundle_dir: Path) -> None: + for ep in ("cpu", "mixed", "qnn", "dml"): + session = GenaiSession(bundle_dir, ep=ep) + assert session.ep == ep + + +# --------------------------------------------------------------------------- +# Tests: load / unload +# --------------------------------------------------------------------------- + + +class TestGenaiSessionLoad: + def test_load_sets_is_loaded(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + assert session.is_loaded + + def test_load_reads_context_length_from_config( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + assert session.context_length == 256 + + def test_context_length_override(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir, context_length=512) + session.load() + assert session.context_length == 512 + + def test_load_is_idempotent(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + session.load() # second call is a no-op + assert mock_og.Model.call_count == 1 + + def test_unload_clears_state(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + session.unload() + assert not session.is_loaded + assert session.context_length is None + + def test_unload_on_unloaded_session_is_safe(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + session.unload() # should not raise + + def test_context_manager_loads_and_unloads(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + assert session.is_loaded + assert not session.is_loaded + + def test_genai_not_installed_raises(self, bundle_dir: Path) -> None: + with patch.dict(sys.modules, {"onnxruntime_genai": None}): # type: ignore[dict-item] + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiNotInstalledError): + session.load() + + def test_og_load_error_raises_genai_load_error( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + mock_og.Model.side_effect = RuntimeError("driver not found") + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiLoadError, match="driver not found"): + session.load() + + def test_og_load_error_leaves_session_unloaded( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + mock_og.Model.side_effect = RuntimeError("driver not found") + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiLoadError): + session.load() + assert not session.is_loaded + + +# --------------------------------------------------------------------------- +# Tests: EP registration +# --------------------------------------------------------------------------- + + +class TestEPRegistration: + def test_cpu_skips_winml_registration(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + session = GenaiSession(bundle_dir, ep="cpu") + session.load() + mock_reg_cls.assert_not_called() + + def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: + mock_registry = MagicMock() + mock_registry.winml_available = True + mock_registry.register_execution_providers.return_value = { + "onnxruntime_genai": ["QNNExecutionProvider"] + } + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + mock_reg_cls.get_instance.return_value = mock_registry + session = GenaiSession(bundle_dir, ep="qnn") + session.load() + mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) + + def test_mixed_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: + mock_registry = MagicMock() + mock_registry.winml_available = True + mock_registry.register_execution_providers.return_value = { + "onnxruntime_genai": ["QNNExecutionProvider"] + } + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + mock_reg_cls.get_instance.return_value = mock_registry + session = GenaiSession(bundle_dir, ep="mixed") + session.load() + mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) + + def test_config_not_modified_at_load(self, bundle_dir: Path, mock_og: MagicMock) -> None: + # EP routing is driven by genai_config.json — we must NOT touch the config. + with _patch_og(mock_og): + session = GenaiSession(bundle_dir, ep="cpu") + session.load() + mock_og.Config.return_value.clear_providers.assert_not_called() + mock_og.Config.return_value.append_provider.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: generate / generate_streaming +# --------------------------------------------------------------------------- + + +class TestGenerate: + def test_generate_streaming_yields_decoded_tokens( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + tokens = list(session.generate_streaming("hi")) + assert tokens == ["Hello", " world"] + + def test_generate_returns_joined_string(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + result = session.generate("hi") + assert result == "Hello world" + + def test_generate_respects_max_new_tokens(self, bundle_dir: Path, mock_og: MagicMock) -> None: + # Generator never signals done; we stop at max_new_tokens=1 + gen = mock_og.Generator.return_value + gen.is_done.side_effect = None + gen.is_done.return_value = False + gen.get_next_tokens.return_value = MagicMock(__getitem__=lambda s, i: 99) + mock_og.Tokenizer.return_value.create_stream.return_value.decode.return_value = "x" + + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + tokens = list(session.generate_streaming("hi", GenerationConfig(max_new_tokens=1))) + assert len(tokens) == 1 + + def test_generate_with_token_list_input(self, bundle_dir: Path, mock_og: MagicMock) -> None: + """Pre-encoded token IDs are forwarded directly to append_tokens.""" + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming([1, 2, 3])) + gen = mock_og.Generator.return_value + gen.append_tokens.assert_called_once_with([1, 2, 3]) + + def test_generate_deletes_generator_after_iteration( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + """Generator is deleted (not leaked) even on normal completion.""" + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming("hi")) + # No assertions needed — test passes if no ResourceWarning / hang + + def test_generate_with_custom_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + cfg = GenerationConfig(max_new_tokens=64, do_sample=True, temperature=0.7) + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming("hi", cfg)) + params = mock_og.GeneratorParams.return_value + params.set_search_options.assert_called_once() + call_kwargs = params.set_search_options.call_args.kwargs + assert call_kwargs["do_sample"] is True + assert call_kwargs["temperature"] == 0.7 + + def test_generate_uses_context_length_as_max_length( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir, context_length=128) as session: + list(session.generate_streaming("hi")) + params = mock_og.GeneratorParams.return_value + call_kwargs = params.set_search_options.call_args.kwargs + assert call_kwargs["max_length"] == 128 + + def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + assert not session.is_loaded + list(session.generate_streaming("hi")) + assert session.is_loaded + + +# --------------------------------------------------------------------------- +# Tests: apply_chatml_template +# --------------------------------------------------------------------------- + + +class TestApplyChatmlTemplate: + def test_user_only(self) -> None: + result = GenaiSession.apply_chatml_template("Hello") + assert result == "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" + + def test_with_system(self) -> None: + result = GenaiSession.apply_chatml_template("Hello", system="You are helpful.") + assert result.startswith("<|im_start|>system\nYou are helpful.<|im_end|>\n") + assert "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" in result + + def test_no_system_no_system_turn(self) -> None: + result = GenaiSession.apply_chatml_template("Hi") + assert "<|im_start|>system" not in result + + def test_ends_with_assistant_priming(self) -> None: + result = GenaiSession.apply_chatml_template("Hi") + assert result.endswith("<|im_start|>assistant\n") + + +# --------------------------------------------------------------------------- +# Tests: GenerationConfig defaults +# --------------------------------------------------------------------------- + + +class TestGenerationConfig: + def test_defaults(self) -> None: + cfg = GenerationConfig() + assert cfg.max_new_tokens == 128 + assert cfg.do_sample is False + assert cfg.temperature == 1.0 + assert cfg.top_p == 1.0 + assert cfg.top_k == 0 + assert cfg.repetition_penalty == 1.0 + + def test_custom_values(self) -> None: + cfg = GenerationConfig(max_new_tokens=32, do_sample=True, top_k=50) + assert cfg.max_new_tokens == 32 + assert cfg.do_sample is True + assert cfg.top_k == 50 + + +# --------------------------------------------------------------------------- +# Tests: exception hierarchy +# --------------------------------------------------------------------------- + + +class TestExceptions: + def test_genai_not_installed_is_genai_session_error(self) -> None: + assert issubclass(GenaiNotInstalledError, GenaiSessionError) + + def test_genai_load_error_is_genai_session_error(self) -> None: + assert issubclass(GenaiLoadError, GenaiSessionError)