Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions olive/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,19 @@ def save_model(
else:
from olive.passes.onnx.common import resave_model

component_output_name = (
component_name
if Path(component_name).suffix == ".onnx"
else f"{component_name}.onnx"
)

resave_model(
ModelConfig.model_validate(component_model_json).create_model().model_path,
actual_output_dir / f"{component_name}.onnx",
actual_output_dir / component_output_name,
saved_external_files=saved_external_files,
)
component_model_json["config"][resource_name] = str(actual_output_dir)
component_model_json["config"]["onnx_file_name"] = f"{component_name}.onnx"
component_model_json["config"]["onnx_file_name"] = component_output_name

copied_components.append(component_model_json)

Expand Down
10 changes: 9 additions & 1 deletion olive/passes/onnx/kquant_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon
def _run_for_config(
self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)
# For composite model components (e.g., Whisper encoder.onnx/decoder.onnx),
# output_model_path already includes .onnx extension. Strip it so ir.save doesn't
# create a double extension (.onnx.onnx). For other cases, resolve normally.
output_path_obj = Path(output_model_path)
if output_path_obj.suffix == ".onnx":
output_model_path = str(output_path_obj.with_suffix(""))
else:
output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name)

ir_model = model.load_ir_model()
ir.external_data.load_to_model(ir_model)
ir_model.graph.opset_imports[MSFT_DOMAIN] = 1
Expand Down
99 changes: 74 additions & 25 deletions olive/passes/onnx/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from olive.constants import Precision
from olive.hardware.accelerator import AcceleratorSpec, Device
from olive.hardware.constants import ExecutionProvider
from olive.model import HfModelHandler, ONNXModelHandler
from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler
from olive.model.utils import resolve_onnx_path
from olive.passes import Pass
from olive.passes.olive_pass import PassConfigParam
Expand Down Expand Up @@ -264,8 +264,9 @@ def _run_for_config(
if config.extra_options:
extra_args.update(config.extra_options)

# Ensure output_model_filepath matches the final filename in extra_args
output_model_filepath = Path(output_model_path) / extra_args["filename"]
# Ensure output_model_filepath matches the final filename in extra_args while preserving
# the resolved output directory selected above.
output_model_filepath = output_model_filepath.parent / extra_args["filename"]

model_attributes = copy.deepcopy(model.model_attributes or {})

Expand All @@ -283,26 +284,6 @@ def _run_for_config(
**extra_args,
)

# Apply post-processing annotations (split assignments and/or layer annotations)
# in a single load/save cycle to avoid redundant disk I/O.
split_assignments = model_attributes.get("split_assignments") if not metadata_only else None
layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None

if split_assignments or layer_annotations:
model_proto = onnx.load(output_model_filepath, load_external_data=False)

if split_assignments:
# NOTE: currently the model builder renames modules to it's own naming convention
# so the assignments for the renamed modules won't match
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})

if layer_annotations:
from olive.passes.onnx.layer_annotation import annotate_proto_model

annotate_proto_model(model_proto, layer_annotations)

onnx.save(model_proto, output_model_filepath)
except Exception:
# if model building fails, clean up the intermediate files in the cache_dir
cache_dir = Path(HF_HUB_CACHE)
Expand All @@ -328,6 +309,58 @@ def _run_for_config(
# tokenizer and generation configs are skipped since they are already saved by the model builder
model.save_metadata(output_model_filepath.parent)

generated_onnx_files = sorted(output_model_filepath.parent.glob("*.onnx")) if not metadata_only else []

# For multi-file models (e.g., Whisper), preserve component file names and process each file independently
# in subsequent passes by returning a CompositeModelHandler.
is_multi_file_model = not metadata_only and len(generated_onnx_files) > 1
resolved_single_model_filepath = output_model_filepath
if (
not metadata_only
and not is_multi_file_model
and not output_model_filepath.exists()
and len(generated_onnx_files) == 1
):
logger.info(
"ONNX model file %s does not exist, using %s instead",
output_model_filepath,
generated_onnx_files[0].name,
)
resolved_single_model_filepath = generated_onnx_files[0]

# Apply post-processing annotations (split assignments and/or layer annotations)
# in a single load/save cycle to avoid redundant disk I/O.
split_assignments = model_attributes.get("split_assignments") if not metadata_only else None
layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None
if is_multi_file_model:
primary_onnx_files = generated_onnx_files
elif resolved_single_model_filepath.exists():
primary_onnx_files = [resolved_single_model_filepath]
else:
primary_onnx_files = []
if split_assignments or layer_annotations:
if primary_onnx_files:
for primary_onnx_file in primary_onnx_files:
model_proto = onnx.load(primary_onnx_file, load_external_data=False)

if split_assignments:
# NOTE: currently the model builder renames modules to it's own naming convention
# so the assignments for the renamed modules won't match
split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()])
onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str})

