Skip to content

Commit 2d1094c

Browse files
author
Dylan Huang
authored
Add model_base_url parameter to RemoteRolloutProcessor constructor (#233)
- Introduced model_base_url as an optional parameter in the RemoteRolloutProcessor class. - Updated internal handling to utilize the model_base_url for configuration consistency.
1 parent 0ec8ba5 commit 2d1094c

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ def __init__(
2222
self,
2323
*,
2424
remote_base_url: Optional[str] = None,
25+
model_base_url: Optional[str] = None,
2526
poll_interval: float = 1.0,
2627
timeout_seconds: float = 120.0,
2728
output_data_loader: Callable[[str], DynamicDataLoader],
2829
):
2930
# Prefer constructor-provided configuration. These can be overridden via
3031
# config.kwargs at call time for backward compatibility.
3132
self._remote_base_url = remote_base_url
33+
self._model_base_url = model_base_url
3234
self._poll_interval = poll_interval
3335
self._timeout_seconds = timeout_seconds
3436
self._output_data_loader = output_data_loader
@@ -38,6 +40,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
3840

3941
# Start with constructor values
4042
remote_base_url: Optional[str] = self._remote_base_url
43+
model_base_url: Optional[str] = self._model_base_url
4144
poll_interval: float = self._poll_interval
4245
timeout_seconds: float = self._timeout_seconds
4346

@@ -112,7 +115,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
112115
messages=clean_messages,
113116
tools=row.tools,
114117
metadata=meta,
115-
model_base_url=config.kwargs.get("model_base_url", None),
118+
model_base_url=model_base_url,
116119
)
117120

118121
# Fire-and-poll

0 commit comments

Comments
 (0)