diff --git a/docs/source/how-to/configure-workflows/metrics-configuration.md b/docs/source/how-to/configure-workflows/metrics-configuration.md index b71e465d7f..027eb3c902 100644 --- a/docs/source/how-to/configure-workflows/metrics-configuration.md +++ b/docs/source/how-to/configure-workflows/metrics-configuration.md @@ -128,3 +128,28 @@ If you have multiple metrics to evaluate, you can configure them in the followin ```{Note} If you have more than one metric, you need to specify `priority: {RANK}`, which Olive will use to determine the best model. ``` + +## Speech Evaluation Metrics (WER and RTFx) + +Olive supports Word Error Rate (WER) and Real-Time Factor (RTFx) as built-in accuracy sub-types for evaluating speech/ASR models. + +### Using WER with the accuracy metric type + +WER can be used as an accuracy sub-type when your data pipeline returns text predictions and references: + +```json +{ + "name": "speech_accuracy", + "type": "accuracy", + "data_config": "speech_data_config", + "sub_types": [ + {"name": "wer", "priority": 1, "higher_is_better": false}, + {"name": "rtfx", "priority": 2, "higher_is_better": true} + ] +} +``` + +```{Note} +- `wer` (Word Error Rate): Measures transcription errors. Lower is better (defaults to `higher_is_better: false`). +- `rtfx` (Real-Time Factor): Ratio of audio duration to inference time. Higher means faster (defaults to `higher_is_better: true`). +``` diff --git a/docs/source/how-to/extending/custom-scripts.md b/docs/source/how-to/extending/custom-scripts.md index 8e8961a5b6..5e78149fe1 100644 --- a/docs/source/how-to/extending/custom-scripts.md +++ b/docs/source/how-to/extending/custom-scripts.md @@ -36,7 +36,7 @@ class MyDataLoader: @Registry.register_dataloader() def my_dataloader(dataset, batch_size): - return MyDataloader(dataset, batch_size) + return MyDataLoader(dataset, batch_size) @Registry.register_post_process() def my_post_process(output): diff --git a/olive/cache.py b/olive/cache.py index 22b13eae5b..ceb64e1528 100644 --- a/olive/cache.py +++ b/olive/cache.py @@ -439,13 +439,19 @@ def save_model( else: from olive.passes.onnx.common import resave_model + component_output_name = ( + component_name + if Path(component_name).suffix == ".onnx" + else f"{component_name}.onnx" + ) + resave_model( ModelConfig.model_validate(component_model_json).create_model().model_path, - actual_output_dir / f"{component_name}.onnx", + actual_output_dir / component_output_name, saved_external_files=saved_external_files, ) component_model_json["config"][resource_name] = str(actual_output_dir) - component_model_json["config"]["onnx_file_name"] = f"{component_name}.onnx" + component_model_json["config"]["onnx_file_name"] = component_output_name copied_components.append(component_model_json) diff --git a/olive/cli/auto_opt.py b/olive/cli/auto_opt.py index 2e0f73444f..ef8f5fed8d 100644 --- a/olive/cli/auto_opt.py +++ b/olive/cli/auto_opt.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import logging from argparse import ArgumentParser from collections import OrderedDict from copy import deepcopy @@ -25,13 +26,19 @@ from olive.package_config import OlivePackageConfig from olive.telemetry import action +logger = logging.getLogger(__name__) + class AutoOptCommand(BaseOliveCLICommand): @staticmethod def register_subcommand(parser: ArgumentParser): sub_parser = parser.add_parser( "auto-opt", - help="Automatically optimize the performance of the input model.", + help=( + "Automatically optimize the performance of the input model.\n" + "**** DEPRECATION WARNING ****\n" + '"auto-opt" command is deprecated in favor of "optimize".' + ), ) # Model options @@ -174,6 +181,11 @@ def register_subcommand(parser: ArgumentParser): @action def run(self): + logger.warning( + "**** DEPRECATION WARNING ****\n" + '"auto-opt" command is deprecated in favor of "optimize". Please switch to using "optimize".\n' + "Deprecated commands will be removed entirely in future release." + ) return self._run_workflow() def _get_run_config(self, tempdir) -> dict: diff --git a/olive/cli/benchmark.py b/olive/cli/benchmark.py index adad957730..a3b3f25e8e 100644 --- a/olive/cli/benchmark.py +++ b/olive/cli/benchmark.py @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser): help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.", ) + lmeval_group.add_argument( + "--confirm_run_unsafe_code", + action="store_true", + default=None, + help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).", + ) + add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict: ("evaluators", "evaluator", "model_class"), None if self.args.backend == "auto" else self.args.backend, ), + ( + ("evaluators", "evaluator", "confirm_run_unsafe_code"), + True if self.args.confirm_run_unsafe_code else None, + ), ] for keys, value in to_replace: diff --git a/olive/data/component/pre_process_data.py b/olive/data/component/pre_process_data.py index c4a4d38a07..2881ea891f 100644 --- a/olive/data/component/pre_process_data.py +++ b/olive/data/component/pre_process_data.py @@ -291,3 +291,91 @@ def _tokenizer_and_align_labels(examples): tokenized_datasets = _huggingface_pre_process_helper(dataset, _tokenizer_and_align_labels, max_samples, **kwargs) return ClassificationDataset(tokenized_datasets, label_col="label", max_samples=max_samples) + + +@Registry.register_pre_process() +def speech_transcription_pre_process( + dataset, + audio_col: str = "audio", + text_col: str = "text", + sample_rate: int = 16000, + max_samples: Optional[int] = None, + limit: Optional[float] = None, + seed: int = 42, + **kwargs, +): + """Pre-process data for speech transcription (ASR) evaluation. + + Loads audio arrays and reference transcription text from a HuggingFace dataset. + Returns a dataset of (audio_array, reference_text) pairs suitable for WER evaluation. + + Args: + dataset: HuggingFace dataset with audio and text columns. + audio_col: Name of the audio column. Defaults to "audio". + text_col: Name of the reference text column. Defaults to "text". + sample_rate: Target sample rate for audio. Defaults to 16000. + max_samples: Maximum number of samples (deprecated, use limit). Defaults to None. + limit: Sampling limit following Olive convention: + If >= 1: use first N samples. + If 0 < limit < 1: randomly sample that percentage. + If 0 or None: use all samples. + seed: Random seed for percentage-based sampling. Defaults to 42. + **kwargs: Additional arguments. + + """ + from datasets import Audio + + dataset = dataset.cast_column(audio_col, Audio(sampling_rate=sample_rate)) + + # Apply sampling: prefer limit over max_samples + effective_limit = limit if limit is not None else (max_samples if max_samples else 0) + if effective_limit and effective_limit != 0: + from random import Random + + total = len(dataset) + if 0 < effective_limit < 1: + n = max(1, int(total * effective_limit)) + rng = Random(seed) + indices = sorted(rng.sample(range(total), min(n, total))) + dataset = dataset.select(indices) + elif effective_limit >= 1: + n = min(int(effective_limit), total) + dataset = dataset.select(range(n)) + + class SpeechTranscriptionDataset: + """Dataset that returns (audio_array, reference_text) pairs. + + Note: Use batch_size=1 in dataloader config as audio samples have variable lengths. + """ + + def __init__(self, hf_dataset, audio_column, text_column): + self.dataset = hf_dataset + self.audio_column = audio_column + self.text_column = text_column + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + item = self.dataset[idx] + import numpy as np + + audio_array = np.array(item[self.audio_column]["array"], dtype=np.float32) + reference_text = item[self.text_column] + return audio_array, reference_text + + @staticmethod + def collate_fn(batch): + """Collate variable-length audio batches. Use with batch_size=1 or pad audio.""" + import numpy as np + + # batch_size=1 is expected for speech evaluation (variable-length audio) + if len(batch) == 1: + audio, text = batch[0] + return (np.expand_dims(audio, 0), [text]) + # For batch_size > 1, return as lists (no padding) + audios = [item[0] for item in batch] + texts = [item[1] for item in batch] + return (audios, texts) + + return SpeechTranscriptionDataset(dataset, audio_col, text_col) diff --git a/olive/data/container/huggingface_container.py b/olive/data/container/huggingface_container.py index 2c702d8908..9a6bb81e3f 100644 --- a/olive/data/container/huggingface_container.py +++ b/olive/data/container/huggingface_container.py @@ -38,4 +38,7 @@ class HuggingfaceContainer(DataContainer): DataComponentType.PRE_PROCESS_DATA.value: "audio_classification_pre_process", DataComponentType.POST_PROCESS_DATA.value: "text_classification_post_process", }, + "speech-transcription": { + DataComponentType.PRE_PROCESS_DATA.value: "speech_transcription_pre_process", + }, } diff --git a/olive/evaluator/accuracy.py b/olive/evaluator/accuracy.py index cd2ddf757c..db5e297754 100644 --- a/olive/evaluator/accuracy.py +++ b/olive/evaluator/accuracy.py @@ -26,6 +26,7 @@ class AccuracyBase(AutoConfigClass): "recall": torchmetrics.Recall, "auroc": torchmetrics.AUROC, "perplexity": torchmetrics.text.perplexity.Perplexity, + "wer": torchmetrics.text.WordErrorRate, } def __init__(self, config: Optional[Union[ConfigBase, dict[str, Any]]] = None) -> None: @@ -157,3 +158,62 @@ def measure(self, model_output, target): perplexity.update(logits, targets) result = perplexity.compute() return result.item() + + +class WordErrorRate(AccuracyBase): + """Word Error Rate metric for speech/ASR evaluation. + + Expects model_output.preds to be a list of predicted transcription strings + and target to be a list of reference transcription strings. + """ + + name: Optional[str] = "wer" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return {} + + def measure(self, model_output, target): + preds = model_output.preds + refs = target + # Ensure inputs are lists of strings + if isinstance(preds, str): + preds = [preds] + elif not isinstance(preds, list): + preds = list(preds) + if isinstance(refs, str): + refs = [refs] + elif not isinstance(refs, list): + refs = list(refs) + + wer = torchmetrics.text.WordErrorRate(**self.config_dict) + result = wer(preds, refs) + return result.item() + + +class RealTimeFactor(AccuracyBase): + """Real-Time Factor (RTFx) metric for speech/ASR evaluation. + + RTFx = total_audio_duration / total_inference_time. + A value > 1 means faster than real-time (e.g., RTFx=5 means 5x faster). + Timing metadata is provided via model_output.logits dict. + """ + + name: Optional[str] = "rtfx" + + @classmethod + def _default_config(cls) -> dict[str, ConfigParam]: + return {} + + def measure(self, model_output, target): + timing = model_output.logits + if not isinstance(timing, dict) or "total_audio_duration" not in timing: + raise ValueError( + "RTFx metric requires timing metadata from text-based inference path. " + "Ensure the metric is used with speech evaluation (WER + RTFx together)." + ) + total_audio = timing["total_audio_duration"] + total_inference = timing["total_inference_time"] + if total_inference == 0: + return float("inf") + return round(total_audio / total_inference, 2) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index fd69b066e7..f50da11b46 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl raise NotImplementedError("Yet to be implemented!") def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: - raise NotImplementedError("Yet to be implemented!") + raise NotImplementedError( + "generate_until is not supported by this model backend. " + "Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval." + ) @register_model("ort") @@ -509,7 +512,14 @@ def __init__( self.max_length = max_length else: self.max_length = genai_config["search"]["max_length"] - self._eot_token_id = genai_config["model"]["eos_token_id"] + eot = genai_config["model"]["eos_token_id"] + # eos_token_id can be a list (e.g. [1, 106] for Gemma4) or a scalar. + # Store all EOS IDs for generate_until stop detection, + # and first/scalar for loglikelihood (TemplateLM.eot_token_id expects int). + self._eos_token_ids = list(eot) if isinstance(eot, list) else [eot] + if not self._eos_token_ids: + raise ValueError("genai_config model.eos_token_id must not be an empty list") + self._eot_token_id = self._eos_token_ids[0] self.params = og.GeneratorParams(self.model) self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False) @@ -546,32 +556,227 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor self.params.set_search_options(batch_size=batch_size) generator = og.Generator(self.model, self.params) - if self._returns_full_logits: - generator.append_tokens(input_ids.tolist()) - return torch.from_numpy(generator.get_output("logits")).to(self.device) - - # Model only returns logits for the last appended position. if batch_size > 1 and cont_len > 1: raise ValueError( - "batch_size > 1 is not supported when the model returns single-position logits" + "batch_size > 1 is not supported when using incremental get_logits() retrieval" " and continuation length > 1. Right-padding misaligns continuation positions across" " batch elements. Use batch_size=1 instead." ) - # Bulk-append context tokens, then step through the last cont_len tokens - # one at a time to collect only the logits we actually need. + # Use incremental token appending with get_logits() to avoid copying + # the full logits tensor from GPU to CPU. get_output("logits") copies + # seq_len * vocab_size * 2 bytes (e.g. 472MB for 900 tokens with + # 262K vocab), while get_logits() copies only vocab_size * 4 bytes + # (~1MB) per position. n_logits = max(cont_len, 1) prefix_len = seq_len - n_logits generator.append_tokens(input_ids[:, : prefix_len + 1].tolist()) - all_logits = [torch.from_numpy(generator.get_output("logits")).to(self.device)] + all_logits = [torch.from_numpy(generator.get_logits()).to(self.device)] for i in range(prefix_len + 1, seq_len): generator.append_tokens(input_ids[:, i : i + 1].tolist()) - all_logits.append(torch.from_numpy(generator.get_output("logits")).to(self.device)) + all_logits.append(torch.from_numpy(generator.get_logits()).to(self.device)) # No need to pad to [batch, seq_len, vocab]. The slicing in _loglikelihood_tokens computes # ctx_len = inplen + (logits.shape[0] - padding_len_inp), which adjusts for the shorter # seq dimension so the continuation slice still lands on the correct positions. return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab] + def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: + """Generate text until a stop sequence is found or max tokens reached. + + Supports generative evaluation tasks such as MBPP and HumanEval. + Each request is a tuple of (context_string, gen_kwargs_dict). + """ + results = [] + for request in tqdm(requests, desc="Running generate_until", disable=disable_tqdm): + context = request.args[0] + gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {} + + # Extract stop sequences — normalise str/None/tuple/other-iterables to list[str] + until = gen_kwargs.get("until", []) + if isinstance(until, str): + until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + try: + until = list(until) # handles tuple, set, generator, etc. + except TypeError: + until = [until] # non-iterable scalar fallback + until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq] + + # Extract generation parameters + max_gen_toks = gen_kwargs.get( + "max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens")) + ) + try: + max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256 + except (TypeError, ValueError): + max_gen_toks = 256 + max_gen_toks = max(max_gen_toks, 0) + try: + temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0) + except (TypeError, ValueError): + temperature = 0.0 + raw_do_sample = gen_kwargs.get("do_sample", None) + if raw_do_sample is None: + do_sample = temperature > 0 + elif isinstance(raw_do_sample, bool): + do_sample = raw_do_sample + elif isinstance(raw_do_sample, str): + do_sample = raw_do_sample.lower() not in ("false", "0", "no", "") + else: + do_sample = bool(raw_do_sample) + + # Tokenize the prompt + prompt_ids = self.tokenizer.encode(context).tolist() + prompt_len = len(prompt_ids) + + # Compute total max_length: prompt + new tokens, capped by model limit + total_max_length = min(prompt_len + max_gen_toks, self.max_length) + + # If the prompt already fills or exceeds the model limit, no generation is possible. + if prompt_len >= self.max_length or max_gen_toks == 0: + results.append("") + if hasattr(request, "cache_hook") and request.cache_hook is not None: + request.cache_hook.add_partial("generate_until", request.args, "") + continue + + # Create fresh generator params per request to avoid state leakage + params = og.GeneratorParams(self.model) + search_options = { + "max_length": total_max_length, + "past_present_share_buffer": False, + } + if do_sample: + search_options["do_sample"] = True + search_options["temperature"] = temperature + else: + search_options["temperature"] = 0.0 + params.set_search_options(**search_options) + + # Run generation token by token to check for stop sequences + generator = og.Generator(self.model, params) + generator.append_tokens([prompt_ids]) + + generated_token_ids = [] + stop_found = False + # Character-based rolling tail wide enough to catch any stop sequence + # across chunk boundaries, regardless of how many tokens a stop string spans. + max_stop_len = max((len(s) for s in until), default=0) + tail = "" + + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_sequence(0)[-1] + + # Check for EOS token(s) + if new_token in self._eos_token_ids: + break + + generated_token_ids.append(new_token) + + # Decode one token at a time only for stop-sequence tail detection. + # The final text is produced by decoding the full ID sequence so that + # tokenizer whitespace/punctuation normalisation is applied correctly. + if until: + chunk = self.tokenizer.decode([new_token]) + tail = (tail + chunk)[-(max_stop_len + len(chunk)) :] + for stop_seq in until: + if stop_seq in tail: + stop_found = True + break + if stop_found: + break + + # Decode full token sequence once for correct whitespace/punctuation handling. + full_text = self.tokenizer.decode(generated_token_ids) if generated_token_ids else "" + + # Trim at the earliest stop sequence found in the final decoded text. + generated_text = full_text + if until: + earliest = None + for stop_seq in until: + idx = full_text.find(stop_seq) + if idx != -1 and (earliest is None or idx < earliest): + earliest = idx + if earliest is not None: + generated_text = full_text[:earliest] + + results.append(generated_text) + + # lm-eval cache hook + if hasattr(request, "cache_hook") and request.cache_hook is not None: + request.cache_hook.add_partial("generate_until", request.args, generated_text) + + return results + def complete(self): pass + + def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: + """Generate text until a stop sequence is reached. + + Used by benchmarks like MMLU Pro (CoT variant) that score by generating + chain-of-thought text and extracting the answer with a regex filter. + """ + results = [] + for request in tqdm(requests, disable=disable_tqdm, desc="Running generate_until requests"): + context = request.args[0] + gen_kwargs = request.args[1] + + until = gen_kwargs.get("until", []) + max_gen_toks = gen_kwargs.get("max_gen_toks", 256) + if isinstance(until, str): + until = [until] + + input_ids = self.tok_encode(context) + max_new_tokens = min(max_gen_toks, self.max_length - len(input_ids)) + if max_new_tokens <= 0: + results.append("") + continue + + params = og.GeneratorParams(self.model) + params.set_search_options( + max_length=len(input_ids) + max_new_tokens, + past_present_share_buffer=False, + batch_size=1, + ) + if gen_kwargs.get("temperature", 0.0) == 0.0: + params.set_search_options(do_sample=False) + else: + params.set_search_options( + do_sample=True, + temperature=gen_kwargs["temperature"], + ) + + generator = og.Generator(self.model, params) + generator.append_tokens([input_ids]) + + eos_ids = self._eos_token_ids + + generated_ids = [] + # Decode periodically to check for stop sequences + decode_interval = 16 + while not generator.is_done(): + generator.generate_next_token() + token_id = generator.get_next_tokens()[0] + generated_ids.append(token_id) + if token_id in eos_ids: + break + # Check stop sequences periodically by decoding + if until and len(generated_ids) % decode_interval == 0: + partial_text = self.tokenizer.decode(generated_ids) + if any(stop_seq in partial_text for stop_seq in until): + break + + generated_text = self.tokenizer.decode(generated_ids) + + # Truncate at the first stop sequence + for stop_seq in until: + idx = generated_text.find(stop_seq) + if idx != -1: + generated_text = generated_text[:idx] + + results.append(generated_text) + return results diff --git a/olive/evaluator/metric.py b/olive/evaluator/metric.py index 520d5f8b1f..24ae3f88db 100644 --- a/olive/evaluator/metric.py +++ b/olive/evaluator/metric.py @@ -38,6 +38,8 @@ class AccuracySubType(StrEnumBase): RECALL = "recall" AUROC = "auroc" PERPLEXITY = "perplexity" + WER = "wer" + RTFX = "rtfx" class LatencySubType(StrEnumBase): @@ -206,7 +208,13 @@ def validate_sub_types(cls, v, info): # metric_config metric_config_cls = None if info.data["type"] == MetricType.ACCURACY: - item["higher_is_better"] = item.get("higher_is_better", True) + # Error rate metrics (WER) default to higher_is_better=False + _error_rate_metrics = {"wer"} + item_name = item["name"] if isinstance(item["name"], str) else item["name"].value + if item_name in _error_rate_metrics: + item["higher_is_better"] = item.get("higher_is_better", False) + else: + item["higher_is_better"] = item.get("higher_is_better", True) if info.data["backend"] == "torch_metrics": metric_config_cls = AccuracyBase.registry[item["name"]].get_config_class() elif info.data["backend"] == "huggingface_metrics": diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 0814850a17..602ce29991 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -3,11 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import collections +import inspect import logging import time from abc import ABC, abstractmethod from copy import deepcopy -from functools import partial +from functools import lru_cache, partial from numbers import Number from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union @@ -26,6 +27,7 @@ from olive.data.container.dummy_data_container import TRANSFORMER_DUMMY_DATA_CONTAINER from olive.data.template import dummy_data_config_template from olive.evaluator.metric import ( + AccuracySubType, LatencySubType, Metric, MetricType, @@ -58,6 +60,27 @@ class OliveModelOutput(NamedTuple): logits: Any +# Text-based accuracy sub-types that work with string predictions/targets +_TEXT_BASED_ACCURACY_SUBTYPES = {AccuracySubType.WER, AccuracySubType.RTFX} + + +def _is_text_based_metric(metric: "Metric") -> bool: + """Check if metric uses text-based accuracy sub-types (WER, RTFx). + + Raises ValueError if text-based and tensor-based sub-types are mixed, + as they require different inference paths. + """ + if metric.type != MetricType.ACCURACY: + return False + text_based = [sub.name in _TEXT_BASED_ACCURACY_SUBTYPES for sub in metric.sub_types] + if any(text_based) and not all(text_based): + raise ValueError( + "Cannot mix text-based accuracy sub-types (WER, RTFx) with tensor-based sub-types " + "(accuracy_score, f1_score, etc.) in the same metric. Please define them as separate metrics." + ) + return all(text_based) + + class OliveEvaluator(ABC): def __init__(self, **kwargs): super().__init__() @@ -496,9 +519,401 @@ def _evaluate_onnx_accuracy( device: Device = Device.CPU, execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: - inference_output, targets = self._inference(model, metric, dataloader, post_func, device, execution_providers) + if _is_text_based_metric(metric): + # Auto-detect genai model by checking for genai_config.json + genai_config_path = Path(model.model_path).parent / "genai_config.json" + if genai_config_path.exists(): + import json + + with genai_config_path.open() as f: + genai_config = json.load(f) + model_type = genai_config.get("model", {}).get("type", "") + + if model_type == "whisper": + inference_output, targets = self._inference_text_genai( + model, metric, dataloader, device, execution_providers + ) + elif model_type == "nemotron_speech": + inference_output, targets = self._inference_text_genai_streaming( + model, metric, dataloader, device, execution_providers + ) + else: + raise ValueError( + f"Unsupported genai model type '{model_type}' for speech evaluation. " + f"Supported types: 'whisper' (offline), 'nemotron_speech' (streaming). " + f"For unsupported model types, use a custom evaluation script." + ) + else: + inference_output, targets = self._inference_text( + model, metric, dataloader, post_func, device, execution_providers + ) + else: + inference_output, targets = self._inference( + model, metric, dataloader, post_func, device, execution_providers + ) return OliveEvaluator.compute_accuracy(metric, inference_output, targets) + def _inference_text( + self, + model: ONNXModelHandler, + metric: Metric, + dataloader: "DataLoader", + post_func=None, + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Text-based inference for speech/ASR metrics (WER, RTFx). + + The post_func must return a list of predicted text strings per batch. + Labels from the dataloader must be a list of reference text strings. + Tracks total inference time and audio duration for RTFx computation. + """ + session, inference_settings = OnnxEvaluator.get_session_wrapper( + model, metric, dataloader, device, execution_providers + ) + io_config = model.io_config + run_kwargs = metric.get_run_kwargs() + + all_preds = [] + all_targets = [] + total_audio_duration = 0.0 + total_inference_time = 0.0 + output_names = io_config["output_names"] + is_single_tensor_output = len(output_names) == 1 + sample_rate = ( + metric.data_config.pre_process_data_config.params.get("sample_rate", 16000) + if (metric.data_config and metric.data_config.pre_process_data_config) + else 16000 + ) + + for batch in dataloader: + input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + # Track audio duration from input data + if isinstance(input_data, (np.ndarray, torch.Tensor)): + audio_samples = input_data.shape[-1] if len(input_data.shape) > 1 else input_data.shape[0] + total_audio_duration += audio_samples / sample_rate + elif isinstance(input_data, dict): + for v in input_data.values(): + if isinstance(v, (np.ndarray, torch.Tensor)) and v.ndim >= 1: + total_audio_duration += v.shape[-1] / sample_rate + break + + input_feed = format_data(input_data, io_config) + start_time = time.perf_counter() + result = model.run_session(session, input_feed, **run_kwargs) + if is_single_tensor_output: + result = torch.from_numpy(result[0]) if hasattr(result[0], "__array__") else torch.tensor(result[0]) + else: + result = { + name: torch.from_numpy(result[i]) if hasattr(result[i], "__array__") else torch.tensor(result[i]) + for i, name in enumerate(output_names) + } + # post_func must decode model output to text strings + outputs = post_func(result) if post_func else result + total_inference_time += time.perf_counter() - start_time + + if isinstance(outputs, str): + all_preds.append(outputs) + elif isinstance(outputs, (list, tuple)): + if not outputs: + continue + if not isinstance(outputs[0], str): + raise ValueError( + f"post_func must return str or list[str] for text-based metrics (WER), " + f"but got list of {type(outputs[0]).__name__}. " + f"Ensure your post_func decodes model output to text." + ) + all_preds.extend(outputs) + else: + raise ValueError( + f"post_func must return str or list[str] for text-based metrics (WER), " + f"but got {type(outputs).__name__}. " + f"Ensure your post_func decodes model output to text." + ) + # labels should be reference text strings + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + tuning_result_file = inference_settings.get("tuning_result_file") + if tuning_result_file: + dump_tuning_result(session.session, tuning_result_file) + + # Store timing metadata for RTFx computation + timing_metadata = { + "total_audio_duration": total_audio_duration, + "total_inference_time": total_inference_time, + } + return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + + def _inference_text_genai( + self, + model: ONNXModelHandler, + metric: Metric, + dataloader: "DataLoader", + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Text-based inference for speech/ASR metrics using onnxruntime-genai. + + Auto-detected when the model directory contains genai_config.json. + Uses og.Model with multimodal processor for Whisper-style models. + Automatically chunks audio longer than 30 seconds. + """ + try: + import onnxruntime_genai as og + except ImportError: + raise ImportError( + "onnxruntime-genai is required for genai-based speech evaluation. " + "Install it with: pip install onnxruntime-genai" + ) from None + + import io + import json + + import soundfile as sf + + model_dir = str(Path(model.model_path).parent) + + # Read genai_config to determine model properties + with (Path(model_dir) / "genai_config.json").open() as f: + genai_config = json.load(f) + + # Build og.Model with appropriate execution provider + config = og.Config(model_dir) + config.clear_providers() + if device == Device.GPU: + config.append_provider("cuda") + og_model = og.Model(config) + processor = og_model.create_multimodal_processor() + + # Determine decoder prompt tokens from model config + # English-only models (vocab_size=51864) use shorter prompt + vocab_size = genai_config.get("model", {}).get("vocab_size", 51865) + is_english_only = vocab_size == 51864 + if is_english_only: + decoder_prompt_tokens = ["<|startoftranscript|>", "<|notimestamps|>"] + else: + decoder_prompt_tokens = ["<|startoftranscript|>", "<|en|>", "<|transcribe|>", "<|notimestamps|>"] + + sample_rate = ( + metric.data_config.pre_process_data_config.params.get("sample_rate", 16000) + if (metric.data_config and metric.data_config.pre_process_data_config) + else 16000 + ) + max_length = genai_config.get("search", {}).get("max_length", 448) + + # Whisper encoder supports max 30s (3000 mel frames) + max_chunk_seconds = 30 + max_chunk_samples = max_chunk_seconds * sample_rate + + prompt = "".join(decoder_prompt_tokens) + + def _transcribe_chunks(audio_arr: np.ndarray, genai_model) -> str: + """Transcribe a single audio array, chunking if longer than 30s.""" + if len(audio_arr) <= max_chunk_samples: + chunks = [audio_arr] + else: + # Split into non-overlapping 30s chunks + chunks = [] + for start in range(0, len(audio_arr), max_chunk_samples): + chunks.append(audio_arr[start : start + max_chunk_samples]) + + transcriptions = [] + for chunk in chunks: + buffer = io.BytesIO() + sf.write(buffer, chunk, samplerate=sample_rate, format="WAV") + audios = og.Audios.open_bytes(buffer.getvalue()) + inputs = processor([prompt], audios=audios) + + params = og.GeneratorParams(genai_model) + params.set_search_options(do_sample=False, max_length=max_length, min_length=0, batch_size=1) + + generator = og.Generator(genai_model, params) + generator.set_inputs(inputs) + + while not generator.is_done(): + generator.generate_next_token() + + tokens = generator.get_sequence(0) + transcriptions.append(processor.decode(tokens).strip()) + + return " ".join(transcriptions) + + all_preds = [] + all_targets = [] + total_audio_duration = 0.0 + total_inference_time = 0.0 + + for batch in dataloader: + input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + + # Convert input to list of audio arrays + audio_arrays = [] + if isinstance(input_data, (np.ndarray, torch.Tensor)): + arr = np.array(input_data) if isinstance(input_data, torch.Tensor) else input_data + if arr.ndim == 1: + audio_arrays = [arr] + else: + audio_arrays = [arr[i] for i in range(arr.shape[0])] + elif isinstance(input_data, list): + audio_arrays = [np.array(a) if not isinstance(a, np.ndarray) else a for a in input_data] + + if not audio_arrays: + continue + + start_time = time.perf_counter() + for arr in audio_arrays: + total_audio_duration += len(arr) / sample_rate + transcription = _transcribe_chunks(arr, og_model) + all_preds.append(transcription) + total_inference_time += time.perf_counter() - start_time + + # Collect reference texts + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + del og_model + + timing_metadata = { + "total_audio_duration": total_audio_duration, + "total_inference_time": total_inference_time, + } + return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + + def _inference_text_genai_streaming( + self, + model: ONNXModelHandler, + metric: Metric, + dataloader: "DataLoader", + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Text-based inference for streaming ASR models using onnxruntime-genai. + + Auto-detected when genai_config.json has model.type = "nemotron_speech". + Uses og.StreamingProcessor for stateful chunked inference with silence padding + for right-context flushing. + """ + try: + import onnxruntime_genai as og + except ImportError: + raise ImportError( + "onnxruntime-genai is required for genai-based speech evaluation. " + "Install it with: pip install onnxruntime-genai" + ) from None + + import json + + model_dir = str(Path(model.model_path).parent) + + with (Path(model_dir) / "genai_config.json").open() as f: + genai_config = json.load(f) + + sample_rate = genai_config["model"].get("sample_rate", 16000) + chunk_samples = genai_config["model"].get("chunk_samples", 8960) + + # Build og.Model with appropriate execution provider + config = og.Config(model_dir) + config.clear_providers() + if device == Device.GPU: + config.append_provider("cuda") + og_model = og.Model(config) + tokenizer = og.Tokenizer(og_model) + + # Number of silence chunks for right-context flushing + num_silence_chunks = 4 + + def _transcribe_streaming(audio_arr: np.ndarray, genai_model) -> str: + """Transcribe audio using stateful streaming processor.""" + audio = audio_arr.astype(np.float32) + stream_processor = og.StreamingProcessor(genai_model) + tokenizer_stream = tokenizer.create_stream() + params = og.GeneratorParams(genai_model) + generator = og.Generator(genai_model, params) + + transcript = "" + + def decode_tokens(): + nonlocal transcript + while not generator.is_done(): + generator.generate_next_token() + tokens = generator.get_next_tokens() + if len(tokens) > 0: + text = tokenizer_stream.decode(tokens[0]) + if text: + transcript += text + + # Feed audio chunks + for start in range(0, len(audio), chunk_samples): + chunk = audio[start : start + chunk_samples].astype(np.float32) + inputs = stream_processor.process(chunk) + if inputs is not None: + generator.set_inputs(inputs) + decode_tokens() + + # Flush remaining audio in the processor + inputs = stream_processor.flush() + if inputs is not None: + generator.set_inputs(inputs) + decode_tokens() + + # Feed silence chunks for right-context flushing + for _ in range(num_silence_chunks): + silence = np.zeros(chunk_samples, dtype=np.float32) + inputs = stream_processor.process(silence) + if inputs is not None: + generator.set_inputs(inputs) + decode_tokens() + + return transcript + + all_preds = [] + all_targets = [] + total_audio_duration = 0.0 + total_inference_time = 0.0 + + for batch in dataloader: + input_data, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + + # Convert input to list of audio arrays + audio_arrays = [] + if isinstance(input_data, (np.ndarray, torch.Tensor)): + arr = np.array(input_data) if isinstance(input_data, torch.Tensor) else input_data + if arr.ndim == 1: + audio_arrays = [arr] + else: + audio_arrays = [arr[i] for i in range(arr.shape[0])] + elif isinstance(input_data, list): + audio_arrays = [np.array(a) if not isinstance(a, np.ndarray) else a for a in input_data] + + if not audio_arrays: + continue + + start_time = time.perf_counter() + for arr in audio_arrays: + total_audio_duration += len(arr) / sample_rate + transcription = _transcribe_streaming(arr, og_model) + all_preds.append(transcription) + total_inference_time += time.perf_counter() - start_time + + # Collect reference texts + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + + del og_model + + timing_metadata = { + "total_audio_duration": total_audio_duration, + "total_inference_time": total_inference_time, + } + return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + def _evaluate_onnx_latency( self, model: ONNXModelHandler, @@ -802,9 +1217,92 @@ def _evaluate_accuracy( device: Device = Device.CPU, execution_providers: Optional[Union[str, list[str]]] = None, ) -> MetricResult: - inference_output, targets = self._inference(model, metric, dataloader, post_func, device, execution_providers) + if _is_text_based_metric(metric): + inference_output, targets = self._inference_text( + model, metric, dataloader, post_func, device, execution_providers + ) + else: + inference_output, targets = self._inference( + model, metric, dataloader, post_func, device, execution_providers + ) return OliveEvaluator.compute_accuracy(metric, inference_output, targets) + @torch.no_grad() + def _inference_text( + self, + model: "PyTorchModelHandler", + metric: Metric, + dataloader: "DataLoader", + post_func=None, + device: Device = Device.CPU, + execution_providers: Optional[Union[str, list[str]]] = None, + ) -> tuple[OliveModelOutput, Any]: + """Text-based inference for speech/ASR metrics (WER, RTFx).""" + session = model.prepare_session() + all_preds = [] + all_targets = [] + total_audio_duration = 0.0 + total_inference_time = 0.0 + device = _OliveEvaluator.device_string_to_torch_device(device) + run_kwargs = metric.get_run_kwargs() + session.to(device) + sample_rate = ( + metric.data_config.pre_process_data_config.params.get("sample_rate", 16000) + if (metric.data_config and metric.data_config.pre_process_data_config) + else 16000 + ) + + for batch in dataloader: + input_data_i, labels = OliveEvaluator.unpack_batch_for_accuracy(batch) + # Track audio duration from input data + if isinstance(input_data_i, (np.ndarray, torch.Tensor)): + audio_samples = input_data_i.shape[-1] if len(input_data_i.shape) > 1 else input_data_i.shape[0] + total_audio_duration += audio_samples / sample_rate + elif isinstance(input_data_i, dict): + for v in input_data_i.values(): + if isinstance(v, (np.ndarray, torch.Tensor)) and v.ndim >= 1: + total_audio_duration += v.shape[-1] / sample_rate + break + + input_data = tensor_data_to_device(input_data_i, device) + start_time = time.perf_counter() + result = model.run_session(session, input_data, **run_kwargs) + outputs = post_func(result) if post_func else result + total_inference_time += time.perf_counter() - start_time + + if isinstance(outputs, str): + all_preds.append(outputs) + elif isinstance(outputs, (list, tuple)): + if not outputs: + continue + if not isinstance(outputs[0], str): + raise ValueError( + f"post_func must return str or list[str] for text-based metrics (WER), " + f"but got list of {type(outputs[0]).__name__}. " + f"Ensure your post_func decodes model output to text." + ) + all_preds.extend(outputs) + else: + raise ValueError( + f"post_func must return str or list[str] for text-based metrics (WER), " + f"but got {type(outputs).__name__}. " + f"Ensure your post_func decodes model output to text." + ) + if isinstance(labels, (list, tuple)): + all_targets.extend(labels) + else: + all_targets.append(labels) + if device: + session.to("cpu") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + timing_metadata = { + "total_audio_duration": total_audio_duration, + "total_inference_time": total_inference_time, + } + return OliveModelOutput(preds=all_preds, logits=timing_metadata), all_targets + @torch.no_grad() def _evaluate_raw_latency( self, @@ -1016,6 +1514,15 @@ def _prepare_dataloader( return FileListCommonDataLoader(dataloader, model.io_config, batch_size=file_chunk_size) +@lru_cache(maxsize=1) +def _simple_evaluate_supports_unsafe_code(simple_evaluate_fn) -> bool: + """Check (cached) whether lm-eval's simple_evaluate accepts confirm_run_unsafe_code.""" + try: + return "confirm_run_unsafe_code" in inspect.signature(simple_evaluate_fn).parameters + except (TypeError, ValueError): + return False + + @Registry.register("LMEvaluator") class LMEvaluator(OliveEvaluator): def __init__(self, tasks: list[str], **kwargs): @@ -1029,6 +1536,7 @@ def __init__(self, tasks: list[str], **kwargs): self.ep = kwargs.get("execution_provider") self.ep_options = kwargs.get("provider_options") self.device = kwargs.get("device") + self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False) def evaluate( self, @@ -1100,15 +1608,19 @@ def evaluate( if self.tasks: lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length) - results = simple_evaluate( - model=lmmodel, - tasks=self.tasks, - task_manager=TaskManager(), - log_samples=False, - batch_size=self.batch_size, - device=device, - limit=self.limit, - ) + simple_evaluate_kwargs = { + "model": lmmodel, + "tasks": self.tasks, + "task_manager": TaskManager(), + "log_samples": False, + "batch_size": self.batch_size, + "device": device, + "limit": self.limit, + } + # Only pass confirm_run_unsafe_code when the installed lm-eval version supports it. + if _simple_evaluate_supports_unsafe_code(simple_evaluate): + simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code + results = simple_evaluate(**simple_evaluate_kwargs) for task_name in sorted(results["results"].keys()): metric_items = sorted(results["results"][task_name].items()) diff --git a/olive/hardware/constants.py b/olive/hardware/constants.py index 0d19ba18d4..5e8330000f 100644 --- a/olive/hardware/constants.py +++ b/olive/hardware/constants.py @@ -37,6 +37,19 @@ class ExecutionProvider(StrEnumBase): ExecutionProvider.DmlExecutionProvider: "onnxruntime-directml", } +EXECUTION_PROVIDER_TO_MOBIUS_EP = { + ExecutionProvider.CPUExecutionProvider: "cpu", + ExecutionProvider.CUDAExecutionProvider: "cuda", + ExecutionProvider.DmlExecutionProvider: "webgpu", + ExecutionProvider.MIGraphXExecutionProvider: "onnx-standard", + ExecutionProvider.NvTensorRTRTXExecutionProvider: "trt-rtx", + ExecutionProvider.OpenVINOExecutionProvider: "default", + ExecutionProvider.QNNExecutionProvider: "onnx-standard", + ExecutionProvider.ROCMExecutionProvider: "onnx-standard", + ExecutionProvider.VitisAIExecutionProvider: "onnx-standard", + ExecutionProvider.WebGpuExecutionProvider: "webgpu", +} + DEVICE_TO_EXECUTION_PROVIDERS = { "cpu": {ExecutionProvider.CPUExecutionProvider, ExecutionProvider.OpenVINOExecutionProvider}, "gpu": { diff --git a/olive/olive_config.json b/olive/olive_config.json index 6e70efc7a5..50e1f36d6c 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -205,6 +205,15 @@ "supported_algorithms": [ ], "supported_quantization_encodings": [ ] }, + "MobiusBuilder": { + "module_path": "olive.passes.onnx.mobius_model_builder.MobiusBuilder", + "supported_providers": [ "*" ], + "supported_accelerators": [ "*" ], + "supported_precisions": [ "fp32", "fp16", "bf16" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ], + "extra_dependencies": [ "mobius-ai" ] + }, "LoftQ": { "module_path": "olive.passes.pytorch.lora.LoftQ", "supported_providers": [ "*" ], @@ -682,6 +691,7 @@ "inc": [ "neural-compressor" ], "lora": [ "accelerate>=0.30.0", "peft", "scipy" ], "diffusers": [ "accelerate>=0.30.0", "peft", "diffusers" ], + "mobius-ai": [ "mobius-ai" ], "nvmo": [ "nvidia-modelopt[onnx]" ], "openvino": [ "openvino>=2025.4.1", @@ -693,6 +703,7 @@ "optimum": [ "optimum" ], "qairt": [ "qairt-dev[onnx]" ], "qnn": [ "onnxruntime-qnn" ], + "speech": [ "jiwer", "librosa", "soundfile" ], "tf": [ "tensorflow==1.15.0" ], "torch-tensorrt": [ "torch-tensorrt" ], "tune-session-params": [ "psutil" ] diff --git a/olive/passes/onnx/kquant_quantization.py b/olive/passes/onnx/kquant_quantization.py index 5d75016ecc..4263406afd 100644 --- a/olive/passes/onnx/kquant_quantization.py +++ b/olive/passes/onnx/kquant_quantization.py @@ -256,7 +256,15 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon def _run_for_config( self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: - output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + # For composite model components (e.g., Whisper encoder.onnx/decoder.onnx), + # output_model_path already includes .onnx extension. Strip it so ir.save doesn't + # create a double extension (.onnx.onnx). For other cases, resolve normally. + output_path_obj = Path(output_model_path) + if output_path_obj.suffix == ".onnx": + output_model_path = str(output_path_obj.with_suffix("")) + else: + output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + ir_model = model.load_ir_model() ir.external_data.load_to_model(ir_model) ir_model.graph.opset_imports[MSFT_DOMAIN] = 1 diff --git a/olive/passes/onnx/mobius_model_builder.py b/olive/passes/onnx/mobius_model_builder.py new file mode 100644 index 0000000000..7b7e52bb34 --- /dev/null +++ b/olive/passes/onnx/mobius_model_builder.py @@ -0,0 +1,261 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Build ONNX models from HuggingFace model IDs using the mobius package.""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +from olive.common.utils import StrEnumBase +from olive.constants import Precision +from olive.hardware.constants import EXECUTION_PROVIDER_TO_MOBIUS_EP, ExecutionProvider +from olive.model import HfModelHandler, ONNXModelHandler +from olive.model.handler.composite import CompositeModelHandler +from olive.passes import Pass +from olive.passes.olive_pass import PassConfigParam + +if TYPE_CHECKING: + from olive.hardware.accelerator import AcceleratorSpec + from olive.passes.pass_config import BasePassConfig + +logger = logging.getLogger(__name__) + +# Maps Olive Precision values to mobius dtype strings. +# "f32" = 32-bit float (torch.float32), standard full precision. +# "f16" = 16-bit float (torch.float16), half precision — good for GPU inference. +# "bf16" = bfloat16 (torch.bfloat16), brain float — preferred over f16 on newer hardware. +# For INT4/INT8 quantization, use a downstream Olive quantization pass (e.g. OnnxMatMulNBits) +# after this pass rather than setting precision here. +_PRECISION_TO_DTYPE: dict[str, str] = { + Precision.FP32: "f32", + Precision.FP16: "f16", + Precision.BF16: "bf16", +} + + +class MobiusBuilder(Pass): + """Olive pass that uses mobius to build ONNX models from HuggingFace model IDs. + + Supports all model architectures registered in mobius (LLMs, VLMs, speech + models, diffusion models). For multi-component models (e.g. vision-language + models that produce ``model``, ``vision``, and ``embedding`` sub-graphs) the + pass returns a :class:`~olive.model.handler.composite.CompositeModelHandler` + whose components are individual :class:`~olive.model.ONNXModelHandler` objects. + Single-component models return a plain :class:`~olive.model.ONNXModelHandler`. + + Requires ``mobius-ai`` to be installed:: + + pip install mobius-ai + + See https://github.com/onnxruntime/mobius + """ + + class MobiusRuntime(StrEnumBase): + """Target runtimes for genai config generation.""" + + NONE = "none" + ORT_GENAI = "ort-genai" + + class MobiusEP(StrEnumBase): + """Execution providers supported by mobius.""" + + DEFAULT = "default" + CPU = "cpu" + CUDA = "cuda" + WEBGPU = "webgpu" + TRT_RTX = "trt-rtx" + ONNX_STANDARD = "onnx-standard" + + # Maps Olive ExecutionProvider enum values to mobius EP names. + EP_MAP: ClassVar[dict[ExecutionProvider, str]] = { + ExecutionProvider.CPUExecutionProvider: "cpu", + ExecutionProvider.CUDAExecutionProvider: "cuda", + ExecutionProvider.DmlExecutionProvider: "dml", + ExecutionProvider.WebGpuExecutionProvider: "webgpu", + } + + @classmethod + def is_accelerator_agnostic(cls, accelerator_spec: AcceleratorSpec) -> bool: + # EP selection determines which fused ops are emitted, so this pass is + # EP-specific. + return False + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + "precision": PassConfigParam( + type_=Precision, + required=False, + default_value=Precision.FP32, + description=( + "Model weight / compute precision. One of: fp32, fp16, bf16. " + "Defaults to fp32. For INT4 quantization, run an Olive " + "quantization pass (e.g. OnnxMatMulNBits) after this pass." + ), + ), + "runtime": PassConfigParam( + type_=MobiusBuilder.MobiusRuntime, + required=False, + default_value=MobiusBuilder.MobiusRuntime.ORT_GENAI, + description=( + "Target runtime. 'ort-genai' (default) generates " + "genai_config.json, tokenizer files, and processor " + "configs alongside the ONNX models. 'none' to skip." + ), + ), + } + + def _run_for_config( + self, + model: HfModelHandler, + config: type[BasePassConfig], + output_model_path: str, + ) -> ONNXModelHandler | CompositeModelHandler: + try: + from mobius import build + except ImportError as exc: + raise ImportError( + "mobius-ai is required to run MobiusBuilder. Install with: pip install mobius-ai" + ) from exc + + if not isinstance(model, HfModelHandler): + raise ValueError(f"MobiusBuilder requires an HfModelHandler input, got {type(model).__name__}.") + + # Map Olive EP to mobius EP. If unsupported/unknown, fall back to mobius default EP. + requested_ep = self.accelerator_spec.execution_provider + ep_str: str = EXECUTION_PROVIDER_TO_MOBIUS_EP.get(requested_ep, self.MobiusEP.DEFAULT) + if ep_str == self.MobiusEP.DEFAULT: + logger.warning( + "MobiusBuilder: execution provider '%s' on accelerator '%s' is not explicitly supported; " + "falling back to mobius default EP.", + requested_ep, + self.accelerator_spec.accelerator_type, + ) + + dtype_str: str = _PRECISION_TO_DTYPE.get(config.precision, "f32") + model_id: str = model.model_name_or_path + + # Read trust_remote_code from the model's HuggingFace load kwargs. + trust_remote_code: bool = model.get_load_kwargs().get("trust_remote_code", False) + + logger.info( + "MobiusBuilder: building '%s' (ep=%s, dtype=%s)", + model_id, + ep_str, + dtype_str, + ) + + if trust_remote_code: + logger.warning("MobiusBuilder: trust_remote_code=True — only use with trusted model sources.") + + output_dir = Path(output_model_path) + output_dir.mkdir(parents=True, exist_ok=True) + + pkg = build( + model_id, + dtype=dtype_str, + execution_provider=ep_str, + load_weights=True, + trust_remote_code=trust_remote_code, + ) + + # ModelPackage.save() handles both single and multi-component layouts: + # single component → /model.onnx + # multi-component → //model.onnx for each key + pkg.save(str(output_dir)) + + # Generate ORT GenAI config artifacts (genai_config.json, tokenizer + # files, processor configs) when runtime is set to ort-genai. + genai_artifacts = {} + if config.runtime == self.MobiusRuntime.ORT_GENAI: + genai_artifacts = self._write_genai_config(pkg, str(output_dir), model_id, ep_str) + + package_keys = list(pkg.keys()) + logger.info("MobiusBuilder: saved components %s to '%s'", package_keys, output_dir) + + if len(package_keys) == 1: + # Single-component model (most LLMs): return a plain ONNXModelHandler. + onnx_path = output_dir / "model.onnx" + if not onnx_path.exists(): + raise RuntimeError( + f"MobiusBuilder: expected output file not found: {onnx_path}. " + "mobius.build() may have failed silently or saved to an unexpected path." + ) + additional_files = sorted( + {str(fp) for fp in output_dir.iterdir()} - {str(onnx_path), str(onnx_path) + ".data"} + ) + # Include ORT GenAI artifacts (genai_config.json, tokenizer files, etc.) + additional_files = sorted(set(additional_files) | set(genai_artifacts.values())) + return ONNXModelHandler( + model_path=str(output_dir), + onnx_file_name="model.onnx", + model_attributes={ + "mobius_package_keys": package_keys, + "additional_files": additional_files, + **(model.model_attributes or {}), + }, + ) + + # Multi-component model (VLMs, encoder-decoders, diffusion pipelines): + # mobius saves each component to //model.onnx. + components = [] + for key in package_keys: + component_dir = output_dir / key + onnx_path = component_dir / "model.onnx" + if not onnx_path.exists(): + raise RuntimeError( + f"MobiusBuilder: expected output file not found: {onnx_path}. " + f"mobius.build() may have failed silently for component '{key}'." + ) + additional_files = sorted( + {str(fp) for fp in component_dir.iterdir()} - {str(onnx_path), str(onnx_path) + ".data"} + ) + # Include ORT GenAI artifacts from root output_dir (shared across components) + additional_files = sorted(set(additional_files) | set(genai_artifacts.values())) + components.append( + ONNXModelHandler( + model_path=str(component_dir), + onnx_file_name="model.onnx", + model_attributes={ + "mobius_component": key, + "additional_files": additional_files, + **(model.model_attributes or {}), + }, + ) + ) + + return CompositeModelHandler( + model_components=components, + model_component_names=package_keys, + model_path=str(output_dir), + model_attributes={ + "mobius_package_keys": package_keys, + **(model.model_attributes or {}), + }, + ) + + @staticmethod + def _write_genai_config(pkg, output_dir: str, model_id: str, ep: str) -> dict[str, str]: + """Generate ORT GenAI config artifacts alongside the ONNX models. + + Returns: + Dict mapping artifact names to their file paths (e.g., genai_config.json, tokenizer.json). + + """ + from mobius.integrations.ort_genai import write_ort_genai_config + + genai_artifacts = write_ort_genai_config( + pkg, + output_dir, + hf_model_id=model_id, + ep=ep, + ) + logger.info( + "MobiusBuilder: wrote ORT GenAI config: %s", + list(genai_artifacts.keys()), + ) + return genai_artifacts diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index f704579ba2..e1062e02b9 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -20,7 +20,7 @@ from olive.constants import Precision from olive.hardware.accelerator import AcceleratorSpec, Device from olive.hardware.constants import ExecutionProvider -from olive.model import HfModelHandler, ONNXModelHandler +from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler from olive.model.utils import resolve_onnx_path from olive.passes import Pass from olive.passes.olive_pass import PassConfigParam @@ -264,8 +264,9 @@ def _run_for_config( if config.extra_options: extra_args.update(config.extra_options) - # Ensure output_model_filepath matches the final filename in extra_args - output_model_filepath = Path(output_model_path) / extra_args["filename"] + # Ensure output_model_filepath matches the final filename in extra_args while preserving + # the resolved output directory selected above. + output_model_filepath = output_model_filepath.parent / extra_args["filename"] model_attributes = copy.deepcopy(model.model_attributes or {}) @@ -283,26 +284,6 @@ def _run_for_config( **extra_args, ) - # Apply post-processing annotations (split assignments and/or layer annotations) - # in a single load/save cycle to avoid redundant disk I/O. - split_assignments = model_attributes.get("split_assignments") if not metadata_only else None - layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None - - if split_assignments or layer_annotations: - model_proto = onnx.load(output_model_filepath, load_external_data=False) - - if split_assignments: - # NOTE: currently the model builder renames modules to it's own naming convention - # so the assignments for the renamed modules won't match - split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()]) - onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str}) - - if layer_annotations: - from olive.passes.onnx.layer_annotation import annotate_proto_model - - annotate_proto_model(model_proto, layer_annotations) - - onnx.save(model_proto, output_model_filepath) except Exception: # if model building fails, clean up the intermediate files in the cache_dir cache_dir = Path(HF_HUB_CACHE) @@ -328,6 +309,58 @@ def _run_for_config( # tokenizer and generation configs are skipped since they are already saved by the model builder model.save_metadata(output_model_filepath.parent) + generated_onnx_files = sorted(output_model_filepath.parent.glob("*.onnx")) if not metadata_only else [] + + # For multi-file models (e.g., Whisper), preserve component file names and process each file independently + # in subsequent passes by returning a CompositeModelHandler. + is_multi_file_model = not metadata_only and len(generated_onnx_files) > 1 + resolved_single_model_filepath = output_model_filepath + if ( + not metadata_only + and not is_multi_file_model + and not output_model_filepath.exists() + and len(generated_onnx_files) == 1 + ): + logger.info( + "ONNX model file %s does not exist, using %s instead", + output_model_filepath, + generated_onnx_files[0].name, + ) + resolved_single_model_filepath = generated_onnx_files[0] + + # Apply post-processing annotations (split assignments and/or layer annotations) + # in a single load/save cycle to avoid redundant disk I/O. + split_assignments = model_attributes.get("split_assignments") if not metadata_only else None + layer_annotations = model_attributes.get("layer_annotations") if not metadata_only else None + if is_multi_file_model: + primary_onnx_files = generated_onnx_files + elif resolved_single_model_filepath.exists(): + primary_onnx_files = [resolved_single_model_filepath] + else: + primary_onnx_files = [] + if split_assignments or layer_annotations: + if primary_onnx_files: + for primary_onnx_file in primary_onnx_files: + model_proto = onnx.load(primary_onnx_file, load_external_data=False) + + if split_assignments: + # NOTE: currently the model builder renames modules to it's own naming convention + # so the assignments for the renamed modules won't match + split_assignment_str = ";".join([f"{k}={v}" for k, v in split_assignments.items()]) + onnx.helper.set_model_props(model_proto, {"split_assignments": split_assignment_str}) + + if layer_annotations: + from olive.passes.onnx.layer_annotation import annotate_proto_model + + annotate_proto_model(model_proto, layer_annotations) + + onnx.save(model_proto, primary_onnx_file) + else: + logger.warning( + "Skipping split_assignments/layer_annotations because no ONNX file was generated in %s.", + output_model_filepath.parent, + ) + # add additional files generated by model builder to model_attributes additional_files = model_attributes.get("additional_files") or [] if metadata_only: @@ -338,20 +371,36 @@ def _run_for_config( str(output_model_filepath.parent / "genai_config.json"), ] else: + primary_model_paths = {str(fp) for fp in primary_onnx_files} model_attributes["additional_files"] = sorted( set(additional_files) # all files in the output directory except the model and model.data files | {str(fp) for fp in output_model_filepath.parent.iterdir()} - - {str(output_model_filepath), str(output_model_filepath) + ".data"} + - primary_model_paths + - {f"{path}.data" for path in primary_model_paths} ) if metadata_only: output_model = copy.copy(model) output_model.model_attributes = model_attributes + elif is_multi_file_model: + # Use the ONNX filenames as component names so child passes write back to encoder.onnx/decoder.onnx + # instead of defaulting to model.onnx. + component_names = [fp.name for fp in generated_onnx_files] + components = [ + ONNXModelHandler(output_model_filepath.parent, onnx_file_name=component_name) + for component_name in component_names + ] + output_model = CompositeModelHandler( + components, + component_names, + model_path=output_model_filepath.parent, + model_attributes=model_attributes, + ) else: output_model = ONNXModelHandler( output_model_filepath.parent, - onnx_file_name=output_model_filepath.name, + onnx_file_name=resolved_single_model_filepath.name, model_attributes=model_attributes, ) diff --git a/test/evaluator/conftest.py b/test/evaluator/conftest.py new file mode 100644 index 0000000000..50e5757140 --- /dev/null +++ b/test/evaluator/conftest.py @@ -0,0 +1,26 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Inject a minimal onnxruntime_genai stub for generate_until unit tests. + +Ensures tests can run in environments where the real package is not installed. +The tests mock all ORT GenAI objects anyway, so the stub only needs to provide +importable names. +""" + +import importlib.util +import sys +import types +from unittest.mock import MagicMock + + +def _ensure_ort_genai_stub(): + if importlib.util.find_spec("onnxruntime_genai") is None: + stub = types.ModuleType("onnxruntime_genai") + stub.Generator = MagicMock + stub.GeneratorParams = MagicMock + sys.modules["onnxruntime_genai"] = stub + + +_ensure_ort_genai_stub() diff --git a/test/evaluator/test_accuracy.py b/test/evaluator/test_accuracy.py index 5d6e44118b..f2c2754dd2 100644 --- a/test/evaluator/test_accuracy.py +++ b/test/evaluator/test_accuracy.py @@ -8,7 +8,16 @@ import pytest import torch -from olive.evaluator.accuracy import AUROC, AccuracyScore, F1Score, Perplexity, Precision, Recall +from olive.evaluator.accuracy import ( + AUROC, + AccuracyScore, + F1Score, + Perplexity, + Precision, + RealTimeFactor, + Recall, + WordErrorRate, +) from olive.evaluator.olive_evaluator import OliveModelOutput @@ -149,3 +158,58 @@ def test_evaluate_perplexity(mock_torchmetrics, mock_torch_tensor): mock_torch_tensor.assert_any_call(model_output.preds[i], dtype=torch.float) mock_torch_tensor.assert_any_call(targets[i], dtype=torch.long) assert actual_res == expected_res + + +class TestWordErrorRate: + def test_perfect_transcription(self): + wer = WordErrorRate({}) + model_output = OliveModelOutput(preds=["hello world", "test sentence"], logits=None) + targets = ["hello world", "test sentence"] + result = wer.measure(model_output, targets) + assert result == 0.0 + + def test_completely_wrong(self): + wer = WordErrorRate({}) + model_output = OliveModelOutput(preds=["completely wrong words here"], logits=None) + targets = ["the correct reference text"] + result = wer.measure(model_output, targets) + assert result > 0.0 + + def test_single_string_input(self): + """Test that a single string is wrapped in a list, not split into chars.""" + wer = WordErrorRate({}) + model_output = OliveModelOutput(preds="hello world", logits=None) + targets = "hello world" + result = wer.measure(model_output, targets) + assert result == 0.0 + + def test_partial_error(self): + wer = WordErrorRate({}) + model_output = OliveModelOutput(preds=["hello world"], logits=None) + targets = ["hello earth"] + result = wer.measure(model_output, targets) + assert 0.0 < result < 1.0 + + +class TestRealTimeFactor: + def test_rtfx_computation(self): + rtfx = RealTimeFactor({}) + # 10 seconds of audio processed in 2 seconds = RTFx 5.0 + timing = {"total_audio_duration": 10.0, "total_inference_time": 2.0} + model_output = OliveModelOutput(preds=["some text"], logits=timing) + result = rtfx.measure(model_output, ["some text"]) + assert result == 5.0 + + def test_rtfx_realtime(self): + rtfx = RealTimeFactor({}) + # 5 seconds of audio processed in 5 seconds = RTFx 1.0 + timing = {"total_audio_duration": 5.0, "total_inference_time": 5.0} + model_output = OliveModelOutput(preds=["text"], logits=timing) + result = rtfx.measure(model_output, ["text"]) + assert result == 1.0 + + def test_rtfx_missing_metadata(self): + rtfx = RealTimeFactor({}) + model_output = OliveModelOutput(preds=["text"], logits=None) + with pytest.raises(ValueError, match="RTFx metric requires timing metadata"): + rtfx.measure(model_output, ["text"]) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index e295d069ad..c10a442703 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -496,9 +496,15 @@ class TestLMEvaluatorModelClass: def test_lm_evaluator_dispatches_to_requested_backend( self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock, model_class ): + import inspect + from olive.evaluator.olive_evaluator import LMEvaluator from olive.model.handler.onnx import ONNXModelHandler + def _fake_evaluate(model, tasks, task_manager=None, log_samples=True, batch_size=1, device="cpu", limit=None): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) simple_evaluate_mock.return_value = {"results": {}} get_model_mock.return_value = MagicMock(return_value=MagicMock()) @@ -510,3 +516,601 @@ def test_lm_evaluator_dispatches_to_requested_backend( evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) get_model_mock.assert_called_once_with(model_class) + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_passes_confirm_run_unsafe_code( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Give the mock a signature that includes confirm_run_unsafe_code so inspect.signature works. + def _fake_evaluate( + model, + tasks, + task_manager=None, + log_samples=True, + batch_size=1, + device="cpu", + limit=None, + confirm_run_unsafe_code=False, + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator( + tasks=["mbpp"], model_class="ortgenai", batch_size=1, max_length=128, confirm_run_unsafe_code=True + ) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + # Verify confirm_run_unsafe_code=True was passed to simple_evaluate + call_kwargs = simple_evaluate_mock.call_args[1] + assert call_kwargs["confirm_run_unsafe_code"] is True + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_confirm_run_unsafe_code_defaults_false( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Give the mock a signature that includes confirm_run_unsafe_code so inspect.signature works. + def _fake_evaluate( + model, + tasks, + task_manager=None, + log_samples=True, + batch_size=1, + device="cpu", + limit=None, + confirm_run_unsafe_code=False, + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator(tasks=["arc_easy"], model_class="ort", batch_size=1, max_length=128) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + # Verify confirm_run_unsafe_code defaults to False + call_kwargs = simple_evaluate_mock.call_args[1] + assert call_kwargs["confirm_run_unsafe_code"] is False + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_skips_confirm_run_unsafe_code_for_older_lm_eval( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + """When lm-eval lacks confirm_run_unsafe_code, the kwarg must not be passed.""" + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Mock a signature WITHOUT confirm_run_unsafe_code (simulates older lm-eval). + def _fake_evaluate_old( + model, tasks, task_manager=None, log_samples=True, batch_size=1, device="cpu", limit=None + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate_old) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator( + tasks=["mbpp"], model_class="ortgenai", batch_size=1, max_length=128, confirm_run_unsafe_code=True + ) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + call_kwargs = simple_evaluate_mock.call_args[1] + assert "confirm_run_unsafe_code" not in call_kwargs + + +@pytest.mark.skipif( + importlib.util.find_spec("lm_eval") is None, + reason="lm_eval not installed", +) +class TestLMEvalORTGenAIGenerateUntil: + """Unit tests for LMEvalORTGenAIEvaluator.generate_until.""" + + def _make_mock_request(self, context, gen_kwargs): + """Create a mock lm-eval Request object.""" + req = MagicMock() + req.args = (context, gen_kwargs) + req.cache_hook = MagicMock() + return req + + def _mock_encode(self, ids): + """Return a mock that behaves like tokenizer.encode() output (has .tolist()).""" + import numpy as np + + return np.array(ids) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_stops_on_eos(self, mock_params_cls, mock_gen_cls): + """Test that generation stops when EOS token is produced.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100, 200]) # 3-token prompt + evaluator.tokenizer.decode.return_value = "hello" + + # Generator produces one token then EOS + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), # first token + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("def foo():", {"until": ["\n"], "max_gen_toks": 100}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + # After EOS on second token, only first token was appended → decode called once + assert results[0] == "hello" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_stops_on_stop_sequence(self, mock_params_cls, mock_gen_cls): + """Test that generation stops and trims at stop sequence.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + + evaluator.tokenizer.decode.side_effect = ["he", "l", "lo\n world", "hello\n world"] + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False, False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), + MagicMock(__getitem__=lambda s, k: 51), + MagicMock(__getitem__=lambda s, k: 52), + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": ["\n"], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "hello" # trimmed at \n + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_respects_max_length(self, mock_params_cls, mock_gen_cls): + """Test that total_max_length = min(prompt_len + max_gen_toks, max_length).""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 50 # Small model limit + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode(list(range(40))) # 40-token prompt + evaluator.tokenizer.decode.return_value = "x" + + # Generator immediately done (max_length reached) + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("long prompt", {"until": ["\n"], "max_gen_toks": 100}) + + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + # Verify search options set max_length = min(40+100, 50) = 50 + set_search_call = mock_params_cls.return_value.set_search_options + call_kwargs = set_search_call.call_args[1] + assert call_kwargs["max_length"] == 50 + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_multiple_eos_tokens(self, mock_params_cls, mock_gen_cls): + """Test that any token in eos_token_ids triggers stop.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2, 151645, 151643] # Multiple EOS like Qwen + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "result" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + # Second EOS token in the set triggers stop + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), + MagicMock(__getitem__=lambda s, k: 151643), # alternate EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "result" + + def test_generate_until_until_string_converted_to_list(self): + """Test that a string 'until' value is converted to a list.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + evaluator.tokenizer.decode.return_value = "x\n" + + with patch("onnxruntime_genai.GeneratorParams"), patch("onnxruntime_genai.Generator") as mock_gen_cls: + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + # Pass until as string, not list + request = self._make_mock_request("p", {"until": "\n", "max_gen_toks": 10}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + # Should still find the stop sequence (string was converted to list) + assert "\n" not in results[0] + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_uses_earliest_stop_match(self, mock_params_cls, mock_gen_cls): + """Test that stop trimming uses earliest occurrence across all stop sequences.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello\nworld" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": ["", "\n"], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "hello" + + @pytest.mark.parametrize( + ("gen_kwargs", "expected_max_length"), + [ + (None, 261), # default 256 when gen_kwargs is not a dict + ({"max_gen_toks": "7"}, 12), # parse numeric string + ({"max_new_tokens": "bad"}, 261), # invalid value falls back to default + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_parses_max_tokens_robustly( + self, mock_params_cls, mock_gen_cls, gen_kwargs, expected_max_length + ): + """Test robust parsing and clamping of max token kwargs.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2, 3, 4, 5]) # 5-token prompt + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", gen_kwargs) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + assert call_kwargs["max_length"] == expected_max_length + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_does_not_pass_batch_size_to_search_options(self, mock_params_cls, mock_gen_cls): + """batch_size is not a valid set_search_options kwarg for ORT GenAI — must never be passed.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2]) + evaluator.tokenizer.decode.return_value = "hello" + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 64}) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + assert "batch_size" not in call_kwargs, ( + f"batch_size must not be passed to set_search_options, got: {call_kwargs}" + ) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_returns_empty_when_max_gen_toks_zero(self, mock_params_cls, mock_gen_cls): + """Test that clamping a negative max_tokens to zero returns an empty completion immediately.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2, 3, 4, 5]) # 5-token prompt + + request = self._make_mock_request("prompt", {"max_tokens": -8}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == [""] + mock_gen_cls.assert_not_called() # generator should never be created + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cls): + """Test generation decodes only new tokens while preserving output.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello" # returned for full-sequence decode + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 11), + MagicMock(__getitem__=lambda s, k: 12), + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": []}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == ["hello"] + # With no stop sequences, tokens are decoded once as a full sequence (not per-token). + decode_inputs = [call.args[0] for call in evaluator.tokenizer.decode.call_args_list] + assert decode_inputs == [[11, 12]] + + @pytest.mark.parametrize( + ("temperature_val", "expect_do_sample"), + [ + ("0.7", True), # string float should be coerced + (None, False), # None should fall back to 0.0 + (0.0, False), # zero means greedy + (0.5, True), # normal float + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_temperature_coercion( + self, mock_params_cls, mock_gen_cls, temperature_val, expect_do_sample + ): + """Test that temperature is safely coerced from string/None without errors.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + gen_kwargs = {"until": [], "max_gen_toks": 10} + if temperature_val is not None: + gen_kwargs["temperature"] = temperature_val + + request = self._make_mock_request("prompt", gen_kwargs) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + if expect_do_sample: + assert call_kwargs["temperature"] > 0 + assert call_kwargs.get("do_sample") is True + else: + assert call_kwargs["temperature"] == 0.0 + assert "do_sample" not in call_kwargs + + @pytest.mark.parametrize( + ("do_sample_val", "expect_sampling"), + [ + (True, True), # bool True → sampling on + (False, False), # bool False → greedy + ("true", True), # string "true" → sampling on + ("false", False), # string "false" → greedy (was truthy before fix) + ("0", False), # string "0" → greedy + ("1", True), # string "1" → sampling + (1, True), # int 1 → sampling + (0, False), # int 0 → greedy + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_coerces_do_sample(self, mock_params_cls, mock_gen_cls, do_sample_val, expect_sampling): + """do_sample must be coerced to a real bool so string 'false'/'0' are not truthy.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request( + "prompt", {"until": [], "max_gen_toks": 10, "do_sample": do_sample_val, "temperature": 0.7} + ) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + if expect_sampling: + assert call_kwargs["temperature"] > 0, f"Expected sampling for do_sample={do_sample_val!r}" + assert call_kwargs.get("do_sample") is True, ( + f"do_sample=True must be set in search_options for do_sample={do_sample_val!r}" + ) + else: + assert call_kwargs["temperature"] == 0.0, f"Expected greedy for do_sample={do_sample_val!r}" + assert "do_sample" not in call_kwargs, ( + f"do_sample must not be set when greedy for do_sample={do_sample_val!r}" + ) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_tuple_until(self, mock_params_cls, mock_gen_cls): + """Until as a tuple must not be wrapped as a single element — each string is a stop sequence.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello\n world" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + # Pass until as a tuple — previously this would silently produce no stop enforcement + request = self._make_mock_request("prompt", {"until": ("\n",), "max_gen_toks": 256}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results[0] == "hello", f"Expected stop at \\n but got: {results[0]!r}" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_processes_multiple_requests_independently(self, mock_params_cls, mock_gen_cls): + """Multiple requests must not share mutable state (tail, stop_found, token_ids).""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + # First request decodes to text with a stop; second decodes cleanly + evaluator.tokenizer.decode.side_effect = [ + "\n", # per-token tail for req 1 (stop sequence present) + "hello\n", # full-sequence decode for req 1 + "world", # full-sequence decode for req 2 (no stop) + ] + + mock_generator = MagicMock() + # Req 1: is_done=False → generates token 10 → stop seq found → break (no more is_done) + # Req 2: is_done=False → generates token 20 → is_done=True → exit loop + mock_generator.is_done.side_effect = [False, False, True] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 10), # req 1 token + MagicMock(__getitem__=lambda s, k: 20), # req 2 token + ] + mock_gen_cls.return_value = mock_generator + + req1 = self._make_mock_request("p1", {"until": ["\n"], "max_gen_toks": 64}) + req2 = self._make_mock_request("p2", {"until": [], "max_gen_toks": 64}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [req1, req2]) + + assert results[0] == "hello" # trimmed at \n + assert results[1] == "world" # no stop, full text + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_calls_cache_hook(self, mock_params_cls, mock_gen_cls): + """cache_hook.add_partial must be called with the final generated text.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = [2] + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + evaluator.tokenizer.decode.return_value = "hello" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 10), + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 64}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == ["hello"] + request.cache_hook.add_partial.assert_called_once_with("generate_until", request.args, "hello") diff --git a/test/passes/onnx/test_mobius_model_builder.py b/test/passes/onnx/test_mobius_model_builder.py new file mode 100644 index 0000000000..cd3d7338d8 --- /dev/null +++ b/test/passes/onnx/test_mobius_model_builder.py @@ -0,0 +1,456 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the MobiusBuilder Olive pass.""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from olive.hardware.accelerator import AcceleratorSpec, Device +from olive.hardware.constants import ExecutionProvider +from olive.model import HfModelHandler, ONNXModelHandler +from olive.model.handler.composite import CompositeModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.mobius_model_builder import MobiusBuilder + +_HAS_REAL_MOBIUS = importlib.util.find_spec("mobius") is not None + + +@pytest.fixture(autouse=True, scope="module") +def _stub_mobius_module(): + """Stub the optional mobius package into sys.modules for the duration of this module. + + patch("mobius.build") resolves the module via sys.modules, so it works correctly + even in environments where mobius-ai is not installed (e.g. Olive CI). + The stub is only injected when mobius is absent; if the real package is installed, + this fixture is a no-op. + """ + if "mobius" in sys.modules: + yield + return + fake = types.ModuleType("mobius") + fake.build = None # overridden per-test by patch("mobius.build") + sys.modules["mobius"] = fake + yield + sys.modules.pop("mobius", None) + + +@pytest.fixture(autouse=True) +def mock_hf_config(): + """Prevent HfModelHandler.__init__ from making network calls to resolve model configs.""" + mock_cfg = MagicMock() + mock_cfg.to_dict.return_value = {} + with ( + patch.object(HfModelHandler, "get_hf_model_config", return_value=mock_cfg), + patch.object(HfModelHandler, "get_load_kwargs", return_value={}), + ): + yield + + +def _make_hf_model(model_path: str, load_kwargs: dict | None = None) -> HfModelHandler: + model = HfModelHandler(model_path=model_path) + if load_kwargs: + # Patch get_load_kwargs on the instance to return the given kwargs. + model.get_load_kwargs = lambda: load_kwargs + return model + + +def _make_pass(ep: str = ExecutionProvider.CPUExecutionProvider) -> MobiusBuilder: + accelerator_spec = AcceleratorSpec(accelerator_type=Device.CPU, execution_provider=ep) + return create_pass_from_dict( + MobiusBuilder, + {"precision": "fp32"}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + +def _fake_pkg(keys: list[str], _output_dir: Path) -> MagicMock: + """Create a fake ModelPackage that writes dummy .onnx files when .save() is called.""" + + def _save(directory: str, **_kwargs): + out = Path(directory) + if len(keys) == 1: + # Single-component: saved as /model.onnx + (out / "model.onnx").write_text("dummy") + else: + # Multi-component: saved as //model.onnx + for k in keys: + (out / k).mkdir(parents=True, exist_ok=True) + (out / k / "model.onnx").write_text("dummy") + + pkg = MagicMock() + pkg.keys.return_value = keys + pkg.__iter__ = MagicMock(return_value=iter(keys)) + pkg.items.return_value = [(k, MagicMock()) for k in keys] + pkg.save.side_effect = _save + return pkg + + +def _patch_build(pkg: MagicMock): + # Patch mobius.build directly — lazy import inside _run_for_config means + # patching the module attribute, not the local binding. + # Also patch _write_genai_config since the default runtime is ort-genai. + return _CombinePatches( + patch("mobius.build", return_value=pkg), + patch.object(MobiusBuilder, "_write_genai_config"), + ) + + +class _CombinePatches: + """Combine multiple patch context managers into one.""" + + def __init__(self, *patches): + self._patches = patches + self._mocks = [] + + def __enter__(self): + self._mocks = [p.__enter__() for p in self._patches] + return self._mocks[0] # return the build mock + + def __exit__(self, *args): + for p in reversed(self._patches): + p.__exit__(*args) + + +# --------------------------------------------------------------------------- +# Configuration tests +# --------------------------------------------------------------------------- + + +def test_default_config_params(): + """MobiusBuilder must declare precision and runtime, and must not declare execution_provider or trust_remote_code.""" + accelerator_spec = AcceleratorSpec( + accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider + ) + config = MobiusBuilder._default_config(accelerator_spec) # pylint: disable=protected-access + assert "precision" in config + assert "runtime" in config + assert "execution_provider" not in config + assert "trust_remote_code" not in config + + +def test_is_not_accelerator_agnostic(): + """Pass must be EP-specific because it chooses fused ops based on the EP.""" + accelerator_spec = AcceleratorSpec( + accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider + ) + assert MobiusBuilder.is_accelerator_agnostic(accelerator_spec) is False + + +def test_ep_map_covers_common_providers(): + assert ExecutionProvider.CPUExecutionProvider in MobiusBuilder.EP_MAP + assert ExecutionProvider.CUDAExecutionProvider in MobiusBuilder.EP_MAP + assert ExecutionProvider.DmlExecutionProvider in MobiusBuilder.EP_MAP + assert ExecutionProvider.WebGpuExecutionProvider in MobiusBuilder.EP_MAP + assert MobiusBuilder.EP_MAP[ExecutionProvider.CPUExecutionProvider] == "cpu" + assert MobiusBuilder.EP_MAP[ExecutionProvider.CUDAExecutionProvider] == "cuda" + assert MobiusBuilder.EP_MAP[ExecutionProvider.DmlExecutionProvider] == "dml" + assert MobiusBuilder.EP_MAP[ExecutionProvider.WebGpuExecutionProvider] == "webgpu" + + +# --------------------------------------------------------------------------- +# Single-component model tests +# --------------------------------------------------------------------------- + + +def test_single_component_returns_onnx_handler(tmp_path): + """Single-component package (e.g. LLM) → ONNXModelHandler.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + + with _patch_build(pkg) as mock_build: + p = _make_pass() + result = p.run(_make_hf_model("meta-llama/Llama-3-8B"), out) + + assert isinstance(result, ONNXModelHandler) + assert not isinstance(result, CompositeModelHandler) + assert Path(result.model_path).exists() + mock_build.assert_called_once() + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["execution_provider"] == "cpu" + assert call_kwargs["dtype"] == "f32" + + +def test_model_onnx_exists_after_run(tmp_path): + """The saved model.onnx file must exist on disk.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + + with _patch_build(pkg): + p = _make_pass() + result = p.run(_make_hf_model("org/model"), out) + + # ONNXModelHandler.model_path already points to the .onnx file + assert Path(result.model_path).exists() + + +def test_genai_artifacts_in_single_component(tmp_path): + """ORT GenAI artifacts must be included in single-component model's additional_files.""" + out = tmp_path / "out" + out.mkdir(parents=True, exist_ok=True) + pkg = _fake_pkg(["model"], out) + + # Mock genai artifact files that would be created + genai_config = str(out / "genai_config.json") + tokenizer_file = str(out / "tokenizer.json") + (out / "genai_config.json").write_text("{}") + (out / "tokenizer.json").write_text("{}") + + # Mock _write_genai_config to return the artifact paths + mock_genai_artifacts = {"genai_config": genai_config, "tokenizer.json": tokenizer_file} + + with _patch_build(pkg), patch.object(MobiusBuilder, "_write_genai_config", return_value=mock_genai_artifacts): + p = _make_pass() + result = p.run(_make_hf_model("meta-llama/Llama-3-8B"), out) + + assert isinstance(result, ONNXModelHandler) + # Verify genai artifacts are in additional_files + additional_files = result.model_attributes.get("additional_files", []) + assert genai_config in additional_files + assert tokenizer_file in additional_files + + +def test_genai_artifacts_in_multi_component(tmp_path): + """ORT GenAI artifacts must be included in all components of multi-component models.""" + out = tmp_path / "out" + out.mkdir(parents=True, exist_ok=True) + keys = ["model", "vision", "embedding"] + pkg = _fake_pkg(keys, out) + + # Mock genai artifact files + genai_config = str(out / "genai_config.json") + image_processor = str(out / "image_processor.json") + (out / "genai_config.json").write_text("{}") + (out / "image_processor.json").write_text("{}") + + # Mock _write_genai_config to return the artifact paths + mock_genai_artifacts = {"genai_config": genai_config, "image_processor": image_processor} + + with _patch_build(pkg), patch.object(MobiusBuilder, "_write_genai_config", return_value=mock_genai_artifacts): + p = _make_pass() + result = p.run(_make_hf_model("microsoft/phi-4-vision"), out) + + assert isinstance(result, CompositeModelHandler) + # Verify all components include genai artifacts + for component in result.model_components: + additional_files = component.model_attributes.get("additional_files", []) + assert genai_config in additional_files + assert image_processor in additional_files + + +# --------------------------------------------------------------------------- +# Multi-component model tests +# --------------------------------------------------------------------------- + + +def test_multi_component_returns_composite_handler(tmp_path): + """Multi-component package (VLM) → CompositeModelHandler with one component per key.""" + out = tmp_path / "out" + keys = ["model", "vision", "embedding"] + pkg = _fake_pkg(keys, out) + + with _patch_build(pkg): + p = _make_pass() + result = p.run(_make_hf_model("microsoft/phi-4-vision"), out) + + assert isinstance(result, CompositeModelHandler) + assert result.model_component_names == keys + components = list(result.model_components) + assert len(components) == 3 + for comp in components: + assert isinstance(comp, ONNXModelHandler) + + +# --------------------------------------------------------------------------- +# EP auto-detection tests +# --------------------------------------------------------------------------- + + +def test_ep_auto_detected_from_accelerator(tmp_path): + """Execution provider is determined by the Olive accelerator spec.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + + accelerator_spec = AcceleratorSpec( + accelerator_type=Device.GPU, execution_provider=ExecutionProvider.CUDAExecutionProvider + ) + p = create_pass_from_dict( + MobiusBuilder, + {"precision": "fp16"}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + with _patch_build(pkg) as mock_build: + p.run(_make_hf_model("org/model"), out) + + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["execution_provider"] == "cuda" + assert call_kwargs["dtype"] == "f16" + + +def test_unsupported_ep_falls_back_to_default(tmp_path): + """If accelerator EP is unsupported, pass should fall back to mobius default EP.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + + # Create a pass with an unsupported EP (one not in EP_MAP). + # QNN exists in all Olive environments and is intentionally unsupported by MobiusBuilder. + accelerator_spec = AcceleratorSpec( + accelerator_type=Device.NPU, execution_provider=ExecutionProvider.JsExecutionProvider + ) + p = create_pass_from_dict( + MobiusBuilder, + {"precision": "fp32"}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + with _patch_build(pkg) as mock_build: + p.run(_make_hf_model("org/model"), out) + + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["execution_provider"] == MobiusBuilder.MobiusEP.DEFAULT + + +def test_none_execution_provider_falls_back_to_default(tmp_path): + """If execution_provider is None, pass should fall back to mobius default EP.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + + # Create a pass with execution_provider=None (unspecified). + accelerator_spec = AcceleratorSpec(accelerator_type=Device.CPU, execution_provider=None) + p = create_pass_from_dict( + MobiusBuilder, + {"precision": "fp32"}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + with _patch_build(pkg) as mock_build: + p.run(_make_hf_model("org/model"), out) + + call_kwargs = mock_build.call_args.kwargs + assert call_kwargs["execution_provider"] == MobiusBuilder.MobiusEP.DEFAULT + + +@pytest.mark.skipif(not _HAS_REAL_MOBIUS, reason="mobius-ai is not publicly available in CI yet") +def test_write_genai_config_requires_real_mobius(tmp_path): + """Integration smoke test for _write_genai_config when real mobius is installed.""" + # This test is intentionally lightweight and only verifies the import path. + # Unit behavior is covered by tests that patch _write_genai_config. + from mobius.integrations.ort_genai import write_ort_genai_config + + assert callable(write_ort_genai_config) + + +# --------------------------------------------------------------------------- +# Input validation tests +# --------------------------------------------------------------------------- + + +def test_non_hf_model_raises(tmp_path): + """Passing a non-HfModelHandler must raise ValueError.""" + out = tmp_path / "out" + out.mkdir() + (out / "model.onnx").write_bytes(b"") + + onnx_model = ONNXModelHandler(model_path=str(out), onnx_file_name="model.onnx") + p = _make_pass() + with pytest.raises(ValueError, match="HfModelHandler"): + p.run(onnx_model, tmp_path / "result") + + +def test_import_error_raised_when_mobius_missing(tmp_path): + """ImportError must surface clearly when mobius is not installed.""" + p = _make_pass() + with patch.dict(sys.modules, {"mobius": None}), pytest.raises(ImportError, match="mobius"): + p.run(_make_hf_model("org/model"), tmp_path / "out") + + +# --------------------------------------------------------------------------- +# Output validation tests +# --------------------------------------------------------------------------- + + +def test_missing_output_file_raises_runtime_error(tmp_path): + """RuntimeError must be raised if pkg.save() does not produce model.onnx.""" + out = tmp_path / "out" + # _fake_pkg normally writes the file; use a pkg whose save() does nothing. + pkg = MagicMock() + pkg.keys.return_value = ["model"] + pkg.__iter__ = MagicMock(return_value=iter(["model"])) + pkg.save.return_value = None # save() succeeds but writes nothing + + with _patch_build(pkg), pytest.raises(RuntimeError, match="expected output file not found"): + _make_pass().run(_make_hf_model("org/model"), out) + + +def test_missing_component_file_raises_runtime_error(tmp_path): + """RuntimeError for multi-component if any component's model.onnx is missing.""" + out = tmp_path / "out" + keys = ["model", "vision", "embedding"] + pkg = MagicMock() + pkg.keys.return_value = keys + pkg.__iter__ = MagicMock(return_value=iter(keys)) + + # save() only creates 'model' component, skips 'vision' and 'embedding' + def _partial_save(directory: str, **_kwargs): + d = Path(directory) / "model" + d.mkdir(parents=True) + (d / "model.onnx").write_text("dummy") + + pkg.save.side_effect = _partial_save + + with _patch_build(pkg), pytest.raises(RuntimeError, match="expected output file not found"): + _make_pass().run(_make_hf_model("org/vlm"), out) + + +# --------------------------------------------------------------------------- +# Security / trust_remote_code tests +# --------------------------------------------------------------------------- + + +def test_trust_remote_code_warning_logged(tmp_path): + """trust_remote_code=True on the model must emit a warning about trusted model sources.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + p = create_pass_from_dict( + MobiusBuilder, + {"precision": "fp32"}, + disable_search=True, + accelerator_spec=AcceleratorSpec( + accelerator_type=Device.CPU, execution_provider=ExecutionProvider.CPUExecutionProvider + ), + ) + with ( + _patch_build(pkg), + patch("olive.passes.onnx.mobius_model_builder.logger") as mock_logger, + ): + p.run(_make_hf_model("org/model", load_kwargs={"trust_remote_code": True}), out) + + warning_messages = [call.args[0] for call in mock_logger.warning.call_args_list] + assert any("trust_remote_code" in msg for msg in warning_messages) + + +def test_no_warning_when_trust_remote_code_false(tmp_path): + """No trust_remote_code warning must be emitted when the model does not set trust_remote_code.""" + out = tmp_path / "out" + pkg = _fake_pkg(["model"], out) + with ( + _patch_build(pkg), + patch("olive.passes.onnx.mobius_model_builder.logger") as mock_logger, + ): + _make_pass().run(_make_hf_model("org/model"), out) + + warning_messages = [call.args[0] for call in mock_logger.warning.call_args_list] + assert not any("trust_remote_code" in msg for msg in warning_messages) diff --git a/test/passes/onnx/test_model_builder.py b/test/passes/onnx/test_model_builder.py index ba62005e4b..be5b728c65 100644 --- a/test/passes/onnx/test_model_builder.py +++ b/test/passes/onnx/test_model_builder.py @@ -2,18 +2,45 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import json +import sys +import types from pathlib import Path +from unittest.mock import Mock import onnx import pytest -from olive.model import ONNXModelHandler +from olive.model import CompositeModelHandler, HfModelHandler, ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.model_builder import ModelBuilder from olive.passes.pytorch.rtn import Rtn from test.utils import make_local_tiny_llama +def _create_test_onnx_model(model_path: Path, node_name: str): + input_info = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 1]) + output_info = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 1]) + node = onnx.helper.make_node("Identity", ["input"], ["output"], name=node_name) + graph = onnx.helper.make_graph([node], "test_graph", [input_info], [output_info]) + model = onnx.helper.make_model(graph) + onnx.save(model, model_path) + + +def _mock_genai_builder(monkeypatch, create_model_fn): + builder_module = types.ModuleType("onnxruntime_genai.models.builder") + builder_module.create_model = create_model_fn + models_module = types.ModuleType("onnxruntime_genai.models") + models_module.builder = builder_module + genai_module = types.ModuleType("onnxruntime_genai") + genai_module.__version__ = "0.8.0" + genai_module.models = models_module + monkeypatch.setitem(sys.modules, "onnxruntime_genai", genai_module) + monkeypatch.setitem(sys.modules, "onnxruntime_genai.models", models_module) + monkeypatch.setitem(sys.modules, "onnxruntime_genai.models.builder", builder_module) + monkeypatch.setattr(ModelBuilder, "maybe_patch_quant", staticmethod(lambda: None)) + + @pytest.mark.parametrize("metadata_only", [True, False]) def test_model_builder(tmp_path, metadata_only): input_model = make_local_tiny_llama(tmp_path / "input_model", "onnx" if metadata_only else "hf") @@ -100,3 +127,72 @@ def test_model_builder_layer_annotations(tmp_path, layer_annotations): assert len(node_names_with_metadata) > 0, ( "Expected nodes with metadata_props when layer_annotations are provided" ) + + +def test_model_builder_apply_annotations_on_single_file_fallback(tmp_path, monkeypatch): + def fake_create_model( + model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs + ): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + _create_test_onnx_model(output_dir / "actual.onnx", "test_node") + (output_dir / "actual.onnx.data").write_text("external_data") + (output_dir / "tokenizer.json").write_text("{}") + (output_dir / "genai_config.json").write_text(json.dumps({"search": {}})) + + _mock_genai_builder(monkeypatch, fake_create_model) + input_model = Mock(spec=HfModelHandler) + input_model.model_name_or_path = "dummy-model" + input_model.adapter_path = None + input_model.model_attributes = {"split_assignments": {"model.layers.0": 1}} + + p = create_pass_from_dict( + ModelBuilder, {"precision": "fp32", "extra_options": {"filename": "expected.onnx"}}, disable_search=True + ) + output_folder = tmp_path / "output_model" + output_model = p.run(input_model, output_folder) + + assert isinstance(output_model, ONNXModelHandler) + assert output_model.onnx_file_name == "actual.onnx" + model_proto = onnx.load(output_folder / "actual.onnx", load_external_data=False) + metadata_props = {prop.key: prop.value for prop in model_proto.metadata_props} + assert metadata_props["split_assignments"] == "model.layers.0=1" + assert str(output_folder / "actual.onnx") not in output_model.model_attributes["additional_files"] + assert str(output_folder / "actual.onnx.data") not in output_model.model_attributes["additional_files"] + assert str(output_folder / "tokenizer.json") in output_model.model_attributes["additional_files"] + + +def test_model_builder_multi_file_output_preserves_component_filenames(tmp_path, monkeypatch): + def fake_create_model( + model_name, input_path, output_dir, precision, execution_provider, cache_dir, filename, **kwargs + ): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + _create_test_onnx_model(output_dir / "encoder.onnx", "encoder_node") + _create_test_onnx_model(output_dir / "decoder.onnx", "decoder_node") + (output_dir / "encoder.onnx.data").write_text("encoder_data") + (output_dir / "decoder.onnx.data").write_text("decoder_data") + (output_dir / "tokenizer.json").write_text("{}") + (output_dir / "genai_config.json").write_text(json.dumps({"search": {}})) + + _mock_genai_builder(monkeypatch, fake_create_model) + input_model = Mock(spec=HfModelHandler) + input_model.model_name_or_path = "dummy-model" + input_model.adapter_path = None + input_model.model_attributes = {} + + p = create_pass_from_dict(ModelBuilder, {"precision": "fp32"}, disable_search=True) + output_folder = tmp_path / "output_model" + output_model = p.run(input_model, output_folder) + + assert isinstance(output_model, CompositeModelHandler) + expected_component_names = sorted(["encoder.onnx", "decoder.onnx"]) + assert output_model.model_component_names == expected_component_names + component_onnx_files = [component.onnx_file_name for component in output_model.model_components] + assert component_onnx_files == output_model.model_component_names + additional_files = output_model.model_attributes["additional_files"] + assert str(output_folder / "encoder.onnx") not in additional_files + assert str(output_folder / "decoder.onnx") not in additional_files + assert str(output_folder / "encoder.onnx.data") not in additional_files + assert str(output_folder / "decoder.onnx.data") not in additional_files + assert str(output_folder / "tokenizer.json") in additional_files