diff --git a/eval/task.py b/eval/task.py index 70962115..10f949f6 100644 --- a/eval/task.py +++ b/eval/task.py @@ -15,6 +15,9 @@ from lm_eval.api.instance import Instance from lm_eval.api.model import LM +# Safety buffer for max_gen_toks calculation +SAFETY_BUFFER_TOKENS = 16 + class BaseBenchmark(ABC): """Abstract base class for implementing LLM evaluation benchmarks.""" @@ -46,6 +49,42 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In _ = instance.args[1].pop("seed") if "seed" in instance.args[1] else None if "max_new_tokens" in instance.args[1]: max_new_tokens = instance.args[1].pop("max_new_tokens") + + max_model_len = None + if isinstance(model, lm_eval_models.vllm_causallms.VLLM): + max_model_len = model.model.llm_engine.model_config.max_model_len + elif isinstance(model, lm_eval_models.huggingface.HFLM): + max_model_len = model.model.config.max_position_embeddings + + if max_model_len is not None: + try: + # Get prompt from instance.args[0] (the templated string) + prompt = instance.args[0] + prompt_length = len(model.tokenizer.encode(prompt)) + + # Check if prompt itself exceeds model capacity + if prompt_length > max_model_len: + self.logger.warning( + f"Prompt length ({prompt_length}) exceeds model max length ({max_model_len}). " + f"Prompt will be truncated with no room for generation." + ) + + # Calculate max allowed generation tokens (16 token safety buffer) + max_allowed = max_model_len - prompt_length - SAFETY_BUFFER_TOKENS + capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed)) + + if capped_max_new_tokens < max_new_tokens: + self.logger.warning( + f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} " + f"(prompt: {prompt_length} tokens, model max: {max_model_len})" + ) + + max_new_tokens = capped_max_new_tokens + except Exception as e: + self.logger.warning( + f"Failed to calculate max_new_tokens, using original value: {e}" + ) + if isinstance(model, lm_eval_models.openai_completions.OpenAIChatCompletion) or isinstance( model, lm_eval_models.openai_completions.OpenAICompletionsAPI ): @@ -53,9 +92,10 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In if "4o" in model.model: instance.args[1]["max_tokens"] = min(max_new_tokens, 16384) elif isinstance(model, lm_eval_models.vllm_causallms.VLLM): - instance.args[1]["max_gen_toks"] = max_new_tokens + instance.args[1]["max_gen_toks"] = int(max_new_tokens) else: # Huggingface - instance.args[1]["max_new_tokens"] = max_new_tokens + instance.args[1]["max_new_tokens"] = int(max_new_tokens) + return instances def _prepare_messages(