diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py
index 6894af518..856123be2 100644
--- a/scripts/export_qwen3_transformer_only.py
+++ b/scripts/export_qwen3_transformer_only.py
@@ -121,6 +121,43 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
default=None,
help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.",
)
+
+ genai = p.add_argument_group(
+ "genai bundle",
+ "Options for producing an onnxruntime-genai inference bundle.",
+ )
+ genai.add_argument(
+ "--genai-bundle",
+ type=Path,
+ default=None,
+ metavar="DIR",
+ help=(
+ "If set, assemble a complete onnxruntime-genai bundle in DIR: "
+ "ctx.onnx (prefill), iter.onnx (decode), genai_config.json, and "
+ "tokenizer files. Provide --embeddings and --lm-head to include "
+ "the placeholder models required for end-to-end inference."
+ ),
+ )
+ genai.add_argument(
+ "--embeddings",
+ type=Path,
+ default=None,
+ metavar="ONNX",
+ help=(
+ "Path to the embeddings ONNX to copy into the genai bundle as "
+ "embeddings.onnx. Required for end-to-end genai inference."
+ ),
+ )
+ genai.add_argument(
+ "--lm-head",
+ type=Path,
+ default=None,
+ metavar="ONNX",
+ help=(
+ "Path to the lm_head ONNX to copy into the genai bundle as "
+ "lm_head.onnx. Required for end-to-end genai inference."
+ ),
+ )
return p.parse_args(argv)
@@ -164,6 +201,39 @@ def main(argv: list[str] | None = None) -> int:
copy_onnx_model(src, dst)
print(f" -> copied to {dst}")
+ # -----------------------------------------------------------------------
+ # Optional: assemble an onnxruntime-genai bundle.
+ # -----------------------------------------------------------------------
+ if args.genai_bundle is not None:
+ from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle
+
+ prefill_path = Path(model.sub_models["decoder_prefill"].onnx_path)
+ decode_path = Path(model.sub_models["decoder_gen"].onnx_path)
+
+ print(f"\n=== assembling genai bundle -> {args.genai_bundle} ===")
+ config_path = write_genai_bundle(
+ args.genai_bundle,
+ context_onnx=prefill_path,
+ iterator_onnx=decode_path,
+ model_id=args.model_id,
+ max_cache_len=args.max_cache_len,
+ prefill_seq_len=args.prefill_seq_len,
+ embeddings_src=args.embeddings,
+ lm_head_src=args.lm_head,
+ ep="qnn" if args.device == "npu" else args.device,
+ )
+ print(f" genai_config.json -> {config_path}")
+ if args.embeddings is None:
+ print(
+ " WARNING: --embeddings not provided; "
+ "add embeddings.onnx to the bundle before inference."
+ )
+ if args.lm_head is None:
+ print(
+ " WARNING: --lm-head not provided; "
+ "add lm_head.onnx to the bundle before inference."
+ )
+
return 0
diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py
new file mode 100644
index 000000000..47aa6683e
--- /dev/null
+++ b/scripts/infer_genai.py
@@ -0,0 +1,151 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+r"""onnxruntime-genai inference for a genai bundle (decoder-pipeline).
+
+Loads the genai bundle produced by ``export_qwen3_transformer_only.py
+--genai-bundle
`` and runs greedy text generation using
+:class:`~winml.modelkit.session.GenaiSession`.
+
+The bundle directory must contain ``genai_config.json`` and the four ONNX
+graphs it references (``embeddings.onnx``, ``ctx.onnx``, ``iter.onnx``,
+``lm_head.onnx``) plus HF tokenizer files.
+
+Usage::
+
+ # CPU sanity check (works anywhere onnxruntime-genai is installed)
+ uv run python scripts/infer_genai.py --prompt "Hello, who are you?" --chat
+
+ # Qualcomm NPU (registers the QNN EP via the Windows ML EP catalog)
+ uv run python scripts/infer_genai.py \\
+ --prompt "Explain what a transformer is." \\
+ --ep qnn --chat
+
+ # Point at a non-default bundle
+ uv run python scripts/infer_genai.py \\
+ --model-dir out/my_bundle --prompt "Hi" --ep cpu
+
+ # Pre-compile QNN stages to EPContext on first run; reuse cache on subsequent runs.
+ # Eliminates per-run JIT overhead (~60-90 s saved on Snapdragon X Elite).
+ uv run python scripts/infer_genai.py \\
+ --prompt "Hello" --ep mixed --compile
+
+Dependencies (install in a fresh venv)::
+
+ pip install onnxruntime-genai-winml
+ pip install "windowsml[with-ort]" # registers QNN EP; also provides onnxruntime
+"""
+
+from __future__ import annotations
+
+import argparse
+import sys
+import time
+from pathlib import Path
+
+from winml.modelkit.session import GenaiSession, GenerationConfig
+
+
+# Default bundle directory: /out/qwen3_bundle
+_REPO_ROOT = Path(__file__).resolve().parent.parent
+DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle"
+
+_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"]
+
+
+def _wrap_chat_template(prompt: str) -> str:
+ """Wrap *prompt* in the ChatML chat template."""
+ return GenaiSession.apply_chatml_template(prompt)
+
+
+def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
+ """Parse CLI arguments."""
+ p = argparse.ArgumentParser(
+ description=__doc__,
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ p.add_argument(
+ "--prompt",
+ default="Give me a short introduction to large language models.",
+ help="Input prompt (default: %(default)s).",
+ )
+ p.add_argument(
+ "--model-dir",
+ type=Path,
+ default=DEFAULT_MODEL_DIR,
+ metavar="DIR",
+ help=(
+ "Path to the genai bundle directory containing genai_config.json "
+ "and the ONNX / tokenizer files (default: %(default)s)."
+ ),
+ )
+ p.add_argument(
+ "--ep",
+ choices=_SUPPORTED_EPS,
+ default="mixed",
+ help="Execution provider: 'mixed' uses genai_config.json as-is (default); "
+ "'cpu' forces all stages to CPU; 'qnn'/'dml' for full NPU/GPU.",
+ )
+ p.add_argument(
+ "--max-new",
+ type=int,
+ default=128,
+ help="Maximum number of new tokens to generate (default: %(default)s).",
+ )
+ p.add_argument(
+ "--chat",
+ action="store_true",
+ help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).",
+ )
+ p.add_argument(
+ "--compile",
+ action="store_true",
+ help=(
+ "Pre-compile QNN pipeline stages to EPContext ONNX before loading. "
+ "On first use this triggers ort.ModelCompiler per stage (~60-90 s for iter). "
+ "Compiled artifacts are cached in bundle_dir/_compiled/; "
+ "subsequent runs reuse the cache and skip JIT. "
+ "Has no effect when --ep cpu."
+ ),
+ )
+ p.add_argument(
+ "--verbose",
+ action="store_true",
+ help="Enable onnxruntime-genai native model I/O logging.",
+ )
+ return p.parse_args(argv)
+
+
+def main(argv: list[str] | None = None) -> int:
+ """Load the genai bundle and run generation."""
+ args = parse_args(argv)
+
+ text = _wrap_chat_template(args.prompt) if args.chat else args.prompt
+ gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False)
+
+ try:
+ session = GenaiSession(
+ args.model_dir, ep=args.ep, verbose=args.verbose, compile=args.compile
+ )
+ except FileNotFoundError as exc:
+ print(f"ERROR: {exc}", file=sys.stderr)
+ return 1
+
+ print(f"[load] ep={args.ep} bundle={args.model_dir}")
+ with session:
+ print(f"[ctx] context_length={session.context_length}")
+ print("[gen] ", end="", flush=True)
+ t0 = time.monotonic()
+ n = 0
+ for token_str in session.generate_streaming(text, gen_cfg):
+ print(token_str, end="", flush=True)
+ n += 1
+
+ dt = time.monotonic() - t0
+ print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py
index 332fb9234..dbabe2d60 100644
--- a/src/winml/modelkit/models/hf/qwen3/__init__.py
+++ b/src/winml/modelkit/models/hf/qwen3/__init__.py
@@ -3,4 +3,30 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
-"""Qwen3 transformer-only export support (modeling, export ops, IO configs)."""
+"""Qwen3 transformer-only export + genai bundle support.
+
+Modules:
+ qwen_transformer_only — OnnxConfig, build config, composite model class.
+ qwen3_modeling — winml-owned Qwen3 module definitions (forward bindings).
+ qwen3_export_ops — custom ONNX symbolic ops (LpNorm, GQA, 1x1 Conv).
+ genai — genai_config.json generator + bundle assembler.
+"""
+
+from .genai import (
+ DecoderIOMapping,
+ PipelineStage,
+ build_decoder_pipeline_stages,
+ build_genai_config,
+ build_qwen3_transformer_only_stages,
+ write_genai_bundle,
+)
+
+
+__all__ = [
+ "DecoderIOMapping",
+ "PipelineStage",
+ "build_decoder_pipeline_stages",
+ "build_genai_config",
+ "build_qwen3_transformer_only_stages",
+ "write_genai_bundle",
+]
diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py
new file mode 100644
index 000000000..9e65908f5
--- /dev/null
+++ b/src/winml/modelkit/models/hf/qwen3/genai.py
@@ -0,0 +1,53 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+"""Qwen3 genai bundle support — thin shim over :mod:`winml.modelkit.utils.genai`.
+
+All generic logic (``PipelineStage``, ``DecoderIOMapping``, ``build_genai_config``,
+``build_decoder_pipeline_stages``, ``write_genai_bundle``) lives in
+:mod:`winml.modelkit.utils.genai` so it can be reused by other model families.
+
+This module re-exports that API unchanged and adds
+``build_qwen3_transformer_only_stages`` as a backward-compatible alias for
+``build_decoder_pipeline_stages``. New code should prefer the generic names.
+"""
+
+from __future__ import annotations
+
+from winml.modelkit.utils.genai import (
+ DEFAULT_CONTEXT_FILENAME,
+ DEFAULT_EMBEDDINGS_FILENAME,
+ DEFAULT_ITERATOR_FILENAME,
+ DEFAULT_LM_HEAD_FILENAME,
+ DecoderIOMapping,
+ PipelineStage,
+ _detect_format_patterns,
+ build_decoder_pipeline_stages,
+ build_genai_config,
+ qnn_stage_session_options,
+ write_genai_bundle,
+)
+
+
+# Backward-compatible alias: existing callers that import
+# ``build_qwen3_transformer_only_stages`` continue to work unchanged.
+build_qwen3_transformer_only_stages = build_decoder_pipeline_stages
+
+# Keep the internal helper accessible for tests that import it directly.
+_qnn_stage_session_options = qnn_stage_session_options
+
+__all__ = [
+ "DEFAULT_CONTEXT_FILENAME",
+ "DEFAULT_EMBEDDINGS_FILENAME",
+ "DEFAULT_ITERATOR_FILENAME",
+ "DEFAULT_LM_HEAD_FILENAME",
+ "DecoderIOMapping",
+ "PipelineStage",
+ "_detect_format_patterns",
+ "build_decoder_pipeline_stages",
+ "build_genai_config",
+ "build_qwen3_transformer_only_stages",
+ "qnn_stage_session_options",
+ "write_genai_bundle",
+]
diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py
index 9a2c0b34c..f2dad581a 100644
--- a/src/winml/modelkit/quant/__init__.py
+++ b/src/winml/modelkit/quant/__init__.py
@@ -18,6 +18,7 @@
from typing import TYPE_CHECKING, Any
+from .calibration import get_quant_finalizer
from .config import QuantizeResult, WinMLQuantizationConfig
@@ -29,18 +30,17 @@
]
-# Names below are loaded lazily via ``__getattr__`` to avoid pulling in
-# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports
-# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports
-# without triggering the heavy imports at runtime.
+# ``quantize_onnx`` is loaded lazily via ``__getattr__`` to avoid pulling in
+# onnxruntime.quantization at import time. The TYPE_CHECKING re-import gives
+# static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports.
+# ``get_quant_finalizer`` is imported directly above — its module chain
+# (calibration/__init__ -> registry) is lightweight and safe at import time.
if TYPE_CHECKING:
- from .calibration import get_quant_finalizer
from .quantizer import quantize_onnx
_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"quantize_onnx": (".quantizer", "quantize_onnx"),
- "get_quant_finalizer": (".calibration", "get_quant_finalizer"),
}
diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py
index 39b9543c5..213199811 100644
--- a/src/winml/modelkit/quant/calibration/base.py
+++ b/src/winml/modelkit/quant/calibration/base.py
@@ -38,3 +38,4 @@ def finalize(
model_id: str | None = None,
) -> WinMLQuantizationConfig:
"""Return ``quant`` populated with the graph-derived quant settings."""
+ ...
diff --git a/src/winml/modelkit/session/__init__.py b/src/winml/modelkit/session/__init__.py
index 5148da0b3..d11673961 100644
--- a/src/winml/modelkit/session/__init__.py
+++ b/src/winml/modelkit/session/__init__.py
@@ -5,6 +5,13 @@
"""WinMLSession - ONNX Runtime session manager with WinML EP integration."""
from .ep_registry import WinMLEPRegistry
+from .genai_session import (
+ GenaiLoadError,
+ GenaiNotInstalledError,
+ GenaiSession,
+ GenaiSessionError,
+ GenerationConfig,
+)
from .monitor.ep_monitor import EPMonitor, NullEPMonitor
from .monitor.hw_monitor import HWMonitor
from .monitor.openvino_monitor import OpenVinoMonitor
@@ -17,6 +24,11 @@
__all__ = [
"EPMonitor",
+ "GenaiLoadError",
+ "GenaiNotInstalledError",
+ "GenaiSession",
+ "GenaiSessionError",
+ "GenerationConfig",
"HWMonitor",
"InferenceError",
"NullEPMonitor",
diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py
new file mode 100644
index 000000000..312a7358c
--- /dev/null
+++ b/src/winml/modelkit/session/genai_session.py
@@ -0,0 +1,680 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+"""GenaiSession — onnxruntime-genai session for multi-model decoder pipelines.
+
+Manages ``og.Model`` + ``og.Generator`` lifecycle for autoregressive text
+generation. Reuses :class:`WinMLEPRegistry` for EP discovery and registration
+so EPs are downloaded / registered at most once per process.
+
+Unlike :class:`WinMLSession` (which wraps ``ort.InferenceSession`` for
+single-shot inference), ``GenaiSession`` drives a streaming token-by-token
+generation loop. The two classes are peers — neither inherits from the other.
+
+Bundle directory layout expected by ``onnxruntime-genai``::
+
+ /
+ genai_config.json ← required; controls pipeline & search
+ ctx.onnx ← prefill transformer graph
+ iter.onnx ← decode transformer graph
+ embeddings.onnx ← embedding lookup
+ lm_head.onnx ← logit projection
+ tokenizer.json ← HF tokenizer files
+ tokenizer_config.json
+ ...
+
+Usage::
+
+ # Context manager (recommended — auto-loads and unloads)
+ with GenaiSession("out/qwen3_bundle", ep="qnn") as session:
+ for token_str in session.generate_streaming("Hello, who are you?"):
+ print(token_str, end="", flush=True)
+
+ # Manual lifecycle
+ session = GenaiSession("out/qwen3_bundle", ep="cpu")
+ session.load()
+ result = session.generate("What is a transformer?")
+ session.unload()
+
+Dependencies::
+
+ pip install onnxruntime-genai-winml
+ pip install "windowsml[with-ort]" # registers QNN EP; also provides ORT
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import shutil
+from dataclasses import dataclass
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+from .ep_registry import WinMLEPRegistry
+
+
+if TYPE_CHECKING:
+ from collections.abc import Iterator
+
+
+logger = logging.getLogger(__name__)
+
+
+# ---------------------------------------------------------------------------
+# Module-level compilation worker (must be at module scope for multiprocessing
+# spawn on Windows, which serialises the target via pickle).
+# ---------------------------------------------------------------------------
+
+
+def _qnn_compile_worker(src: str, dst: str, qnn_options: dict) -> None:
+ """Compile *src* ONNX to an EPContext ONNX at *dst* using QNN HTP.
+
+ Executed in a subprocess by :meth:`GenaiSession._compile_stage`.
+ """
+ import onnxruntime as ort
+
+ from winml.modelkit.session.ep_registry import WinMLEPRegistry
+ from winml.modelkit.winml import add_ep_for_device
+
+ registry = WinMLEPRegistry.get_instance()
+ registry.register_execution_providers()
+ so = ort.SessionOptions()
+ so.add_session_config_entry("ep.context_enable", "1")
+ so.add_session_config_entry("ep.context_file_path", dst)
+ add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options)
+ mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False)
+ mc.compile_to_file(dst)
+
+
+# ---------------------------------------------------------------------------
+# Valid EP short names.
+# "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU,
+# ctx/iter on the target accelerator).
+# EP routing is driven entirely by per-stage session_options in the bundle's
+# genai_config.json — GenaiSession never calls clear_providers/append_provider.
+# ---------------------------------------------------------------------------
+_VALID_EPS: frozenset[str] = frozenset({"cpu", "mixed", "qnn", "dml"})
+# EPs that require WinML EP discovery + registration before og.Model() init.
+_NEEDS_WINML_EPS: frozenset[str] = frozenset({"mixed", "qnn", "dml"})
+
+
+# ---------------------------------------------------------------------------
+# Data classes
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class GenerationConfig:
+ """Search / sampling parameters for a single generation call.
+
+ All parameters are forwarded to ``og.GeneratorParams.set_search_options``.
+ ``max_length`` is **not** configurable here — it is set to the bundle's
+ ``context_length`` (read from ``genai_config.json``) because the static KV
+ cache size is baked into the ONNX graphs at export time.
+
+ Attributes:
+ max_new_tokens: Soft cap on the number of new tokens to generate.
+ Generation stops when the model signals EOS, when the KV buffer is
+ exhausted (``context_length``), or when this limit is reached,
+ whichever comes first.
+ do_sample: Enable sampling (``True``) vs greedy (``False``).
+ temperature: Sampling temperature. Ignored when ``do_sample=False``.
+ top_p: Nucleus sampling probability mass. Ignored when
+ ``do_sample=False``.
+ top_k: Top-K sampling. ``0`` disables the filter. Ignored when
+ ``do_sample=False``.
+ repetition_penalty: Multiplicative penalty for repeated tokens
+ (``1.0`` = no penalty).
+ """
+
+ max_new_tokens: int = 128
+ do_sample: bool = False
+ temperature: float = 1.0
+ top_p: float = 1.0
+ top_k: int = 0
+ repetition_penalty: float = 1.0
+
+
+# ---------------------------------------------------------------------------
+# Exceptions
+# ---------------------------------------------------------------------------
+
+
+class GenaiSessionError(Exception):
+ """Base exception for GenaiSession."""
+
+
+class GenaiNotInstalledError(GenaiSessionError):
+ """``onnxruntime-genai`` (or ``onnxruntime-genai-winml``) is not installed."""
+
+
+class GenaiLoadError(GenaiSessionError):
+ """The bundle could not be loaded (bad config, EP unavailable, etc.)."""
+
+
+# ---------------------------------------------------------------------------
+# Session
+# ---------------------------------------------------------------------------
+
+
+class GenaiSession:
+ """ORT GenAI session for multi-model decoder-pipeline inference.
+
+ Wraps ``og.Model`` + ``og.Generator`` to provide a clean generation API.
+
+ The session is **stateless across calls**: each :meth:`generate_streaming`
+ call creates a fresh ``og.Generator`` so KV state does not persist between
+ prompts. Thread-safety within a single session is not guaranteed.
+
+ Args:
+ bundle_dir: Path to the genai bundle directory. Must contain
+ ``genai_config.json`` and the ONNX files it references.
+ ep: Execution provider short name (``"cpu"``, ``"qnn"``, ``"dml"``).
+ Non-CPU EPs trigger WinML EP discovery and registration.
+ context_length: Override for the static KV cache length. When
+ ``None`` (default), read from ``genai_config.json``.
+ Must match the ``--max-cache-len`` used during the winml-cli build.
+ verbose: Enable ``onnxruntime-genai`` native model I/O logging.
+ compile: Pre-compile QNN pipeline stages to EPContext ONNX on first
+ run (inside ``bundle_dir/_compiled/``). Subsequent calls reuse
+ the cached EPContext files, eliminating per-run JIT overhead.
+ Only stages that can be compiled without hanging are attempted;
+ stages that fail compilation fall back to the original ONNX.
+ Has no effect when ``ep="cpu"``.
+ """
+
+ # Sub-directory within the bundle that holds pre-compiled EPContext ONNX files.
+ _COMPILED_SUBDIR: str = "_compiled"
+
+ def __init__(
+ self,
+ bundle_dir: str | Path,
+ ep: str = "cpu",
+ *,
+ context_length: int | None = None,
+ verbose: bool = False,
+ compile: bool = False,
+ ) -> None:
+ self._bundle_dir = Path(bundle_dir)
+ self._ep = ep.lower()
+ self._context_length_override = context_length
+ self._verbose = verbose
+ self._compile = compile
+
+ # Resolved at load() time.
+ self._context_length: int | None = None
+
+ # og.* handles — None until load() is called.
+ self._model: object | None = None
+ self._tokenizer: object | None = None
+
+ if not self._bundle_dir.exists():
+ raise FileNotFoundError(f"Bundle directory not found: {self._bundle_dir}")
+ config_path = self._bundle_dir / "genai_config.json"
+ if not config_path.exists():
+ raise FileNotFoundError(
+ f"genai_config.json not found in {self._bundle_dir}. "
+ "Run export_qwen3_transformer_only.py --genai-bundle first."
+ )
+ if self._ep not in _VALID_EPS:
+ raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_VALID_EPS)}")
+
+ logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep)
+
+ # ------------------------------------------------------------------
+ # Lifecycle
+ # ------------------------------------------------------------------
+
+ def load(self) -> None:
+ """Load ``og.Model`` and tokenizer from the bundle directory.
+
+ Idempotent: calling ``load()`` on an already-loaded session is a no-op.
+
+ Raises:
+ GenaiNotInstalledError: ``onnxruntime_genai`` is not installed.
+ GenaiLoadError: The model could not be loaded (EP error, bad config,
+ missing ONNX files, …).
+ """
+ if self._model is not None:
+ return
+
+ og = self._import_og()
+
+ # Register WinML EPs to ORT GenAI when the bundle may use a hardware EP.
+ if self._ep in _NEEDS_WINML_EPS:
+ self._register_eps(og)
+
+ if self._verbose:
+ og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True)
+
+ # Determine which bundle directory og.Config should load from.
+ load_dir = self._bundle_dir
+ if self._compile and self._ep in _NEEDS_WINML_EPS:
+ load_dir = self._prepare_compiled_bundle()
+
+ try:
+ config = og.Config(str(load_dir))
+ # EP routing is driven entirely by genai_config.json (per-stage
+ # session_options). Do NOT call clear_providers/append_provider —
+ # those only touch the top-level provider and cannot override
+ # per-stage session_options already embedded in the pipeline config.
+ self._model = og.Model(config)
+ self._tokenizer = og.Tokenizer(self._model)
+ except Exception as exc:
+ self._model = None
+ self._tokenizer = None
+ raise GenaiLoadError(f"Failed to load genai bundle from {load_dir}: {exc}") from exc
+
+ self._context_length = self._context_length_override or self._read_context_length()
+ logger.info(
+ "GenaiSession loaded: ep=%s context_length=%d",
+ self._ep,
+ self._context_length,
+ )
+
+ def unload(self) -> None:
+ """Release ``og.Model`` and tokenizer handles.
+
+ Safe to call on an unloaded session.
+ """
+ self._model = None
+ self._tokenizer = None
+ self._context_length = None
+ logger.info("GenaiSession unloaded: bundle=%s", self._bundle_dir)
+
+ def __enter__(self) -> GenaiSession:
+ self.load()
+ return self
+
+ def __exit__(self, *_: object) -> None:
+ self.unload()
+
+ # ------------------------------------------------------------------
+ # Generation
+ # ------------------------------------------------------------------
+
+ def generate(
+ self,
+ prompt: str | list[int],
+ config: GenerationConfig | None = None,
+ ) -> str:
+ """Generate text and return the full response as a single string.
+
+ This is a convenience wrapper around :meth:`generate_streaming`.
+
+ Args:
+ prompt: Input text (auto-encoded) or a pre-encoded token-ID list.
+ config: Generation parameters. Uses :class:`GenerationConfig`
+ defaults when ``None``.
+
+ Returns:
+ The generated text (not including the prompt).
+ """
+ return "".join(self.generate_streaming(prompt, config))
+
+ def generate_streaming(
+ self,
+ prompt: str | list[int],
+ config: GenerationConfig | None = None,
+ ) -> Iterator[str]:
+ """Generate text token-by-token, yielding decoded token strings.
+
+ The method auto-loads the session on the first call (lazy-load
+ equivalent of :meth:`load`).
+
+ Each yield is the decoded string for a single new token. Callers
+ typically ``print(token, end="", flush=True)`` to stream output.
+
+ Args:
+ prompt: Input text (auto-encoded via the bundle tokenizer) or a
+ pre-encoded token-ID list. Pass a pre-formatted string when
+ chat templates or special tokens are needed — the session is
+ not aware of any particular model's template format.
+ config: Generation parameters. Uses :class:`GenerationConfig`
+ defaults when ``None``.
+
+ Yields:
+ Decoded string for each newly generated token.
+ """
+ self._ensure_loaded()
+ og = self._import_og()
+ cfg = config or GenerationConfig()
+
+ tokens = (
+ self._tokenizer.encode(prompt) # type: ignore[union-attr]
+ if isinstance(prompt, str)
+ else prompt
+ )
+
+ params = og.GeneratorParams(self._model)
+ params.set_search_options(
+ max_length=self._context_length,
+ do_sample=cfg.do_sample,
+ temperature=cfg.temperature,
+ top_p=cfg.top_p,
+ top_k=cfg.top_k,
+ repetition_penalty=cfg.repetition_penalty,
+ )
+
+ generator = og.Generator(self._model, params)
+ generator.append_tokens(tokens)
+
+ stream = self._tokenizer.create_stream() # type: ignore[union-attr]
+ n = 0
+ while not generator.is_done():
+ generator.generate_next_token()
+ new_token = generator.get_next_tokens()[0]
+ yield stream.decode(new_token)
+ n += 1
+ if n >= cfg.max_new_tokens:
+ break
+ # ``generator`` (og.Generator) holds the KV cache buffer; releasing the
+ # reference here (end of scope) frees it before the caller processes the
+ # last yielded token, which is earlier than waiting for GC.
+
+ # ------------------------------------------------------------------
+ # Chat-template helpers
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def apply_chatml_template(
+ prompt: str,
+ system: str | None = None,
+ ) -> str:
+ r"""Wrap *prompt* in the ChatML format used by Qwen2/3, Yi, Mistral, etc.
+
+ Produces::
+
+ <|im_start|>system
+ {system}<|im_end|>
+ <|im_start|>user
+ {prompt}<|im_end|>
+ <|im_start|>assistant
+
+ The trailing ``<|im_start|>assistant\\n`` primes the model to respond
+ as the assistant role with no leading newline in its output.
+
+ Args:
+ prompt: The user turn text.
+ system: Optional system prompt. When ``None`` no system turn is
+ prepended.
+
+ Returns:
+ Formatted string ready to pass to :meth:`generate` /
+ :meth:`generate_streaming`.
+ """
+ parts: list[str] = []
+ if system is not None:
+ parts.append(f"<|im_start|>system\n{system}<|im_end|>\n")
+ parts.append(f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n")
+ return "".join(parts)
+
+ # ------------------------------------------------------------------
+ # Tokenizer helpers
+ # ------------------------------------------------------------------
+
+ def encode(self, text: str) -> list[int]:
+ """Encode *text* to a list of token IDs using the bundle tokenizer."""
+ self._ensure_loaded()
+ return self._tokenizer.encode(text).tolist() # type: ignore[union-attr]
+
+ def decode(self, tokens: list[int]) -> str:
+ """Decode a list of token IDs to a string."""
+ self._ensure_loaded()
+ return self._tokenizer.decode(tokens) # type: ignore[union-attr]
+
+ # ------------------------------------------------------------------
+ # Properties
+ # ------------------------------------------------------------------
+
+ @property
+ def is_loaded(self) -> bool:
+ """``True`` if the model is loaded and ready for generation."""
+ return self._model is not None
+
+ @property
+ def bundle_dir(self) -> Path:
+ """Path to the genai bundle directory."""
+ return self._bundle_dir
+
+ @property
+ def ep(self) -> str:
+ """Execution provider short name (as passed to ``__init__``)."""
+ return self._ep
+
+ @property
+ def context_length(self) -> int | None:
+ """Static KV cache length, populated after :meth:`load`."""
+ return self._context_length
+
+ # ------------------------------------------------------------------
+ # Private helpers
+ # ------------------------------------------------------------------
+
+ def _ensure_loaded(self) -> None:
+ if self._model is None:
+ self.load()
+
+ def _prepare_compiled_bundle(self) -> Path:
+ """Create (or reuse) a *compiled* bundle directory.
+
+ Reads ``genai_config.json``, finds QNN-accelerated stages (those with
+ ``QNNExecutionProvider`` in their ``session_options``), and tries to
+ compile their ONNX to EPContext format using ``ort.ModelCompiler``.
+
+ The compiled bundle is stored under ``bundle_dir/_compiled/``. On
+ every call the helper checks whether the cached EPContext file is
+ newer than the source ONNX; if so, it skips recompilation.
+
+ Returns:
+ Path to the compiled bundle directory (may equal ``bundle_dir``
+ if no compilable stages were found, or if all compilations failed).
+ """
+ compiled_dir = self._bundle_dir / self._COMPILED_SUBDIR
+ config_src = self._bundle_dir / "genai_config.json"
+ cfg = json.loads(config_src.read_text(encoding="utf-8"))
+
+ # Collect pipeline stages that use a QNN EP ("qnn" key in provider_options).
+ # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...]
+ # provider_options format: [{"qnn": {...}}]
+ pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", [])
+ # [(stage_key, onnx_filename, qnn_opts), ...]
+ qnn_stages: list[tuple[str, str, dict]] = []
+ for stage_entry in pipeline_list:
+ if not isinstance(stage_entry, dict):
+ continue
+ for stage_key, stage_cfg in stage_entry.items():
+ if not isinstance(stage_cfg, dict):
+ continue
+ so = stage_cfg.get("session_options", {})
+ providers = so.get("provider_options", [])
+ for p in providers:
+ if isinstance(p, dict) and "qnn" in p:
+ onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx")
+ qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"])))
+ break
+
+ if not qnn_stages:
+ logger.info("No QNN stages found in genai_config.json; skipping compilation")
+ return self._bundle_dir
+
+ compiled_dir.mkdir(exist_ok=True)
+ modified_cfg = json.loads(config_src.read_text(encoding="utf-8"))
+ any_compiled = False
+
+ for stage_key, onnx_filename, qnn_opts in qnn_stages:
+ src_onnx = self._bundle_dir / onnx_filename
+ ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx"
+
+ # Skip recompilation if cache is up-to-date.
+ if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime:
+ logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name)
+ # Use just the filename — genai_config.json lives in compiled_dir,
+ # so ort-genai resolves filenames relative to compiled_dir.
+ self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name)
+ any_compiled = True
+ continue
+
+ # Attempt compilation.
+ success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts)
+ if success:
+ self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name)
+ any_compiled = True
+ else:
+ logger.warning(
+ "Stage %r: compilation failed; using original ONNX (JIT fallback)", stage_key
+ )
+
+ if not any_compiled:
+ return self._bundle_dir
+
+ # Write the modified genai_config into the compiled sub-directory.
+ # ONNX filenames are relative to compiled_dir; ort-genai resolves them
+ # from the directory it loads og.Config from.
+ compiled_config = compiled_dir / "genai_config.json"
+ compiled_config.write_text(
+ json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8"
+ )
+ self._mirror_non_onnx_files(compiled_dir)
+
+ logger.info("Compiled bundle prepared at %s", compiled_dir)
+ return compiled_dir
+
+ @staticmethod
+ def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None:
+ """Rewrite a pipeline stage's ``filename`` to an absolute path."""
+ pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", [])
+ for stage_entry in pipeline_list:
+ if isinstance(stage_entry, dict) and stage_key in stage_entry:
+ stage_cfg = stage_entry[stage_key]
+ if isinstance(stage_cfg, dict):
+ stage_cfg["filename"] = abs_path
+ return
+
+ def _compile_stage(
+ self,
+ src_onnx: Path,
+ ctx_out: Path,
+ stage_key: str,
+ qnn_opts: dict | None = None,
+ ) -> bool:
+ """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``.
+
+ Runs in a subprocess so that a ModelCompiler failure does not block
+ the caller. The QNN options from ``genai_config.json`` are forwarded
+ unchanged to the compilation session, so each stage is compiled at
+ exactly the optimization level configured in the bundle.
+
+ Args:
+ src_onnx: Source ONNX file path.
+ ctx_out: Destination EPContext ONNX path.
+ stage_key: Human-readable label for logging.
+ qnn_opts: QNN provider options from genai_config (e.g. backend_path,
+ htp_performance_mode, htp_graph_finalization_optimization_mode,
+ soc_model).
+
+ Returns:
+ ``True`` if compilation succeeded; ``False`` on timeout or error.
+ """
+ import multiprocessing
+
+ compile_qnn_opts = dict(qnn_opts or {})
+ compile_timeout_s = 300 # 5 minutes; ctx compiles in ~73s, iter in ~67s
+
+ logger.info(
+ "Compiling stage %r: %s → %s (qnn_opts=%s)",
+ stage_key,
+ src_onnx.name,
+ ctx_out.name,
+ compile_qnn_opts,
+ )
+
+ ctx = multiprocessing.get_context("spawn")
+ proc = ctx.Process(
+ target=_qnn_compile_worker, args=(str(src_onnx), str(ctx_out), compile_qnn_opts)
+ )
+ proc.start()
+ proc.join(timeout=compile_timeout_s)
+
+ if proc.is_alive():
+ logger.error(
+ "Stage %r compilation timed out after %ds — killing subprocess.",
+ stage_key,
+ compile_timeout_s,
+ )
+ proc.kill()
+ proc.join()
+ ctx_out.unlink(missing_ok=True)
+ return False
+
+ if proc.exitcode != 0:
+ logger.warning("Stage %r compilation failed (exit %d)", stage_key, proc.exitcode)
+ ctx_out.unlink(missing_ok=True)
+ return False
+
+ logger.info("Stage %r compiled successfully → %s", stage_key, ctx_out)
+ return True
+
+ def _mirror_non_onnx_files(self, compiled_dir: Path) -> None:
+ """Create symlinks (or copies on Windows) for every non-ONNX file.
+
+ Files are linked/copied into *compiled_dir* so that ``og.Config``
+ finds tokenizer files, specials maps, etc. Existing files are left
+ untouched.
+ """
+ for src in self._bundle_dir.iterdir():
+ if src.name == self._COMPILED_SUBDIR:
+ continue
+ dst = compiled_dir / src.name
+ if dst.exists():
+ continue
+ if src.is_file():
+ try:
+ dst.symlink_to(src.resolve())
+ except (OSError, NotImplementedError):
+ # Symlinks may require elevated privileges on Windows; fall back to copy.
+ shutil.copy2(src, dst)
+
+ @staticmethod
+ def _import_og() -> object:
+ """Import and return the ``onnxruntime_genai`` module.
+
+ Raises:
+ GenaiNotInstalledError: Package not found.
+ """
+ try:
+ import onnxruntime_genai as og
+
+ return og
+ except ImportError as exc:
+ raise GenaiNotInstalledError(
+ "onnxruntime_genai is not installed. "
+ "Install it with: pip install onnxruntime-genai-winml"
+ ) from exc
+
+ def _register_eps(self, og: object) -> None:
+ """Register WinML EPs with ORT GenAI (idempotent, best-effort)."""
+ try:
+ registry = WinMLEPRegistry.get_instance()
+ if registry.winml_available:
+ result = registry.register_execution_providers(ort_genai=True)
+ registered = result.get("onnxruntime_genai", [])
+ logger.info("WinML EPs registered for ORT GenAI: %s", registered)
+ except Exception as exc:
+ logger.warning("WinML EP registration skipped: %s", exc)
+
+ def _read_context_length(self) -> int:
+ """Read ``model.context_length`` from ``genai_config.json``."""
+ cfg = json.loads((self._bundle_dir / "genai_config.json").read_text(encoding="utf-8"))
+ return int(cfg["model"]["context_length"])
+
+
+__all__ = [
+ "GenaiLoadError",
+ "GenaiNotInstalledError",
+ "GenaiSession",
+ "GenaiSessionError",
+ "GenerationConfig",
+]
diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py
new file mode 100644
index 000000000..33645cdaf
--- /dev/null
+++ b/src/winml/modelkit/utils/genai.py
@@ -0,0 +1,663 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+r"""Generic onnxruntime-genai bundle utilities for decoder-pipeline models.
+
+The bundle is a directory that ``onnxruntime-genai`` can load directly via
+``og.Config(str(bundle_dir))``. It contains:
+
+ genai_config.json — pipeline config consumed by onnxruntime-genai
+ ctx.onnx — prefill/context ONNX
+ iter.onnx — iteration/decode ONNX
+ embeddings.onnx — embedding-lookup ONNX
+ lm_head.onnx — lm_head ONNX
+ tokenizer.json — HF tokenizer files (downloaded from the model repo)
+ tokenizer_config.json
+ vocab.json / merges.txt / generation_config.json
+
+The pipeline follows the standard 4-stage decoder layout:
+
+ input_ids → [embeddings] → input_hidden_states
+ → [context | iterator] → output_hidden_states + present KVs
+ → [lm_head] → logits
+
+The context stage runs on the prompt (prefill); the iterator stage runs on each
+subsequent decode step. Both share the same KV cache buffer via genai's
+``past_present_share_buffer`` mode.
+
+Public API::
+
+ from winml.modelkit.utils.genai import (
+ build_genai_config,
+ build_decoder_pipeline_stages,
+ write_genai_bundle,
+ DecoderIOMapping,
+ PipelineStage,
+ qnn_stage_session_options,
+ )
+
+ # Build stages by introspecting the ONNX I/O (no hardcoded tensor names)
+ stages, decoder_io = build_decoder_pipeline_stages(
+ ctx_path, iter_path, num_layers=hf_config.num_hidden_layers, ep="qnn"
+ )
+ cfg = build_genai_config(
+ hf_config, max_cache_len=256, prefill_seq_len=64,
+ pipeline=stages, decoder_io=decoder_io,
+ )
+
+ # Or one-shot bundle assembly
+ write_genai_bundle(
+ Path("out/bundle"),
+ context_onnx=ctx_path,
+ iterator_onnx=iter_path,
+ model_id="Qwen/Qwen3-0.6B",
+ max_cache_len=256,
+ prefill_seq_len=64,
+ embeddings_src=emb_path, # None = skip (add later)
+ lm_head_src=lmh_path, # None = skip (add later)
+ ep="qnn",
+ )
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import re
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+
+logger = logging.getLogger(__name__)
+
+# Default filenames inside the bundle directory.
+DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx"
+DEFAULT_CONTEXT_FILENAME = "ctx.onnx"
+DEFAULT_ITERATOR_FILENAME = "iter.onnx"
+DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx"
+
+# Regex for detecting indexed tensor names such as ``past_keys_3``.
+_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$")
+
+
+# ---------------------------------------------------------------------------
+# Pipeline data structures
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class PipelineStage:
+ """One stage in an onnxruntime-genai multi-model pipeline.
+
+ Attributes:
+ name: Stage key used inside the ``pipeline`` list of ``genai_config.json``.
+ filename: ONNX filename inside the bundle directory.
+ run_on_prompt: Whether genai runs this stage during the prefill pass.
+ run_on_token_gen: Whether genai runs this stage during decode steps.
+ inputs: Actual ONNX input tensor names (not format strings).
+ outputs: Actual ONNX output tensor names (not format strings).
+ is_lm_head: Set ``True`` for the final language-model head stage.
+ """
+
+ name: str
+ filename: str
+ run_on_prompt: bool
+ run_on_token_gen: bool
+ inputs: list[str]
+ outputs: list[str]
+ is_lm_head: bool = False
+ session_options: dict | None = None
+ """Per-stage ORT session options (e.g. provider_options for QNN).
+
+ When set, emitted verbatim as the ``session_options`` key in the
+ ``genai_config.json`` pipeline stage. Leave ``None`` (default) for
+ stages that should run on the default (CPU) provider.
+ """
+
+ def to_dict(self) -> dict:
+ """Serialize to the dict format expected by ``genai_config.json``."""
+ d: dict = {
+ "filename": self.filename,
+ "inputs": list(self.inputs),
+ "outputs": list(self.outputs),
+ "run_on_prompt": self.run_on_prompt,
+ "run_on_token_gen": self.run_on_token_gen,
+ }
+ if self.session_options:
+ d["session_options"] = self.session_options
+ if self.is_lm_head:
+ d["is_lm_head"] = True
+ return d
+
+
+@dataclass
+class DecoderIOMapping:
+ """Maps genai's abstract I/O concepts to ONNX tensor name format strings.
+
+ The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is
+ the convention genai uses to expand per-layer KV cache tensor names
+ (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …).
+
+ Defaults match the Qwen3 transformer-only export naming; override any field
+ when building bundles for models with different tensor names.
+ """
+
+ input_ids: str = "input_ids"
+ past_sequence_length: str = "past_seq_len"
+ total_sequence_length: str = "total_seq_len"
+ past_key_names: str = "past_keys_%d"
+ past_value_names: str = "past_values_%d"
+ logits: str = "logits"
+ present_key_names: str = "present_keys_%d"
+ present_value_names: str = "present_values_%d"
+
+ def inputs_dict(self) -> dict:
+ """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``."""
+ return {
+ "input_ids": self.input_ids,
+ "past_sequence_length": self.past_sequence_length,
+ "total_sequence_length": self.total_sequence_length,
+ "past_key_names": self.past_key_names,
+ "past_value_names": self.past_value_names,
+ }
+
+ def outputs_dict(self) -> dict:
+ """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``."""
+ return {
+ "logits": self.logits,
+ "present_key_names": self.present_key_names,
+ "present_value_names": self.present_value_names,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Generic config builder
+# ---------------------------------------------------------------------------
+
+
+def build_genai_config(
+ hf_config: Any,
+ *,
+ max_cache_len: int,
+ prefill_seq_len: int | None = None,
+ pipeline: list[PipelineStage],
+ decoder_io: DecoderIOMapping | None = None,
+) -> dict:
+ """Build a ``genai_config.json`` dict for any decoder-pipeline model.
+
+ This function is architecture-agnostic: the caller supplies the pipeline
+ stages and the I/O name mapping so no tensor names are hardcoded here.
+
+ Args:
+ hf_config: A ``transformers.PretrainedConfig``. Reads:
+ ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``,
+ ``num_key_value_heads``, ``head_dim`` (optional, falls back to
+ ``hidden_size // num_attention_heads``), ``bos_token_id``,
+ ``eos_token_id``, ``pad_token_id``, ``vocab_size``.
+ max_cache_len: Static KV cache length → ``context_length`` and
+ ``search.max_length``.
+ prefill_seq_len: When given, emits a ``sliding_window`` section with
+ ``window_size=prefill_seq_len``. Pass ``None`` to omit.
+ pipeline: Ordered list of :class:`PipelineStage` describing each
+ model in the genai pipeline.
+ decoder_io: Format-string mapping from genai's abstract I/O names to
+ actual ONNX tensor names. Defaults to
+ :class:`DecoderIOMapping` (the standard names).
+
+ Returns:
+ A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``.
+ """
+ if decoder_io is None:
+ decoder_io = DecoderIOMapping()
+
+ num_layers: int = hf_config.num_hidden_layers
+ head_size: int = getattr(
+ hf_config,
+ "head_dim",
+ hf_config.hidden_size // hf_config.num_attention_heads,
+ )
+
+ eos_token_id = hf_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+
+ pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id
+
+ decoder_section: dict = {
+ "hidden_size": hf_config.hidden_size,
+ "num_attention_heads": hf_config.num_attention_heads,
+ "num_key_value_heads": hf_config.num_key_value_heads,
+ "num_hidden_layers": num_layers,
+ "head_size": head_size,
+ }
+
+ if prefill_seq_len is not None:
+ decoder_section["sliding_window"] = {
+ "window_size": prefill_seq_len,
+ "pad_value": 0,
+ "alignment": "left",
+ "slide_inputs": True,
+ "slide_key_value_cache": False,
+ }
+
+ decoder_section["inputs"] = decoder_io.inputs_dict()
+ decoder_section["outputs"] = decoder_io.outputs_dict()
+ decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline]
+
+ return {
+ "model": {
+ "type": "decoder-pipeline",
+ "bos_token_id": hf_config.bos_token_id,
+ "eos_token_id": eos_token_id,
+ "pad_token_id": pad_token_id,
+ "vocab_size": hf_config.vocab_size,
+ "context_length": max_cache_len,
+ "decoder": decoder_section,
+ },
+ "search": {
+ "max_length": max_cache_len,
+ "min_length": 0,
+ "do_sample": False,
+ "past_present_share_buffer": True,
+ },
+ }
+
+
+# ---------------------------------------------------------------------------
+# ONNX introspection helpers
+# ---------------------------------------------------------------------------
+
+
+def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]:
+ """Return ``(input_names, output_names)`` from an ONNX model graph header.
+
+ External data is intentionally not loaded — only the graph topology is read,
+ so this is fast even for large quantized models.
+ """
+ try:
+ import onnx
+ except ImportError as exc:
+ raise ImportError(
+ "The 'onnx' package is required for ONNX introspection. "
+ "Install it with: pip install onnx"
+ ) from exc
+ model = onnx.load(str(onnx_path), load_external_data=False)
+ return (
+ [inp.name for inp in model.graph.input],
+ [out.name for out in model.graph.output],
+ )
+
+
+def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]:
+ """Detect ``prefix%d`` patterns from a list of indexed tensor names.
+
+ Scans *names* for entries matching ```` where exactly
+ *num_layers* consecutive zero-based indices are present.
+
+ Returns:
+ ``{prefix: "prefix%d"}`` for each qualifying group, in the order the
+ prefixes first appear in *names*. Only groups covering the full
+ ``[0, num_layers)`` index range are returned.
+
+ Examples::
+
+ >>> _detect_format_patterns(
+ ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"],
+ ... num_layers=2,
+ ... )
+ {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"}
+ """
+ groups: dict[str, list[int]] = {}
+ for name in names:
+ m = _KV_INDEXED_RE.match(name)
+ if m:
+ prefix, idx = m.group(1), int(m.group(2))
+ groups.setdefault(prefix, []).append(idx)
+
+ return {
+ prefix: f"{prefix}%d"
+ for prefix, indices in groups.items()
+ if len(indices) == num_layers and sorted(indices) == list(range(num_layers))
+ }
+
+
+def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]:
+ """Sort *patterns* keys by when ``0`` first appears in *names*."""
+
+ def _key(prefix: str) -> int:
+ try:
+ return names.index(f"{prefix}0")
+ except ValueError:
+ return len(names)
+
+ return sorted(patterns.keys(), key=_key)
+
+
+# ---------------------------------------------------------------------------
+# Per-EP stage session_options helpers
+# ---------------------------------------------------------------------------
+
+
+def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict:
+ """Return the ``session_options`` block that routes a stage to QNN HTP.
+
+ Args:
+ log_id: ORT log identifier (shown in ORT logs), e.g.
+ ``"onnxruntime-genai.context"``.
+ soc_model: Snapdragon SoC model number passed to the QNN HTP backend.
+ ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other
+ SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite).
+
+ Returns:
+ Dict suitable for the ``session_options`` key of a pipeline stage in
+ ``genai_config.json``.
+ """
+ return {
+ "log_id": log_id,
+ "provider_options": [
+ {
+ "qnn": {
+ "backend_path": "QnnHtp.dll",
+ "htp_performance_mode": "burst",
+ "htp_graph_finalization_optimization_mode": "3",
+ "soc_model": soc_model,
+ }
+ }
+ ],
+ "intra_op_num_threads": 2,
+ "inter_op_num_threads": 1,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Generic decoder-pipeline stage factory
+# ---------------------------------------------------------------------------
+
+
+def build_decoder_pipeline_stages(
+ context_onnx: str | Path,
+ iterator_onnx: str | Path,
+ num_layers: int,
+ *,
+ context_filename: str = DEFAULT_CONTEXT_FILENAME,
+ iterator_filename: str = DEFAULT_ITERATOR_FILENAME,
+ embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME,
+ lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME,
+ ep: str = "cpu",
+ soc_model: str = "60",
+) -> tuple[list[PipelineStage], DecoderIOMapping]:
+ """Build pipeline stages by introspecting the built ONNX models.
+
+ Reads actual tensor names from *context_onnx* and *iterator_onnx* so the
+ generated ``genai_config.json`` can never drift out of sync with the real
+ model I/O — no tensor names are hardcoded.
+
+ Args:
+ context_onnx: Path to the built prefill/context ONNX.
+ iterator_onnx: Path to the built decode/iterator ONNX.
+ num_layers: Number of transformer layers (``hf_config.num_hidden_layers``).
+ context_filename: Bundle filename for the context model.
+ iterator_filename: Bundle filename for the iterator model.
+ embeddings_filename: Bundle filename for the embeddings model.
+ lm_head_filename: Bundle filename for the lm_head model.
+ ep: Execution provider for the transformer stages. ``"qnn"`` injects
+ QNN HTP ``session_options`` into the ``context`` and ``iterator``
+ stages so they run on the NPU while ``embeddings`` and ``lm_head``
+ continue on CPU. ``"cpu"`` (default) omits ``session_options``
+ from all stages.
+ soc_model: Snapdragon SoC model number forwarded to the QNN backend
+ when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3.
+
+ Returns:
+ ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and
+ the :class:`DecoderIOMapping` derived from the introspected tensor names.
+ """
+ ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx))
+ iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx))
+
+ # Detect per-layer KV format-string patterns in the context model.
+ input_patterns = _detect_format_patterns(ctx_inputs, num_layers)
+ output_patterns = _detect_format_patterns(ctx_outputs, num_layers)
+
+ in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs)
+ out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs)
+
+ past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d"
+ past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d"
+ pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d"
+ pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d"
+
+ # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars.
+ non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)]
+ seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)]
+ hidden_state_in = next(
+ (n for n in non_indexed if n not in seq_len_names), "input_hidden_states"
+ )
+ past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len")
+ total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len")
+
+ # Non-indexed output: hidden-state output of the transformer stack.
+ hidden_state_out = next(
+ (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states"
+ )
+
+ decoder_io = DecoderIOMapping(
+ past_sequence_length=past_seq_name,
+ total_sequence_length=total_seq_name,
+ past_key_names=past_key_fmt,
+ past_value_names=past_val_fmt,
+ present_key_names=pres_key_fmt,
+ present_value_names=pres_val_fmt,
+ )
+
+ # Per-stage session_options: NPU stages get QNN config; CPU and others get None.
+ ctx_session_opts: dict | None = None
+ iter_session_opts: dict | None = None
+ if ep == "qnn":
+ ctx_session_opts = qnn_stage_session_options(
+ "onnxruntime-genai.context", soc_model=soc_model
+ )
+ iter_session_opts = qnn_stage_session_options(
+ "onnxruntime-genai.iterator", soc_model=soc_model
+ )
+
+ stages: list[PipelineStage] = [
+ PipelineStage(
+ name="embeddings",
+ filename=embeddings_filename,
+ run_on_prompt=True,
+ run_on_token_gen=True,
+ inputs=[decoder_io.input_ids],
+ outputs=[hidden_state_in],
+ ),
+ PipelineStage(
+ name="context",
+ filename=context_filename,
+ run_on_prompt=True,
+ run_on_token_gen=False,
+ inputs=ctx_inputs,
+ outputs=ctx_outputs,
+ session_options=ctx_session_opts,
+ ),
+ PipelineStage(
+ name="iterator",
+ filename=iterator_filename,
+ run_on_prompt=False,
+ run_on_token_gen=True,
+ inputs=iter_inputs,
+ outputs=iter_outputs,
+ session_options=iter_session_opts,
+ ),
+ PipelineStage(
+ name="lm_head",
+ filename=lm_head_filename,
+ run_on_prompt=True,
+ run_on_token_gen=True,
+ inputs=[hidden_state_out],
+ outputs=[decoder_io.logits],
+ is_lm_head=True,
+ ),
+ ]
+ return stages, decoder_io
+
+
+# ---------------------------------------------------------------------------
+# Bundle assembler
+# ---------------------------------------------------------------------------
+
+
+def write_genai_bundle(
+ output_dir: str | Path,
+ *,
+ context_onnx: str | Path,
+ iterator_onnx: str | Path,
+ model_id: str,
+ max_cache_len: int,
+ prefill_seq_len: int,
+ embeddings_src: str | Path | None = None,
+ lm_head_src: str | Path | None = None,
+ context_filename: str = DEFAULT_CONTEXT_FILENAME,
+ iterator_filename: str = DEFAULT_ITERATOR_FILENAME,
+ embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME,
+ lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME,
+ ep: str = "cpu",
+ soc_model: str = "60",
+) -> Path:
+ """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*.
+
+ Copies the winml-built transformer ONNX files, optional embedding /
+ lm_head models, HF tokenizer files, and writes ``genai_config.json``.
+ Tensor names in the config are derived by introspecting the built ONNX
+ files rather than being hardcoded, so this works for any model that
+ follows the 4-stage decoder-pipeline layout.
+
+ Args:
+ output_dir: Destination directory (created if absent).
+ context_onnx: Path to the built prefill/context ONNX.
+ iterator_onnx: Path to the built decode/iterator ONNX.
+ model_id: HuggingFace model ID or local path for config + tokenizer.
+ max_cache_len: Static KV cache length (= ``context_length`` in genai).
+ prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``).
+ embeddings_src: Source path of the embeddings ONNX. ``None`` = skip.
+ lm_head_src: Source path of the lm_head ONNX. ``None`` = skip.
+ context_filename: Bundle filename for the context model.
+ iterator_filename: Bundle filename for the iterator model.
+ embeddings_filename: Bundle filename for the embeddings model.
+ lm_head_filename: Bundle filename for the lm_head model.
+ ep: Execution provider for the transformer (context/iterator) stages.
+ ``"qnn"`` injects QNN HTP ``session_options`` so those stages run
+ on the NPU while embeddings and lm_head run on CPU.
+ ``"cpu"`` (default) omits ``session_options`` (all stages on CPU).
+ soc_model: Snapdragon SoC model passed to the QNN backend when
+ ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite.
+
+ Returns:
+ Path to the written ``genai_config.json``.
+ """
+ from transformers import AutoConfig, AutoTokenizer
+
+ from winml.modelkit.onnx import copy_onnx_model
+
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ context_onnx = Path(context_onnx)
+ iterator_onnx = Path(iterator_onnx)
+
+ # 1. Copy winml-built transformer ONNX files.
+ logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename)
+ copy_onnx_model(context_onnx, output_dir / context_filename)
+
+ logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename)
+ copy_onnx_model(iterator_onnx, output_dir / iterator_filename)
+
+ # 2. Copy placeholder models (embeddings + lm_head).
+ if embeddings_src is not None:
+ logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename)
+ copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename)
+ else:
+ logger.warning(
+ "embeddings_src not provided — '%s' is missing from bundle.",
+ embeddings_filename,
+ )
+
+ if lm_head_src is not None:
+ logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename)
+ copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename)
+ else:
+ logger.warning(
+ "lm_head_src not provided — '%s' is missing from bundle.",
+ lm_head_filename,
+ )
+
+ # 3. Save tokenizer files from the HF snapshot.
+ logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir)
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer.save_pretrained(str(output_dir))
+
+ # 4. Build pipeline stages by introspecting the source ONNX files.
+ hf_config = AutoConfig.from_pretrained(model_id)
+ stages, decoder_io = build_decoder_pipeline_stages(
+ context_onnx,
+ iterator_onnx,
+ num_layers=hf_config.num_hidden_layers,
+ context_filename=context_filename,
+ iterator_filename=iterator_filename,
+ embeddings_filename=embeddings_filename,
+ lm_head_filename=lm_head_filename,
+ ep=ep,
+ soc_model=soc_model,
+ )
+
+ # 5. Write genai_config.json.
+ config = build_genai_config(
+ hf_config,
+ max_cache_len=max_cache_len,
+ prefill_seq_len=prefill_seq_len,
+ pipeline=stages,
+ decoder_io=decoder_io,
+ )
+ config_path = output_dir / "genai_config.json"
+ config_path.write_text(json.dumps(config, indent=2), encoding="utf-8")
+ logger.info("Wrote genai_config.json -> %s", config_path)
+
+ _log_bundle_summary(output_dir, config_path)
+ return config_path
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None:
+ """Print a human-readable summary of the assembled bundle."""
+ files = sorted(bundle_dir.iterdir())
+ lines = [f"\n=== genai bundle: {bundle_dir} ==="]
+ for f in files:
+ size_kb = f.stat().st_size / 1024
+ tag = ""
+ if f.name == "genai_config.json":
+ tag = " <- pipeline config"
+ elif f.name.endswith(".onnx"):
+ tag = " <- ONNX graph"
+ elif f.name.endswith(".data"):
+ tag = " <- ONNX external weights"
+ lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}")
+ lines.append(f"\nConfig written to: {config_path}")
+ logger.info("\n".join(lines))
+
+
+__all__ = [
+ "DEFAULT_CONTEXT_FILENAME",
+ "DEFAULT_EMBEDDINGS_FILENAME",
+ "DEFAULT_ITERATOR_FILENAME",
+ "DEFAULT_LM_HEAD_FILENAME",
+ "DecoderIOMapping",
+ "PipelineStage",
+ "build_decoder_pipeline_stages",
+ "build_genai_config",
+ "qnn_stage_session_options",
+ "write_genai_bundle",
+]
diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py
new file mode 100644
index 000000000..e359f02e2
--- /dev/null
+++ b/tests/unit/models/qwen3/test_genai_config.py
@@ -0,0 +1,515 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+"""Unit tests for the Qwen3 genai config builder."""
+
+from __future__ import annotations
+
+from types import SimpleNamespace
+from unittest.mock import patch
+
+from winml.modelkit.models.hf.qwen3.genai import (
+ DEFAULT_CONTEXT_FILENAME,
+ DEFAULT_EMBEDDINGS_FILENAME,
+ DEFAULT_ITERATOR_FILENAME,
+ DEFAULT_LM_HEAD_FILENAME,
+ DecoderIOMapping,
+ PipelineStage,
+ _detect_format_patterns,
+ build_genai_config,
+ build_qwen3_transformer_only_stages,
+)
+
+
+# ---------------------------------------------------------------------------
+# Helpers
+# ---------------------------------------------------------------------------
+
+
+def _mock_config(
+ *,
+ num_hidden_layers: int = 28,
+ hidden_size: int = 1024,
+ num_attention_heads: int = 16,
+ num_key_value_heads: int = 8,
+ head_dim: int = 128,
+ bos_token_id: int = 151643,
+ eos_token_id: int = 151645,
+ pad_token_id: int = 151643,
+ vocab_size: int = 151936,
+) -> SimpleNamespace:
+ """Return a minimal stand-in for a HF PretrainedConfig."""
+ return SimpleNamespace(
+ num_hidden_layers=num_hidden_layers,
+ hidden_size=hidden_size,
+ num_attention_heads=num_attention_heads,
+ num_key_value_heads=num_key_value_heads,
+ head_dim=head_dim,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ pad_token_id=pad_token_id,
+ vocab_size=vocab_size,
+ )
+
+
+def _make_pipeline(
+ num_layers: int = 28,
+ *,
+ emb_filename: str = DEFAULT_EMBEDDINGS_FILENAME,
+ ctx_filename: str = DEFAULT_CONTEXT_FILENAME,
+ iter_filename: str = DEFAULT_ITERATOR_FILENAME,
+ lmh_filename: str = DEFAULT_LM_HEAD_FILENAME,
+) -> list[PipelineStage]:
+ """Build a standard 4-stage pipeline for use in unit tests."""
+ ctx_inputs = [
+ "input_hidden_states",
+ "past_seq_len",
+ "total_seq_len",
+ *[f"past_keys_{i}" for i in range(num_layers)],
+ *[f"past_values_{i}" for i in range(num_layers)],
+ ]
+ ctx_outputs = [
+ "output_hidden_states",
+ *[f"present_keys_{i}" for i in range(num_layers)],
+ *[f"present_values_{i}" for i in range(num_layers)],
+ ]
+ return [
+ PipelineStage(
+ "embeddings", emb_filename, True, True, ["input_ids"], ["input_hidden_states"]
+ ),
+ PipelineStage("context", ctx_filename, True, False, ctx_inputs, ctx_outputs),
+ PipelineStage("iterator", iter_filename, False, True, ctx_inputs, ctx_outputs),
+ PipelineStage(
+ "lm_head",
+ lmh_filename,
+ True,
+ True,
+ ["output_hidden_states"],
+ ["logits"],
+ is_lm_head=True,
+ ),
+ ]
+
+
+# ---------------------------------------------------------------------------
+# Tests: build_genai_config
+# ---------------------------------------------------------------------------
+
+
+class TestBuildGenaiConfig:
+ def setup_method(self) -> None:
+ self.cfg = _mock_config()
+ self.result = build_genai_config(
+ self.cfg,
+ max_cache_len=256,
+ prefill_seq_len=64,
+ pipeline=_make_pipeline(),
+ )
+
+ def test_top_level_model_type(self) -> None:
+ assert self.result["model"]["type"] == "decoder-pipeline"
+
+ def test_token_ids(self) -> None:
+ m = self.result["model"]
+ assert m["bos_token_id"] == 151643
+ assert m["eos_token_id"] == 151645
+ assert m["pad_token_id"] == 151643
+ assert m["vocab_size"] == 151936
+
+ def test_context_length_equals_max_cache_len(self) -> None:
+ assert self.result["model"]["context_length"] == 256
+
+ def test_search_max_length_equals_context_length(self) -> None:
+ assert self.result["search"]["max_length"] == self.result["model"]["context_length"]
+
+ def test_search_past_present_share_buffer(self) -> None:
+ assert self.result["search"]["past_present_share_buffer"] is True
+
+ def test_decoder_architecture_params(self) -> None:
+ dec = self.result["model"]["decoder"]
+ assert dec["hidden_size"] == 1024
+ assert dec["num_attention_heads"] == 16
+ assert dec["num_key_value_heads"] == 8
+ assert dec["num_hidden_layers"] == 28
+ assert dec["head_size"] == 128
+
+ def test_sliding_window_present_when_prefill_seq_len_given(self) -> None:
+ sw = self.result["model"]["decoder"]["sliding_window"]
+ assert sw["window_size"] == 64
+ assert sw["slide_inputs"] is True
+ assert sw["slide_key_value_cache"] is False
+
+ def test_sliding_window_absent_when_prefill_seq_len_none(self) -> None:
+ result = build_genai_config(
+ self.cfg,
+ max_cache_len=256,
+ prefill_seq_len=None,
+ pipeline=_make_pipeline(),
+ )
+ assert "sliding_window" not in result["model"]["decoder"]
+
+ def test_decoder_io_tensor_names(self) -> None:
+ inputs = self.result["model"]["decoder"]["inputs"]
+ assert inputs["past_sequence_length"] == "past_seq_len"
+ assert inputs["total_sequence_length"] == "total_seq_len"
+ assert inputs["past_key_names"] == "past_keys_%d"
+ assert inputs["past_value_names"] == "past_values_%d"
+ outputs = self.result["model"]["decoder"]["outputs"]
+ assert outputs["logits"] == "logits"
+ assert outputs["present_key_names"] == "present_keys_%d"
+ assert outputs["present_value_names"] == "present_values_%d"
+
+ def test_custom_decoder_io_mapping(self) -> None:
+ custom_io = DecoderIOMapping(
+ past_key_names="k_%d",
+ past_value_names="v_%d",
+ present_key_names="pk_%d",
+ present_value_names="pv_%d",
+ )
+ result = build_genai_config(
+ self.cfg,
+ max_cache_len=256,
+ prefill_seq_len=64,
+ pipeline=_make_pipeline(),
+ decoder_io=custom_io,
+ )
+ dec_inputs = result["model"]["decoder"]["inputs"]
+ assert dec_inputs["past_key_names"] == "k_%d"
+ assert dec_inputs["past_value_names"] == "v_%d"
+ dec_outputs = result["model"]["decoder"]["outputs"]
+ assert dec_outputs["present_key_names"] == "pk_%d"
+ assert dec_outputs["present_value_names"] == "pv_%d"
+
+ def test_pipeline_has_four_stages(self) -> None:
+ pipeline = self.result["model"]["decoder"]["pipeline"]
+ assert len(pipeline) == 4
+ stage_names = [next(iter(s.keys())) for s in pipeline]
+ assert stage_names == ["embeddings", "context", "iterator", "lm_head"]
+
+ def test_embeddings_stage(self) -> None:
+ stage = self.result["model"]["decoder"]["pipeline"][0]["embeddings"]
+ assert stage["filename"] == DEFAULT_EMBEDDINGS_FILENAME
+ assert stage["inputs"] == ["input_ids"]
+ assert stage["outputs"] == ["input_hidden_states"]
+ assert stage["run_on_prompt"] is True
+ assert stage["run_on_token_gen"] is True
+
+ def test_context_stage(self) -> None:
+ stage = self.result["model"]["decoder"]["pipeline"][1]["context"]
+ assert stage["filename"] == DEFAULT_CONTEXT_FILENAME
+ assert "input_hidden_states" in stage["inputs"]
+ assert "past_seq_len" in stage["inputs"]
+ assert "total_seq_len" in stage["inputs"]
+ assert stage["run_on_prompt"] is True
+ assert stage["run_on_token_gen"] is False
+
+ def test_iterator_stage(self) -> None:
+ stage = self.result["model"]["decoder"]["pipeline"][2]["iterator"]
+ assert stage["filename"] == DEFAULT_ITERATOR_FILENAME
+ assert stage["run_on_prompt"] is False
+ assert stage["run_on_token_gen"] is True
+
+ def test_lm_head_stage(self) -> None:
+ stage = self.result["model"]["decoder"]["pipeline"][3]["lm_head"]
+ assert stage["filename"] == DEFAULT_LM_HEAD_FILENAME
+ assert stage["inputs"] == ["output_hidden_states"]
+ assert stage["outputs"] == ["logits"]
+ assert stage["is_lm_head"] is True
+ assert stage["run_on_prompt"] is True
+ assert stage["run_on_token_gen"] is True
+
+ def test_context_kv_inputs_count(self) -> None:
+ """context.inputs must include all 28 past_keys + 28 past_values."""
+ inputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["inputs"]
+ past_keys = [x for x in inputs if x.startswith("past_keys_")]
+ past_values = [x for x in inputs if x.startswith("past_values_")]
+ assert len(past_keys) == 28
+ assert len(past_values) == 28
+ assert set(past_keys) == {f"past_keys_{i}" for i in range(28)}
+ assert set(past_values) == {f"past_values_{i}" for i in range(28)}
+
+ def test_context_outputs_kv_count(self) -> None:
+ outputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["outputs"]
+ present_keys = [x for x in outputs if x.startswith("present_keys_")]
+ present_values = [x for x in outputs if x.startswith("present_values_")]
+ assert len(present_keys) == 28
+ assert len(present_values) == 28
+
+ def test_context_and_iterator_have_same_io(self) -> None:
+ ctx = self.result["model"]["decoder"]["pipeline"][1]["context"]
+ itr = self.result["model"]["decoder"]["pipeline"][2]["iterator"]
+ assert ctx["inputs"] == itr["inputs"]
+ assert ctx["outputs"] == itr["outputs"]
+
+ def test_custom_filenames(self) -> None:
+ result = build_genai_config(
+ self.cfg,
+ max_cache_len=512,
+ prefill_seq_len=128,
+ pipeline=_make_pipeline(
+ emb_filename="emb.onnx",
+ ctx_filename="prefill.onnx",
+ iter_filename="decode.onnx",
+ lmh_filename="head.onnx",
+ ),
+ )
+ pipeline = result["model"]["decoder"]["pipeline"]
+ assert pipeline[0]["embeddings"]["filename"] == "emb.onnx"
+ assert pipeline[1]["context"]["filename"] == "prefill.onnx"
+ assert pipeline[2]["iterator"]["filename"] == "decode.onnx"
+ assert pipeline[3]["lm_head"]["filename"] == "head.onnx"
+
+ def test_eos_token_id_list_unpacked(self) -> None:
+ cfg = _mock_config(eos_token_id=[151645, 151643])
+ result = build_genai_config(
+ cfg, max_cache_len=256, prefill_seq_len=64, pipeline=_make_pipeline()
+ )
+ assert result["model"]["eos_token_id"] == 151645
+
+ def test_head_size_derived_when_head_dim_missing(self) -> None:
+ cfg = SimpleNamespace(
+ num_hidden_layers=2,
+ hidden_size=512,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ # no head_dim attribute
+ bos_token_id=0,
+ eos_token_id=1,
+ pad_token_id=0,
+ vocab_size=32000,
+ )
+ result = build_genai_config(
+ cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2)
+ )
+ # head_size = hidden_size // num_attention_heads = 512 // 8 = 64
+ assert result["model"]["decoder"]["head_size"] == 64
+
+ def test_pad_token_id_falls_back_to_bos(self) -> None:
+ cfg = SimpleNamespace(
+ num_hidden_layers=2,
+ hidden_size=512,
+ num_attention_heads=8,
+ num_key_value_heads=4,
+ head_dim=64,
+ bos_token_id=0,
+ eos_token_id=1,
+ pad_token_id=None,
+ vocab_size=32000,
+ )
+ result = build_genai_config(
+ cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2)
+ )
+ assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id
+
+ def test_different_layer_count(self) -> None:
+ cfg = _mock_config(num_hidden_layers=4)
+ result = build_genai_config(
+ cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(4)
+ )
+ inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"]
+ past_keys = [x for x in inputs if x.startswith("past_keys_")]
+ assert len(past_keys) == 4
+ assert {f"past_keys_{i}" for i in range(4)} == set(past_keys)
+
+
+# ---------------------------------------------------------------------------
+# Tests: _detect_format_patterns
+# ---------------------------------------------------------------------------
+
+
+class TestDetectFormatPatterns:
+ def test_detects_two_kv_groups(self) -> None:
+ names = [
+ "input_hidden_states",
+ "past_seq_len",
+ "past_keys_0",
+ "past_keys_1",
+ "past_keys_2",
+ "past_values_0",
+ "past_values_1",
+ "past_values_2",
+ ]
+ result = _detect_format_patterns(names, num_layers=3)
+ assert result == {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"}
+
+ def test_ignores_incomplete_index_range(self) -> None:
+ # Missing index 1 — should not be detected
+ names = ["prefix_0", "prefix_2"]
+ result = _detect_format_patterns(names, num_layers=3)
+ assert "prefix_" not in result
+
+ def test_ignores_wrong_num_layers(self) -> None:
+ # 3 entries but num_layers=5
+ names = ["kv_0", "kv_1", "kv_2"]
+ result = _detect_format_patterns(names, num_layers=5)
+ assert len(result) == 0
+
+ def test_empty_input(self) -> None:
+ assert _detect_format_patterns([], num_layers=4) == {}
+
+ def test_non_indexed_names_ignored(self) -> None:
+ names = ["input_hidden_states", "past_seq_len", "total_seq_len"]
+ result = _detect_format_patterns(names, num_layers=3)
+ assert result == {}
+
+ def test_single_layer_model(self) -> None:
+ names = ["keys_0", "vals_0"]
+ result = _detect_format_patterns(names, num_layers=1)
+ assert result == {"keys_": "keys_%d", "vals_": "vals_%d"}
+
+
+# ---------------------------------------------------------------------------
+# Tests: build_qwen3_transformer_only_stages
+# ---------------------------------------------------------------------------
+
+
+class TestBuildQwen3TransformerOnlyStages:
+ """Uses mocked onnx.load so no real ONNX files are required."""
+
+ def _ctx_inputs(self, n: int = 4) -> list[str]:
+ return [
+ "input_hidden_states",
+ "past_seq_len",
+ "total_seq_len",
+ *[f"past_keys_{i}" for i in range(n)],
+ *[f"past_values_{i}" for i in range(n)],
+ ]
+
+ def _ctx_outputs(self, n: int = 4) -> list[str]:
+ return [
+ "output_hidden_states",
+ *[f"present_keys_{i}" for i in range(n)],
+ *[f"present_values_{i}" for i in range(n)],
+ ]
+
+ def _patch_onnx(self, n: int = 4):
+ ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n))
+ iter_io = (self._ctx_inputs(n), self._ctx_outputs(n))
+ return patch(
+ "winml.modelkit.utils.genai._introspect_onnx_io",
+ side_effect=[ctx_io, iter_io],
+ )
+
+ def test_returns_four_stages(self) -> None:
+ with self._patch_onnx():
+ stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4)
+ assert len(stages) == 4
+ assert [s.name for s in stages] == ["embeddings", "context", "iterator", "lm_head"]
+
+ def test_detected_kv_format_patterns(self) -> None:
+ with self._patch_onnx():
+ _, decoder_io = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4
+ )
+ assert decoder_io.past_key_names == "past_keys_%d"
+ assert decoder_io.past_value_names == "past_values_%d"
+ assert decoder_io.present_key_names == "present_keys_%d"
+ assert decoder_io.present_value_names == "present_values_%d"
+
+ def test_detected_seq_len_names(self) -> None:
+ with self._patch_onnx():
+ _, decoder_io = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4
+ )
+ assert decoder_io.past_sequence_length == "past_seq_len"
+ assert decoder_io.total_sequence_length == "total_seq_len"
+
+ def test_context_stage_inputs_from_onnx(self) -> None:
+ with self._patch_onnx(n=4):
+ stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4)
+ ctx_stage = next(s for s in stages if s.name == "context")
+ assert "input_hidden_states" in ctx_stage.inputs
+ assert "past_keys_0" in ctx_stage.inputs
+ assert "past_values_3" in ctx_stage.inputs
+
+ def test_custom_filenames(self) -> None:
+ with self._patch_onnx():
+ stages, _ = build_qwen3_transformer_only_stages(
+ "ctx.onnx",
+ "iter.onnx",
+ num_layers=4,
+ context_filename="prefill.onnx",
+ iterator_filename="decode.onnx",
+ embeddings_filename="emb.onnx",
+ lm_head_filename="head.onnx",
+ )
+ names = {s.name: s.filename for s in stages}
+ assert names["context"] == "prefill.onnx"
+ assert names["iterator"] == "decode.onnx"
+ assert names["embeddings"] == "emb.onnx"
+ assert names["lm_head"] == "head.onnx"
+
+ def test_roundtrip_with_build_genai_config(self) -> None:
+ """build_qwen3_transformer_only_stages output feeds build_genai_config cleanly."""
+ with self._patch_onnx(n=4):
+ stages, decoder_io = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4
+ )
+ cfg = _mock_config(num_hidden_layers=4)
+ result = build_genai_config(
+ cfg,
+ max_cache_len=128,
+ prefill_seq_len=32,
+ pipeline=stages,
+ decoder_io=decoder_io,
+ )
+ assert result["model"]["type"] == "decoder-pipeline"
+ assert len(result["model"]["decoder"]["pipeline"]) == 4
+
+ def test_cpu_ep_no_session_options(self) -> None:
+ """Default cpu ep: context/iterator stages have no session_options."""
+ with self._patch_onnx():
+ stages, _ = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4, ep="cpu"
+ )
+ ctx = next(s for s in stages if s.name == "context")
+ itr = next(s for s in stages if s.name == "iterator")
+ assert ctx.session_options is None
+ assert itr.session_options is None
+
+ def test_qnn_ep_injects_session_options(self) -> None:
+ """ep='qnn': context/iterator get QNN session_options; emb/lm_head do not."""
+ with self._patch_onnx():
+ stages, _ = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn"
+ )
+ stage_map = {s.name: s for s in stages}
+ assert stage_map["embeddings"].session_options is None
+ assert stage_map["lm_head"].session_options is None
+ ctx_opts = stage_map["context"].session_options
+ itr_opts = stage_map["iterator"].session_options
+ assert ctx_opts is not None
+ assert itr_opts is not None
+ assert ctx_opts["provider_options"][0]["qnn"]["backend_path"] == "QnnHtp.dll"
+ assert itr_opts["log_id"] == "onnxruntime-genai.iterator"
+
+ def test_qnn_session_options_in_serialized_config(self) -> None:
+ """QNN session_options appear in genai_config.json pipeline output."""
+ with self._patch_onnx():
+ stages, decoder_io = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn"
+ )
+ cfg = build_genai_config(
+ _mock_config(num_hidden_layers=4),
+ max_cache_len=256,
+ prefill_seq_len=64,
+ pipeline=stages,
+ decoder_io=decoder_io,
+ )
+ pipeline = cfg["model"]["decoder"]["pipeline"]
+ ctx_dict = next(s for s in pipeline if "context" in s)["context"]
+ itr_dict = next(s for s in pipeline if "iterator" in s)["iterator"]
+ emb_dict = next(s for s in pipeline if "embeddings" in s)["embeddings"]
+ assert "session_options" in ctx_dict
+ assert "session_options" in itr_dict
+ assert "session_options" not in emb_dict
+
+ def test_custom_soc_model(self) -> None:
+ """soc_model parameter propagates to QNN provider_options."""
+ with self._patch_onnx():
+ stages, _ = build_qwen3_transformer_only_stages(
+ "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn", soc_model="73"
+ )
+ ctx = next(s for s in stages if s.name == "context")
+ assert ctx.session_options["provider_options"][0]["qnn"]["soc_model"] == "73"
diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py
index ad53ef352..6881f0f72 100644
--- a/tests/unit/quant/calibration/test_qwen3_calibration.py
+++ b/tests/unit/quant/calibration/test_qwen3_calibration.py
@@ -16,8 +16,9 @@
from types import SimpleNamespace
import numpy as np
-import onnx
import torch
+from onnx import TensorProto, helper
+from onnx import save as onnx_save
from winml.modelkit.quant.calibration.qwen3_transformer_only import (
Qwen3DecodeTrajectoryCalibReader,
@@ -47,27 +48,27 @@ def _fake_config() -> SimpleNamespace:
def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None:
"""Write a minimal graph carrying the inputs the readers introspect."""
inputs = [
- onnx.helper.make_tensor_value_info(
- "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN]
+ helper.make_tensor_value_info(
+ "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN]
),
- onnx.helper.make_tensor_value_info(
- "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM]
+ helper.make_tensor_value_info(
+ "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM]
),
]
- out = onnx.helper.make_tensor_value_info(
- "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN]
+ out = helper.make_tensor_value_info(
+ "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN]
)
- gqa = onnx.helper.make_node(
+ gqa = helper.make_node(
"GroupQueryAttention",
["input_hidden_states"],
["attn_out"],
name="gqa_layer_0",
domain="com.microsoft",
)
- identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"])
- graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out])
- model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)])
- onnx.save(model, str(path))
+ identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"])
+ graph = helper.make_graph([gqa, identity], "tiny", inputs, [out])
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)])
+ onnx_save(model, str(path))
def test_graph_shapes_and_gqa_nodes(tmp_path):
diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py
new file mode 100644
index 000000000..4dcf7ea1c
--- /dev/null
+++ b/tests/unit/session/test_genai_session.py
@@ -0,0 +1,380 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+"""Unit tests for GenaiSession.
+
+All tests that touch load() / generate*() mock onnxruntime_genai so no
+real model files or GPU/NPU hardware is required.
+"""
+
+from __future__ import annotations
+
+import json
+import sys
+from typing import TYPE_CHECKING
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from winml.modelkit.session.genai_session import (
+ GenaiLoadError,
+ GenaiNotInstalledError,
+ GenaiSession,
+ GenaiSessionError,
+ GenerationConfig,
+)
+
+
+if TYPE_CHECKING:
+ from pathlib import Path
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture
+def bundle_dir(tmp_path: Path) -> Path:
+ """Create a minimal genai bundle directory with genai_config.json."""
+ cfg = {
+ "model": {
+ "type": "decoder-pipeline",
+ "context_length": 256,
+ "decoder": {},
+ },
+ "search": {"max_length": 256},
+ }
+ (tmp_path / "genai_config.json").write_text(json.dumps(cfg), encoding="utf-8")
+ return tmp_path
+
+
+@pytest.fixture
+def mock_og() -> MagicMock:
+ """Return a fully mocked onnxruntime_genai module."""
+ og = MagicMock(name="onnxruntime_genai")
+ og.Config.return_value = MagicMock()
+ og.Model.return_value = MagicMock()
+ og.Tokenizer.return_value = MagicMock()
+ og.GeneratorParams.return_value = MagicMock()
+
+ # Generator that yields two tokens then is_done()
+ gen = MagicMock()
+ gen.is_done.side_effect = [False, False, True]
+ gen.get_next_tokens.side_effect = [
+ MagicMock(__getitem__=lambda s, i: 10),
+ MagicMock(__getitem__=lambda s, i: 20),
+ ]
+ og.Generator.return_value = gen
+
+ # TokenizerStream decodes tokens to text
+ stream = MagicMock()
+ stream.decode.side_effect = ["Hello", " world"]
+ og.Tokenizer.return_value.create_stream.return_value = stream
+
+ return og
+
+
+def _patch_og(mock: MagicMock):
+ """Context manager: inject mock_og as onnxruntime_genai in sys.modules."""
+ return patch.dict(sys.modules, {"onnxruntime_genai": mock})
+
+
+# ---------------------------------------------------------------------------
+# Tests: GenaiSession.__init__
+# ---------------------------------------------------------------------------
+
+
+class TestGenaiSessionInit:
+ def test_missing_bundle_dir_raises(self, tmp_path: Path) -> None:
+ with pytest.raises(FileNotFoundError, match="Bundle directory not found"):
+ GenaiSession(tmp_path / "nonexistent")
+
+ def test_missing_config_raises(self, tmp_path: Path) -> None:
+ # Dir exists but no genai_config.json
+ with pytest.raises(FileNotFoundError, match=r"genai_config\.json not found"):
+ GenaiSession(tmp_path)
+
+ def test_unknown_ep_raises(self, bundle_dir: Path) -> None:
+ with pytest.raises(ValueError, match="Unknown EP"):
+ GenaiSession(bundle_dir, ep="tensorrt")
+
+ def test_default_ep_is_cpu(self, bundle_dir: Path) -> None:
+ session = GenaiSession(bundle_dir)
+ assert session.ep == "cpu"
+
+ def test_not_loaded_after_init(self, bundle_dir: Path) -> None:
+ session = GenaiSession(bundle_dir)
+ assert not session.is_loaded
+ assert session.context_length is None
+
+ def test_bundle_dir_property(self, bundle_dir: Path) -> None:
+ session = GenaiSession(bundle_dir)
+ assert session.bundle_dir == bundle_dir
+
+ def test_supported_eps(self, bundle_dir: Path) -> None:
+ for ep in ("cpu", "mixed", "qnn", "dml"):
+ session = GenaiSession(bundle_dir, ep=ep)
+ assert session.ep == ep
+
+
+# ---------------------------------------------------------------------------
+# Tests: load / unload
+# ---------------------------------------------------------------------------
+
+
+class TestGenaiSessionLoad:
+ def test_load_sets_is_loaded(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ session.load()
+ assert session.is_loaded
+
+ def test_load_reads_context_length_from_config(
+ self, bundle_dir: Path, mock_og: MagicMock
+ ) -> None:
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ session.load()
+ assert session.context_length == 256
+
+ def test_context_length_override(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir, context_length=512)
+ session.load()
+ assert session.context_length == 512
+
+ def test_load_is_idempotent(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ session.load()
+ session.load() # second call is a no-op
+ assert mock_og.Model.call_count == 1
+
+ def test_unload_clears_state(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ session.load()
+ session.unload()
+ assert not session.is_loaded
+ assert session.context_length is None
+
+ def test_unload_on_unloaded_session_is_safe(self, bundle_dir: Path) -> None:
+ session = GenaiSession(bundle_dir)
+ session.unload() # should not raise
+
+ def test_context_manager_loads_and_unloads(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ assert session.is_loaded
+ assert not session.is_loaded
+
+ def test_genai_not_installed_raises(self, bundle_dir: Path) -> None:
+ with patch.dict(sys.modules, {"onnxruntime_genai": None}): # type: ignore[dict-item]
+ session = GenaiSession(bundle_dir)
+ with pytest.raises(GenaiNotInstalledError):
+ session.load()
+
+ def test_og_load_error_raises_genai_load_error(
+ self, bundle_dir: Path, mock_og: MagicMock
+ ) -> None:
+ mock_og.Model.side_effect = RuntimeError("driver not found")
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ with pytest.raises(GenaiLoadError, match="driver not found"):
+ session.load()
+
+ def test_og_load_error_leaves_session_unloaded(
+ self, bundle_dir: Path, mock_og: MagicMock
+ ) -> None:
+ mock_og.Model.side_effect = RuntimeError("driver not found")
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ with pytest.raises(GenaiLoadError):
+ session.load()
+ assert not session.is_loaded
+
+
+# ---------------------------------------------------------------------------
+# Tests: EP registration
+# ---------------------------------------------------------------------------
+
+
+class TestEPRegistration:
+ def test_cpu_skips_winml_registration(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with (
+ _patch_og(mock_og),
+ patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls,
+ ):
+ session = GenaiSession(bundle_dir, ep="cpu")
+ session.load()
+ mock_reg_cls.assert_not_called()
+
+ def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ mock_registry = MagicMock()
+ mock_registry.winml_available = True
+ mock_registry.register_execution_providers.return_value = {
+ "onnxruntime_genai": ["QNNExecutionProvider"]
+ }
+ with (
+ _patch_og(mock_og),
+ patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls,
+ ):
+ mock_reg_cls.get_instance.return_value = mock_registry
+ session = GenaiSession(bundle_dir, ep="qnn")
+ session.load()
+ mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True)
+
+ def test_mixed_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ mock_registry = MagicMock()
+ mock_registry.winml_available = True
+ mock_registry.register_execution_providers.return_value = {
+ "onnxruntime_genai": ["QNNExecutionProvider"]
+ }
+ with (
+ _patch_og(mock_og),
+ patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls,
+ ):
+ mock_reg_cls.get_instance.return_value = mock_registry
+ session = GenaiSession(bundle_dir, ep="mixed")
+ session.load()
+ mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True)
+
+ def test_config_not_modified_at_load(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ # EP routing is driven by genai_config.json — we must NOT touch the config.
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir, ep="cpu")
+ session.load()
+ mock_og.Config.return_value.clear_providers.assert_not_called()
+ mock_og.Config.return_value.append_provider.assert_not_called()
+
+
+# ---------------------------------------------------------------------------
+# Tests: generate / generate_streaming
+# ---------------------------------------------------------------------------
+
+
+class TestGenerate:
+ def test_generate_streaming_yields_decoded_tokens(
+ self, bundle_dir: Path, mock_og: MagicMock
+ ) -> None:
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ tokens = list(session.generate_streaming("hi"))
+ assert tokens == ["Hello", " world"]
+
+ def test_generate_returns_joined_string(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ result = session.generate("hi")
+ assert result == "Hello world"
+
+ def test_generate_respects_max_new_tokens(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ # Generator never signals done; we stop at max_new_tokens=1
+ gen = mock_og.Generator.return_value
+ gen.is_done.side_effect = None
+ gen.is_done.return_value = False
+ gen.get_next_tokens.return_value = MagicMock(__getitem__=lambda s, i: 99)
+ mock_og.Tokenizer.return_value.create_stream.return_value.decode.return_value = "x"
+
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ tokens = list(session.generate_streaming("hi", GenerationConfig(max_new_tokens=1)))
+ assert len(tokens) == 1
+
+ def test_generate_with_token_list_input(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ """Pre-encoded token IDs are forwarded directly to append_tokens."""
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ list(session.generate_streaming([1, 2, 3]))
+ gen = mock_og.Generator.return_value
+ gen.append_tokens.assert_called_once_with([1, 2, 3])
+
+ def test_generate_deletes_generator_after_iteration(
+ self, bundle_dir: Path, mock_og: MagicMock
+ ) -> None:
+ """Generator is deleted (not leaked) even on normal completion."""
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ list(session.generate_streaming("hi"))
+ # No assertions needed — test passes if no ResourceWarning / hang
+
+ def test_generate_with_custom_config(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ cfg = GenerationConfig(max_new_tokens=64, do_sample=True, temperature=0.7)
+ with _patch_og(mock_og), GenaiSession(bundle_dir) as session:
+ list(session.generate_streaming("hi", cfg))
+ params = mock_og.GeneratorParams.return_value
+ params.set_search_options.assert_called_once()
+ call_kwargs = params.set_search_options.call_args.kwargs
+ assert call_kwargs["do_sample"] is True
+ assert call_kwargs["temperature"] == 0.7
+
+ def test_generate_uses_context_length_as_max_length(
+ self, bundle_dir: Path, mock_og: MagicMock
+ ) -> None:
+ with _patch_og(mock_og), GenaiSession(bundle_dir, context_length=128) as session:
+ list(session.generate_streaming("hi"))
+ params = mock_og.GeneratorParams.return_value
+ call_kwargs = params.set_search_options.call_args.kwargs
+ assert call_kwargs["max_length"] == 128
+
+ def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) -> None:
+ with _patch_og(mock_og):
+ session = GenaiSession(bundle_dir)
+ assert not session.is_loaded
+ list(session.generate_streaming("hi"))
+ assert session.is_loaded
+
+
+# ---------------------------------------------------------------------------
+# Tests: apply_chatml_template
+# ---------------------------------------------------------------------------
+
+
+class TestApplyChatmlTemplate:
+ def test_user_only(self) -> None:
+ result = GenaiSession.apply_chatml_template("Hello")
+ assert result == "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
+
+ def test_with_system(self) -> None:
+ result = GenaiSession.apply_chatml_template("Hello", system="You are helpful.")
+ assert result.startswith("<|im_start|>system\nYou are helpful.<|im_end|>\n")
+ assert "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" in result
+
+ def test_no_system_no_system_turn(self) -> None:
+ result = GenaiSession.apply_chatml_template("Hi")
+ assert "<|im_start|>system" not in result
+
+ def test_ends_with_assistant_priming(self) -> None:
+ result = GenaiSession.apply_chatml_template("Hi")
+ assert result.endswith("<|im_start|>assistant\n")
+
+
+# ---------------------------------------------------------------------------
+# Tests: GenerationConfig defaults
+# ---------------------------------------------------------------------------
+
+
+class TestGenerationConfig:
+ def test_defaults(self) -> None:
+ cfg = GenerationConfig()
+ assert cfg.max_new_tokens == 128
+ assert cfg.do_sample is False
+ assert cfg.temperature == 1.0
+ assert cfg.top_p == 1.0
+ assert cfg.top_k == 0
+ assert cfg.repetition_penalty == 1.0
+
+ def test_custom_values(self) -> None:
+ cfg = GenerationConfig(max_new_tokens=32, do_sample=True, top_k=50)
+ assert cfg.max_new_tokens == 32
+ assert cfg.do_sample is True
+ assert cfg.top_k == 50
+
+
+# ---------------------------------------------------------------------------
+# Tests: exception hierarchy
+# ---------------------------------------------------------------------------
+
+
+class TestExceptions:
+ def test_genai_not_installed_is_genai_session_error(self) -> None:
+ assert issubclass(GenaiNotInstalledError, GenaiSessionError)
+
+ def test_genai_load_error_is_genai_session_error(self) -> None:
+ assert issubclass(GenaiLoadError, GenaiSessionError)