if layer_annotations:
from olive.passes.onnx.layer_annotation import annotate_proto_model

annotate_proto_model(model_proto, layer_annotations)

onnx.save(model_proto, primary_onnx_file)
else:
logger.warning(
"Skipping split_assignments/layer_annotations because no ONNX file was generated in %s.",
output_model_filepath.parent,
)

Comment thread
kunal-vaishnavi marked this conversation as resolved.
# add additional files generated by model builder to model_attributes
additional_files = model_attributes.get("additional_files") or []
if metadata_only:
Expand All @@ -338,20 +371,36 @@ def _run_for_config(
str(output_model_filepath.parent / "genai_config.json"),
]
else:
primary_model_paths = {str(fp) for fp in primary_onnx_files}
model_attributes["additional_files"] = sorted(
set(additional_files)
# all files in the output directory except the model and model.data files
| {str(fp) for fp in output_model_filepath.parent.iterdir()}
- {str(output_model_filepath), str(output_model_filepath) + ".data"}
- primary_model_paths
- {f"{path}.data" for path in primary_model_paths}
)

if metadata_only:
output_model = copy.copy(model)
output_model.model_attributes = model_attributes
elif is_multi_file_model:
# Use the ONNX filenames as component names so child passes write back to encoder.onnx/decoder.onnx
# instead of defaulting to model.onnx.
component_names = [fp.name for fp in generated_onnx_files]
components = [
ONNXModelHandler(output_model_filepath.parent, onnx_file_name=component_name)
for component_name in component_names
]
output_model = CompositeModelHandler(
components,
component_names,
model_path=output_model_filepath.parent,
model_attributes=model_attributes,
)
else:
output_model = ONNXModelHandler(
output_model_filepath.parent,
onnx_file_name=output_model_filepath.name,
onnx_file_name=resolved_single_model_filepath.name,
model_attributes=model_attributes,
)

Expand Down
98 changes: 97 additions & 1 deletion test/passes/onnx/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,45 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import json
import sys
import types
from pathlib import Path
from unittest.mock import Mock

import onnx
import pytest

from olive.model import ONNXModelHandler
from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler
from olive.passes.olive_pass import create_pass_from_dict
from olive.passes.onnx.model_builder import ModelBuilder
from olive.passes.pytorch.rtn import Rtn
from test.utils import make_local_tiny_llama


def _create_test_onnx_model(model_path: Path, node_name: str):
input_info = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1])
output_info = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 1])
node = onnx.helper.make_node("Identity", ["input"], ["output"], name=node_name)
graph = onnx.helper.make_graph([node], "test_graph", [input_info], [output_info])
model = onnx.helper.make_model(graph)
onnx.save(model, model_path)


def _mock_genai_builder(monkeypatch, create_model_fn):
builder_module = types.ModuleType("onnxruntime_genai.models.builder")
builder_module.create_model = create_model_fn
models_module = types.ModuleType("onnxruntime_genai.models")
models_module.builder = builder_module
genai_module = types.ModuleType("onnxruntime_genai")
genai_module.__version__ = "0.8.0"
genai_module.models = models_module
monkeypatch.setitem(sys.modules, "onnxruntime_genai", genai_module)
monkeypatch.setitem(sys.modules, "onnxruntime_genai.models", models_module)
monkeypatch.setitem(sys.modules, "onnxruntime_genai.models.builder", builder_module)
monkeypatch.setattr(ModelBuilder, "maybe_patch_quant", staticmethod(lambda: None))


