Skip to content

Commit 9d7040d

Browse files
author
Dylan Huang
authored
fix printing of local ui URL to be at end of pytest session (#204)
* fix printing of local ui URL to be at end of pytest session * fix * fix test_pytest_propagate_error
1 parent fdd76dc commit 9d7040d

File tree

7 files changed

+168
-128
lines changed

7 files changed

+168
-128
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
parse_ep_passed_threshold,
6060
rollout_processor_with_retry,
6161
)
62-
from eval_protocol.utils.show_results_url import show_results_url
62+
from eval_protocol.utils.show_results_url import store_local_ui_results_url
6363

6464
from ..common_utils import load_jsonl
6565

@@ -220,6 +220,9 @@ def create_wrapper_with_signature() -> Callable[[], None]:
220220
# Create the function body that will be used
221221
invocation_id = generate_id()
222222

223+
# Store URL for viewing results (after all postprocessing is complete)
224+
store_local_ui_results_url(invocation_id)
225+
223226
async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
224227
eval_metadata = None
225228

@@ -556,9 +559,6 @@ async def execute_run_with_progress(run_idx: int, config):
556559
experiment_duration_seconds,
557560
)
558561

559-
# Show URL for viewing results (after all postprocessing is complete)
560-
show_results_url(invocation_id)
561-
562562
except AssertionError:
563563
_log_eval_error(
564564
Status.eval_finished(),

eval_protocol/pytest/plugin.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def pytest_configure(config) -> None:
279279
pass
280280

281281

282-
def pytest_sessionfinish(session, exitstatus):
282+
def _print_experiment_links(session):
283283
"""Print all collected Fireworks experiment links from pytest stash."""
284284
try:
285285
# Late import to avoid circulars; if missing key, skip printing
@@ -291,9 +291,8 @@ def pytest_sessionfinish(session, exitstatus):
291291
except Exception:
292292
EXPERIMENT_LINKS_STASH_KEY = None
293293

294-
# Get links from pytest stash using shared key
294+
# Get links from pytest stash
295295
links = []
296-
297296
if EXPERIMENT_LINKS_STASH_KEY is not None and EXPERIMENT_LINKS_STASH_KEY in session.stash:
298297
links = session.stash[EXPERIMENT_LINKS_STASH_KEY]
299298

@@ -309,6 +308,55 @@ def pytest_sessionfinish(session, exitstatus):
309308
print(f"❌ Experiment {link['experiment_id']}: {link['job_link']}", file=sys.__stderr__)
310309

311310
print("=" * 80, file=sys.__stderr__)
311+
return True
312+
return False
313+
except Exception:
314+
return False
315+
316+
317+
def _print_local_ui_results_urls(session):
318+
"""Print all collected evaluation results URLs from pytest stash."""
319+
try:
320+
# Late import to avoid circulars; if missing key, skip printing
321+
RESULTS_URLS_STASH_KEY = None
322+
try:
323+
from .store_results_url import RESULTS_URLS_STASH_KEY as _URL_KEY # type: ignore
324+
325+
RESULTS_URLS_STASH_KEY = _URL_KEY
326+
except Exception:
327+
RESULTS_URLS_STASH_KEY = None
328+
329+
# Get URLs from pytest stash
330+
urls = []
331+
if RESULTS_URLS_STASH_KEY is not None and RESULTS_URLS_STASH_KEY in session.stash:
332+
urls = session.stash[RESULTS_URLS_STASH_KEY]
333+
334+
if urls:
335+
print("\n" + "=" * 80, file=sys.__stderr__)
336+
print("📊 LOCAL UI EVALUATION RESULTS", file=sys.__stderr__)
337+
print("=" * 80, file=sys.__stderr__)
338+
339+
for url_data in urls:
340+
print(f"📊 Invocation {url_data['invocation_id']}:", file=sys.__stderr__)
341+
print(f" 📊 Aggregate scores: {url_data['pivot_url']}", file=sys.__stderr__)
342+
print(f" 📋 Trajectories: {url_data['table_url']}", file=sys.__stderr__)
343+
344+
print("=" * 80, file=sys.__stderr__)
345+
return True
346+
return False
347+
except Exception:
348+
return False
349+
350+
351+
def pytest_sessionfinish(session, exitstatus):
352+
"""Print all collected Fireworks experiment links and evaluation results URLs from pytest stash."""
353+
try:
354+
# Print experiment links and results URLs separately
355+
links_printed = _print_experiment_links(session)
356+
urls_printed = _print_local_ui_results_urls(session)
357+
358+
# Flush stderr if anything was printed
359+
if links_printed or urls_printed:
312360
err_stream = getattr(sys, "__stderr__", None)
313361
if err_stream is not None:
314362
try:
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import TypedDict
2+
from pytest import StashKey
3+
4+
5+
class ResultsUrl(TypedDict):
6+
invocation_id: str
7+
pivot_url: str
8+
table_url: str
9+
10+
11+
RESULTS_URLS_STASH_KEY = StashKey[list[ResultsUrl]]()
12+
13+
14+
def _store_local_ui_url_in_stash(invocation_id: str, pivot_url: str, table_url: str):
15+
"""Store results URL in pytest session stash."""
16+
try:
17+
import sys
18+
19+
# Walk up the call stack to find the pytest session
20+
session = None
21+
frame = sys._getframe() # pyright: ignore[reportPrivateUsage]
22+
while frame:
23+
if "session" in frame.f_locals and hasattr(frame.f_locals["session"], "stash"): # pyright: ignore[reportAny]
24+
session = frame.f_locals["session"] # pyright: ignore[reportAny]
25+
break
26+
frame = frame.f_back
27+
28+
if session is not None:
29+
global RESULTS_URLS_STASH_KEY
30+
31+
if RESULTS_URLS_STASH_KEY not in session.stash: # pyright: ignore[reportAny]
32+
session.stash[RESULTS_URLS_STASH_KEY] = [] # pyright: ignore[reportAny]
33+
34+
session.stash[RESULTS_URLS_STASH_KEY].append( # pyright: ignore[reportAny]
35+
{"invocation_id": invocation_id, "pivot_url": pivot_url, "table_url": table_url}
36+
)
37+
else:
38+
pass
39+
40+
except Exception as e: # pyright: ignore[reportUnusedVariable]
41+
pass
42+
43+
44+
def store_local_ui_url(invocation_id: str, pivot_url: str, table_url: str):
45+
"""Public function to store results URL in pytest session stash."""
46+
_store_local_ui_url_in_stash(invocation_id, pivot_url, table_url)

eval_protocol/utils/show_results_url.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import socket
66
import urllib.parse
77

8+
from eval_protocol.pytest.store_results_url import store_local_ui_url
9+
810

911
def is_server_running(host: str = "localhost", port: int = 8000) -> bool:
1012
"""
@@ -58,25 +60,15 @@ def generate_invocation_filter_url(invocation_id: str, base_url: str = "http://l
5860
return f"{base_url}?filterConfig={encoded_filter}"
5961

6062

61-
def show_results_url(invocation_id: str) -> None:
63+
def store_local_ui_results_url(invocation_id: str) -> None:
6264
"""
63-
Show URLs for viewing evaluation results filtered by invocation_id.
64-
65-
If the server is not running, prints a message to run "ep logs" to start the local UI.
66-
If the server is running, prints URLs to view results filtered by invocation_id.
65+
Store URLs for viewing evaluation results filtered by invocation_id in pytest stash.
6766
6867
Args:
69-
invocation_id: The invocation ID to filter results by
68+
invocation_id: The invocation ID to filter results by
7069
"""
71-
if is_server_running():
72-
pivot_url = generate_invocation_filter_url(invocation_id, "http://localhost:8000/pivot")
73-
table_url = generate_invocation_filter_url(invocation_id, "http://localhost:8000/table")
74-
print("View your evaluation results:")
75-
print(f" 📊 Aggregate scores: {pivot_url}")
76-
print(f" 📋 Trajectories: {table_url}")
77-
else:
78-
pivot_url = generate_invocation_filter_url(invocation_id, "http://localhost:8000/pivot")
79-
table_url = generate_invocation_filter_url(invocation_id, "http://localhost:8000/table")
80-
print("Start the local UI with 'ep logs', then visit:")
81-
print(f" 📊 Aggregate scores: {pivot_url}")
82-
print(f" 📋 Trajectories: {table_url}")
70+
pivot_url = generate_invocation_filter_url(invocation_id, "http://localhost:8000/pivot")
71+
table_url = generate_invocation_filter_url(invocation_id, "http://localhost:8000/table")
72+
73+
# Store URLs in pytest stash for later printing in pytest_sessionfinish
74+
store_local_ui_url(invocation_id, pivot_url, table_url)

tests/chinook/pydantic/test_pydantic_chinook.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,25 @@
2323

2424
def agent_factory(config: RolloutProcessorConfig) -> Agent:
2525
model_name = config.completion_params["model"]
26-
provider = config.completion_params["provider"]
26+
provider = config.completion_params["provider"] if "provider" in config.completion_params else "openai"
2727
model = OpenAIChatModel(model_name, provider=provider)
2828
return setup_agent(model)
2929

3030

31-
@pytest.mark.asyncio
32-
@evaluation_test(
33-
input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]],
34-
completion_params=[
31+
@pytest.mark.parametrize(
32+
"completion_params",
33+
[
3534
{
3635
"model": "accounts/fireworks/models/kimi-k2-instruct",
3736
"provider": "fireworks",
3837
},
38+
{
39+
"model": "gpt-5",
40+
},
3941
],
42+
)
43+
@evaluation_test(
44+
input_messages=[[[Message(role="user", content="What is the total number of tracks in the database?")]]],
4045
rollout_processor=PydanticAgentRolloutProcessor(agent_factory),
4146
mode="pointwise",
4247
)

tests/pytest/test_pytest_propagate_error.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,6 @@ def eval_fn(row: EvaluationRow) -> EvaluationRow:
7272
assert row.eval_metadata.status.is_error()
7373

7474
# make sure the error message includes details of the error
75-
assert all("HTTPStatusError" in row.rollout_status.message for row in rollouts.values())
76-
assert all("405 Method Not Allowed" in row.rollout_status.message for row in rollouts.values())
77-
assert all("https://docs.fireworks.ai/mcp-non-existent" in row.rollout_status.message for row in rollouts.values())
75+
assert any("HTTPStatusError" in row.rollout_status.message for row in rollouts.values())
76+
assert any("405 Method Not Allowed" in row.rollout_status.message for row in rollouts.values())
77+
assert any("https://docs.fireworks.ai/mcp-non-existent" in row.rollout_status.message for row in rollouts.values())

0 commit comments

Comments
 (0)