Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions eval_protocol/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from pathlib import Path
from typing import Dict, Optional # Added Dict

import requests

logger = logging.getLogger(__name__)

# Default locations (used for tests and as fallback). Actual resolution is dynamic via _get_auth_ini_file().
Expand Down Expand Up @@ -218,3 +220,40 @@ def get_fireworks_api_base() -> str:
else:
logger.debug("FIREWORKS_API_BASE not set in environment, defaulting to %s.", api_base)
return api_base


def verify_api_key_and_get_account_id(
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Optional[str]:
"""
Calls the Fireworks API verify endpoint to validate the API key and returns the
account id from response headers when available.

Args:
api_key: Optional explicit API key. When None, resolves via get_fireworks_api_key().
api_base: Optional explicit API base. When None, resolves via get_fireworks_api_base().

Returns:
The resolved account id if verification succeeds and the header is present; otherwise None.
"""
try:
resolved_key = api_key or get_fireworks_api_key()
if not resolved_key:
return None
resolved_base = api_base or get_fireworks_api_base()
url = f"{resolved_base.rstrip('/')}/verifyApiKey"
headers = {"Authorization": f"Bearer {resolved_key}"}
resp = requests.get(url, headers=headers, timeout=10)
if resp.status_code != 200:
logger.debug("verifyApiKey returned status %s", resp.status_code)
return None
# Header keys could vary in case; requests provides case-insensitive dict
account_id = resp.headers.get("x-fireworks-account-id") or resp.headers.get("X-Fireworks-Account-Id")
if account_id and account_id.strip():
logger.debug("Resolved FIREWORKS_ACCOUNT_ID via verifyApiKey: %s", account_id)
return account_id.strip()
return None
except Exception as e:
logger.debug("Failed to verify API key for account id resolution: %s", e)
return None
81 changes: 27 additions & 54 deletions eval_protocol/cli_commands/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from typing import Any, Callable, Iterable, Optional

import pytest
from eval_protocol.auth import get_fireworks_account_id, get_fireworks_api_key
from eval_protocol.auth import (
get_fireworks_account_id,
get_fireworks_api_key,
get_fireworks_api_base,
verify_api_key_and_get_account_id,
)
from eval_protocol.platform_api import create_or_update_fireworks_secret

from eval_protocol.evaluation import create_evaluation
Expand Down Expand Up @@ -259,7 +264,7 @@ def _parse_entry(entry: str, cwd: str) -> tuple[str, str]:
raise ValueError("--entry must be in 'module::function', 'path::function', or 'module:function' format")


def _generate_ts_mode_code_from_entry(entry: str, cwd: str) -> tuple[str, str, str, str]:
def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]:
target, func = _parse_entry(entry, cwd)

