Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig

Expand Down Expand Up @@ -71,14 +72,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
async def _process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()

if row.execution_metadata.invocation_id is None:
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")
if row.execution_metadata.experiment_id is None:
raise ValueError("Experiment ID is required in RemoteRolloutProcessor")
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
if row.execution_metadata.run_id is None:
raise ValueError("Run ID is required in RemoteRolloutProcessor")
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in RemoteRolloutProcessor")

# Build request metadata and payload
meta: Dict[str, Any] = {
"invocation_id": row.execution_metadata.invocation_id,
"experiment_id": row.execution_metadata.experiment_id,
"rollout_id": row.execution_metadata.rollout_id,
"run_id": row.execution_metadata.run_id,
"row_id": row.input_metadata.row_id,
}
meta: RolloutMetadata = RolloutMetadata(
invocation_id=row.execution_metadata.invocation_id,
experiment_id=row.execution_metadata.experiment_id,
rollout_id=row.execution_metadata.rollout_id,
run_id=row.execution_metadata.run_id,
row_id=row.input_metadata.row_id,
)

model: Optional[str] = None
if row.input_metadata and row.input_metadata.completion_params:
Expand Down Expand Up @@ -110,18 +122,22 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
}
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})

init_payload: Dict[str, Any] = {
"rollout_id": row.execution_metadata.rollout_id,
"model": model,
"messages": clean_messages,
"tools": row.tools,
"metadata": meta,
}
if row.execution_metadata.rollout_id is None:
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")

init_payload: InitRequest = InitRequest(
rollout_id=row.execution_metadata.rollout_id,
model=model,
messages=clean_messages,
tools=row.tools,
metadata=meta,
model_base_url=config.kwargs.get("model_base_url", None),
)

# Fire-and-poll
def _post_init() -> None:
url = f"{remote_base_url}/init"
r = requests.post(url, json=init_payload, timeout=30)
r = requests.post(url, json=init_payload.model_dump(), timeout=30)
r.raise_for_status()

await asyncio.to_thread(_post_init)
Expand Down
8 changes: 8 additions & 0 deletions eval_protocol/types/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ class InitRequest(BaseModel):
model: str
messages: List[Message] = Field(min_length=1)
tools: Optional[List[Dict[str, Any]]] = None

model_base_url: Optional[str] = None
"""
A Base URL that the remote server can use to make LLM calls. This is useful
to configure on the eval-protocol side for flexibility in
development/traning.
"""

metadata: RolloutMetadata


Expand Down
Loading