|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import json |
| 4 | +import os |
| 5 | +import shutil |
| 6 | +import subprocess |
| 7 | +import tempfile |
| 8 | +from dataclasses import dataclass |
| 9 | +from pathlib import Path |
| 10 | +from typing import Callable, Optional, cast |
| 11 | +from urllib.parse import quote |
| 12 | + |
| 13 | +import requests |
| 14 | + |
| 15 | +from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key |
| 16 | +from eval_protocol.common_utils import load_jsonl |
| 17 | +from eval_protocol.data_loader.models import ( |
| 18 | + DataLoaderResult, |
| 19 | + DataLoaderVariant, |
| 20 | + EvaluationDataLoader, |
| 21 | +) |
| 22 | +from eval_protocol.models import EvaluationRow, JSONType |
| 23 | + |
| 24 | + |
| 25 | +def _default_dataset_adapter(rows: list[dict[str, object]]) -> list[EvaluationRow]: |
| 26 | + """ |
| 27 | + Convert Fireworks dataset rows into EvaluationRow. |
| 28 | +
|
| 29 | + Preferred shape: |
| 30 | + - { messages: [...], ground_truth?: any } |
| 31 | +
|
| 32 | + Fallback (legacy demo): |
| 33 | + - { user_query: str, ground_truth_for_eval?: any } |
| 34 | + """ |
| 35 | + converted: list[EvaluationRow] = [] |
| 36 | + for row in rows: |
| 37 | + # Defer import to avoid cycles |
| 38 | + from eval_protocol.models import Message |
| 39 | + |
| 40 | + messages = row.get("messages") |
| 41 | + ground_truth = cast(JSONType, row.get("ground_truth")) |
| 42 | + if ground_truth is None: |
| 43 | + ground_truth = cast(JSONType, row.get("ground_truth_for_eval")) |
| 44 | + |
| 45 | + if isinstance(messages, list) and messages: |
| 46 | + normalized_messages: list[Message] = [] |
| 47 | + for m in messages: |
| 48 | + if isinstance(m, Message): |
| 49 | + normalized_messages.append(m) |
| 50 | + elif isinstance(m, dict): |
| 51 | + # Let Message handle content types (str or list) |
| 52 | + normalized_messages.append(Message.model_validate(m)) |
| 53 | + converted.append(EvaluationRow(messages=normalized_messages, ground_truth=ground_truth)) |
| 54 | + continue |
| 55 | + |
| 56 | + # Fallback: single-turn user_query |
| 57 | + user_query = str(row.get("user_query", "")) |
| 58 | + converted.append(EvaluationRow(messages=[Message(role="user", content=user_query)], ground_truth=ground_truth)) |
| 59 | + return converted |
| 60 | + |
| 61 | + |
| 62 | +def _download_fireworks_dataset_jsonl( |
| 63 | + dataset_ref: str, |
| 64 | + *, |
| 65 | + api_key: Optional[str] = None, |
| 66 | + api_base: Optional[str] = None, |
| 67 | +) -> Path: |
| 68 | + """ |
| 69 | + Download a Fireworks dataset to a temporary file and return its path. |
| 70 | +
|
| 71 | + This mirrors `firectl download dataset <ref>` behavior using HTTP APIs. |
| 72 | + We expect a single JSONL file under dataset/<name>/dataset_with_ground_truth_column_*.jsonl |
| 73 | + """ |
| 74 | + # Prefer firectl if available, as in user's example |
| 75 | + firectl_bin = shutil.which("firectl") |
| 76 | + if firectl_bin: |
| 77 | + tmp_root = Path(tempfile.mkdtemp(prefix="ep_fw_ds_")) |
| 78 | + # firectl requires an explicit --output-dir |
| 79 | + cmd = [firectl_bin, "download", "dataset", dataset_ref, "--output-dir", str(tmp_root)] |
| 80 | + proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| 81 | + if proc.returncode != 0: |
| 82 | + raise RuntimeError(f"firectl failed: {proc.stderr or proc.stdout}") |
| 83 | + # Expected structure: <tmp_root>/dataset/<name>/*.jsonl |
| 84 | + name_part = dataset_ref.split("/datasets/")[-1] if "/datasets/" in dataset_ref else None |
| 85 | + candidate_dir = tmp_root / "dataset" |
| 86 | + if name_part: |
| 87 | + candidate_dir = candidate_dir / name_part |
| 88 | + jsonl_files = ( |
| 89 | + list(candidate_dir.rglob("*.jsonl")) if candidate_dir.exists() else list((tmp_root).rglob("*.jsonl")) |
| 90 | + ) |
| 91 | + if not jsonl_files: |
| 92 | + raise RuntimeError("No JSONL files found after firectl download") |
| 93 | + # Prefer ground_truth jsonl |
| 94 | + jsonl_files.sort(key=lambda p: (0 if "ground_truth" in p.name else 1, 0 if p.suffix == ".jsonl" else 1)) |
| 95 | + return jsonl_files[0] |
| 96 | + |
| 97 | + # Fallback to HTTP API |
| 98 | + resolved_key = api_key or get_fireworks_api_key() |
| 99 | + if not resolved_key: |
| 100 | + raise RuntimeError("FIREWORKS_API_KEY is required to download Fireworks datasets") |
| 101 | + |
| 102 | + base = (api_base or get_fireworks_api_base()).rstrip("/") |
| 103 | + headers = {"Authorization": f"Bearer {resolved_key}"} |
| 104 | + |
| 105 | + encoded_ref = quote(dataset_ref, safe="") |
| 106 | + list_url = f"{base}/v1/datasets/{encoded_ref}/files" |
| 107 | + resp = requests.get(list_url, headers=headers, timeout=60) |
| 108 | + resp.raise_for_status() |
| 109 | + payload = resp.json() |
| 110 | + files = payload.get("files", []) if isinstance(payload, dict) else [] |
| 111 | + if not files: |
| 112 | + raise RuntimeError(f"No files found for dataset {dataset_ref}") |
| 113 | + |
| 114 | + def _score(name: str) -> tuple[int, int]: |
| 115 | + name_lower = name.lower() |
| 116 | + return ( |
| 117 | + 0 if "ground_truth" in name_lower else 1, |
| 118 | + 0 if name_lower.endswith(".jsonl") else 1, |
| 119 | + ) |
| 120 | + |
| 121 | + files_sorted = sorted(files, key=lambda f: _score(str(f.get("name", "")))) |
| 122 | + chosen = None |
| 123 | + for f in files_sorted: |
| 124 | + name = str(f.get("name", "")) |
| 125 | + if name.endswith(".jsonl"): |
| 126 | + chosen = f |
| 127 | + break |
| 128 | + if not chosen: |
| 129 | + raise RuntimeError(f"No JSONL file found for dataset {dataset_ref}") |
| 130 | + |
| 131 | + file_id = chosen.get("id") or chosen.get("file_id") or chosen.get("name") |
| 132 | + encoded_file = quote(str(file_id), safe="") |
| 133 | + dl_url = f"{base}/v1/datasets/{encoded_ref}/files/{encoded_file}:download" |
| 134 | + dl_resp = requests.get(dl_url, headers=headers, timeout=60) |
| 135 | + dl_resp.raise_for_status() |
| 136 | + dl_payload = dl_resp.json() |
| 137 | + signed_url = dl_payload.get("url") or dl_payload.get("signed_url") |
| 138 | + if not signed_url: |
| 139 | + raise RuntimeError("Failed to obtain signed URL for dataset file download") |
| 140 | + |
| 141 | + tmp_dir = Path(tempfile.mkdtemp(prefix="ep_fw_ds_")) |
| 142 | + out_path = tmp_dir / Path(str(chosen.get("name", "dataset.jsonl"))).name |
| 143 | + with requests.get(str(signed_url), stream=True, timeout=300) as r: |
| 144 | + r.raise_for_status() |
| 145 | + with open(out_path, "wb") as f: |
| 146 | + for chunk in r.iter_content(chunk_size=1 << 16): |
| 147 | + if chunk: |
| 148 | + f.write(chunk) |
| 149 | + |
| 150 | + return out_path |
| 151 | + |
| 152 | + |
| 153 | +@dataclass(kw_only=True) |
| 154 | +class FireworksDatasetLoader(EvaluationDataLoader): |
| 155 | + """ |
| 156 | + Data loader that downloads a dataset from Fireworks and emits `EvaluationRow`s. |
| 157 | +
|
| 158 | + - dataset_ref: e.g. "accounts/fireworks/datasets/demo-gsm8k-math-dataset-1000" |
| 159 | + - dataset_adapter: function to convert list[dict] -> list[EvaluationRow]. If not provided, |
| 160 | + defaults to an adapter that expects OpenAI-style `messages` rows and falls back to legacy demo shape. |
| 161 | + - max_rows: optional limit on number of rows to emit. |
| 162 | + - api_key/api_base: override resolution from environment if needed. |
| 163 | + """ |
| 164 | + |
| 165 | + dataset_ref: str |
| 166 | + dataset_adapter: Callable[[list[dict[str, object]]], list[EvaluationRow]] | None = None |
| 167 | + max_rows: Optional[int] = None |
| 168 | + api_key: Optional[str] = None |
| 169 | + api_base: Optional[str] = None |
| 170 | + id: str = "fireworks" |
| 171 | + description: Optional[str] = None |
| 172 | + |
| 173 | + def variants(self) -> list[DataLoaderVariant]: |
| 174 | + def _load() -> DataLoaderResult: |
| 175 | + jsonl_path = _download_fireworks_dataset_jsonl( |
| 176 | + self.dataset_ref, api_key=self.api_key, api_base=self.api_base |
| 177 | + ) |
| 178 | + try: |
| 179 | + raw_rows = load_jsonl(str(jsonl_path)) |
| 180 | + if self.max_rows is not None: |
| 181 | + raw_rows = raw_rows[: self.max_rows] |
| 182 | + adapter = self.dataset_adapter or _default_dataset_adapter |
| 183 | + rows = adapter(raw_rows) |
| 184 | + return DataLoaderResult( |
| 185 | + rows=rows, |
| 186 | + type=self.__class__.__name__, |
| 187 | + variant_id=self.id, |
| 188 | + variant_description=self.description or f"Fireworks dataset {self.dataset_ref}", |
| 189 | + ) |
| 190 | + finally: |
| 191 | + # Clean up temp file directory |
| 192 | + try: |
| 193 | + # Remove file and its parent temp dir |
| 194 | + p = Path(jsonl_path) |
| 195 | + parent = p.parent |
| 196 | + if p.exists(): |
| 197 | + p.unlink(missing_ok=True) # type: ignore[arg-type] |
| 198 | + # Attempt to remove directory if empty |
| 199 | + try: |
| 200 | + parent.rmdir() |
| 201 | + except OSError: |
| 202 | + pass |
| 203 | + except Exception: |
| 204 | + pass |
| 205 | + |
| 206 | + return [_load] |
0 commit comments