Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions olive/passes/onnx/rtn_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import shutil
from pathlib import Path
from typing import Optional

Expand Down Expand Up @@ -69,9 +70,92 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
default_value=None,
description="List of node names to include in quantization.",
),
"components_to_skip": PassConfigParam(
type_=list[str],
default_value=None,
description=(
"Optional list of component names to skip quantization for "
"(e.g. ['embedding'] to pass the embedding model through unchanged). "
"When a composite model component's name matches an entry in this list, "
"its files are copied to the output path without modification. "
"When not set, all components are quantized (default, backward compatible). "
"Has no effect on single-component (non-composite) models."
),
),
**get_external_data_config(),
}

def run(self, model, output_model_path: str):
"""Run quantization, skipping components listed in components_to_skip.

Overrides the base Pass.run() to intercept CompositeModelHandler processing.
Components whose names appear in config.components_to_skip are copied to the
output path unchanged instead of being quantized.
"""
from olive.model import CompositeModelHandler

components_to_skip: set[str] = set(self.config.components_to_skip or [])
if not components_to_skip or not isinstance(model, CompositeModelHandler):
return super().run(model, output_model_path)

# Warn about component names that won't match anything — misspellings are
# silently ignored otherwise since skipping is non-fatal.
all_component_names = {name for name, _ in model.get_model_components()}
unknown_skips = components_to_skip - all_component_names
if unknown_skips:
logger.warning(
"OnnxBlockWiseRtnQuantization: components_to_skip contains name(s) not found "
"in this composite model: %s. Available components: %s",
sorted(unknown_skips),
sorted(all_component_names),
)

# Mirror the _initialized guard from the base Pass.run() implementation.
# Pass.run() checks and sets self._initialized before calling _run_for_config;
# since we bypass super().run() for composite models, we must replicate it here
# so lazy initialization (e.g. loading config, setting up hardware state) still runs.
if not self._initialized:
self._initialize()
self._initialized = True

model_dir = Path(output_model_path).with_suffix("")
model_dir.mkdir(parents=True, exist_ok=True)

components = []
component_names = []
for component_name, component_model in model.get_model_components():
component_output_path = model_dir / component_name
if component_name in components_to_skip:
logger.info(
"OnnxBlockWiseRtnQuantization: skipping quantization for component '%s'.",
component_name,
)
src = Path(component_model.model_path)
# model_path may point to the .onnx file rather than its parent dir
src_dir = src.parent if src.is_file() else src
if src_dir != component_output_path:
if component_output_path.exists():
shutil.rmtree(str(component_output_path))
shutil.copytree(str(src_dir), str(component_output_path))
# onnx_file_name may be None if the handler was created without an explicit name;
# fall back to 'model.onnx' which is the standard Olive convention.
onnx_file_name = getattr(component_model, "onnx_file_name", None) or "model.onnx"
output_component = ONNXModelHandler(
model_path=str(component_output_path),
onnx_file_name=onnx_file_name,
model_attributes=component_model.model_attributes,
)
Pass._carry_forward_additional_files(component_model, output_component)
else:
output_component = self.run(component_model, str(component_output_path))
components.append(output_component)
component_names.append(component_name)

output_model = CompositeModelHandler(components, component_names, model_path=model_dir)
output_model.model_attributes = output_model.model_attributes or model.model_attributes
Pass._carry_forward_additional_files(model, output_model)
return output_model

def _run_for_config(
self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
Expand Down
146 changes: 146 additions & 0 deletions test/passes/onnx/test_rtn_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,149 @@ def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, t
assert "weight" not in init_names, (
f"Original FP32 'weight' initializer should have been removed, found: {init_names}"
)


class TestRTNQuantizationComponentsToSkip:
"""Tests for the components_to_skip parameter on OnnxBlockWiseRtnQuantization."""

@staticmethod
def _make_matmul_model(tmp_path, name: str) -> ONNXModelHandler:
"""Create a tiny MatMul ONNX model and return an ONNXModelHandler."""
weight = np.random.randn(64, 128).astype(np.float32)
inp = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 64])
out = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 128])
weight_init = onnx.helper.make_tensor(
name="weight",
data_type=onnx.TensorProto.FLOAT,
dims=[64, 128],
vals=weight.flatten().tolist(),
)
node = onnx.helper.make_node("MatMul", ["input", "weight"], ["output"], name="MatMul_Node")
graph = onnx.helper.make_graph([node], "g", [inp], [out], initializer=[weight_init])
model_def = onnx.helper.make_model(graph, producer_name="test")
model_def.opset_import[0].version = 13

