diff --git a/eval_protocol/pytest/evaluation_test.py b/eval_protocol/pytest/evaluation_test.py index 4aabd296..4625114a 100644 --- a/eval_protocol/pytest/evaluation_test.py +++ b/eval_protocol/pytest/evaluation_test.py @@ -189,6 +189,7 @@ def evaluation_test( completion_params = parse_ep_completion_params(completion_params) original_completion_params = completion_params passed_threshold = parse_ep_passed_threshold(passed_threshold) + custom_invocation_id = os.environ.get("EP_INVOCATION_ID", None) def decorator( test_func: TestFunction, @@ -228,7 +229,10 @@ def decorator( # Create wrapper function with exact signature that pytest expects def create_wrapper_with_signature() -> Callable[[], None]: # Create the function body that will be used - invocation_id = generate_id() + if custom_invocation_id: + invocation_id = custom_invocation_id + else: + invocation_id = generate_id() async def wrapper_body(**kwargs: Unpack[ParameterizedTestKwargs]) -> None: # Store URL for viewing results (after all postprocessing is complete) diff --git a/tests/pytest/test_pytest_env_overwrite.py b/tests/pytest/test_pytest_env_overwrite.py new file mode 100644 index 00000000..c88dd2b8 --- /dev/null +++ b/tests/pytest/test_pytest_env_overwrite.py @@ -0,0 +1,20 @@ +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest import evaluation_test +from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor +import os +from unittest import mock + + +with mock.patch.dict(os.environ, {"EP_INVOCATION_ID": "test-invocation-123"}): + + @evaluation_test( + input_rows=[[EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])]], + completion_params=[{"model": "no-op"}], + rollout_processor=NoOpRolloutProcessor(), + mode="pointwise", + ) + def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow: + """Run math evaluation on sample dataset using pytest interface.""" + assert row.messages[0].content == "What is the capital of France?" + assert row.execution_metadata.invocation_id == "test-invocation-123" + return row