Skip to content

Commit 2d97758

Browse files
authored
make remote rollout processor respect max_concurrent_rollout (#231)
1 parent 76cc1e7 commit 2d97758

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,14 @@ def _load_data():
169169
else:
170170
raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.")
171171

172-
for r in rows:
173-
tasks.append(asyncio.create_task(_process_row(r)))
172+
semaphore = config.semaphore
174173

174+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
175+
async with semaphore:
176+
result = await _process_row(r)
177+
return result
178+
179+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
175180
return tasks
176181

177182
def cleanup(self) -> None:

0 commit comments

Comments
 (0)