model_dir = tmp_path / name
model_dir.mkdir(parents=True, exist_ok=True)
onnx.save(model_def, str(model_dir / "model.onnx"))
return ONNXModelHandler(model_path=str(model_dir), onnx_file_name="model.onnx")

@staticmethod
def _make_pass(components_to_skip=None) -> OnnxBlockWiseRtnQuantization:
accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider")
config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}
if components_to_skip is not None:
config["components_to_skip"] = components_to_skip
return create_pass_from_dict(
OnnxBlockWiseRtnQuantization, config, disable_search=True, accelerator_spec=accelerator_spec
)

def test_components_to_skip_passes_component_through_unchanged(self, tmp_path):
"""Skipped component's model files are copied without quantization."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
embedding = self._make_matmul_model(tmp_path / "src", "embedding")

composite = CompositeModelHandler(
model_components=[decoder, embedding],
model_component_names=["decoder", "embedding"],
model_path=str(tmp_path / "src"),
)

p = self._make_pass(components_to_skip=["embedding"])
result = p.run(composite, str(tmp_path / "out"))

assert isinstance(result, CompositeModelHandler)
assert result.model_component_names == ["decoder", "embedding"]

# decoder should be quantized (MatMulNBits present)
decoder_out = next(m for name, m in result.get_model_components() if name == "decoder")
decoder_ir = ir.load(decoder_out.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in decoder_ir.graph.all_nodes()), (
"decoder should be quantized (MatMulNBits expected)"
)

# embedding should be unchanged (original MatMul still present)
emb_out = next(m for name, m in result.get_model_components() if name == "embedding")
emb_ir = ir.load(emb_out.model_path)
has_matmul = any(n.op_type == str(OpType.MatMul) for n in emb_ir.graph.all_nodes())
has_nbits = any(n.op_type == str(OpType.MatMulNBits) for n in emb_ir.graph.all_nodes())
assert has_matmul, "embedding should still contain the original MatMul op"
assert not has_nbits, "embedding should not be quantized (no MatMulNBits expected)"

def test_components_to_skip_none_quantizes_all(self, tmp_path):
"""When components_to_skip is not set, all composite components are quantized."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
embedding = self._make_matmul_model(tmp_path / "src", "embedding")

composite = CompositeModelHandler(
model_components=[decoder, embedding],
model_component_names=["decoder", "embedding"],
model_path=str(tmp_path / "src"),
)

p = self._make_pass(components_to_skip=None)
result = p.run(composite, str(tmp_path / "out"))

assert isinstance(result, CompositeModelHandler)

for name, component in result.get_model_components():
component_ir = ir.load(component.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in component_ir.graph.all_nodes()), (
f"component '{name}' should be quantized when components_to_skip is None"
)

def test_components_to_skip_does_not_affect_single_model(self, tmp_path):
"""components_to_skip has no effect on non-composite (single) models."""
model = self._make_matmul_model(tmp_path, "single")
p = self._make_pass(components_to_skip=["single"])
result = p.run(model, str(tmp_path / "out"))

# Single model should still be quantized despite its path matching the skip list
result_ir = ir.load(result.model_path)
assert any(n.op_type == str(OpType.MatMulNBits) for n in result_ir.graph.all_nodes()), (
"Single-component model should be quantized even when components_to_skip is set"
)

def test_components_to_skip_in_default_config(self):
"""components_to_skip must appear in _default_config with None as default."""
accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider")
config = OnnxBlockWiseRtnQuantization._default_config(accelerator_spec) # pylint: disable=protected-access
assert "components_to_skip" in config
assert config["components_to_skip"].default_value is None
assert config["components_to_skip"].required is False

def test_components_to_skip_unknown_name_warns(self, tmp_path):
"""Misspelled or missing component names in components_to_skip must log a warning."""
from olive.model.handler.composite import CompositeModelHandler

decoder = self._make_matmul_model(tmp_path / "src", "decoder")
vision = self._make_matmul_model(tmp_path / "src", "vision")
composite = CompositeModelHandler(
model_components=[decoder, vision],
model_component_names=["decoder", "vision"],
)

p = self._make_pass(components_to_skip=["typo_component"])

import logging

records = []

class _Handler(logging.Handler):
def emit(self, record):
records.append(record.getMessage())

rtn_logger = logging.getLogger("olive.passes.onnx.rtn_quantization")
rtn_logger.addHandler(_Handler())
try:
p.run(composite, str(tmp_path / "out"))
finally:
rtn_logger.handlers = [h for h in rtn_logger.handlers if not isinstance(h, _Handler)]

assert any("typo_component" in msg for msg in records), (
f"Expected warning about unknown component name 'typo_component', got: {records}"
)
Loading