From a9a4fa04373021e1012d99778278f9198b658d2c Mon Sep 17 00:00:00 2001 From: Ali Elganzory Date: Sat, 3 Jan 2026 14:40:59 +0100 Subject: [PATCH 1/3] fix: cap max_gen_toks to prevent vLLM crash when exceeding context window --- eval/task.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/eval/task.py b/eval/task.py index 70962115..1c90ee20 100644 --- a/eval/task.py +++ b/eval/task.py @@ -53,7 +53,26 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In if "4o" in model.model: instance.args[1]["max_tokens"] = min(max_new_tokens, 16384) elif isinstance(model, lm_eval_models.vllm_causallms.VLLM): - instance.args[1]["max_gen_toks"] = max_new_tokens + # Get prompt from instance.args[0] (the templated string) + prompt = instance.args[0] + prompt_length = len(model.tokenizer.encode(prompt)) + + # Get max model length from vLLM engine + max_model_len = model.model.llm_engine.model_config.max_model_len + + # Calculate max allowed generation tokens (16 token safety buffer) + max_allowed = max_model_len - prompt_length - 16 + + # Cap to available space + capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed)) + + if capped_max_new_tokens < max_new_tokens: + self.logger.warning( + f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} " + f"(prompt: {prompt_length} tokens, model max: {max_model_len})" + ) + + instance.args[1]["max_gen_toks"] = capped_max_new_tokens else: # Huggingface instance.args[1]["max_new_tokens"] = max_new_tokens return instances From 977d755841e5ab515db2321c800b535374fdd42d Mon Sep 17 00:00:00 2001 From: Ali Elganzory Date: Sun, 4 Jan 2026 11:27:41 +0100 Subject: [PATCH 2/3] Address PR feedback: add try-catch, prompt_length check, and constant --- eval/task.py | 50 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/eval/task.py b/eval/task.py index 1c90ee20..204c2a12 100644 --- a/eval/task.py +++ b/eval/task.py @@ -15,6 +15,9 @@ from lm_eval.api.instance import Instance from lm_eval.api.model import LM +# Safety buffer for vLLM max_gen_toks calculation +VLLM_SAFETY_BUFFER_TOKENS = 16 + class BaseBenchmark(ABC): """Abstract base class for implementing LLM evaluation benchmarks.""" @@ -53,26 +56,37 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In if "4o" in model.model: instance.args[1]["max_tokens"] = min(max_new_tokens, 16384) elif isinstance(model, lm_eval_models.vllm_causallms.VLLM): - # Get prompt from instance.args[0] (the templated string) - prompt = instance.args[0] - prompt_length = len(model.tokenizer.encode(prompt)) - - # Get max model length from vLLM engine - max_model_len = model.model.llm_engine.model_config.max_model_len - - # Calculate max allowed generation tokens (16 token safety buffer) - max_allowed = max_model_len - prompt_length - 16 - - # Cap to available space - capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed)) - - if capped_max_new_tokens < max_new_tokens: + try: + # Get prompt from instance.args[0] (the templated string) + prompt = instance.args[0] + prompt_length = len(model.tokenizer.encode(prompt)) + + # Get max model length from vLLM engine + max_model_len = model.model.llm_engine.model_config.max_model_len + + # Check if prompt itself exceeds model capacity + if prompt_length > max_model_len: + self.logger.warning( + f"Prompt length ({prompt_length}) exceeds model max length ({max_model_len}). " + f"Prompt will be truncated with no room for generation." + ) + + # Calculate max allowed generation tokens (16 token safety buffer) + max_allowed = max_model_len - prompt_length - VLLM_SAFETY_BUFFER_TOKENS + capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed)) + + if capped_max_new_tokens < max_new_tokens: + self.logger.warning( + f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} " + f"(prompt: {prompt_length} tokens, model max: {max_model_len})" + ) + + instance.args[1]["max_gen_toks"] = capped_max_new_tokens + except Exception as e: self.logger.warning( - f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} " - f"(prompt: {prompt_length} tokens, model max: {max_model_len})" + f"Failed to calculate max_gen_toks for vLLM, using original value: {e}" ) - - instance.args[1]["max_gen_toks"] = capped_max_new_tokens + instance.args[1]["max_gen_toks"] = max_new_tokens else: # Huggingface instance.args[1]["max_new_tokens"] = max_new_tokens return instances From e93d565bbf9ea128d6200896efe86f94c80a9ea9 Mon Sep 17 00:00:00 2001 From: Ali Elganzory Date: Fri, 6 Feb 2026 15:10:43 +0100 Subject: [PATCH 3/3] fix: cap max_new_tokens to avoid degraded/undefined behavior when exceeding context window - HuggingFace's transformers does NOT fail when the `context_length < prompt_length + max_new_tokens`; the behavior is rather undefined or degraded. - Cap the max number of generated tokens for both HF and vLLM. --- eval/task.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/eval/task.py b/eval/task.py index 204c2a12..10f949f6 100644 --- a/eval/task.py +++ b/eval/task.py @@ -15,8 +15,8 @@ from lm_eval.api.instance import Instance from lm_eval.api.model import LM -# Safety buffer for vLLM max_gen_toks calculation -VLLM_SAFETY_BUFFER_TOKENS = 16 +# Safety buffer for max_gen_toks calculation +SAFETY_BUFFER_TOKENS = 16 class BaseBenchmark(ABC): @@ -49,21 +49,19 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In _ = instance.args[1].pop("seed") if "seed" in instance.args[1] else None if "max_new_tokens" in instance.args[1]: max_new_tokens = instance.args[1].pop("max_new_tokens") - if isinstance(model, lm_eval_models.openai_completions.OpenAIChatCompletion) or isinstance( - model, lm_eval_models.openai_completions.OpenAICompletionsAPI - ): - instance.args[1]["max_tokens"] = max_new_tokens - if "4o" in model.model: - instance.args[1]["max_tokens"] = min(max_new_tokens, 16384) - elif isinstance(model, lm_eval_models.vllm_causallms.VLLM): + + max_model_len = None + if isinstance(model, lm_eval_models.vllm_causallms.VLLM): + max_model_len = model.model.llm_engine.model_config.max_model_len + elif isinstance(model, lm_eval_models.huggingface.HFLM): + max_model_len = model.model.config.max_position_embeddings + + if max_model_len is not None: try: # Get prompt from instance.args[0] (the templated string) prompt = instance.args[0] prompt_length = len(model.tokenizer.encode(prompt)) - # Get max model length from vLLM engine - max_model_len = model.model.llm_engine.model_config.max_model_len - # Check if prompt itself exceeds model capacity if prompt_length > max_model_len: self.logger.warning( @@ -72,7 +70,7 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In ) # Calculate max allowed generation tokens (16 token safety buffer) - max_allowed = max_model_len - prompt_length - VLLM_SAFETY_BUFFER_TOKENS + max_allowed = max_model_len - prompt_length - SAFETY_BUFFER_TOKENS capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed)) if capped_max_new_tokens < max_new_tokens: @@ -81,14 +79,23 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In f"(prompt: {prompt_length} tokens, model max: {max_model_len})" ) - instance.args[1]["max_gen_toks"] = capped_max_new_tokens + max_new_tokens = capped_max_new_tokens except Exception as e: self.logger.warning( - f"Failed to calculate max_gen_toks for vLLM, using original value: {e}" + f"Failed to calculate max_new_tokens, using original value: {e}" ) - instance.args[1]["max_gen_toks"] = max_new_tokens + + if isinstance(model, lm_eval_models.openai_completions.OpenAIChatCompletion) or isinstance( + model, lm_eval_models.openai_completions.OpenAICompletionsAPI + ): + instance.args[1]["max_tokens"] = max_new_tokens + if "4o" in model.model: + instance.args[1]["max_tokens"] = min(max_new_tokens, 16384) + elif isinstance(model, lm_eval_models.vllm_causallms.VLLM): + instance.args[1]["max_gen_toks"] = int(max_new_tokens) else: # Huggingface - instance.args[1]["max_new_tokens"] = max_new_tokens + instance.args[1]["max_new_tokens"] = int(max_new_tokens) + return instances def _prepare_messages(