diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index c4a158533..95343f7c2 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -498,6 +498,8 @@ def __init__( self.config.set_provider_option(ep, key, value) self.model = og.Model(self.config) self.tokenizer = og.Tokenizer(self.model) + self._pretrained = str(pretrained) + self._hf_tokenizer: AutoTokenizer | None = None # consider adding auto batch sizes self.batch_size = int(batch_size) @@ -521,6 +523,20 @@ def __init__( self.device = device self._returns_full_logits = self._detect_full_logits() + @property + def tokenizer_name(self) -> str: + return self._pretrained.replace("\\", "__").replace("/", "__") + + def apply_chat_template(self, chat_history: list[dict], add_generation_prompt: bool = True) -> str: + if self._hf_tokenizer is None: + self._hf_tokenizer = AutoTokenizer.from_pretrained(self._pretrained) + return self._hf_tokenizer.apply_chat_template( + chat_history, + tokenize=False, + add_generation_prompt=add_generation_prompt, + continue_final_message=not add_generation_prompt, + ) + def _detect_full_logits(self) -> bool: """Check if the model returns logits for all input positions or only the last.""" try: diff --git a/test/evaluator/test_olive_evaluator.py b/test/evaluator/test_olive_evaluator.py index e295d069a..251ab619c 100644 --- a/test/evaluator/test_olive_evaluator.py +++ b/test/evaluator/test_olive_evaluator.py @@ -510,3 +510,47 @@ def test_lm_evaluator_dispatches_to_requested_backend( evaluator.evaluate(model, metrics=[], device=Device.CPU, execution_providers=["CPUExecutionProvider"]) get_model_mock.assert_called_once_with(model_class) + + +class TestLMEvalORTGenAIChatTemplate: + def _bare_instance(self, pretrained: str): + # pylint: disable=protected-access + from olive.evaluator.lmeval_ort import LMEvalORTGenAIEvaluator + + instance = object.__new__(LMEvalORTGenAIEvaluator) + instance._pretrained = pretrained + instance._hf_tokenizer = None + return instance + + @pytest.mark.parametrize( + ("pretrained", "expected"), + [ + ("/models/lfm2-350m", "__models__lfm2-350m"), + ("relative/path/model", "relative__path__model"), + ("C:\\models\\lfm2-350m", "C:__models__lfm2-350m"), + ], + ) + def test_tokenizer_name_normalizes_separators(self, pretrained, expected): + assert self._bare_instance(pretrained).tokenizer_name == expected + + @patch("olive.evaluator.lmeval_ort.AutoTokenizer") + def test_apply_chat_template_lazy_loads_hf_tokenizer(self, auto_tokenizer_mock): + chat_history = [{"role": "user", "content": "hello"}] + mock_tokenizer = MagicMock() + mock_tokenizer.apply_chat_template.return_value = "rendered prompt" + auto_tokenizer_mock.from_pretrained.return_value = mock_tokenizer + + instance = self._bare_instance("/models/lfm2") + + auto_tokenizer_mock.from_pretrained.assert_not_called() + assert instance.apply_chat_template(chat_history) == "rendered prompt" + auto_tokenizer_mock.from_pretrained.assert_called_once_with("/models/lfm2") + + instance.apply_chat_template(chat_history, add_generation_prompt=False) + auto_tokenizer_mock.from_pretrained.assert_called_once() + mock_tokenizer.apply_chat_template.assert_called_with( + chat_history, + tokenize=False, + add_generation_prompt=False, + continue_final_message=True, + )