diff --git a/olive/passes/onnx/rtn_quantization.py b/olive/passes/onnx/rtn_quantization.py index e66cec112..9148903f4 100644 --- a/olive/passes/onnx/rtn_quantization.py +++ b/olive/passes/onnx/rtn_quantization.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging +import shutil from pathlib import Path from typing import Optional @@ -69,9 +70,92 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon default_value=None, description="List of node names to include in quantization.", ), + "components_to_skip": PassConfigParam( + type_=list[str], + default_value=None, + description=( + "Optional list of component names to skip quantization for " + "(e.g. ['embedding'] to pass the embedding model through unchanged). " + "When a composite model component's name matches an entry in this list, " + "its files are copied to the output path without modification. " + "When not set, all components are quantized (default, backward compatible). " + "Has no effect on single-component (non-composite) models." + ), + ), **get_external_data_config(), } + def run(self, model, output_model_path: str): + """Run quantization, skipping components listed in components_to_skip. + + Overrides the base Pass.run() to intercept CompositeModelHandler processing. + Components whose names appear in config.components_to_skip are copied to the + output path unchanged instead of being quantized. + """ + from olive.model import CompositeModelHandler + + components_to_skip: set[str] = set(self.config.components_to_skip or []) + if not components_to_skip or not isinstance(model, CompositeModelHandler): + return super().run(model, output_model_path) + + # Warn about component names that won't match anything — misspellings are + # silently ignored otherwise since skipping is non-fatal. + all_component_names = {name for name, _ in model.get_model_components()} + unknown_skips = components_to_skip - all_component_names + if unknown_skips: + logger.warning( + "OnnxBlockWiseRtnQuantization: components_to_skip contains name(s) not found " + "in this composite model: %s. Available components: %s", + sorted(unknown_skips), + sorted(all_component_names), + ) + + # Mirror the _initialized guard from the base Pass.run() implementation. + # Pass.run() checks and sets self._initialized before calling _run_for_config; + # since we bypass super().run() for composite models, we must replicate it here + # so lazy initialization (e.g. loading config, setting up hardware state) still runs. + if not self._initialized: + self._initialize() + self._initialized = True + + model_dir = Path(output_model_path).with_suffix("") + model_dir.mkdir(parents=True, exist_ok=True) + + components = [] + component_names = [] + for component_name, component_model in model.get_model_components(): + component_output_path = model_dir / component_name + if component_name in components_to_skip: + logger.info( + "OnnxBlockWiseRtnQuantization: skipping quantization for component '%s'.", + component_name, + ) + src = Path(component_model.model_path) + # model_path may point to the .onnx file rather than its parent dir + src_dir = src.parent if src.is_file() else src + if src_dir != component_output_path: + if component_output_path.exists(): + shutil.rmtree(str(component_output_path)) + shutil.copytree(str(src_dir), str(component_output_path)) + # onnx_file_name may be None if the handler was created without an explicit name; + # fall back to 'model.onnx' which is the standard Olive convention. + onnx_file_name = getattr(component_model, "onnx_file_name", None) or "model.onnx" + output_component = ONNXModelHandler( + model_path=str(component_output_path), + onnx_file_name=onnx_file_name, + model_attributes=component_model.model_attributes, + ) + Pass._carry_forward_additional_files(component_model, output_component) + else: + output_component = self.run(component_model, str(component_output_path)) + components.append(output_component) + component_names.append(component_name) + + output_model = CompositeModelHandler(components, component_names, model_path=model_dir) + output_model.model_attributes = output_model.model_attributes or model.model_attributes + Pass._carry_forward_additional_files(model, output_model) + return output_model + def _run_for_config( self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: diff --git a/test/passes/onnx/test_rtn_quantization.py b/test/passes/onnx/test_rtn_quantization.py index edec80ec0..460cf9142 100644 --- a/test/passes/onnx/test_rtn_quantization.py +++ b/test/passes/onnx/test_rtn_quantization.py @@ -427,3 +427,149 @@ def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, t assert "weight" not in init_names, ( f"Original FP32 'weight' initializer should have been removed, found: {init_names}" ) + + +class TestRTNQuantizationComponentsToSkip: + """Tests for the components_to_skip parameter on OnnxBlockWiseRtnQuantization.""" + + @staticmethod + def _make_matmul_model(tmp_path, name: str) -> ONNXModelHandler: + """Create a tiny MatMul ONNX model and return an ONNXModelHandler.""" + weight = np.random.randn(64, 128).astype(np.float32) + inp = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 64]) + out = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 128]) + weight_init = onnx.helper.make_tensor( + name="weight", + data_type=onnx.TensorProto.FLOAT, + dims=[64, 128], + vals=weight.flatten().tolist(), + ) + node = onnx.helper.make_node("MatMul", ["input", "weight"], ["output"], name="MatMul_Node") + graph = onnx.helper.make_graph([node], "g", [inp], [out], initializer=[weight_init]) + model_def = onnx.helper.make_model(graph, producer_name="test") + model_def.opset_import[0].version = 13 + + model_dir = tmp_path / name + model_dir.mkdir(parents=True, exist_ok=True) + onnx.save(model_def, str(model_dir / "model.onnx")) + return ONNXModelHandler(model_path=str(model_dir), onnx_file_name="model.onnx") + + @staticmethod + def _make_pass(components_to_skip=None) -> OnnxBlockWiseRtnQuantization: + accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider") + config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True} + if components_to_skip is not None: + config["components_to_skip"] = components_to_skip + return create_pass_from_dict( + OnnxBlockWiseRtnQuantization, config, disable_search=True, accelerator_spec=accelerator_spec + ) + + def test_components_to_skip_passes_component_through_unchanged(self, tmp_path): + """Skipped component's model files are copied without quantization.""" + from olive.model.handler.composite import CompositeModelHandler + + decoder = self._make_matmul_model(tmp_path / "src", "decoder") + embedding = self._make_matmul_model(tmp_path / "src", "embedding") + + composite = CompositeModelHandler( + model_components=[decoder, embedding], + model_component_names=["decoder", "embedding"], + model_path=str(tmp_path / "src"), + ) + + p = self._make_pass(components_to_skip=["embedding"]) + result = p.run(composite, str(tmp_path / "out")) + + assert isinstance(result, CompositeModelHandler) + assert result.model_component_names == ["decoder", "embedding"] + + # decoder should be quantized (MatMulNBits present) + decoder_out = next(m for name, m in result.get_model_components() if name == "decoder") + decoder_ir = ir.load(decoder_out.model_path) + assert any(n.op_type == str(OpType.MatMulNBits) for n in decoder_ir.graph.all_nodes()), ( + "decoder should be quantized (MatMulNBits expected)" + ) + + # embedding should be unchanged (original MatMul still present) + emb_out = next(m for name, m in result.get_model_components() if name == "embedding") + emb_ir = ir.load(emb_out.model_path) + has_matmul = any(n.op_type == str(OpType.MatMul) for n in emb_ir.graph.all_nodes()) + has_nbits = any(n.op_type == str(OpType.MatMulNBits) for n in emb_ir.graph.all_nodes()) + assert has_matmul, "embedding should still contain the original MatMul op" + assert not has_nbits, "embedding should not be quantized (no MatMulNBits expected)" + + def test_components_to_skip_none_quantizes_all(self, tmp_path): + """When components_to_skip is not set, all composite components are quantized.""" + from olive.model.handler.composite import CompositeModelHandler + + decoder = self._make_matmul_model(tmp_path / "src", "decoder") + embedding = self._make_matmul_model(tmp_path / "src", "embedding") + + composite = CompositeModelHandler( + model_components=[decoder, embedding], + model_component_names=["decoder", "embedding"], + model_path=str(tmp_path / "src"), + ) + + p = self._make_pass(components_to_skip=None) + result = p.run(composite, str(tmp_path / "out")) + + assert isinstance(result, CompositeModelHandler) + + for name, component in result.get_model_components(): + component_ir = ir.load(component.model_path) + assert any(n.op_type == str(OpType.MatMulNBits) for n in component_ir.graph.all_nodes()), ( + f"component '{name}' should be quantized when components_to_skip is None" + ) + + def test_components_to_skip_does_not_affect_single_model(self, tmp_path): + """components_to_skip has no effect on non-composite (single) models.""" + model = self._make_matmul_model(tmp_path, "single") + p = self._make_pass(components_to_skip=["single"]) + result = p.run(model, str(tmp_path / "out")) + + # Single model should still be quantized despite its path matching the skip list + result_ir = ir.load(result.model_path) + assert any(n.op_type == str(OpType.MatMulNBits) for n in result_ir.graph.all_nodes()), ( + "Single-component model should be quantized even when components_to_skip is set" + ) + + def test_components_to_skip_in_default_config(self): + """components_to_skip must appear in _default_config with None as default.""" + accelerator_spec = AcceleratorSpec(accelerator_type="CPU", execution_provider="CPUExecutionProvider") + config = OnnxBlockWiseRtnQuantization._default_config(accelerator_spec) # pylint: disable=protected-access + assert "components_to_skip" in config + assert config["components_to_skip"].default_value is None + assert config["components_to_skip"].required is False + + def test_components_to_skip_unknown_name_warns(self, tmp_path): + """Misspelled or missing component names in components_to_skip must log a warning.""" + from olive.model.handler.composite import CompositeModelHandler + + decoder = self._make_matmul_model(tmp_path / "src", "decoder") + vision = self._make_matmul_model(tmp_path / "src", "vision") + composite = CompositeModelHandler( + model_components=[decoder, vision], + model_component_names=["decoder", "vision"], + ) + + p = self._make_pass(components_to_skip=["typo_component"]) + + import logging + + records = [] + + class _Handler(logging.Handler): + def emit(self, record): + records.append(record.getMessage()) + + rtn_logger = logging.getLogger("olive.passes.onnx.rtn_quantization") + rtn_logger.addHandler(_Handler()) + try: + p.run(composite, str(tmp_path / "out")) + finally: + rtn_logger.handlers = [h for h in rtn_logger.handlers if not isinstance(h, _Handler)] + + assert any("typo_component" in msg for msg in records), ( + f"Expected warning about unknown component name 'typo_component', got: {records}" + )