Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 169 additions & 11 deletions olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,151 @@
self._batch_size = batch_size


@register_model("ort-multimodal")
class LMEvalORTMultimodalEvaluator(LMEvalOnnxBase):
"""Evaluate a multimodal ONNX model using direct ORT InferenceSession.

Designed for ORT GenAI multimodal packages (e.g. Gemma4) that have separate
decoder and embedding ONNX models. Uses direct session.run() instead of
GenAI's Generator API, avoiding the overhead of loading all sub-models
and creating Generator objects per call.

Supports models with heterogeneous KV cache head dimensions (e.g. Gemma4
with head_dim=256 for sliding attention and head_dim=512 for full attention),
which the standard 'ort' backend cannot handle.
"""

def __init__(
self,
pretrained: str,
batch_size: int | str = 1,
max_length: int | None = None,
ep: str | None = None,
ep_options: dict | None = None,
**kwargs,
):
"""Initialize the evaluator.

:param pretrained: Path to the ORT GenAI model directory containing
genai_config.json, decoder/, embedding/, and tokenizer files.
:param batch_size: Batch size for evaluation.
:param max_length: Maximum sequence length. Defaults to config value.
:param ep: Execution provider (e.g. 'CUDAExecutionProvider').
:param ep_options: Provider options dict.
"""
import onnxruntime as ort

super().__init__()

model_dir = Path(pretrained)

# Load genai_config to find model paths and metadata
with (model_dir / "genai_config.json").open() as f:
genai_config = json.load(f)

model_config = genai_config["model"]
decoder_config = model_config["decoder"]

# Resolve max_length
if max_length:
self._max_length = max_length
else:
self._max_length = min(
genai_config.get("search", {}).get("max_length", 2048),
2048, # Cap at 2048 for eval efficiency
)

# EOS token handling (can be list or scalar)
eot = model_config["eos_token_id"]
self._eot_token_id = eot[0] if isinstance(eot, list) else eot

# Set up execution providers
providers = []
if ep:
providers.append(ep)
providers.append("CPUExecutionProvider")
Comment on lines +522 to +525

# Load decoder session
decoder_path = str(model_dir / decoder_config["filename"])
logger.info("Loading decoder from %s", decoder_path)
self._decoder_sess = ort.InferenceSession(decoder_path, providers=providers)

Comment on lines +521 to +531
# Detect per-layer KV cache shapes (supports heterogeneous head_dim)
self._kv_shapes = {}
for inp in self._decoder_sess.get_inputs():
if inp.name.startswith("past_key_values"):
self._kv_shapes[inp.name] = {
"num_kv_heads": inp.shape[1],
"head_dim": inp.shape[3],
}

# Load embedding session if available
self._embedding_sess = None
self._hidden_size = decoder_config["hidden_size"]
embedding_config = model_config.get("embedding")
if embedding_config:
emb_path = str(model_dir / embedding_config["filename"])
logger.info("Loading embedding from %s", emb_path)
self._embedding_sess = ort.InferenceSession(emb_path, providers=providers)

# Load tokenizer from model directory
self._tokenizer = AutoTokenizer.from_pretrained(str(model_dir))
self.batch_size = int(batch_size)

@property
def max_length(self) -> int:
return self._max_length

@property
def eot_token_id(self) -> int:
return self._eot_token_id

def tok_encode(self, string: str, **kwargs) -> list[int]:
return self._tokenizer.encode(string, add_special_tokens=False)

def prepare(self, requests: list[LogLikelihoodInputs]):
pass

def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor:
import numpy as np

batch_size, seq_len = input_ids.shape
ids_np = input_ids.cpu().numpy().astype(np.int64)

# Get embeddings if embedding model is available
if self._embedding_sess is not None:
emb_feed = {
"input_ids": ids_np,
"image_features": np.zeros((0, self._hidden_size), dtype=np.float16),
"audio_features": np.zeros((0, self._hidden_size), dtype=np.float16),
}
inputs_embeds = self._embedding_sess.run(None, emb_feed)[0]
else:
inputs_embeds = np.zeros((batch_size, seq_len, self._hidden_size), dtype=np.float16)

