Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions eval/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from lm_eval.api.instance import Instance
from lm_eval.api.model import LM

# Safety buffer for max_gen_toks calculation
SAFETY_BUFFER_TOKENS = 16


class BaseBenchmark(ABC):
"""Abstract base class for implementing LLM evaluation benchmarks."""
Expand Down Expand Up @@ -46,16 +49,53 @@ def _normalize_model_args(self, model: LM, instances: List[Instance]) -> List[In
_ = instance.args[1].pop("seed") if "seed" in instance.args[1] else None
if "max_new_tokens" in instance.args[1]:
max_new_tokens = instance.args[1].pop("max_new_tokens")

max_model_len = None
if isinstance(model, lm_eval_models.vllm_causallms.VLLM):
max_model_len = model.model.llm_engine.model_config.max_model_len
elif isinstance(model, lm_eval_models.huggingface.HFLM):
max_model_len = model.model.config.max_position_embeddings

if max_model_len is not None:
try:
# Get prompt from instance.args[0] (the templated string)
prompt = instance.args[0]
prompt_length = len(model.tokenizer.encode(prompt))

# Check if prompt itself exceeds model capacity
if prompt_length > max_model_len:
self.logger.warning(
f"Prompt length ({prompt_length}) exceeds model max length ({max_model_len}). "
f"Prompt will be truncated with no room for generation."
)

# Calculate max allowed generation tokens (16 token safety buffer)
max_allowed = max_model_len - prompt_length - SAFETY_BUFFER_TOKENS
capped_max_new_tokens = min(max_new_tokens, max(1, max_allowed))

if capped_max_new_tokens < max_new_tokens:
self.logger.warning(
f"max_new_tokens ({max_new_tokens}) capped to {capped_max_new_tokens} "
f"(prompt: {prompt_length} tokens, model max: {max_model_len})"
)

max_new_tokens = capped_max_new_tokens
except Exception as e:
self.logger.warning(
f"Failed to calculate max_new_tokens, using original value: {e}"
)

if isinstance(model, lm_eval_models.openai_completions.OpenAIChatCompletion) or isinstance(
model, lm_eval_models.openai_completions.OpenAICompletionsAPI
):
instance.args[1]["max_tokens"] = max_new_tokens
if "4o" in model.model:
instance.args[1]["max_tokens"] = min(max_new_tokens, 16384)
elif isinstance(model, lm_eval_models.vllm_causallms.VLLM):
instance.args[1]["max_gen_toks"] = max_new_tokens
instance.args[1]["max_gen_toks"] = int(max_new_tokens)
else: # Huggingface
instance.args[1]["max_new_tokens"] = max_new_tokens
instance.args[1]["max_new_tokens"] = int(max_new_tokens)

return instances

def _prepare_messages(
Expand Down