1212from typing import Any , Callable , Iterable , Optional
1313
1414import pytest
15- from eval_protocol .auth import get_fireworks_account_id , get_fireworks_api_key
15+ from eval_protocol .auth import (
16+ get_fireworks_account_id ,
17+ get_fireworks_api_key ,
18+ get_fireworks_api_base ,
19+ verify_api_key_and_get_account_id ,
20+ )
1621from eval_protocol .platform_api import create_or_update_fireworks_secret
1722
1823from eval_protocol .evaluation import create_evaluation
@@ -259,7 +264,7 @@ def _parse_entry(entry: str, cwd: str) -> tuple[str, str]:
259264 raise ValueError ("--entry must be in 'module::function', 'path::function', or 'module:function' format" )
260265
261266
262- def _generate_ts_mode_code_from_entry (entry : str , cwd : str ) -> tuple [str , str , str , str ]:
267+ def _resolve_entry_to_qual_and_source (entry : str , cwd : str ) -> tuple [str , str ]:
263268 target , func = _parse_entry (entry , cwd )
264269
265270 # Check if target looks like a file path
@@ -293,47 +298,12 @@ def _generate_ts_mode_code_from_entry(entry: str, cwd: str) -> tuple[str, str, s
293298 raise ValueError (f"Function '{ func } ' not found in module '{ module_name } '" )
294299
295300 qualname = f"{ module_name } .{ func } "
296- code , file_name = _generate_ts_mode_code (
297- DiscoveredTest (
298- module_path = module_name ,
299- module_name = module_name ,
300- qualname = qualname ,
301- file_path = getattr (module , "__file__" , module_name ),
302- lineno = None ,
303- has_parametrize = False ,
304- param_count = 0 ,
305- nodeids = [],
306- )
307- )
308- return code , file_name , qualname , os .path .abspath (source_file_path ) if source_file_path else ""
301+ return qualname , os .path .abspath (source_file_path ) if source_file_path else ""
309302
310303
311304def _generate_ts_mode_code (test : DiscoveredTest ) -> tuple [str , str ]:
312- # Generate a minimal main.py that imports the test module and calls the function
313- module = test .module_name
314- func = test .qualname .split ("." )[- 1 ]
315- code = f"""
316- from typing import Any, Dict, List, Optional, Union
317-
318- from eval_protocol.models import EvaluationRow, Message
319- from { module } import { func } as _ep_test
320-
321- def evaluate(messages: List[Dict[str, Any]], ground_truth: Optional[Union[str, List[Dict[str, Any]]]] = None, tools=None, **kwargs):
322- row = EvaluationRow(messages=[Message(**m) for m in messages], ground_truth=ground_truth)
323- result = _ep_test(row) # Supports sync/async via decorator's dual-mode
324- if hasattr(result, "__await__"):
325- import asyncio
326- result = asyncio.get_event_loop().run_until_complete(result)
327- if result.evaluation_result is None:
328- return {{"score": 0.0, "reason": "No evaluation_result set"}}
329- out = {{
330- "score": float(result.evaluation_result.score or 0.0),
331- "reason": result.evaluation_result.reason,
332- "metrics": {{k: (v.model_dump() if hasattr(v, "model_dump") else v) for k, v in (result.evaluation_result.metrics or {{}}).items()}},
333- }}
334- return out
335- """
336- return (code , "main.py" )
305+ # Deprecated: we no longer generate a shim; keep stub for import compatibility
306+ return ("" , "main.py" )
337307
338308
339309def _normalize_evaluator_id (evaluator_id : str ) -> str :
@@ -522,10 +492,10 @@ def upload_command(args: argparse.Namespace) -> int:
522492 entries_arg = getattr (args , "entry" , None )
523493 if entries_arg :
524494 entries = [e .strip () for e in re .split (r"[,\s]+" , entries_arg ) if e .strip ()]
525- selected_specs : list [tuple [str , str , str , str ]] = []
495+ selected_specs : list [tuple [str , str ]] = []
526496 for e in entries :
527- code , file_name , qualname , resolved_path = _generate_ts_mode_code_from_entry (e , root )
528- selected_specs .append ((code , file_name , qualname , resolved_path ))
497+ qualname , resolved_path = _resolve_entry_to_qual_and_source (e , root )
498+ selected_specs .append ((qualname , resolved_path ))
529499 else :
530500 print ("Scanning for evaluation tests..." )
531501 tests = _discover_tests (root )
@@ -545,11 +515,7 @@ def upload_command(args: argparse.Namespace) -> int:
545515 print (" handles all parameter combinations. The evaluator will work with" )
546516 print (" the same logic regardless of which model/parameters are used." )
547517
548- selected_specs = []
549- for t in selected_tests :
550- code , file_name = _generate_ts_mode_code (t )
551- # Store test info for better ID generation
552- selected_specs .append ((code , file_name , t .qualname , t .file_path ))
518+ selected_specs = [(t .qualname , t .file_path ) for t in selected_tests ]
553519
554520 base_id = getattr (args , "id" , None )
555521 display_name = getattr (args , "display_name" , None )
@@ -560,6 +526,16 @@ def upload_command(args: argparse.Namespace) -> int:
560526 try :
561527 fw_account_id = get_fireworks_account_id ()
562528 fw_api_key_value = get_fireworks_api_key ()
529+ if not fw_account_id and fw_api_key_value :
530+ # Attempt to verify and resolve account id from server headers
531+ resolved = verify_api_key_and_get_account_id (
532+ api_key = fw_api_key_value , api_base = get_fireworks_api_base ()
533+ )
534+ if resolved :
535+ fw_account_id = resolved
536+ # Propagate to environment so downstream calls use it if needed
537+ os .environ ["FIREWORKS_ACCOUNT_ID" ] = fw_account_id
538+ print (f"Resolved FIREWORKS_ACCOUNT_ID via API verification: { fw_account_id } " )
563539 if fw_account_id and fw_api_key_value :
564540 print ("Ensuring FIREWORKS_API_KEY is registered as a secret on Fireworks for rollout..." )
565541 if create_or_update_fireworks_secret (
@@ -579,8 +555,7 @@ def upload_command(args: argparse.Namespace) -> int:
579555 print (f"Warning: Skipped Fireworks secret registration due to error: { e } " )
580556
581557 exit_code = 0
582- for i , (code , file_name , qualname , source_file_path ) in enumerate (selected_specs ):
583- # Use ts_mode to upload evaluator
558+ for i , (qualname , source_file_path ) in enumerate (selected_specs ):
584559 # Generate a short default ID from just the test function name
585560 if base_id :
586561 evaluator_id = base_id
@@ -618,17 +593,31 @@ def upload_command(args: argparse.Namespace) -> int:
618593
619594 print (f"\n Uploading evaluator '{ evaluator_id } ' for { qualname .split ('.' )[- 1 ]} ..." )
620595 try :
621- result = create_evaluation (
622- evaluator_id = evaluator_id ,
623- python_code_to_evaluate = code ,
624- python_file_name_for_code = file_name ,
625- criterion_name_for_code = qualname ,
626- criterion_description_for_code = description or f"Evaluator for { qualname } " ,
627- display_name = display_name or evaluator_id ,
628- description = description or f"Evaluator for { qualname } " ,
629- force = force ,
630- entry_point = entry_point ,
631- )
596+ # Upload full directory of the test as multi-metric if the dir contains multiple files
597+ test_dir = os .path .dirname (source_file_path ) if source_file_path else root
598+ # Use multi_metrics if multiple .py files exist at the root dir; otherwise treat as single metric dir
599+ py_files = [f for f in os .listdir (test_dir ) if f .endswith (".py" )]
600+ if len (py_files ) > 1 :
601+ result = create_evaluation (
602+ evaluator_id = evaluator_id ,
603+ multi_metrics = True ,
604+ folder = test_dir ,
605+ display_name = display_name or evaluator_id ,
606+ description = description or f"Evaluator for { qualname } " ,
607+ force = force ,
608+ entry_point = entry_point ,
609+ )
610+ else :
611+ # Single metric mode: metric name derived from folder name; include all files recursively
612+ metric_name = os .path .basename (test_dir ) or "metric"
613+ result = create_evaluation (
614+ evaluator_id = evaluator_id ,
615+ metric_folders = [f"{ metric_name } ={ test_dir } " ],
616+ display_name = display_name or evaluator_id ,
617+ description = description or f"Evaluator for { qualname } " ,
618+ force = force ,
619+ entry_point = entry_point ,
620+ )
632621 name = result .get ("name" , evaluator_id ) if isinstance (result , dict ) else evaluator_id
633622
634623 # Print success message with Fireworks dashboard link
0 commit comments