diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 50d1f1289..ad841a360 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -460,6 +460,151 @@ def initialize_buffers(self, batch_size: int, max_length: int): self._batch_size = batch_size +@register_model("ort-multimodal") +class LMEvalORTMultimodalEvaluator(LMEvalOnnxBase): + """Evaluate a multimodal ONNX model using direct ORT InferenceSession. + + Designed for ORT GenAI multimodal packages (e.g. Gemma4) that have separate + decoder and embedding ONNX models. Uses direct session.run() instead of + GenAI's Generator API, avoiding the overhead of loading all sub-models + and creating Generator objects per call. + + Supports models with heterogeneous KV cache head dimensions (e.g. Gemma4 + with head_dim=256 for sliding attention and head_dim=512 for full attention), + which the standard 'ort' backend cannot handle. + """ + + def __init__( + self, + pretrained: str, + batch_size: int | str = 1, + max_length: int | None = None, + ep: str | None = None, + ep_options: dict | None = None, + **kwargs, + ): + """Initialize the evaluator. + + :param pretrained: Path to the ORT GenAI model directory containing + genai_config.json, decoder/, embedding/, and tokenizer files. + :param batch_size: Batch size for evaluation. + :param max_length: Maximum sequence length. Defaults to config value. + :param ep: Execution provider (e.g. 'CUDAExecutionProvider'). + :param ep_options: Provider options dict. + """ + import onnxruntime as ort + + super().__init__() + + model_dir = Path(pretrained) + + # Load genai_config to find model paths and metadata + with (model_dir / "genai_config.json").open() as f: + genai_config = json.load(f) + + model_config = genai_config["model"] + decoder_config = model_config["decoder"] + + # Resolve max_length + if max_length: + self._max_length = max_length + else: + self._max_length = min( + genai_config.get("search", {}).get("max_length", 2048), + 2048, # Cap at 2048 for eval efficiency + ) + + # EOS token handling (can be list or scalar) + eot = model_config["eos_token_id"] + self._eot_token_id = eot[0] if isinstance(eot, list) else eot + + # Set up execution providers + providers = [] + if ep: + providers.append(ep) + providers.append("CPUExecutionProvider") + + # Load decoder session + decoder_path = str(model_dir / decoder_config["filename"]) + logger.info("Loading decoder from %s", decoder_path) + self._decoder_sess = ort.InferenceSession(decoder_path, providers=providers) + + # Detect per-layer KV cache shapes (supports heterogeneous head_dim) + self._kv_shapes = {} + for inp in self._decoder_sess.get_inputs(): + if inp.name.startswith("past_key_values"): + self._kv_shapes[inp.name] = { + "num_kv_heads": inp.shape[1], + "head_dim": inp.shape[3], + } + + # Load embedding session if available + self._embedding_sess = None + self._hidden_size = decoder_config["hidden_size"] + embedding_config = model_config.get("embedding") + if embedding_config: + emb_path = str(model_dir / embedding_config["filename"]) + logger.info("Loading embedding from %s", emb_path) + self._embedding_sess = ort.InferenceSession(emb_path, providers=providers) + + # Load tokenizer from model directory + self._tokenizer = AutoTokenizer.from_pretrained(str(model_dir)) + self.batch_size = int(batch_size) + + @property + def max_length(self) -> int: + return self._max_length + + @property + def eot_token_id(self) -> int: + return self._eot_token_id + + def tok_encode(self, string: str, **kwargs) -> list[int]: + return self._tokenizer.encode(string, add_special_tokens=False) + + def prepare(self, requests: list[LogLikelihoodInputs]): + pass + + def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor: + import numpy as np + + batch_size, seq_len = input_ids.shape + ids_np = input_ids.cpu().numpy().astype(np.int64) + + # Get embeddings if embedding model is available + if self._embedding_sess is not None: + emb_feed = { + "input_ids": ids_np, + "image_features": np.zeros((0, self._hidden_size), dtype=np.float16), + "audio_features": np.zeros((0, self._hidden_size), dtype=np.float16), + } + inputs_embeds = self._embedding_sess.run(None, emb_feed)[0] + else: + inputs_embeds = np.zeros((batch_size, seq_len, self._hidden_size), dtype=np.float16) + + # Build decoder feed with per-layer KV cache shapes + dec_feed = { + "input_ids": ids_np, + "inputs_embeds": inputs_embeds, + "attention_mask": np.ones((batch_size, seq_len), dtype=np.int64), + "position_ids": np.broadcast_to( + np.arange(seq_len, dtype=np.int64).reshape(1, -1), + (batch_size, seq_len), + ).copy(), + } + for name, info in self._kv_shapes.items(): + dec_feed[name] = np.zeros( + (batch_size, info["num_kv_heads"], 0, info["head_dim"]), + dtype=np.float16, + ) + + result = self._decoder_sess.run(["logits"], dec_feed) + return torch.from_numpy(result[0]) + + def complete(self): + pass + + @register_model("ortgenai") class LMEvalORTGenAIEvaluator(LMEvalOnnxBase): """Evaluate a model using ONNX Runtime GenAI.""" @@ -520,6 +665,7 @@ def __init__( self.device = device self._returns_full_logits = self._detect_full_logits() + self._cached_generator = None def _detect_full_logits(self) -> bool: """Check if the model returns logits for all input positions or only the last.""" @@ -546,16 +692,25 @@ def tok_encode(self, string: str, **kwargs) -> list[int]: def prepare(self, requests: list[LogLikelihoodInputs]): pass - def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor: - batch_size, seq_len = input_ids.shape + def _get_generator(self, batch_size: int) -> "og.Generator": + """Get a Generator, reusing via rewind_to(0) when possible.""" + if self._cached_generator is not None: + try: + self._cached_generator.rewind_to(0) + return self._cached_generator + except Exception: + # rewind_to not supported for this model — fall back to new Generator + self._cached_generator = None + self.params.set_search_options(batch_size=batch_size) generator = og.Generator(self.model, self.params) + self._cached_generator = generator + return generator - if self._returns_full_logits: - generator.append_tokens(input_ids.tolist()) - return torch.from_numpy(generator.get_output("logits")).to(self.device) + def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor: + batch_size, seq_len = input_ids.shape + generator = self._get_generator(batch_size) - # Model only returns logits for the last appended position. if batch_size > 1 and cont_len > 1: raise ValueError( "batch_size > 1 is not supported when the model returns single-position logits" @@ -563,15 +718,18 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor " batch elements. Use batch_size=1 instead." ) - # Bulk-append context tokens, then step through the last cont_len tokens - # one at a time to collect only the logits we actually need. + # Use incremental token appending with get_logits() to avoid copying + # the full logits tensor from GPU to CPU. get_output("logits") copies + # seq_len * vocab_size * 2 bytes (e.g. 472MB for 900 tokens with + # 262K vocab), while get_logits() copies only vocab_size * 4 bytes + # (~1MB) per position. n_logits = max(cont_len, 1) prefix_len = seq_len - n_logits generator.append_tokens(input_ids[:, : prefix_len + 1].tolist()) - all_logits = [torch.from_numpy(generator.get_output("logits")).to(self.device)] + all_logits = [torch.from_numpy(generator.get_logits()).to(self.device)] for i in range(prefix_len + 1, seq_len): generator.append_tokens(input_ids[:, i : i + 1].tolist()) - all_logits.append(torch.from_numpy(generator.get_output("logits")).to(self.device)) + all_logits.append(torch.from_numpy(generator.get_logits()).to(self.device)) # No need to pad to [batch, seq_len, vocab]. The slicing in _loglikelihood_tokens computes # ctx_len = inplen + (logits.shape[0] - padding_len_inp), which adjusts for the shorter @@ -579,7 +737,7 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab] def complete(self): - pass + self._cached_generator = None def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: """Generate text until a stop sequence is reached. diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 05933b8b6..537aa3b0d 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1572,6 +1572,12 @@ def evaluate( "ep_options": self.ep_options, "device": device, } + elif self.model_class == "ort-multimodal": + init_args = { + "pretrained": str(Path(model.model_path).parent), + "ep": self.ep or execution_providers, + "ep_options": self.ep_options, + } else: raise ValueError(f"Unknown model class: {self.model_class}")