Skip to content

Commit 8c1c254

Browse files
committed
format
1 parent 370d4da commit 8c1c254

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -584,9 +584,10 @@ def _log_eval_error(
584584
# log the fresh_dataset
585585
for row in fresh_dataset:
586586
active_logger.log(row)
587-
587+
588588
# prepare parallel eval helper function
589589
semaphore = asyncio.Semaphore(max_concurrent_evaluations)
590+
590591
async def _execute_eval_with_semaphore(**kwargs):
591592
async with semaphore:
592593
# NOTE: we will still evaluate errored rows (give users control over this)
@@ -802,23 +803,23 @@ async def dual_mode_wrapper(*args, **kwargs):
802803

803804
# If not a direct call, use the pytest wrapper
804805
return await pytest_wrapper(*args, **kwargs)
805-
806-
dual_mode_wrapper._origin_func = test_func
806+
807+
dual_mode_wrapper._origin_func = test_func
807808
dual_mode_wrapper._evaluator_id = test_func.__name__
808809
# Generate (stable) evaluator ID from function source code hash
809810
try:
810811
func_source = inspect.getsource(test_func)
811812
parsed = ast.parse(func_source)
812813
normalized_source = ast.unparse(parsed)
813-
clean_source = ''.join(normalized_source.split()) + test_func.__name__
814-
func_hash = hashlib.sha256(clean_source.encode('utf-8')).hexdigest()[:12]
814+
clean_source = "".join(normalized_source.split()) + test_func.__name__
815+
func_hash = hashlib.sha256(clean_source.encode("utf-8")).hexdigest()[:12]
815816
dual_mode_wrapper._version = f"{test_func.__name__}_{func_hash}"
816817
except (OSError, TypeError, SyntaxError):
817818
pass
818819
dual_mode_wrapper._metainfo = {
819-
"mode": mode,
820-
"max_rollout_concurrency": max_concurrent_rollouts,
821-
"max_evaluation_concurrency": max_concurrent_evaluations,
820+
"mode": mode,
821+
"max_rollout_concurrency": max_concurrent_rollouts,
822+
"max_evaluation_concurrency": max_concurrent_evaluations,
822823
}
823824

824825
# Copy all attributes from the pytest wrapper to our dual mode wrapper

tests/pytest/test_get_metadata.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from eval_protocol.pytest import evaluation_test
55
from eval_protocol.models import EvaluationRow, Message
66

7+
78
@evaluation_test(
89
input_messages=[
910
[
@@ -23,24 +24,18 @@ def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2324
return rows
2425

2526

26-
2727
def test_pytest_func_metainfo():
28-
assert hasattr(test_pytest_async, "_origin_func")
28+
assert hasattr(test_pytest_async, "_origin_func")
2929
origin_func = test_pytest_async._origin_func
3030
assert not asyncio.iscoroutinefunction(origin_func)
3131
assert asyncio.iscoroutinefunction(test_pytest_async)
3232
assert test_pytest_async._metainfo["mode"] == "groupwise"
3333
assert test_pytest_async._metainfo["max_rollout_concurrency"] == 5
3434
assert test_pytest_async._metainfo["max_evaluation_concurrency"] == 10
35-
35+
3636
# Test evaluator ID generation
3737
assert hasattr(test_pytest_async, "_evaluator_id")
3838
evaluator_id = test_pytest_async._evaluator_id
3939
assert evaluator_id.startswith("eval_")
4040
assert len(evaluator_id) == 17 # "eval_" + 12 character hash
4141
print(f"Generated evaluator ID: {evaluator_id}")
42-
43-
44-
45-
46-

0 commit comments

Comments
 (0)