Skip to content

Commit a3baa0a

Browse files
authored
proper end to end upload (#269)
* proper end to end upload * fix multi metrics issue * keep multi metric and rollout
1 parent 4d01e1d commit a3baa0a

File tree

3 files changed

+191
-94
lines changed

3 files changed

+191
-94
lines changed

eval_protocol/auth.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from pathlib import Path
55
from typing import Dict, Optional # Added Dict
66

7+
import requests
8+
79
logger = logging.getLogger(__name__)
810

911
# Default locations (used for tests and as fallback). Actual resolution is dynamic via _get_auth_ini_file().
@@ -218,3 +220,40 @@ def get_fireworks_api_base() -> str:
218220
else:
219221
logger.debug("FIREWORKS_API_BASE not set in environment, defaulting to %s.", api_base)
220222
return api_base
223+
224+
225+
def verify_api_key_and_get_account_id(
226+
api_key: Optional[str] = None,
227+
api_base: Optional[str] = None,
228+
) -> Optional[str]:
229+
"""
230+
Calls the Fireworks API verify endpoint to validate the API key and returns the
231+
account id from response headers when available.
232+
233+
Args:
234+
api_key: Optional explicit API key. When None, resolves via get_fireworks_api_key().
235+
api_base: Optional explicit API base. When None, resolves via get_fireworks_api_base().
236+
237+
Returns:
238+
The resolved account id if verification succeeds and the header is present; otherwise None.
239+
"""
240+
try:
241+
resolved_key = api_key or get_fireworks_api_key()
242+
if not resolved_key:
243+
return None
244+
resolved_base = api_base or get_fireworks_api_base()
245+
url = f"{resolved_base.rstrip('/')}/verifyApiKey"
246+
headers = {"Authorization": f"Bearer {resolved_key}"}
247+
resp = requests.get(url, headers=headers, timeout=10)
248+
if resp.status_code != 200:
249+
logger.debug("verifyApiKey returned status %s", resp.status_code)
250+
return None
251+
# Header keys could vary in case; requests provides case-insensitive dict
252+
account_id = resp.headers.get("x-fireworks-account-id") or resp.headers.get("X-Fireworks-Account-Id")
253+
if account_id and account_id.strip():
254+
logger.debug("Resolved FIREWORKS_ACCOUNT_ID via verifyApiKey: %s", account_id)
255+
return account_id.strip()
256+
return None
257+
except Exception as e:
258+
logger.debug("Failed to verify API key for account id resolution: %s", e)
259+
return None

eval_protocol/cli_commands/upload.py

Lines changed: 27 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
from typing import Any, Callable, Iterable, Optional
1313

1414
import pytest
15-
from eval_protocol.auth import get_fireworks_account_id, get_fireworks_api_key
15+
from eval_protocol.auth import (
16+
get_fireworks_account_id,
17+
get_fireworks_api_key,
18+
get_fireworks_api_base,
19+
verify_api_key_and_get_account_id,
20+
)
1621
from eval_protocol.platform_api import create_or_update_fireworks_secret
1722

1823
from eval_protocol.evaluation import create_evaluation
@@ -259,7 +264,7 @@ def _parse_entry(entry: str, cwd: str) -> tuple[str, str]:
259264
raise ValueError("--entry must be in 'module::function', 'path::function', or 'module:function' format")
260265

261266

262-
def _generate_ts_mode_code_from_entry(entry: str, cwd: str) -> tuple[str, str, str, str]:
267+
def _resolve_entry_to_qual_and_source(entry: str, cwd: str) -> tuple[str, str]:
263268
target, func = _parse_entry(entry, cwd)
264269

265270
# Check if target looks like a file path
@@ -293,47 +298,12 @@ def _generate_ts_mode_code_from_entry(entry: str, cwd: str) -> tuple[str, str, s
293298
raise ValueError(f"Function '{func}' not found in module '{module_name}'")
294299

295300
qualname = f"{module_name}.{func}"
296-
code, file_name = _generate_ts_mode_code(
297-
DiscoveredTest(
298-
module_path=module_name,
299-
module_name=module_name,
300-
qualname=qualname,
301-
file_path=getattr(module, "__file__", module_name),
302-
lineno=None,
303-
has_parametrize=False,
304-
param_count=0,
305-
nodeids=[],
306-
)
307-
)
308-
return code, file_name, qualname, os.path.abspath(source_file_path) if source_file_path else ""
301+
return qualname, os.path.abspath(source_file_path) if source_file_path else ""
309302

310303

311304
def _generate_ts_mode_code(test: DiscoveredTest) -> tuple[str, str]:
312-
# Generate a minimal main.py that imports the test module and calls the function
313-
module = test.module_name
314-
func = test.qualname.split(".")[-1]
315-
code = f"""
316-
from typing import Any, Dict, List, Optional, Union
317-
318-
from eval_protocol.models import EvaluationRow, Message
319-
from {module} import {func} as _ep_test
320-
321-
def evaluate(messages: List[Dict[str, Any]], ground_truth: Optional[Union[str, List[Dict[str, Any]]]] = None, tools=None, **kwargs):
322-
row = EvaluationRow(messages=[Message(**m) for m in messages], ground_truth=ground_truth)
323-
result = _ep_test(row) # Supports sync/async via decorator's dual-mode
324-
if hasattr(result, "__await__"):
325-
import asyncio
326-
result = asyncio.get_event_loop().run_until_complete(result)
327-
if result.evaluation_result is None:
328-
return {{"score": 0.0, "reason": "No evaluation_result set"}}
329-
out = {{
330-
"score": float(result.evaluation_result.score or 0.0),
331-
"reason": result.evaluation_result.reason,
332-
"metrics": {{k: (v.model_dump() if hasattr(v, "model_dump") else v) for k, v in (result.evaluation_result.metrics or {{}}).items()}},
333-
}}
334-
return out
335-
"""
336-
return (code, "main.py")
305+
# Deprecated: we no longer generate a shim; keep stub for import compatibility
306+
return ("", "main.py")
337307

338308

339309
def _normalize_evaluator_id(evaluator_id: str) -> str:
@@ -522,10 +492,10 @@ def upload_command(args: argparse.Namespace) -> int:
522492
entries_arg = getattr(args, "entry", None)
523493
if entries_arg:
524494
entries = [e.strip() for e in re.split(r"[,\s]+", entries_arg) if e.strip()]
525-
selected_specs: list[tuple[str, str, str, str]] = []
495+
selected_specs: list[tuple[str, str]] = []
526496
for e in entries:
527-
code, file_name, qualname, resolved_path = _generate_ts_mode_code_from_entry(e, root)
528-
selected_specs.append((code, file_name, qualname, resolved_path))
497+
qualname, resolved_path = _resolve_entry_to_qual_and_source(e, root)
498+
selected_specs.append((qualname, resolved_path))
529499
else:
530500
print("Scanning for evaluation tests...")
531501
tests = _discover_tests(root)
@@ -545,11 +515,7 @@ def upload_command(args: argparse.Namespace) -> int:
545515
print(" handles all parameter combinations. The evaluator will work with")
546516
print(" the same logic regardless of which model/parameters are used.")
547517

548-
selected_specs = []
549-
for t in selected_tests:
550-
code, file_name = _generate_ts_mode_code(t)
551-
# Store test info for better ID generation
552-
selected_specs.append((code, file_name, t.qualname, t.file_path))
518+
selected_specs = [(t.qualname, t.file_path) for t in selected_tests]
553519

554520
base_id = getattr(args, "id", None)
555521
display_name = getattr(args, "display_name", None)
@@ -560,6 +526,14 @@ def upload_command(args: argparse.Namespace) -> int:
560526
try:
561527
fw_account_id = get_fireworks_account_id()
562528
fw_api_key_value = get_fireworks_api_key()
529+
if not fw_account_id and fw_api_key_value:
530+
# Attempt to verify and resolve account id from server headers
531+
resolved = verify_api_key_and_get_account_id(api_key=fw_api_key_value, api_base=get_fireworks_api_base())
532+
if resolved:
533+
fw_account_id = resolved
534+
# Propagate to environment so downstream calls use it if needed
535+
os.environ["FIREWORKS_ACCOUNT_ID"] = fw_account_id
536+
print(f"Resolved FIREWORKS_ACCOUNT_ID via API verification: {fw_account_id}")
563537
if fw_account_id and fw_api_key_value:
564538
print("Ensuring FIREWORKS_API_KEY is registered as a secret on Fireworks for rollout...")
565539
if create_or_update_fireworks_secret(
@@ -579,8 +553,7 @@ def upload_command(args: argparse.Namespace) -> int:
579553
print(f"Warning: Skipped Fireworks secret registration due to error: {e}")
580554

581555
exit_code = 0
582-
for i, (code, file_name, qualname, source_file_path) in enumerate(selected_specs):
583-
# Use ts_mode to upload evaluator
556+
for i, (qualname, source_file_path) in enumerate(selected_specs):
584557
# Generate a short default ID from just the test function name
585558
if base_id:
586559
evaluator_id = base_id
@@ -618,12 +591,12 @@ def upload_command(args: argparse.Namespace) -> int:
618591

619592
print(f"\nUploading evaluator '{evaluator_id}' for {qualname.split('.')[-1]}...")
620593
try:
594+
# Always treat as a single evaluator (single-metric) even if folder has helper modules
595+
test_dir = os.path.dirname(source_file_path) if source_file_path else root
596+
metric_name = os.path.basename(test_dir) or "metric"
621597
result = create_evaluation(
622598
evaluator_id=evaluator_id,
623-
python_code_to_evaluate=code,
624-
python_file_name_for_code=file_name,
625-
criterion_name_for_code=qualname,
626-
criterion_description_for_code=description or f"Evaluator for {qualname}",
599+
metric_folders=[f"{metric_name}={test_dir}"],
627600
display_name=display_name or evaluator_id,
628601
description=description or f"Evaluator for {qualname}",
629602
force=force,

0 commit comments

Comments
 (0)