Skip to content
Merged
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ def _run_quantize_stage(
) -> Path:
"""Run the quantize stage (if quant is configured).

Delegates single-pass quantization to ``quantize_onnx(config=...)``.
Delegates quantization to ``quantize_onnx(config=...)``.
The cmd layer only handles UI display and the QDQ skip check.

Args:
Expand Down
163 changes: 147 additions & 16 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Examples:
winml quantize -m model.onnx
winml quantize -m model.onnx --precision int8
winml quantize -m model.onnx -o model_qdq.onnx --samples 100
winml quantize -m model.onnx -o model_quantized.onnx --samples 100
winml quantize -m model.onnx --weight-type int8 --activation-type uint8
"""

Expand Down Expand Up @@ -47,14 +47,16 @@
type=click.Path(exists=True, path_type=Path),
help="Input ONNX model file",
)
@cli_utils.output_option("Output path (default: {input}_qdq.onnx)")
@cli_utils.output_option("Output path (default: {input}_quantized.onnx)")
@cli_utils.overwrite_option()
@cli_utils.precision_option(
default=None,
help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where "
"x in {4,8,16}, y in {8,16} (e.g., w4a16, w8a8, w8a16). "
"int4/w4a16 uses RTN weight-only quantization; "
"fp16 converts all FP32 tensors to FP16 (no QDQ)",
default=(),
multiple=True,
help_text="Quantization precision: fp16, int4, int8, int16, or w{x}a{y} where "
"x in {4,8,16}, y in {8,16} (e.g., w8a8, w8a16). "
"int4 uses RTN weight-only quantization; "
"fp16 converts all FP32 tensors to FP16 (no QDQ). "
"Repeat to chain passes in order (e.g. -p int4 -p fp16)",
optional_message="Overridden by explicit --weight-type/--activation-type",
)
@click.option(
Expand Down Expand Up @@ -111,7 +113,7 @@ def quantize(
model: Path,
output: Path | None,
overwrite: bool,
precision: str | None,
precision: tuple[str, ...],
samples: int,
method: str,
weight_type: str | None,
Expand All @@ -127,9 +129,12 @@ def quantize(
r"""Quantize ONNX model by inserting QDQ nodes, RTN weight-only, or convert to FP16.

This command applies quantization to an ONNX model. The algorithm is
auto-selected from the precision: int4/w4a16 → RTN weight-only,
auto-selected from the precision: int4 → RTN weight-only,
int8/int16/w8a8 → static QDQ, fp16 → FP16 conversion.

Repeat --precision to chain passes in order:
``-p int4 -p fp16`` runs RTN int4 quantization then FP16 conversion.

\b
Examples:
# Basic quantization with defaults (10 samples, uint8)
Expand All @@ -141,6 +146,9 @@ def quantize(
# RTN 4-bit weight-only quantization (no calibration data needed)
winml quantize -m model.onnx --precision int4

# RTN int4 followed by FP16 conversion (two-pass pipeline)
winml quantize -m model.onnx --precision int4 --precision fp16

# Int16 quantization
winml quantize -m model.onnx --precision int16

Expand Down Expand Up @@ -182,8 +190,29 @@ def quantize(
# Import quantizer (late import to speed up CLI)
from ..quant import WinMLQuantizationConfig, quantize_onnx

# ── Build config based on precision ──────────────────────────
precision_lower = precision.lower() if precision else None
# ── Multi-pass pipeline ───────────────────────────────────────
if len(precision) > 1:
_run_multi_precision(
ctx=ctx,
model=model,
output=output,
overwrite=overwrite,
precision=precision,
samples=samples,
method=method,
weight_type=weight_type,
activation_type=activation_type,
per_channel=per_channel,
symmetric=symmetric,
task=task,
model_id=model_id,
console=console,
)
return

# ── Single-precision (or default) path ───────────────────────
single = precision[0] if precision else None
precision_lower = single.lower() if single else None

if precision_lower == "fp16":
# FP16 conversion
Expand Down Expand Up @@ -211,10 +240,10 @@ def quantize(
else:
# QDQ calibrated quantization
resolved_weight, resolved_activation = _resolve_quant_types(
precision, weight_type, activation_type
single, weight_type, activation_type
)
if output is None:
output = model.parent / f"{model.stem}_qdq.onnx"
output = model.parent / f"{model.stem}_quantized.onnx"
config = WinMLQuantizationConfig(
samples=samples,
calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method),
Expand Down Expand Up @@ -243,13 +272,11 @@ def quantize(
console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}")

# ── Shared execution: print header, run, report ──────────────
# Refuse to clobber an existing output unless the user opted in. Runs after
# the per-precision default path is resolved, before any mkdir/work.
cli_utils.guard_output(output, overwrite)
output.parent.mkdir(parents=True, exist_ok=True)
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}")
console.print(f"[bold blue]Precision:[/bold blue] {single or 'auto'}")

try:
console.print(f"\n[bold]Running {label.lower()}...[/bold]")
Expand All @@ -275,6 +302,110 @@ def quantize(
raise click.ClickException(f"{label} failed: {e}") from e


def _cli_precision_to_mode(precision: str) -> str:
"""Map a CLI precision string to a quantizer pass mode."""
p = precision.lower()
if p == "fp16":
return "fp16"
if is_weight_only_precision(p):
return "rtn"
return "static"


def _run_multi_precision(
Comment thread
xieofxie marked this conversation as resolved.
*,
ctx: click.Context,
model: Path,
output: Path | None,
overwrite: bool,
precision: tuple[str, ...],
samples: int,
method: str,
weight_type: str | None,
activation_type: str | None,
per_channel: bool,
symmetric: bool,
task: str | None,
model_id: str | None,
console: Console,
) -> None:
"""Execute a multi-pass quantization pipeline from ordered precision strings."""
from ..config.precision import extract_weight_bits
from ..quant import Quantizer, WinMLQuantizationConfig, expand_precision

modes = [_cli_precision_to_mode(p) for p in precision]
has_calibration_pass = any(m == "static" for m in modes)

if not has_calibration_pass:
cli_utils.warn_ignored_calibration_options(
ctx, "No selected pass uses calibration data.", console=console
)

# Extract rtn_bits from the first weight-only precision in the list.
rtn_bits = next(
(extract_weight_bits(p.lower()) for p in precision if is_weight_only_precision(p.lower())),
4,
)

# Resolve weight/activation types from the first static precision in the list
# (same logic as single-pass path) so -p int16 -p fp16 uses int16, not uint8.
first_static = next(
(p for p in precision if _cli_precision_to_mode(p) == "static"),
None,
)
resolved_weight, resolved_activation = _resolve_quant_types(
first_static, weight_type, activation_type
)

config = WinMLQuantizationConfig(
rtn_bits=rtn_bits,
samples=samples,
calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method),
weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_weight),
activation_type=cast('Literal["uint8", "int8", "uint16", "int16"]', resolved_activation),
per_channel=per_channel,
symmetric=symmetric,
task=task,
model_id=model_id,
)

passes = []
for mode in modes:
passes.extend(expand_precision(mode, config))

label = " → ".join(p.lower() for p in precision)
if output is None:
suffix = "_".join(p.lower() for p in precision)
output = model.parent / f"{model.stem}_{suffix}.onnx"

cli_utils.guard_output(output, overwrite)
output.parent.mkdir(parents=True, exist_ok=True)
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print(f"[bold blue]Pipeline:[/bold blue] {label}")

try:
console.print(f"\n[bold]Running pipeline: {label}...[/bold]")
result = Quantizer(passes).run(model, output)

if result.success:
console.print("\n[bold green]Success![/bold green] Pipeline complete")
console.print(f"[dim]Output: {result.output_path}[/dim]")
console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]")
else:
console.print("\n[bold red]Pipeline failed:[/bold red]")
for error in result.errors:
console.print(f" {error}")
raise click.ClickException("Pipeline failed")

except click.ClickException:
raise
except Exception as e:
console.print(f"\n[bold red]Pipeline failed:[/bold red] {e}")
logger.exception("Pipeline failed")
raise click.ClickException(f"Pipeline failed: {e}") from e


def _resolve_quant_types(
precision: str | None,
weight_type: str | None,
Expand Down
28 changes: 25 additions & 3 deletions src/winml/modelkit/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,43 @@
Provides QDQ (Quantize-Dequantize) quantization for ONNX models.

Usage:
from winml.modelkit.quant import quantize_onnx, WinMLQuantizationConfig
from winml.modelkit.quant import (
quantize_onnx,
Quantizer,
expand_precision,
WinMLQuantizationConfig,
)

# Quick quantize with defaults (10 samples, uint8)
result = quantize_onnx("model.onnx")

# Custom config
result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100))

# Pipeline: RTN int4 weight-only
config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4)
result = Quantizer(expand_precision("rtn", config)).run("model.onnx", "out.onnx")
"""

from typing import TYPE_CHECKING, Any

from .config import QuantizeResult, WinMLQuantizationConfig
from .passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass


if TYPE_CHECKING:
from .quantizer import Quantizer, expand_precision, quantize_onnx


__all__ = [
"BaseQuantPass",
"FP16Pass",
"QuantizeResult",
"Quantizer",
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
"RTNPass",
"StaticPass",
"WinMLQuantizationConfig",
"expand_precision",
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
"get_quant_finalizer",
"quantize_onnx",
]
Expand All @@ -35,17 +55,19 @@
# without triggering the heavy imports at runtime.
if TYPE_CHECKING:
from .calibration import get_quant_finalizer
from .quantizer import quantize_onnx
from .quantizer import Quantizer, expand_precision, quantize_onnx


_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"quantize_onnx": (".quantizer", "quantize_onnx"),
"Quantizer": (".quantizer", "Quantizer"),
"expand_precision": (".quantizer", "expand_precision"),
"get_quant_finalizer": (".calibration", "get_quant_finalizer"),
}


def __getattr__(name: str) -> Any:
"""Lazy-load quantizer (imports onnxruntime.quantization)."""
"""Lazy-load quantizer module (avoids importing onnxruntime at package import time)."""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
import importlib
Expand Down
18 changes: 18 additions & 0 deletions src/winml/modelkit/quant/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Quantization passes sub-package."""

from .base import BaseQuantPass
from .fp16 import FP16Pass
from .rtn import RTNPass
from .static import StaticPass


__all__ = [
"BaseQuantPass",
"FP16Pass",
"RTNPass",
"StaticPass",
]
65 changes: 65 additions & 0 deletions src/winml/modelkit/quant/passes/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Base class for quantization passes."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from pathlib import Path

from ..config import QuantizeResult, WinMLQuantizationConfig


class BaseQuantPass(ABC):
"""Abstract base class for a single quantization pass.

Each pass is constructed with a ``WinMLQuantizationConfig`` that provides
all settings. Passes read only the fields relevant to them and ignore the
rest, so a single shared config object can be threaded through every pass
in a :class:`~winml.modelkit.quant.quantizer.Quantizer` pipeline.

Example::

pass_ = FP16Pass(config)
result = pass_.run(model_path, output_path)
"""

def __init__(self, config: WinMLQuantizationConfig) -> None:
self._config = config

@property
def config(self) -> WinMLQuantizationConfig:
"""Return the shared quantization configuration."""
return self._config

@abstractmethod
def run(
self,
model_path: Path,
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
output_path: Path,
*,
use_external_data: bool = True,
) -> QuantizeResult:
"""Run this quantization pass.

Args:
model_path: Path to the input ONNX model.
output_path: Path where the output ONNX model should be written.
use_external_data: Whether to write large tensors as external data.

Returns:
:class:`~winml.modelkit.quant.config.QuantizeResult` describing
the outcome of this pass.

Note:
Passes use file-based I/O because ORT's calibration and RTN APIs
operate on paths, and external-data models cannot be held fully in
memory. A future enhancement could add an optional in-memory
fast-path for small single-pass models.
"""
Loading
Loading