Skip to content

Commit 4491f17

Browse files
author
Dylan Huang
committed
add llm base url to initrequest
1 parent d4d6cef commit 4491f17

File tree

2 files changed

+39
-15
lines changed

2 files changed

+39
-15
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from eval_protocol.models import EvaluationRow, Status
88
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
9+
from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata
910
from .rollout_processor import RolloutProcessor
1011
from .types import RolloutProcessorConfig
1112

@@ -71,14 +72,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) ->
7172
async def _process_row(row: EvaluationRow) -> EvaluationRow:
7273
start_time = time.perf_counter()
7374

75+
if row.execution_metadata.invocation_id is None:
76+
raise ValueError("Invocation ID is required in RemoteRolloutProcessor")
77+
if row.execution_metadata.experiment_id is None:
78+
raise ValueError("Experiment ID is required in RemoteRolloutProcessor")
79+
if row.execution_metadata.rollout_id is None:
80+
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
81+
if row.execution_metadata.run_id is None:
82+
raise ValueError("Run ID is required in RemoteRolloutProcessor")
83+
if row.input_metadata.row_id is None:
84+
raise ValueError("Row ID is required in RemoteRolloutProcessor")
85+
7486
# Build request metadata and payload
75-
meta: Dict[str, Any] = {
76-
"invocation_id": row.execution_metadata.invocation_id,
77-
"experiment_id": row.execution_metadata.experiment_id,
78-
"rollout_id": row.execution_metadata.rollout_id,
79-
"run_id": row.execution_metadata.run_id,
80-
"row_id": row.input_metadata.row_id,
81-
}
87+
meta: RolloutMetadata = RolloutMetadata(
88+
invocation_id=row.execution_metadata.invocation_id,
89+
experiment_id=row.execution_metadata.experiment_id,
90+
rollout_id=row.execution_metadata.rollout_id,
91+
run_id=row.execution_metadata.run_id,
92+
row_id=row.input_metadata.row_id,
93+
)
8294

8395
model: Optional[str] = None
8496
if row.input_metadata and row.input_metadata.completion_params:
@@ -110,18 +122,22 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
110122
}
111123
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})
112124

113-
init_payload: Dict[str, Any] = {
114-
"rollout_id": row.execution_metadata.rollout_id,
115-
"model": model,
116-
"messages": clean_messages,
117-
"tools": row.tools,
118-
"metadata": meta,
119-
}
125+
if row.execution_metadata.rollout_id is None:
126+
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
127+
128+
init_payload: InitRequest = InitRequest(
129+
rollout_id=row.execution_metadata.rollout_id,
130+
model=model,
131+
messages=clean_messages,
132+
tools=row.tools,
133+
metadata=meta,
134+
model_base_url=config.kwargs.get("model_base_url", None),
135+
)
120136

121137
# Fire-and-poll
122138
def _post_init() -> None:
123139
url = f"{remote_base_url}/init"
124-
r = requests.post(url, json=init_payload, timeout=30)
140+
r = requests.post(url, json=init_payload.model_dump(), timeout=30)
125141
r.raise_for_status()
126142

127143
await asyncio.to_thread(_post_init)

eval_protocol/types/remote_rollout_processor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ class InitRequest(BaseModel):
2424
model: str
2525
messages: List[Message] = Field(min_length=1)
2626
tools: Optional[List[Dict[str, Any]]] = None
27+
28+
model_base_url: Optional[str] = None
29+
"""
30+
A Base URL that the remote server can use to make LLM calls. This is useful
31+
to configure on the eval-protocol side for flexibility in
32+
development/traning.
33+
"""
34+
2735
metadata: RolloutMetadata
2836

2937

0 commit comments

Comments
 (0)