From 4491f177268443b115924e54060cdc63b3302da8 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Fri, 26 Sep 2025 13:18:59 -0700 Subject: [PATCH] add llm base url to initrequest --- .../pytest/remote_rollout_processor.py | 46 +++++++++++++------ .../types/remote_rollout_processor.py | 8 ++++ 2 files changed, 39 insertions(+), 15 deletions(-) diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 5efa793e..2359b6e1 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -6,6 +6,7 @@ from eval_protocol.models import EvaluationRow, Status from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader +from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig @@ -71,14 +72,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> async def _process_row(row: EvaluationRow) -> EvaluationRow: start_time = time.perf_counter() + if row.execution_metadata.invocation_id is None: + raise ValueError("Invocation ID is required in RemoteRolloutProcessor") + if row.execution_metadata.experiment_id is None: + raise ValueError("Experiment ID is required in RemoteRolloutProcessor") + if row.execution_metadata.rollout_id is None: + raise ValueError("Rollout ID is required in RemoteRolloutProcessor") + if row.execution_metadata.run_id is None: + raise ValueError("Run ID is required in RemoteRolloutProcessor") + if row.input_metadata.row_id is None: + raise ValueError("Row ID is required in RemoteRolloutProcessor") + # Build request metadata and payload - meta: Dict[str, Any] = { - "invocation_id": row.execution_metadata.invocation_id, - "experiment_id": row.execution_metadata.experiment_id, - "rollout_id": row.execution_metadata.rollout_id, - "run_id": row.execution_metadata.run_id, - "row_id": row.input_metadata.row_id, - } + meta: RolloutMetadata = RolloutMetadata( + invocation_id=row.execution_metadata.invocation_id, + experiment_id=row.execution_metadata.experiment_id, + rollout_id=row.execution_metadata.rollout_id, + run_id=row.execution_metadata.run_id, + row_id=row.input_metadata.row_id, + ) model: Optional[str] = None if row.input_metadata and row.input_metadata.completion_params: @@ -110,18 +122,22 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: } clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None}) - init_payload: Dict[str, Any] = { - "rollout_id": row.execution_metadata.rollout_id, - "model": model, - "messages": clean_messages, - "tools": row.tools, - "metadata": meta, - } + if row.execution_metadata.rollout_id is None: + raise ValueError("Rollout ID is required in RemoteRolloutProcessor") + + init_payload: InitRequest = InitRequest( + rollout_id=row.execution_metadata.rollout_id, + model=model, + messages=clean_messages, + tools=row.tools, + metadata=meta, + model_base_url=config.kwargs.get("model_base_url", None), + ) # Fire-and-poll def _post_init() -> None: url = f"{remote_base_url}/init" - r = requests.post(url, json=init_payload, timeout=30) + r = requests.post(url, json=init_payload.model_dump(), timeout=30) r.raise_for_status() await asyncio.to_thread(_post_init) diff --git a/eval_protocol/types/remote_rollout_processor.py b/eval_protocol/types/remote_rollout_processor.py index bdc1f9f2..21d93ceb 100644 --- a/eval_protocol/types/remote_rollout_processor.py +++ b/eval_protocol/types/remote_rollout_processor.py @@ -24,6 +24,14 @@ class InitRequest(BaseModel): model: str messages: List[Message] = Field(min_length=1) tools: Optional[List[Dict[str, Any]]] = None + + model_base_url: Optional[str] = None + """ + A Base URL that the remote server can use to make LLM calls. This is useful + to configure on the eval-protocol side for flexibility in + development/traning. + """ + metadata: RolloutMetadata