|
6 | 6 |
|
7 | 7 | from eval_protocol.models import EvaluationRow, Status |
8 | 8 | from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader |
| 9 | +from eval_protocol.types.remote_rollout_processor import InitRequest, RolloutMetadata |
9 | 10 | from .rollout_processor import RolloutProcessor |
10 | 11 | from .types import RolloutProcessorConfig |
11 | 12 |
|
@@ -71,14 +72,25 @@ def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> |
71 | 72 | async def _process_row(row: EvaluationRow) -> EvaluationRow: |
72 | 73 | start_time = time.perf_counter() |
73 | 74 |
|
| 75 | + if row.execution_metadata.invocation_id is None: |
| 76 | + raise ValueError("Invocation ID is required in RemoteRolloutProcessor") |
| 77 | + if row.execution_metadata.experiment_id is None: |
| 78 | + raise ValueError("Experiment ID is required in RemoteRolloutProcessor") |
| 79 | + if row.execution_metadata.rollout_id is None: |
| 80 | + raise ValueError("Rollout ID is required in RemoteRolloutProcessor") |
| 81 | + if row.execution_metadata.run_id is None: |
| 82 | + raise ValueError("Run ID is required in RemoteRolloutProcessor") |
| 83 | + if row.input_metadata.row_id is None: |
| 84 | + raise ValueError("Row ID is required in RemoteRolloutProcessor") |
| 85 | + |
74 | 86 | # Build request metadata and payload |
75 | | - meta: Dict[str, Any] = { |
76 | | - "invocation_id": row.execution_metadata.invocation_id, |
77 | | - "experiment_id": row.execution_metadata.experiment_id, |
78 | | - "rollout_id": row.execution_metadata.rollout_id, |
79 | | - "run_id": row.execution_metadata.run_id, |
80 | | - "row_id": row.input_metadata.row_id, |
81 | | - } |
| 87 | + meta: RolloutMetadata = RolloutMetadata( |
| 88 | + invocation_id=row.execution_metadata.invocation_id, |
| 89 | + experiment_id=row.execution_metadata.experiment_id, |
| 90 | + rollout_id=row.execution_metadata.rollout_id, |
| 91 | + run_id=row.execution_metadata.run_id, |
| 92 | + row_id=row.input_metadata.row_id, |
| 93 | + ) |
82 | 94 |
|
83 | 95 | model: Optional[str] = None |
84 | 96 | if row.input_metadata and row.input_metadata.completion_params: |
@@ -110,18 +122,22 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: |
110 | 122 | } |
111 | 123 | clean_messages.append({k: v for k, v in md.items() if k in allowed_message_fields and v is not None}) |
112 | 124 |
|
113 | | - init_payload: Dict[str, Any] = { |
114 | | - "rollout_id": row.execution_metadata.rollout_id, |
115 | | - "model": model, |
116 | | - "messages": clean_messages, |
117 | | - "tools": row.tools, |
118 | | - "metadata": meta, |
119 | | - } |
| 125 | + if row.execution_metadata.rollout_id is None: |
| 126 | + raise ValueError("Rollout ID is required in RemoteRolloutProcessor") |
| 127 | + |
| 128 | + init_payload: InitRequest = InitRequest( |
| 129 | + rollout_id=row.execution_metadata.rollout_id, |
| 130 | + model=model, |
| 131 | + messages=clean_messages, |
| 132 | + tools=row.tools, |
| 133 | + metadata=meta, |
| 134 | + model_base_url=config.kwargs.get("model_base_url", None), |
| 135 | + ) |
120 | 136 |
|
121 | 137 | # Fire-and-poll |
122 | 138 | def _post_init() -> None: |
123 | 139 | url = f"{remote_base_url}/init" |
124 | | - r = requests.post(url, json=init_payload, timeout=30) |
| 140 | + r = requests.post(url, json=init_payload.model_dump(), timeout=30) |
125 | 141 | r.raise_for_status() |
126 | 142 |
|
127 | 143 | await asyncio.to_thread(_post_init) |
|
0 commit comments