Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion eval_protocol/pytest/evaluation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def evaluation_test(
completion_params = parse_ep_completion_params(completion_params)
original_completion_params = completion_params
passed_threshold = parse_ep_passed_threshold(passed_threshold)
custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None)

def decorator(
test_func: TestFunction,
Expand Down Expand Up @@ -228,7 +229,10 @@ def decorator(
# Create wrapper function with exact signature that pytest expects
def create_wrapper_with_signature() -> Callable[[], None]:
# Create the function body that will be used
invocation_id = generate_id()
if custom_invocation_id:
invocation_id = custom_invocation_id
else:
invocation_id = generate_id()

async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None:
# Store URL for viewing results (after all postprocessing is complete)
Expand Down
20 changes: 20 additions & 0 deletions tests/pytest/test_pytest_env_overwrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
import os
from unittest import mock


with mock.patch.dict(os.environ, {"EP_INVOCATION_ID": "test-invocation-123"}):

@evaluation_test(
input_rows=[[EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])]],
completion_params=[{"model": "no-op"}],
rollout_processor=NoOpRolloutProcessor(),
mode="pointwise",
)
def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
"""Run math evaluation on sample dataset using pytest interface."""
assert row.messages[0].content == "What is the capital of France?"
assert row.execution_metadata.invocation_id == "test-invocation-123"
return row
Loading