# Build decoder feed with per-layer KV cache shapes
dec_feed = {
"input_ids": ids_np,
"inputs_embeds": inputs_embeds,
"attention_mask": np.ones((batch_size, seq_len), dtype=np.int64),
"position_ids": np.broadcast_to(
np.arange(seq_len, dtype=np.int64).reshape(1, -1),
(batch_size, seq_len),
).copy(),
}
for name, info in self._kv_shapes.items():
dec_feed[name] = np.zeros(
(batch_size, info["num_kv_heads"], 0, info["head_dim"]),
dtype=np.float16,
)

result = self._decoder_sess.run(["logits"], dec_feed)
return torch.from_numpy(result[0])

def complete(self):
pass


@register_model("ortgenai")
class LMEvalORTGenAIEvaluator(LMEvalOnnxBase):
"""Evaluate a model using ONNX Runtime GenAI."""
Expand Down Expand Up @@ -520,6 +665,7 @@

self.device = device
self._returns_full_logits = self._detect_full_logits()
self._cached_generator = None

def _detect_full_logits(self) -> bool:
"""Check if the model returns logits for all input positions or only the last."""
Expand All @@ -546,40 +692,52 @@
def prepare(self, requests: list[LogLikelihoodInputs]):
pass

def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
def _get_generator(self, batch_size: int) -> "og.Generator":

Check warning

Code scanning / lintrunner

RUFF/UP037 Warning

Remove quotes from type annotation.
See https://docs.astral.sh/ruff/rules/quoted-annotation
"""Get a Generator, reusing via rewind_to(0) when possible."""
if self._cached_generator is not None:
try:
self._cached_generator.rewind_to(0)
return self._cached_generator
except Exception:
# rewind_to not supported for this model — fall back to new Generator
self._cached_generator = None

self.params.set_search_options(batch_size=batch_size)
generator = og.Generator(self.model, self.params)
self._cached_generator = generator
return generator

if self._returns_full_logits:
generator.append_tokens(input_ids.tolist())
return torch.from_numpy(generator.get_output("logits")).to(self.device)
def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
generator = self._get_generator(batch_size)

# Model only returns logits for the last appended position.
if batch_size > 1 and cont_len > 1:
raise ValueError(
"batch_size > 1 is not supported when the model returns single-position logits"
" and continuation length > 1. Right-padding misaligns continuation positions across"
" batch elements. Use batch_size=1 instead."
)

# Bulk-append context tokens, then step through the last cont_len tokens
# one at a time to collect only the logits we actually need.
# Use incremental token appending with get_logits() to avoid copying
# the full logits tensor from GPU to CPU. get_output("logits") copies
# seq_len * vocab_size * 2 bytes (e.g. 472MB for 900 tokens with
# 262K vocab), while get_logits() copies only vocab_size * 4 bytes
# (~1MB) per position.
n_logits = max(cont_len, 1)
prefix_len = seq_len - n_logits
generator.append_tokens(input_ids[:, : prefix_len + 1].tolist())
all_logits = [torch.from_numpy(generator.get_output("logits")).to(self.device)]
all_logits = [torch.from_numpy(generator.get_logits()).to(self.device)]
for i in range(prefix_len + 1, seq_len):
generator.append_tokens(input_ids[:, i : i + 1].tolist())
all_logits.append(torch.from_numpy(generator.get_output("logits")).to(self.device))
all_logits.append(torch.from_numpy(generator.get_logits()).to(self.device))

# No need to pad to [batch, seq_len, vocab]. The slicing in _loglikelihood_tokens computes
# ctx_len = inplen + (logits.shape[0] - padding_len_inp), which adjusts for the shorter
# seq dimension so the continuation slice still lands on the correct positions.
return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab]

def complete(self):
pass
self._cached_generator = None

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is reached.
Expand Down
6 changes: 6 additions & 0 deletions olive/evaluator/olive_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1572,6 +1572,12 @@ def evaluate(
"ep_options": self.ep_options,
"device": device,
}
elif self.model_class == "ort-multimodal":
init_args = {
"pretrained": str(Path(model.model_path).parent),
"ep": self.ep or execution_providers,
"ep_options": self.ep_options,
}
Comment on lines +1575 to +1580
Comment on lines +1575 to +1580
else:
raise ValueError(f"Unknown model class: {self.model_class}")

Expand Down
Loading