Skip to content

Commit 8560c88

Browse files
author
Shrey Modi
committed
rolloutprocessor
1 parent 81a21e6 commit 8560c88

File tree

1 file changed

+1
-66
lines changed

1 file changed

+1
-66
lines changed

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 1 addition & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -123,72 +123,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
123123
if row.input_metadata.row_id is None:
124124
raise ValueError("Row ID is required in RemoteRolloutProcessor")
125125

126-
# Build request metadata and payload
127-
meta: RolloutMetadata = RolloutMetadata(
128-
invocation_id=row.execution_metadata.invocation_id,
129-
experiment_id=row.execution_metadata.experiment_id,
130-
rollout_id=row.execution_metadata.rollout_id,
131-
run_id=row.execution_metadata.run_id,
132-
row_id=row.input_metadata.row_id,
133-
)
134-
135-
completion_params_dict: Dict[str, Any] = {}
136-
137-
# Start with config-level completion_params
138-
if config.completion_params and isinstance(config.completion_params, dict):
139-
completion_params_dict.update(config.completion_params)
140-
141-
#Override with row-level completion_params
142-
if row.input_metadata and row.input_metadata.completion_params:
143-
row_cp = row.input_metadata.completion_params
144-
if isinstance(row_cp, dict):
145-
completion_params_dict.update(row_cp)
146-
147-
# Validate model is present
148-
if not completion_params_dict.get("model"):
149-
raise ValueError("Model must be provided in completion_params")
150-
151-
# Extract base_url from completion_params if provided
152-
completion_params_base_url: Optional[str] = completion_params_dict.get("base_url")
153-
154-
# Strip non-OpenAI fields from messages before sending to remote
155-
allowed_message_fields = {"role", "content", "tool_calls", "tool_call_id", "name"}
156-
clean_messages = []
157-
for m in row.messages:
158-
md: Dict[str, Any]
159-
if hasattr(m, "model_dump"):
160-
md = m.model_dump() # type: ignore[assignment]
161-
elif isinstance(m, dict):
162-
md = m # type: ignore[assignment]
163-
else:
164-
# Fallback to constructing a dict from Message-like object
165-
md = {
166-
"role": getattr(m, "role", None),
167-
"content": getattr(m, "content", None),
168-
"tool_calls": getattr(m, "tool_calls", None),
169-
"tool_call_id": getattr(m, "tool_call_id", None),
170-
"name": getattr(m, "name", None),
171-
}
172-
clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None})
173-
174-
if row.execution_metadata.rollout_id is None:
175-
raise ValueError("Rollout ID is required in RemoteRolloutProcessor")
176-
177-
final_model_base_url = model_base_url
178-
if model_base_url and (
179-
model_base_url.startswith("https://tracing.fireworks.ai")
180-
or model_base_url.startswith("http://localhost")
181-
):
182-
final_model_base_url = _build_fireworks_tracing_url(model_base_url, meta, completion_params_base_url)
183-
184-
init_payload: InitRequest = InitRequest(
185-
completion_params=completion_params_dict,
186-
messages=clean_messages,
187-
tools=row.tools,
188-
metadata=meta,
189-
model_base_url=final_model_base_url,
190-
elastic_search_config=self._elastic_search_config,
191-
)
126+
init_payload = build_init_request(row, config, model_base_url, self._elastic_search_config)
192127

193128
# Fire-and-poll
194129
def _post_init() -> None:

0 commit comments

Comments
 (0)