-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_get_metadata.py
More file actions
37 lines (32 loc) · 1.56 KB
/
test_get_metadata.py
File metadata and controls
37 lines (32 loc) · 1.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import asyncio
from eval_protocol.pytest import evaluation_test
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
@evaluation_test(
input_messages=[
[
[
Message(role="user", content="What is the capital of France?"),
],
[
Message(role="user", content="What is the capital of the moon?"),
],
]
],
completion_params=[{"model": "accounts/fireworks/models/kimi-k2-instruct"}] * 2,
mode="groupwise",
max_concurrent_rollouts=5,
max_concurrent_evaluations=10,
)
def test_pytest_async(rows: list[EvaluationRow]) -> list[EvaluationRow]:
"""Run math evaluation on sample dataset using pytest interface."""
for row in rows:
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
return rows
def test_pytest_func_metainfo():
assert hasattr(test_pytest_async, "_origin_func")
origin_func = test_pytest_async._origin_func # pyright: ignore[reportAny, reportFunctionMemberAccess]
assert not asyncio.iscoroutinefunction(origin_func) # pyright: ignore[reportAny]
assert asyncio.iscoroutinefunction(test_pytest_async)
assert test_pytest_async._metainfo["mode"] == "groupwise" # pyright: ignore[reportAny, reportFunctionMemberAccess]
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5 # pyright: ignore[reportAny, reportFunctionMemberAccess]
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10 # pyright: ignore[reportAny, reportFunctionMemberAccess]