From c76ff8cd5b9f9b7acd23dfdfb4009bee69103905 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 7 May 2026 22:01:49 +0000 Subject: [PATCH] Implement generate_until for LMEvalORTGenAIEvaluator Add generate_until method to the ortgenai evaluator, enabling chain-of-thought (CoT) benchmarks like MMLU Pro that generate text and extract answers via regex filters. Previously, generate_until raised NotImplementedError, limiting the evaluator to log-likelihood-only benchmarks. This blocked CoT-scored benchmarks which are the standard methodology for instruction-tuned models like Gemma4. The implementation: - Tokenizes the prompt and generates token-by-token using og.Generator - Supports multiple EOS token IDs (common in modern models) - Checks stop sequences periodically during generation for early exit - Handles temperature/sampling and max_gen_toks from gen_kwargs - Truncates output at the first matching stop sequence Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu --- olive/evaluator/lmeval_ort.py | 72 ++++++++++++++++++++++++++++++++++- 1 file changed, 71 insertions(+), 1 deletion(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index fd69b066e..95c34c8d3 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -509,7 +509,10 @@ def __init__( self.max_length = max_length else: self.max_length = genai_config["search"]["max_length"] - self._eot_token_id = genai_config["model"]["eos_token_id"] + eot = genai_config["model"]["eos_token_id"] + # eos_token_id can be a list (e.g. [1, 106] for Gemma4) or a scalar. + # Use the first element for loglikelihood evaluation. + self._eot_token_id = eot[0] if isinstance(eot, list) else eot self.params = og.GeneratorParams(self.model) self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False) @@ -575,3 +578,70 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor def complete(self): pass + + def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: + """Generate text until a stop sequence is reached. + + Used by benchmarks like MMLU Pro (CoT variant) that score by generating + chain-of-thought text and extracting the answer with a regex filter. + """ + results = [] + for request in tqdm(requests, disable=disable_tqdm, desc="Running generate_until requests"): + context = request.args[0] + gen_kwargs = request.args[1] + + until = gen_kwargs.get("until", []) + max_gen_toks = gen_kwargs.get("max_gen_toks", 256) + if isinstance(until, str): + until = [until] + + input_ids = self.tok_encode(context) + max_new_tokens = min(max_gen_toks, self.max_length - len(input_ids)) + if max_new_tokens <= 0: + results.append("") + continue + + params = og.GeneratorParams(self.model) + params.set_search_options( + max_length=len(input_ids) + max_new_tokens, + past_present_share_buffer=False, + batch_size=1, + ) + if gen_kwargs.get("temperature", 0.0) == 0.0: + params.set_search_options(do_sample=False) + else: + params.set_search_options( + do_sample=True, + temperature=gen_kwargs["temperature"], + ) + + generator = og.Generator(self.model, params) + generator.append_tokens([input_ids]) + + eos_ids = self._eot_token_id if isinstance(self._eot_token_id, (list, tuple)) else [self._eot_token_id] + + generated_ids = [] + # Decode periodically to check for stop sequences + decode_interval = 16 + while not generator.is_done(): + generator.generate_next_token() + token_id = generator.get_next_tokens()[0] + generated_ids.append(token_id) + if token_id in eos_ids: + break + # Check stop sequences periodically by decoding + if until and len(generated_ids) % decode_interval == 0: + partial_text = self.tokenizer.decode(generated_ids) + if any(stop_seq in partial_text for stop_seq in until): + break + + generated_text = self.tokenizer.decode(generated_ids) + + # Truncate at the first stop sequence + for stop_seq in until: + idx = generated_text.find(stop_seq) + if idx != -1: + generated_text = generated_text[:idx] + + results.append(generated_text) + return results