Skip to content

Commit 9761e85

Browse files
committed
support extra headers
1 parent 8219c44 commit 9761e85

File tree

7 files changed

+111
-67
lines changed

7 files changed

+111
-67
lines changed

.env.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
FIREWORKS_API_KEY="your_fireworks_api_key_here"
88
FIREWORKS_ACCOUNT_ID="your_fireworks_account_id_here" # e.g., "fireworks" or your specific account
99

10+
# OpenAI Credentials (for using OpenAI models as judge)
11+
OPENAI_API_KEY="your_openai_api_key_here"
12+
1013
# Optional: If targeting a non-production Fireworks API endpoint
1114
# FIREWORKS_API_BASE="https://dev.api.fireworks.ai"
1215

eval_protocol/auth.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,79 @@ def get_fireworks_api_base() -> str:
287287
return api_base
288288

289289

290+
def get_extra_headers() -> Dict[str, str]:
291+
"""
292+
Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable.
293+
294+
The value should be a JSON object mapping header names to values.
295+
Example: FIREWORKS_EXTRA_HEADERS='{"x-custom-header": "value", "x-another": "value2"}'
296+
297+
Returns:
298+
Dictionary of extra headers, or empty dict if not set or invalid.
299+
"""
300+
import json
301+
302+
extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS")
303+
if not extra_headers_str:
304+
return {}
305+
306+
try:
307+
extra_headers = json.loads(extra_headers_str)
308+
if isinstance(extra_headers, dict):
309+
# Ensure all values are strings
310+
return {str(k): str(v) for k, v in extra_headers.items()}
311+
else:
312+
logger.warning("FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s", type(extra_headers).__name__)
313+
return {}
314+
except json.JSONDecodeError as e:
315+
logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s", e)
316+
return {}
317+
318+
319+
def get_platform_headers(
320+
api_key: Optional[str] = None,
321+
content_type: Optional[str] = "application/json",
322+
include_extra_headers: bool = True,
323+
) -> Dict[str, str]:
324+
"""
325+
Builds standard headers for Fireworks platform API requests.
326+
327+
This centralizes header construction including:
328+
- Authorization bearer token
329+
- Content-Type
330+
- User-Agent
331+
- Extra headers from FIREWORKS_EXTRA_HEADERS env var (JSON format)
332+
333+
Args:
334+
api_key: The API key for authorization. If None, resolves via get_fireworks_api_key().
335+
content_type: The Content-Type header value. Set to None to omit.
336+
include_extra_headers: Whether to include extra headers from FIREWORKS_EXTRA_HEADERS env var.
337+
338+
Returns:
339+
Dictionary of headers for platform API requests.
340+
"""
341+
from .common_utils import get_user_agent
342+
343+
resolved_api_key = api_key or get_fireworks_api_key()
344+
345+
headers: Dict[str, str] = {
346+
"User-Agent": get_user_agent(),
347+
}
348+
349+
if resolved_api_key:
350+
headers["Authorization"] = f"Bearer {resolved_api_key}"
351+
352+
if content_type:
353+
headers["Content-Type"] = content_type
354+
355+
# Include extra headers if set in environment
356+
if include_extra_headers:
357+
extra = get_extra_headers()
358+
headers.update(extra)
359+
360+
return headers
361+
362+
290363
def verify_api_key_and_get_account_id(
291364
api_key: Optional[str] = None,
292365
api_base: Optional[str] = None,

eval_protocol/cli_commands/create_rft.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import requests
99
from pydantic import ValidationError
1010

11-
from ..auth import get_fireworks_api_base, get_fireworks_api_key
11+
from ..auth import get_fireworks_api_base, get_fireworks_api_key, get_platform_headers
1212
from ..common_utils import get_user_agent
1313
from ..fireworks_rft import (
1414
build_default_output_model,
@@ -175,11 +175,7 @@ def _poll_evaluator_status(
175175
Returns:
176176
True if evaluator becomes ACTIVE, False if timeout or BUILD_FAILED
177177
"""
178-
headers = {
179-
"Authorization": f"Bearer {api_key}",
180-
"Content-Type": "application/json",
181-
"User-Agent": get_user_agent(),
182-
}
178+
headers = get_platform_headers(api_key=api_key, content_type="application/json")
183179

184180
check_url = f"{api_base}/v1/{evaluator_resource_name}"
185181
timeout_seconds = timeout_minutes * 60
@@ -517,11 +513,7 @@ def _upload_and_ensure_evaluator(
517513
# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
518514
if not force:
519515
try:
520-
headers = {
521-
"Authorization": f"Bearer {api_key}",
522-
"Content-Type": "application/json",
523-
"User-Agent": get_user_agent(),
524-
}
516+
headers = get_platform_headers(api_key=api_key, content_type="application/json")
525517
resp = requests.get(f"{api_base}/v1/{evaluator_resource_name}", headers=headers, timeout=10)
526518
if resp.ok:
527519
state = resp.json().get("state", "STATE_UNSPECIFIED")
@@ -702,7 +694,7 @@ def _create_rft_job(
702694
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
703695
if getattr(args, "evaluation_dataset", None):
704696
body["evaluationDataset"] = args.evaluation_dataset
705-
697+
706698
output_model_arg = getattr(args, "output_model", None)
707699
if output_model_arg:
708700
if len(output_model_arg) > 63:

eval_protocol/evaluation.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from eval_protocol.auth import (
1919
get_fireworks_account_id,
2020
get_fireworks_api_key,
21+
get_platform_headers,
2122
verify_api_key_and_get_account_id,
2223
)
2324
from eval_protocol.common_utils import get_user_agent
@@ -403,11 +404,7 @@ def preview(self, sample_file, max_samples=5):
403404
account_id = "pyroworks-dev"
404405

405406
url = f"{api_base}/v1/accounts/{account_id}/evaluators:previewEvaluator"
406-
headers = {
407-
"Authorization": f"Bearer {auth_token}",
408-
"Content-Type": "application/json",
409-
"User-Agent": get_user_agent(),
410-
}
407+
headers = get_platform_headers(api_key=auth_token, content_type="application/json")
411408
logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}")
412409
logger.debug(f"Preview API Request URL: {url}")
413410
logger.debug(f"Preview API Request Headers: {json.dumps(headers, indent=2)}")
@@ -749,11 +746,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False)
749746
account_id = "pyroworks-dev"
750747

751748
base_url = f"{self.api_base}/v1/{parent}/evaluatorsV2"
752-
headers = {
753-
"Authorization": f"Bearer {auth_token}",
754-
"Content-Type": "application/json",
755-
"User-Agent": get_user_agent(),
756-
}
749+
headers = get_platform_headers(api_key=auth_token, content_type="application/json")
757750

758751
self._ensure_requirements_present(os.getcwd())
759752

eval_protocol/fireworks_rft.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313

1414
import requests
1515

16-
from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key
17-
from .common_utils import get_user_agent
16+
from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key, get_platform_headers
1817

1918

2019
def _map_api_host_to_app_host(api_base: str) -> str:
@@ -142,11 +141,17 @@ def create_dataset_from_jsonl(
142141
display_name: Optional[str],
143142
jsonl_path: str,
144143
) -> Tuple[str, Dict[str, Any]]:
145-
headers = {
146-
"Authorization": f"Bearer {api_key}",
147-
"Content-Type": "application/json",
148-
"User-Agent": get_user_agent(),
149-
}
144+
import os
145+
146+
# DEBUG: Check environment variable
147+
extra_headers_env = os.environ.get("FIREWORKS_EXTRA_HEADERS", "<NOT SET>")
148+
print(f"[DEBUG] FIREWORKS_EXTRA_HEADERS env: {extra_headers_env}")
149+
150+
headers = get_platform_headers(api_key=api_key, content_type="application/json")
151+
152+
# DEBUG: Print headers (mask auth token)
153+
debug_headers = {k: (v[:20] + "..." if k == "Authorization" else v) for k, v in headers.items()}
154+
print(f"[DEBUG] Headers being sent: {debug_headers}")
150155
# Count examples quickly
151156
example_count = 0
152157
with open(jsonl_path, "r", encoding="utf-8") as f:
@@ -171,10 +176,8 @@ def create_dataset_from_jsonl(
171176
upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload"
172177
with open(jsonl_path, "rb") as f:
173178
files = {"file": f}
174-
up_headers = {
175-
"Authorization": f"Bearer {api_key}",
176-
"User-Agent": get_user_agent(),
177-
}
179+
# For file uploads, omit Content-Type (let requests set multipart boundary)
180+
up_headers = get_platform_headers(api_key=api_key, content_type=None)
178181
up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600)
179182
if up_resp.status_code not in (200, 201):
180183
raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}")
@@ -196,12 +199,8 @@ def create_reinforcement_fine_tuning_job(
196199
# Remove from body and append as query param
197200
body.pop("jobId", None)
198201
url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}"
199-
headers = {
200-
"Authorization": f"Bearer {api_key}",
201-
"Content-Type": "application/json",
202-
"Accept": "application/json",
203-
"User-Agent": get_user_agent(),
204-
}
202+
headers = get_platform_headers(api_key=api_key, content_type="application/json")
203+
headers["Accept"] = "application/json"
205204
resp = requests.post(url, json=body, headers=headers, timeout=60)
206205
if resp.status_code not in (200, 201):
207206
raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}")
@@ -217,22 +216,22 @@ def build_default_dataset_id(evaluator_id: str) -> str:
217216
def build_default_output_model(evaluator_id: str) -> str:
218217
base = evaluator_id.lower().replace("_", "-")
219218
uuid_suffix = str(uuid.uuid4())[:4]
220-
219+
221220
# suffix is "-rft-{4chars}" -> 9 chars
222221
suffix_len = 9
223222
max_len = 63
224-
223+
225224
# Check if we need to truncate
226225
if len(base) + suffix_len > max_len:
227226
# Calculate hash of the full base to preserve uniqueness
228227
hash_digest = hashlib.sha256(base.encode("utf-8")).hexdigest()[:6]
229228
# New structure: {truncated_base}-{hash}-{uuid_suffix}
230229
# Space needed for "-{hash}" is 1 + 6 = 7
231230
hash_part_len = 7
232-
231+
233232
allowed_base_len = max_len - suffix_len - hash_part_len
234233
truncated_base = base[:allowed_base_len].strip("-")
235-
234+
236235
return f"{truncated_base}-{hash_digest}-rft-{uuid_suffix}"
237236

238237
return f"{base}-rft-{uuid_suffix}"

eval_protocol/platform_api.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
get_fireworks_account_id,
1111
get_fireworks_api_base,
1212
get_fireworks_api_key,
13+
get_platform_headers,
1314
)
14-
from eval_protocol.common_utils import get_user_agent
1515

1616
logger = logging.getLogger(__name__)
1717

@@ -93,11 +93,7 @@ def create_or_update_fireworks_secret(
9393
logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.")
9494
return False
9595

96-
headers = {
97-
"Authorization": f"Bearer {resolved_api_key}",
98-
"Content-Type": "application/json",
99-
"User-Agent": get_user_agent(),
100-
}
96+
headers = get_platform_headers(api_key=resolved_api_key, content_type="application/json")
10197

10298
# The secret_id for GET/PATCH/DELETE operations is the key_name.
10399
# The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous.
@@ -219,10 +215,7 @@ def get_fireworks_secret(
219215
logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.")
220216
return None
221217

222-
headers = {
223-
"Authorization": f"Bearer {resolved_api_key}",
224-
"User-Agent": get_user_agent(),
225-
}
218+
headers = get_platform_headers(api_key=resolved_api_key, content_type=None)
226219
resource_id = _normalize_secret_resource_id(key_name)
227220

228221
try:
@@ -259,10 +252,7 @@ def delete_fireworks_secret(
259252
logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.")
260253
return False
261254

262-
headers = {
263-
"Authorization": f"Bearer {resolved_api_key}",
264-
"User-Agent": get_user_agent(),
265-
}
255+
headers = get_platform_headers(api_key=resolved_api_key, content_type=None)
266256
resource_id = _normalize_secret_resource_id(key_name)
267257

268258
try:

eval_protocol/pytest/handle_persist_flow.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import re
88
from typing import Any
99

10-
from eval_protocol.common_utils import get_user_agent
1110
from eval_protocol.directory_utils import find_eval_protocol_dir
1211
from eval_protocol.models import EvaluationRow
1312
from eval_protocol.pytest.store_experiment_link import store_experiment_link
@@ -16,6 +15,7 @@
1615
get_fireworks_account_id,
1716
verify_api_key_and_get_account_id,
1817
get_fireworks_api_base,
18+
get_platform_headers,
1919
)
2020

2121
import requests
@@ -130,11 +130,7 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
130130
continue
131131

132132
api_base = get_fireworks_api_base()
133-
headers = {
134-
"Authorization": f"Bearer {fireworks_api_key}",
135-
"Content-Type": "application/json",
136-
"User-Agent": get_user_agent(),
137-
}
133+
headers = get_platform_headers(api_key=fireworks_api_key, content_type="application/json")
138134

139135
# Make dataset first
140136

@@ -167,10 +163,8 @@ def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name:
167163
upload_url = f"{api_base}/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload"
168164
with open(exp_file, "rb") as f:
169165
files = {"file": f}
170-
upload_headers = {
171-
"Authorization": f"Bearer {fireworks_api_key}",
172-
"User-Agent": get_user_agent(),
173-
}
166+
# For file uploads, omit Content-Type (let requests set multipart boundary)
167+
upload_headers = get_platform_headers(api_key=fireworks_api_key, content_type=None)
174168
upload_response = requests.post(upload_url, files=files, headers=upload_headers)
175169

176170
# Skip if upload failed

0 commit comments

Comments
 (0)