Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/test_eval_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self, api_key=None, base_url=None):
model="gpt-4.1-mini",
api_key_var="OPENAI_API_KEY",
api_base_url="https://api.openai.com/v1",
extra_headers={},
num_examples=1,
rollouts_per_example=1,
max_concurrent=1,
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(self, api_key=None, base_url=None):
model="gpt-4.1-mini",
api_key_var="OPENAI_API_KEY",
api_base_url="https://api.openai.com/v1",
extra_headers={},
num_examples=1,
rollouts_per_example=1,
max_concurrent=1,
Expand Down
22 changes: 21 additions & 1 deletion verifiers/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import uuid
from datetime import datetime
from pathlib import Path
from typing import cast
from typing import cast, Dict

import numpy as np
from datasets import Dataset
Expand Down Expand Up @@ -39,6 +39,7 @@ def eval_environment(
save_dataset: bool,
save_to_hf_hub: bool,
hf_hub_dataset_name: str,
extra_headers: Dict[str, str],
):
logger.setLevel("DEBUG" if verbose else "INFO")
try:
Expand Down Expand Up @@ -92,6 +93,7 @@ def eval_environment(
max_connections=28000, # Number of available ports
max_keepalive_connections=28000, # Number of available ports
max_retries=10, # 10 retries (w/ exponential backoffs)
extra_headers=extra_headers,
)
logger.debug(f"Initialized OpenAI client with base_url: {api_base_url}")
vf_env = vf.load_environment(env_id=env, **env_args)
Expand Down Expand Up @@ -263,6 +265,12 @@ def main():
default="https://api.openai.com/v1",
help="Base URL for API",
)
parser.add_argument(
"--header",
action="append",
default=None,
help="Extra HTTP header to pass to inference API. 'Name: Value'. Repeatable.",
)
parser.add_argument(
"--num-examples",
"-n",
Expand Down Expand Up @@ -330,6 +338,17 @@ def main():
)
args = parser.parse_args()

# Build headers from repeated --header flags
merged_headers: Dict[str, str] = {}
for h in args.header or []:
if ":" not in h:
raise ValueError(f"--header must be 'Name: Value', got: {h!r}")
k, v = h.split(":", 1)
k, v = k.strip(), v.strip()
if not k:
raise ValueError("--header name cannot be empty")
merged_headers[k] = v

eval_environment(
env=args.env,
env_args=args.env_args,
Expand All @@ -348,6 +367,7 @@ def main():
save_dataset=args.save_dataset,
save_to_hf_hub=args.save_to_hf_hub,
hf_hub_dataset_name=args.hf_hub_dataset_name,
extra_headers=merged_headers,
)


Expand Down
8 changes: 7 additions & 1 deletion verifiers/utils/client_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Dict

