Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 71 additions & 1 deletion olive/evaluator/lmeval_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,10 @@ def __init__(
self.max_length = max_length
else:
self.max_length = genai_config["search"]["max_length"]
self._eot_token_id = genai_config["model"]["eos_token_id"]
eot = genai_config["model"]["eos_token_id"]
# eos_token_id can be a list (e.g. [1, 106] for Gemma4) or a scalar.
# Use the first element for loglikelihood evaluation.
self._eot_token_id = eot[0] if isinstance(eot, list) else eot
Comment on lines +514 to +515
self.params = og.GeneratorParams(self.model)
self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False)

Expand Down Expand Up @@ -575,3 +578,70 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor

def complete(self):
pass

def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
"""Generate text until a stop sequence is reached.

Used by benchmarks like MMLU Pro (CoT variant) that score by generating
chain-of-thought text and extracting the answer with a regex filter.
"""
results = []
for request in tqdm(requests, disable=disable_tqdm, desc="Running generate_until requests"):
context = request.args[0]
gen_kwargs = request.args[1]

until = gen_kwargs.get("until", [])
max_gen_toks = gen_kwargs.get("max_gen_toks", 256)
if isinstance(until, str):
until = [until]

input_ids = self.tok_encode(context)
max_new_tokens = min(max_gen_toks, self.max_length - len(input_ids))
if max_new_tokens <= 0:
results.append("")
continue

params = og.GeneratorParams(self.model)
params.set_search_options(
max_length=len(input_ids) + max_new_tokens,
past_present_share_buffer=False,
batch_size=1,
)
if gen_kwargs.get("temperature", 0.0) == 0.0:
params.set_search_options(do_sample=False)
else:
params.set_search_options(
do_sample=True,
temperature=gen_kwargs["temperature"],
)

generator = og.Generator(self.model, params)
generator.append_tokens([input_ids])

eos_ids = self._eot_token_id if isinstance(self._eot_token_id, (list, tuple)) else [self._eot_token_id]

generated_ids = []
# Decode periodically to check for stop sequences
decode_interval = 16
while not generator.is_done():
generator.generate_next_token()
token_id = generator.get_next_tokens()[0]
generated_ids.append(token_id)
if token_id in eos_ids:
break
Comment on lines +621 to +631
# Check stop sequences periodically by decoding
if until and len(generated_ids) % decode_interval == 0:
partial_text = self.tokenizer.decode(generated_ids)
if any(stop_seq in partial_text for stop_seq in until):
break

generated_text = self.tokenizer.decode(generated_ids)

# Truncate at the first stop sequence
for stop_seq in until:
idx = generated_text.find(stop_seq)
if idx != -1:
generated_text = generated_text[:idx]

results.append(generated_text)
return results
Loading