From 39b2113a7dbad362e4864196d07129f487cd299c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 8 May 2026 16:28:35 +0000 Subject: [PATCH 1/2] =?UTF-8?q?perf:=20use=20get=5Flogits()=20to=20avoid?= =?UTF-8?q?=20massive=20GPU=E2=86=92CPU=20logits=20copy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace get_output('logits') with get_logits() in the ortgenai evaluator. get_output copies the full logits tensor from GPU to CPU each call (e.g. 472MB for 900 tokens × 262K vocab in f16), taking ~410ms. get_logits() returns only the last position's logits (~1MB), taking ~2ms. The evaluator now always uses incremental token appending: bulk-prefill the context, then step through continuation tokens collecting logits at each position via get_logits(). This is both faster and simpler than the previous approach which had separate paths for full-logits and single-position models. Benchmark on Gemma4 E2B-IT MMLU Pro (limit=50, CUDA EP): - Before: 46.9s (10.6 req/s) - After: 24.2s (20.7 req/s) - Speedup: 1.94× Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu --- olive/evaluator/lmeval_ort.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 50d1f1289..afa48695a 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -551,11 +551,6 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor self.params.set_search_options(batch_size=batch_size) generator = og.Generator(self.model, self.params) - if self._returns_full_logits: - generator.append_tokens(input_ids.tolist()) - return torch.from_numpy(generator.get_output("logits")).to(self.device) - - # Model only returns logits for the last appended position. if batch_size > 1 and cont_len > 1: raise ValueError( "batch_size > 1 is not supported when the model returns single-position logits" @@ -563,15 +558,18 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor " batch elements. Use batch_size=1 instead." ) - # Bulk-append context tokens, then step through the last cont_len tokens - # one at a time to collect only the logits we actually need. + # Use incremental token appending with get_logits() to avoid copying + # the full logits tensor from GPU to CPU. get_output("logits") copies + # seq_len * vocab_size * 2 bytes (e.g. 472MB for 900 tokens with + # 262K vocab), while get_logits() copies only vocab_size * 4 bytes + # (~1MB) per position. n_logits = max(cont_len, 1) prefix_len = seq_len - n_logits generator.append_tokens(input_ids[:, : prefix_len + 1].tolist()) - all_logits = [torch.from_numpy(generator.get_output("logits")).to(self.device)] + all_logits = [torch.from_numpy(generator.get_logits()).to(self.device)] for i in range(prefix_len + 1, seq_len): generator.append_tokens(input_ids[:, i : i + 1].tolist()) - all_logits.append(torch.from_numpy(generator.get_output("logits")).to(self.device)) + all_logits.append(torch.from_numpy(generator.get_logits()).to(self.device)) # No need to pad to [batch, seq_len, vocab]. The slicing in _loglikelihood_tokens computes # ctx_len = inplen + (logits.shape[0] - padding_len_inp), which adjusts for the shorter From f086863a14f81b8f92a369f9994626d0272e8c6f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 8 May 2026 17:36:45 +0000 Subject: [PATCH 2/2] fix: reword ortgenai incremental logits batching error message Agent-Logs-Url: https://github.com/microsoft/Olive/sessions/8e4a3ef0-bdf6-4a0c-a2a8-3b3e56650f0a Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index afa48695a..c4a158533 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -553,7 +553,7 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor if batch_size > 1 and cont_len > 1: raise ValueError( - "batch_size > 1 is not supported when the model returns single-position logits" + "batch_size > 1 is not supported when using incremental get_logits() retrieval" " and continuation length > 1. Right-padding misaligns continuation positions across" " batch elements. Use batch_size=1 instead." )