diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index fd69b066e..95c34c8d3 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -509,7 +509,10 @@ def __init__( self.max_length = max_length else: self.max_length = genai_config["search"]["max_length"] - self._eot_token_id = genai_config["model"]["eos_token_id"] + eot = genai_config["model"]["eos_token_id"] + # eos_token_id can be a list (e.g. [1, 106] for Gemma4) or a scalar. + # Use the first element for loglikelihood evaluation. + self._eot_token_id = eot[0] if isinstance(eot, list) else eot self.params = og.GeneratorParams(self.model) self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False) @@ -575,3 +578,70 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor def complete(self): pass + + def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: + """Generate text until a stop sequence is reached. + + Used by benchmarks like MMLU Pro (CoT variant) that score by generating + chain-of-thought text and extracting the answer with a regex filter. + """ + results = [] + for request in tqdm(requests, disable=disable_tqdm, desc="Running generate_until requests"): + context = request.args[0] + gen_kwargs = request.args[1] + + until = gen_kwargs.get("until", []) + max_gen_toks = gen_kwargs.get("max_gen_toks", 256) + if isinstance(until, str): + until = [until] + + input_ids = self.tok_encode(context) + max_new_tokens = min(max_gen_toks, self.max_length - len(input_ids)) + if max_new_tokens <= 0: + results.append("") + continue + + params = og.GeneratorParams(self.model) + params.set_search_options( + max_length=len(input_ids) + max_new_tokens, + past_present_share_buffer=False, + batch_size=1, + ) + if gen_kwargs.get("temperature", 0.0) == 0.0: + params.set_search_options(do_sample=False) + else: + params.set_search_options( + do_sample=True, + temperature=gen_kwargs["temperature"], + ) + + generator = og.Generator(self.model, params) + generator.append_tokens([input_ids]) + + eos_ids = self._eot_token_id if isinstance(self._eot_token_id, (list, tuple)) else [self._eot_token_id] + + generated_ids = [] + # Decode periodically to check for stop sequences + decode_interval = 16 + while not generator.is_done(): + generator.generate_next_token() + token_id = generator.get_next_tokens()[0] + generated_ids.append(token_id) + if token_id in eos_ids: + break + # Check stop sequences periodically by decoding + if until and len(generated_ids) % decode_interval == 0: + partial_text = self.tokenizer.decode(generated_ids) + if any(stop_seq in partial_text for stop_seq in until): + break + + generated_text = self.tokenizer.decode(generated_ids) + + # Truncate at the first stop sequence + for stop_seq in until: + idx = generated_text.find(stop_seq) + if idx != -1: + generated_text = generated_text[:idx] + + results.append(generated_text) + return results