Skip to content

Commit b37b35a

Browse files
committed
add test
1 parent b47a51e commit b37b35a

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/pytest/test_pytest_env_overwrite.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import atexit
2+
import shutil
3+
import tempfile
14
from eval_protocol.models import EvaluationRow, Message
25
from eval_protocol.pytest import evaluation_test
36
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
@@ -18,3 +21,40 @@ def test_input_messages_in_decorator(row: EvaluationRow) -> EvaluationRow:
1821
assert row.messages[0].content == "What is the capital of France?"
1922
assert row.execution_metadata.invocation_id == "test-invocation-123"
2023
return row
24+
25+
26+
27+
with mock.patch.dict(os.environ, {"EP_COMPLETION_PARAMS": "[{\"model\": \"gpt-40\"}]"}):
28+
@evaluation_test(
29+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="What is 5 * 6?")])]],
30+
completion_params=[{"model": "no-op"}], # This should be overridden by the env var
31+
rollout_processor=NoOpRolloutProcessor(),
32+
mode="pointwise",
33+
)
34+
def test_input_messages_in_env(row: EvaluationRow) -> EvaluationRow:
35+
"""Run math evaluation on sample dataset using pytest interface."""
36+
assert row.messages[0].content == "What is 5 * 6?"
37+
assert row.input_metadata.completion_params["model"] == "gpt-40"
38+
return row
39+
40+
41+
42+
_jsonl_tmpdir = tempfile.mkdtemp()
43+
atexit.register(shutil.rmtree, _jsonl_tmpdir, ignore_errors=True)
44+
45+
input_path = os.path.join(_jsonl_tmpdir, "input.jsonl")
46+
with open(input_path, "w") as f:
47+
f.write(
48+
'{"messages": [{"role": "user", "content": "What is 10 / 2?"}], "input_metadata": {"some_key": "some_value"}}\n'
49+
)
50+
print(f"finish prepare input file {input_path}")
51+
with mock.patch.dict(os.environ, {"EP_JSONL_PATH": input_path}):
52+
@evaluation_test(
53+
input_rows=[[EvaluationRow(messages=[Message(role="user", content="This will be ignored")])]],
54+
completion_params=[{"model": "no-op"}],
55+
rollout_processor=NoOpRolloutProcessor(),
56+
mode="pointwise",
57+
)
58+
def test_input_override(row: EvaluationRow) -> EvaluationRow:
59+
assert row.messages[0].content == "What is 10 / 2?"
60+
return row

0 commit comments

Comments
 (0)