@pytest.mark.parametrize("metadata_only", [True, False])
def test_model_builder(tmp_path, metadata_only):
input_model = make_local_tiny_llama(tmp_path / "input_model", "onnx" if metadata_only else "hf")
Expand Down Expand Up @@ -100,3 +127,72 @@ def test_model_builder_layer_annotations(tmp_path, layer_annotations):
assert len(node_names_with_metadata) > 0, (
"Expected nodes with metadata_props when layer_annotations are provided"
)


def test_model_builder_apply_annotations_on_single_file_fallback(tmp_path, monkeypatch):
def fake_create_model(
model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs
):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
_create_test_onnx_model(output_dir / "actual.onnx", "test_node")
(output_dir / "actual.onnx.data").write_text("external_data")
(output_dir / "tokenizer.json").write_text("{}")
(output_dir / "genai_config.json").write_text(json.dumps({"search": {}}))

_mock_genai_builder(monkeypatch, fake_create_model)
input_model = Mock(spec=HfModelHandler)
input_model.model_name_or_path = "dummy-model"
input_model.adapter_path = None
input_model.model_attributes = {"split_assignments": {"model.layers.0": 1}}

p = create_pass_from_dict(
ModelBuilder, {"precision": "fp32", "extra_options": {"filename": "expected.onnx"}}, disable_search=True
)
output_folder = tmp_path / "output_model"
output_model = p.run(input_model, output_folder)

assert isinstance(output_model, ONNXModelHandler)
assert output_model.onnx_file_name == "actual.onnx"
model_proto = onnx.load(output_folder / "actual.onnx", load_external_data=False)
metadata_props = {prop.key: prop.value for prop in model_proto.metadata_props}
assert metadata_props["split_assignments"] == "model.layers.0=1"
assert str(output_folder / "actual.onnx") not in output_model.model_attributes["additional_files"]
assert str(output_folder / "actual.onnx.data") not in output_model.model_attributes["additional_files"]
assert str(output_folder / "tokenizer.json") in output_model.model_attributes["additional_files"]


def test_model_builder_multi_file_output_preserves_component_filenames(tmp_path, monkeypatch):
def fake_create_model(
model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs
):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
_create_test_onnx_model(output_dir / "encoder.onnx", "encoder_node")
_create_test_onnx_model(output_dir / "decoder.onnx", "decoder_node")
(output_dir / "encoder.onnx.data").write_text("encoder_data")
(output_dir / "decoder.onnx.data").write_text("decoder_data")
(output_dir / "tokenizer.json").write_text("{}")
(output_dir / "genai_config.json").write_text(json.dumps({"search": {}}))

_mock_genai_builder(monkeypatch, fake_create_model)
input_model = Mock(spec=HfModelHandler)
input_model.model_name_or_path = "dummy-model"
input_model.adapter_path = None
input_model.model_attributes = {}

p = create_pass_from_dict(ModelBuilder, {"precision": "fp32"}, disable_search=True)
output_folder = tmp_path / "output_model"
output_model = p.run(input_model, output_folder)

assert isinstance(output_model, CompositeModelHandler)
expected_component_names = sorted(["encoder.onnx", "decoder.onnx"])
assert output_model.model_component_names == expected_component_names
component_onnx_files = [component.onnx_file_name for component in output_model.model_components]
assert component_onnx_files == output_model.model_component_names
additional_files = output_model.model_attributes["additional_files"]
assert str(output_folder / "encoder.onnx") not in additional_files
assert str(output_folder / "decoder.onnx") not in additional_files
assert str(output_folder / "encoder.onnx.data") not in additional_files
assert str(output_folder / "decoder.onnx.data") not in additional_files
assert str(output_folder / "tokenizer.json") in additional_files
Loading