# Check if target looks like a file path
Expand Down Expand Up @@ -293,47 +298,12 @@ def _generate_ts_mode_code_from_entry(entry: str, cwd: str) -> tuple[str, str, s
raise ValueError(f"Function '{func}' not found in module '{module_name}'")

qualname = f"{module_name}.{func}"
code, file_name = _generate_ts_mode_code(
DiscoveredTest(
module_path=module_name,
module_name=module_name,
qualname=qualname,
file_path=getattr(module, "__file__", module_name),
lineno=None,
has_parametrize=False,
param_count=0,
nodeids=[],
)
)
return code, file_name, qualname, os.path.abspath(source_file_path) if source_file_path else ""
return qualname, os.path.abspath(source_file_path) if source_file_path else ""


def _generate_ts_mode_code(test: DiscoveredTest) -> tuple[str, str]:
# Generate a minimal main.py that imports the test module and calls the function
module = test.module_name
func = test.qualname.split(".")[-1]
code = f"""
from typing import Any, Dict, List, Optional, Union

from eval_protocol.models import EvaluationRow, Message
from {module} import {func} as _ep_test

def evaluate(messages: List[Dict[str, Any]], ground_truth: Optional[Union[str, List[Dict[str, Any]]]] = None, tools=None, **kwargs):
row = EvaluationRow(messages=[Message(**m) for m in messages], ground_truth=ground_truth)
result = _ep_test(row) # Supports sync/async via decorator's dual-mode
if hasattr(result, "__await__"):
import asyncio
result = asyncio.get_event_loop().run_until_complete(result)
if result.evaluation_result is None:
return {{"score": 0.0, "reason": "No evaluation_result set"}}
out = {{
"score": float(result.evaluation_result.score or 0.0),
"reason": result.evaluation_result.reason,
"metrics": {{k: (v.model_dump() if hasattr(v, "model_dump") else v) for k, v in (result.evaluation_result.metrics or {{}}).items()}},
}}
return out
"""
return (code, "main.py")
# Deprecated: we no longer generate a shim; keep stub for import compatibility
return ("", "main.py")


def _normalize_evaluator_id(evaluator_id: str) -> str:
Expand Down Expand Up @@ -522,10 +492,10 @@ def upload_command(args: argparse.Namespace) -> int:
entries_arg = getattr(args, "entry", None)
if entries_arg:
entries = [e.strip() for e in re.split(r"[,\s]+", entries_arg) if e.strip()]
selected_specs: list[tuple[str, str, str, str]] = []
selected_specs: list[tuple[str, str]] = []
for e in entries:
code, file_name, qualname, resolved_path = _generate_ts_mode_code_from_entry(e, root)
selected_specs.append((code, file_name, qualname, resolved_path))
qualname, resolved_path = _resolve_entry_to_qual_and_source(e, root)
selected_specs.append((qualname, resolved_path))
else:
print("Scanning for evaluation tests...")
tests = _discover_tests(root)
Expand All @@ -545,11 +515,7 @@ def upload_command(args: argparse.Namespace) -> int:
print(" handles all parameter combinations. The evaluator will work with")
print(" the same logic regardless of which model/parameters are used.")

selected_specs = []
for t in selected_tests:
code, file_name = _generate_ts_mode_code(t)
# Store test info for better ID generation
selected_specs.append((code, file_name, t.qualname, t.file_path))
selected_specs = [(t.qualname, t.file_path) for t in selected_tests]

base_id = getattr(args, "id", None)
display_name = getattr(args, "display_name", None)
Expand All @@ -560,6 +526,14 @@ def upload_command(args: argparse.Namespace) -> int:
try:
fw_account_id = get_fireworks_account_id()
fw_api_key_value = get_fireworks_api_key()
if not fw_account_id and fw_api_key_value:
# Attempt to verify and resolve account id from server headers
resolved = verify_api_key_and_get_account_id(api_key=fw_api_key_value, api_base=get_fireworks_api_base())
if resolved:
fw_account_id = resolved
# Propagate to environment so downstream calls use it if needed
os.environ["FIREWORKS_ACCOUNT_ID"] = fw_account_id
print(f"Resolved FIREWORKS_ACCOUNT_ID via API verification: {fw_account_id}")
if fw_account_id and fw_api_key_value:
print("Ensuring FIREWORKS_API_KEY is registered as a secret on Fireworks for rollout...")
if create_or_update_fireworks_secret(
Expand All @@ -579,8 +553,7 @@ def upload_command(args: argparse.Namespace) -> int:
print(f"Warning: Skipped Fireworks secret registration due to error: {e}")

exit_code = 0
for i, (code, file_name, qualname, source_file_path) in enumerate(selected_specs):
# Use ts_mode to upload evaluator
for i, (qualname, source_file_path) in enumerate(selected_specs):
# Generate a short default ID from just the test function name
if base_id:
evaluator_id = base_id
Expand Down Expand Up @@ -618,12 +591,12 @@ def upload_command(args: argparse.Namespace) -> int:

print(f"\nUploading evaluator '{evaluator_id}' for {qualname.split('.')[-1]}...")
try:
# Always treat as a single evaluator (single-metric) even if folder has helper modules
test_dir = os.path.dirname(source_file_path) if source_file_path else root
metric_name = os.path.basename(test_dir) or "metric"
result = create_evaluation(
evaluator_id=evaluator_id,
python_code_to_evaluate=code,
python_file_name_for_code=file_name,
criterion_name_for_code=qualname,
criterion_description_for_code=description or f"Evaluator for {qualname}",
metric_folders=[f"{metric_name}={test_dir}"],
display_name=display_name or evaluator_id,
description=description or f"Evaluator for {qualname}",
force=force,
Expand Down
Loading
Loading