From 4893195d7d179486f7b6c0a01cad3f1a175c354c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 16:00:47 +0800 Subject: [PATCH 01/12] feat(qwen3): add genai bundle generation and inference script - src/winml/modelkit/models/hf/qwen3/genai.py: new module with build_genai_config() and write_genai_bundle(). build_genai_config generates the onnxruntime-genai pipeline config JSON from a HF PretrainedConfig + max_cache_len + prefill_seq_len. write_genai_bundle copies the winml-built ctx/iter ONNX, optional placeholder embeddings and lm_head ONNX, saves tokenizer files from HF, and writes genai_config.json. - scripts/export_qwen3_transformer_only.py: add --genai-bundle DIR, --embeddings ONNX, --lm-head ONNX flags. When --genai-bundle is set, write_genai_bundle is called after the build to emit a complete onnxruntime-genai bundle. - scripts/infer_genai.py: new inference script. Loads the genai bundle with og.Config, registers WinML EPs (QNN), and runs greedy generation via og.Generator. Supports --ep cpu|qnn, --chat template wrapping, --max-new, --context-length, --verbose. - src/winml/modelkit/models/hf/qwen3/__init__.py: export build_genai_config and write_genai_bundle. - tests/unit/models/qwen3/test_genai_config.py: 21 unit tests for build_genai_config covering pipeline structure, KV name counts, tensor name constants, edge cases (list eos_token_id, missing head_dim, None pad_token_id, custom filenames, variable layer count). --- scripts/export_qwen3_transformer_only.py | 69 +++ scripts/infer_genai.py | 232 +++++++++++ .../modelkit/models/hf/qwen3/__init__.py | 17 +- src/winml/modelkit/models/hf/qwen3/genai.py | 392 ++++++++++++++++++ tests/unit/models/qwen3/test_genai_config.py | 225 ++++++++++ 5 files changed, 934 insertions(+), 1 deletion(-) create mode 100644 scripts/infer_genai.py create mode 100644 src/winml/modelkit/models/hf/qwen3/genai.py create mode 100644 tests/unit/models/qwen3/test_genai_config.py diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 6894af518..202c9c906 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -121,6 +121,43 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: default=None, help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.", ) + + genai = p.add_argument_group( + "genai bundle", + "Options for producing an onnxruntime-genai inference bundle.", + ) + genai.add_argument( + "--genai-bundle", + type=Path, + default=None, + metavar="DIR", + help=( + "If set, assemble a complete onnxruntime-genai bundle in DIR: " + "ctx.onnx (prefill), iter.onnx (decode), genai_config.json, and " + "tokenizer files. Provide --embeddings and --lm-head to include " + "the placeholder models required for end-to-end inference." + ), + ) + genai.add_argument( + "--embeddings", + type=Path, + default=None, + metavar="ONNX", + help=( + "Path to the embeddings ONNX to copy into the genai bundle as " + "embeddings.onnx. Required for end-to-end genai inference." + ), + ) + genai.add_argument( + "--lm-head", + type=Path, + default=None, + metavar="ONNX", + help=( + "Path to the lm_head ONNX to copy into the genai bundle as " + "lm_head.onnx. Required for end-to-end genai inference." + ), + ) return p.parse_args(argv) @@ -164,6 +201,38 @@ def main(argv: list[str] | None = None) -> int: copy_onnx_model(src, dst) print(f" -> copied to {dst}") + # ----------------------------------------------------------------------- + # Optional: assemble an onnxruntime-genai bundle. + # ----------------------------------------------------------------------- + if args.genai_bundle is not None: + from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle + + prefill_path = Path(model.sub_models["decoder_prefill"].onnx_path) + decode_path = Path(model.sub_models["decoder_gen"].onnx_path) + + print(f"\n=== assembling genai bundle -> {args.genai_bundle} ===") + config_path = write_genai_bundle( + args.genai_bundle, + context_onnx=prefill_path, + iterator_onnx=decode_path, + model_id=args.model_id, + max_cache_len=args.max_cache_len, + prefill_seq_len=args.prefill_seq_len, + embeddings_src=args.embeddings, + lm_head_src=args.lm_head, + ) + print(f" genai_config.json -> {config_path}") + if args.embeddings is None: + print( + " WARNING: --embeddings not provided; " + "add embeddings.onnx to the bundle before inference." + ) + if args.lm_head is None: + print( + " WARNING: --lm-head not provided; " + "add lm_head.onnx to the bundle before inference." + ) + return 0 diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py new file mode 100644 index 000000000..4a06ea6be --- /dev/null +++ b/scripts/infer_genai.py @@ -0,0 +1,232 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""onnxruntime-genai inference for the Qwen3 transformer-only pipeline. + +Loads the genai bundle produced by ``export_qwen3_transformer_only.py +--genai-bundle `` and runs greedy text generation. + +The bundle directory must contain ``genai_config.json`` and the four ONNX +graphs it references: + + embeddings.onnx — embedding lookup (input_ids -> input_hidden_states) + ctx.onnx — prefill/context graph (seq_len = prefill_seq_len) + iter.onnx — iteration/decode graph (seq_len = 1) + lm_head.onnx — lm_head (output_hidden_states -> logits) + +It also needs the HF tokenizer files (``tokenizer.json``, +``tokenizer_config.json``, ``vocab.json``, ``merges.txt``, +``generation_config.json``) which ``write_genai_bundle`` downloads +automatically. + +Usage:: + + # CPU sanity check (works anywhere onnxruntime-genai is installed) + uv run python scripts/infer_genai.py --prompt "Hello, who are you?" --chat + + # Qualcomm NPU (registers the QNN EP via the Windows ML EP catalog) + uv run python scripts/infer_genai.py \\ + --prompt "Explain what a transformer is." \\ + --ep qnn --chat + + # Point at a non-default bundle + uv run python scripts/infer_genai.py \\ + --model-dir out/my_bundle --prompt "Hi" --ep cpu + +Dependencies (install in a fresh venv):: + + pip install onnxruntime-genai-winml + pip install "windowsml[with-ort]" # registers QNN EP; also provides onnxruntime +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +import onnxruntime_genai as og + + +# Default bundle directory: /out/qwen3_bundle +_REPO_ROOT = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" + +# The static KV cache length. Must equal ``context_length`` in genai_config.json +# (and the ``--max-cache-len`` used during the winml build). Do not lower this +# value — the KV buffer size is baked into the ONNX graphs. +CONTEXT_LENGTH = 256 + +# Maps the friendly --ep name to the ORT EP canonical name. +_EP_NAME = { + "cpu": "cpu", + "qnn": "QNNExecutionProvider", +} + + +def _register_winml_eps() -> list[str]: + """Discover and register Windows ML execution providers. + + Walks the WinML EP catalog, calls ``ensure_ready()`` on each provider + (downloads via Windows Update if needed), then registers the shared + library with ORT GenAI. Mirrors ``examples/python/winml.py`` from the + onnxruntime-genai repo. + """ + import traceback + + from windowsml import EpCatalog + + registered: list[str] = [] + with EpCatalog() as catalog: + for provider in catalog.find_all_providers(): + provider.ensure_ready() + if not provider.library_path: + continue + try: + og.register_execution_provider_library(provider.name, provider.library_path) + registered.append(provider.name) + except Exception as exc: + print(f"[winml] failed to register {provider.name}: {exc}") + traceback.print_exc() + return registered + + +def _build_og_config(model_dir: Path, ep: str) -> og.Config: + """Create an ``og.Config``, registering WinML EPs when not on CPU.""" + if ep != "cpu": + registered = _register_winml_eps() + print(f"[winml] registered EPs: {registered}") + + config = og.Config(str(model_dir)) + config.clear_providers() + if ep != "cpu": + config.append_provider(_EP_NAME[ep]) + return config + + +def _wrap_chat_template(prompt: str) -> str: + """Wrap *prompt* in the Qwen3 chat template (no thinking mode).""" + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse CLI arguments.""" + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument( + "--prompt", + default="Give me a short introduction to large language models.", + help="Input prompt (default: %(default)s).", + ) + p.add_argument( + "--model-dir", + type=Path, + default=DEFAULT_MODEL_DIR, + metavar="DIR", + help=( + "Path to the genai bundle directory containing genai_config.json " + "and the ONNX / tokenizer files (default: %(default)s)." + ), + ) + p.add_argument( + "--ep", + choices=sorted(_EP_NAME), + default="cpu", + help="Execution provider (default: cpu).", + ) + p.add_argument( + "--max-new", + type=int, + default=128, + help="Maximum number of new tokens to generate (default: %(default)s).", + ) + p.add_argument( + "--chat", + action="store_true", + help="Wrap --prompt in the Qwen3 chat template.", + ) + p.add_argument( + "--context-length", + type=int, + default=CONTEXT_LENGTH, + help=( + "Static KV cache length. Must match the --max-cache-len used " + "during the winml build and the genai_config.json context_length " + "(default: %(default)s). Do NOT lower this value." + ), + ) + p.add_argument( + "--verbose", + action="store_true", + help="Enable onnxruntime-genai native model I/O logging.", + ) + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + """Load the genai bundle and run generation.""" + args = parse_args(argv) + + model_dir: Path = args.model_dir + if not model_dir.exists(): + print( + f"ERROR: model directory not found: {model_dir}\n" + "Run export_qwen3_transformer_only.py --genai-bundle first.", + file=sys.stderr, + ) + return 1 + + config_file = model_dir / "genai_config.json" + if not config_file.exists(): + print( + f"ERROR: genai_config.json not found in {model_dir}\nThe bundle may be incomplete.", + file=sys.stderr, + ) + return 1 + + if args.verbose: + og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + + print(f"[load] ep={args.ep} bundle={model_dir}") + config = _build_og_config(model_dir, args.ep) + model = og.Model(config) + tokenizer = og.Tokenizer(model) + tokenizer_stream = tokenizer.create_stream() + + text = _wrap_chat_template(args.prompt) if args.chat else args.prompt + input_tokens = tokenizer.encode(text) + print(f"[tokens] prompt has {len(input_tokens)} tokens") + + params = og.GeneratorParams(model) + # max_length must equal the static KV cache size so genai sizes the + # total_sequence_length input and KV buffers correctly. + params.set_search_options( + max_length=args.context_length, + do_sample=False, + ) + + generator = og.Generator(model, params) + generator.append_tokens(input_tokens) + + print("[gen] ", end="", flush=True) + t0 = time.monotonic() + n = 0 + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + print(tokenizer_stream.decode(new_token), end="", flush=True) + n += 1 + if n >= args.max_new: + break + + dt = time.monotonic() - t0 + print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 332fb9234..9cbac5568 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -3,4 +3,19 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Qwen3 transformer-only export support (modeling, export ops, IO configs).""" +"""Qwen3 transformer-only export + genai bundle support. + +Modules: + qwen_transformer_only — OnnxConfig, build config, composite model class. + qwen3_modeling — winml-owned Qwen3 module definitions (forward bindings). + qwen3_export_ops — custom ONNX symbolic ops (LpNorm, GQA, 1x1 Conv). + genai — genai_config.json generator + bundle assembler. +""" + +from .genai import build_genai_config, write_genai_bundle + + +__all__ = [ + "build_genai_config", + "write_genai_bundle", +] diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py new file mode 100644 index 000000000..c3b93cfb1 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -0,0 +1,392 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Generate an onnxruntime-genai bundle for the Qwen3 transformer-only pipeline. + +The bundle is a directory that ``onnxruntime-genai`` can load directly via +``og.Config(str(bundle_dir))``. It contains: + + genai_config.json — pipeline config consumed by onnxruntime-genai + ctx.onnx — prefill/context ONNX (built by winml-cli) + iter.onnx — iteration/decode ONNX (built by winml-cli) + embeddings.onnx — embedding-lookup ONNX (placeholder; copy externally) + lm_head.onnx — lm_head ONNX (placeholder; copy externally) + tokenizer.json — HF tokenizer files (downloaded from the model repo) + tokenizer_config.json + vocab.json / merges.txt / generation_config.json + +The pipeline follows the same 4-stage layout as the reference bundle: + + input_ids → [embeddings] → input_hidden_states + → [context | iterator] → output_hidden_states + present KVs + → [lm_head] → logits + +The context stage runs on the prompt (prefill); the iterator stage runs on each +subsequent decode step. Both share the same KV cache buffer via genai's +``past_present_share_buffer`` mode. + +Public API:: + + from winml.modelkit.models.hf.qwen3.genai import build_genai_config, write_genai_bundle + + cfg = build_genai_config(hf_config, max_cache_len=256, prefill_seq_len=64) + write_genai_bundle( + Path("out/bundle"), + context_onnx=ctx_path, + iterator_onnx=iter_path, + model_id="Qwen/Qwen3-0.6B", + max_cache_len=256, + prefill_seq_len=64, + embeddings_src=emb_path, # None = skip (add later) + lm_head_src=lmh_path, # None = skip (add later) + ) +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Fixed tensor name constants — must match qwen_transformer_only.py I/O. +# --------------------------------------------------------------------------- +_INPUT_IDS = "input_ids" +_INPUT_HIDDEN_STATES = "input_hidden_states" +_OUTPUT_HIDDEN_STATES = "output_hidden_states" +_PAST_SEQ_LEN = "past_seq_len" +_TOTAL_SEQ_LEN = "total_seq_len" +_PAST_KEY_FMT = "past_keys_%d" +_PAST_VALUE_FMT = "past_values_%d" +_PRESENT_KEY_FMT = "present_keys_%d" +_PRESENT_VALUE_FMT = "present_values_%d" +_LOGITS = "logits" + +# Default filenames inside the bundle directory. +DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" +DEFAULT_CONTEXT_FILENAME = "ctx.onnx" +DEFAULT_ITERATOR_FILENAME = "iter.onnx" +DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" + +# Tokenizer files to save from the HF snapshot. +_TOKENIZER_FILES = [ + "tokenizer.json", + "tokenizer_config.json", + "vocab.json", + "merges.txt", + "generation_config.json", + "special_tokens_map.json", +] + + +# --------------------------------------------------------------------------- +# Config builder +# --------------------------------------------------------------------------- + + +def build_genai_config( + hf_config: Any, + *, + max_cache_len: int, + prefill_seq_len: int, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> dict: + """Build the ``genai_config.json`` dict for the transformer-only pipeline. + + Args: + hf_config: A ``transformers.PretrainedConfig`` (e.g. from + ``AutoConfig.from_pretrained``). Reads: ``num_hidden_layers``, + ``hidden_size``, ``num_attention_heads``, ``num_key_value_heads``, + ``head_dim`` (or derived), ``bos_token_id``, ``eos_token_id``, + ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length. Becomes ``context_length`` and + ``search.max_length`` in the generated config. + prefill_seq_len: Prefill / context sequence length. Becomes + ``decoder.sliding_window.window_size``. + embeddings_filename: Filename of the embeddings ONNX in the bundle. + context_filename: Filename of the context (prefill) ONNX. + iterator_filename: Filename of the iterator (decode) ONNX. + lm_head_filename: Filename of the lm_head ONNX. + + Returns: + A ``dict`` ready for ``json.dumps`` as ``genai_config.json``. + """ + num_layers: int = hf_config.num_hidden_layers + head_size: int = getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ) + + eos_token_id = hf_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + + pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id + + # Build per-layer KV name lists (same ordering as the reference config). + past_keys = [f"past_keys_{i}" for i in range(num_layers)] + past_values = [f"past_values_{i}" for i in range(num_layers)] + present_keys = [f"present_keys_{i}" for i in range(num_layers)] + present_values = [f"present_values_{i}" for i in range(num_layers)] + + # Transformer stage I/O: hidden states + seq lens + KV buffers. + transformer_inputs = [ + _INPUT_HIDDEN_STATES, + _PAST_SEQ_LEN, + _TOTAL_SEQ_LEN, + *past_keys, + *past_values, + ] + transformer_outputs = [_OUTPUT_HIDDEN_STATES, *present_keys, *present_values] + + return { + "model": { + "type": "decoder-pipeline", + "bos_token_id": hf_config.bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "vocab_size": hf_config.vocab_size, + "context_length": max_cache_len, + "decoder": { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + "sliding_window": { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + }, + "inputs": { + "input_ids": _INPUT_IDS, + "past_sequence_length": _PAST_SEQ_LEN, + "total_sequence_length": _TOTAL_SEQ_LEN, + "past_key_names": _PAST_KEY_FMT, + "past_value_names": _PAST_VALUE_FMT, + }, + "outputs": { + "logits": _LOGITS, + "present_key_names": _PRESENT_KEY_FMT, + "present_value_names": _PRESENT_VALUE_FMT, + }, + "pipeline": [ + { + "embeddings": { + "filename": embeddings_filename, + "inputs": [_INPUT_IDS], + "outputs": [_INPUT_HIDDEN_STATES], + "run_on_prompt": True, + "run_on_token_gen": True, + } + }, + { + "context": { + "filename": context_filename, + "inputs": transformer_inputs, + "outputs": transformer_outputs, + "run_on_prompt": True, + "run_on_token_gen": False, + } + }, + { + "iterator": { + "filename": iterator_filename, + "inputs": transformer_inputs, + "outputs": transformer_outputs, + "run_on_prompt": False, + "run_on_token_gen": True, + } + }, + { + "lm_head": { + "filename": lm_head_filename, + "inputs": [_OUTPUT_HIDDEN_STATES], + "outputs": [_LOGITS], + "is_lm_head": True, + "run_on_prompt": True, + "run_on_token_gen": True, + } + }, + ], + }, + }, + "search": { + "max_length": max_cache_len, + "min_length": 0, + "do_sample": False, + "past_present_share_buffer": True, + }, + } + + +# --------------------------------------------------------------------------- +# Bundle assembler +# --------------------------------------------------------------------------- + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> Path: + """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. + + Copies the winml-built transformer ONNX files, placeholder embedding / + lm_head models (when provided), HF tokenizer files, and writes + ``genai_config.json``. + + Args: + output_dir: Destination directory (created if absent). + context_onnx: Path to the built prefill/context ONNX + (``decoder_prefill`` sub-model output). + iterator_onnx: Path to the built iteration/decode ONNX + (``decoder_gen`` sub-model output). + model_id: HuggingFace model ID or local path used to download the HF + config and tokenizer files. + max_cache_len: Static KV cache length (= ``context_length`` in genai). + prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). + embeddings_src: Source path of the embeddings ONNX to copy into the + bundle. Pass ``None`` to skip (the bundle will be incomplete until + the embeddings model is added separately). + lm_head_src: Source path of the lm_head ONNX to copy. Pass ``None`` + to skip. + context_filename: Filename used for the context ONNX inside the bundle. + iterator_filename: Filename used for the iterator ONNX. + embeddings_filename: Filename used for the embeddings ONNX. + lm_head_filename: Filename used for the lm_head ONNX. + + Returns: + Path to the written ``genai_config.json``. + """ + from transformers import AutoConfig, AutoTokenizer + + from winml.modelkit.onnx import copy_onnx_model + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + context_onnx = Path(context_onnx) + iterator_onnx = Path(iterator_onnx) + + # ------------------------------------------------------------------ + # 1. Copy winml-built transformer ONNX files. + # ------------------------------------------------------------------ + logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) + copy_onnx_model(context_onnx, output_dir / context_filename) + + logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) + copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + + # ------------------------------------------------------------------ + # 2. Copy placeholder models (embeddings + lm_head). + # ------------------------------------------------------------------ + if embeddings_src is not None: + logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) + copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) + else: + logger.warning( + "embeddings_src not provided — '%s' is missing from bundle; " + "add it manually before running inference.", + embeddings_filename, + ) + + if lm_head_src is not None: + logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) + copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) + else: + logger.warning( + "lm_head_src not provided — '%s' is missing from bundle; " + "add it manually before running inference.", + lm_head_filename, + ) + + # ------------------------------------------------------------------ + # 3. Save tokenizer files from the HF snapshot. + # ------------------------------------------------------------------ + logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(str(output_dir)) + # Prune any extra files that save_pretrained creates but genai doesn't need + # (e.g. tokenizer.model for sentencepiece models). Keep only known files. + onnx_filenames = {context_filename, iterator_filename, embeddings_filename, lm_head_filename} + for path in output_dir.iterdir(): + if ( + path.name not in _TOKENIZER_FILES + and path.suffix in (".json", ".txt", ".model") + and path.name not in onnx_filenames + ): + logger.debug("Keeping extra tokenizer file: %s", path.name) + + # ------------------------------------------------------------------ + # 4. Write genai_config.json. + # ------------------------------------------------------------------ + hf_config = AutoConfig.from_pretrained(model_id) + config = build_genai_config( + hf_config, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + embeddings_filename=embeddings_filename, + context_filename=context_filename, + iterator_filename=iterator_filename, + lm_head_filename=lm_head_filename, + ) + config_path = output_dir / "genai_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + logger.info("Wrote genai_config.json -> %s", config_path) + + _log_bundle_summary(output_dir, config_path) + return config_path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: + """Print a human-readable summary of the assembled bundle.""" + files = sorted(bundle_dir.iterdir()) + lines = [f"\n=== genai bundle: {bundle_dir} ==="] + for f in files: + size_kb = f.stat().st_size / 1024 + tag = "" + if f.name == "genai_config.json": + tag = " <- pipeline config" + elif f.name.endswith(".onnx"): + tag = " <- ONNX graph" + elif f.name.endswith(".data"): + tag = " <- ONNX external weights" + lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") + lines.append(f"\nConfig written to: {config_path}") + logger.info("\n".join(lines)) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "build_genai_config", + "write_genai_bundle", +] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py new file mode 100644 index 000000000..5f6930b65 --- /dev/null +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -0,0 +1,225 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the Qwen3 genai config builder.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from winml.modelkit.models.hf.qwen3.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + build_genai_config, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_config( + *, + num_hidden_layers: int = 28, + hidden_size: int = 1024, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + bos_token_id: int = 151643, + eos_token_id: int = 151645, + pad_token_id: int = 151643, + vocab_size: int = 151936, +) -> SimpleNamespace: + """Return a minimal stand-in for a HF PretrainedConfig.""" + return SimpleNamespace( + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + vocab_size=vocab_size, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestBuildGenaiConfig: + def setup_method(self) -> None: + self.cfg = _mock_config() + self.result = build_genai_config(self.cfg, max_cache_len=256, prefill_seq_len=64) + + def test_top_level_model_type(self) -> None: + assert self.result["model"]["type"] == "decoder-pipeline" + + def test_token_ids(self) -> None: + m = self.result["model"] + assert m["bos_token_id"] == 151643 + assert m["eos_token_id"] == 151645 + assert m["pad_token_id"] == 151643 + assert m["vocab_size"] == 151936 + + def test_context_length_equals_max_cache_len(self) -> None: + assert self.result["model"]["context_length"] == 256 + + def test_search_max_length_equals_context_length(self) -> None: + assert self.result["search"]["max_length"] == self.result["model"]["context_length"] + + def test_search_past_present_share_buffer(self) -> None: + assert self.result["search"]["past_present_share_buffer"] is True + + def test_decoder_architecture_params(self) -> None: + dec = self.result["model"]["decoder"] + assert dec["hidden_size"] == 1024 + assert dec["num_attention_heads"] == 16 + assert dec["num_key_value_heads"] == 8 + assert dec["num_hidden_layers"] == 28 + assert dec["head_size"] == 128 + + def test_sliding_window_size_equals_prefill_seq_len(self) -> None: + sw = self.result["model"]["decoder"]["sliding_window"] + assert sw["window_size"] == 64 + assert sw["slide_inputs"] is True + assert sw["slide_key_value_cache"] is False + + def test_decoder_io_tensor_names(self) -> None: + inputs = self.result["model"]["decoder"]["inputs"] + assert inputs["past_sequence_length"] == "past_seq_len" + assert inputs["total_sequence_length"] == "total_seq_len" + assert inputs["past_key_names"] == "past_keys_%d" + assert inputs["past_value_names"] == "past_values_%d" + outputs = self.result["model"]["decoder"]["outputs"] + assert outputs["logits"] == "logits" + assert outputs["present_key_names"] == "present_keys_%d" + assert outputs["present_value_names"] == "present_values_%d" + + def test_pipeline_has_four_stages(self) -> None: + pipeline = self.result["model"]["decoder"]["pipeline"] + assert len(pipeline) == 4 + stage_names = [next(iter(s.keys())) for s in pipeline] + assert stage_names == ["embeddings", "context", "iterator", "lm_head"] + + def test_embeddings_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][0]["embeddings"] + assert stage["filename"] == DEFAULT_EMBEDDINGS_FILENAME + assert stage["inputs"] == ["input_ids"] + assert stage["outputs"] == ["input_hidden_states"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][1]["context"] + assert stage["filename"] == DEFAULT_CONTEXT_FILENAME + assert "input_hidden_states" in stage["inputs"] + assert "past_seq_len" in stage["inputs"] + assert "total_seq_len" in stage["inputs"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is False + + def test_iterator_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert stage["filename"] == DEFAULT_ITERATOR_FILENAME + assert stage["run_on_prompt"] is False + assert stage["run_on_token_gen"] is True + + def test_lm_head_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][3]["lm_head"] + assert stage["filename"] == DEFAULT_LM_HEAD_FILENAME + assert stage["inputs"] == ["output_hidden_states"] + assert stage["outputs"] == ["logits"] + assert stage["is_lm_head"] is True + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_kv_inputs_count(self) -> None: + """context.inputs must include all 28 past_keys + 28 past_values.""" + inputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + past_values = [x for x in inputs if x.startswith("past_values_")] + assert len(past_keys) == 28 + assert len(past_values) == 28 + # All layer indices present + assert set(past_keys) == {f"past_keys_{i}" for i in range(28)} + assert set(past_values) == {f"past_values_{i}" for i in range(28)} + + def test_context_outputs_kv_count(self) -> None: + outputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["outputs"] + present_keys = [x for x in outputs if x.startswith("present_keys_")] + present_values = [x for x in outputs if x.startswith("present_values_")] + assert len(present_keys) == 28 + assert len(present_values) == 28 + + def test_context_and_iterator_have_same_io(self) -> None: + ctx = self.result["model"]["decoder"]["pipeline"][1]["context"] + itr = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert ctx["inputs"] == itr["inputs"] + assert ctx["outputs"] == itr["outputs"] + + def test_custom_filenames(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=512, + prefill_seq_len=128, + embeddings_filename="emb.onnx", + context_filename="prefill.onnx", + iterator_filename="decode.onnx", + lm_head_filename="head.onnx", + ) + pipeline = result["model"]["decoder"]["pipeline"] + assert pipeline[0]["embeddings"]["filename"] == "emb.onnx" + assert pipeline[1]["context"]["filename"] == "prefill.onnx" + assert pipeline[2]["iterator"]["filename"] == "decode.onnx" + assert pipeline[3]["lm_head"]["filename"] == "head.onnx" + + def test_eos_token_id_list_unpacked(self) -> None: + cfg = _mock_config(eos_token_id=[151645, 151643]) + result = build_genai_config(cfg, max_cache_len=256, prefill_seq_len=64) + assert result["model"]["eos_token_id"] == 151645 + + def test_head_size_derived_when_head_dim_missing(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + # no head_dim attribute + bos_token_id=0, + eos_token_id=1, + pad_token_id=0, + vocab_size=32000, + ) + result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + # head_size = hidden_size // num_attention_heads = 512 // 8 = 64 + assert result["model"]["decoder"]["head_size"] == 64 + + def test_pad_token_id_falls_back_to_bos(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=64, + bos_token_id=0, + eos_token_id=1, + pad_token_id=None, + vocab_size=32000, + ) + result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id + + def test_different_layer_count(self) -> None: + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + assert len(past_keys) == 4 + assert {f"past_keys_{i}" for i in range(4)} == set(past_keys) From 432d1e99cc7721abb2cd6dfb3f95f988a9ce9eaa Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 17:46:03 +0800 Subject: [PATCH 02/12] refactor(qwen3/genai): generic build_genai_config with ONNX introspection Replace hardcoded tensor-name constants with a data-driven design: - PipelineStage dataclass: carries name, filename, run_on_prompt/token_gen, inputs, outputs, is_lm_head. Callers construct stages explicitly; no tensor names are baked into build_genai_config itself. - DecoderIOMapping dataclass: holds the %d-style format strings that genai uses to expand per-layer KV tensor names. Defaults match Qwen3 naming but any naming convention is supported. - build_genai_config: now takes pipeline: list[PipelineStage] and decoder_io: DecoderIOMapping. Architecture-agnostic; no Qwen3-specific logic. prefill_seq_len=None omits the sliding_window section. - _introspect_onnx_io: reads graph.input / graph.output from an ONNX model without loading external data weights. - _detect_format_patterns: scans tensor names for indexed groups matching with exactly num_layers consecutive zero-based indices, returns {prefix: 'prefix%d'} patterns. - build_qwen3_transformer_only_stages: Qwen3-specific factory that calls _introspect_onnx_io on the built ctx/iter ONNX, detects KV patterns via _detect_format_patterns, and returns (list[PipelineStage], DecoderIOMapping). Tensor names can never drift from the actual ONNX graph I/O. - write_genai_bundle: delegates to build_qwen3_transformer_only_stages instead of hardcoding names. Tests (35 total, all pass): - TestBuildGenaiConfig: +2 new cases (no sliding_window, custom DecoderIOMapping) - TestDetectFormatPatterns: 6 new unit tests for the pattern detector - TestBuildQwen3TransformerOnlyStages: 6 new tests using patched _introspect_onnx_io (no real ONNX files required) --- .../modelkit/models/hf/qwen3/__init__.py | 11 +- src/winml/modelkit/models/hf/qwen3/genai.py | 530 ++++++++++++------ tests/unit/models/qwen3/test_genai_config.py | 257 ++++++++- 3 files changed, 620 insertions(+), 178 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 9cbac5568..8d8676398 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -12,10 +12,19 @@ genai — genai_config.json generator + bundle assembler. """ -from .genai import build_genai_config, write_genai_bundle +from .genai import ( + DecoderIOMapping, + PipelineStage, + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, +) __all__ = [ + "DecoderIOMapping", + "PipelineStage", "build_genai_config", + "build_qwen3_transformer_only_stages", "write_genai_bundle", ] diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index c3b93cfb1..8c7c14503 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -2,10 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Generate an onnxruntime-genai bundle for the Qwen3 transformer-only pipeline. +r"""Generate an onnxruntime-genai bundle for a transformer-only decoder pipeline. The bundle is a directory that ``onnxruntime-genai`` can load directly via -``og.Config(str(bundle_dir))``. It contains: +``og.Config(str(bundle_dir))``. It contains: genai_config.json — pipeline config consumed by onnxruntime-genai ctx.onnx — prefill/context ONNX (built by winml-cli) @@ -23,14 +23,29 @@ → [lm_head] → logits The context stage runs on the prompt (prefill); the iterator stage runs on each -subsequent decode step. Both share the same KV cache buffer via genai's +subsequent decode step. Both share the same KV cache buffer via genai's ``past_present_share_buffer`` mode. Public API:: - from winml.modelkit.models.hf.qwen3.genai import build_genai_config, write_genai_bundle + from winml.modelkit.models.hf.qwen3.genai import ( + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, + DecoderIOMapping, + PipelineStage, + ) + + # High-level: derive everything from the built ONNX files + stages, decoder_io = build_qwen3_transformer_only_stages( + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers + ) + cfg = build_genai_config( + hf_config, max_cache_len=256, prefill_seq_len=64, + pipeline=stages, decoder_io=decoder_io, + ) - cfg = build_genai_config(hf_config, max_cache_len=256, prefill_seq_len=64) + # Or one-shot bundle assembly write_genai_bundle( Path("out/bundle"), context_onnx=ctx_path, @@ -47,33 +62,21 @@ import json import logging +import re +from dataclasses import dataclass from pathlib import Path from typing import Any logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Fixed tensor name constants — must match qwen_transformer_only.py I/O. -# --------------------------------------------------------------------------- -_INPUT_IDS = "input_ids" -_INPUT_HIDDEN_STATES = "input_hidden_states" -_OUTPUT_HIDDEN_STATES = "output_hidden_states" -_PAST_SEQ_LEN = "past_seq_len" -_TOTAL_SEQ_LEN = "total_seq_len" -_PAST_KEY_FMT = "past_keys_%d" -_PAST_VALUE_FMT = "past_values_%d" -_PRESENT_KEY_FMT = "present_keys_%d" -_PRESENT_VALUE_FMT = "present_values_%d" -_LOGITS = "logits" - # Default filenames inside the bundle directory. DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" DEFAULT_CONTEXT_FILENAME = "ctx.onnx" DEFAULT_ITERATOR_FILENAME = "iter.onnx" DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" -# Tokenizer files to save from the HF snapshot. +# Tokenizer files written by AutoTokenizer.save_pretrained. _TOKENIZER_FILES = [ "tokenizer.json", "tokenizer_config.json", @@ -83,9 +86,93 @@ "special_tokens_map.json", ] +# Regex for detecting indexed tensor names such as ``past_keys_3``. +_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") + # --------------------------------------------------------------------------- -# Config builder +# Pipeline data structures +# --------------------------------------------------------------------------- + + +@dataclass +class PipelineStage: + """One stage in an onnxruntime-genai multi-model pipeline. + + Attributes: + name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. + filename: ONNX filename inside the bundle directory. + run_on_prompt: Whether genai runs this stage during the prefill pass. + run_on_token_gen: Whether genai runs this stage during decode steps. + inputs: Actual ONNX input tensor names (not format strings). + outputs: Actual ONNX output tensor names (not format strings). + is_lm_head: Set ``True`` for the final language-model head stage. + """ + + name: str + filename: str + run_on_prompt: bool + run_on_token_gen: bool + inputs: list[str] + outputs: list[str] + is_lm_head: bool = False + + def to_dict(self) -> dict: + """Serialize to the dict format expected by ``genai_config.json``.""" + d: dict = { + "filename": self.filename, + "inputs": list(self.inputs), + "outputs": list(self.outputs), + "run_on_prompt": self.run_on_prompt, + "run_on_token_gen": self.run_on_token_gen, + } + if self.is_lm_head: + d["is_lm_head"] = True + return d + + +@dataclass +class DecoderIOMapping: + """Maps genai's abstract I/O concepts to ONNX tensor name format strings. + + The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is + the convention genai uses to expand per-layer KV cache tensor names + (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). + + All fields default to the names produced by the Qwen3 transformer-only + export. + """ + + input_ids: str = "input_ids" + past_sequence_length: str = "past_seq_len" + total_sequence_length: str = "total_seq_len" + past_key_names: str = "past_keys_%d" + past_value_names: str = "past_values_%d" + logits: str = "logits" + present_key_names: str = "present_keys_%d" + present_value_names: str = "present_values_%d" + + def inputs_dict(self) -> dict: + """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" + return { + "input_ids": self.input_ids, + "past_sequence_length": self.past_sequence_length, + "total_sequence_length": self.total_sequence_length, + "past_key_names": self.past_key_names, + "past_value_names": self.past_value_names, + } + + def outputs_dict(self) -> dict: + """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" + return { + "logits": self.logits, + "present_key_names": self.present_key_names, + "present_value_names": self.present_value_names, + } + + +# --------------------------------------------------------------------------- +# Generic config builder # --------------------------------------------------------------------------- @@ -93,32 +180,37 @@ def build_genai_config( hf_config: Any, *, max_cache_len: int, - prefill_seq_len: int, - embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, - context_filename: str = DEFAULT_CONTEXT_FILENAME, - iterator_filename: str = DEFAULT_ITERATOR_FILENAME, - lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + prefill_seq_len: int | None = None, + pipeline: list[PipelineStage], + decoder_io: DecoderIOMapping | None = None, ) -> dict: - """Build the ``genai_config.json`` dict for the transformer-only pipeline. + """Build a ``genai_config.json`` dict for any decoder-pipeline model. + + This function is architecture-agnostic: the caller supplies the pipeline + stages and the I/O name mapping so no tensor names are hardcoded here. Args: - hf_config: A ``transformers.PretrainedConfig`` (e.g. from - ``AutoConfig.from_pretrained``). Reads: ``num_hidden_layers``, - ``hidden_size``, ``num_attention_heads``, ``num_key_value_heads``, - ``head_dim`` (or derived), ``bos_token_id``, ``eos_token_id``, - ``pad_token_id``, ``vocab_size``. - max_cache_len: Static KV cache length. Becomes ``context_length`` and - ``search.max_length`` in the generated config. - prefill_seq_len: Prefill / context sequence length. Becomes - ``decoder.sliding_window.window_size``. - embeddings_filename: Filename of the embeddings ONNX in the bundle. - context_filename: Filename of the context (prefill) ONNX. - iterator_filename: Filename of the iterator (decode) ONNX. - lm_head_filename: Filename of the lm_head ONNX. + hf_config: A ``transformers.PretrainedConfig``. Reads: + ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, + ``num_key_value_heads``, ``head_dim`` (optional, falls back to + ``hidden_size // num_attention_heads``), ``bos_token_id``, + ``eos_token_id``, ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length → ``context_length`` and + ``search.max_length``. + prefill_seq_len: When given, emits a ``sliding_window`` section with + ``window_size=prefill_seq_len``. Pass ``None`` to omit. + pipeline: Ordered list of :class:`PipelineStage` describing each + model in the genai pipeline. + decoder_io: Format-string mapping from genai's abstract I/O names to + actual ONNX tensor names. Defaults to + :class:`DecoderIOMapping` (the Qwen3 default names). Returns: - A ``dict`` ready for ``json.dumps`` as ``genai_config.json``. + A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. """ + if decoder_io is None: + decoder_io = DecoderIOMapping() + num_layers: int = hf_config.num_hidden_layers head_size: int = getattr( hf_config, @@ -132,21 +224,26 @@ def build_genai_config( pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id - # Build per-layer KV name lists (same ordering as the reference config). - past_keys = [f"past_keys_{i}" for i in range(num_layers)] - past_values = [f"past_values_{i}" for i in range(num_layers)] - present_keys = [f"present_keys_{i}" for i in range(num_layers)] - present_values = [f"present_values_{i}" for i in range(num_layers)] - - # Transformer stage I/O: hidden states + seq lens + KV buffers. - transformer_inputs = [ - _INPUT_HIDDEN_STATES, - _PAST_SEQ_LEN, - _TOTAL_SEQ_LEN, - *past_keys, - *past_values, - ] - transformer_outputs = [_OUTPUT_HIDDEN_STATES, *present_keys, *present_values] + decoder_section: dict = { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + } + + if prefill_seq_len is not None: + decoder_section["sliding_window"] = { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + } + + decoder_section["inputs"] = decoder_io.inputs_dict() + decoder_section["outputs"] = decoder_io.outputs_dict() + decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] return { "model": { @@ -156,71 +253,7 @@ def build_genai_config( "pad_token_id": pad_token_id, "vocab_size": hf_config.vocab_size, "context_length": max_cache_len, - "decoder": { - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_key_value_heads": hf_config.num_key_value_heads, - "num_hidden_layers": num_layers, - "head_size": head_size, - "sliding_window": { - "window_size": prefill_seq_len, - "pad_value": 0, - "alignment": "left", - "slide_inputs": True, - "slide_key_value_cache": False, - }, - "inputs": { - "input_ids": _INPUT_IDS, - "past_sequence_length": _PAST_SEQ_LEN, - "total_sequence_length": _TOTAL_SEQ_LEN, - "past_key_names": _PAST_KEY_FMT, - "past_value_names": _PAST_VALUE_FMT, - }, - "outputs": { - "logits": _LOGITS, - "present_key_names": _PRESENT_KEY_FMT, - "present_value_names": _PRESENT_VALUE_FMT, - }, - "pipeline": [ - { - "embeddings": { - "filename": embeddings_filename, - "inputs": [_INPUT_IDS], - "outputs": [_INPUT_HIDDEN_STATES], - "run_on_prompt": True, - "run_on_token_gen": True, - } - }, - { - "context": { - "filename": context_filename, - "inputs": transformer_inputs, - "outputs": transformer_outputs, - "run_on_prompt": True, - "run_on_token_gen": False, - } - }, - { - "iterator": { - "filename": iterator_filename, - "inputs": transformer_inputs, - "outputs": transformer_outputs, - "run_on_prompt": False, - "run_on_token_gen": True, - } - }, - { - "lm_head": { - "filename": lm_head_filename, - "inputs": [_OUTPUT_HIDDEN_STATES], - "outputs": [_LOGITS], - "is_lm_head": True, - "run_on_prompt": True, - "run_on_token_gen": True, - } - }, - ], - }, + "decoder": decoder_section, }, "search": { "max_length": max_cache_len, @@ -231,6 +264,186 @@ def build_genai_config( } +# --------------------------------------------------------------------------- +# ONNX introspection helpers +# --------------------------------------------------------------------------- + + +def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: + """Return ``(input_names, output_names)`` from an ONNX model graph header. + + External data is intentionally not loaded — only the graph topology is read, + so this is fast even for large quantized models. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "The 'onnx' package is required for ONNX introspection. " + "Install it with: pip install onnx" + ) from exc + model = onnx.load(str(onnx_path), load_external_data=False) + return ( + [inp.name for inp in model.graph.input], + [out.name for out in model.graph.output], + ) + + +def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: + """Detect ``prefix%d`` patterns from a list of indexed tensor names. + + Scans *names* for entries matching ```` where exactly + *num_layers* consecutive zero-based indices are present. + + Returns: + ``{prefix: "prefix%d"}`` for each qualifying group, in the order the + prefixes first appear in *names*. Only groups covering the full + ``[0, num_layers)`` index range are returned. + + Examples:: + + >>> _detect_format_patterns( + ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], + ... num_layers=2, + ... ) + {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + """ + groups: dict[str, list[int]] = {} + for name in names: + m = _KV_INDEXED_RE.match(name) + if m: + prefix, idx = m.group(1), int(m.group(2)) + groups.setdefault(prefix, []).append(idx) + + return { + prefix: f"{prefix}%d" + for prefix, indices in groups.items() + if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) + } + + +def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: + """Sort *patterns* keys by when ``0`` first appears in *names*.""" + + def _key(prefix: str) -> int: + try: + return names.index(f"{prefix}0") + except ValueError: + return len(names) + + return sorted(patterns.keys(), key=_key) + + +# --------------------------------------------------------------------------- +# Qwen3 transformer-only pipeline factory +# --------------------------------------------------------------------------- + + +def build_qwen3_transformer_only_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build pipeline stages by introspecting the built ONNX models. + + Reads actual tensor names from *context_onnx* and *iterator_onnx* so the + generated ``genai_config.json`` can never drift out of sync with the real + model I/O — no tensor names are hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + + Returns: + ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and + the :class:`DecoderIOMapping` derived from the introspected tensor names. + """ + ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) + iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) + + # Detect per-layer KV format-string patterns in the context model. + input_patterns = _detect_format_patterns(ctx_inputs, num_layers) + output_patterns = _detect_format_patterns(ctx_outputs, num_layers) + + in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) + out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) + + past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" + past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" + pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" + pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" + + # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. + non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] + seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] + hidden_state_in = next( + (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" + ) + past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") + total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") + + # Non-indexed output: hidden-state output of the transformer stack. + hidden_state_out = next( + (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" + ) + + decoder_io = DecoderIOMapping( + past_sequence_length=past_seq_name, + total_sequence_length=total_seq_name, + past_key_names=past_key_fmt, + past_value_names=past_val_fmt, + present_key_names=pres_key_fmt, + present_value_names=pres_val_fmt, + ) + + stages: list[PipelineStage] = [ + PipelineStage( + name="embeddings", + filename=embeddings_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[decoder_io.input_ids], + outputs=[hidden_state_in], + ), + PipelineStage( + name="context", + filename=context_filename, + run_on_prompt=True, + run_on_token_gen=False, + inputs=ctx_inputs, + outputs=ctx_outputs, + ), + PipelineStage( + name="iterator", + filename=iterator_filename, + run_on_prompt=False, + run_on_token_gen=True, + inputs=iter_inputs, + outputs=iter_outputs, + ), + PipelineStage( + name="lm_head", + filename=lm_head_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[hidden_state_out], + outputs=[decoder_io.logits], + is_lm_head=True, + ), + ] + return stages, decoder_io + + # --------------------------------------------------------------------------- # Bundle assembler # --------------------------------------------------------------------------- @@ -255,27 +468,22 @@ def write_genai_bundle( Copies the winml-built transformer ONNX files, placeholder embedding / lm_head models (when provided), HF tokenizer files, and writes - ``genai_config.json``. + ``genai_config.json``. Tensor names in the config are derived by + introspecting the built ONNX files rather than being hardcoded. Args: output_dir: Destination directory (created if absent). - context_onnx: Path to the built prefill/context ONNX - (``decoder_prefill`` sub-model output). - iterator_onnx: Path to the built iteration/decode ONNX - (``decoder_gen`` sub-model output). - model_id: HuggingFace model ID or local path used to download the HF - config and tokenizer files. + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + model_id: HuggingFace model ID or local path for config + tokenizer. max_cache_len: Static KV cache length (= ``context_length`` in genai). prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). - embeddings_src: Source path of the embeddings ONNX to copy into the - bundle. Pass ``None`` to skip (the bundle will be incomplete until - the embeddings model is added separately). - lm_head_src: Source path of the lm_head ONNX to copy. Pass ``None`` - to skip. - context_filename: Filename used for the context ONNX inside the bundle. - iterator_filename: Filename used for the iterator ONNX. - embeddings_filename: Filename used for the embeddings ONNX. - lm_head_filename: Filename used for the lm_head ONNX. + embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. + lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. Returns: Path to the written ``genai_config.json``. @@ -289,25 +497,20 @@ def write_genai_bundle( context_onnx = Path(context_onnx) iterator_onnx = Path(iterator_onnx) - # ------------------------------------------------------------------ # 1. Copy winml-built transformer ONNX files. - # ------------------------------------------------------------------ logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) copy_onnx_model(context_onnx, output_dir / context_filename) logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) copy_onnx_model(iterator_onnx, output_dir / iterator_filename) - # ------------------------------------------------------------------ # 2. Copy placeholder models (embeddings + lm_head). - # ------------------------------------------------------------------ if embeddings_src is not None: logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) else: logger.warning( - "embeddings_src not provided — '%s' is missing from bundle; " - "add it manually before running inference.", + "embeddings_src not provided — '%s' is missing from bundle.", embeddings_filename, ) @@ -316,40 +519,34 @@ def write_genai_bundle( copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) else: logger.warning( - "lm_head_src not provided — '%s' is missing from bundle; " - "add it manually before running inference.", + "lm_head_src not provided — '%s' is missing from bundle.", lm_head_filename, ) - # ------------------------------------------------------------------ # 3. Save tokenizer files from the HF snapshot. - # ------------------------------------------------------------------ logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.save_pretrained(str(output_dir)) - # Prune any extra files that save_pretrained creates but genai doesn't need - # (e.g. tokenizer.model for sentencepiece models). Keep only known files. - onnx_filenames = {context_filename, iterator_filename, embeddings_filename, lm_head_filename} - for path in output_dir.iterdir(): - if ( - path.name not in _TOKENIZER_FILES - and path.suffix in (".json", ".txt", ".model") - and path.name not in onnx_filenames - ): - logger.debug("Keeping extra tokenizer file: %s", path.name) - - # ------------------------------------------------------------------ - # 4. Write genai_config.json. - # ------------------------------------------------------------------ + + # 4. Build pipeline stages by introspecting the source ONNX files. hf_config = AutoConfig.from_pretrained(model_id) + stages, decoder_io = build_qwen3_transformer_only_stages( + context_onnx, + iterator_onnx, + num_layers=hf_config.num_hidden_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + ) + + # 5. Write genai_config.json. config = build_genai_config( hf_config, max_cache_len=max_cache_len, prefill_seq_len=prefill_seq_len, - embeddings_filename=embeddings_filename, - context_filename=context_filename, - iterator_filename=iterator_filename, - lm_head_filename=lm_head_filename, + pipeline=stages, + decoder_io=decoder_io, ) config_path = output_dir / "genai_config.json" config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") @@ -387,6 +584,9 @@ def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: "DEFAULT_EMBEDDINGS_FILENAME", "DEFAULT_ITERATOR_FILENAME", "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", "build_genai_config", + "build_qwen3_transformer_only_stages", "write_genai_bundle", ] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 5f6930b65..012c71cba 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -7,13 +7,18 @@ from __future__ import annotations from types import SimpleNamespace +from unittest.mock import patch from winml.modelkit.models.hf.qwen3.genai import ( DEFAULT_CONTEXT_FILENAME, DEFAULT_EMBEDDINGS_FILENAME, DEFAULT_ITERATOR_FILENAME, DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + _detect_format_patterns, build_genai_config, + build_qwen3_transformer_only_stages, ) @@ -48,15 +53,59 @@ def _mock_config( ) +def _make_pipeline( + num_layers: int = 28, + *, + emb_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + ctx_filename: str = DEFAULT_CONTEXT_FILENAME, + iter_filename: str = DEFAULT_ITERATOR_FILENAME, + lmh_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> list[PipelineStage]: + """Build a standard 4-stage pipeline for use in unit tests.""" + ctx_inputs = [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(num_layers)], + *[f"past_values_{i}" for i in range(num_layers)], + ] + ctx_outputs = [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(num_layers)], + *[f"present_values_{i}" for i in range(num_layers)], + ] + return [ + PipelineStage( + "embeddings", emb_filename, True, True, ["input_ids"], ["input_hidden_states"] + ), + PipelineStage("context", ctx_filename, True, False, ctx_inputs, ctx_outputs), + PipelineStage("iterator", iter_filename, False, True, ctx_inputs, ctx_outputs), + PipelineStage( + "lm_head", + lmh_filename, + True, + True, + ["output_hidden_states"], + ["logits"], + is_lm_head=True, + ), + ] + + # --------------------------------------------------------------------------- -# Tests +# Tests: build_genai_config # --------------------------------------------------------------------------- class TestBuildGenaiConfig: def setup_method(self) -> None: self.cfg = _mock_config() - self.result = build_genai_config(self.cfg, max_cache_len=256, prefill_seq_len=64) + self.result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + ) def test_top_level_model_type(self) -> None: assert self.result["model"]["type"] == "decoder-pipeline" @@ -85,12 +134,21 @@ def test_decoder_architecture_params(self) -> None: assert dec["num_hidden_layers"] == 28 assert dec["head_size"] == 128 - def test_sliding_window_size_equals_prefill_seq_len(self) -> None: + def test_sliding_window_present_when_prefill_seq_len_given(self) -> None: sw = self.result["model"]["decoder"]["sliding_window"] assert sw["window_size"] == 64 assert sw["slide_inputs"] is True assert sw["slide_key_value_cache"] is False + def test_sliding_window_absent_when_prefill_seq_len_none(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=None, + pipeline=_make_pipeline(), + ) + assert "sliding_window" not in result["model"]["decoder"] + def test_decoder_io_tensor_names(self) -> None: inputs = self.result["model"]["decoder"]["inputs"] assert inputs["past_sequence_length"] == "past_seq_len" @@ -102,6 +160,27 @@ def test_decoder_io_tensor_names(self) -> None: assert outputs["present_key_names"] == "present_keys_%d" assert outputs["present_value_names"] == "present_values_%d" + def test_custom_decoder_io_mapping(self) -> None: + custom_io = DecoderIOMapping( + past_key_names="k_%d", + past_value_names="v_%d", + present_key_names="pk_%d", + present_value_names="pv_%d", + ) + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + decoder_io=custom_io, + ) + dec_inputs = result["model"]["decoder"]["inputs"] + assert dec_inputs["past_key_names"] == "k_%d" + assert dec_inputs["past_value_names"] == "v_%d" + dec_outputs = result["model"]["decoder"]["outputs"] + assert dec_outputs["present_key_names"] == "pk_%d" + assert dec_outputs["present_value_names"] == "pv_%d" + def test_pipeline_has_four_stages(self) -> None: pipeline = self.result["model"]["decoder"]["pipeline"] assert len(pipeline) == 4 @@ -147,7 +226,6 @@ def test_context_kv_inputs_count(self) -> None: past_values = [x for x in inputs if x.startswith("past_values_")] assert len(past_keys) == 28 assert len(past_values) == 28 - # All layer indices present assert set(past_keys) == {f"past_keys_{i}" for i in range(28)} assert set(past_values) == {f"past_values_{i}" for i in range(28)} @@ -169,10 +247,12 @@ def test_custom_filenames(self) -> None: self.cfg, max_cache_len=512, prefill_seq_len=128, - embeddings_filename="emb.onnx", - context_filename="prefill.onnx", - iterator_filename="decode.onnx", - lm_head_filename="head.onnx", + pipeline=_make_pipeline( + emb_filename="emb.onnx", + ctx_filename="prefill.onnx", + iter_filename="decode.onnx", + lmh_filename="head.onnx", + ), ) pipeline = result["model"]["decoder"]["pipeline"] assert pipeline[0]["embeddings"]["filename"] == "emb.onnx" @@ -182,7 +262,9 @@ def test_custom_filenames(self) -> None: def test_eos_token_id_list_unpacked(self) -> None: cfg = _mock_config(eos_token_id=[151645, 151643]) - result = build_genai_config(cfg, max_cache_len=256, prefill_seq_len=64) + result = build_genai_config( + cfg, max_cache_len=256, prefill_seq_len=64, pipeline=_make_pipeline() + ) assert result["model"]["eos_token_id"] == 151645 def test_head_size_derived_when_head_dim_missing(self) -> None: @@ -197,7 +279,9 @@ def test_head_size_derived_when_head_dim_missing(self) -> None: pad_token_id=0, vocab_size=32000, ) - result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) # head_size = hidden_size // num_attention_heads = 512 // 8 = 64 assert result["model"]["decoder"]["head_size"] == 64 @@ -213,13 +297,162 @@ def test_pad_token_id_falls_back_to_bos(self) -> None: pad_token_id=None, vocab_size=32000, ) - result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id def test_different_layer_count(self) -> None: cfg = _mock_config(num_hidden_layers=4) - result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(4) + ) inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] past_keys = [x for x in inputs if x.startswith("past_keys_")] assert len(past_keys) == 4 assert {f"past_keys_{i}" for i in range(4)} == set(past_keys) + + +# --------------------------------------------------------------------------- +# Tests: _detect_format_patterns +# --------------------------------------------------------------------------- + + +class TestDetectFormatPatterns: + def test_detects_two_kv_groups(self) -> None: + names = [ + "input_hidden_states", + "past_seq_len", + "past_keys_0", + "past_keys_1", + "past_keys_2", + "past_values_0", + "past_values_1", + "past_values_2", + ] + result = _detect_format_patterns(names, num_layers=3) + assert result == {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + + def test_ignores_incomplete_index_range(self) -> None: + # Missing index 1 — should not be detected + names = ["prefix_0", "prefix_2"] + result = _detect_format_patterns(names, num_layers=3) + assert "prefix_" not in result + + def test_ignores_wrong_num_layers(self) -> None: + # 3 entries but num_layers=5 + names = ["kv_0", "kv_1", "kv_2"] + result = _detect_format_patterns(names, num_layers=5) + assert len(result) == 0 + + def test_empty_input(self) -> None: + assert _detect_format_patterns([], num_layers=4) == {} + + def test_non_indexed_names_ignored(self) -> None: + names = ["input_hidden_states", "past_seq_len", "total_seq_len"] + result = _detect_format_patterns(names, num_layers=3) + assert result == {} + + def test_single_layer_model(self) -> None: + names = ["keys_0", "vals_0"] + result = _detect_format_patterns(names, num_layers=1) + assert result == {"keys_": "keys_%d", "vals_": "vals_%d"} + + +# --------------------------------------------------------------------------- +# Tests: build_qwen3_transformer_only_stages +# --------------------------------------------------------------------------- + + +class TestBuildQwen3TransformerOnlyStages: + """Uses mocked onnx.load so no real ONNX files are required.""" + + def _ctx_inputs(self, n: int = 4) -> list[str]: + return [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(n)], + *[f"past_values_{i}" for i in range(n)], + ] + + def _ctx_outputs(self, n: int = 4) -> list[str]: + return [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(n)], + *[f"present_values_{i}" for i in range(n)], + ] + + def _patch_onnx(self, n: int = 4): + ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + iter_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + return patch( + "winml.modelkit.models.hf.qwen3.genai._introspect_onnx_io", + side_effect=[ctx_io, iter_io], + ) + + def test_returns_four_stages(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + assert len(stages) == 4 + assert [s.name for s in stages] == ["embeddings", "context", "iterator", "lm_head"] + + def test_detected_kv_format_patterns(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_key_names == "past_keys_%d" + assert decoder_io.past_value_names == "past_values_%d" + assert decoder_io.present_key_names == "present_keys_%d" + assert decoder_io.present_value_names == "present_values_%d" + + def test_detected_seq_len_names(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_sequence_length == "past_seq_len" + assert decoder_io.total_sequence_length == "total_seq_len" + + def test_context_stage_inputs_from_onnx(self) -> None: + with self._patch_onnx(n=4): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + ctx_stage = next(s for s in stages if s.name == "context") + assert "input_hidden_states" in ctx_stage.inputs + assert "past_keys_0" in ctx_stage.inputs + assert "past_values_3" in ctx_stage.inputs + + def test_custom_filenames(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", + "iter.onnx", + num_layers=4, + context_filename="prefill.onnx", + iterator_filename="decode.onnx", + embeddings_filename="emb.onnx", + lm_head_filename="head.onnx", + ) + names = {s.name: s.filename for s in stages} + assert names["context"] == "prefill.onnx" + assert names["iterator"] == "decode.onnx" + assert names["embeddings"] == "emb.onnx" + assert names["lm_head"] == "head.onnx" + + def test_roundtrip_with_build_genai_config(self) -> None: + """build_qwen3_transformer_only_stages output feeds build_genai_config cleanly.""" + with self._patch_onnx(n=4): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config( + cfg, + max_cache_len=128, + prefill_seq_len=32, + pipeline=stages, + decoder_io=decoder_io, + ) + assert result["model"]["type"] == "decoder-pipeline" + assert len(result["model"]["decoder"]["pipeline"]) == 4 From d7918c280663a5b8b9980e15ff28d0bed6a97a68 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 18:13:28 +0800 Subject: [PATCH 03/12] feat(session): add GenaiSession for onnxruntime-genai inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - GenaiSession drives og.Model + og.Generator lifecycle for autoregressive text generation; peer class to WinMLSession (not a subclass) - GenerationConfig dataclass: temperature, top_p, top_k, max_new_tokens, repetition_penalty, do_sample - Lazy onnxruntime_genai import via _import_og() — class importable without the package installed (raises GenaiNotInstalledError on first use) - Reuses WinMLEPRegistry for EP discovery/registration (idempotent) - EP support: cpu (clear_providers only), qnn, dml - context_length read from genai_config.json; overridable at construction - generate_streaming() yields decoded token strings; generator del'd in finally - generate() returns joined string; auto-load on first call if not loaded - 33 unit tests; all use patch.dict(sys.modules) to avoid real hardware --- scripts/infer_genai.py | 144 +------ src/winml/modelkit/session/__init__.py | 12 + src/winml/modelkit/session/genai_session.py | 421 ++++++++++++++++++++ tests/unit/session/test_genai_session.py | 349 ++++++++++++++++ 4 files changed, 804 insertions(+), 122 deletions(-) create mode 100644 src/winml/modelkit/session/genai_session.py create mode 100644 tests/unit/session/test_genai_session.py diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 4a06ea6be..69139ec3e 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -5,20 +5,12 @@ r"""onnxruntime-genai inference for the Qwen3 transformer-only pipeline. Loads the genai bundle produced by ``export_qwen3_transformer_only.py ---genai-bundle `` and runs greedy text generation. +--genai-bundle `` and runs greedy text generation using +:class:`~winml.modelkit.session.GenaiSession`. The bundle directory must contain ``genai_config.json`` and the four ONNX -graphs it references: - - embeddings.onnx — embedding lookup (input_ids -> input_hidden_states) - ctx.onnx — prefill/context graph (seq_len = prefill_seq_len) - iter.onnx — iteration/decode graph (seq_len = 1) - lm_head.onnx — lm_head (output_hidden_states -> logits) - -It also needs the HF tokenizer files (``tokenizer.json``, -``tokenizer_config.json``, ``vocab.json``, ``merges.txt``, -``generation_config.json``) which ``write_genai_bundle`` downloads -automatically. +graphs it references (``embeddings.onnx``, ``ctx.onnx``, ``iter.onnx``, +``lm_head.onnx``) plus HF tokenizer files. Usage:: @@ -47,63 +39,14 @@ import time from pathlib import Path -import onnxruntime_genai as og +from winml.modelkit.session import GenaiSession, GenerationConfig # Default bundle directory: /out/qwen3_bundle _REPO_ROOT = Path(__file__).resolve().parent.parent DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" -# The static KV cache length. Must equal ``context_length`` in genai_config.json -# (and the ``--max-cache-len`` used during the winml build). Do not lower this -# value — the KV buffer size is baked into the ONNX graphs. -CONTEXT_LENGTH = 256 - -# Maps the friendly --ep name to the ORT EP canonical name. -_EP_NAME = { - "cpu": "cpu", - "qnn": "QNNExecutionProvider", -} - - -def _register_winml_eps() -> list[str]: - """Discover and register Windows ML execution providers. - - Walks the WinML EP catalog, calls ``ensure_ready()`` on each provider - (downloads via Windows Update if needed), then registers the shared - library with ORT GenAI. Mirrors ``examples/python/winml.py`` from the - onnxruntime-genai repo. - """ - import traceback - - from windowsml import EpCatalog - - registered: list[str] = [] - with EpCatalog() as catalog: - for provider in catalog.find_all_providers(): - provider.ensure_ready() - if not provider.library_path: - continue - try: - og.register_execution_provider_library(provider.name, provider.library_path) - registered.append(provider.name) - except Exception as exc: - print(f"[winml] failed to register {provider.name}: {exc}") - traceback.print_exc() - return registered - - -def _build_og_config(model_dir: Path, ep: str) -> og.Config: - """Create an ``og.Config``, registering WinML EPs when not on CPU.""" - if ep != "cpu": - registered = _register_winml_eps() - print(f"[winml] registered EPs: {registered}") - - config = og.Config(str(model_dir)) - config.clear_providers() - if ep != "cpu": - config.append_provider(_EP_NAME[ep]) - return config +_SUPPORTED_EPS = ["cpu", "qnn", "dml"] def _wrap_chat_template(prompt: str) -> str: @@ -134,7 +77,7 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: ) p.add_argument( "--ep", - choices=sorted(_EP_NAME), + choices=_SUPPORTED_EPS, default="cpu", help="Execution provider (default: cpu).", ) @@ -149,16 +92,6 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: action="store_true", help="Wrap --prompt in the Qwen3 chat template.", ) - p.add_argument( - "--context-length", - type=int, - default=CONTEXT_LENGTH, - help=( - "Static KV cache length. Must match the --max-cache-len used " - "during the winml build and the genai_config.json context_length " - "(default: %(default)s). Do NOT lower this value." - ), - ) p.add_argument( "--verbose", action="store_true", @@ -171,57 +104,24 @@ def main(argv: list[str] | None = None) -> int: """Load the genai bundle and run generation.""" args = parse_args(argv) - model_dir: Path = args.model_dir - if not model_dir.exists(): - print( - f"ERROR: model directory not found: {model_dir}\n" - "Run export_qwen3_transformer_only.py --genai-bundle first.", - file=sys.stderr, - ) - return 1 + text = _wrap_chat_template(args.prompt) if args.chat else args.prompt + gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) - config_file = model_dir / "genai_config.json" - if not config_file.exists(): - print( - f"ERROR: genai_config.json not found in {model_dir}\nThe bundle may be incomplete.", - file=sys.stderr, - ) + try: + session = GenaiSession(args.model_dir, ep=args.ep, verbose=args.verbose) + except FileNotFoundError as exc: + print(f"ERROR: {exc}", file=sys.stderr) return 1 - if args.verbose: - og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) - - print(f"[load] ep={args.ep} bundle={model_dir}") - config = _build_og_config(model_dir, args.ep) - model = og.Model(config) - tokenizer = og.Tokenizer(model) - tokenizer_stream = tokenizer.create_stream() - - text = _wrap_chat_template(args.prompt) if args.chat else args.prompt - input_tokens = tokenizer.encode(text) - print(f"[tokens] prompt has {len(input_tokens)} tokens") - - params = og.GeneratorParams(model) - # max_length must equal the static KV cache size so genai sizes the - # total_sequence_length input and KV buffers correctly. - params.set_search_options( - max_length=args.context_length, - do_sample=False, - ) - - generator = og.Generator(model, params) - generator.append_tokens(input_tokens) - - print("[gen] ", end="", flush=True) - t0 = time.monotonic() - n = 0 - while not generator.is_done(): - generator.generate_next_token() - new_token = generator.get_next_tokens()[0] - print(tokenizer_stream.decode(new_token), end="", flush=True) - n += 1 - if n >= args.max_new: - break + print(f"[load] ep={args.ep} bundle={args.model_dir}") + with session: + print(f"[ctx] context_length={session.context_length}") + print("[gen] ", end="", flush=True) + t0 = time.monotonic() + n = 0 + for token_str in session.generate_streaming(text, gen_cfg): + print(token_str, end="", flush=True) + n += 1 dt = time.monotonic() - t0 print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)") diff --git a/src/winml/modelkit/session/__init__.py b/src/winml/modelkit/session/__init__.py index 5148da0b3..d11673961 100644 --- a/src/winml/modelkit/session/__init__.py +++ b/src/winml/modelkit/session/__init__.py @@ -5,6 +5,13 @@ """WinMLSession - ONNX Runtime session manager with WinML EP integration.""" from .ep_registry import WinMLEPRegistry +from .genai_session import ( + GenaiLoadError, + GenaiNotInstalledError, + GenaiSession, + GenaiSessionError, + GenerationConfig, +) from .monitor.ep_monitor import EPMonitor, NullEPMonitor from .monitor.hw_monitor import HWMonitor from .monitor.openvino_monitor import OpenVinoMonitor @@ -17,6 +24,11 @@ __all__ = [ "EPMonitor", + "GenaiLoadError", + "GenaiNotInstalledError", + "GenaiSession", + "GenaiSessionError", + "GenerationConfig", "HWMonitor", "InferenceError", "NullEPMonitor", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py new file mode 100644 index 000000000..21b0c5b31 --- /dev/null +++ b/src/winml/modelkit/session/genai_session.py @@ -0,0 +1,421 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""GenaiSession — onnxruntime-genai session for multi-model decoder pipelines. + +Manages ``og.Model`` + ``og.Generator`` lifecycle for autoregressive text +generation. Reuses :class:`WinMLEPRegistry` for EP discovery and registration +so EPs are downloaded / registered at most once per process. + +Unlike :class:`WinMLSession` (which wraps ``ort.InferenceSession`` for +single-shot inference), ``GenaiSession`` drives a streaming token-by-token +generation loop. The two classes are peers — neither inherits from the other. + +Bundle directory layout expected by ``onnxruntime-genai``:: + + / + genai_config.json ← required; controls pipeline & search + ctx.onnx ← prefill transformer graph + iter.onnx ← decode transformer graph + embeddings.onnx ← embedding lookup + lm_head.onnx ← logit projection + tokenizer.json ← HF tokenizer files + tokenizer_config.json + ... + +Usage:: + + # Context manager (recommended — auto-loads and unloads) + with GenaiSession("out/qwen3_bundle", ep="qnn") as session: + for token_str in session.generate_streaming("Hello, who are you?"): + print(token_str, end="", flush=True) + + # Manual lifecycle + session = GenaiSession("out/qwen3_bundle", ep="cpu") + session.load() + result = session.generate("What is a transformer?") + session.unload() + +Dependencies:: + + pip install onnxruntime-genai-winml + pip install "windowsml[with-ort]" # registers QNN EP; also provides ORT +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from .ep_registry import WinMLEPRegistry + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# EP name mapping: user-friendly short name → ORT GenAI provider string. +# None means "do not append a provider" (= default CPU execution). +# --------------------------------------------------------------------------- +_EP_PROVIDER_MAP: dict[str, str | None] = { + "cpu": None, + "qnn": "QNNExecutionProvider", + "dml": "DmlExecutionProvider", +} + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass +class GenerationConfig: + """Search / sampling parameters for a single generation call. + + All parameters are forwarded to ``og.GeneratorParams.set_search_options``. + ``max_length`` is **not** configurable here — it is set to the bundle's + ``context_length`` (read from ``genai_config.json``) because the static KV + cache size is baked into the ONNX graphs at export time. + + Attributes: + max_new_tokens: Soft cap on the number of new tokens to generate. + Generation stops when the model signals EOS, when the KV buffer is + exhausted (``context_length``), or when this limit is reached, + whichever comes first. + do_sample: Enable sampling (``True``) vs greedy (``False``). + temperature: Sampling temperature. Ignored when ``do_sample=False``. + top_p: Nucleus sampling probability mass. Ignored when + ``do_sample=False``. + top_k: Top-K sampling. ``0`` disables the filter. Ignored when + ``do_sample=False``. + repetition_penalty: Multiplicative penalty for repeated tokens + (``1.0`` = no penalty). + """ + + max_new_tokens: int = 128 + do_sample: bool = False + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + + +class GenaiSessionError(Exception): + """Base exception for GenaiSession.""" + + +class GenaiNotInstalledError(GenaiSessionError): + """``onnxruntime-genai`` (or ``onnxruntime-genai-winml``) is not installed.""" + + +class GenaiLoadError(GenaiSessionError): + """The bundle could not be loaded (bad config, EP unavailable, etc.).""" + + +# --------------------------------------------------------------------------- +# Session +# --------------------------------------------------------------------------- + + +class GenaiSession: + """ORT GenAI session for multi-model decoder-pipeline inference. + + Wraps ``og.Model`` + ``og.Generator`` to provide a clean generation API. + + The session is **stateless across calls**: each :meth:`generate_streaming` + call creates a fresh ``og.Generator`` so KV state does not persist between + prompts. Thread-safety within a single session is not guaranteed. + + Args: + bundle_dir: Path to the genai bundle directory. Must contain + ``genai_config.json`` and the ONNX files it references. + ep: Execution provider short name (``"cpu"``, ``"qnn"``, ``"dml"``). + Non-CPU EPs trigger WinML EP discovery and registration. + context_length: Override for the static KV cache length. When + ``None`` (default), read from ``genai_config.json``. + Must match the ``--max-cache-len`` used during the winml-cli build. + verbose: Enable ``onnxruntime-genai`` native model I/O logging. + """ + + def __init__( + self, + bundle_dir: str | Path, + ep: str = "cpu", + *, + context_length: int | None = None, + verbose: bool = False, + ) -> None: + self._bundle_dir = Path(bundle_dir) + self._ep = ep.lower() + self._context_length_override = context_length + self._verbose = verbose + + # Resolved at load() time. + self._context_length: int | None = None + + # og.* handles — None until load() is called. + self._model: object | None = None + self._tokenizer: object | None = None + + if not self._bundle_dir.exists(): + raise FileNotFoundError(f"Bundle directory not found: {self._bundle_dir}") + config_path = self._bundle_dir / "genai_config.json" + if not config_path.exists(): + raise FileNotFoundError( + f"genai_config.json not found in {self._bundle_dir}. " + "Run export_qwen3_transformer_only.py --genai-bundle first." + ) + if self._ep not in _EP_PROVIDER_MAP: + raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_EP_PROVIDER_MAP)}") + + logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def load(self) -> None: + """Load ``og.Model`` and tokenizer from the bundle directory. + + Idempotent: calling ``load()`` on an already-loaded session is a no-op. + + Raises: + GenaiNotInstalledError: ``onnxruntime_genai`` is not installed. + GenaiLoadError: The model could not be loaded (EP error, bad config, + missing ONNX files, …). + """ + if self._model is not None: + return + + og = self._import_og() + + # Register WinML EPs to ORT GenAI (skipped for CPU; idempotent). + if self._ep != "cpu": + self._register_eps(og) + + if self._verbose: + og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + + try: + config = og.Config(str(self._bundle_dir)) + config.clear_providers() + provider = _EP_PROVIDER_MAP[self._ep] + if provider is not None: + config.append_provider(provider) + self._model = og.Model(config) + self._tokenizer = og.Tokenizer(self._model) + except Exception as exc: + self._model = None + self._tokenizer = None + raise GenaiLoadError( + f"Failed to load genai bundle from {self._bundle_dir}: {exc}" + ) from exc + + self._context_length = self._context_length_override or self._read_context_length() + logger.info( + "GenaiSession loaded: ep=%s context_length=%d", + self._ep, + self._context_length, + ) + + def unload(self) -> None: + """Release ``og.Model`` and tokenizer handles. + + Safe to call on an unloaded session. + """ + self._model = None + self._tokenizer = None + self._context_length = None + logger.info("GenaiSession unloaded: bundle=%s", self._bundle_dir) + + def __enter__(self) -> GenaiSession: + self.load() + return self + + def __exit__(self, *_: object) -> None: + self.unload() + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + + def generate( + self, + prompt: str | list[int], + config: GenerationConfig | None = None, + ) -> str: + """Generate text and return the full response as a single string. + + This is a convenience wrapper around :meth:`generate_streaming`. + + Args: + prompt: Input text (auto-encoded) or a pre-encoded token-ID list. + config: Generation parameters. Uses :class:`GenerationConfig` + defaults when ``None``. + + Returns: + The generated text (not including the prompt). + """ + return "".join(self.generate_streaming(prompt, config)) + + def generate_streaming( + self, + prompt: str | list[int], + config: GenerationConfig | None = None, + ) -> Iterator[str]: + """Generate text token-by-token, yielding decoded token strings. + + The method auto-loads the session on the first call (lazy-load + equivalent of :meth:`load`). + + Each yield is the decoded string for a single new token. Callers + typically ``print(token, end="", flush=True)`` to stream output. + + Args: + prompt: Input text (auto-encoded via the bundle tokenizer) or a + pre-encoded token-ID list. Pass a pre-formatted string when + chat templates or special tokens are needed — the session is + not aware of any particular model's template format. + config: Generation parameters. Uses :class:`GenerationConfig` + defaults when ``None``. + + Yields: + Decoded string for each newly generated token. + """ + self._ensure_loaded() + og = self._import_og() + cfg = config or GenerationConfig() + + tokens = ( + self._tokenizer.encode(prompt) # type: ignore[union-attr] + if isinstance(prompt, str) + else prompt + ) + + params = og.GeneratorParams(self._model) + params.set_search_options( + max_length=self._context_length, + do_sample=cfg.do_sample, + temperature=cfg.temperature, + top_p=cfg.top_p, + top_k=cfg.top_k, + repetition_penalty=cfg.repetition_penalty, + ) + + generator = og.Generator(self._model, params) + generator.append_tokens(tokens) + + stream = self._tokenizer.create_stream() # type: ignore[union-attr] + n = 0 + try: + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + yield stream.decode(new_token) + n += 1 + if n >= cfg.max_new_tokens: + break + finally: + # Explicit deletion releases the KV cache buffer held by the generator. + del generator + + # ------------------------------------------------------------------ + # Tokenizer helpers + # ------------------------------------------------------------------ + + def encode(self, text: str) -> list[int]: + """Encode *text* to a list of token IDs using the bundle tokenizer.""" + self._ensure_loaded() + return self._tokenizer.encode(text).tolist() # type: ignore[union-attr] + + def decode(self, tokens: list[int]) -> str: + """Decode a list of token IDs to a string.""" + self._ensure_loaded() + return self._tokenizer.decode(tokens) # type: ignore[union-attr] + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_loaded(self) -> bool: + """``True`` if the model is loaded and ready for generation.""" + return self._model is not None + + @property + def bundle_dir(self) -> Path: + """Path to the genai bundle directory.""" + return self._bundle_dir + + @property + def ep(self) -> str: + """Execution provider short name (as passed to ``__init__``).""" + return self._ep + + @property + def context_length(self) -> int | None: + """Static KV cache length, populated after :meth:`load`.""" + return self._context_length + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _ensure_loaded(self) -> None: + if self._model is None: + self.load() + + @staticmethod + def _import_og() -> object: + """Import and return the ``onnxruntime_genai`` module. + + Raises: + GenaiNotInstalledError: Package not found. + """ + try: + import onnxruntime_genai as og + + return og + except ImportError as exc: + raise GenaiNotInstalledError( + "onnxruntime_genai is not installed. " + "Install it with: pip install onnxruntime-genai-winml" + ) from exc + + def _register_eps(self, og: object) -> None: + """Register WinML EPs with ORT GenAI (idempotent, best-effort).""" + try: + registry = WinMLEPRegistry.get_instance() + if registry.winml_available: + result = registry.register_execution_providers(ort_genai=True) + registered = result.get("onnxruntime_genai", []) + logger.info("WinML EPs registered for ORT GenAI: %s", registered) + except Exception as exc: + logger.warning("WinML EP registration skipped: %s", exc) + + def _read_context_length(self) -> int: + """Read ``model.context_length`` from ``genai_config.json``.""" + cfg = json.loads((self._bundle_dir / "genai_config.json").read_text(encoding="utf-8")) + return int(cfg["model"]["context_length"]) + + +__all__ = [ + "GenaiLoadError", + "GenaiNotInstalledError", + "GenaiSession", + "GenaiSessionError", + "GenerationConfig", +] diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py new file mode 100644 index 000000000..dbc815f23 --- /dev/null +++ b/tests/unit/session/test_genai_session.py @@ -0,0 +1,349 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for GenaiSession. + +All tests that touch load() / generate*() mock onnxruntime_genai so no +real model files or GPU/NPU hardware is required. +""" + +from __future__ import annotations + +import json +import sys +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest + +from winml.modelkit.session.genai_session import ( + GenaiLoadError, + GenaiNotInstalledError, + GenaiSession, + GenaiSessionError, + GenerationConfig, +) + + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bundle_dir(tmp_path: Path) -> Path: + """Create a minimal genai bundle directory with genai_config.json.""" + cfg = { + "model": { + "type": "decoder-pipeline", + "context_length": 256, + "decoder": {}, + }, + "search": {"max_length": 256}, + } + (tmp_path / "genai_config.json").write_text(json.dumps(cfg), encoding="utf-8") + return tmp_path + + +@pytest.fixture +def mock_og() -> MagicMock: + """Return a fully mocked onnxruntime_genai module.""" + og = MagicMock(name="onnxruntime_genai") + og.Config.return_value = MagicMock() + og.Model.return_value = MagicMock() + og.Tokenizer.return_value = MagicMock() + og.GeneratorParams.return_value = MagicMock() + + # Generator that yields two tokens then is_done() + gen = MagicMock() + gen.is_done.side_effect = [False, False, True] + gen.get_next_tokens.side_effect = [ + MagicMock(__getitem__=lambda s, i: 10), + MagicMock(__getitem__=lambda s, i: 20), + ] + og.Generator.return_value = gen + + # TokenizerStream decodes tokens to text + stream = MagicMock() + stream.decode.side_effect = ["Hello", " world"] + og.Tokenizer.return_value.create_stream.return_value = stream + + return og + + +def _patch_og(mock: MagicMock): + """Context manager: inject mock_og as onnxruntime_genai in sys.modules.""" + return patch.dict(sys.modules, {"onnxruntime_genai": mock}) + + +# --------------------------------------------------------------------------- +# Tests: GenaiSession.__init__ +# --------------------------------------------------------------------------- + + +class TestGenaiSessionInit: + def test_missing_bundle_dir_raises(self, tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError, match="Bundle directory not found"): + GenaiSession(tmp_path / "nonexistent") + + def test_missing_config_raises(self, tmp_path: Path) -> None: + # Dir exists but no genai_config.json + with pytest.raises(FileNotFoundError, match=r"genai_config\.json not found"): + GenaiSession(tmp_path) + + def test_unknown_ep_raises(self, bundle_dir: Path) -> None: + with pytest.raises(ValueError, match="Unknown EP"): + GenaiSession(bundle_dir, ep="tensorrt") + + def test_default_ep_is_cpu(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert session.ep == "cpu" + + def test_not_loaded_after_init(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert not session.is_loaded + assert session.context_length is None + + def test_bundle_dir_property(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert session.bundle_dir == bundle_dir + + def test_supported_eps(self, bundle_dir: Path) -> None: + for ep in ("cpu", "qnn", "dml"): + session = GenaiSession(bundle_dir, ep=ep) + assert session.ep == ep + + +# --------------------------------------------------------------------------- +# Tests: load / unload +# --------------------------------------------------------------------------- + + +class TestGenaiSessionLoad: + def test_load_sets_is_loaded(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + assert session.is_loaded + + def test_load_reads_context_length_from_config( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + assert session.context_length == 256 + + def test_context_length_override(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir, context_length=512) + session.load() + assert session.context_length == 512 + + def test_load_is_idempotent(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + session.load() # second call is a no-op + assert mock_og.Model.call_count == 1 + + def test_unload_clears_state(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + session.unload() + assert not session.is_loaded + assert session.context_length is None + + def test_unload_on_unloaded_session_is_safe(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + session.unload() # should not raise + + def test_context_manager_loads_and_unloads(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + assert session.is_loaded + assert not session.is_loaded + + def test_genai_not_installed_raises(self, bundle_dir: Path) -> None: + with patch.dict(sys.modules, {"onnxruntime_genai": None}): # type: ignore[dict-item] + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiNotInstalledError): + session.load() + + def test_og_load_error_raises_genai_load_error( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + mock_og.Model.side_effect = RuntimeError("driver not found") + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiLoadError, match="driver not found"): + session.load() + + def test_og_load_error_leaves_session_unloaded( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + mock_og.Model.side_effect = RuntimeError("driver not found") + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiLoadError): + session.load() + assert not session.is_loaded + + +# --------------------------------------------------------------------------- +# Tests: EP registration +# --------------------------------------------------------------------------- + + +class TestEPRegistration: + def test_cpu_skips_winml_registration(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + session = GenaiSession(bundle_dir, ep="cpu") + session.load() + mock_reg_cls.assert_not_called() + + def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: + mock_registry = MagicMock() + mock_registry.winml_available = True + mock_registry.register_execution_providers.return_value = { + "onnxruntime_genai": ["QNNExecutionProvider"] + } + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + mock_reg_cls.get_instance.return_value = mock_registry + session = GenaiSession(bundle_dir, ep="qnn") + session.load() + mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) + + def test_non_cpu_appends_provider_to_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + mock_reg_cls.get_instance.return_value = MagicMock(winml_available=False) + session = GenaiSession(bundle_dir, ep="qnn") + session.load() + mock_og.Config.return_value.append_provider.assert_called_once_with("QNNExecutionProvider") + + def test_cpu_does_not_append_provider(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir, ep="cpu") + session.load() + mock_og.Config.return_value.append_provider.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: generate / generate_streaming +# --------------------------------------------------------------------------- + + +class TestGenerate: + def test_generate_streaming_yields_decoded_tokens( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + tokens = list(session.generate_streaming("hi")) + assert tokens == ["Hello", " world"] + + def test_generate_returns_joined_string(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + result = session.generate("hi") + assert result == "Hello world" + + def test_generate_respects_max_new_tokens(self, bundle_dir: Path, mock_og: MagicMock) -> None: + # Generator never signals done; we stop at max_new_tokens=1 + gen = mock_og.Generator.return_value + gen.is_done.side_effect = None + gen.is_done.return_value = False + gen.get_next_tokens.return_value = MagicMock(__getitem__=lambda s, i: 99) + mock_og.Tokenizer.return_value.create_stream.return_value.decode.return_value = "x" + + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + tokens = list(session.generate_streaming("hi", GenerationConfig(max_new_tokens=1))) + assert len(tokens) == 1 + + def test_generate_with_token_list_input(self, bundle_dir: Path, mock_og: MagicMock) -> None: + """Pre-encoded token IDs are forwarded directly to append_tokens.""" + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming([1, 2, 3])) + gen = mock_og.Generator.return_value + gen.append_tokens.assert_called_once_with([1, 2, 3]) + + def test_generate_deletes_generator_after_iteration( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + """Generator is deleted (not leaked) even on normal completion.""" + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming("hi")) + # No assertions needed — test passes if no ResourceWarning / hang + + def test_generate_with_custom_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + cfg = GenerationConfig(max_new_tokens=64, do_sample=True, temperature=0.7) + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming("hi", cfg)) + params = mock_og.GeneratorParams.return_value + params.set_search_options.assert_called_once() + call_kwargs = params.set_search_options.call_args.kwargs + assert call_kwargs["do_sample"] is True + assert call_kwargs["temperature"] == 0.7 + + def test_generate_uses_context_length_as_max_length( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir, context_length=128) as session: + list(session.generate_streaming("hi")) + params = mock_og.GeneratorParams.return_value + call_kwargs = params.set_search_options.call_args.kwargs + assert call_kwargs["max_length"] == 128 + + def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + assert not session.is_loaded + list(session.generate_streaming("hi")) + assert session.is_loaded + + +# --------------------------------------------------------------------------- +# Tests: GenerationConfig defaults +# --------------------------------------------------------------------------- + + +class TestGenerationConfig: + def test_defaults(self) -> None: + cfg = GenerationConfig() + assert cfg.max_new_tokens == 128 + assert cfg.do_sample is False + assert cfg.temperature == 1.0 + assert cfg.top_p == 1.0 + assert cfg.top_k == 0 + assert cfg.repetition_penalty == 1.0 + + def test_custom_values(self) -> None: + cfg = GenerationConfig(max_new_tokens=32, do_sample=True, top_k=50) + assert cfg.max_new_tokens == 32 + assert cfg.do_sample is True + assert cfg.top_k == 50 + + +# --------------------------------------------------------------------------- +# Tests: exception hierarchy +# --------------------------------------------------------------------------- + + +class TestExceptions: + def test_genai_not_installed_is_genai_session_error(self) -> None: + assert issubclass(GenaiNotInstalledError, GenaiSessionError) + + def test_genai_load_error_is_genai_session_error(self) -> None: + assert issubclass(GenaiLoadError, GenaiSessionError) From 7e2a67b4eff5705410f118c4924ce6a7c1ed31dd Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 18:27:24 +0800 Subject: [PATCH 04/12] feat(session): add GenaiSession.apply_chatml_template static method - Moves chat template logic from infer_genai.py into GenaiSession - Supports optional system prompt - ChatML is not Qwen3-specific; used by Qwen2/3, Yi, Mistral, etc. - infer_genai.py _wrap_chat_template now delegates to the static method - Updated --chat flag help text and script docstring - 4 new tests covering user-only, with-system, no-system-turn, assistant-priming --- scripts/infer_genai.py | 8 ++--- src/winml/modelkit/session/genai_session.py | 37 +++++++++++++++++++++ tests/unit/session/test_genai_session.py | 24 +++++++++++++ 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 69139ec3e..5144fa7bc 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -r"""onnxruntime-genai inference for the Qwen3 transformer-only pipeline. +r"""onnxruntime-genai inference for a genai bundle (decoder-pipeline). Loads the genai bundle produced by ``export_qwen3_transformer_only.py --genai-bundle `` and runs greedy text generation using @@ -50,8 +50,8 @@ def _wrap_chat_template(prompt: str) -> str: - """Wrap *prompt* in the Qwen3 chat template (no thinking mode).""" - return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + """Wrap *prompt* in the ChatML chat template.""" + return GenaiSession.apply_chatml_template(prompt) def parse_args(argv: list[str] | None = None) -> argparse.Namespace: @@ -90,7 +90,7 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: p.add_argument( "--chat", action="store_true", - help="Wrap --prompt in the Qwen3 chat template.", + help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", ) p.add_argument( "--verbose", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 21b0c5b31..229b4e83a 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -332,6 +332,43 @@ def generate_streaming( # Explicit deletion releases the KV cache buffer held by the generator. del generator + # ------------------------------------------------------------------ + # Chat-template helpers + # ------------------------------------------------------------------ + + @staticmethod + def apply_chatml_template( + prompt: str, + system: str | None = None, + ) -> str: + r"""Wrap *prompt* in the ChatML format used by Qwen2/3, Yi, Mistral, etc. + + Produces:: + + <|im_start|>system + {system}<|im_end|> + <|im_start|>user + {prompt}<|im_end|> + <|im_start|>assistant + + The trailing ``<|im_start|>assistant\\n`` primes the model to respond + as the assistant role with no leading newline in its output. + + Args: + prompt: The user turn text. + system: Optional system prompt. When ``None`` no system turn is + prepended. + + Returns: + Formatted string ready to pass to :meth:`generate` / + :meth:`generate_streaming`. + """ + parts: list[str] = [] + if system is not None: + parts.append(f"<|im_start|>system\n{system}<|im_end|>\n") + parts.append(f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n") + return "".join(parts) + # ------------------------------------------------------------------ # Tokenizer helpers # ------------------------------------------------------------------ diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py index dbc815f23..4859ef11e 100644 --- a/tests/unit/session/test_genai_session.py +++ b/tests/unit/session/test_genai_session.py @@ -314,6 +314,30 @@ def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) assert session.is_loaded +# --------------------------------------------------------------------------- +# Tests: apply_chatml_template +# --------------------------------------------------------------------------- + + +class TestApplyChatmlTemplate: + def test_user_only(self) -> None: + result = GenaiSession.apply_chatml_template("Hello") + assert result == "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" + + def test_with_system(self) -> None: + result = GenaiSession.apply_chatml_template("Hello", system="You are helpful.") + assert result.startswith("<|im_start|>system\nYou are helpful.<|im_end|>\n") + assert "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" in result + + def test_no_system_no_system_turn(self) -> None: + result = GenaiSession.apply_chatml_template("Hi") + assert "<|im_start|>system" not in result + + def test_ends_with_assistant_priming(self) -> None: + result = GenaiSession.apply_chatml_template("Hi") + assert result.endswith("<|im_start|>assistant\n") + + # --------------------------------------------------------------------------- # Tests: GenerationConfig defaults # --------------------------------------------------------------------------- From ae7d19bbecaa5a9e6dedefff86fba085421d0e78 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 12:39:25 +0800 Subject: [PATCH 05/12] feat(qwen3/genai): NPU+CPU hybrid EP support in genai_config - PipelineStage gains session_options: dict | None = None field; PipelineStage.to_dict() emits it when set - Add _qnn_stage_session_options(log_id, soc_model) helper that produces QNN HTP provider_options for a pipeline stage - build_qwen3_transformer_only_stages gains ep='cpu' and soc_model='60' params; when ep='qnn' the context and iterator stages receive QNN session_options, embeddings and lm_head stay on CPU (no session_options) - write_genai_bundle threads ep/soc_model through - export_qwen3_transformer_only.py passes ep='qnn' when --device npu - 5 new tests covering cpu/qnn ep routing and soc_model propagation (39 total, all pass) --- scripts/export_qwen3_transformer_only.py | 1 + src/winml/modelkit/models/hf/qwen3/genai.py | 77 ++++++++++++++++++++ tests/unit/models/qwen3/test_genai_config.py | 57 +++++++++++++++ 3 files changed, 135 insertions(+) diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 202c9c906..856123be2 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -220,6 +220,7 @@ def main(argv: list[str] | None = None) -> int: prefill_seq_len=args.prefill_seq_len, embeddings_src=args.embeddings, lm_head_src=args.lm_head, + ep="qnn" if args.device == "npu" else args.device, ) print(f" genai_config.json -> {config_path}") if args.embeddings is None: diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index 8c7c14503..4a63de45b 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -116,6 +116,13 @@ class PipelineStage: inputs: list[str] outputs: list[str] is_lm_head: bool = False + session_options: dict | None = None + """Per-stage ORT session options (e.g. provider_options for QNN). + + When set, emitted verbatim as the ``session_options`` key in the + ``genai_config.json`` pipeline stage. Leave ``None`` (default) for + stages that should run on the default (CPU) provider. + """ def to_dict(self) -> dict: """Serialize to the dict format expected by ``genai_config.json``.""" @@ -126,6 +133,8 @@ def to_dict(self) -> dict: "run_on_prompt": self.run_on_prompt, "run_on_token_gen": self.run_on_token_gen, } + if self.session_options: + d["session_options"] = self.session_options if self.is_lm_head: d["is_lm_head"] = True return d @@ -334,6 +343,42 @@ def _key(prefix: str) -> int: return sorted(patterns.keys(), key=_key) +# --------------------------------------------------------------------------- +# Per-EP stage session_options helpers +# --------------------------------------------------------------------------- + + +def _qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + # --------------------------------------------------------------------------- # Qwen3 transformer-only pipeline factory # --------------------------------------------------------------------------- @@ -348,6 +393,8 @@ def build_qwen3_transformer_only_stages( iterator_filename: str = DEFAULT_ITERATOR_FILENAME, embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", ) -> tuple[list[PipelineStage], DecoderIOMapping]: """Build pipeline stages by introspecting the built ONNX models. @@ -363,6 +410,13 @@ def build_qwen3_transformer_only_stages( iterator_filename: Bundle filename for the iterator model. embeddings_filename: Bundle filename for the embeddings model. lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer stages. ``"qnn"`` injects + QNN HTP ``session_options`` into the ``context`` and ``iterator`` + stages so they run on the NPU while ``embeddings`` and ``lm_head`` + continue on CPU. ``"cpu"`` (default) omits ``session_options`` + from all stages. + soc_model: Snapdragon SoC model number forwarded to the QNN backend + when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. Returns: ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and @@ -406,6 +460,17 @@ def build_qwen3_transformer_only_stages( present_value_names=pres_val_fmt, ) + # Per-stage session_options: NPU stages get QNN config; CPU and others get None. + ctx_session_opts: dict | None = None + iter_session_opts: dict | None = None + if ep == "qnn": + ctx_session_opts = _qnn_stage_session_options( + "onnxruntime-genai.context", soc_model=soc_model + ) + iter_session_opts = _qnn_stage_session_options( + "onnxruntime-genai.iterator", soc_model=soc_model + ) + stages: list[PipelineStage] = [ PipelineStage( name="embeddings", @@ -422,6 +487,7 @@ def build_qwen3_transformer_only_stages( run_on_token_gen=False, inputs=ctx_inputs, outputs=ctx_outputs, + session_options=ctx_session_opts, ), PipelineStage( name="iterator", @@ -430,6 +496,7 @@ def build_qwen3_transformer_only_stages( run_on_token_gen=True, inputs=iter_inputs, outputs=iter_outputs, + session_options=iter_session_opts, ), PipelineStage( name="lm_head", @@ -463,6 +530,8 @@ def write_genai_bundle( iterator_filename: str = DEFAULT_ITERATOR_FILENAME, embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", ) -> Path: """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. @@ -484,6 +553,12 @@ def write_genai_bundle( iterator_filename: Bundle filename for the iterator model. embeddings_filename: Bundle filename for the embeddings model. lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer (context/iterator) stages. + ``"qnn"`` injects QNN HTP ``session_options`` so those stages run + on the NPU while embeddings and lm_head run on CPU. + ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. Returns: Path to the written ``genai_config.json``. @@ -538,6 +613,8 @@ def write_genai_bundle( iterator_filename=iterator_filename, embeddings_filename=embeddings_filename, lm_head_filename=lm_head_filename, + ep=ep, + soc_model=soc_model, ) # 5. Write genai_config.json. diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 012c71cba..900f8b664 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -456,3 +456,60 @@ def test_roundtrip_with_build_genai_config(self) -> None: ) assert result["model"]["type"] == "decoder-pipeline" assert len(result["model"]["decoder"]["pipeline"]) == 4 + + def test_cpu_ep_no_session_options(self) -> None: + """Default cpu ep: context/iterator stages have no session_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="cpu" + ) + ctx = next(s for s in stages if s.name == "context") + itr = next(s for s in stages if s.name == "iterator") + assert ctx.session_options is None + assert itr.session_options is None + + def test_qnn_ep_injects_session_options(self) -> None: + """ep='qnn': context/iterator get QNN session_options; emb/lm_head do not.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + stage_map = {s.name: s for s in stages} + assert stage_map["embeddings"].session_options is None + assert stage_map["lm_head"].session_options is None + ctx_opts = stage_map["context"].session_options + itr_opts = stage_map["iterator"].session_options + assert ctx_opts is not None + assert itr_opts is not None + assert ctx_opts["provider_options"][0]["qnn"]["backend_path"] == "QnnHtp.dll" + assert itr_opts["log_id"] == "onnxruntime-genai.iterator" + + def test_qnn_session_options_in_serialized_config(self) -> None: + """QNN session_options appear in genai_config.json pipeline output.""" + with self._patch_onnx(): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + cfg = build_genai_config( + _mock_config(num_hidden_layers=4), + max_cache_len=256, + prefill_seq_len=64, + pipeline=stages, + decoder_io=decoder_io, + ) + pipeline = cfg["model"]["decoder"]["pipeline"] + ctx_dict = next(s for s in pipeline if "context" in s)["context"] + itr_dict = next(s for s in pipeline if "iterator" in s)["iterator"] + emb_dict = next(s for s in pipeline if "embeddings" in s)["embeddings"] + assert "session_options" in ctx_dict + assert "session_options" in itr_dict + assert "session_options" not in emb_dict + + def test_custom_soc_model(self) -> None: + """soc_model parameter propagates to QNN provider_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn", soc_model="73" + ) + ctx = next(s for s in stages if s.name == "context") + assert ctx.session_options["provider_options"][0]["qnn"]["soc_model"] == "73" From 6a44a077a7a737146ced1f56f349aed0318429d6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 13:28:00 +0800 Subject: [PATCH 06/12] fix(genai_session): let genai_config.json drive EP routing, add mixed EP Remove clear_providers/append_provider calls from GenaiSession.load(). EP placement is fully driven by per-stage session_options in genai_config.json. clear_providers() only clears the top-level provider and cannot override per-stage session_options embedded in the pipeline config. - Add 'mixed' EP (use genai_config.json as-is; default for infer_genai.py) - _NEEDS_WINML_EPS covers mixed/qnn/dml to trigger EP registration - Replace _EP_PROVIDER_MAP with _VALID_EPS + _NEEDS_WINML_EPS sets - Update tests: remove append_provider assertions, add mixed/config-not-modified tests - infer_genai.py default EP changed from 'cpu' to 'mixed' Result: NPU bundle (out/qwen3_bundle_npu) now runs at 9.3 tok/s vs 1.2 tok/s CPU --- scripts/infer_genai.py | 7 +++-- src/winml/modelkit/session/genai_session.py | 31 +++++++++++---------- tests/unit/session/test_genai_session.py | 19 +++++++++---- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 5144fa7bc..f31847b9b 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -46,7 +46,7 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" -_SUPPORTED_EPS = ["cpu", "qnn", "dml"] +_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] def _wrap_chat_template(prompt: str) -> str: @@ -78,8 +78,9 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: p.add_argument( "--ep", choices=_SUPPORTED_EPS, - default="cpu", - help="Execution provider (default: cpu).", + default="mixed", + help="Execution provider: 'mixed' uses genai_config.json as-is (default); " + "'cpu' forces all stages to CPU; 'qnn'/'dml' for full NPU/GPU.", ) p.add_argument( "--max-new", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 229b4e83a..61423efe1 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -61,14 +61,15 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# EP name mapping: user-friendly short name → ORT GenAI provider string. -# None means "do not append a provider" (= default CPU execution). +# Valid EP short names. +# "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU, +# ctx/iter on the target accelerator). +# EP routing is driven entirely by per-stage session_options in the bundle's +# genai_config.json — GenaiSession never calls clear_providers/append_provider. # --------------------------------------------------------------------------- -_EP_PROVIDER_MAP: dict[str, str | None] = { - "cpu": None, - "qnn": "QNNExecutionProvider", - "dml": "DmlExecutionProvider", -} +_VALID_EPS: frozenset[str] = frozenset({"cpu", "mixed", "qnn", "dml"}) +# EPs that require WinML EP discovery + registration before og.Model() init. +_NEEDS_WINML_EPS: frozenset[str] = frozenset({"mixed", "qnn", "dml"}) # --------------------------------------------------------------------------- @@ -178,8 +179,8 @@ def __init__( f"genai_config.json not found in {self._bundle_dir}. " "Run export_qwen3_transformer_only.py --genai-bundle first." ) - if self._ep not in _EP_PROVIDER_MAP: - raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_EP_PROVIDER_MAP)}") + if self._ep not in _VALID_EPS: + raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_VALID_EPS)}") logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep) @@ -202,8 +203,8 @@ def load(self) -> None: og = self._import_og() - # Register WinML EPs to ORT GenAI (skipped for CPU; idempotent). - if self._ep != "cpu": + # Register WinML EPs to ORT GenAI when the bundle may use a hardware EP. + if self._ep in _NEEDS_WINML_EPS: self._register_eps(og) if self._verbose: @@ -211,10 +212,10 @@ def load(self) -> None: try: config = og.Config(str(self._bundle_dir)) - config.clear_providers() - provider = _EP_PROVIDER_MAP[self._ep] - if provider is not None: - config.append_provider(provider) + # EP routing is driven entirely by genai_config.json (per-stage + # session_options). Do NOT call clear_providers/append_provider — + # those only touch the top-level provider and cannot override + # per-stage session_options already embedded in the pipeline config. self._model = og.Model(config) self._tokenizer = og.Tokenizer(self._model) except Exception as exc: diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py index 4859ef11e..4dcf7ea1c 100644 --- a/tests/unit/session/test_genai_session.py +++ b/tests/unit/session/test_genai_session.py @@ -114,7 +114,7 @@ def test_bundle_dir_property(self, bundle_dir: Path) -> None: assert session.bundle_dir == bundle_dir def test_supported_eps(self, bundle_dir: Path) -> None: - for ep in ("cpu", "qnn", "dml"): + for ep in ("cpu", "mixed", "qnn", "dml"): session = GenaiSession(bundle_dir, ep=ep) assert session.ep == ep @@ -225,20 +225,27 @@ def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) session.load() mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) - def test_non_cpu_appends_provider_to_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + def test_mixed_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: + mock_registry = MagicMock() + mock_registry.winml_available = True + mock_registry.register_execution_providers.return_value = { + "onnxruntime_genai": ["QNNExecutionProvider"] + } with ( _patch_og(mock_og), patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, ): - mock_reg_cls.get_instance.return_value = MagicMock(winml_available=False) - session = GenaiSession(bundle_dir, ep="qnn") + mock_reg_cls.get_instance.return_value = mock_registry + session = GenaiSession(bundle_dir, ep="mixed") session.load() - mock_og.Config.return_value.append_provider.assert_called_once_with("QNNExecutionProvider") + mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) - def test_cpu_does_not_append_provider(self, bundle_dir: Path, mock_og: MagicMock) -> None: + def test_config_not_modified_at_load(self, bundle_dir: Path, mock_og: MagicMock) -> None: + # EP routing is driven by genai_config.json — we must NOT touch the config. with _patch_og(mock_og): session = GenaiSession(bundle_dir, ep="cpu") session.load() + mock_og.Config.return_value.clear_providers.assert_not_called() mock_og.Config.return_value.append_provider.assert_not_called() From d1d4b345617c083cb7f0e073b25843e01999446f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 16:36:47 +0800 Subject: [PATCH 07/12] feat: add --compile flag to infer_genai.py for EPContext pre-compilation - GenaiSession gains compile=True parameter - _prepare_compiled_bundle(): detects QNN stages from genai_config.json, compiles each stage to EPContext ONNX via ort.ModelCompiler in a subprocess - _compile_stage(): 5-minute timeout per stage to handle QNN SDK hang (known bug: w8a16 + multi-token prefill hangs indefinitely) - Compiled artifacts cached in bundle_dir/_compiled/; reused on subsequent runs - _mirror_non_onnx_files(): symlinks/copies tokenizer files so og.Config can load from the compiled sub-directory - infer_genai.py --compile flag wired through to GenaiSession --- scripts/infer_genai.py | 20 ++- src/winml/modelkit/session/genai_session.py | 190 +++++++++++++++++++- 2 files changed, 205 insertions(+), 5 deletions(-) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index f31847b9b..47aa6683e 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -26,6 +26,11 @@ uv run python scripts/infer_genai.py \\ --model-dir out/my_bundle --prompt "Hi" --ep cpu + # Pre-compile QNN stages to EPContext on first run; reuse cache on subsequent runs. + # Eliminates per-run JIT overhead (~60-90 s saved on Snapdragon X Elite). + uv run python scripts/infer_genai.py \\ + --prompt "Hello" --ep mixed --compile + Dependencies (install in a fresh venv):: pip install onnxruntime-genai-winml @@ -93,6 +98,17 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: action="store_true", help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", ) + p.add_argument( + "--compile", + action="store_true", + help=( + "Pre-compile QNN pipeline stages to EPContext ONNX before loading. " + "On first use this triggers ort.ModelCompiler per stage (~60-90 s for iter). " + "Compiled artifacts are cached in bundle_dir/_compiled/; " + "subsequent runs reuse the cache and skip JIT. " + "Has no effect when --ep cpu." + ), + ) p.add_argument( "--verbose", action="store_true", @@ -109,7 +125,9 @@ def main(argv: list[str] | None = None) -> int: gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) try: - session = GenaiSession(args.model_dir, ep=args.ep, verbose=args.verbose) + session = GenaiSession( + args.model_dir, ep=args.ep, verbose=args.verbose, compile=args.compile + ) except FileNotFoundError as exc: print(f"ERROR: {exc}", file=sys.stderr) return 1 diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 61423efe1..2069d3076 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -47,6 +47,7 @@ import json import logging +import shutil from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -149,8 +150,17 @@ class GenaiSession: ``None`` (default), read from ``genai_config.json``. Must match the ``--max-cache-len`` used during the winml-cli build. verbose: Enable ``onnxruntime-genai`` native model I/O logging. + compile: Pre-compile QNN pipeline stages to EPContext ONNX on first + run (inside ``bundle_dir/_compiled/``). Subsequent calls reuse + the cached EPContext files, eliminating per-run JIT overhead. + Only stages that can be compiled without hanging are attempted; + stages that fail compilation fall back to the original ONNX. + Has no effect when ``ep="cpu"``. """ + # Sub-directory within the bundle that holds pre-compiled EPContext ONNX files. + _COMPILED_SUBDIR: str = "_compiled" + def __init__( self, bundle_dir: str | Path, @@ -158,11 +168,13 @@ def __init__( *, context_length: int | None = None, verbose: bool = False, + compile: bool = False, ) -> None: self._bundle_dir = Path(bundle_dir) self._ep = ep.lower() self._context_length_override = context_length self._verbose = verbose + self._compile = compile # Resolved at load() time. self._context_length: int | None = None @@ -210,8 +222,13 @@ def load(self) -> None: if self._verbose: og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + # Determine which bundle directory og.Config should load from. + load_dir = self._bundle_dir + if self._compile and self._ep in _NEEDS_WINML_EPS: + load_dir = self._prepare_compiled_bundle() + try: - config = og.Config(str(self._bundle_dir)) + config = og.Config(str(load_dir)) # EP routing is driven entirely by genai_config.json (per-stage # session_options). Do NOT call clear_providers/append_provider — # those only touch the top-level provider and cannot override @@ -221,9 +238,7 @@ def load(self) -> None: except Exception as exc: self._model = None self._tokenizer = None - raise GenaiLoadError( - f"Failed to load genai bundle from {self._bundle_dir}: {exc}" - ) from exc + raise GenaiLoadError(f"Failed to load genai bundle from {load_dir}: {exc}") from exc self._context_length = self._context_length_override or self._read_context_length() logger.info( @@ -416,6 +431,173 @@ def _ensure_loaded(self) -> None: if self._model is None: self.load() + def _prepare_compiled_bundle(self) -> Path: + """Create (or reuse) a *compiled* bundle directory. + + Reads ``genai_config.json``, finds QNN-accelerated stages (those with + ``QNNExecutionProvider`` in their ``session_options``), and tries to + compile their ONNX to EPContext format using ``ort.ModelCompiler``. + + The compiled bundle is stored under ``bundle_dir/_compiled/``. On + every call the helper checks whether the cached EPContext file is + newer than the source ONNX; if so, it skips recompilation. + + Returns: + Path to the compiled bundle directory (may equal ``bundle_dir`` + if no compilable stages were found, or if all compilations failed). + """ + compiled_dir = self._bundle_dir / self._COMPILED_SUBDIR + config_src = self._bundle_dir / "genai_config.json" + cfg = json.loads(config_src.read_text(encoding="utf-8")) + + # Collect pipeline stages that use QNNExecutionProvider. + # genai_config pipeline entries: {"ctx": {...}, "iter": {...}, ...} + pipeline: dict = cfg.get("model", {}).get("decoder", {}) + qnn_stages: list[tuple[str, str]] = [] # [(stage_key, onnx_filename), ...] + for stage_key, stage_cfg in pipeline.items(): + if not isinstance(stage_cfg, dict): + continue + so = stage_cfg.get("session_options", {}) + providers = so.get("provider_options", []) + for p in providers: + if isinstance(p, dict) and "QNNExecutionProvider" in p: + onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") + qnn_stages.append((stage_key, onnx_filename)) + break + + if not qnn_stages: + logger.info("No QNN stages found in genai_config.json; skipping compilation") + return self._bundle_dir + + compiled_dir.mkdir(exist_ok=True) + modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) + any_compiled = False + + for stage_key, onnx_filename in qnn_stages: + src_onnx = self._bundle_dir / onnx_filename + ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" + + # Skip recompilation if cache is up-to-date. + if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime: + logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name) + self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + any_compiled = True + continue + + # Attempt compilation. + success = self._compile_stage(src_onnx, ctx_onnx, stage_key) + if success: + self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + any_compiled = True + else: + logger.warning( + "Stage %r: compilation failed or was skipped; using original ONNX", stage_key + ) + + if not any_compiled: + return self._bundle_dir + + # Write the modified genai_config into the compiled sub-directory so that + # ort-genai can resolve all ONNX paths (absolute paths are used). + # Also symlink/copy every other file that og.Config expects. + compiled_config = compiled_dir / "genai_config.json" + compiled_config.write_text( + json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" + ) + self._mirror_non_onnx_files(compiled_dir) + + logger.info("Compiled bundle prepared at %s", compiled_dir) + return compiled_dir + + @staticmethod + def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None: + """Rewrite a pipeline stage's ``filename`` to an absolute path.""" + decoder: dict = cfg.get("model", {}).get("decoder", {}) + if stage_key in decoder and isinstance(decoder[stage_key], dict): + decoder[stage_key]["filename"] = abs_path + + def _compile_stage(self, src_onnx: Path, ctx_out: Path, stage_key: str) -> bool: + """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. + + Runs in a subprocess so that a ModelCompiler hang (a known QNN SDK bug + with w8a16 + multi-token prefill) does not block the caller. + + Args: + src_onnx: Source ONNX file path. + ctx_out: Destination EPContext ONNX path. + stage_key: Human-readable label for logging. + + Returns: + ``True`` if compilation succeeded; ``False`` on timeout or error. + """ + import multiprocessing + + compile_timeout_s = 300 # 5 minutes; iter compiles in ~67s normally + + logger.info("Compiling stage %r: %s → %s", stage_key, src_onnx.name, ctx_out.name) + + def _do_compile(src: str, dst: str) -> None: + import onnxruntime as ort + + from winml.modelkit.session.ep_registry import WinMLEPRegistry + from winml.modelkit.winml import add_ep_for_device + + registry = WinMLEPRegistry.get_instance() + registry.register_execution_providers() + so = ort.SessionOptions() + so.add_session_config_entry("ep.context_enable", "1") + so.add_session_config_entry("ep.context_file_path", dst) + add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU) + mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) + mc.compile_to_file(dst) + + ctx = multiprocessing.get_context("spawn") + proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out))) + proc.start() + proc.join(timeout=compile_timeout_s) + + if proc.is_alive(): + logger.warning( + "Stage %r compilation timed out after %ds (QNN SDK hang). " + "This is a known issue with multi-token prefill + w8a16 quantization. " + "Falling back to JIT compilation for this stage.", + stage_key, + compile_timeout_s, + ) + proc.kill() + proc.join() + # Remove partial output file. + ctx_out.unlink(missing_ok=True) + return False + + if proc.exitcode != 0: + logger.warning("Stage %r compilation failed (exit %d)", stage_key, proc.exitcode) + ctx_out.unlink(missing_ok=True) + return False + + logger.info("Stage %r compiled successfully → %s", stage_key, ctx_out) + return True + + def _mirror_non_onnx_files(self, compiled_dir: Path) -> None: + """Create symlinks (or copies on Windows) for every non-ONNX file. + + Files are linked/copied into *compiled_dir* so that ``og.Config`` + finds tokenizer files, specials maps, etc. Existing files are left + untouched. + """ + for src in self._bundle_dir.iterdir(): + if src.name == self._COMPILED_SUBDIR: + continue + dst = compiled_dir / src.name + if dst.exists(): + continue + if src.is_file(): + try: + dst.symlink_to(src.resolve()) + except (OSError, NotImplementedError): + # Symlinks may require elevated privileges on Windows; fall back to copy. + shutil.copy2(src, dst) + @staticmethod def _import_og() -> object: """Import and return the ``onnxruntime_genai`` module. From 5e3162b108b7479034dde651db1548f016fd739e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 16:46:05 +0800 Subject: [PATCH 08/12] fix: resolve prefill compilation hang by forcing htp_graph_finalization_optimization_mode=0 Root cause: QNN SDK ModelCompiler deadlocks when compiling w8a16 quantized ONNX with multi-token static input shapes (seq_len > 1) at graph finalization optimization levels 1-3. The genai_config uses level 3 for runtime inference, which triggers the hang when passed to ModelCompiler directly. Fix: _compile_stage now forces htp_graph_finalization_optimization_mode=0 for compilation. This lets ModelCompiler finish (ctx ~41s, iter ~67s) while runtime inference still uses the full level-3 optimization from genai_config (EPContext loading bypasses compilation entirely, so the runtime option is irrelevant). Also fixes: - Pipeline stage detection: genai_config uses 'qnn' key (not 'QNNExecutionProvider') in provider_options; detection and option extraction now uses the correct key - _patch_stage_filename: genai_config pipeline is a list, not a dict; updated to iterate list entries correctly - _prepare_compiled_bundle: passes QNN provider options from each stage's session_options to _compile_stage so soc_model, backend_path, etc. are respected - Removed the 'prefill fallback to JIT' warning since the hang is now fixed --- src/winml/modelkit/session/genai_session.py | 103 +++++++++++++------- 1 file changed, 70 insertions(+), 33 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 2069d3076..25b2ce311 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -450,20 +450,25 @@ def _prepare_compiled_bundle(self) -> Path: config_src = self._bundle_dir / "genai_config.json" cfg = json.loads(config_src.read_text(encoding="utf-8")) - # Collect pipeline stages that use QNNExecutionProvider. - # genai_config pipeline entries: {"ctx": {...}, "iter": {...}, ...} - pipeline: dict = cfg.get("model", {}).get("decoder", {}) - qnn_stages: list[tuple[str, str]] = [] # [(stage_key, onnx_filename), ...] - for stage_key, stage_cfg in pipeline.items(): - if not isinstance(stage_cfg, dict): + # Collect pipeline stages that use a QNN EP ("qnn" key in provider_options). + # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] + # provider_options format: [{"qnn": {...}}] + pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) + # [(stage_key, onnx_filename, qnn_opts), ...] + qnn_stages: list[tuple[str, str, dict]] = [] + for stage_entry in pipeline_list: + if not isinstance(stage_entry, dict): continue - so = stage_cfg.get("session_options", {}) - providers = so.get("provider_options", []) - for p in providers: - if isinstance(p, dict) and "QNNExecutionProvider" in p: - onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - qnn_stages.append((stage_key, onnx_filename)) - break + for stage_key, stage_cfg in stage_entry.items(): + if not isinstance(stage_cfg, dict): + continue + so = stage_cfg.get("session_options", {}) + providers = so.get("provider_options", []) + for p in providers: + if isinstance(p, dict) and "qnn" in p: + onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) + break if not qnn_stages: logger.info("No QNN stages found in genai_config.json; skipping compilation") @@ -473,7 +478,7 @@ def _prepare_compiled_bundle(self) -> Path: modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) any_compiled = False - for stage_key, onnx_filename in qnn_stages: + for stage_key, onnx_filename, qnn_opts in qnn_stages: src_onnx = self._bundle_dir / onnx_filename ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" @@ -485,13 +490,13 @@ def _prepare_compiled_bundle(self) -> Path: continue # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key) + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) if success: self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) any_compiled = True else: logger.warning( - "Stage %r: compilation failed or was skipped; using original ONNX", stage_key + "Stage %r: compilation failed; using original ONNX (JIT fallback)", stage_key ) if not any_compiled: @@ -512,31 +517,64 @@ def _prepare_compiled_bundle(self) -> Path: @staticmethod def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None: """Rewrite a pipeline stage's ``filename`` to an absolute path.""" - decoder: dict = cfg.get("model", {}).get("decoder", {}) - if stage_key in decoder and isinstance(decoder[stage_key], dict): - decoder[stage_key]["filename"] = abs_path - - def _compile_stage(self, src_onnx: Path, ctx_out: Path, stage_key: str) -> bool: + pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) + for stage_entry in pipeline_list: + if isinstance(stage_entry, dict) and stage_key in stage_entry: + stage_cfg = stage_entry[stage_key] + if isinstance(stage_cfg, dict): + stage_cfg["filename"] = abs_path + return + + def _compile_stage( + self, + src_onnx: Path, + ctx_out: Path, + stage_key: str, + qnn_opts: dict | None = None, + ) -> bool: """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. - Runs in a subprocess so that a ModelCompiler hang (a known QNN SDK bug - with w8a16 + multi-token prefill) does not block the caller. + Runs in a subprocess so that a ModelCompiler failure does not block + the caller. The QNN options from ``genai_config.json`` are forwarded + to the compilation session, with ``htp_graph_finalization_optimization_mode`` + forced to ``"0"``. This avoids a QNN SDK deadlock that occurs when + compiling w8a16 quantized models with multi-token static input shapes + (``seq_len > 1``) at higher optimization levels. + + The resulting EPContext ONNX is identical in interface to the original; + at runtime, ort-genai loads the pre-compiled QNN binary and the + inference-time ``htp_graph_finalization_optimization_mode`` from + ``genai_config.json`` governs any further JIT compilation. Args: src_onnx: Source ONNX file path. ctx_out: Destination EPContext ONNX path. stage_key: Human-readable label for logging. + qnn_opts: QNN provider options from genai_config (e.g. backend_path, + htp_performance_mode, soc_model). ``htp_graph_finalization_ + optimization_mode`` is always overridden to ``"0"``. Returns: ``True`` if compilation succeeded; ``False`` on timeout or error. """ import multiprocessing - compile_timeout_s = 300 # 5 minutes; iter compiles in ~67s normally + # Force graph-finalization optimization off. Levels 1-3 deadlock QNN + # ModelCompiler for w8a16 quantized models with multi-token input shapes. + compile_qnn_opts = dict(qnn_opts or {}) + compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" - logger.info("Compiling stage %r: %s → %s", stage_key, src_onnx.name, ctx_out.name) + compile_timeout_s = 300 # 5 minutes; ctx compiles in ~41s, iter in ~67s - def _do_compile(src: str, dst: str) -> None: + logger.info( + "Compiling stage %r: %s → %s (qnn_opts=%s)", + stage_key, + src_onnx.name, + ctx_out.name, + compile_qnn_opts, + ) + + def _do_compile(src: str, dst: str, qnn_options: dict) -> None: import onnxruntime as ort from winml.modelkit.session.ep_registry import WinMLEPRegistry @@ -547,26 +585,25 @@ def _do_compile(src: str, dst: str) -> None: so = ort.SessionOptions() so.add_session_config_entry("ep.context_enable", "1") so.add_session_config_entry("ep.context_file_path", dst) - add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU) + add_ep_for_device( + so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options + ) mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) mc.compile_to_file(dst) ctx = multiprocessing.get_context("spawn") - proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out))) + proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out), compile_qnn_opts)) proc.start() proc.join(timeout=compile_timeout_s) if proc.is_alive(): - logger.warning( - "Stage %r compilation timed out after %ds (QNN SDK hang). " - "This is a known issue with multi-token prefill + w8a16 quantization. " - "Falling back to JIT compilation for this stage.", + logger.error( + "Stage %r compilation timed out after %ds — killing subprocess.", stage_key, compile_timeout_s, ) proc.kill() proc.join() - # Remove partial output file. ctx_out.unlink(missing_ok=True) return False From c95c74b56137586ba56e3254d5bf652b84c99c72 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 16:55:25 +0800 Subject: [PATCH 09/12] fix: move _do_compile to module-level _qnn_compile_worker for Windows spawn Windows multiprocessing spawn serialises the subprocess target via pickle. Local functions (closures) defined inside a method cannot be pickled, which caused 'AttributeError: Can't pickle local function' at runtime. Moved the compilation logic to a module-level function _qnn_compile_worker so it is importable by name in the spawned subprocess. Also fix ONNX filename in compiled genai_config: use ctx_onnx.name (just the filename) instead of str(ctx_onnx) (absolute path). ort-genai resolves filenames relative to the directory passed to og.Config, so an absolute path causes double-path concatenation and a 'file not found' error. --- src/winml/modelkit/session/genai_session.py | 60 +++++++++++++-------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 25b2ce311..5943c0266 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -61,6 +61,33 @@ logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Module-level compilation worker (must be at module scope for multiprocessing +# spawn on Windows, which serialises the target via pickle). +# --------------------------------------------------------------------------- + + +def _qnn_compile_worker(src: str, dst: str, qnn_options: dict) -> None: + """Compile *src* ONNX to an EPContext ONNX at *dst* using QNN HTP. + + Executed in a subprocess by :meth:`GenaiSession._compile_stage`. + """ + import onnxruntime as ort + + from winml.modelkit.session.ep_registry import WinMLEPRegistry + from winml.modelkit.winml import add_ep_for_device + + registry = WinMLEPRegistry.get_instance() + registry.register_execution_providers() + so = ort.SessionOptions() + so.add_session_config_entry("ep.context_enable", "1") + so.add_session_config_entry("ep.context_file_path", dst) + add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options) + mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) + mc.compile_to_file(dst) + + # --------------------------------------------------------------------------- # Valid EP short names. # "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU, @@ -485,14 +512,16 @@ def _prepare_compiled_bundle(self) -> Path: # Skip recompilation if cache is up-to-date. if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime: logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name) - self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + # Use just the filename — genai_config.json lives in compiled_dir, + # so ort-genai resolves filenames relative to compiled_dir. + self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True continue # Attempt compilation. success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) if success: - self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True else: logger.warning( @@ -502,9 +531,9 @@ def _prepare_compiled_bundle(self) -> Path: if not any_compiled: return self._bundle_dir - # Write the modified genai_config into the compiled sub-directory so that - # ort-genai can resolve all ONNX paths (absolute paths are used). - # Also symlink/copy every other file that og.Config expects. + # Write the modified genai_config into the compiled sub-directory. + # ONNX filenames are relative to compiled_dir; ort-genai resolves them + # from the directory it loads og.Config from. compiled_config = compiled_dir / "genai_config.json" compiled_config.write_text( json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" @@ -574,25 +603,10 @@ def _compile_stage( compile_qnn_opts, ) - def _do_compile(src: str, dst: str, qnn_options: dict) -> None: - import onnxruntime as ort - - from winml.modelkit.session.ep_registry import WinMLEPRegistry - from winml.modelkit.winml import add_ep_for_device - - registry = WinMLEPRegistry.get_instance() - registry.register_execution_providers() - so = ort.SessionOptions() - so.add_session_config_entry("ep.context_enable", "1") - so.add_session_config_entry("ep.context_file_path", dst) - add_ep_for_device( - so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options - ) - mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) - mc.compile_to_file(dst) - ctx = multiprocessing.get_context("spawn") - proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out), compile_qnn_opts)) + proc = ctx.Process( + target=_qnn_compile_worker, args=(str(src_onnx), str(ctx_out), compile_qnn_opts) + ) proc.start() proc.join(timeout=compile_timeout_s) From 8cdc9f82576d43f6125017b48228164092df3e8d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 17:08:51 +0800 Subject: [PATCH 10/12] perf: use configured htp_graph_finalization_optimization_mode for gen stages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously _compile_stage forced mode='0' for ALL stages to avoid a QNN SDK deadlock on w8a16 + multi-token prefill. This also silently capped the iter (generation) stage at mode 0, producing under-optimized kernels (~10 tok/s). Fix: only force mode=0 for prefill stages (run_on_prompt=true, seq_len>1 where the deadlock occurs). Generation stages (run_on_token_gen=true, seq_len=1) use the configured mode from genai_config.json (typically '3'), which is safe for single-token input and produces fully-optimized kernels. Performance: Before: 10.4 tok/s (both ctx+iter compiled with mode 0) After: 43.4 tok/s (ctx mode 0, iter mode 3) — matches reference ~45 tok/s _prepare_compiled_bundle now passes is_prefill flag per stage based on run_on_prompt / run_on_token_gen fields in genai_config.json pipeline config. --- src/winml/modelkit/session/genai_session.py | 50 +++++++++++++-------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 5943c0266..8896295b7 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -481,8 +481,12 @@ def _prepare_compiled_bundle(self) -> Path: # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] # provider_options format: [{"qnn": {...}}] pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) - # [(stage_key, onnx_filename, qnn_opts), ...] - qnn_stages: list[tuple[str, str, dict]] = [] + # [(stage_key, onnx_filename, qnn_opts, is_prefill), ...] + # is_prefill=True when run_on_prompt=True and run_on_token_gen=False. + # Prefill stages with seq_len>1 require htp_graph_finalization_optimization_mode="0" + # to avoid a QNN SDK deadlock; generation stages (seq_len=1) can use the full + # configured optimization level for maximum throughput. + qnn_stages: list[tuple[str, str, dict, bool]] = [] for stage_entry in pipeline_list: if not isinstance(stage_entry, dict): continue @@ -494,7 +498,11 @@ def _prepare_compiled_bundle(self) -> Path: for p in providers: if isinstance(p, dict) and "qnn" in p: onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) + is_prefill = bool( + stage_cfg.get("run_on_prompt", False) + and not stage_cfg.get("run_on_token_gen", False) + ) + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]), is_prefill)) break if not qnn_stages: @@ -505,7 +513,7 @@ def _prepare_compiled_bundle(self) -> Path: modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) any_compiled = False - for stage_key, onnx_filename, qnn_opts in qnn_stages: + for stage_key, onnx_filename, qnn_opts, is_prefill in qnn_stages: src_onnx = self._bundle_dir / onnx_filename ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" @@ -519,7 +527,7 @@ def _prepare_compiled_bundle(self) -> Path: continue # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts, is_prefill) if success: self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True @@ -560,38 +568,42 @@ def _compile_stage( ctx_out: Path, stage_key: str, qnn_opts: dict | None = None, + is_prefill: bool = False, ) -> bool: """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. Runs in a subprocess so that a ModelCompiler failure does not block - the caller. The QNN options from ``genai_config.json`` are forwarded - to the compilation session, with ``htp_graph_finalization_optimization_mode`` - forced to ``"0"``. This avoids a QNN SDK deadlock that occurs when - compiling w8a16 quantized models with multi-token static input shapes - (``seq_len > 1``) at higher optimization levels. + the caller. QNN options from ``genai_config.json`` are forwarded to + the compilation session. - The resulting EPContext ONNX is identical in interface to the original; - at runtime, ort-genai loads the pre-compiled QNN binary and the - inference-time ``htp_graph_finalization_optimization_mode`` from - ``genai_config.json`` governs any further JIT compilation. + For prefill stages (``is_prefill=True``) ``htp_graph_finalization_ + optimization_mode`` is forced to ``"0"`` to avoid a QNN SDK deadlock + that occurs when compiling w8a16 quantized models with multi-token + static input shapes (``seq_len > 1``) at higher optimization levels. + For generation stages (``is_prefill=False``, ``seq_len=1``) the + configured optimization level is preserved so that the compiled kernels + are as fast as the JIT path. Args: src_onnx: Source ONNX file path. ctx_out: Destination EPContext ONNX path. stage_key: Human-readable label for logging. qnn_opts: QNN provider options from genai_config (e.g. backend_path, - htp_performance_mode, soc_model). ``htp_graph_finalization_ - optimization_mode`` is always overridden to ``"0"``. + htp_performance_mode, soc_model). + is_prefill: ``True`` when the stage runs only on prompt (ctx) and has + multi-token input; forces ``htp_graph_finalization_optimization_mode`` + to ``"0"`` to avoid the QNN SDK deadlock. Returns: ``True`` if compilation succeeded; ``False`` on timeout or error. """ import multiprocessing - # Force graph-finalization optimization off. Levels 1-3 deadlock QNN - # ModelCompiler for w8a16 quantized models with multi-token input shapes. compile_qnn_opts = dict(qnn_opts or {}) - compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" + if is_prefill: + # QNN SDK deadlocks at levels 1-3 for w8a16 models with seq_len > 1. + compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" + # else: keep the configured mode (typically "3") for generation stages. compile_timeout_s = 300 # 5 minutes; ctx compiles in ~41s, iter in ~67s From dd1e9ab154b344a2b0175e439e62b544d291530d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 17:28:26 +0800 Subject: [PATCH 11/12] simplify: remove htp_graph_finalization_optimization_mode override in _compile_stage The original mode=0 override was added to avoid a QNN SDK deadlock when compiling w8a16 prefill (seq_len>1) at higher optimization levels. Testing revealed the deadlock only occurs when QNN provider options are NOT passed to ort.ModelCompiler at all (causing it to fall back to a broken default path). With correct QNN options (backend_path, soc_model, etc.) forwarded, mode=3 compiles successfully for both ctx (~73s) and iter (~67s) with no hang. Remove the is_prefill flag and mode override entirely. _compile_stage now passes genai_config QNN options unchanged, giving fully-optimized kernels for all stages. Performance (hot NPU, EPContext loaded): ctx+iter both mode=3: ~44.5 tok/s vs reference ~45 tok/s --- src/winml/modelkit/session/genai_session.py | 45 +++++---------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 8896295b7..34e2ef5e4 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -481,12 +481,8 @@ def _prepare_compiled_bundle(self) -> Path: # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] # provider_options format: [{"qnn": {...}}] pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) - # [(stage_key, onnx_filename, qnn_opts, is_prefill), ...] - # is_prefill=True when run_on_prompt=True and run_on_token_gen=False. - # Prefill stages with seq_len>1 require htp_graph_finalization_optimization_mode="0" - # to avoid a QNN SDK deadlock; generation stages (seq_len=1) can use the full - # configured optimization level for maximum throughput. - qnn_stages: list[tuple[str, str, dict, bool]] = [] + # [(stage_key, onnx_filename, qnn_opts), ...] + qnn_stages: list[tuple[str, str, dict]] = [] for stage_entry in pipeline_list: if not isinstance(stage_entry, dict): continue @@ -498,11 +494,7 @@ def _prepare_compiled_bundle(self) -> Path: for p in providers: if isinstance(p, dict) and "qnn" in p: onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - is_prefill = bool( - stage_cfg.get("run_on_prompt", False) - and not stage_cfg.get("run_on_token_gen", False) - ) - qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]), is_prefill)) + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) break if not qnn_stages: @@ -513,7 +505,7 @@ def _prepare_compiled_bundle(self) -> Path: modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) any_compiled = False - for stage_key, onnx_filename, qnn_opts, is_prefill in qnn_stages: + for stage_key, onnx_filename, qnn_opts in qnn_stages: src_onnx = self._bundle_dir / onnx_filename ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" @@ -527,7 +519,7 @@ def _prepare_compiled_bundle(self) -> Path: continue # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts, is_prefill) + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) if success: self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True @@ -568,31 +560,21 @@ def _compile_stage( ctx_out: Path, stage_key: str, qnn_opts: dict | None = None, - is_prefill: bool = False, ) -> bool: """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. Runs in a subprocess so that a ModelCompiler failure does not block - the caller. QNN options from ``genai_config.json`` are forwarded to - the compilation session. - - For prefill stages (``is_prefill=True``) ``htp_graph_finalization_ - optimization_mode`` is forced to ``"0"`` to avoid a QNN SDK deadlock - that occurs when compiling w8a16 quantized models with multi-token - static input shapes (``seq_len > 1``) at higher optimization levels. - For generation stages (``is_prefill=False``, ``seq_len=1``) the - configured optimization level is preserved so that the compiled kernels - are as fast as the JIT path. + the caller. The QNN options from ``genai_config.json`` are forwarded + unchanged to the compilation session, so each stage is compiled at + exactly the optimization level configured in the bundle. Args: src_onnx: Source ONNX file path. ctx_out: Destination EPContext ONNX path. stage_key: Human-readable label for logging. qnn_opts: QNN provider options from genai_config (e.g. backend_path, - htp_performance_mode, soc_model). - is_prefill: ``True`` when the stage runs only on prompt (ctx) and has - multi-token input; forces ``htp_graph_finalization_optimization_mode`` - to ``"0"`` to avoid the QNN SDK deadlock. + htp_performance_mode, htp_graph_finalization_optimization_mode, + soc_model). Returns: ``True`` if compilation succeeded; ``False`` on timeout or error. @@ -600,12 +582,7 @@ def _compile_stage( import multiprocessing compile_qnn_opts = dict(qnn_opts or {}) - if is_prefill: - # QNN SDK deadlocks at levels 1-3 for w8a16 models with seq_len > 1. - compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" - # else: keep the configured mode (typically "3") for generation stages. - - compile_timeout_s = 300 # 5 minutes; ctx compiles in ~41s, iter in ~67s + compile_timeout_s = 300 # 5 minutes; ctx compiles in ~73s, iter in ~67s logger.info( "Compiling stage %r: %s → %s (qnn_opts=%s)", From 7c2d067d104d5ab4f1df01353fabc9bb10b5c029 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 21:44:25 +0800 Subject: [PATCH 12/12] refactor(genai): move generic bundle logic to utils/genai, qwen3/genai as shim - Extract all architecture-agnostic logic (PipelineStage, DecoderIOMapping, build_genai_config, build_decoder_pipeline_stages, write_genai_bundle, qnn_stage_session_options, ONNX introspection helpers) into src/winml/modelkit/utils/genai.py so other model families can reuse it - Reduce qwen3/genai.py to a thin re-export shim with a backward-compatible build_qwen3_transformer_only_stages alias for existing callers - fix(codeql): remove unused _TOKENIZER_FILES from utils/genai.py - fix(codeql): remove unnecessary del generator in GenaiSession.generate_streaming - fix(codeql): add missing Protocol body ellipsis in QuantConfigFinalizer.finalize - fix(codeql): import get_quant_finalizer directly in quant/__init__.py - fix(test): update mock patch path to winml.modelkit.utils.genai._introspect_onnx_io - fix(test): replace bare 'import onnx' with 'from onnx import ...' in test_qwen3_calibration.py --- .../modelkit/models/hf/qwen3/__init__.py | 2 + src/winml/modelkit/models/hf/qwen3/genai.py | 672 +----------------- src/winml/modelkit/quant/__init__.py | 12 +- src/winml/modelkit/quant/calibration/base.py | 1 + src/winml/modelkit/session/genai_session.py | 21 +- src/winml/modelkit/utils/genai.py | 663 +++++++++++++++++ tests/unit/models/qwen3/test_genai_config.py | 2 +- .../calibration/test_qwen3_calibration.py | 25 +- 8 files changed, 724 insertions(+), 674 deletions(-) create mode 100644 src/winml/modelkit/utils/genai.py diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 8d8676398..dbabe2d60 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -15,6 +15,7 @@ from .genai import ( DecoderIOMapping, PipelineStage, + build_decoder_pipeline_stages, build_genai_config, build_qwen3_transformer_only_stages, write_genai_bundle, @@ -24,6 +25,7 @@ __all__ = [ "DecoderIOMapping", "PipelineStage", + "build_decoder_pipeline_stages", "build_genai_config", "build_qwen3_transformer_only_stages", "write_genai_bundle", diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index 4a63de45b..9e65908f5 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -2,659 +2,40 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -r"""Generate an onnxruntime-genai bundle for a transformer-only decoder pipeline. +"""Qwen3 genai bundle support — thin shim over :mod:`winml.modelkit.utils.genai`. -The bundle is a directory that ``onnxruntime-genai`` can load directly via -``og.Config(str(bundle_dir))``. It contains: +All generic logic (``PipelineStage``, ``DecoderIOMapping``, ``build_genai_config``, +``build_decoder_pipeline_stages``, ``write_genai_bundle``) lives in +:mod:`winml.modelkit.utils.genai` so it can be reused by other model families. - genai_config.json — pipeline config consumed by onnxruntime-genai - ctx.onnx — prefill/context ONNX (built by winml-cli) - iter.onnx — iteration/decode ONNX (built by winml-cli) - embeddings.onnx — embedding-lookup ONNX (placeholder; copy externally) - lm_head.onnx — lm_head ONNX (placeholder; copy externally) - tokenizer.json — HF tokenizer files (downloaded from the model repo) - tokenizer_config.json - vocab.json / merges.txt / generation_config.json - -The pipeline follows the same 4-stage layout as the reference bundle: - - input_ids → [embeddings] → input_hidden_states - → [context | iterator] → output_hidden_states + present KVs - → [lm_head] → logits - -The context stage runs on the prompt (prefill); the iterator stage runs on each -subsequent decode step. Both share the same KV cache buffer via genai's -``past_present_share_buffer`` mode. - -Public API:: - - from winml.modelkit.models.hf.qwen3.genai import ( - build_genai_config, - build_qwen3_transformer_only_stages, - write_genai_bundle, - DecoderIOMapping, - PipelineStage, - ) - - # High-level: derive everything from the built ONNX files - stages, decoder_io = build_qwen3_transformer_only_stages( - ctx_path, iter_path, num_layers=hf_config.num_hidden_layers - ) - cfg = build_genai_config( - hf_config, max_cache_len=256, prefill_seq_len=64, - pipeline=stages, decoder_io=decoder_io, - ) - - # Or one-shot bundle assembly - write_genai_bundle( - Path("out/bundle"), - context_onnx=ctx_path, - iterator_onnx=iter_path, - model_id="Qwen/Qwen3-0.6B", - max_cache_len=256, - prefill_seq_len=64, - embeddings_src=emb_path, # None = skip (add later) - lm_head_src=lmh_path, # None = skip (add later) - ) +This module re-exports that API unchanged and adds +``build_qwen3_transformer_only_stages`` as a backward-compatible alias for +``build_decoder_pipeline_stages``. New code should prefer the generic names. """ from __future__ import annotations -import json -import logging -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Any - - -logger = logging.getLogger(__name__) - -# Default filenames inside the bundle directory. -DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" -DEFAULT_CONTEXT_FILENAME = "ctx.onnx" -DEFAULT_ITERATOR_FILENAME = "iter.onnx" -DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" - -# Tokenizer files written by AutoTokenizer.save_pretrained. -_TOKENIZER_FILES = [ - "tokenizer.json", - "tokenizer_config.json", - "vocab.json", - "merges.txt", - "generation_config.json", - "special_tokens_map.json", -] - -# Regex for detecting indexed tensor names such as ``past_keys_3``. -_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") - - -# --------------------------------------------------------------------------- -# Pipeline data structures -# --------------------------------------------------------------------------- - - -@dataclass -class PipelineStage: - """One stage in an onnxruntime-genai multi-model pipeline. - - Attributes: - name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. - filename: ONNX filename inside the bundle directory. - run_on_prompt: Whether genai runs this stage during the prefill pass. - run_on_token_gen: Whether genai runs this stage during decode steps. - inputs: Actual ONNX input tensor names (not format strings). - outputs: Actual ONNX output tensor names (not format strings). - is_lm_head: Set ``True`` for the final language-model head stage. - """ - - name: str - filename: str - run_on_prompt: bool - run_on_token_gen: bool - inputs: list[str] - outputs: list[str] - is_lm_head: bool = False - session_options: dict | None = None - """Per-stage ORT session options (e.g. provider_options for QNN). - - When set, emitted verbatim as the ``session_options`` key in the - ``genai_config.json`` pipeline stage. Leave ``None`` (default) for - stages that should run on the default (CPU) provider. - """ - - def to_dict(self) -> dict: - """Serialize to the dict format expected by ``genai_config.json``.""" - d: dict = { - "filename": self.filename, - "inputs": list(self.inputs), - "outputs": list(self.outputs), - "run_on_prompt": self.run_on_prompt, - "run_on_token_gen": self.run_on_token_gen, - } - if self.session_options: - d["session_options"] = self.session_options - if self.is_lm_head: - d["is_lm_head"] = True - return d - - -@dataclass -class DecoderIOMapping: - """Maps genai's abstract I/O concepts to ONNX tensor name format strings. - - The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is - the convention genai uses to expand per-layer KV cache tensor names - (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). - - All fields default to the names produced by the Qwen3 transformer-only - export. - """ - - input_ids: str = "input_ids" - past_sequence_length: str = "past_seq_len" - total_sequence_length: str = "total_seq_len" - past_key_names: str = "past_keys_%d" - past_value_names: str = "past_values_%d" - logits: str = "logits" - present_key_names: str = "present_keys_%d" - present_value_names: str = "present_values_%d" - - def inputs_dict(self) -> dict: - """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" - return { - "input_ids": self.input_ids, - "past_sequence_length": self.past_sequence_length, - "total_sequence_length": self.total_sequence_length, - "past_key_names": self.past_key_names, - "past_value_names": self.past_value_names, - } - - def outputs_dict(self) -> dict: - """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" - return { - "logits": self.logits, - "present_key_names": self.present_key_names, - "present_value_names": self.present_value_names, - } - - -# --------------------------------------------------------------------------- -# Generic config builder -# --------------------------------------------------------------------------- - - -def build_genai_config( - hf_config: Any, - *, - max_cache_len: int, - prefill_seq_len: int | None = None, - pipeline: list[PipelineStage], - decoder_io: DecoderIOMapping | None = None, -) -> dict: - """Build a ``genai_config.json`` dict for any decoder-pipeline model. - - This function is architecture-agnostic: the caller supplies the pipeline - stages and the I/O name mapping so no tensor names are hardcoded here. - - Args: - hf_config: A ``transformers.PretrainedConfig``. Reads: - ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, - ``num_key_value_heads``, ``head_dim`` (optional, falls back to - ``hidden_size // num_attention_heads``), ``bos_token_id``, - ``eos_token_id``, ``pad_token_id``, ``vocab_size``. - max_cache_len: Static KV cache length → ``context_length`` and - ``search.max_length``. - prefill_seq_len: When given, emits a ``sliding_window`` section with - ``window_size=prefill_seq_len``. Pass ``None`` to omit. - pipeline: Ordered list of :class:`PipelineStage` describing each - model in the genai pipeline. - decoder_io: Format-string mapping from genai's abstract I/O names to - actual ONNX tensor names. Defaults to - :class:`DecoderIOMapping` (the Qwen3 default names). - - Returns: - A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. - """ - if decoder_io is None: - decoder_io = DecoderIOMapping() - - num_layers: int = hf_config.num_hidden_layers - head_size: int = getattr( - hf_config, - "head_dim", - hf_config.hidden_size // hf_config.num_attention_heads, - ) - - eos_token_id = hf_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - - pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id - - decoder_section: dict = { - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_key_value_heads": hf_config.num_key_value_heads, - "num_hidden_layers": num_layers, - "head_size": head_size, - } - - if prefill_seq_len is not None: - decoder_section["sliding_window"] = { - "window_size": prefill_seq_len, - "pad_value": 0, - "alignment": "left", - "slide_inputs": True, - "slide_key_value_cache": False, - } - - decoder_section["inputs"] = decoder_io.inputs_dict() - decoder_section["outputs"] = decoder_io.outputs_dict() - decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] - - return { - "model": { - "type": "decoder-pipeline", - "bos_token_id": hf_config.bos_token_id, - "eos_token_id": eos_token_id, - "pad_token_id": pad_token_id, - "vocab_size": hf_config.vocab_size, - "context_length": max_cache_len, - "decoder": decoder_section, - }, - "search": { - "max_length": max_cache_len, - "min_length": 0, - "do_sample": False, - "past_present_share_buffer": True, - }, - } - - -# --------------------------------------------------------------------------- -# ONNX introspection helpers -# --------------------------------------------------------------------------- - - -def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: - """Return ``(input_names, output_names)`` from an ONNX model graph header. - - External data is intentionally not loaded — only the graph topology is read, - so this is fast even for large quantized models. - """ - try: - import onnx - except ImportError as exc: - raise ImportError( - "The 'onnx' package is required for ONNX introspection. " - "Install it with: pip install onnx" - ) from exc - model = onnx.load(str(onnx_path), load_external_data=False) - return ( - [inp.name for inp in model.graph.input], - [out.name for out in model.graph.output], - ) - - -def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: - """Detect ``prefix%d`` patterns from a list of indexed tensor names. - - Scans *names* for entries matching ```` where exactly - *num_layers* consecutive zero-based indices are present. - - Returns: - ``{prefix: "prefix%d"}`` for each qualifying group, in the order the - prefixes first appear in *names*. Only groups covering the full - ``[0, num_layers)`` index range are returned. - - Examples:: - - >>> _detect_format_patterns( - ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], - ... num_layers=2, - ... ) - {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} - """ - groups: dict[str, list[int]] = {} - for name in names: - m = _KV_INDEXED_RE.match(name) - if m: - prefix, idx = m.group(1), int(m.group(2)) - groups.setdefault(prefix, []).append(idx) - - return { - prefix: f"{prefix}%d" - for prefix, indices in groups.items() - if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) - } - - -def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: - """Sort *patterns* keys by when ``0`` first appears in *names*.""" - - def _key(prefix: str) -> int: - try: - return names.index(f"{prefix}0") - except ValueError: - return len(names) - - return sorted(patterns.keys(), key=_key) - - -# --------------------------------------------------------------------------- -# Per-EP stage session_options helpers -# --------------------------------------------------------------------------- - - -def _qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: - """Return the ``session_options`` block that routes a stage to QNN HTP. - - Args: - log_id: ORT log identifier (shown in ORT logs), e.g. - ``"onnxruntime-genai.context"``. - soc_model: Snapdragon SoC model number passed to the QNN HTP backend. - ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other - SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). - - Returns: - Dict suitable for the ``session_options`` key of a pipeline stage in - ``genai_config.json``. - """ - return { - "log_id": log_id, - "provider_options": [ - { - "qnn": { - "backend_path": "QnnHtp.dll", - "htp_performance_mode": "burst", - "htp_graph_finalization_optimization_mode": "3", - "soc_model": soc_model, - } - } - ], - "intra_op_num_threads": 2, - "inter_op_num_threads": 1, - } - - -# --------------------------------------------------------------------------- -# Qwen3 transformer-only pipeline factory -# --------------------------------------------------------------------------- - - -def build_qwen3_transformer_only_stages( - context_onnx: str | Path, - iterator_onnx: str | Path, - num_layers: int, - *, - context_filename: str = DEFAULT_CONTEXT_FILENAME, - iterator_filename: str = DEFAULT_ITERATOR_FILENAME, - embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, - lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, - ep: str = "cpu", - soc_model: str = "60", -) -> tuple[list[PipelineStage], DecoderIOMapping]: - """Build pipeline stages by introspecting the built ONNX models. - - Reads actual tensor names from *context_onnx* and *iterator_onnx* so the - generated ``genai_config.json`` can never drift out of sync with the real - model I/O — no tensor names are hardcoded. - - Args: - context_onnx: Path to the built prefill/context ONNX. - iterator_onnx: Path to the built decode/iterator ONNX. - num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). - context_filename: Bundle filename for the context model. - iterator_filename: Bundle filename for the iterator model. - embeddings_filename: Bundle filename for the embeddings model. - lm_head_filename: Bundle filename for the lm_head model. - ep: Execution provider for the transformer stages. ``"qnn"`` injects - QNN HTP ``session_options`` into the ``context`` and ``iterator`` - stages so they run on the NPU while ``embeddings`` and ``lm_head`` - continue on CPU. ``"cpu"`` (default) omits ``session_options`` - from all stages. - soc_model: Snapdragon SoC model number forwarded to the QNN backend - when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. - - Returns: - ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and - the :class:`DecoderIOMapping` derived from the introspected tensor names. - """ - ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) - iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) - - # Detect per-layer KV format-string patterns in the context model. - input_patterns = _detect_format_patterns(ctx_inputs, num_layers) - output_patterns = _detect_format_patterns(ctx_outputs, num_layers) - - in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) - out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) - - past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" - past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" - pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" - pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" - - # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. - non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] - seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] - hidden_state_in = next( - (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" - ) - past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") - total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") - - # Non-indexed output: hidden-state output of the transformer stack. - hidden_state_out = next( - (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" - ) - - decoder_io = DecoderIOMapping( - past_sequence_length=past_seq_name, - total_sequence_length=total_seq_name, - past_key_names=past_key_fmt, - past_value_names=past_val_fmt, - present_key_names=pres_key_fmt, - present_value_names=pres_val_fmt, - ) - - # Per-stage session_options: NPU stages get QNN config; CPU and others get None. - ctx_session_opts: dict | None = None - iter_session_opts: dict | None = None - if ep == "qnn": - ctx_session_opts = _qnn_stage_session_options( - "onnxruntime-genai.context", soc_model=soc_model - ) - iter_session_opts = _qnn_stage_session_options( - "onnxruntime-genai.iterator", soc_model=soc_model - ) - - stages: list[PipelineStage] = [ - PipelineStage( - name="embeddings", - filename=embeddings_filename, - run_on_prompt=True, - run_on_token_gen=True, - inputs=[decoder_io.input_ids], - outputs=[hidden_state_in], - ), - PipelineStage( - name="context", - filename=context_filename, - run_on_prompt=True, - run_on_token_gen=False, - inputs=ctx_inputs, - outputs=ctx_outputs, - session_options=ctx_session_opts, - ), - PipelineStage( - name="iterator", - filename=iterator_filename, - run_on_prompt=False, - run_on_token_gen=True, - inputs=iter_inputs, - outputs=iter_outputs, - session_options=iter_session_opts, - ), - PipelineStage( - name="lm_head", - filename=lm_head_filename, - run_on_prompt=True, - run_on_token_gen=True, - inputs=[hidden_state_out], - outputs=[decoder_io.logits], - is_lm_head=True, - ), - ] - return stages, decoder_io - - -# --------------------------------------------------------------------------- -# Bundle assembler -# --------------------------------------------------------------------------- - - -def write_genai_bundle( - output_dir: str | Path, - *, - context_onnx: str | Path, - iterator_onnx: str | Path, - model_id: str, - max_cache_len: int, - prefill_seq_len: int, - embeddings_src: str | Path | None = None, - lm_head_src: str | Path | None = None, - context_filename: str = DEFAULT_CONTEXT_FILENAME, - iterator_filename: str = DEFAULT_ITERATOR_FILENAME, - embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, - lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, - ep: str = "cpu", - soc_model: str = "60", -) -> Path: - """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. - - Copies the winml-built transformer ONNX files, placeholder embedding / - lm_head models (when provided), HF tokenizer files, and writes - ``genai_config.json``. Tensor names in the config are derived by - introspecting the built ONNX files rather than being hardcoded. - - Args: - output_dir: Destination directory (created if absent). - context_onnx: Path to the built prefill/context ONNX. - iterator_onnx: Path to the built decode/iterator ONNX. - model_id: HuggingFace model ID or local path for config + tokenizer. - max_cache_len: Static KV cache length (= ``context_length`` in genai). - prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). - embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. - lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. - context_filename: Bundle filename for the context model. - iterator_filename: Bundle filename for the iterator model. - embeddings_filename: Bundle filename for the embeddings model. - lm_head_filename: Bundle filename for the lm_head model. - ep: Execution provider for the transformer (context/iterator) stages. - ``"qnn"`` injects QNN HTP ``session_options`` so those stages run - on the NPU while embeddings and lm_head run on CPU. - ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). - soc_model: Snapdragon SoC model passed to the QNN backend when - ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. - - Returns: - Path to the written ``genai_config.json``. - """ - from transformers import AutoConfig, AutoTokenizer - - from winml.modelkit.onnx import copy_onnx_model - - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - context_onnx = Path(context_onnx) - iterator_onnx = Path(iterator_onnx) - - # 1. Copy winml-built transformer ONNX files. - logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) - copy_onnx_model(context_onnx, output_dir / context_filename) - - logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) - copy_onnx_model(iterator_onnx, output_dir / iterator_filename) - - # 2. Copy placeholder models (embeddings + lm_head). - if embeddings_src is not None: - logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) - copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) - else: - logger.warning( - "embeddings_src not provided — '%s' is missing from bundle.", - embeddings_filename, - ) - - if lm_head_src is not None: - logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) - copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) - else: - logger.warning( - "lm_head_src not provided — '%s' is missing from bundle.", - lm_head_filename, - ) - - # 3. Save tokenizer files from the HF snapshot. - logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.save_pretrained(str(output_dir)) - - # 4. Build pipeline stages by introspecting the source ONNX files. - hf_config = AutoConfig.from_pretrained(model_id) - stages, decoder_io = build_qwen3_transformer_only_stages( - context_onnx, - iterator_onnx, - num_layers=hf_config.num_hidden_layers, - context_filename=context_filename, - iterator_filename=iterator_filename, - embeddings_filename=embeddings_filename, - lm_head_filename=lm_head_filename, - ep=ep, - soc_model=soc_model, - ) - - # 5. Write genai_config.json. - config = build_genai_config( - hf_config, - max_cache_len=max_cache_len, - prefill_seq_len=prefill_seq_len, - pipeline=stages, - decoder_io=decoder_io, - ) - config_path = output_dir / "genai_config.json" - config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") - logger.info("Wrote genai_config.json -> %s", config_path) - - _log_bundle_summary(output_dir, config_path) - return config_path - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- +from winml.modelkit.utils.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + _detect_format_patterns, + build_decoder_pipeline_stages, + build_genai_config, + qnn_stage_session_options, + write_genai_bundle, +) -def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: - """Print a human-readable summary of the assembled bundle.""" - files = sorted(bundle_dir.iterdir()) - lines = [f"\n=== genai bundle: {bundle_dir} ==="] - for f in files: - size_kb = f.stat().st_size / 1024 - tag = "" - if f.name == "genai_config.json": - tag = " <- pipeline config" - elif f.name.endswith(".onnx"): - tag = " <- ONNX graph" - elif f.name.endswith(".data"): - tag = " <- ONNX external weights" - lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") - lines.append(f"\nConfig written to: {config_path}") - logger.info("\n".join(lines)) +# Backward-compatible alias: existing callers that import +# ``build_qwen3_transformer_only_stages`` continue to work unchanged. +build_qwen3_transformer_only_stages = build_decoder_pipeline_stages +# Keep the internal helper accessible for tests that import it directly. +_qnn_stage_session_options = qnn_stage_session_options __all__ = [ "DEFAULT_CONTEXT_FILENAME", @@ -663,7 +44,10 @@ def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: "DEFAULT_LM_HEAD_FILENAME", "DecoderIOMapping", "PipelineStage", + "_detect_format_patterns", + "build_decoder_pipeline_stages", "build_genai_config", "build_qwen3_transformer_only_stages", + "qnn_stage_session_options", "write_genai_bundle", ] diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 9a2c0b34c..f2dad581a 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, Any +from .calibration import get_quant_finalizer from .config import QuantizeResult, WinMLQuantizationConfig @@ -29,18 +30,17 @@ ] -# Names below are loaded lazily via ``__getattr__`` to avoid pulling in -# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports -# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports -# without triggering the heavy imports at runtime. +# ``quantize_onnx`` is loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization at import time. The TYPE_CHECKING re-import gives +# static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports. +# ``get_quant_finalizer`` is imported directly above — its module chain +# (calibration/__init__ -> registry) is lightweight and safe at import time. if TYPE_CHECKING: - from .calibration import get_quant_finalizer from .quantizer import quantize_onnx _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), - "get_quant_finalizer": (".calibration", "get_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index 39b9543c5..213199811 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -38,3 +38,4 @@ def finalize( model_id: str | None = None, ) -> WinMLQuantizationConfig: """Return ``quant`` populated with the graph-derived quant settings.""" + ... diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 34e2ef5e4..312a7358c 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -363,17 +363,16 @@ def generate_streaming( stream = self._tokenizer.create_stream() # type: ignore[union-attr] n = 0 - try: - while not generator.is_done(): - generator.generate_next_token() - new_token = generator.get_next_tokens()[0] - yield stream.decode(new_token) - n += 1 - if n >= cfg.max_new_tokens: - break - finally: - # Explicit deletion releases the KV cache buffer held by the generator. - del generator + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + yield stream.decode(new_token) + n += 1 + if n >= cfg.max_new_tokens: + break + # ``generator`` (og.Generator) holds the KV cache buffer; releasing the + # reference here (end of scope) frees it before the caller processes the + # last yielded token, which is earlier than waiting for GC. # ------------------------------------------------------------------ # Chat-template helpers diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py new file mode 100644 index 000000000..33645cdaf --- /dev/null +++ b/src/winml/modelkit/utils/genai.py @@ -0,0 +1,663 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Generic onnxruntime-genai bundle utilities for decoder-pipeline models. + +The bundle is a directory that ``onnxruntime-genai`` can load directly via +``og.Config(str(bundle_dir))``. It contains: + + genai_config.json — pipeline config consumed by onnxruntime-genai + ctx.onnx — prefill/context ONNX + iter.onnx — iteration/decode ONNX + embeddings.onnx — embedding-lookup ONNX + lm_head.onnx — lm_head ONNX + tokenizer.json — HF tokenizer files (downloaded from the model repo) + tokenizer_config.json + vocab.json / merges.txt / generation_config.json + +The pipeline follows the standard 4-stage decoder layout: + + input_ids → [embeddings] → input_hidden_states + → [context | iterator] → output_hidden_states + present KVs + → [lm_head] → logits + +The context stage runs on the prompt (prefill); the iterator stage runs on each +subsequent decode step. Both share the same KV cache buffer via genai's +``past_present_share_buffer`` mode. + +Public API:: + + from winml.modelkit.utils.genai import ( + build_genai_config, + build_decoder_pipeline_stages, + write_genai_bundle, + DecoderIOMapping, + PipelineStage, + qnn_stage_session_options, + ) + + # Build stages by introspecting the ONNX I/O (no hardcoded tensor names) + stages, decoder_io = build_decoder_pipeline_stages( + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers, ep="qnn" + ) + cfg = build_genai_config( + hf_config, max_cache_len=256, prefill_seq_len=64, + pipeline=stages, decoder_io=decoder_io, + ) + + # Or one-shot bundle assembly + write_genai_bundle( + Path("out/bundle"), + context_onnx=ctx_path, + iterator_onnx=iter_path, + model_id="Qwen/Qwen3-0.6B", + max_cache_len=256, + prefill_seq_len=64, + embeddings_src=emb_path, # None = skip (add later) + lm_head_src=lmh_path, # None = skip (add later) + ep="qnn", + ) +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + +# Default filenames inside the bundle directory. +DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" +DEFAULT_CONTEXT_FILENAME = "ctx.onnx" +DEFAULT_ITERATOR_FILENAME = "iter.onnx" +DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" + +# Regex for detecting indexed tensor names such as ``past_keys_3``. +_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") + + +# --------------------------------------------------------------------------- +# Pipeline data structures +# --------------------------------------------------------------------------- + + +@dataclass +class PipelineStage: + """One stage in an onnxruntime-genai multi-model pipeline. + + Attributes: + name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. + filename: ONNX filename inside the bundle directory. + run_on_prompt: Whether genai runs this stage during the prefill pass. + run_on_token_gen: Whether genai runs this stage during decode steps. + inputs: Actual ONNX input tensor names (not format strings). + outputs: Actual ONNX output tensor names (not format strings). + is_lm_head: Set ``True`` for the final language-model head stage. + """ + + name: str + filename: str + run_on_prompt: bool + run_on_token_gen: bool + inputs: list[str] + outputs: list[str] + is_lm_head: bool = False + session_options: dict | None = None + """Per-stage ORT session options (e.g. provider_options for QNN). + + When set, emitted verbatim as the ``session_options`` key in the + ``genai_config.json`` pipeline stage. Leave ``None`` (default) for + stages that should run on the default (CPU) provider. + """ + + def to_dict(self) -> dict: + """Serialize to the dict format expected by ``genai_config.json``.""" + d: dict = { + "filename": self.filename, + "inputs": list(self.inputs), + "outputs": list(self.outputs), + "run_on_prompt": self.run_on_prompt, + "run_on_token_gen": self.run_on_token_gen, + } + if self.session_options: + d["session_options"] = self.session_options + if self.is_lm_head: + d["is_lm_head"] = True + return d + + +@dataclass +class DecoderIOMapping: + """Maps genai's abstract I/O concepts to ONNX tensor name format strings. + + The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is + the convention genai uses to expand per-layer KV cache tensor names + (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). + + Defaults match the Qwen3 transformer-only export naming; override any field + when building bundles for models with different tensor names. + """ + + input_ids: str = "input_ids" + past_sequence_length: str = "past_seq_len" + total_sequence_length: str = "total_seq_len" + past_key_names: str = "past_keys_%d" + past_value_names: str = "past_values_%d" + logits: str = "logits" + present_key_names: str = "present_keys_%d" + present_value_names: str = "present_values_%d" + + def inputs_dict(self) -> dict: + """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" + return { + "input_ids": self.input_ids, + "past_sequence_length": self.past_sequence_length, + "total_sequence_length": self.total_sequence_length, + "past_key_names": self.past_key_names, + "past_value_names": self.past_value_names, + } + + def outputs_dict(self) -> dict: + """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" + return { + "logits": self.logits, + "present_key_names": self.present_key_names, + "present_value_names": self.present_value_names, + } + + +# --------------------------------------------------------------------------- +# Generic config builder +# --------------------------------------------------------------------------- + + +def build_genai_config( + hf_config: Any, + *, + max_cache_len: int, + prefill_seq_len: int | None = None, + pipeline: list[PipelineStage], + decoder_io: DecoderIOMapping | None = None, +) -> dict: + """Build a ``genai_config.json`` dict for any decoder-pipeline model. + + This function is architecture-agnostic: the caller supplies the pipeline + stages and the I/O name mapping so no tensor names are hardcoded here. + + Args: + hf_config: A ``transformers.PretrainedConfig``. Reads: + ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, + ``num_key_value_heads``, ``head_dim`` (optional, falls back to + ``hidden_size // num_attention_heads``), ``bos_token_id``, + ``eos_token_id``, ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length → ``context_length`` and + ``search.max_length``. + prefill_seq_len: When given, emits a ``sliding_window`` section with + ``window_size=prefill_seq_len``. Pass ``None`` to omit. + pipeline: Ordered list of :class:`PipelineStage` describing each + model in the genai pipeline. + decoder_io: Format-string mapping from genai's abstract I/O names to + actual ONNX tensor names. Defaults to + :class:`DecoderIOMapping` (the standard names). + + Returns: + A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. + """ + if decoder_io is None: + decoder_io = DecoderIOMapping() + + num_layers: int = hf_config.num_hidden_layers + head_size: int = getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ) + + eos_token_id = hf_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + + pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id + + decoder_section: dict = { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + } + + if prefill_seq_len is not None: + decoder_section["sliding_window"] = { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + } + + decoder_section["inputs"] = decoder_io.inputs_dict() + decoder_section["outputs"] = decoder_io.outputs_dict() + decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] + + return { + "model": { + "type": "decoder-pipeline", + "bos_token_id": hf_config.bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "vocab_size": hf_config.vocab_size, + "context_length": max_cache_len, + "decoder": decoder_section, + }, + "search": { + "max_length": max_cache_len, + "min_length": 0, + "do_sample": False, + "past_present_share_buffer": True, + }, + } + + +# --------------------------------------------------------------------------- +# ONNX introspection helpers +# --------------------------------------------------------------------------- + + +def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: + """Return ``(input_names, output_names)`` from an ONNX model graph header. + + External data is intentionally not loaded — only the graph topology is read, + so this is fast even for large quantized models. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "The 'onnx' package is required for ONNX introspection. " + "Install it with: pip install onnx" + ) from exc + model = onnx.load(str(onnx_path), load_external_data=False) + return ( + [inp.name for inp in model.graph.input], + [out.name for out in model.graph.output], + ) + + +def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: + """Detect ``prefix%d`` patterns from a list of indexed tensor names. + + Scans *names* for entries matching ```` where exactly + *num_layers* consecutive zero-based indices are present. + + Returns: + ``{prefix: "prefix%d"}`` for each qualifying group, in the order the + prefixes first appear in *names*. Only groups covering the full + ``[0, num_layers)`` index range are returned. + + Examples:: + + >>> _detect_format_patterns( + ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], + ... num_layers=2, + ... ) + {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + """ + groups: dict[str, list[int]] = {} + for name in names: + m = _KV_INDEXED_RE.match(name) + if m: + prefix, idx = m.group(1), int(m.group(2)) + groups.setdefault(prefix, []).append(idx) + + return { + prefix: f"{prefix}%d" + for prefix, indices in groups.items() + if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) + } + + +def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: + """Sort *patterns* keys by when ``0`` first appears in *names*.""" + + def _key(prefix: str) -> int: + try: + return names.index(f"{prefix}0") + except ValueError: + return len(names) + + return sorted(patterns.keys(), key=_key) + + +# --------------------------------------------------------------------------- +# Per-EP stage session_options helpers +# --------------------------------------------------------------------------- + + +def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + +# --------------------------------------------------------------------------- +# Generic decoder-pipeline stage factory +# --------------------------------------------------------------------------- + + +def build_decoder_pipeline_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build pipeline stages by introspecting the built ONNX models. + + Reads actual tensor names from *context_onnx* and *iterator_onnx* so the + generated ``genai_config.json`` can never drift out of sync with the real + model I/O — no tensor names are hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer stages. ``"qnn"`` injects + QNN HTP ``session_options`` into the ``context`` and ``iterator`` + stages so they run on the NPU while ``embeddings`` and ``lm_head`` + continue on CPU. ``"cpu"`` (default) omits ``session_options`` + from all stages. + soc_model: Snapdragon SoC model number forwarded to the QNN backend + when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. + + Returns: + ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and + the :class:`DecoderIOMapping` derived from the introspected tensor names. + """ + ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) + iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) + + # Detect per-layer KV format-string patterns in the context model. + input_patterns = _detect_format_patterns(ctx_inputs, num_layers) + output_patterns = _detect_format_patterns(ctx_outputs, num_layers) + + in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) + out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) + + past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" + past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" + pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" + pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" + + # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. + non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] + seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] + hidden_state_in = next( + (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" + ) + past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") + total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") + + # Non-indexed output: hidden-state output of the transformer stack. + hidden_state_out = next( + (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" + ) + + decoder_io = DecoderIOMapping( + past_sequence_length=past_seq_name, + total_sequence_length=total_seq_name, + past_key_names=past_key_fmt, + past_value_names=past_val_fmt, + present_key_names=pres_key_fmt, + present_value_names=pres_val_fmt, + ) + + # Per-stage session_options: NPU stages get QNN config; CPU and others get None. + ctx_session_opts: dict | None = None + iter_session_opts: dict | None = None + if ep == "qnn": + ctx_session_opts = qnn_stage_session_options( + "onnxruntime-genai.context", soc_model=soc_model + ) + iter_session_opts = qnn_stage_session_options( + "onnxruntime-genai.iterator", soc_model=soc_model + ) + + stages: list[PipelineStage] = [ + PipelineStage( + name="embeddings", + filename=embeddings_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[decoder_io.input_ids], + outputs=[hidden_state_in], + ), + PipelineStage( + name="context", + filename=context_filename, + run_on_prompt=True, + run_on_token_gen=False, + inputs=ctx_inputs, + outputs=ctx_outputs, + session_options=ctx_session_opts, + ), + PipelineStage( + name="iterator", + filename=iterator_filename, + run_on_prompt=False, + run_on_token_gen=True, + inputs=iter_inputs, + outputs=iter_outputs, + session_options=iter_session_opts, + ), + PipelineStage( + name="lm_head", + filename=lm_head_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[hidden_state_out], + outputs=[decoder_io.logits], + is_lm_head=True, + ), + ] + return stages, decoder_io + + +# --------------------------------------------------------------------------- +# Bundle assembler +# --------------------------------------------------------------------------- + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> Path: + """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. + + Copies the winml-built transformer ONNX files, optional embedding / + lm_head models, HF tokenizer files, and writes ``genai_config.json``. + Tensor names in the config are derived by introspecting the built ONNX + files rather than being hardcoded, so this works for any model that + follows the 4-stage decoder-pipeline layout. + + Args: + output_dir: Destination directory (created if absent). + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + model_id: HuggingFace model ID or local path for config + tokenizer. + max_cache_len: Static KV cache length (= ``context_length`` in genai). + prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). + embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. + lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer (context/iterator) stages. + ``"qnn"`` injects QNN HTP ``session_options`` so those stages run + on the NPU while embeddings and lm_head run on CPU. + ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. + + Returns: + Path to the written ``genai_config.json``. + """ + from transformers import AutoConfig, AutoTokenizer + + from winml.modelkit.onnx import copy_onnx_model + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + context_onnx = Path(context_onnx) + iterator_onnx = Path(iterator_onnx) + + # 1. Copy winml-built transformer ONNX files. + logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) + copy_onnx_model(context_onnx, output_dir / context_filename) + + logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) + copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + + # 2. Copy placeholder models (embeddings + lm_head). + if embeddings_src is not None: + logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) + copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) + else: + logger.warning( + "embeddings_src not provided — '%s' is missing from bundle.", + embeddings_filename, + ) + + if lm_head_src is not None: + logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) + copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) + else: + logger.warning( + "lm_head_src not provided — '%s' is missing from bundle.", + lm_head_filename, + ) + + # 3. Save tokenizer files from the HF snapshot. + logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(str(output_dir)) + + # 4. Build pipeline stages by introspecting the source ONNX files. + hf_config = AutoConfig.from_pretrained(model_id) + stages, decoder_io = build_decoder_pipeline_stages( + context_onnx, + iterator_onnx, + num_layers=hf_config.num_hidden_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + ep=ep, + soc_model=soc_model, + ) + + # 5. Write genai_config.json. + config = build_genai_config( + hf_config, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + pipeline=stages, + decoder_io=decoder_io, + ) + config_path = output_dir / "genai_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + logger.info("Wrote genai_config.json -> %s", config_path) + + _log_bundle_summary(output_dir, config_path) + return config_path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: + """Print a human-readable summary of the assembled bundle.""" + files = sorted(bundle_dir.iterdir()) + lines = [f"\n=== genai bundle: {bundle_dir} ==="] + for f in files: + size_kb = f.stat().st_size / 1024 + tag = "" + if f.name == "genai_config.json": + tag = " <- pipeline config" + elif f.name.endswith(".onnx"): + tag = " <- ONNX graph" + elif f.name.endswith(".data"): + tag = " <- ONNX external weights" + lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") + lines.append(f"\nConfig written to: {config_path}") + logger.info("\n".join(lines)) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "qnn_stage_session_options", + "write_genai_bundle", +] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 900f8b664..e359f02e2 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -387,7 +387,7 @@ def _patch_onnx(self, n: int = 4): ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n)) iter_io = (self._ctx_inputs(n), self._ctx_outputs(n)) return patch( - "winml.modelkit.models.hf.qwen3.genai._introspect_onnx_io", + "winml.modelkit.utils.genai._introspect_onnx_io", side_effect=[ctx_io, iter_io], ) diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py index ad53ef352..6881f0f72 100644 --- a/tests/unit/quant/calibration/test_qwen3_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -16,8 +16,9 @@ from types import SimpleNamespace import numpy as np -import onnx import torch +from onnx import TensorProto, helper +from onnx import save as onnx_save from winml.modelkit.quant.calibration.qwen3_transformer_only import ( Qwen3DecodeTrajectoryCalibReader, @@ -47,27 +48,27 @@ def _fake_config() -> SimpleNamespace: def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: """Write a minimal graph carrying the inputs the readers introspect.""" inputs = [ - onnx.helper.make_tensor_value_info( - "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + helper.make_tensor_value_info( + "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ), - onnx.helper.make_tensor_value_info( - "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + helper.make_tensor_value_info( + "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] ), ] - out = onnx.helper.make_tensor_value_info( - "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + out = helper.make_tensor_value_info( + "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ) - gqa = onnx.helper.make_node( + gqa = helper.make_node( "GroupQueryAttention", ["input_hidden_states"], ["attn_out"], name="gqa_layer_0", domain="com.microsoft", ) - identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) - graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) - model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) - onnx.save(model, str(path)) + identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + onnx_save(model, str(path)) def test_graph_shapes_and_gqa_nodes(tmp_path):