Skip to content

Commit 8cb080c

Browse files
committed
removing pyright
1 parent 903584b commit 8cb080c

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

eval_protocol/pytest/evaluation_test.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def evaluation_test(
7272
input_dataset: Sequence[DatasetPathParam] | None = None,
7373
input_rows: Sequence[list[EvaluationRow]] | None = None,
7474
data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None,
75-
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
75+
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter,
7676
rollout_processor: RolloutProcessor | None = None,
7777
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None,
7878
rollout_processor_kwargs: RolloutProcessorInputParam | None = None,
@@ -418,9 +418,7 @@ async def _execute_groupwise_eval_with_semaphore(
418418
all_results[run_idx] = results
419419
elif mode == "groupwise":
420420
# rollout all the completion_params for the same row at once, and then send the output to the test_func
421-
row_groups = defaultdict( # pyright: ignore[reportUnknownVariableType]
422-
list
423-
) # key: row_id, value: list of rollout_result
421+
row_groups = defaultdict(list) # key: row_id, value: list of rollout_result
424422
tasks: list[asyncio.Task[list[EvaluationRow]]] = []
425423
# completion_groups = []
426424
for idx, cp in enumerate(original_completion_params):
@@ -435,13 +433,13 @@ async def _execute_groupwise_eval_with_semaphore(
435433
)
436434
lst = []
437435

438-
async def _collect_result(config, lst): # pyright: ignore[reportUnknownParameterType, reportMissingParameterType]
436+
async def _collect_result(config, lst):
439437
result = []
440438
async for row in rollout_processor_with_retry(
441439
rollout_processor, lst, config, run_idx
442440
): # pyright: ignore[reportUnknownArgumentType]
443-
result.append(row) # pyright: ignore[reportUnknownMemberType]
444-
return result # pyright: ignore[reportUnknownVariableType]
441+
result.append(row)
442+
return result
445443

446444
for ori_row in fresh_dataset:
447445
copied_row = ori_row.model_copy(deep=True)
@@ -450,33 +448,32 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
450448
str(ori_row.execution_metadata.rollout_id) + "_" + str(idx)
451449
)
452450
copied_row.input_metadata.completion_params = cp if cp is not None else {}
453-
lst.append(copied_row) # pyright: ignore[reportUnknownMemberType]
454-
tasks.append(asyncio.create_task(_collect_result(config, lst))) # pyright: ignore[reportUnknownArgumentType]
451+
lst.append(copied_row)
452+
tasks.append(asyncio.create_task(_collect_result(config, lst)))
455453
rollout_results = await asyncio.gather(*tasks)
456454
for result in rollout_results:
457455
for row in result:
458-
row_groups[row.input_metadata.row_id].append(row) # pyright: ignore[reportUnknownMemberType]
456+
row_groups[row.input_metadata.row_id].append(row)
459457
tasks = []
460-
for _, rows in row_groups.items(): # pyright: ignore[reportUnknownVariableType]
461-
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows))) # pyright: ignore[reportUnknownArgumentType]
458+
for _, rows in row_groups.items():
459+
tasks.append(asyncio.create_task(_execute_groupwise_eval_with_semaphore(rows=rows)))
462460
results = []
463461
for task in tasks:
464462
res = await task
465-
results.extend(res) # pyright: ignore[reportUnknownMemberType]
463+
results.extend(res)
466464
all_results[run_idx] = results
467465
else:
468466
# Batch mode: collect all results first, then evaluate (no pipelining)
469467
input_dataset = []
470468
async for row in rollout_processor_with_retry(
471469
rollout_processor, fresh_dataset, config, run_idx
472470
):
473-
input_dataset.append(row) # pyright: ignore[reportUnknownMemberType]
474-
471+
input_dataset.append(row)
475472
# NOTE: we will still evaluate errored rows (give users control over this)
476473
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
477474
results = await execute_pytest(
478475
test_func,
479-
processed_dataset=input_dataset, # pyright: ignore[reportUnknownArgumentType]
476+
processed_dataset=input_dataset,
480477
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
481478
)
482479
if (
@@ -539,16 +536,16 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
539536
# for groupwise mode, the result contains eval otuput from multiple completion_params, we need to differentiate them
540537
# rollout_id is used to differentiate the result from different completion_params
541538
if mode == "groupwise":
542-
results_by_group = [ # pyright: ignore[reportUnknownVariableType]
539+
results_by_group = [
543540
[[] for _ in range(num_runs)] for _ in range(len(original_completion_params))
544541
]
545542
for i_run, result in enumerate(all_results):
546543
for r in result:
547544
completion_param_idx = int(r.execution_metadata.rollout_id.split("_")[1]) # pyright: ignore[reportOptionalMemberAccess]
548-
results_by_group[completion_param_idx][i_run].append(r) # pyright: ignore[reportUnknownMemberType]
549-
for rollout_id, result in enumerate(results_by_group): # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
545+
results_by_group[completion_param_idx][i_run].append(r)
546+
for rollout_id, result in enumerate(results_by_group):
550547
postprocess(
551-
result, # pyright: ignore[reportUnknownArgumentType]
548+
result,
552549
aggregation_method,
553550
passed_threshold,
554551
active_logger,
@@ -600,7 +597,7 @@ async def _collect_result(config, lst): # pyright: ignore[reportUnknownParamete
600597
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
601598

602599
# Create the dual mode wrapper
603-
dual_mode_wrapper = create_dual_mode_wrapper( # pyright: ignore[reportUnknownVariableType]
600+
dual_mode_wrapper = create_dual_mode_wrapper(
604601
test_func, mode, max_concurrent_rollouts, max_concurrent_evaluations, pytest_wrapper
605602
)
606603

0 commit comments

Comments
 (0)