Skip to content

Commit b42b3e1

Browse files
authored
support custom invocation id (#220)
* support custom invocation id * format * add test * format * format
1 parent f9c6f1b commit b42b3e1

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def evaluation_test(
189189
completion_params = parse_ep_completion_params(completion_params)
190190
original_completion_params = completion_params
191191
passed_threshold = parse_ep_passed_threshold(passed_threshold)
192+
custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None)
192193

193194
def decorator(
194195
test_func: TestFunction,
@@ -228,7 +229,10 @@ def decorator(
228229
# Create wrapper function with exact signature that pytest expects
229230
def create_wrapper_with_signature() -> Callable[[], None]:
230231
# Create the function body that will be used
231-
invocation_id = generate_id()
232+
if custom_invocation_id:
233+
invocation_id = custom_invocation_id
234+
else:
235+
invocation_id = generate_id()
232236

233237
async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
234238
# Store URL for viewing results (after all postprocessing is complete)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from eval_protocol.models import EvaluationRow, Message
2+
from eval_protocol.pytest import evaluation_test
3+
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
4+
import os
5+
from unittest import mock
6+
7+
8+
with mock.patch.dict(os.environ, {"EP_INVOCATION_ID": "test-invocation-123"}):
9+
10+
@evaluation_test(
11+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])]],
12+
completion_params=[{"model": "no-op"}],
13+
rollout_processor=NoOpRolloutProcessor(),
14+
mode="pointwise",
15+
)
16+
def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
17+
"""Run math evaluation on sample dataset using pytest interface."""
18+
assert row.messages[0].content == "What is the capital of France?"
19+
assert row.execution_metadata.invocation_id == "test-invocation-123"
20+
return row

0 commit comments

Comments
 (0)