Skip to content

Commit 2f2bf26

Browse files
author
Dylan Huang
committed
fix test_math_dataset
1 parent d20f83e commit 2f2bf26

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,11 @@ async def _execute_pointwise_eval_with_semaphore(
322322
row: EvaluationRow,
323323
) -> EvaluationRow:
324324
async with semaphore:
325+
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
325326
result = await execute_pytest(
326327
test_func,
327328
processed_row=row,
328-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
329+
evaluation_test_kwargs=evaluation_test_kwargs,
329330
)
330331
if not isinstance(result, EvaluationRow):
331332
raise ValueError(
@@ -337,10 +338,11 @@ async def _execute_groupwise_eval_with_semaphore(
337338
rows: list[EvaluationRow],
338339
) -> list[EvaluationRow]:
339340
async with semaphore:
341+
evaluation_test_kwargs = kwargs.get("evaluation_test_kwargs") or {}
340342
results = await execute_pytest(
341343
test_func,
342344
processed_dataset=rows,
343-
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
345+
evaluation_test_kwargs=evaluation_test_kwargs,
344346
)
345347
if not isinstance(results, list):
346348
raise ValueError(

eval_protocol/pytest/execution.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,25 @@ async def execute_pytest(
1919
raise ValueError("'row' is a reserved parameter for the evaluation function")
2020
if "rows" in evaluation_test_kwargs:
2121
raise ValueError("'rows' is a reserved parameter for the evaluation function")
22+
else:
23+
evaluation_test_kwargs = {}
2224

2325
# Handle both sync and async test functions
2426
if asyncio.iscoroutinefunction(test_func):
2527
if processed_row is not None:
2628
test_func = cast(Callable[[EvaluationRow], Awaitable[EvaluationRow]], test_func)
27-
return await test_func(processed_row)
29+
return await test_func(processed_row, **evaluation_test_kwargs)
2830
if processed_dataset is not None:
2931
test_func = cast(Callable[[list[EvaluationRow]], Awaitable[list[EvaluationRow]]], test_func)
30-
return await test_func(processed_dataset)
32+
return await test_func(processed_dataset, **evaluation_test_kwargs)
3133
test_func = cast(Callable[[], Awaitable[EvaluationRow]], test_func)
32-
return await test_func()
34+
return await test_func(**evaluation_test_kwargs)
3335
else:
3436
if processed_row is not None:
3537
test_func = cast(Callable[[EvaluationRow], EvaluationRow], test_func)
36-
return test_func(processed_row)
38+
return test_func(processed_row, **evaluation_test_kwargs)
3739
if processed_dataset is not None:
3840
test_func = cast(Callable[[Dataset], Dataset], test_func)
39-
return test_func(processed_dataset)
41+
return test_func(processed_dataset, **evaluation_test_kwargs)
4042
test_func = cast(Callable[[], EvaluationRow], test_func)
41-
return test_func()
43+
return test_func(**evaluation_test_kwargs)

0 commit comments

Comments
 (0)