diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 4c8df2df..cafcb978 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -22,6 +22,7 @@ def __init__( self, *, remote_base_url: Optional[str] = None, + model_base_url: Optional[str] = None, poll_interval: float = 1.0, timeout_seconds: float = 120.0, output_data_loader: Callable[[str], DynamicDataLoader], @@ -29,6 +30,7 @@ def __init__( # Prefer constructor-provided configuration. These can be overridden via # config.kwargs at call time for backward compatibility. self._remote_base_url = remote_base_url + self._model_base_url = model_base_url self._poll_interval = poll_interval self._timeout_seconds = timeout_seconds self._output_data_loader = output_data_loader @@ -38,6 +40,7 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> # Start with constructor values remote_base_url: Optional[str] = self._remote_base_url + model_base_url: Optional[str] = self._model_base_url poll_interval: float = self._poll_interval timeout_seconds: float = self._timeout_seconds @@ -112,7 +115,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: messages=clean_messages, tools=row.tools, metadata=meta, - model_base_url=config.kwargs.get("model_base_url", None), + model_base_url=model_base_url, ) # Fire-and-poll