From a1e8fedf7ff14cc0614519997747ea7cc6820d55 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 3 Nov 2025 21:55:26 +0000 Subject: [PATCH 01/11] Add User-Agent header to Fireworks API requests Co-authored-by: dhuang --- eval_protocol/adapters/fireworks_tracing.py | 11 +++++++++-- eval_protocol/auth.py | 4 +++- eval_protocol/common_utils.py | 16 ++++++++++++++++ eval_protocol/evaluation.py | 3 +++ eval_protocol/fireworks_rft.py | 16 +++++++++++++--- eval_protocol/generation/clients.py | 3 +++ .../log_utils/fireworks_tracing_http_handler.py | 7 ++++++- eval_protocol/platform_api.py | 6 ++++-- eval_protocol/pytest/handle_persist_flow.py | 9 +++++++-- eval_protocol/reward_function.py | 2 ++ 10 files changed, 66 insertions(+), 11 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 233371cc..fdee0cf5 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -12,6 +12,7 @@ import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message +from eval_protocol.common_utils import get_user_agent from .base import BaseAdapter from .utils import extract_messages_from_data @@ -273,7 +274,10 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) - if not tags: raise ValueError("At least one tag is required to fetch logs") - headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} + headers = { + "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "User-Agent": get_user_agent(), + } params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"} # Try /logs first, fall back to /v1/logs if not found @@ -398,7 +402,10 @@ def get_evaluation_rows( else: url = f"{self.base_url}/v1/traces/pointwise" - headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} + headers = { + "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "User-Agent": get_user_agent(), + } result = None try: diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 6b0845bb..10ea57a7 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -6,6 +6,8 @@ import requests +from .common_utils import get_user_agent + logger = logging.getLogger(__name__) # Default locations (used for tests and as fallback). Actual resolution is dynamic via _get_auth_ini_file(). @@ -243,7 +245,7 @@ def verify_api_key_and_get_account_id( return None resolved_base = api_base or get_fireworks_api_base() url = f"{resolved_base.rstrip('/')}/verifyApiKey" - headers = {"Authorization": f"Bearer {resolved_key}"} + headers = {"Authorization": f"Bearer {resolved_key}", "User-Agent": get_user_agent()} resp = requests.get(url, headers=headers, timeout=10) if resp.status_code != 200: logger.debug("verifyApiKey returned status %s", resp.status_code) diff --git a/eval_protocol/common_utils.py b/eval_protocol/common_utils.py index 9b9032ab..65be1f66 100644 --- a/eval_protocol/common_utils.py +++ b/eval_protocol/common_utils.py @@ -5,6 +5,22 @@ import requests +def get_user_agent() -> str: + """ + Returns the user-agent string for eval-protocol CLI requests. + + Format: eval-protocol-cli/{version} + + Returns: + User-agent string identifying the eval-protocol CLI and version. + """ + try: + from . import __version__ + return f"eval-protocol-cli/{__version__}" + except Exception: + return "eval-protocol-cli/unknown" + + def load_jsonl(file_path: str) -> List[Dict[str, Any]]: """ Reads a JSONL file where each line is a valid JSON object and returns a list of these objects. diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 72459828..7d7ef3da 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -20,6 +20,7 @@ get_fireworks_api_key, verify_api_key_and_get_account_id, ) +from eval_protocol.common_utils import get_user_agent from eval_protocol.typed_interface import EvaluationMode from eval_protocol.get_pep440_version import get_pep440_version @@ -405,6 +406,7 @@ def preview(self, sample_file, max_samples=5): headers = { "Authorization": f"Bearer {auth_token}", "Content-Type": "application/json", + "User-Agent": get_user_agent(), } logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}") logger.debug(f"Preview API Request URL: {url}") @@ -748,6 +750,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) headers = { "Authorization": f"Bearer {auth_token}", "Content-Type": "application/json", + "User-Agent": get_user_agent(), } self._ensure_requirements_present(os.getcwd()) diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 6bd2e62e..4a8b351a 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -11,6 +11,7 @@ import requests from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key +from .common_utils import get_user_agent def _map_api_host_to_app_host(api_base: str) -> str: @@ -157,7 +158,11 @@ def create_dataset_from_jsonl( display_name: Optional[str], jsonl_path: str, ) -> Tuple[str, Dict[str, Any]]: - headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } # Count examples quickly example_count = 0 with open(jsonl_path, "r", encoding="utf-8") as f: @@ -181,7 +186,7 @@ def create_dataset_from_jsonl( upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" with open(jsonl_path, "rb") as f: files = {"file": f} - up_headers = {"Authorization": f"Bearer {api_key}"} + up_headers = {"Authorization": f"Bearer {api_key}", "User-Agent": get_user_agent()} up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) if up_resp.status_code not in (200, 201): raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") @@ -195,7 +200,12 @@ def create_reinforcement_fine_tuning_job( body: Dict[str, Any], ) -> Dict[str, Any]: url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs" - headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"} + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": get_user_agent(), + } resp = requests.post(url, json=body, headers=headers, timeout=60) if resp.status_code not in (200, 201): raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}") diff --git a/eval_protocol/generation/clients.py b/eval_protocol/generation/clients.py index 873f587e..7ac80272 100644 --- a/eval_protocol/generation/clients.py +++ b/eval_protocol/generation/clients.py @@ -13,6 +13,8 @@ from omegaconf import DictConfig from pydantic import BaseModel # Added for new models +from ..common_utils import get_user_agent + logger = logging.getLogger(__name__) @@ -101,6 +103,7 @@ async def generate( "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "Accept": "application/json", + "User-Agent": get_user_agent(), } debug_payload_log = json.loads(json.dumps(payload)) diff --git a/eval_protocol/log_utils/fireworks_tracing_http_handler.py b/eval_protocol/log_utils/fireworks_tracing_http_handler.py index 2031d13a..7582b6af 100644 --- a/eval_protocol/log_utils/fireworks_tracing_http_handler.py +++ b/eval_protocol/log_utils/fireworks_tracing_http_handler.py @@ -6,6 +6,8 @@ import requests +from ..common_utils import get_user_agent + class FireworksTracingHttpHandler(logging.Handler): """Logging handler that posts structured logs to tracing.fireworks gateway /logs endpoint.""" @@ -22,7 +24,10 @@ def __init__(self, gateway_base_url: Optional[str] = None, rollout_id_env: str = api_key = os.environ.get("FIREWORKS_API_KEY") if api_key: try: - self._session.headers.update({"Authorization": f"Bearer {api_key}"}) + self._session.headers.update({ + "Authorization": f"Bearer {api_key}", + "User-Agent": get_user_agent(), + }) except Exception: pass diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 5158d8e0..9e284924 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -11,6 +11,7 @@ get_fireworks_api_base, get_fireworks_api_key, ) +from eval_protocol.common_utils import get_user_agent logger = logging.getLogger(__name__) @@ -95,6 +96,7 @@ def create_or_update_fireworks_secret( headers = { "Authorization": f"Bearer {resolved_api_key}", "Content-Type": "application/json", + "User-Agent": get_user_agent(), } # The secret_id for GET/PATCH/DELETE operations is the key_name. @@ -217,7 +219,7 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - headers = {"Authorization": f"Bearer {resolved_api_key}"} + headers = {"Authorization": f"Bearer {resolved_api_key}", "User-Agent": get_user_agent()} resource_id = _normalize_secret_resource_id(key_name) url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" @@ -254,7 +256,7 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - headers = {"Authorization": f"Bearer {resolved_api_key}"} + headers = {"Authorization": f"Bearer {resolved_api_key}", "User-Agent": get_user_agent()} resource_id = _normalize_secret_resource_id(key_name) url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" diff --git a/eval_protocol/pytest/handle_persist_flow.py b/eval_protocol/pytest/handle_persist_flow.py index e2f2a93d..959b0651 100644 --- a/eval_protocol/pytest/handle_persist_flow.py +++ b/eval_protocol/pytest/handle_persist_flow.py @@ -7,6 +7,7 @@ import re from typing import Any +from eval_protocol.common_utils import get_user_agent from eval_protocol.directory_utils import find_eval_protocol_dir from eval_protocol.models import EvaluationRow from eval_protocol.pytest.store_experiment_link import store_experiment_link @@ -127,7 +128,11 @@ def get_auth_value(key: str) -> str | None: ) continue - headers = {"Authorization": f"Bearer {fireworks_api_key}", "Content-Type": "application/json"} + headers = { + "Authorization": f"Bearer {fireworks_api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } # Make dataset first dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets" @@ -160,7 +165,7 @@ def get_auth_value(key: str) -> str | None: upload_url = ( f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload" ) - upload_headers = {"Authorization": f"Bearer {fireworks_api_key}"} + upload_headers = {"Authorization": f"Bearer {fireworks_api_key}", "User-Agent": get_user_agent()} with open(exp_file, "rb") as f: files = {"file": f} diff --git a/eval_protocol/reward_function.py b/eval_protocol/reward_function.py index 6bd11974..c309012f 100644 --- a/eval_protocol/reward_function.py +++ b/eval_protocol/reward_function.py @@ -9,6 +9,7 @@ import requests +from .common_utils import get_user_agent from .models import EvaluateResult, MetricResult from .typed_interface import reward_function @@ -211,6 +212,7 @@ def __call__( headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" if api_key else "", + "User-Agent": get_user_agent(), } try: From cbc418b0a6aa0bc7e95235688068fb630b898dc8 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 3 Nov 2025 21:57:10 +0000 Subject: [PATCH 02/11] Remove unused User-Agent header from RewardFunction Co-authored-by: dhuang --- eval_protocol/reward_function.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/eval_protocol/reward_function.py b/eval_protocol/reward_function.py index c309012f..6bd11974 100644 --- a/eval_protocol/reward_function.py +++ b/eval_protocol/reward_function.py @@ -9,7 +9,6 @@ import requests -from .common_utils import get_user_agent from .models import EvaluateResult, MetricResult from .typed_interface import reward_function @@ -212,7 +211,6 @@ def __call__( headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}" if api_key else "", - "User-Agent": get_user_agent(), } try: From 00339e877f2d0a1a09aeb25fe10ec417f3ee8017 Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 3 Nov 2025 21:58:12 +0000 Subject: [PATCH 03/11] Remove unnecessary User-Agent header from Fireworks adapter Co-authored-by: dhuang --- eval_protocol/adapters/fireworks_tracing.py | 11 ++--------- .../log_utils/fireworks_tracing_http_handler.py | 7 +------ 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index fdee0cf5..233371cc 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -12,7 +12,6 @@ import os from eval_protocol.models import EvaluationRow, InputMetadata, ExecutionMetadata, Message -from eval_protocol.common_utils import get_user_agent from .base import BaseAdapter from .utils import extract_messages_from_data @@ -274,10 +273,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) - if not tags: raise ValueError("At least one tag is required to fetch logs") - headers = { - "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", - "User-Agent": get_user_agent(), - } + headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"} # Try /logs first, fall back to /v1/logs if not found @@ -402,10 +398,7 @@ def get_evaluation_rows( else: url = f"{self.base_url}/v1/traces/pointwise" - headers = { - "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", - "User-Agent": get_user_agent(), - } + headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} result = None try: diff --git a/eval_protocol/log_utils/fireworks_tracing_http_handler.py b/eval_protocol/log_utils/fireworks_tracing_http_handler.py index 7582b6af..2031d13a 100644 --- a/eval_protocol/log_utils/fireworks_tracing_http_handler.py +++ b/eval_protocol/log_utils/fireworks_tracing_http_handler.py @@ -6,8 +6,6 @@ import requests -from ..common_utils import get_user_agent - class FireworksTracingHttpHandler(logging.Handler): """Logging handler that posts structured logs to tracing.fireworks gateway /logs endpoint.""" @@ -24,10 +22,7 @@ def __init__(self, gateway_base_url: Optional[str] = None, rollout_id_env: str = api_key = os.environ.get("FIREWORKS_API_KEY") if api_key: try: - self._session.headers.update({ - "Authorization": f"Bearer {api_key}", - "User-Agent": get_user_agent(), - }) + self._session.headers.update({"Authorization": f"Bearer {api_key}"}) except Exception: pass From d406ce0f27097968f9bdc94f9c8c1078fd0ab86c Mon Sep 17 00:00:00 2001 From: Cursor Agent Date: Mon, 3 Nov 2025 22:04:40 +0000 Subject: [PATCH 04/11] Refactor: Use FireworksAPIClient for all API requests This change centralizes API request logic into a new FireworksAPIClient class, simplifying and standardizing how the Fireworks API is interacted with across the project. It removes redundant request setup code and ensures consistent headers are sent. Co-authored-by: dhuang --- eval_protocol/auth.py | 12 +- eval_protocol/evaluation.py | 29 ++--- eval_protocol/fireworks_api_client.py | 127 ++++++++++++++++++++ eval_protocol/fireworks_rft.py | 30 ++--- eval_protocol/platform_api.py | 29 ++--- eval_protocol/pytest/handle_persist_flow.py | 21 ++-- 6 files changed, 171 insertions(+), 77 deletions(-) create mode 100644 eval_protocol/fireworks_api_client.py diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index 10ea57a7..bdbd2912 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -4,10 +4,6 @@ from pathlib import Path from typing import Dict, Optional # Added Dict -import requests - -from .common_utils import get_user_agent - logger = logging.getLogger(__name__) # Default locations (used for tests and as fallback). Actual resolution is dynamic via _get_auth_ini_file(). @@ -244,9 +240,11 @@ def verify_api_key_and_get_account_id( if not resolved_key: return None resolved_base = api_base or get_fireworks_api_base() - url = f"{resolved_base.rstrip('/')}/verifyApiKey" - headers = {"Authorization": f"Bearer {resolved_key}", "User-Agent": get_user_agent()} - resp = requests.get(url, headers=headers, timeout=10) + + from .fireworks_api_client import FireworksAPIClient + client = FireworksAPIClient(api_key=resolved_key, api_base=resolved_base) + resp = client.get("verifyApiKey", timeout=10) + if resp.status_code != 200: logger.debug("verifyApiKey returned status %s", resp.status_code) return None diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 7d7ef3da..7371f68f 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -20,7 +20,7 @@ get_fireworks_api_key, verify_api_key_and_get_account_id, ) -from eval_protocol.common_utils import get_user_agent +from eval_protocol.fireworks_api_client import FireworksAPIClient from eval_protocol.typed_interface import EvaluationMode from eval_protocol.get_pep440_version import get_pep440_version @@ -402,20 +402,15 @@ def preview(self, sample_file, max_samples=5): if "dev.api.fireworks.ai" in api_base and account_id == "fireworks": account_id = "pyroworks-dev" - url = f"{api_base}/v1/accounts/{account_id}/evaluators:previewEvaluator" - headers = { - "Authorization": f"Bearer {auth_token}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } - logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}") - logger.debug(f"Preview API Request URL: {url}") - logger.debug(f"Preview API Request Headers: {json.dumps(headers, indent=2)}") + client = FireworksAPIClient(api_key=auth_token, api_base=api_base) + path = f"v1/accounts/{account_id}/evaluators:previewEvaluator" + + logger.info(f"Previewing evaluator using API endpoint: {api_base}/{path} with account: {account_id}") logger.debug(f"Preview API Request Payload: {json.dumps(payload, indent=2)}") global used_preview_api try: - response = requests.post(url, json=payload, headers=headers) + response = client.post(path, json=payload) response.raise_for_status() result = response.json() used_preview_api = True @@ -746,12 +741,8 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) if "dev.api.fireworks.ai" in self.api_base and account_id == "fireworks": account_id = "pyroworks-dev" - base_url = f"{self.api_base}/v1/{parent}/evaluatorsV2" - headers = { - "Authorization": f"Bearer {auth_token}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + client = FireworksAPIClient(api_key=auth_token, api_base=self.api_base) + path = f"v1/{parent}/evaluatorsV2" self._ensure_requirements_present(os.getcwd()) @@ -813,7 +804,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) upload_payload = {"name": evaluator_name, "filename_to_size": {tar_filename: tar_size}} logger.info(f"Requesting upload endpoint for {tar_filename}") - upload_response = requests.post(upload_endpoint_url, json=upload_payload, headers=headers) + upload_response = client.post(upload_endpoint_url, json=upload_payload) upload_response.raise_for_status() # Check for signed URLs @@ -895,7 +886,7 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) # Step 3: Validate upload validate_url = f"{self.api_base}/v1/{evaluator_name}:validateUpload" validate_payload = {"name": evaluator_name} - validate_response = requests.post(validate_url, json=validate_payload, headers=headers) + validate_response = client.post(validate_url, json=validate_payload) validate_response.raise_for_status() validate_data = validate_response.json() diff --git a/eval_protocol/fireworks_api_client.py b/eval_protocol/fireworks_api_client.py new file mode 100644 index 00000000..960e3f7c --- /dev/null +++ b/eval_protocol/fireworks_api_client.py @@ -0,0 +1,127 @@ +"""Centralized client for making requests to Fireworks API with consistent headers.""" + +import os +from typing import Any, Dict, Optional + +import requests + +from .common_utils import get_user_agent + + +class FireworksAPIClient: + """Client for making authenticated requests to Fireworks API with proper headers. + + This client automatically includes: + - Authorization header (Bearer token) + - User-Agent header for tracking eval-protocol CLI usage + """ + + def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): + """Initialize the Fireworks API client. + + Args: + api_key: Fireworks API key. If None, will be read from environment. + api_base: API base URL. If None, defaults to https://api.fireworks.ai + """ + self.api_key = api_key + self.api_base = api_base or os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") + self._session = requests.Session() + + def _get_headers(self, content_type: Optional[str] = "application/json", + additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """Build headers for API requests. + + Args: + content_type: Content-Type header value. If None, Content-Type won't be set. + additional_headers: Additional headers to merge in. + + Returns: + Dictionary of headers including authorization and user-agent. + """ + headers = { + "User-Agent": get_user_agent(), + } + + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + if content_type: + headers["Content-Type"] = content_type + + if additional_headers: + headers.update(additional_headers) + + return headers + + def get(self, path: str, params: Optional[Dict[str, Any]] = None, + timeout: int = 30, **kwargs) -> requests.Response: + """Make a GET request to the Fireworks API. + + Args: + path: API path (relative to api_base) + params: Query parameters + timeout: Request timeout in seconds + **kwargs: Additional arguments passed to requests.get + + Returns: + Response object + """ + url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" + headers = self._get_headers(content_type=None) + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + return self._session.get(url, params=params, headers=headers, timeout=timeout, **kwargs) + + def post(self, path: str, json: Optional[Dict[str, Any]] = None, + data: Optional[Any] = None, files: Optional[Dict[str, Any]] = None, + timeout: int = 60, **kwargs) -> requests.Response: + """Make a POST request to the Fireworks API. + + Args: + path: API path (relative to api_base) + json: JSON payload + data: Form data payload + files: Files to upload + timeout: Request timeout in seconds + **kwargs: Additional arguments passed to requests.post + + Returns: + Response object + """ + url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" + + # For file uploads, don't set Content-Type (let requests handle multipart/form-data) + content_type = None if files else "application/json" + headers = self._get_headers(content_type=content_type) + + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + + return self._session.post(url, json=json, data=data, files=files, + headers=headers, timeout=timeout, **kwargs) + + def put(self, path: str, json: Optional[Dict[str, Any]] = None, + timeout: int = 60, **kwargs) -> requests.Response: + """Make a PUT request to the Fireworks API.""" + url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" + headers = self._get_headers() + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + return self._session.put(url, json=json, headers=headers, timeout=timeout, **kwargs) + + def patch(self, path: str, json: Optional[Dict[str, Any]] = None, + timeout: int = 60, **kwargs) -> requests.Response: + """Make a PATCH request to the Fireworks API.""" + url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" + headers = self._get_headers() + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + return self._session.patch(url, json=json, headers=headers, timeout=timeout, **kwargs) + + def delete(self, path: str, timeout: int = 30, **kwargs) -> requests.Response: + """Make a DELETE request to the Fireworks API.""" + url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" + headers = self._get_headers(content_type=None) + if "headers" in kwargs: + headers.update(kwargs.pop("headers")) + return self._session.delete(url, headers=headers, timeout=timeout, **kwargs) diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 4a8b351a..0f2a1706 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -11,7 +11,7 @@ import requests from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key -from .common_utils import get_user_agent +from .fireworks_api_client import FireworksAPIClient def _map_api_host_to_app_host(api_base: str) -> str: @@ -158,17 +158,14 @@ def create_dataset_from_jsonl( display_name: Optional[str], jsonl_path: str, ) -> Tuple[str, Dict[str, Any]]: - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + client = FireworksAPIClient(api_key=api_key, api_base=api_base) + # Count examples quickly example_count = 0 with open(jsonl_path, "r", encoding="utf-8") as f: for _ in f: example_count += 1 - dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets" + payload = { "dataset": { "displayName": display_name or dataset_id, @@ -178,16 +175,15 @@ def create_dataset_from_jsonl( }, "datasetId": dataset_id, } - resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60) + resp = client.post(f"v1/accounts/{account_id}/datasets", json=payload, timeout=60) if resp.status_code not in (200, 201): raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}") ds = resp.json() - upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" with open(jsonl_path, "rb") as f: files = {"file": f} - up_headers = {"Authorization": f"Bearer {api_key}", "User-Agent": get_user_agent()} - up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) + up_resp = client.post(f"v1/accounts/{account_id}/datasets/{dataset_id}:upload", + files=files, timeout=600) if up_resp.status_code not in (200, 201): raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") return dataset_id, ds @@ -199,14 +195,10 @@ def create_reinforcement_fine_tuning_job( api_base: str, body: Dict[str, Any], ) -> Dict[str, Any]: - url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs" - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "Accept": "application/json", - "User-Agent": get_user_agent(), - } - resp = requests.post(url, json=body, headers=headers, timeout=60) + client = FireworksAPIClient(api_key=api_key, api_base=api_base) + resp = client.post(f"v1/accounts/{account_id}/reinforcementFineTuningJobs", + json=body, timeout=60, + headers={"Accept": "application/json"}) if resp.status_code not in (200, 201): raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}") return resp.json() diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 9e284924..32dc141e 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -11,7 +11,7 @@ get_fireworks_api_base, get_fireworks_api_key, ) -from eval_protocol.common_utils import get_user_agent +from eval_protocol.fireworks_api_client import FireworksAPIClient logger = logging.getLogger(__name__) @@ -93,11 +93,7 @@ def create_or_update_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") return False - headers = { - "Authorization": f"Bearer {resolved_api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base) # The secret_id for GET/PATCH/DELETE operations is the key_name. # The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous. @@ -109,10 +105,9 @@ def create_or_update_fireworks_secret( # Check if secret exists using GET (path uses normalized resource id) resource_id = _normalize_secret_resource_id(key_name) - get_url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" secret_exists = False try: - response = requests.get(get_url, headers=headers, timeout=10) + response = client.get(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=10) if response.status_code == 200: secret_exists = True logger.info(f"Secret '{key_name}' already exists. Will attempt to update.") @@ -133,7 +128,6 @@ def create_or_update_fireworks_secret( if secret_exists: # Update existing secret (PATCH) - patch_url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" # Body for PATCH requires 'keyName' and 'value'. # Transform key_name for payload: uppercase and underscores payload_key_name = key_name.upper().replace("-", "_") @@ -148,7 +142,8 @@ def create_or_update_fireworks_secret( payload = {"keyName": payload_key_name, "value": secret_value} try: logger.debug(f"PATCH payload for '{key_name}': {payload}") - response = requests.patch(patch_url, headers=headers, json=payload, timeout=30) + response = client.patch(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", + json=payload, timeout=30) response.raise_for_status() logger.info(f"Successfully updated secret '{key_name}' on Fireworks platform.") return True @@ -160,7 +155,6 @@ def create_or_update_fireworks_secret( return False else: # Create new secret (POST) - post_url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets" # Body for POST is gatewaySecret. 'name' field in payload is the resource path. # Let's assume for POST, the 'name' in payload can be omitted or is the key_name. # The API should ideally use 'keyName' from URL or a specific 'secretId' in payload for creation if 'name' is server-assigned. @@ -185,7 +179,8 @@ def create_or_update_fireworks_secret( } try: logger.debug(f"POST payload for '{key_name}': {payload}") - response = requests.post(post_url, headers=headers, json=payload, timeout=30) + response = client.post(f"v1/accounts/{resolved_account_id}/secrets", + json=payload, timeout=30) response.raise_for_status() logger.info( f"Successfully created secret '{key_name}' on Fireworks platform. Full name: {response.json().get('name')}" @@ -219,12 +214,11 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - headers = {"Authorization": f"Bearer {resolved_api_key}", "User-Agent": get_user_agent()} + client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base) resource_id = _normalize_secret_resource_id(key_name) - url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" try: - response = requests.get(url, headers=headers, timeout=10) + response = client.get(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=10) if response.status_code == 200: logger.info(f"Successfully retrieved secret '{key_name}'.") return response.json() @@ -256,12 +250,11 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - headers = {"Authorization": f"Bearer {resolved_api_key}", "User-Agent": get_user_agent()} + client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base) resource_id = _normalize_secret_resource_id(key_name) - url = f"{resolved_api_base.rstrip('/')}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" try: - response = requests.delete(url, headers=headers, timeout=30) + response = client.delete(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=30) if response.status_code == 200 or response.status_code == 204: # 204 No Content is also success for DELETE logger.info(f"Successfully deleted secret '{key_name}'.") return True diff --git a/eval_protocol/pytest/handle_persist_flow.py b/eval_protocol/pytest/handle_persist_flow.py index 959b0651..50a2aed8 100644 --- a/eval_protocol/pytest/handle_persist_flow.py +++ b/eval_protocol/pytest/handle_persist_flow.py @@ -7,8 +7,8 @@ import re from typing import Any -from eval_protocol.common_utils import get_user_agent from eval_protocol.directory_utils import find_eval_protocol_dir +from eval_protocol.fireworks_api_client import FireworksAPIClient from eval_protocol.models import EvaluationRow from eval_protocol.pytest.store_experiment_link import store_experiment_link import requests @@ -128,14 +128,10 @@ def get_auth_value(key: str) -> str | None: ) continue - headers = { - "Authorization": f"Bearer {fireworks_api_key}", - "Content-Type": "application/json", - "User-Agent": get_user_agent(), - } + client = FireworksAPIClient(api_key=fireworks_api_key, + api_base="https://api.fireworks.ai") # Make dataset first - dataset_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets" dataset_payload = { # pyright: ignore[reportUnknownVariableType] "dataset": { @@ -162,14 +158,10 @@ def get_auth_value(key: str) -> str | None: dataset_id = dataset_data.get("datasetId", dataset_name) # pyright: ignore[reportAny] # Upload the JSONL file content - upload_url = ( - f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload" - ) - upload_headers = {"Authorization": f"Bearer {fireworks_api_key}", "User-Agent": get_user_agent()} - with open(exp_file, "rb") as f: files = {"file": f} - upload_response = requests.post(upload_url, files=files, headers=upload_headers) + upload_response = client.post(f"v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload", + files=files) # Skip if upload failed if upload_response.status_code not in [200, 201]: @@ -199,7 +191,8 @@ def get_auth_value(key: str) -> str | None: }, } - eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers) + eval_response = client.post(f"v1/accounts/{fireworks_account_id}/evaluationJobs", + json=eval_job_payload) if eval_response.status_code in [200, 201]: eval_job_data = eval_response.json() # pyright: ignore[reportAny] From 0ba3df5a5836e008e848b637e59b41e96f6cb97a Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 14:38:00 -0800 Subject: [PATCH 05/11] fix lint errors --- eval_protocol/evaluation.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 7371f68f..2985fd66 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -404,7 +404,7 @@ def preview(self, sample_file, max_samples=5): client = FireworksAPIClient(api_key=auth_token, api_base=api_base) path = f"v1/accounts/{account_id}/evaluators:previewEvaluator" - + logger.info(f"Previewing evaluator using API endpoint: {api_base}/{path} with account: {account_id}") logger.debug(f"Preview API Request Payload: {json.dumps(payload, indent=2)}") @@ -750,16 +750,16 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) try: if force: - check_url = f"{self.api_base}/v1/{parent}/evaluators/{evaluator_id}" + check_path = f"v1/{parent}/evaluators/{evaluator_id}" try: - logger.info(f"Checking if evaluator exists: {check_url}") - check_response = requests.get(check_url, headers=headers) + logger.info(f"Checking if evaluator exists: {self.api_base}/{check_path}") + check_response = client.get(check_path) if check_response.status_code == 200: logger.info(f"Evaluator '{evaluator_id}' already exists, deleting and recreating...") - delete_url = f"{self.api_base}/v1/{parent}/evaluators/{evaluator_id}" + delete_path = f"v1/{parent}/evaluators/{evaluator_id}" try: - delete_response = requests.delete(delete_url, headers=headers) + delete_response = client.delete(delete_path) if delete_response.status_code < 400: logger.info(f"Successfully deleted evaluator '{evaluator_id}'") else: @@ -768,14 +768,14 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) ) except Exception as e_del: logger.warning(f"Error deleting evaluator: {str(e_del)}") - response = requests.post(base_url, json=payload_data, headers=headers) + response = client.post(path, json=payload_data) else: - response = requests.post(base_url, json=payload_data, headers=headers) + response = client.post(path, json=payload_data) except requests.exceptions.RequestException: - response = requests.post(base_url, json=payload_data, headers=headers) + response = client.post(path, json=payload_data) else: - logger.info(f"Creating evaluator at: {base_url}") - response = requests.post(base_url, json=payload_data, headers=headers) + logger.info(f"Creating evaluator at: {self.api_base}/{path}") + response = client.post(path, json=payload_data) response.raise_for_status() result = response.json() From ff67affebb366a969399e4183b9213b740f1f28e Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 14:40:14 -0800 Subject: [PATCH 06/11] fix lint errors --- eval_protocol/pytest/handle_persist_flow.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/eval_protocol/pytest/handle_persist_flow.py b/eval_protocol/pytest/handle_persist_flow.py index 50a2aed8..184c0077 100644 --- a/eval_protocol/pytest/handle_persist_flow.py +++ b/eval_protocol/pytest/handle_persist_flow.py @@ -11,7 +11,6 @@ from eval_protocol.fireworks_api_client import FireworksAPIClient from eval_protocol.models import EvaluationRow from eval_protocol.pytest.store_experiment_link import store_experiment_link -import requests def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: str): @@ -128,8 +127,7 @@ def get_auth_value(key: str) -> str | None: ) continue - client = FireworksAPIClient(api_key=fireworks_api_key, - api_base="https://api.fireworks.ai") + client = FireworksAPIClient(api_key=fireworks_api_key, api_base="https://api.fireworks.ai") # Make dataset first @@ -143,7 +141,9 @@ def get_auth_value(key: str) -> str | None: "datasetId": dataset_name, } - dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers) # pyright: ignore[reportUnknownArgumentType] + dataset_response = client.post( + f"v1/accounts/{fireworks_account_id}/datasets", json=dataset_payload + ) # pyright: ignore[reportUnknownArgumentType] # Skip if dataset creation failed if dataset_response.status_code not in [200, 201]: @@ -160,8 +160,9 @@ def get_auth_value(key: str) -> str | None: # Upload the JSONL file content with open(exp_file, "rb") as f: files = {"file": f} - upload_response = client.post(f"v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload", - files=files) + upload_response = client.post( + f"v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload", files=files + ) # Skip if upload failed if upload_response.status_code not in [200, 201]: @@ -173,7 +174,6 @@ def get_auth_value(key: str) -> str | None: continue # Create evaluation job (optional - don't skip experiment if this fails) - eval_job_url = f"https://api.fireworks.ai/v1/accounts/{fireworks_account_id}/evaluationJobs" # Truncate job ID to fit 63 character limit job_id_base = f"{dataset_name}-job" if len(job_id_base) > 63: @@ -191,8 +191,9 @@ def get_auth_value(key: str) -> str | None: }, } - eval_response = client.post(f"v1/accounts/{fireworks_account_id}/evaluationJobs", - json=eval_job_payload) + eval_response = client.post( + f"v1/accounts/{fireworks_account_id}/evaluationJobs", json=eval_job_payload + ) if eval_response.status_code in [200, 201]: eval_job_data = eval_response.json() # pyright: ignore[reportAny] From 1b10a8f20a7919282ec57f1baf4cbe53a30af30d Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 14:41:34 -0800 Subject: [PATCH 07/11] Update User-Agent format in get_user_agent function to use 'eval-protocol' prefix --- eval_protocol/common_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/eval_protocol/common_utils.py b/eval_protocol/common_utils.py index 65be1f66..3bca887f 100644 --- a/eval_protocol/common_utils.py +++ b/eval_protocol/common_utils.py @@ -8,17 +8,18 @@ def get_user_agent() -> str: """ Returns the user-agent string for eval-protocol CLI requests. - + Format: eval-protocol-cli/{version} - + Returns: User-agent string identifying the eval-protocol CLI and version. """ try: from . import __version__ - return f"eval-protocol-cli/{__version__}" + + return f"eval-protocol/{__version__}" except Exception: - return "eval-protocol-cli/unknown" + return "eval-protocol/unknown" def load_jsonl(file_path: str) -> List[Dict[str, Any]]: From 1a14b73edd4f7f80089f5397cecba30de33d573f Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 14:43:07 -0800 Subject: [PATCH 08/11] added tests --- tests/test_fireworks_api_client.py | 229 +++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) create mode 100644 tests/test_fireworks_api_client.py diff --git a/tests/test_fireworks_api_client.py b/tests/test_fireworks_api_client.py new file mode 100644 index 00000000..b4acaeba --- /dev/null +++ b/tests/test_fireworks_api_client.py @@ -0,0 +1,229 @@ +"""Tests for FireworksAPIClient user-agent header functionality.""" + +import re +from unittest.mock import MagicMock, patch + +import pytest + +from eval_protocol.common_utils import get_user_agent +from eval_protocol.fireworks_api_client import FireworksAPIClient + + +class TestFireworksAPIClientUserAgent: + """Test that FireworksAPIClient correctly sets the User-Agent header.""" + + def test_get_user_agent_format(self): + """Test that get_user_agent returns the expected format.""" + user_agent = get_user_agent() + # Should match format: eval-protocol/{version} + # Version can be actual version or "unknown" + assert user_agent.startswith("eval-protocol/") + assert len(user_agent) > len("eval-protocol/") + + def test_get_user_agent_fallback_logic(self): + """Test that get_user_agent has fallback logic for when version can't be imported. + + This test verifies the code structure, since actually triggering an import + failure during the import statement is difficult to test reliably. + The important behavior (User-Agent header being set) is verified in other tests. + """ + # Verify the function exists and can be called normally + user_agent = get_user_agent() + # The function should always return a valid user agent string + assert isinstance(user_agent, str) + assert user_agent.startswith("eval-protocol/") + + # The actual fallback ("eval-protocol/unknown") happens when the import + # fails, which is hard to simulate without patching at a very low level. + # The try/except block in the implementation handles this gracefully. + + def test_get_headers_includes_user_agent(self): + """Test that _get_headers includes the User-Agent header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + headers = client._get_headers() + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + + def test_get_request_includes_user_agent(self): + """Test that GET requests include the User-Agent header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "get", return_value=mock_response) as mock_get: + client.get("test/path") + + mock_get.assert_called_once() + call_kwargs = mock_get.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + assert headers["Authorization"] == "Bearer test_key" + + def test_post_request_includes_user_agent(self): + """Test that POST requests include the User-Agent header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + client.post("test/path", json={"key": "value"}) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + assert headers["Authorization"] == "Bearer test_key" + assert headers["Content-Type"] == "application/json" + + def test_post_with_files_excludes_content_type(self): + """Test that POST requests with files exclude Content-Type header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + client.post("test/path", files={"file": MagicMock()}) + + mock_post.assert_called_once() + call_kwargs = mock_post.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + # Content-Type should not be set when files are present + assert "Content-Type" not in headers + + def test_put_request_includes_user_agent(self): + """Test that PUT requests include the User-Agent header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "put", return_value=mock_response) as mock_put: + client.put("test/path", json={"key": "value"}) + + mock_put.assert_called_once() + call_kwargs = mock_put.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + assert headers["Authorization"] == "Bearer test_key" + + def test_patch_request_includes_user_agent(self): + """Test that PATCH requests include the User-Agent header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "patch", return_value=mock_response) as mock_patch: + client.patch("test/path", json={"key": "value"}) + + mock_patch.assert_called_once() + call_kwargs = mock_patch.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + assert headers["Authorization"] == "Bearer test_key" + + def test_delete_request_includes_user_agent(self): + """Test that DELETE requests include the User-Agent header.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "delete", return_value=mock_response) as mock_delete: + client.delete("test/path") + + mock_delete.assert_called_once() + call_kwargs = mock_delete.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + assert headers["Authorization"] == "Bearer test_key" + # DELETE requests shouldn't have Content-Type + assert "Content-Type" not in headers + + def test_additional_headers_merged(self): + """Test that additional headers passed to requests are merged with User-Agent.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "get", return_value=mock_response) as mock_get: + client.get("test/path", headers={"X-Custom-Header": "custom-value"}) + + mock_get.assert_called_once() + call_kwargs = mock_get.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + assert headers["X-Custom-Header"] == "custom-value" + + def test_user_agent_consistent_across_methods(self): + """Test that User-Agent is consistent across all HTTP methods.""" + client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + expected_user_agent = get_user_agent() + + # Test all methods + methods = [ + ("get", lambda: client.get("test/path")), + ("post", lambda: client.post("test/path", json={})), + ("put", lambda: client.put("test/path", json={})), + ("patch", lambda: client.patch("test/path", json={})), + ("delete", lambda: client.delete("test/path")), + ] + + for method_name, method_call in methods: + with patch.object(client._session, method_name, return_value=mock_response) as mock_method: + method_call() + + call_kwargs = mock_method.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers, f"{method_name} should include User-Agent" + assert headers["User-Agent"] == expected_user_agent, ( + f"{method_name} User-Agent should match expected value" + ) + + def test_user_agent_without_api_key(self): + """Test that User-Agent is still included even without API key.""" + client = FireworksAPIClient(api_key=None, api_base="https://api.fireworks.ai") + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "get", return_value=mock_response) as mock_get: + client.get("test/path") + + mock_get.assert_called_once() + call_kwargs = mock_get.call_args[1] + headers = call_kwargs["headers"] + + assert "User-Agent" in headers + assert headers["User-Agent"] == get_user_agent() + # Authorization should not be present + assert "Authorization" not in headers + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 432021aca4323ad788a1f3044a168d1c7d8af96c Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 14:46:11 -0800 Subject: [PATCH 09/11] update --- eval_protocol/evaluation.py | 8 +- tests/test_fireworks_api_client.py | 239 +++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+), 4 deletions(-) diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index 2985fd66..af582bc2 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -800,11 +800,11 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) tar_size = self._create_tar_gz_with_ignores(tar_path, cwd) # Call GetEvaluatorUploadEndpoint - upload_endpoint_url = f"{self.api_base}/v1/{evaluator_name}:getUploadEndpoint" + upload_endpoint_path = f"v1/{evaluator_name}:getUploadEndpoint" upload_payload = {"name": evaluator_name, "filename_to_size": {tar_filename: tar_size}} logger.info(f"Requesting upload endpoint for {tar_filename}") - upload_response = client.post(upload_endpoint_url, json=upload_payload) + upload_response = client.post(upload_endpoint_path, json=upload_payload) upload_response.raise_for_status() # Check for signed URLs @@ -884,9 +884,9 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) raise # Step 3: Validate upload - validate_url = f"{self.api_base}/v1/{evaluator_name}:validateUpload" + validate_path = f"v1/{evaluator_name}:validateUpload" validate_payload = {"name": evaluator_name} - validate_response = client.post(validate_url, json=validate_payload) + validate_response = client.post(validate_path, json=validate_payload) validate_response.raise_for_status() validate_data = validate_response.json() diff --git a/tests/test_fireworks_api_client.py b/tests/test_fireworks_api_client.py index b4acaeba..72626b8f 100644 --- a/tests/test_fireworks_api_client.py +++ b/tests/test_fireworks_api_client.py @@ -225,5 +225,244 @@ def test_user_agent_without_api_key(self): assert "Authorization" not in headers +class TestFireworksAPIClientPathHandling: + """Test that FireworksAPIClient correctly handles relative paths and prevents URL construction bugs.""" + + def test_post_relative_path_combines_with_api_base(self): + """Test that POST requests correctly combine relative paths with api_base.""" + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + relative_path = "v1/test/evaluator:getUploadEndpoint" + client.post(relative_path, json={"name": "test"}) + + mock_post.assert_called_once() + call_args = mock_post.call_args + # Check the URL passed to requests.post + assert call_args[0][0] == f"{api_base}/{relative_path}" + assert not call_args[0][0].startswith(f"{api_base}/{api_base}") + + def test_get_relative_path_combines_with_api_base(self): + """Test that GET requests correctly combine relative paths with api_base.""" + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "get", return_value=mock_response) as mock_get: + relative_path = "verifyApiKey" + client.get(relative_path) + + mock_get.assert_called_once() + call_args = mock_get.call_args + assert call_args[0][0] == f"{api_base}/{relative_path}" + + def test_post_get_upload_endpoint_path(self): + """Test the specific getUploadEndpoint path that was buggy. + + This ensures relative paths like 'v1/{name}:getUploadEndpoint' are handled correctly + and don't get double-prefixed with api_base. + """ + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"filenameToSignedUrls": {"test.tar.gz": "https://signed.url"}} + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + evaluator_name = "test-evaluator" + # This is the correct pattern - relative path, not full URL + upload_endpoint_path = f"v1/{evaluator_name}:getUploadEndpoint" + client.post(upload_endpoint_path, json={"name": evaluator_name}) + + mock_post.assert_called_once() + call_args = mock_post.call_args + expected_url = f"{api_base}/{upload_endpoint_path}" + actual_url = call_args[0][0] + assert actual_url == expected_url, f"Expected {expected_url}, got {actual_url}" + # Ensure it doesn't have the buggy double-prefix + assert not actual_url.startswith(f"{api_base}/{api_base}") + + def test_post_validate_upload_path(self): + """Test the specific validateUpload path that was buggy. + + This ensures relative paths like 'v1/{name}:validateUpload' are handled correctly. + """ + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"status": "validated"} + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + evaluator_name = "test-evaluator" + # This is the correct pattern - relative path, not full URL + validate_path = f"v1/{evaluator_name}:validateUpload" + client.post(validate_path, json={"name": evaluator_name}) + + mock_post.assert_called_once() + call_args = mock_post.call_args + expected_url = f"{api_base}/{validate_path}" + actual_url = call_args[0][0] + assert actual_url == expected_url, f"Expected {expected_url}, got {actual_url}" + # Ensure it doesn't have the buggy double-prefix + assert not actual_url.startswith(f"{api_base}/{api_base}") + + def test_path_with_leading_slash_stripped(self): + """Test that leading slashes in paths are correctly handled.""" + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "get", return_value=mock_response) as mock_get: + # Path with leading slash should be handled correctly + client.get("/v1/test/path") + + mock_get.assert_called_once() + call_args = mock_get.call_args + # Should not have double slash + assert call_args[0][0] == f"{api_base}/v1/test/path" + + def test_api_base_with_trailing_slash(self): + """Test that api_base with trailing slash is handled correctly.""" + api_base = "https://api.fireworks.ai/" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + relative_path = "v1/test/path" + client.post(relative_path, json={}) + + mock_post.assert_called_once() + call_args = mock_post.call_args + # Should not have double slash + assert call_args[0][0] == f"https://api.fireworks.ai/{relative_path}" + + def test_all_http_methods_with_relative_paths(self): + """Test that all HTTP methods correctly handle relative paths.""" + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + test_path = "v1/accounts/test/evaluators" + + methods = [ + ("get", lambda p: client.get(p)), + ("post", lambda p: client.post(p, json={})), + ("put", lambda p: client.put(p, json={})), + ("patch", lambda p: client.patch(p, json={})), + ("delete", lambda p: client.delete(p)), + ] + + for method_name, method_call in methods: + with patch.object(client._session, method_name, return_value=mock_response) as mock_method: + method_call(test_path) + + mock_method.assert_called_once() + call_args = mock_method.call_args + expected_url = f"{api_base}/{test_path}" + actual_url = call_args[0][0] + assert actual_url == expected_url, f"{method_name.upper()} expected {expected_url}, got {actual_url}" + # Ensure no double-prefix bug + assert not actual_url.startswith(f"{api_base}/{api_base}"), ( + f"{method_name.upper()} URL has double-prefix bug: {actual_url}" + ) + + def test_paths_containing_v1_pattern(self): + """Test various v1 API paths to ensure correct URL construction.""" + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + test_cases = [ + "v1/accounts/test/evaluators", + "v1/accounts/test/evaluators/eval-id", + "v1/accounts/test/evaluatorsV2", + "v1/accounts/test/evaluators:previewEvaluator", + "v1/test-evaluator:getUploadEndpoint", + "v1/test-evaluator:validateUpload", + ] + + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + for path in test_cases: + client.post(path, json={}) + + call_args = mock_post.call_args + actual_url = call_args[0][0] + expected_url = f"{api_base}/{path}" + + assert actual_url == expected_url, ( + f"Path '{path}' resulted in URL '{actual_url}', expected '{expected_url}'" + ) + assert not actual_url.startswith(f"{api_base}/{api_base}"), ( + f"Path '{path}' has double-prefix bug: {actual_url}" + ) + + mock_post.reset_mock() + + def test_full_url_passed_by_mistake_detected(self): + """Test that accidentally passing a full URL instead of relative path is detected. + + This test documents the bug pattern: if a full URL like '{api_base}/v1/path' + is passed instead of a relative path like 'v1/path', it will result in a + malformed URL like '{api_base}/{api_base}/v1/path'. + + This test verifies that our code correctly handles relative paths (which prevents + the bug), and documents what would happen if the bug occurred. + """ + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) + + mock_response = MagicMock() + mock_response.status_code = 200 + + # CORRECT: Relative path (what we should use) + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + correct_relative_path = "v1/test-evaluator:getUploadEndpoint" + client.post(correct_relative_path, json={}) + + call_args = mock_post.call_args + correct_url = call_args[0][0] + expected_correct_url = f"{api_base}/{correct_relative_path}" + assert correct_url == expected_correct_url + + # INCORRECT: Full URL (this would cause the bug - but we're not actually testing this, + # just documenting that our current implementation would create a malformed URL) + # If someone accidentally did: client.post(f"{api_base}/v1/path", ...) + # The result would be: f"{api_base}/{api_base}/v1/path" which is wrong. + # Our tests above ensure we use relative paths, preventing this bug. + mock_post.reset_mock() + with patch.object(client._session, "post", return_value=mock_response) as mock_post: + # Simulating what WOULD happen if buggy code passed full URL + buggy_full_url = f"{api_base}/v1/test-evaluator:getUploadEndpoint" + client.post(buggy_full_url, json={}) + + call_args = mock_post.call_args + buggy_url = call_args[0][0] + # This shows what the buggy URL would look like + buggy_expected = f"{api_base}/{buggy_full_url}" + + # This assertion documents the bug pattern - the URL would be malformed + assert buggy_url == buggy_expected + assert buggy_url.startswith(f"{api_base}/{api_base}"), ( + "This documents the bug: passing full URL creates double-prefix. Always use relative paths!" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 04a2a2faf1676930f702d5ebfc83068be3311c99 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 14:47:38 -0800 Subject: [PATCH 10/11] reject absolute within path arg --- eval_protocol/fireworks_api_client.py | 98 +++++++++++++++++---------- tests/test_fireworks_api_client.py | 72 ++++++++++++-------- 2 files changed, 106 insertions(+), 64 deletions(-) diff --git a/eval_protocol/fireworks_api_client.py b/eval_protocol/fireworks_api_client.py index 960e3f7c..0e24fb9b 100644 --- a/eval_protocol/fireworks_api_client.py +++ b/eval_protocol/fireworks_api_client.py @@ -10,15 +10,15 @@ class FireworksAPIClient: """Client for making authenticated requests to Fireworks API with proper headers. - + This client automatically includes: - Authorization header (Bearer token) - User-Agent header for tracking eval-protocol CLI usage """ - + def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): """Initialize the Fireworks API client. - + Args: api_key: Fireworks API key. If None, will be read from environment. api_base: API base URL. If None, defaults to https://api.fireworks.ai @@ -26,57 +26,82 @@ def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None self.api_key = api_key self.api_base = api_base or os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") self._session = requests.Session() - - def _get_headers(self, content_type: Optional[str] = "application/json", - additional_headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + + def _validate_path_is_relative(self, path: str) -> None: + """Validate that the path is relative, not an absolute URL. + + Args: + path: The path to validate + + Raises: + ValueError: If path appears to be an absolute URL (starts with http:// or https://) + """ + if path.startswith(("http://", "https://")): + raise ValueError( + f"Absolute URL detected: '{path}'. FireworksAPIClient methods expect relative paths only. " + f"Use a relative path like 'v1/path' instead of '{path}'. " + f"The client will automatically prepend the api_base: '{self.api_base}'" + ) + + def _get_headers( + self, content_type: Optional[str] = "application/json", additional_headers: Optional[Dict[str, str]] = None + ) -> Dict[str, str]: """Build headers for API requests. - + Args: content_type: Content-Type header value. If None, Content-Type won't be set. additional_headers: Additional headers to merge in. - + Returns: Dictionary of headers including authorization and user-agent. """ headers = { "User-Agent": get_user_agent(), } - + if self.api_key: headers["Authorization"] = f"Bearer {self.api_key}" - + if content_type: headers["Content-Type"] = content_type - + if additional_headers: headers.update(additional_headers) - + return headers - - def get(self, path: str, params: Optional[Dict[str, Any]] = None, - timeout: int = 30, **kwargs) -> requests.Response: + + def get( + self, path: str, params: Optional[Dict[str, Any]] = None, timeout: int = 30, **kwargs + ) -> requests.Response: """Make a GET request to the Fireworks API. - + Args: path: API path (relative to api_base) params: Query parameters timeout: Request timeout in seconds **kwargs: Additional arguments passed to requests.get - + Returns: Response object """ + self._validate_path_is_relative(path) url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" headers = self._get_headers(content_type=None) if "headers" in kwargs: headers.update(kwargs.pop("headers")) return self._session.get(url, params=params, headers=headers, timeout=timeout, **kwargs) - - def post(self, path: str, json: Optional[Dict[str, Any]] = None, - data: Optional[Any] = None, files: Optional[Dict[str, Any]] = None, - timeout: int = 60, **kwargs) -> requests.Response: + + def post( + self, + path: str, + json: Optional[Dict[str, Any]] = None, + data: Optional[Any] = None, + files: Optional[Dict[str, Any]] = None, + timeout: int = 60, + **kwargs, + ) -> requests.Response: """Make a POST request to the Fireworks API. - + Args: path: API path (relative to api_base) json: JSON payload @@ -84,42 +109,45 @@ def post(self, path: str, json: Optional[Dict[str, Any]] = None, files: Files to upload timeout: Request timeout in seconds **kwargs: Additional arguments passed to requests.post - + Returns: Response object """ + self._validate_path_is_relative(path) url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" - + # For file uploads, don't set Content-Type (let requests handle multipart/form-data) content_type = None if files else "application/json" headers = self._get_headers(content_type=content_type) - + if "headers" in kwargs: headers.update(kwargs.pop("headers")) - - return self._session.post(url, json=json, data=data, files=files, - headers=headers, timeout=timeout, **kwargs) - - def put(self, path: str, json: Optional[Dict[str, Any]] = None, - timeout: int = 60, **kwargs) -> requests.Response: + + return self._session.post(url, json=json, data=data, files=files, headers=headers, timeout=timeout, **kwargs) + + def put(self, path: str, json: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs) -> requests.Response: """Make a PUT request to the Fireworks API.""" + self._validate_path_is_relative(path) url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" headers = self._get_headers() if "headers" in kwargs: headers.update(kwargs.pop("headers")) return self._session.put(url, json=json, headers=headers, timeout=timeout, **kwargs) - - def patch(self, path: str, json: Optional[Dict[str, Any]] = None, - timeout: int = 60, **kwargs) -> requests.Response: + + def patch( + self, path: str, json: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs + ) -> requests.Response: """Make a PATCH request to the Fireworks API.""" + self._validate_path_is_relative(path) url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" headers = self._get_headers() if "headers" in kwargs: headers.update(kwargs.pop("headers")) return self._session.patch(url, json=json, headers=headers, timeout=timeout, **kwargs) - + def delete(self, path: str, timeout: int = 30, **kwargs) -> requests.Response: """Make a DELETE request to the Fireworks API.""" + self._validate_path_is_relative(path) url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" headers = self._get_headers(content_type=None) if "headers" in kwargs: diff --git a/tests/test_fireworks_api_client.py b/tests/test_fireworks_api_client.py index 72626b8f..4965c8ae 100644 --- a/tests/test_fireworks_api_client.py +++ b/tests/test_fireworks_api_client.py @@ -415,23 +415,19 @@ def test_paths_containing_v1_pattern(self): mock_post.reset_mock() - def test_full_url_passed_by_mistake_detected(self): - """Test that accidentally passing a full URL instead of relative path is detected. + def test_full_url_passed_by_mistake_raises_error(self): + """Test that accidentally passing a full URL instead of relative path raises ValueError. - This test documents the bug pattern: if a full URL like '{api_base}/v1/path' - is passed instead of a relative path like 'v1/path', it will result in a - malformed URL like '{api_base}/{api_base}/v1/path'. - - This test verifies that our code correctly handles relative paths (which prevents - the bug), and documents what would happen if the bug occurred. + This test verifies that our code correctly catches the bug early by raising an error + when an absolute URL is passed instead of a relative path. """ api_base = "https://api.fireworks.ai" client = FireworksAPIClient(api_key="test_key", api_base=api_base) + # CORRECT: Relative path (what we should use) - should work fine mock_response = MagicMock() mock_response.status_code = 200 - # CORRECT: Relative path (what we should use) with patch.object(client._session, "post", return_value=mock_response) as mock_post: correct_relative_path = "v1/test-evaluator:getUploadEndpoint" client.post(correct_relative_path, json={}) @@ -441,27 +437,45 @@ def test_full_url_passed_by_mistake_detected(self): expected_correct_url = f"{api_base}/{correct_relative_path}" assert correct_url == expected_correct_url - # INCORRECT: Full URL (this would cause the bug - but we're not actually testing this, - # just documenting that our current implementation would create a malformed URL) - # If someone accidentally did: client.post(f"{api_base}/v1/path", ...) - # The result would be: f"{api_base}/{api_base}/v1/path" which is wrong. - # Our tests above ensure we use relative paths, preventing this bug. - mock_post.reset_mock() - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - # Simulating what WOULD happen if buggy code passed full URL - buggy_full_url = f"{api_base}/v1/test-evaluator:getUploadEndpoint" - client.post(buggy_full_url, json={}) + # INCORRECT: Full URL should raise ValueError + full_url_with_http = "https://api.fireworks.ai/v1/test-evaluator:getUploadEndpoint" + with pytest.raises(ValueError, match="Absolute URL detected"): + client.post(full_url_with_http, json={}) + + full_url_with_http_scheme = "http://api.fireworks.ai/v1/test-evaluator:getUploadEndpoint" + with pytest.raises(ValueError, match="Absolute URL detected"): + client.post(full_url_with_http_scheme, json={}) + + # Test that error message is helpful + with pytest.raises(ValueError) as exc_info: + client.post(full_url_with_http, json={}) + error_msg = str(exc_info.value) + assert "Absolute URL detected" in error_msg + assert full_url_with_http in error_msg + assert "relative paths only" in error_msg + assert api_base in error_msg # Should mention api_base in the help message + + def test_all_methods_reject_absolute_urls(self): + """Test that all HTTP methods reject absolute URLs.""" + api_base = "https://api.fireworks.ai" + client = FireworksAPIClient(api_key="test_key", api_base=api_base) - call_args = mock_post.call_args - buggy_url = call_args[0][0] - # This shows what the buggy URL would look like - buggy_expected = f"{api_base}/{buggy_full_url}" - - # This assertion documents the bug pattern - the URL would be malformed - assert buggy_url == buggy_expected - assert buggy_url.startswith(f"{api_base}/{api_base}"), ( - "This documents the bug: passing full URL creates double-prefix. Always use relative paths!" - ) + absolute_url = f"{api_base}/v1/test/path" + + methods = [ + ("get", lambda url: client.get(url)), + ("post", lambda url: client.post(url, json={})), + ("put", lambda url: client.put(url, json={})), + ("patch", lambda url: client.patch(url, json={})), + ("delete", lambda url: client.delete(url)), + ] + + for method_name, method_call in methods: + with pytest.raises(ValueError, match="Absolute URL detected") as exc_info: + method_call(absolute_url) + error_msg = str(exc_info.value) + assert "Absolute URL detected" in error_msg, f"{method_name.upper()} should reject absolute URL" + assert absolute_url in error_msg if __name__ == "__main__": From 4528c556ad4d1a100832d2b80f78136595d6e8cf Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Mon, 3 Nov 2025 16:42:34 -0800 Subject: [PATCH 11/11] revert --- eval_protocol/adapters/fireworks_tracing.py | 14 +- eval_protocol/auth.py | 17 +- eval_protocol/evaluation.py | 53 ++- eval_protocol/fireworks_api_client.py | 155 ------- eval_protocol/fireworks_rft.py | 34 +- eval_protocol/platform_api.py | 35 +- eval_protocol/pytest/handle_persist_flow.py | 30 +- tests/test_fireworks_api_client.py | 482 -------------------- 8 files changed, 121 insertions(+), 699 deletions(-) delete mode 100644 eval_protocol/fireworks_api_client.py delete mode 100644 tests/test_fireworks_api_client.py diff --git a/eval_protocol/adapters/fireworks_tracing.py b/eval_protocol/adapters/fireworks_tracing.py index 233371cc..218f9d1d 100644 --- a/eval_protocol/adapters/fireworks_tracing.py +++ b/eval_protocol/adapters/fireworks_tracing.py @@ -273,7 +273,12 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) - if not tags: raise ValueError("At least one tag is required to fetch logs") - headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} + from ..common_utils import get_user_agent + + headers = { + "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "User-Agent": get_user_agent(), + } params: Dict[str, Any] = {"tags": tags, "limit": limit, "hours_back": hours_back, "program": "eval_protocol"} # Try /logs first, fall back to /v1/logs if not found @@ -398,7 +403,12 @@ def get_evaluation_rows( else: url = f"{self.base_url}/v1/traces/pointwise" - headers = {"Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}"} + from ..common_utils import get_user_agent + + headers = { + "Authorization": f"Bearer {os.environ.get('FIREWORKS_API_KEY')}", + "User-Agent": get_user_agent(), + } result = None try: diff --git a/eval_protocol/auth.py b/eval_protocol/auth.py index bdbd2912..3002bd4a 100644 --- a/eval_protocol/auth.py +++ b/eval_protocol/auth.py @@ -4,6 +4,8 @@ from pathlib import Path from typing import Dict, Optional # Added Dict +import requests + logger = logging.getLogger(__name__) # Default locations (used for tests and as fallback). Actual resolution is dynamic via _get_auth_ini_file(). @@ -240,11 +242,16 @@ def verify_api_key_and_get_account_id( if not resolved_key: return None resolved_base = api_base or get_fireworks_api_base() - - from .fireworks_api_client import FireworksAPIClient - client = FireworksAPIClient(api_key=resolved_key, api_base=resolved_base) - resp = client.get("verifyApiKey", timeout=10) - + + from .common_utils import get_user_agent + + url = f"{resolved_base.rstrip('/')}/verifyApiKey" + headers = { + "Authorization": f"Bearer {resolved_key}", + "User-Agent": get_user_agent(), + } + resp = requests.get(url, headers=headers, timeout=10) + if resp.status_code != 200: logger.debug("verifyApiKey returned status %s", resp.status_code) return None diff --git a/eval_protocol/evaluation.py b/eval_protocol/evaluation.py index af582bc2..7d7ef3da 100644 --- a/eval_protocol/evaluation.py +++ b/eval_protocol/evaluation.py @@ -20,7 +20,7 @@ get_fireworks_api_key, verify_api_key_and_get_account_id, ) -from eval_protocol.fireworks_api_client import FireworksAPIClient +from eval_protocol.common_utils import get_user_agent from eval_protocol.typed_interface import EvaluationMode from eval_protocol.get_pep440_version import get_pep440_version @@ -402,15 +402,20 @@ def preview(self, sample_file, max_samples=5): if "dev.api.fireworks.ai" in api_base and account_id == "fireworks": account_id = "pyroworks-dev" - client = FireworksAPIClient(api_key=auth_token, api_base=api_base) - path = f"v1/accounts/{account_id}/evaluators:previewEvaluator" - - logger.info(f"Previewing evaluator using API endpoint: {api_base}/{path} with account: {account_id}") + url = f"{api_base}/v1/accounts/{account_id}/evaluators:previewEvaluator" + headers = { + "Authorization": f"Bearer {auth_token}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } + logger.info(f"Previewing evaluator using API endpoint: {url} with account: {account_id}") + logger.debug(f"Preview API Request URL: {url}") + logger.debug(f"Preview API Request Headers: {json.dumps(headers, indent=2)}") logger.debug(f"Preview API Request Payload: {json.dumps(payload, indent=2)}") global used_preview_api try: - response = client.post(path, json=payload) + response = requests.post(url, json=payload, headers=headers) response.raise_for_status() result = response.json() used_preview_api = True @@ -741,8 +746,12 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) if "dev.api.fireworks.ai" in self.api_base and account_id == "fireworks": account_id = "pyroworks-dev" - client = FireworksAPIClient(api_key=auth_token, api_base=self.api_base) - path = f"v1/{parent}/evaluatorsV2" + base_url = f"{self.api_base}/v1/{parent}/evaluatorsV2" + headers = { + "Authorization": f"Bearer {auth_token}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } self._ensure_requirements_present(os.getcwd()) @@ -750,16 +759,16 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) try: if force: - check_path = f"v1/{parent}/evaluators/{evaluator_id}" + check_url = f"{self.api_base}/v1/{parent}/evaluators/{evaluator_id}" try: - logger.info(f"Checking if evaluator exists: {self.api_base}/{check_path}") - check_response = client.get(check_path) + logger.info(f"Checking if evaluator exists: {check_url}") + check_response = requests.get(check_url, headers=headers) if check_response.status_code == 200: logger.info(f"Evaluator '{evaluator_id}' already exists, deleting and recreating...") - delete_path = f"v1/{parent}/evaluators/{evaluator_id}" + delete_url = f"{self.api_base}/v1/{parent}/evaluators/{evaluator_id}" try: - delete_response = client.delete(delete_path) + delete_response = requests.delete(delete_url, headers=headers) if delete_response.status_code < 400: logger.info(f"Successfully deleted evaluator '{evaluator_id}'") else: @@ -768,14 +777,14 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) ) except Exception as e_del: logger.warning(f"Error deleting evaluator: {str(e_del)}") - response = client.post(path, json=payload_data) + response = requests.post(base_url, json=payload_data, headers=headers) else: - response = client.post(path, json=payload_data) + response = requests.post(base_url, json=payload_data, headers=headers) except requests.exceptions.RequestException: - response = client.post(path, json=payload_data) + response = requests.post(base_url, json=payload_data, headers=headers) else: - logger.info(f"Creating evaluator at: {self.api_base}/{path}") - response = client.post(path, json=payload_data) + logger.info(f"Creating evaluator at: {base_url}") + response = requests.post(base_url, json=payload_data, headers=headers) response.raise_for_status() result = response.json() @@ -800,11 +809,11 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) tar_size = self._create_tar_gz_with_ignores(tar_path, cwd) # Call GetEvaluatorUploadEndpoint - upload_endpoint_path = f"v1/{evaluator_name}:getUploadEndpoint" + upload_endpoint_url = f"{self.api_base}/v1/{evaluator_name}:getUploadEndpoint" upload_payload = {"name": evaluator_name, "filename_to_size": {tar_filename: tar_size}} logger.info(f"Requesting upload endpoint for {tar_filename}") - upload_response = client.post(upload_endpoint_path, json=upload_payload) + upload_response = requests.post(upload_endpoint_url, json=upload_payload, headers=headers) upload_response.raise_for_status() # Check for signed URLs @@ -884,9 +893,9 @@ def create(self, evaluator_id, display_name=None, description=None, force=False) raise # Step 3: Validate upload - validate_path = f"v1/{evaluator_name}:validateUpload" + validate_url = f"{self.api_base}/v1/{evaluator_name}:validateUpload" validate_payload = {"name": evaluator_name} - validate_response = client.post(validate_path, json=validate_payload) + validate_response = requests.post(validate_url, json=validate_payload, headers=headers) validate_response.raise_for_status() validate_data = validate_response.json() diff --git a/eval_protocol/fireworks_api_client.py b/eval_protocol/fireworks_api_client.py deleted file mode 100644 index 0e24fb9b..00000000 --- a/eval_protocol/fireworks_api_client.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Centralized client for making requests to Fireworks API with consistent headers.""" - -import os -from typing import Any, Dict, Optional - -import requests - -from .common_utils import get_user_agent - - -class FireworksAPIClient: - """Client for making authenticated requests to Fireworks API with proper headers. - - This client automatically includes: - - Authorization header (Bearer token) - - User-Agent header for tracking eval-protocol CLI usage - """ - - def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None): - """Initialize the Fireworks API client. - - Args: - api_key: Fireworks API key. If None, will be read from environment. - api_base: API base URL. If None, defaults to https://api.fireworks.ai - """ - self.api_key = api_key - self.api_base = api_base or os.environ.get("FIREWORKS_API_BASE", "https://api.fireworks.ai") - self._session = requests.Session() - - def _validate_path_is_relative(self, path: str) -> None: - """Validate that the path is relative, not an absolute URL. - - Args: - path: The path to validate - - Raises: - ValueError: If path appears to be an absolute URL (starts with http:// or https://) - """ - if path.startswith(("http://", "https://")): - raise ValueError( - f"Absolute URL detected: '{path}'. FireworksAPIClient methods expect relative paths only. " - f"Use a relative path like 'v1/path' instead of '{path}'. " - f"The client will automatically prepend the api_base: '{self.api_base}'" - ) - - def _get_headers( - self, content_type: Optional[str] = "application/json", additional_headers: Optional[Dict[str, str]] = None - ) -> Dict[str, str]: - """Build headers for API requests. - - Args: - content_type: Content-Type header value. If None, Content-Type won't be set. - additional_headers: Additional headers to merge in. - - Returns: - Dictionary of headers including authorization and user-agent. - """ - headers = { - "User-Agent": get_user_agent(), - } - - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - if content_type: - headers["Content-Type"] = content_type - - if additional_headers: - headers.update(additional_headers) - - return headers - - def get( - self, path: str, params: Optional[Dict[str, Any]] = None, timeout: int = 30, **kwargs - ) -> requests.Response: - """Make a GET request to the Fireworks API. - - Args: - path: API path (relative to api_base) - params: Query parameters - timeout: Request timeout in seconds - **kwargs: Additional arguments passed to requests.get - - Returns: - Response object - """ - self._validate_path_is_relative(path) - url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" - headers = self._get_headers(content_type=None) - if "headers" in kwargs: - headers.update(kwargs.pop("headers")) - return self._session.get(url, params=params, headers=headers, timeout=timeout, **kwargs) - - def post( - self, - path: str, - json: Optional[Dict[str, Any]] = None, - data: Optional[Any] = None, - files: Optional[Dict[str, Any]] = None, - timeout: int = 60, - **kwargs, - ) -> requests.Response: - """Make a POST request to the Fireworks API. - - Args: - path: API path (relative to api_base) - json: JSON payload - data: Form data payload - files: Files to upload - timeout: Request timeout in seconds - **kwargs: Additional arguments passed to requests.post - - Returns: - Response object - """ - self._validate_path_is_relative(path) - url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" - - # For file uploads, don't set Content-Type (let requests handle multipart/form-data) - content_type = None if files else "application/json" - headers = self._get_headers(content_type=content_type) - - if "headers" in kwargs: - headers.update(kwargs.pop("headers")) - - return self._session.post(url, json=json, data=data, files=files, headers=headers, timeout=timeout, **kwargs) - - def put(self, path: str, json: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs) -> requests.Response: - """Make a PUT request to the Fireworks API.""" - self._validate_path_is_relative(path) - url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" - headers = self._get_headers() - if "headers" in kwargs: - headers.update(kwargs.pop("headers")) - return self._session.put(url, json=json, headers=headers, timeout=timeout, **kwargs) - - def patch( - self, path: str, json: Optional[Dict[str, Any]] = None, timeout: int = 60, **kwargs - ) -> requests.Response: - """Make a PATCH request to the Fireworks API.""" - self._validate_path_is_relative(path) - url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" - headers = self._get_headers() - if "headers" in kwargs: - headers.update(kwargs.pop("headers")) - return self._session.patch(url, json=json, headers=headers, timeout=timeout, **kwargs) - - def delete(self, path: str, timeout: int = 30, **kwargs) -> requests.Response: - """Make a DELETE request to the Fireworks API.""" - self._validate_path_is_relative(path) - url = f"{self.api_base.rstrip('/')}/{path.lstrip('/')}" - headers = self._get_headers(content_type=None) - if "headers" in kwargs: - headers.update(kwargs.pop("headers")) - return self._session.delete(url, headers=headers, timeout=timeout, **kwargs) diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 0f2a1706..05b49291 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -11,7 +11,7 @@ import requests from .auth import get_fireworks_account_id, get_fireworks_api_base, get_fireworks_api_key -from .fireworks_api_client import FireworksAPIClient +from .common_utils import get_user_agent def _map_api_host_to_app_host(api_base: str) -> str: @@ -158,14 +158,18 @@ def create_dataset_from_jsonl( display_name: Optional[str], jsonl_path: str, ) -> Tuple[str, Dict[str, Any]]: - client = FireworksAPIClient(api_key=api_key, api_base=api_base) - + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } # Count examples quickly example_count = 0 with open(jsonl_path, "r", encoding="utf-8") as f: for _ in f: example_count += 1 - + + dataset_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets" payload = { "dataset": { "displayName": display_name or dataset_id, @@ -175,15 +179,19 @@ def create_dataset_from_jsonl( }, "datasetId": dataset_id, } - resp = client.post(f"v1/accounts/{account_id}/datasets", json=payload, timeout=60) + resp = requests.post(dataset_url, json=payload, headers=headers, timeout=60) if resp.status_code not in (200, 201): raise RuntimeError(f"Dataset creation failed: {resp.status_code} {resp.text}") ds = resp.json() + upload_url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/datasets/{dataset_id}:upload" with open(jsonl_path, "rb") as f: files = {"file": f} - up_resp = client.post(f"v1/accounts/{account_id}/datasets/{dataset_id}:upload", - files=files, timeout=600) + up_headers = { + "Authorization": f"Bearer {api_key}", + "User-Agent": get_user_agent(), + } + up_resp = requests.post(upload_url, files=files, headers=up_headers, timeout=600) if up_resp.status_code not in (200, 201): raise RuntimeError(f"Dataset upload failed: {up_resp.status_code} {up_resp.text}") return dataset_id, ds @@ -195,10 +203,14 @@ def create_reinforcement_fine_tuning_job( api_base: str, body: Dict[str, Any], ) -> Dict[str, Any]: - client = FireworksAPIClient(api_key=api_key, api_base=api_base) - resp = client.post(f"v1/accounts/{account_id}/reinforcementFineTuningJobs", - json=body, timeout=60, - headers={"Accept": "application/json"}) + url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + "User-Agent": get_user_agent(), + } + resp = requests.post(url, json=body, headers=headers, timeout=60) if resp.status_code not in (200, 201): raise RuntimeError(f"RFT job creation failed: {resp.status_code} {resp.text}") return resp.json() diff --git a/eval_protocol/platform_api.py b/eval_protocol/platform_api.py index 32dc141e..81754e13 100644 --- a/eval_protocol/platform_api.py +++ b/eval_protocol/platform_api.py @@ -11,7 +11,7 @@ get_fireworks_api_base, get_fireworks_api_key, ) -from eval_protocol.fireworks_api_client import FireworksAPIClient +from eval_protocol.common_utils import get_user_agent logger = logging.getLogger(__name__) @@ -93,7 +93,11 @@ def create_or_update_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for creating/updating secret.") return False - client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base) + headers = { + "Authorization": f"Bearer {resolved_api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } # The secret_id for GET/PATCH/DELETE operations is the key_name. # The 'name' field in the gatewaySecret model for POST/PATCH is a bit ambiguous. @@ -107,7 +111,8 @@ def create_or_update_fireworks_secret( resource_id = _normalize_secret_resource_id(key_name) secret_exists = False try: - response = client.get(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=10) + url = f"{resolved_api_base}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" + response = requests.get(url, headers=headers, timeout=10) if response.status_code == 200: secret_exists = True logger.info(f"Secret '{key_name}' already exists. Will attempt to update.") @@ -142,8 +147,8 @@ def create_or_update_fireworks_secret( payload = {"keyName": payload_key_name, "value": secret_value} try: logger.debug(f"PATCH payload for '{key_name}': {payload}") - response = client.patch(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", - json=payload, timeout=30) + url = f"{resolved_api_base}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" + response = requests.patch(url, json=payload, headers=headers, timeout=30) response.raise_for_status() logger.info(f"Successfully updated secret '{key_name}' on Fireworks platform.") return True @@ -179,8 +184,8 @@ def create_or_update_fireworks_secret( } try: logger.debug(f"POST payload for '{key_name}': {payload}") - response = client.post(f"v1/accounts/{resolved_account_id}/secrets", - json=payload, timeout=30) + url = f"{resolved_api_base}/v1/accounts/{resolved_account_id}/secrets" + response = requests.post(url, json=payload, headers=headers, timeout=30) response.raise_for_status() logger.info( f"Successfully created secret '{key_name}' on Fireworks platform. Full name: {response.json().get('name')}" @@ -214,11 +219,15 @@ def get_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for getting secret.") return None - client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base) + headers = { + "Authorization": f"Bearer {resolved_api_key}", + "User-Agent": get_user_agent(), + } resource_id = _normalize_secret_resource_id(key_name) try: - response = client.get(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=10) + url = f"{resolved_api_base}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" + response = requests.get(url, headers=headers, timeout=10) if response.status_code == 200: logger.info(f"Successfully retrieved secret '{key_name}'.") return response.json() @@ -250,11 +259,15 @@ def delete_fireworks_secret( logger.error("Missing Fireworks API key, base URL, or account ID for deleting secret.") return False - client = FireworksAPIClient(api_key=resolved_api_key, api_base=resolved_api_base) + headers = { + "Authorization": f"Bearer {resolved_api_key}", + "User-Agent": get_user_agent(), + } resource_id = _normalize_secret_resource_id(key_name) try: - response = client.delete(f"v1/accounts/{resolved_account_id}/secrets/{resource_id}", timeout=30) + url = f"{resolved_api_base}/v1/accounts/{resolved_account_id}/secrets/{resource_id}" + response = requests.delete(url, headers=headers, timeout=30) if response.status_code == 200 or response.status_code == 204: # 204 No Content is also success for DELETE logger.info(f"Successfully deleted secret '{key_name}'.") return True diff --git a/eval_protocol/pytest/handle_persist_flow.py b/eval_protocol/pytest/handle_persist_flow.py index 184c0077..07a627f3 100644 --- a/eval_protocol/pytest/handle_persist_flow.py +++ b/eval_protocol/pytest/handle_persist_flow.py @@ -7,11 +7,13 @@ import re from typing import Any +from eval_protocol.common_utils import get_user_agent from eval_protocol.directory_utils import find_eval_protocol_dir -from eval_protocol.fireworks_api_client import FireworksAPIClient from eval_protocol.models import EvaluationRow from eval_protocol.pytest.store_experiment_link import store_experiment_link +import requests + def handle_persist_flow(all_results: list[list[EvaluationRow]], test_func_name: str): try: @@ -127,7 +129,12 @@ def get_auth_value(key: str) -> str | None: ) continue - client = FireworksAPIClient(api_key=fireworks_api_key, api_base="https://api.fireworks.ai") + api_base = "https://api.fireworks.ai" + headers = { + "Authorization": f"Bearer {fireworks_api_key}", + "Content-Type": "application/json", + "User-Agent": get_user_agent(), + } # Make dataset first @@ -141,9 +148,8 @@ def get_auth_value(key: str) -> str | None: "datasetId": dataset_name, } - dataset_response = client.post( - f"v1/accounts/{fireworks_account_id}/datasets", json=dataset_payload - ) # pyright: ignore[reportUnknownArgumentType] + dataset_url = f"{api_base}/v1/accounts/{fireworks_account_id}/datasets" + dataset_response = requests.post(dataset_url, json=dataset_payload, headers=headers) # pyright: ignore[reportUnknownArgumentType] # Skip if dataset creation failed if dataset_response.status_code not in [200, 201]: @@ -158,11 +164,14 @@ def get_auth_value(key: str) -> str | None: dataset_id = dataset_data.get("datasetId", dataset_name) # pyright: ignore[reportAny] # Upload the JSONL file content + upload_url = f"{api_base}/v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload" with open(exp_file, "rb") as f: files = {"file": f} - upload_response = client.post( - f"v1/accounts/{fireworks_account_id}/datasets/{dataset_id}:upload", files=files - ) + upload_headers = { + "Authorization": f"Bearer {fireworks_api_key}", + "User-Agent": get_user_agent(), + } + upload_response = requests.post(upload_url, files=files, headers=upload_headers) # Skip if upload failed if upload_response.status_code not in [200, 201]: @@ -191,9 +200,8 @@ def get_auth_value(key: str) -> str | None: }, } - eval_response = client.post( - f"v1/accounts/{fireworks_account_id}/evaluationJobs", json=eval_job_payload - ) + eval_job_url = f"{api_base}/v1/accounts/{fireworks_account_id}/evaluationJobs" + eval_response = requests.post(eval_job_url, json=eval_job_payload, headers=headers) if eval_response.status_code in [200, 201]: eval_job_data = eval_response.json() # pyright: ignore[reportAny] diff --git a/tests/test_fireworks_api_client.py b/tests/test_fireworks_api_client.py deleted file mode 100644 index 4965c8ae..00000000 --- a/tests/test_fireworks_api_client.py +++ /dev/null @@ -1,482 +0,0 @@ -"""Tests for FireworksAPIClient user-agent header functionality.""" - -import re -from unittest.mock import MagicMock, patch - -import pytest - -from eval_protocol.common_utils import get_user_agent -from eval_protocol.fireworks_api_client import FireworksAPIClient - - -class TestFireworksAPIClientUserAgent: - """Test that FireworksAPIClient correctly sets the User-Agent header.""" - - def test_get_user_agent_format(self): - """Test that get_user_agent returns the expected format.""" - user_agent = get_user_agent() - # Should match format: eval-protocol/{version} - # Version can be actual version or "unknown" - assert user_agent.startswith("eval-protocol/") - assert len(user_agent) > len("eval-protocol/") - - def test_get_user_agent_fallback_logic(self): - """Test that get_user_agent has fallback logic for when version can't be imported. - - This test verifies the code structure, since actually triggering an import - failure during the import statement is difficult to test reliably. - The important behavior (User-Agent header being set) is verified in other tests. - """ - # Verify the function exists and can be called normally - user_agent = get_user_agent() - # The function should always return a valid user agent string - assert isinstance(user_agent, str) - assert user_agent.startswith("eval-protocol/") - - # The actual fallback ("eval-protocol/unknown") happens when the import - # fails, which is hard to simulate without patching at a very low level. - # The try/except block in the implementation handles this gracefully. - - def test_get_headers_includes_user_agent(self): - """Test that _get_headers includes the User-Agent header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - headers = client._get_headers() - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - - def test_get_request_includes_user_agent(self): - """Test that GET requests include the User-Agent header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "get", return_value=mock_response) as mock_get: - client.get("test/path") - - mock_get.assert_called_once() - call_kwargs = mock_get.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - assert headers["Authorization"] == "Bearer test_key" - - def test_post_request_includes_user_agent(self): - """Test that POST requests include the User-Agent header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - client.post("test/path", json={"key": "value"}) - - mock_post.assert_called_once() - call_kwargs = mock_post.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - assert headers["Authorization"] == "Bearer test_key" - assert headers["Content-Type"] == "application/json" - - def test_post_with_files_excludes_content_type(self): - """Test that POST requests with files exclude Content-Type header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - client.post("test/path", files={"file": MagicMock()}) - - mock_post.assert_called_once() - call_kwargs = mock_post.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - # Content-Type should not be set when files are present - assert "Content-Type" not in headers - - def test_put_request_includes_user_agent(self): - """Test that PUT requests include the User-Agent header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "put", return_value=mock_response) as mock_put: - client.put("test/path", json={"key": "value"}) - - mock_put.assert_called_once() - call_kwargs = mock_put.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - assert headers["Authorization"] == "Bearer test_key" - - def test_patch_request_includes_user_agent(self): - """Test that PATCH requests include the User-Agent header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "patch", return_value=mock_response) as mock_patch: - client.patch("test/path", json={"key": "value"}) - - mock_patch.assert_called_once() - call_kwargs = mock_patch.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - assert headers["Authorization"] == "Bearer test_key" - - def test_delete_request_includes_user_agent(self): - """Test that DELETE requests include the User-Agent header.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "delete", return_value=mock_response) as mock_delete: - client.delete("test/path") - - mock_delete.assert_called_once() - call_kwargs = mock_delete.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - assert headers["Authorization"] == "Bearer test_key" - # DELETE requests shouldn't have Content-Type - assert "Content-Type" not in headers - - def test_additional_headers_merged(self): - """Test that additional headers passed to requests are merged with User-Agent.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "get", return_value=mock_response) as mock_get: - client.get("test/path", headers={"X-Custom-Header": "custom-value"}) - - mock_get.assert_called_once() - call_kwargs = mock_get.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - assert headers["X-Custom-Header"] == "custom-value" - - def test_user_agent_consistent_across_methods(self): - """Test that User-Agent is consistent across all HTTP methods.""" - client = FireworksAPIClient(api_key="test_key", api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - expected_user_agent = get_user_agent() - - # Test all methods - methods = [ - ("get", lambda: client.get("test/path")), - ("post", lambda: client.post("test/path", json={})), - ("put", lambda: client.put("test/path", json={})), - ("patch", lambda: client.patch("test/path", json={})), - ("delete", lambda: client.delete("test/path")), - ] - - for method_name, method_call in methods: - with patch.object(client._session, method_name, return_value=mock_response) as mock_method: - method_call() - - call_kwargs = mock_method.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers, f"{method_name} should include User-Agent" - assert headers["User-Agent"] == expected_user_agent, ( - f"{method_name} User-Agent should match expected value" - ) - - def test_user_agent_without_api_key(self): - """Test that User-Agent is still included even without API key.""" - client = FireworksAPIClient(api_key=None, api_base="https://api.fireworks.ai") - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "get", return_value=mock_response) as mock_get: - client.get("test/path") - - mock_get.assert_called_once() - call_kwargs = mock_get.call_args[1] - headers = call_kwargs["headers"] - - assert "User-Agent" in headers - assert headers["User-Agent"] == get_user_agent() - # Authorization should not be present - assert "Authorization" not in headers - - -class TestFireworksAPIClientPathHandling: - """Test that FireworksAPIClient correctly handles relative paths and prevents URL construction bugs.""" - - def test_post_relative_path_combines_with_api_base(self): - """Test that POST requests correctly combine relative paths with api_base.""" - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - relative_path = "v1/test/evaluator:getUploadEndpoint" - client.post(relative_path, json={"name": "test"}) - - mock_post.assert_called_once() - call_args = mock_post.call_args - # Check the URL passed to requests.post - assert call_args[0][0] == f"{api_base}/{relative_path}" - assert not call_args[0][0].startswith(f"{api_base}/{api_base}") - - def test_get_relative_path_combines_with_api_base(self): - """Test that GET requests correctly combine relative paths with api_base.""" - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "get", return_value=mock_response) as mock_get: - relative_path = "verifyApiKey" - client.get(relative_path) - - mock_get.assert_called_once() - call_args = mock_get.call_args - assert call_args[0][0] == f"{api_base}/{relative_path}" - - def test_post_get_upload_endpoint_path(self): - """Test the specific getUploadEndpoint path that was buggy. - - This ensures relative paths like 'v1/{name}:getUploadEndpoint' are handled correctly - and don't get double-prefixed with api_base. - """ - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {"filenameToSignedUrls": {"test.tar.gz": "https://signed.url"}} - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - evaluator_name = "test-evaluator" - # This is the correct pattern - relative path, not full URL - upload_endpoint_path = f"v1/{evaluator_name}:getUploadEndpoint" - client.post(upload_endpoint_path, json={"name": evaluator_name}) - - mock_post.assert_called_once() - call_args = mock_post.call_args - expected_url = f"{api_base}/{upload_endpoint_path}" - actual_url = call_args[0][0] - assert actual_url == expected_url, f"Expected {expected_url}, got {actual_url}" - # Ensure it doesn't have the buggy double-prefix - assert not actual_url.startswith(f"{api_base}/{api_base}") - - def test_post_validate_upload_path(self): - """Test the specific validateUpload path that was buggy. - - This ensures relative paths like 'v1/{name}:validateUpload' are handled correctly. - """ - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = {"status": "validated"} - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - evaluator_name = "test-evaluator" - # This is the correct pattern - relative path, not full URL - validate_path = f"v1/{evaluator_name}:validateUpload" - client.post(validate_path, json={"name": evaluator_name}) - - mock_post.assert_called_once() - call_args = mock_post.call_args - expected_url = f"{api_base}/{validate_path}" - actual_url = call_args[0][0] - assert actual_url == expected_url, f"Expected {expected_url}, got {actual_url}" - # Ensure it doesn't have the buggy double-prefix - assert not actual_url.startswith(f"{api_base}/{api_base}") - - def test_path_with_leading_slash_stripped(self): - """Test that leading slashes in paths are correctly handled.""" - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "get", return_value=mock_response) as mock_get: - # Path with leading slash should be handled correctly - client.get("/v1/test/path") - - mock_get.assert_called_once() - call_args = mock_get.call_args - # Should not have double slash - assert call_args[0][0] == f"{api_base}/v1/test/path" - - def test_api_base_with_trailing_slash(self): - """Test that api_base with trailing slash is handled correctly.""" - api_base = "https://api.fireworks.ai/" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - relative_path = "v1/test/path" - client.post(relative_path, json={}) - - mock_post.assert_called_once() - call_args = mock_post.call_args - # Should not have double slash - assert call_args[0][0] == f"https://api.fireworks.ai/{relative_path}" - - def test_all_http_methods_with_relative_paths(self): - """Test that all HTTP methods correctly handle relative paths.""" - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - - test_path = "v1/accounts/test/evaluators" - - methods = [ - ("get", lambda p: client.get(p)), - ("post", lambda p: client.post(p, json={})), - ("put", lambda p: client.put(p, json={})), - ("patch", lambda p: client.patch(p, json={})), - ("delete", lambda p: client.delete(p)), - ] - - for method_name, method_call in methods: - with patch.object(client._session, method_name, return_value=mock_response) as mock_method: - method_call(test_path) - - mock_method.assert_called_once() - call_args = mock_method.call_args - expected_url = f"{api_base}/{test_path}" - actual_url = call_args[0][0] - assert actual_url == expected_url, f"{method_name.upper()} expected {expected_url}, got {actual_url}" - # Ensure no double-prefix bug - assert not actual_url.startswith(f"{api_base}/{api_base}"), ( - f"{method_name.upper()} URL has double-prefix bug: {actual_url}" - ) - - def test_paths_containing_v1_pattern(self): - """Test various v1 API paths to ensure correct URL construction.""" - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - mock_response = MagicMock() - mock_response.status_code = 200 - - test_cases = [ - "v1/accounts/test/evaluators", - "v1/accounts/test/evaluators/eval-id", - "v1/accounts/test/evaluatorsV2", - "v1/accounts/test/evaluators:previewEvaluator", - "v1/test-evaluator:getUploadEndpoint", - "v1/test-evaluator:validateUpload", - ] - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - for path in test_cases: - client.post(path, json={}) - - call_args = mock_post.call_args - actual_url = call_args[0][0] - expected_url = f"{api_base}/{path}" - - assert actual_url == expected_url, ( - f"Path '{path}' resulted in URL '{actual_url}', expected '{expected_url}'" - ) - assert not actual_url.startswith(f"{api_base}/{api_base}"), ( - f"Path '{path}' has double-prefix bug: {actual_url}" - ) - - mock_post.reset_mock() - - def test_full_url_passed_by_mistake_raises_error(self): - """Test that accidentally passing a full URL instead of relative path raises ValueError. - - This test verifies that our code correctly catches the bug early by raising an error - when an absolute URL is passed instead of a relative path. - """ - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - # CORRECT: Relative path (what we should use) - should work fine - mock_response = MagicMock() - mock_response.status_code = 200 - - with patch.object(client._session, "post", return_value=mock_response) as mock_post: - correct_relative_path = "v1/test-evaluator:getUploadEndpoint" - client.post(correct_relative_path, json={}) - - call_args = mock_post.call_args - correct_url = call_args[0][0] - expected_correct_url = f"{api_base}/{correct_relative_path}" - assert correct_url == expected_correct_url - - # INCORRECT: Full URL should raise ValueError - full_url_with_http = "https://api.fireworks.ai/v1/test-evaluator:getUploadEndpoint" - with pytest.raises(ValueError, match="Absolute URL detected"): - client.post(full_url_with_http, json={}) - - full_url_with_http_scheme = "http://api.fireworks.ai/v1/test-evaluator:getUploadEndpoint" - with pytest.raises(ValueError, match="Absolute URL detected"): - client.post(full_url_with_http_scheme, json={}) - - # Test that error message is helpful - with pytest.raises(ValueError) as exc_info: - client.post(full_url_with_http, json={}) - error_msg = str(exc_info.value) - assert "Absolute URL detected" in error_msg - assert full_url_with_http in error_msg - assert "relative paths only" in error_msg - assert api_base in error_msg # Should mention api_base in the help message - - def test_all_methods_reject_absolute_urls(self): - """Test that all HTTP methods reject absolute URLs.""" - api_base = "https://api.fireworks.ai" - client = FireworksAPIClient(api_key="test_key", api_base=api_base) - - absolute_url = f"{api_base}/v1/test/path" - - methods = [ - ("get", lambda url: client.get(url)), - ("post", lambda url: client.post(url, json={})), - ("put", lambda url: client.put(url, json={})), - ("patch", lambda url: client.patch(url, json={})), - ("delete", lambda url: client.delete(url)), - ] - - for method_name, method_call in methods: - with pytest.raises(ValueError, match="Absolute URL detected") as exc_info: - method_call(absolute_url) - error_msg = str(exc_info.value) - assert "Absolute URL detected" in error_msg, f"{method_name.upper()} should reject absolute URL" - assert absolute_url in error_msg - - -if __name__ == "__main__": - pytest.main([__file__, "-v"])