From bb38c03a4186014a07a6bd60a92ed249c3c9ebca Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Wed, 29 Apr 2026 16:04:29 -0700 Subject: [PATCH 01/13] Add MBPP/code generation evaluator support via ortgenai backend - Implement generate_until in LMEvalORTGenAIEvaluator for code generation tasks - Fix EOS token handling for models with multiple EOS token IDs (e.g. Qwen) - Add confirm_run_unsafe_code parameter to CLI and LMEvaluator for code eval tasks - Update base ORT class error message to direct users to ortgenai backend - Add unit tests for generate_until and confirm_run_unsafe_code This enables running MBPP, HumanEval, and other code generation benchmarks through Olive's ortgenai evaluation backend. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/cli/benchmark.py | 11 ++ olive/evaluator/lmeval_ort.py | 95 ++++++++++- olive/evaluator/olive_evaluator.py | 2 + test/evaluator/test_olive_evaluator.py | 217 +++++++++++++++++++++++++ 4 files changed, 323 insertions(+), 2 deletions(-) diff --git a/olive/cli/benchmark.py b/olive/cli/benchmark.py index adad95773..18237bbf3 100644 --- a/olive/cli/benchmark.py +++ b/olive/cli/benchmark.py @@ -76,6 +76,13 @@ def register_subcommand(parser: ArgumentParser): help="Backend for ONNX model evaluation. Use 'auto' to infer backend from model type.", ) + lmeval_group.add_argument( + "--confirm_run_unsafe_code", + action="store_true", + default=False, + help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).", + ) + add_logging_options(sub_parser) add_save_config_file_options(sub_parser) add_shared_cache_options(sub_parser) @@ -117,6 +124,10 @@ def _get_run_config(self, tempdir: str) -> dict: ("evaluators", "evaluator", "model_class"), None if self.args.backend == "auto" else self.args.backend, ), + ( + ("evaluators", "evaluator", "confirm_run_unsafe_code"), + self.args.confirm_run_unsafe_code or None, + ), ] for keys, value in to_replace: diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index fd69b066e..d385d186f 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -190,7 +190,10 @@ def loglikelihood_rolling(self, requests, disable_tqdm: bool = False) -> list[fl raise NotImplementedError("Yet to be implemented!") def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: - raise NotImplementedError("Yet to be implemented!") + raise NotImplementedError( + "generate_until is not supported by this model backend. " + "Use model_class='ortgenai' for generative tasks such as MBPP or HumanEval." + ) @register_model("ort") @@ -509,7 +512,14 @@ def __init__( self.max_length = max_length else: self.max_length = genai_config["search"]["max_length"] - self._eot_token_id = genai_config["model"]["eos_token_id"] + eos = genai_config["model"]["eos_token_id"] + # eos_token_id can be a single int or a list of ints + if isinstance(eos, list): + self._eot_token_id = eos[0] + self._eos_token_ids = set(eos) + else: + self._eot_token_id = eos + self._eos_token_ids = {eos} self.params = og.GeneratorParams(self.model) self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False) @@ -573,5 +583,86 @@ def model_call(self, input_ids: torch.Tensor, cont_len: int = 0) -> torch.Tensor # seq dimension so the continuation slice still lands on the correct positions. return torch.cat(all_logits, dim=1) # [batch, n_logits, vocab] + def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: + """Generate text until a stop sequence is found or max tokens reached. + + Supports generative evaluation tasks such as MBPP and HumanEval. + Each request is a tuple of (context_string, gen_kwargs_dict). + """ + results = [] + for request in requests: + context = request.args[0] + gen_kwargs = request.args[1] + + # Extract stop sequences + until = gen_kwargs.get("until", []) + if isinstance(until, str): + until = [until] + + # Extract generation parameters + max_gen_toks = gen_kwargs.get( + "max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens", 256)) + ) + temperature = gen_kwargs.get("temperature", 0.0) + do_sample = gen_kwargs.get("do_sample", temperature > 0) + + # Tokenize the prompt + prompt_ids = self.tokenizer.encode(context).tolist() + prompt_len = len(prompt_ids) + + # Compute total max_length: prompt + new tokens, capped by model limit + total_max_length = min(prompt_len + max_gen_toks, self.max_length) + + # Create fresh generator params per request to avoid state leakage + params = og.GeneratorParams(self.model) + search_options = { + "max_length": total_max_length, + "past_present_share_buffer": False, + "batch_size": 1, + } + if do_sample: + search_options["temperature"] = temperature + else: + search_options["temperature"] = 0.0 + params.set_search_options(**search_options) + + # Run generation token by token to check for stop sequences + generator = og.Generator(self.model, params) + generator.append_tokens([prompt_ids]) + + generated_ids = [] + generated_text = "" + stop_found = False + + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_sequence(0)[-1] + + # Check for EOS token(s) + if new_token in self._eos_token_ids: + break + + generated_ids.append(new_token) + generated_text = self.tokenizer.decode(generated_ids) + + # Check stop sequences against generated text + for stop_seq in until: + if stop_seq in generated_text: + # Trim at the stop sequence + generated_text = generated_text[: generated_text.index(stop_seq)] + stop_found = True + break + + if stop_found: + break + + results.append(generated_text) + + # lm-eval cache hook + if hasattr(request, "cache_hook") and request.cache_hook is not None: + request.cache_hook.add_partial("generate_until", request.args, generated_text) + + return results + def complete(self): pass diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 0814850a1..4cd28f257 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1029,6 +1029,7 @@ def __init__(self, tasks: list[str], **kwargs): self.ep = kwargs.get("execution_provider") self.ep_options = kwargs.get("provider_options") self.device = kwargs.get("device") + self.confirm_run_unsafe_code = kwargs.get("confirm_run_unsafe_code", False) def evaluate( self, @@ -1108,6 +1109,7 @@ def evaluate( batch_size=self.batch_size, device=device, limit=self.limit, + confirm_run_unsafe_code=self.confirm_run_unsafe_code, ) for task_name in sorted(results["results"].keys()): diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index e295d069a..7e7ec0084 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -510,3 +510,220 @@ def test_lm_evaluator_dispatches_to_requested_backend( evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) get_model_mock.assert_called_once_with(model_class) + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_passes_confirm_run_unsafe_code( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator( + tasks=["mbpp"], model_class="ortgenai", batch_size=1, max_length=128, confirm_run_unsafe_code=True + ) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + # Verify confirm_run_unsafe_code=True was passed to simple_evaluate + call_kwargs = simple_evaluate_mock.call_args[1] + assert call_kwargs["confirm_run_unsafe_code"] is True + + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_confirm_run_unsafe_code_defaults_false( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator(tasks=["arc_easy"], model_class="ort", batch_size=1, max_length=128) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + # Verify confirm_run_unsafe_code defaults to False + call_kwargs = simple_evaluate_mock.call_args[1] + assert call_kwargs["confirm_run_unsafe_code"] is False + + +@pytest.mark.skipif( + importlib.util.find_spec("lm_eval") is None or importlib.util.find_spec("onnxruntime_genai") is None, + reason="lm_eval or onnxruntime_genai not installed", +) +class TestLMEvalORTGenAIGenerateUntil: + """Unit tests for LMEvalORTGenAIEvaluator.generate_until.""" + + def _make_mock_request(self, context, gen_kwargs): + """Create a mock lm-eval Request object.""" + req = MagicMock() + req.args = (context, gen_kwargs) + req.cache_hook = MagicMock() + return req + + def _mock_encode(self, ids): + """Return a mock that behaves like tokenizer.encode() output (has .tolist()).""" + import numpy as np + + return np.array(ids) + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_stops_on_eos(self, mock_params_cls, mock_gen_cls): + """Test that generation stops when EOS token is produced.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100, 200]) # 3-token prompt + evaluator.tokenizer.decode.return_value = "hello" + + # Generator produces one token then EOS + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), # first token + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("def foo():", {"until": ["\n"], "max_gen_toks": 100}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + # After EOS on second token, only first token was appended → decode called once + assert results[0] == "hello" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_stops_on_stop_sequence(self, mock_params_cls, mock_gen_cls): + """Test that generation stops and trims at stop sequence.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + + evaluator.tokenizer.decode.side_effect = ["he", "hel", "hello\n world"] + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False, False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), + MagicMock(__getitem__=lambda s, k: 51), + MagicMock(__getitem__=lambda s, k: 52), + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": ["\n"], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "hello" # trimmed at \n + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_respects_max_length(self, mock_params_cls, mock_gen_cls): + """Test that total_max_length = min(prompt_len + max_gen_toks, max_length).""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 50 # Small model limit + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode(list(range(40))) # 40-token prompt + evaluator.tokenizer.decode.return_value = "x" + + # Generator immediately done (max_length reached) + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("long prompt", {"until": ["\n"], "max_gen_toks": 100}) + + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + # Verify search options set max_length = min(40+100, 50) = 50 + set_search_call = mock_params_cls.return_value.set_search_options + call_kwargs = set_search_call.call_args[1] + assert call_kwargs["max_length"] == 50 + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_multiple_eos_tokens(self, mock_params_cls, mock_gen_cls): + """Test that any token in _eos_token_ids triggers stop.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2, 151645, 151643} # Multiple EOS like Qwen + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "result" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + # Second EOS token in the set triggers stop + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 50), + MagicMock(__getitem__=lambda s, k: 151643), # alternate EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "result" + + def test_generate_until_until_string_converted_to_list(self): + """Test that a string 'until' value is converted to a list.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + evaluator.tokenizer.decode.return_value = "x\n" + + with patch("onnxruntime_genai.GeneratorParams"), patch("onnxruntime_genai.Generator") as mock_gen_cls: + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + # Pass until as string, not list + request = self._make_mock_request("p", {"until": "\n", "max_gen_toks": 10}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + # Should still find the stop sequence (string was converted to list) + assert "\n" not in results[0] From a95404d6887c9825b3a47f112ab9e7c75e5e5b99 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Wed, 29 Apr 2026 16:41:52 -0700 Subject: [PATCH 02/13] Fix ORT GenAI stop handling and generation robustness - Trim generation at earliest matching stop sequence across all provided stops\n- Improve generate_until decoding efficiency by decoding incrementally\n- Harden max token parsing for malformed generation kwargs\n- Add evaluator tests for stop ordering and token parsing edge cases\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 31 +++++---- test/evaluator/test_olive_evaluator.py | 92 +++++++++++++++++++++++++- 2 files changed, 110 insertions(+), 13 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index d385d186f..94be42ba5 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -592,17 +592,25 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: results = [] for request in requests: context = request.args[0] - gen_kwargs = request.args[1] + gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {} # Extract stop sequences until = gen_kwargs.get("until", []) if isinstance(until, str): until = [until] + elif until is None: + until = [] + elif not isinstance(until, list): + until = [until] + until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq] # Extract generation parameters - max_gen_toks = gen_kwargs.get( - "max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens", 256)) - ) + max_gen_toks = gen_kwargs.get("max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens"))) + try: + max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256 + except (TypeError, ValueError): + max_gen_toks = 256 + max_gen_toks = max(max_gen_toks, 0) temperature = gen_kwargs.get("temperature", 0.0) do_sample = gen_kwargs.get("do_sample", temperature > 0) @@ -632,7 +640,6 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: generated_ids = [] generated_text = "" - stop_found = False while not generator.is_done(): generator.generate_next_token() @@ -643,17 +650,17 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: break generated_ids.append(new_token) - generated_text = self.tokenizer.decode(generated_ids) + generated_text += self.tokenizer.decode([new_token]) # Check stop sequences against generated text + earliest_stop_idx = None for stop_seq in until: - if stop_seq in generated_text: - # Trim at the stop sequence - generated_text = generated_text[: generated_text.index(stop_seq)] - stop_found = True - break + stop_idx = generated_text.find(stop_seq) + if stop_idx != -1 and (earliest_stop_idx is None or stop_idx < earliest_stop_idx): + earliest_stop_idx = stop_idx - if stop_found: + if earliest_stop_idx is not None: + generated_text = generated_text[:earliest_stop_idx] break results.append(generated_text) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 7e7ec0084..703946fc1 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -626,7 +626,7 @@ def test_generate_until_stops_on_stop_sequence(self, mock_params_cls, mock_gen_c evaluator.tokenizer = MagicMock() evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) - evaluator.tokenizer.decode.side_effect = ["he", "hel", "hello\n world"] + evaluator.tokenizer.decode.side_effect = ["he", "l", "lo\n world"] mock_generator = MagicMock() mock_generator.is_done.side_effect = [False, False, False, False] @@ -727,3 +727,93 @@ def test_generate_until_until_string_converted_to_list(self): # Should still find the stop sequence (string was converted to list) assert "\n" not in results[0] + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_uses_earliest_stop_match(self, mock_params_cls, mock_gen_cls): + """Test that stop trimming uses earliest occurrence across all stop sequences.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello\nworld" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": ["", "\n"], "max_gen_toks": 256}) + + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert len(results) == 1 + assert results[0] == "hello" + + @pytest.mark.parametrize( + ("gen_kwargs", "expected_max_length"), + [ + (None, 261), # default 256 when gen_kwargs is not a dict + ({"max_gen_toks": "7"}, 12), # parse numeric string + ({"max_new_tokens": "bad"}, 261), # invalid value falls back to default + ({"max_tokens": -8}, 5), # clamp negative to zero + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_parses_max_tokens_robustly( + self, mock_params_cls, mock_gen_cls, gen_kwargs, expected_max_length + ): + """Test robust parsing and clamping of max token kwargs.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2, 3, 4, 5]) # 5-token prompt + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", gen_kwargs) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + assert call_kwargs["max_length"] == expected_max_length + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cls): + """Test generation decodes only new tokens while preserving output.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator._eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.side_effect = ["he", "llo"] + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 11), + MagicMock(__getitem__=lambda s, k: 12), + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": []}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == ["hello"] + decode_inputs = [call.args[0] for call in evaluator.tokenizer.decode.call_args_list] + assert decode_inputs == [[11], [12]] From a3e3930f55990e8042ff0929aa3381a0c1880249 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 10:36:54 -0700 Subject: [PATCH 03/13] Address PR review comments in lmeval_ort generate_until - Remove dead generated_ids list - Replace string concatenation with list accumulation to avoid quadratic growth - Rename _eos_token_ids to eos_token_ids (public attribute) to fix protected-access lint warnings in tests Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 12 ++++++------ test/evaluator/test_olive_evaluator.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 94be42ba5..2af594d42 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -516,10 +516,10 @@ def __init__( # eos_token_id can be a single int or a list of ints if isinstance(eos, list): self._eot_token_id = eos[0] - self._eos_token_ids = set(eos) + self.eos_token_ids = set(eos) else: self._eot_token_id = eos - self._eos_token_ids = {eos} + self.eos_token_ids = {eos} self.params = og.GeneratorParams(self.model) self.params.set_search_options(max_length=self.max_length, past_present_share_buffer=False) @@ -638,7 +638,7 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: generator = og.Generator(self.model, params) generator.append_tokens([prompt_ids]) - generated_ids = [] + generated_chunks = [] generated_text = "" while not generator.is_done(): @@ -646,11 +646,11 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: new_token = generator.get_sequence(0)[-1] # Check for EOS token(s) - if new_token in self._eos_token_ids: + if new_token in self.eos_token_ids: break - generated_ids.append(new_token) - generated_text += self.tokenizer.decode([new_token]) + generated_chunks.append(self.tokenizer.decode([new_token])) + generated_text = "".join(generated_chunks) # Check stop sequences against generated text earliest_stop_idx = None diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 703946fc1..fd30dee25 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -589,7 +589,7 @@ def test_generate_until_stops_on_eos(self, mock_params_cls, mock_gen_cls): from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -620,7 +620,7 @@ def test_generate_until_stops_on_stop_sequence(self, mock_params_cls, mock_gen_c from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -651,7 +651,7 @@ def test_generate_until_respects_max_length(self, mock_params_cls, mock_gen_cls) from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 50 # Small model limit evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -675,11 +675,11 @@ def test_generate_until_respects_max_length(self, mock_params_cls, mock_gen_cls) @patch("onnxruntime_genai.Generator") @patch("onnxruntime_genai.GeneratorParams") def test_generate_until_handles_multiple_eos_tokens(self, mock_params_cls, mock_gen_cls): - """Test that any token in _eos_token_ids triggers stop.""" + """Test that any token in eos_token_ids triggers stop.""" from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2, 151645, 151643} # Multiple EOS like Qwen + evaluator.eos_token_ids = {2, 151645, 151643} # Multiple EOS like Qwen evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -707,7 +707,7 @@ def test_generate_until_until_string_converted_to_list(self): from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -735,7 +735,7 @@ def test_generate_until_uses_earliest_stop_match(self, mock_params_cls, mock_gen from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -772,7 +772,7 @@ def test_generate_until_parses_max_tokens_robustly( from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() @@ -795,7 +795,7 @@ def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cl from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) - evaluator._eos_token_ids = {2} + evaluator.eos_token_ids = {2} evaluator.max_length = 1024 evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() From 33f93343ba053d44695da98b50e70fe88b8279e3 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 10:54:59 -0700 Subject: [PATCH 04/13] Fix generate_until edge cases from PR review - Guard against empty eos_token_id list with a clear ValueError - Early-return empty completion when prompt >= max_length or max_gen_toks == 0 to avoid passing invalid max_length to the ORT GenAI generator - Fix generated_text not being set when loop exits via EOS break - Use tail-buffer for stop-sequence checking instead of full join per token - Update test for max_gen_toks=0 to assert early-return behaviour Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 45 ++++++++++++++++++-------- test/evaluator/test_olive_evaluator.py | 20 +++++++++++- 2 files changed, 50 insertions(+), 15 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 2af594d42..7300e78c9 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -515,6 +515,8 @@ def __init__( eos = genai_config["model"]["eos_token_id"] # eos_token_id can be a single int or a list of ints if isinstance(eos, list): + if not eos: + raise ValueError("genai_config model.eos_token_id must not be an empty list") self._eot_token_id = eos[0] self.eos_token_ids = set(eos) else: @@ -621,6 +623,13 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: # Compute total max_length: prompt + new tokens, capped by model limit total_max_length = min(prompt_len + max_gen_toks, self.max_length) + # If the prompt already fills or exceeds the model limit, no generation is possible. + if prompt_len >= self.max_length or max_gen_toks == 0: + results.append("") + if hasattr(request, "cache_hook") and request.cache_hook is not None: + request.cache_hook.add_partial("generate_until", request.args, "") + continue + # Create fresh generator params per request to avoid state leakage params = og.GeneratorParams(self.model) search_options = { @@ -639,7 +648,9 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: generator.append_tokens([prompt_ids]) generated_chunks = [] - generated_text = "" + stop_idx = None + # Tail buffer wide enough to detect any stop sequence across chunk boundaries + max_stop_len = max((len(s) for s in until), default=0) while not generator.is_done(): generator.generate_next_token() @@ -649,19 +660,25 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: if new_token in self.eos_token_ids: break - generated_chunks.append(self.tokenizer.decode([new_token])) - generated_text = "".join(generated_chunks) - - # Check stop sequences against generated text - earliest_stop_idx = None - for stop_seq in until: - stop_idx = generated_text.find(stop_seq) - if stop_idx != -1 and (earliest_stop_idx is None or stop_idx < earliest_stop_idx): - earliest_stop_idx = stop_idx - - if earliest_stop_idx is not None: - generated_text = generated_text[:earliest_stop_idx] - break + chunk = self.tokenizer.decode([new_token]) + generated_chunks.append(chunk) + + # Check stop sequences against a tail window to avoid O(n²) full join + if until: + tail = "".join(generated_chunks[-(max_stop_len + 1) :]) if max_stop_len else "" + tail_offset = len("".join(generated_chunks)) - len(tail) + earliest = None + for stop_seq in until: + idx = tail.find(stop_seq) + if idx != -1: + abs_idx = tail_offset + idx + if earliest is None or abs_idx < earliest: + earliest = abs_idx + if earliest is not None: + stop_idx = earliest + break + + generated_text = "".join(generated_chunks) if stop_idx is None else "".join(generated_chunks)[:stop_idx] results.append(generated_text) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index fd30dee25..2e61553b0 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -760,7 +760,6 @@ def test_generate_until_uses_earliest_stop_match(self, mock_params_cls, mock_gen (None, 261), # default 256 when gen_kwargs is not a dict ({"max_gen_toks": "7"}, 12), # parse numeric string ({"max_new_tokens": "bad"}, 261), # invalid value falls back to default - ({"max_tokens": -8}, 5), # clamp negative to zero ], ) @patch("onnxruntime_genai.Generator") @@ -788,6 +787,25 @@ def test_generate_until_parses_max_tokens_robustly( call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] assert call_kwargs["max_length"] == expected_max_length + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_returns_empty_when_max_gen_toks_zero(self, mock_params_cls, mock_gen_cls): + """Test that clamping a negative max_tokens to zero returns an empty completion immediately.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2, 3, 4, 5]) # 5-token prompt + + request = self._make_mock_request("prompt", {"max_tokens": -8}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == [""] + mock_gen_cls.assert_not_called() # generator should never be created + @patch("onnxruntime_genai.Generator") @patch("onnxruntime_genai.GeneratorParams") def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cls): From 957d755c7b10919a51d148dd6dff0877799261e6 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 11:16:31 -0700 Subject: [PATCH 05/13] Fix tail-buffer and temperature coercion in generate_until - Replace chunk-count tail window with character-based rolling tail string, fixing missed stop sequences when stop strings span more tokens than characters - Track generated_len with a running counter to avoid O(n^2) join for tail_offset - Coerce temperature from str/None safely with float() + fallback to 0.0 - Add parametrized tests for temperature coercion edge cases (string, None, zero, float) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 18 +++++++---- test/evaluator/test_olive_evaluator.py | 41 ++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 7300e78c9..c6a0014ec 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -613,7 +613,10 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: except (TypeError, ValueError): max_gen_toks = 256 max_gen_toks = max(max_gen_toks, 0) - temperature = gen_kwargs.get("temperature", 0.0) + try: + temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0) + except (TypeError, ValueError): + temperature = 0.0 do_sample = gen_kwargs.get("do_sample", temperature > 0) # Tokenize the prompt @@ -648,9 +651,12 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: generator.append_tokens([prompt_ids]) generated_chunks = [] + generated_len = 0 # running total character count, avoids O(n²) join for offset stop_idx = None - # Tail buffer wide enough to detect any stop sequence across chunk boundaries + # Character-based rolling tail wide enough to catch any stop sequence + # across chunk boundaries, regardless of how many tokens a stop string spans. max_stop_len = max((len(s) for s in until), default=0) + tail = "" while not generator.is_done(): generator.generate_next_token() @@ -662,11 +668,13 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: chunk = self.tokenizer.decode([new_token]) generated_chunks.append(chunk) + generated_len += len(chunk) - # Check stop sequences against a tail window to avoid O(n²) full join + # Maintain a character-based tail of exactly max_stop_len + len(chunk) chars + # so stop sequences that span chunk boundaries are never missed. if until: - tail = "".join(generated_chunks[-(max_stop_len + 1) :]) if max_stop_len else "" - tail_offset = len("".join(generated_chunks)) - len(tail) + tail = (tail + chunk)[-(max_stop_len + len(chunk)):] + tail_offset = generated_len - len(tail) earliest = None for stop_seq in until: idx = tail.find(stop_seq) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 2e61553b0..685fed910 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -835,3 +835,44 @@ def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cl assert results == ["hello"] decode_inputs = [call.args[0] for call in evaluator.tokenizer.decode.call_args_list] assert decode_inputs == [[11], [12]] + + @pytest.mark.parametrize( + ("temperature_val", "expect_do_sample"), + [ + ("0.7", True), # string float should be coerced + (None, False), # None should fall back to 0.0 + (0.0, False), # zero means greedy + (0.5, True), # normal float + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_temperature_coercion( + self, mock_params_cls, mock_gen_cls, temperature_val, expect_do_sample + ): + """Test that temperature is safely coerced from string/None without errors.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + gen_kwargs = {"until": [], "max_gen_toks": 10} + if temperature_val is not None: + gen_kwargs["temperature"] = temperature_val + + request = self._make_mock_request("prompt", gen_kwargs) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + if expect_do_sample: + assert call_kwargs["temperature"] > 0 + else: + assert call_kwargs["temperature"] == 0.0 From f548cd532276f9a8203c795e69f4ff58cda28adb Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 11:59:42 -0700 Subject: [PATCH 06/13] Apply ruff formatting to fix lintrunner CI failure Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 6 ++++-- test/evaluator/test_olive_evaluator.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index c6a0014ec..5b1899c52 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -607,7 +607,9 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq] # Extract generation parameters - max_gen_toks = gen_kwargs.get("max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens"))) + max_gen_toks = gen_kwargs.get( + "max_gen_toks", gen_kwargs.get("max_new_tokens", gen_kwargs.get("max_tokens")) + ) try: max_gen_toks = int(max_gen_toks) if max_gen_toks is not None else 256 except (TypeError, ValueError): @@ -673,7 +675,7 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: # Maintain a character-based tail of exactly max_stop_len + len(chunk) chars # so stop sequences that span chunk boundaries are never missed. if until: - tail = (tail + chunk)[-(max_stop_len + len(chunk)):] + tail = (tail + chunk)[-(max_stop_len + len(chunk)) :] tail_offset = generated_len - len(tail) earliest = None for stop_seq in until: diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 685fed910..3b2da778b 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -839,10 +839,10 @@ def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cl @pytest.mark.parametrize( ("temperature_val", "expect_do_sample"), [ - ("0.7", True), # string float should be coerced - (None, False), # None should fall back to 0.0 - (0.0, False), # zero means greedy - (0.5, True), # normal float + ("0.7", True), # string float should be coerced + (None, False), # None should fall back to 0.0 + (0.0, False), # zero means greedy + (0.5, True), # normal float ], ) @patch("onnxruntime_genai.Generator") From b6e89d1df759b8b087ee08e3b1f88034229e869e Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 12:17:25 -0700 Subject: [PATCH 07/13] Fix batch_size in search_options and confirm_run_unsafe_code compat with older lm-eval - Remove batch_size from set_search_options (not a valid kwarg) - Use try/except TypeError for confirm_run_unsafe_code compat - Remove unused import inspect Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 1 - olive/evaluator/olive_evaluator.py | 28 ++++++++++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 5b1899c52..67e90adfa 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -640,7 +640,6 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: search_options = { "max_length": total_max_length, "past_present_share_buffer": False, - "batch_size": 1, } if do_sample: search_options["temperature"] = temperature diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 4cd28f257..88c7e8a19 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1101,16 +1101,24 @@ def evaluate( if self.tasks: lmmodel = get_model(self.model_class)(**init_args, batch_size=self.batch_size, max_length=self.max_length) - results = simple_evaluate( - model=lmmodel, - tasks=self.tasks, - task_manager=TaskManager(), - log_samples=False, - batch_size=self.batch_size, - device=device, - limit=self.limit, - confirm_run_unsafe_code=self.confirm_run_unsafe_code, - ) + simple_evaluate_kwargs = { + "model": lmmodel, + "tasks": self.tasks, + "task_manager": TaskManager(), + "log_samples": False, + "batch_size": self.batch_size, + "device": device, + "limit": self.limit, + "confirm_run_unsafe_code": self.confirm_run_unsafe_code, + } + try: + results = simple_evaluate(**simple_evaluate_kwargs) + except TypeError as e: + if "confirm_run_unsafe_code" not in str(e): + raise + # Older lm-eval versions don't support confirm_run_unsafe_code; retry without it + simple_evaluate_kwargs.pop("confirm_run_unsafe_code") + results = simple_evaluate(**simple_evaluate_kwargs) for task_name in sorted(results["results"].keys()): metric_items = sorted(results["results"][task_name].items()) From 206b3c0bb1aa5a74cf99ab2d2a27c9073ee55d04 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 12:19:28 -0700 Subject: [PATCH 08/13] Add test asserting batch_size is not passed to set_search_options Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/evaluator/test_olive_evaluator.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 3b2da778b..bbbe08564 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -787,6 +787,30 @@ def test_generate_until_parses_max_tokens_robustly( call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] assert call_kwargs["max_length"] == expected_max_length + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_does_not_pass_batch_size_to_search_options(self, mock_params_cls, mock_gen_cls): + """batch_size is not a valid set_search_options kwarg for ORT GenAI — must never be passed.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 2]) + evaluator.tokenizer.decode.return_value = "hello" + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 64}) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + assert "batch_size" not in call_kwargs, f"batch_size must not be passed to set_search_options, got: {call_kwargs}" + @patch("onnxruntime_genai.Generator") @patch("onnxruntime_genai.GeneratorParams") def test_generate_until_returns_empty_when_max_gen_toks_zero(self, mock_params_cls, mock_gen_cls): From 0a3a489d0d60c21bcf8e6b3ee734c7e8fdddd826 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 12:45:43 -0700 Subject: [PATCH 09/13] Fix Copilot review batch 5: inspect.signature, full-seq decode, do_sample coercion - olive_evaluator.py: use inspect.signature to check if lm-eval supports confirm_run_unsafe_code before passing it (avoids brittle try/except TypeError) - lmeval_ort.py: accumulate generated_token_ids and decode full sequence once at end so tokenizer whitespace/punctuation normalisation is applied correctly; keep per-token incremental decode only for stop-sequence tail detection - lmeval_ort.py: coerce do_sample to bool, handling string 'false'/'0'/'no' and int 0/1 so they are not unintentionally truthy - tests: set __signature__ on simple_evaluate mocks; add older-lm-eval compat test; update decode-assertion for full-seq decode; add do_sample coercion test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/evaluator/lmeval_ort.py | 52 ++++++---- olive/evaluator/olive_evaluator.py | 14 +-- test/evaluator/test_olive_evaluator.py | 125 ++++++++++++++++++++++++- 3 files changed, 159 insertions(+), 32 deletions(-) diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 67e90adfa..91fb06238 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -619,7 +619,15 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: temperature = float(gen_kwargs.get("temperature", 0.0) or 0.0) except (TypeError, ValueError): temperature = 0.0 - do_sample = gen_kwargs.get("do_sample", temperature > 0) + raw_do_sample = gen_kwargs.get("do_sample", None) + if raw_do_sample is None: + do_sample = temperature > 0 + elif isinstance(raw_do_sample, bool): + do_sample = raw_do_sample + elif isinstance(raw_do_sample, str): + do_sample = raw_do_sample.lower() not in ("false", "0", "no", "") + else: + do_sample = bool(raw_do_sample) # Tokenize the prompt prompt_ids = self.tokenizer.encode(context).tolist() @@ -651,9 +659,8 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: generator = og.Generator(self.model, params) generator.append_tokens([prompt_ids]) - generated_chunks = [] - generated_len = 0 # running total character count, avoids O(n²) join for offset - stop_idx = None + generated_token_ids = [] + stop_found = False # Character-based rolling tail wide enough to catch any stop sequence # across chunk boundaries, regardless of how many tokens a stop string spans. max_stop_len = max((len(s) for s in until), default=0) @@ -667,27 +674,34 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: if new_token in self.eos_token_ids: break - chunk = self.tokenizer.decode([new_token]) - generated_chunks.append(chunk) - generated_len += len(chunk) + generated_token_ids.append(new_token) - # Maintain a character-based tail of exactly max_stop_len + len(chunk) chars - # so stop sequences that span chunk boundaries are never missed. + # Decode one token at a time only for stop-sequence tail detection. + # The final text is produced by decoding the full ID sequence so that + # tokenizer whitespace/punctuation normalisation is applied correctly. if until: + chunk = self.tokenizer.decode([new_token]) tail = (tail + chunk)[-(max_stop_len + len(chunk)) :] - tail_offset = generated_len - len(tail) - earliest = None for stop_seq in until: - idx = tail.find(stop_seq) - if idx != -1: - abs_idx = tail_offset + idx - if earliest is None or abs_idx < earliest: - earliest = abs_idx - if earliest is not None: - stop_idx = earliest + if stop_seq in tail: + stop_found = True + break + if stop_found: break - generated_text = "".join(generated_chunks) if stop_idx is None else "".join(generated_chunks)[:stop_idx] + # Decode full token sequence once for correct whitespace/punctuation handling. + full_text = self.tokenizer.decode(generated_token_ids) if generated_token_ids else "" + + # Trim at the earliest stop sequence found in the final decoded text. + generated_text = full_text + if until: + earliest = None + for stop_seq in until: + idx = full_text.find(stop_seq) + if idx != -1 and (earliest is None or idx < earliest): + earliest = idx + if earliest is not None: + generated_text = full_text[:earliest] results.append(generated_text) diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index 88c7e8a19..b057b9c70 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import collections +import inspect import logging import time from abc import ABC, abstractmethod @@ -1109,16 +1110,11 @@ def evaluate( "batch_size": self.batch_size, "device": device, "limit": self.limit, - "confirm_run_unsafe_code": self.confirm_run_unsafe_code, } - try: - results = simple_evaluate(**simple_evaluate_kwargs) - except TypeError as e: - if "confirm_run_unsafe_code" not in str(e): - raise - # Older lm-eval versions don't support confirm_run_unsafe_code; retry without it - simple_evaluate_kwargs.pop("confirm_run_unsafe_code") - results = simple_evaluate(**simple_evaluate_kwargs) + # Only pass confirm_run_unsafe_code when the installed lm-eval version supports it. + if "confirm_run_unsafe_code" in inspect.signature(simple_evaluate).parameters: + simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code + results = simple_evaluate(**simple_evaluate_kwargs) for task_name in sorted(results["results"].keys()): metric_items = sorted(results["results"][task_name].items()) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index bbbe08564..5a8601cfc 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -496,9 +496,15 @@ class TestLMEvaluatorModelClass: def test_lm_evaluator_dispatches_to_requested_backend( self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock, model_class ): + import inspect + from olive.evaluator.olive_evaluator import LMEvaluator from olive.model.handler.onnx import ONNXModelHandler + def _fake_evaluate(model, tasks, task_manager=None, log_samples=True, batch_size=1, device="cpu", limit=None): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) simple_evaluate_mock.return_value = {"results": {}} get_model_mock.return_value = MagicMock(return_value=MagicMock()) @@ -518,9 +524,25 @@ def test_lm_evaluator_dispatches_to_requested_backend( def test_lm_evaluator_passes_confirm_run_unsafe_code( self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock ): + import inspect + from olive.evaluator.olive_evaluator import LMEvaluator from olive.model.handler.onnx import ONNXModelHandler + # Give the mock a signature that includes confirm_run_unsafe_code so inspect.signature works. + def _fake_evaluate( + model, + tasks, + task_manager=None, + log_samples=True, + batch_size=1, + device="cpu", + limit=None, + confirm_run_unsafe_code=False, + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) simple_evaluate_mock.return_value = {"results": {}} get_model_mock.return_value = MagicMock(return_value=MagicMock()) @@ -544,9 +566,25 @@ def test_lm_evaluator_passes_confirm_run_unsafe_code( def test_lm_evaluator_confirm_run_unsafe_code_defaults_false( self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock ): + import inspect + from olive.evaluator.olive_evaluator import LMEvaluator from olive.model.handler.onnx import ONNXModelHandler + # Give the mock a signature that includes confirm_run_unsafe_code so inspect.signature works. + def _fake_evaluate( + model, + tasks, + task_manager=None, + log_samples=True, + batch_size=1, + device="cpu", + limit=None, + confirm_run_unsafe_code=False, + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate) simple_evaluate_mock.return_value = {"results": {}} get_model_mock.return_value = MagicMock(return_value=MagicMock()) @@ -561,6 +599,41 @@ def test_lm_evaluator_confirm_run_unsafe_code_defaults_false( call_kwargs = simple_evaluate_mock.call_args[1] assert call_kwargs["confirm_run_unsafe_code"] is False + @patch("lm_eval.utils.setup_logging") + @patch("lm_eval.tasks.TaskManager") + @patch("lm_eval.simple_evaluate") + @patch("lm_eval.api.registry.get_model") + def test_lm_evaluator_skips_confirm_run_unsafe_code_for_older_lm_eval( + self, get_model_mock, simple_evaluate_mock, _task_manager_mock, _setup_logging_mock + ): + """When lm-eval lacks confirm_run_unsafe_code, the kwarg must not be passed.""" + import inspect + + from olive.evaluator.olive_evaluator import LMEvaluator + from olive.model.handler.onnx import ONNXModelHandler + + # Mock a signature WITHOUT confirm_run_unsafe_code (simulates older lm-eval). + def _fake_evaluate_old( + model, tasks, task_manager=None, log_samples=True, batch_size=1, device="cpu", limit=None + ): + pass + + simple_evaluate_mock.__signature__ = inspect.signature(_fake_evaluate_old) + simple_evaluate_mock.return_value = {"results": {}} + get_model_mock.return_value = MagicMock(return_value=MagicMock()) + + evaluator = LMEvaluator( + tasks=["mbpp"], model_class="ortgenai", batch_size=1, max_length=128, confirm_run_unsafe_code=True + ) + + model = MagicMock(spec=ONNXModelHandler) + model.model_path = "/tmp/model.onnx" + + evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) + + call_kwargs = simple_evaluate_mock.call_args[1] + assert "confirm_run_unsafe_code" not in call_kwargs + @pytest.mark.skipif( importlib.util.find_spec("lm_eval") is None or importlib.util.find_spec("onnxruntime_genai") is None, @@ -626,7 +699,7 @@ def test_generate_until_stops_on_stop_sequence(self, mock_params_cls, mock_gen_c evaluator.tokenizer = MagicMock() evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) - evaluator.tokenizer.decode.side_effect = ["he", "l", "lo\n world"] + evaluator.tokenizer.decode.side_effect = ["he", "l", "lo\n world", "hello\n world"] mock_generator = MagicMock() mock_generator.is_done.side_effect = [False, False, False, False] @@ -809,7 +882,9 @@ def test_generate_until_does_not_pass_batch_size_to_search_options(self, mock_pa LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] - assert "batch_size" not in call_kwargs, f"batch_size must not be passed to set_search_options, got: {call_kwargs}" + assert "batch_size" not in call_kwargs, ( + f"batch_size must not be passed to set_search_options, got: {call_kwargs}" + ) @patch("onnxruntime_genai.Generator") @patch("onnxruntime_genai.GeneratorParams") @@ -842,7 +917,7 @@ def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cl evaluator.model = MagicMock() evaluator.tokenizer = MagicMock() evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) - evaluator.tokenizer.decode.side_effect = ["he", "llo"] + evaluator.tokenizer.decode.return_value = "hello" # returned for full-sequence decode mock_generator = MagicMock() mock_generator.is_done.side_effect = [False, False, False] @@ -857,8 +932,9 @@ def test_generate_until_decodes_incrementally(self, mock_params_cls, mock_gen_cl results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) assert results == ["hello"] + # With no stop sequences, tokens are decoded once as a full sequence (not per-token). decode_inputs = [call.args[0] for call in evaluator.tokenizer.decode.call_args_list] - assert decode_inputs == [[11], [12]] + assert decode_inputs == [[11, 12]] @pytest.mark.parametrize( ("temperature_val", "expect_do_sample"), @@ -900,3 +976,44 @@ def test_generate_until_handles_temperature_coercion( assert call_kwargs["temperature"] > 0 else: assert call_kwargs["temperature"] == 0.0 + + @pytest.mark.parametrize( + ("do_sample_val", "expect_sampling"), + [ + (True, True), # bool True → sampling on + (False, False), # bool False → greedy + ("true", True), # string "true" → sampling on + ("false", False), # string "false" → greedy (was truthy before fix) + ("0", False), # string "0" → greedy + ("1", True), # string "1" → sampling + (1, True), # int 1 → sampling + (0, False), # int 0 → greedy + ], + ) + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_coerces_do_sample(self, mock_params_cls, mock_gen_cls, do_sample_val, expect_sampling): + """do_sample must be coerced to a real bool so string 'false'/'0' are not truthy.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + + mock_generator = MagicMock() + mock_generator.is_done.return_value = True + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request( + "prompt", {"until": [], "max_gen_toks": 10, "do_sample": do_sample_val, "temperature": 0.7} + ) + LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] + if expect_sampling: + assert call_kwargs["temperature"] > 0, f"Expected sampling for do_sample={do_sample_val!r}" + else: + assert call_kwargs["temperature"] == 0.0, f"Expected greedy for do_sample={do_sample_val!r}" From 0b62f2d03d108b4b84390e323d8b2006db8d223c Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 13:42:20 -0700 Subject: [PATCH 10/13] Fix Copilot review batch 6: signature guard, null flag, tuple until, ort_genai stub - olive_evaluator.py: wrap inspect.signature in try/except (TypeError, ValueError) so an unintrospectable wrapper never crashes evaluation - benchmark.py: write confirm_run_unsafe_code as explicit bool (not 'or None') so False is always written rather than silently omitted as null - lmeval_ort.py: normalise 'until' tuple/set/generator to list via list() so stop sequences in tuple form are not silently dropped - test/evaluator/conftest.py: inject onnxruntime_genai stub so generate_until tests run without the real package installed - test: remove onnxruntime_genai from skipif; add tuple-until regression test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/cli/benchmark.py | 2 +- olive/evaluator/lmeval_ort.py | 7 +++++-- olive/evaluator/olive_evaluator.py | 8 ++++++- test/evaluator/conftest.py | 25 ++++++++++++++++++++++ test/evaluator/test_olive_evaluator.py | 29 ++++++++++++++++++++++++-- 5 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 test/evaluator/conftest.py diff --git a/olive/cli/benchmark.py b/olive/cli/benchmark.py index 18237bbf3..193394d83 100644 --- a/olive/cli/benchmark.py +++ b/olive/cli/benchmark.py @@ -126,7 +126,7 @@ def _get_run_config(self, tempdir: str) -> dict: ), ( ("evaluators", "evaluator", "confirm_run_unsafe_code"), - self.args.confirm_run_unsafe_code or None, + self.args.confirm_run_unsafe_code, ), ] diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 91fb06238..9a800b7b7 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -596,14 +596,17 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: context = request.args[0] gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {} - # Extract stop sequences + # Extract stop sequences — normalise str/None/tuple/other-iterables to list[str] until = gen_kwargs.get("until", []) if isinstance(until, str): until = [until] elif until is None: until = [] elif not isinstance(until, list): - until = [until] + try: + until = list(until) # handles tuple, set, generator, etc. + except TypeError: + until = [until] # non-iterable scalar fallback until = [stop_seq for stop_seq in until if isinstance(stop_seq, str) and stop_seq] # Extract generation parameters diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index b057b9c70..e7f9ffd98 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -1112,7 +1112,13 @@ def evaluate( "limit": self.limit, } # Only pass confirm_run_unsafe_code when the installed lm-eval version supports it. - if "confirm_run_unsafe_code" in inspect.signature(simple_evaluate).parameters: + try: + supports_confirm_run_unsafe_code = ( + "confirm_run_unsafe_code" in inspect.signature(simple_evaluate).parameters + ) + except (TypeError, ValueError): + supports_confirm_run_unsafe_code = False + if supports_confirm_run_unsafe_code: simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code results = simple_evaluate(**simple_evaluate_kwargs) diff --git a/test/evaluator/conftest.py b/test/evaluator/conftest.py new file mode 100644 index 000000000..8617534db --- /dev/null +++ b/test/evaluator/conftest.py @@ -0,0 +1,25 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Inject a minimal onnxruntime_genai stub for generate_until unit tests. + +Ensures tests can run in environments where the real package is not installed. +The tests mock all ORT GenAI objects anyway, so the stub only needs to provide +importable names. +""" + +import sys +import types +from unittest.mock import MagicMock + + +def _ensure_ort_genai_stub(): + if "onnxruntime_genai" not in sys.modules: + stub = types.ModuleType("onnxruntime_genai") + stub.Generator = MagicMock + stub.GeneratorParams = MagicMock + sys.modules["onnxruntime_genai"] = stub + + +_ensure_ort_genai_stub() diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index 5a8601cfc..ee18d84d9 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -636,8 +636,8 @@ def _fake_evaluate_old( @pytest.mark.skipif( - importlib.util.find_spec("lm_eval") is None or importlib.util.find_spec("onnxruntime_genai") is None, - reason="lm_eval or onnxruntime_genai not installed", + importlib.util.find_spec("lm_eval") is None, + reason="lm_eval not installed", ) class TestLMEvalORTGenAIGenerateUntil: """Unit tests for LMEvalORTGenAIEvaluator.generate_until.""" @@ -1017,3 +1017,28 @@ def test_generate_until_coerces_do_sample(self, mock_params_cls, mock_gen_cls, d assert call_kwargs["temperature"] > 0, f"Expected sampling for do_sample={do_sample_val!r}" else: assert call_kwargs["temperature"] == 0.0, f"Expected greedy for do_sample={do_sample_val!r}" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_handles_tuple_until(self, mock_params_cls, mock_gen_cls): + """Until as a tuple must not be wrapped as a single element — each string is a stop sequence.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1, 100]) + evaluator.tokenizer.decode.return_value = "hello\n world" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.return_value = MagicMock(__getitem__=lambda s, k: 50) + mock_gen_cls.return_value = mock_generator + + # Pass until as a tuple — previously this would silently produce no stop enforcement + request = self._make_mock_request("prompt", {"until": ("\n",), "max_gen_toks": 256}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results[0] == "hello", f"Expected stop at \\n but got: {results[0]!r}" From a4ba67a9affbe2049de09d48bd487a6c6f43261b Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Thu, 30 Apr 2026 14:14:08 -0700 Subject: [PATCH 11/13] Fix critical bugs from code review: do_sample flag, config clobber, tqdm, caching MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - lmeval_ort.py: pass do_sample=True to set_search_options when sampling is enabled — without it ORT GenAI ignores temperature and runs greedy regardless - benchmark.py: use default=None for --confirm_run_unsafe_code so False is not written to config when flag is omitted (prevents overriding config-file truth) - lmeval_ort.py: wrap request loop in tqdm respecting disable_tqdm parameter - olive_evaluator.py: extract _simple_evaluate_supports_unsafe_code() with lru_cache so inspect.signature is not called on every evaluate() invocation - tests: assert do_sample=True in search_options when sampling; add multi-request isolation test and cache_hook.add_partial assertion Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- olive/cli/benchmark.py | 4 +- olive/evaluator/lmeval_ort.py | 3 +- olive/evaluator/olive_evaluator.py | 19 ++++--- test/evaluator/test_olive_evaluator.py | 72 ++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 11 deletions(-) diff --git a/olive/cli/benchmark.py b/olive/cli/benchmark.py index 193394d83..a3b3f25e8 100644 --- a/olive/cli/benchmark.py +++ b/olive/cli/benchmark.py @@ -79,7 +79,7 @@ def register_subcommand(parser: ArgumentParser): lmeval_group.add_argument( "--confirm_run_unsafe_code", action="store_true", - default=False, + default=None, help="Allow running tasks that execute model-generated code (e.g., MBPP, HumanEval).", ) @@ -126,7 +126,7 @@ def _get_run_config(self, tempdir: str) -> dict: ), ( ("evaluators", "evaluator", "confirm_run_unsafe_code"), - self.args.confirm_run_unsafe_code, + True if self.args.confirm_run_unsafe_code else None, ), ] diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index 9a800b7b7..5f9ac921a 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -592,7 +592,7 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: Each request is a tuple of (context_string, gen_kwargs_dict). """ results = [] - for request in requests: + for request in tqdm(requests, desc="Running generate_until", disable=disable_tqdm): context = request.args[0] gen_kwargs = request.args[1] if len(request.args) > 1 and isinstance(request.args[1], dict) else {} @@ -653,6 +653,7 @@ def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]: "past_present_share_buffer": False, } if do_sample: + search_options["do_sample"] = True search_options["temperature"] = temperature else: search_options["temperature"] = 0.0 diff --git a/olive/evaluator/olive_evaluator.py b/olive/evaluator/olive_evaluator.py index e7f9ffd98..d29629920 100644 --- a/olive/evaluator/olive_evaluator.py +++ b/olive/evaluator/olive_evaluator.py @@ -8,7 +8,7 @@ import time from abc import ABC, abstractmethod from copy import deepcopy -from functools import partial +from functools import lru_cache, partial from numbers import Number from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple, Optional, Union @@ -1017,6 +1017,15 @@ def _prepare_dataloader( return FileListCommonDataLoader(dataloader, model.io_config, batch_size=file_chunk_size) +@lru_cache(maxsize=1) +def _simple_evaluate_supports_unsafe_code(simple_evaluate_fn) -> bool: + """Check (cached) whether lm-eval's simple_evaluate accepts confirm_run_unsafe_code.""" + try: + return "confirm_run_unsafe_code" in inspect.signature(simple_evaluate_fn).parameters + except (TypeError, ValueError): + return False + + @Registry.register("LMEvaluator") class LMEvaluator(OliveEvaluator): def __init__(self, tasks: list[str], **kwargs): @@ -1112,13 +1121,7 @@ def evaluate( "limit": self.limit, } # Only pass confirm_run_unsafe_code when the installed lm-eval version supports it. - try: - supports_confirm_run_unsafe_code = ( - "confirm_run_unsafe_code" in inspect.signature(simple_evaluate).parameters - ) - except (TypeError, ValueError): - supports_confirm_run_unsafe_code = False - if supports_confirm_run_unsafe_code: + if _simple_evaluate_supports_unsafe_code(simple_evaluate): simple_evaluate_kwargs["confirm_run_unsafe_code"] = self.confirm_run_unsafe_code results = simple_evaluate(**simple_evaluate_kwargs) diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index ee18d84d9..5bc819a7b 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -974,8 +974,10 @@ def test_generate_until_handles_temperature_coercion( call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] if expect_do_sample: assert call_kwargs["temperature"] > 0 + assert call_kwargs.get("do_sample") is True else: assert call_kwargs["temperature"] == 0.0 + assert "do_sample" not in call_kwargs @pytest.mark.parametrize( ("do_sample_val", "expect_sampling"), @@ -1015,8 +1017,14 @@ def test_generate_until_coerces_do_sample(self, mock_params_cls, mock_gen_cls, d call_kwargs = mock_params_cls.return_value.set_search_options.call_args[1] if expect_sampling: assert call_kwargs["temperature"] > 0, f"Expected sampling for do_sample={do_sample_val!r}" + assert call_kwargs.get("do_sample") is True, ( + f"do_sample=True must be set in search_options for do_sample={do_sample_val!r}" + ) else: assert call_kwargs["temperature"] == 0.0, f"Expected greedy for do_sample={do_sample_val!r}" + assert "do_sample" not in call_kwargs, ( + f"do_sample must not be set when greedy for do_sample={do_sample_val!r}" + ) @patch("onnxruntime_genai.Generator") @patch("onnxruntime_genai.GeneratorParams") @@ -1042,3 +1050,67 @@ def test_generate_until_handles_tuple_until(self, mock_params_cls, mock_gen_cls) results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) assert results[0] == "hello", f"Expected stop at \\n but got: {results[0]!r}" + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_processes_multiple_requests_independently(self, mock_params_cls, mock_gen_cls): + """Multiple requests must not share mutable state (tail, stop_found, token_ids).""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + # First request decodes to text with a stop; second decodes cleanly + evaluator.tokenizer.decode.side_effect = [ + "\n", # per-token tail for req 1 (stop sequence present) + "hello\n", # full-sequence decode for req 1 + "world", # full-sequence decode for req 2 (no stop) + ] + + mock_generator = MagicMock() + # Req 1: is_done=False → generates token 10 → stop seq found → break (no more is_done) + # Req 2: is_done=False → generates token 20 → is_done=True → exit loop + mock_generator.is_done.side_effect = [False, False, True] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 10), # req 1 token + MagicMock(__getitem__=lambda s, k: 20), # req 2 token + ] + mock_gen_cls.return_value = mock_generator + + req1 = self._make_mock_request("p1", {"until": ["\n"], "max_gen_toks": 64}) + req2 = self._make_mock_request("p2", {"until": [], "max_gen_toks": 64}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [req1, req2]) + + assert results[0] == "hello" # trimmed at \n + assert results[1] == "world" # no stop, full text + + @patch("onnxruntime_genai.Generator") + @patch("onnxruntime_genai.GeneratorParams") + def test_generate_until_calls_cache_hook(self, mock_params_cls, mock_gen_cls): + """cache_hook.add_partial must be called with the final generated text.""" + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + evaluator = MagicMock(spec=LMEvalORTGenAIEvaluator) + evaluator.eos_token_ids = {2} + evaluator.max_length = 1024 + evaluator.model = MagicMock() + evaluator.tokenizer = MagicMock() + evaluator.tokenizer.encode.return_value = self._mock_encode([1]) + evaluator.tokenizer.decode.return_value = "hello" + + mock_generator = MagicMock() + mock_generator.is_done.side_effect = [False, False] + mock_generator.get_sequence.side_effect = [ + MagicMock(__getitem__=lambda s, k: 10), + MagicMock(__getitem__=lambda s, k: 2), # EOS + ] + mock_gen_cls.return_value = mock_generator + + request = self._make_mock_request("prompt", {"until": [], "max_gen_toks": 64}) + results = LMEvalORTGenAIEvaluator.generate_until(evaluator, [request]) + + assert results == ["hello"] + request.cache_hook.add_partial.assert_called_once_with("generate_until", request.args, "hello") From 973054ec235eba33a66c1e9120ec542a30cad7d1 Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Mon, 4 May 2026 10:29:29 -0700 Subject: [PATCH 12/13] Fix conftest.py stub polluting sys.modules when onnxruntime_genai is installed The previous check 'if onnxruntime_genai not in sys.modules' was wrong: on CI machines where the real package is installed but not yet imported, the check passes and injects the hollow stub. Later imports of onnxruntime_genai.models.builder then fail with 'is not a package'. Fix: attempt a real import and only inject the stub on ImportError. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/evaluator/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/evaluator/conftest.py b/test/evaluator/conftest.py index 8617534db..548115082 100644 --- a/test/evaluator/conftest.py +++ b/test/evaluator/conftest.py @@ -15,7 +15,9 @@ def _ensure_ort_genai_stub(): - if "onnxruntime_genai" not in sys.modules: + try: + import onnxruntime_genai # noqa: F401 + except ImportError: stub = types.ModuleType("onnxruntime_genai") stub.Generator = MagicMock stub.GeneratorParams = MagicMock From 814faa164ff48c1580a5557b207f446e28007dfa Mon Sep 17 00:00:00 2001 From: Nat Kershaw Date: Mon, 4 May 2026 16:20:12 -0700 Subject: [PATCH 13/13] Fix conftest.py pylint W0611: use find_spec instead of unused import importlib.util.find_spec("onnxruntime_genai") checks whether the package is installed without importing it, avoiding the unused-import lint warning. This is also cleaner than try/except ImportError for an existence check. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- test/evaluator/conftest.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/evaluator/conftest.py b/test/evaluator/conftest.py index 548115082..50e575714 100644 --- a/test/evaluator/conftest.py +++ b/test/evaluator/conftest.py @@ -9,15 +9,14 @@ importable names. """ +import importlib.util import sys import types from unittest.mock import MagicMock def _ensure_ort_genai_stub(): - try: - import onnxruntime_genai # noqa: F401 - except ImportError: + if importlib.util.find_spec("onnxruntime_genai") is None: stub = types.ModuleType("onnxruntime_genai") stub.Generator = MagicMock stub.GeneratorParams = MagicMock