import httpx
from httpx import AsyncClient
Expand All @@ -12,6 +13,7 @@ def setup_client(
max_connections: int = 1000, # OAI default, larger value recommended for evals
max_keepalive_connections: int = 100, # OAI default, larger value recommended for evals
max_retries: int = 2, # OAI default, larger value recommended for evals
extra_headers: Dict[str, str] | None = None,
) -> AsyncOpenAI:
"""
A helper function to setup an AsyncOpenAI client.
Expand All @@ -24,7 +26,11 @@ def setup_client(
)

# Setup client
http_client = AsyncClient(limits=limits, timeout=http_timeout)
http_client = AsyncClient(
limits=limits,
timeout=http_timeout,
headers=extra_headers,
)
client = AsyncOpenAI(
base_url=api_base_url,
api_key=os.getenv(api_key_var, "EMPTY"),
Expand Down
158 changes: 89 additions & 69 deletions verifiers/utils/env_utils.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,106 @@
from __future__ import annotations

import importlib
import inspect
import logging
from importlib.metadata import entry_points
from typing import Callable

from verifiers.envs.environment import Environment

LOGGER = logging.getLogger("verifiers.utils.env_utils")


def _call_loader(
func: Callable[..., Environment], env_id: str, **env_args
) -> Environment:
sig = inspect.signature(func)

if env_args:
LOGGER.info(
"Using provided args: "
+ ", ".join(f"{k}={v!r}" for k, v in env_args.items())
)

defaults = []
for name, p in sig.parameters.items():
if name not in env_args and p.default is not inspect._empty:
defaults.append(f"{name}={p.default!r}")
if defaults:
LOGGER.info("Using default args: " + ", ".join(defaults))

env = func(**env_args)
LOGGER.info(f"Successfully loaded environment '{env_id}'")
return env


def _load_from_target_spec(target: str, env_id: str, **env_args) -> Environment:
mod, sep, attr = target.partition(":")
if not sep or not attr:
raise AttributeError(f"Invalid target spec '{target}'. Expected 'module:attr'.")
module = importlib.import_module(mod)
func = getattr(module, attr)
if not callable(func):
raise TypeError(f"Target '{target}' is not callable")
return _call_loader(func, env_id, **env_args)


def _load_via_entry_point_exact(env_id: str, **env_args) -> Environment | None:
"""Exact match on the 'verifiers' entry point name. No aliasing or splitting."""
eps = entry_points(group="verifiers")
matches = [ep for ep in eps if ep.name == env_id]
if not matches:
return None
if len(matches) > 1:
details = ", ".join(ep.value for ep in matches)
raise RuntimeError(
f"Multiple 'verifiers' entry points named '{env_id}' found: {details}"
)
func = matches[0].load()
if not callable(func):
raise TypeError(
f"Entry point '{env_id}' did not load a callable; got {type(func)!r}"
)
return _call_loader(func, env_id, **env_args)


def load_environment(env_id: str, **env_args) -> Environment:
logger = logging.getLogger("verifiers.utils.env_utils")
logger.info(f"Loading environment: {env_id}")
LOGGER.info(f"Loading environment: {env_id}")

module_name = env_id.replace("-", "_")
try:
module = importlib.import_module(module_name)
# 1) Explicit module target: "pkg.mod:callable"
if ":" in env_id:
try:
return _load_from_target_spec(env_id, env_id, **env_args)
except Exception as e:
LOGGER.error(f"Failed to load environment {env_id} via target spec: {e}")
raise RuntimeError(f"Failed to load environment '{env_id}': {e}") from e

if not hasattr(module, "load_environment"):
raise AttributeError(
f"Module '{module_name}' does not have a 'load_environment' function. "
f"This usually means there's a package name collision. Please either:\n"
f"1. Rename your environment (e.g. suffix with '-env')\n"
f"2. Remove unneeded files with the same name\n"
f"3. Check that you've installed the correct environment package"
)

env_load_func: Callable[..., Environment] = getattr(module, "load_environment")
sig = inspect.signature(env_load_func)
defaults_info = []
for param_name, param in sig.parameters.items():
if param.default != inspect.Parameter.empty:
if isinstance(param.default, (dict, list)):
defaults_info.append(f"{param_name}={param.default}")
elif isinstance(param.default, str):
defaults_info.append(f"{param_name}='{param.default}'")
else:
defaults_info.append(f"{param_name}={param.default}")
else:
defaults_info.append(f"{param_name}=<required>")

if defaults_info:
logger.debug(f"Environment defaults: {', '.join(defaults_info)}")

if env_args:
provided_params = set(env_args.keys())
else:
provided_params = set()

all_params = set(sig.parameters.keys())
default_params = all_params - provided_params

if provided_params:
provided_values = []
for param_name in provided_params:
provided_values.append(f"{param_name}={env_args[param_name]}")
logger.info(f"Using provided args: {', '.join(provided_values)}")

if default_params:
default_values = []
for param_name in default_params:
param = sig.parameters[param_name]
if param.default != inspect.Parameter.empty:
if isinstance(param.default, str):
default_values.append(f"{param_name}='{param.default}'")
else:
default_values.append(f"{param_name}={param.default}")
if default_values:
logger.info(f"Using default args: {', '.join(default_values)}")

env_instance: Environment = env_load_func(**env_args)

logger.info(f"Successfully loaded environment '{env_id}'")

return env_instance
# 2) Prefer entry points (exact match only)
try:
ep_env = _load_via_entry_point_exact(env_id, **env_args)
if ep_env is not None:
return ep_env
except Exception as e:
LOGGER.error(f"Failed to load environment {env_id} via entry point: {e}")
raise RuntimeError(f"Failed to load environment '{env_id}': {e}") from e

# 3) Back-compat fallback: import by module name (slug or namespaced ID's tail)
module_name = env_id.split("/")[-1].replace("-", "_")
try:
module = importlib.import_module(module_name)
except ImportError as e:
logger.error(
f"Failed to import environment module {module_name} for env_id {env_id}: {str(e)}"
LOGGER.error(
f"Failed to import environment module {module_name} for env_id {env_id}: {e}"
)
raise ValueError(
f"Could not import '{env_id}' environment. Ensure the package for the '{env_id}' environment is installed."
f"Could not import '{env_id}'. Install a package that exposes a matching "
f"[project.entry-points.verifiers] = \"{env_id}\" entry or provide 'module:attr'."
) from e
except Exception as e:
logger.error(
f"Failed to load environment {env_id} with args {env_args}: {str(e)}"

if not hasattr(module, "load_environment"):
raise AttributeError(
f"Module '{module_name}' has no 'load_environment'. "
f"Prefer registering an entry point named '{env_id}' under the 'verifiers' group."
)
raise RuntimeError(f"Failed to load environment '{env_id}': {str(e)}") from e

return _call_loader(getattr(module, "load_environment"), env_id, **env_args)
Loading