Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0a00bcd
Refactor Fireworks client integration
Jan 7, 2026
d465a89
remove launch.json
Jan 7, 2026
348bb58
Add .vscode/launch.json to .gitignore
Jan 7, 2026
acaa901
Enhance environment variable loading in auth module
Jan 7, 2026
4b71ddb
Add evaluator version creation in evaluation module
Jan 7, 2026
3dbcd59
test
Jan 8, 2026
532e071
REVERT this later
Jan 8, 2026
5e7a5fa
Merge branch 'main' into dhuang/dxe-478-implement-evaluator-versions
Jan 8, 2026
060d72c
fix mock tests
Jan 9, 2026
bc31c9f
Add error handling for evaluator creation in evaluation module
Jan 9, 2026
ea08062
Support EP_REMOTE_API_KEY
Jan 9, 2026
f246087
Merge branch 'main' into dhuang/dxe-478-implement-evaluator-versions
Jan 9, 2026
6b53ac1
include launch.json.backup
Jan 12, 2026
ec0c8ca
rename to .example and add docker run extra arg
Jan 12, 2026
fc036f5
use ignore-docker by default
Jan 12, 2026
4566584
delete backup
Jan 13, 2026
f103b69
ignore-docker by default in dev
Jan 13, 2026
9c3e417
Refactor evaluator function calls to use Fireworks directly for metho…
Jan 13, 2026
ea673f4
use in-flight SDK version
Jan 13, 2026
26fbc2d
Enhance evaluator handling by returning version ID on creation and up…
Jan 13, 2026
4702307
update
Jan 13, 2026
9d1bc74
use published a22 of fireworks-ai
Jan 13, 2026
3314bec
uv lock
Jan 13, 2026
66f191a
Refactor dotenv handling in auth module and integrate environment var…
Jan 14, 2026
165afe1
add create rft launch configuration
Jan 14, 2026
838c7a5
Refactor dotenv handling in auth module and integrate environment var…
Jan 14, 2026
71599e6
Merge branch 'pass-dot-env-to-docker-container' into dhuang/dxe-478-i…
Jan 14, 2026
0144c9f
actually not necessary for local test since local-test mounts the wor…
Jan 14, 2026
c8774a6
increase sql retries
Jan 14, 2026
2076f0a
Refactor dotenv loading to use explicit paths in CLI and API modules
Jan 14, 2026
8acdc35
Merge branch 'main' into dhuang/dxe-478-implement-evaluator-versions
Jan 14, 2026
432a649
Refactor dotenv loading to use explicit paths in CLI and API modules
Jan 14, 2026
ab04086
Merge branch 'ensure-explicit-dotenv' into dhuang/dxe-478-implement-e…
Jan 14, 2026
3c2db59
"ep create evj"
Jan 14, 2026
17eb18f
use SDK for Dataset API calls
Jan 14, 2026
1fd66f7
Implement evaluator upload and status polling in create commands
Jan 15, 2026
fc4f913
Add secret management for uploads in CLI
Jan 15, 2026
2f88428
handle existing secrets with caution
Jan 15, 2026
c6a8c51
Integrate secrets upload handling in CLI commands
Jan 15, 2026
a2165fb
Remove unused `_to_pyargs_nodeid` function from `upload.py` to enhanc…
Jan 15, 2026
1445d75
increase sql retries
Jan 14, 2026
7969a6e
Refactor secret loading in CLI to use python-dotenv
Jan 15, 2026
d4a445b
make connection more robust
Jan 16, 2026
b3adfee
Merge branch 'increase-sql-retries' into dhuang/dxe-478-implement-eva…
Jan 16, 2026
37f4856
passes
Jan 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,5 @@ package.json
tau2-bench
*.err
eval-protocol

.vscode/launch.json
39 changes: 0 additions & 39 deletions .vscode/launch.json

This file was deleted.

30 changes: 29 additions & 1 deletion eval_protocol/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,30 @@
from typing import Optional

import requests
from dotenv import find_dotenv, load_dotenv

logger = logging.getLogger(__name__)

# --- Load .env files ---
# Attempt to load .env.dev first, then .env as a fallback.
# This happens when the module is imported.
# We use override=False (default) so that existing environment variables
# (e.g., set in the shell) are NOT overridden by .env files.
_ENV_DEV_PATH = find_dotenv(filename=".env.dev", raise_error_if_not_found=False, usecwd=True)
if _ENV_DEV_PATH:
load_dotenv(dotenv_path=_ENV_DEV_PATH, override=False)
logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_ENV_DEV_PATH}")
else:
_ENV_PATH = find_dotenv(filename=".env", raise_error_if_not_found=False, usecwd=True)
if _ENV_PATH:
load_dotenv(dotenv_path=_ENV_PATH, override=False)
logger.debug(f"eval_protocol.auth: Loaded environment variables from: {_ENV_PATH}")
else:
logger.debug(
"eval_protocol.auth: No .env.dev or .env file found. Relying on shell/existing environment variables."
)
# --- End .env loading ---


def get_fireworks_api_key() -> Optional[str]:
"""
Expand Down Expand Up @@ -73,6 +94,8 @@ def verify_api_key_and_get_account_id(
Args:
api_key: Optional explicit API key. When None, resolves via get_fireworks_api_key().
api_base: Optional explicit API base. When None, resolves via get_fireworks_api_base().
If api_base is api.fireworks.ai, it is used directly. Otherwise, defaults to
dev.api.fireworks.ai for the verification call.

Returns:
The resolved account id if verification succeeds and the header is present; otherwise None.
Expand All @@ -81,7 +104,12 @@ def verify_api_key_and_get_account_id(
resolved_key = api_key or get_fireworks_api_key()
if not resolved_key:
return None
resolved_base = api_base or get_fireworks_api_base()
provided_base = api_base or get_fireworks_api_base()
# Use api.fireworks.ai if explicitly provided, otherwise fall back to dev
if "api.fireworks.ai" in provided_base:
resolved_base = provided_base
else:
resolved_base = "https://dev.api.fireworks.ai"
Comment thread
cursor[bot] marked this conversation as resolved.

from .common_utils import get_user_agent

Expand Down
7 changes: 3 additions & 4 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import sys
from pathlib import Path

from fireworks import Fireworks

from .cli_commands.common import setup_logging
from .cli_commands.utils import add_args_from_callable_signature
from .fireworks_client import create_fireworks_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -88,7 +87,7 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
)

# Auto-generate flags from SDK Fireworks().evaluators.create() signature
create_evaluator_fn = Fireworks().evaluators.create
create_evaluator_fn = create_fireworks_client().evaluators.create
Comment thread
dphuang2 marked this conversation as resolved.
Outdated

upload_skip_fields = {
"__top_level__": {
Expand Down Expand Up @@ -198,7 +197,7 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
"loss_config.method": "RL loss method for underlying trainers. One of {grpo,dapo}.",
}

create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create
create_rft_job_fn = create_fireworks_client().reinforcement_fine_tuning_jobs.create

add_args_from_callable_signature(
rft_parser,
Expand Down
7 changes: 3 additions & 4 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pydantic import ValidationError

from ..auth import get_fireworks_api_base, get_fireworks_api_key
from ..fireworks_client import create_fireworks_client
from ..common_utils import get_user_agent, load_jsonl
from ..fireworks_rft import (
create_dataset_from_jsonl,
Expand All @@ -35,8 +36,6 @@
)
from .local_test import run_evaluator_test

from fireworks import Fireworks


def _extract_dataset_adapter(
test_file_path: str, test_func_name: str
Expand Down Expand Up @@ -672,7 +671,7 @@ def _create_rft_job(
) -> int:
"""Build and submit the RFT job request (via Fireworks SDK)."""

signature = inspect.signature(Fireworks().reinforcement_fine_tuning_jobs.create)
signature = inspect.signature(create_fireworks_client().reinforcement_fine_tuning_jobs.create)
Comment thread
cursor[bot] marked this conversation as resolved.

# Build top-level SDK kwargs
sdk_kwargs: Dict[str, Any] = {
Expand Down Expand Up @@ -711,7 +710,7 @@ def _create_rft_job(
return 0

try:
fw: Fireworks = Fireworks(api_key=api_key, base_url=api_base)
fw: Fireworks = create_fireworks_client(api_key=api_key, base_url=api_base)
job: ReinforcementFineTuningJob = fw.reinforcement_fine_tuning_jobs.create(account_id=account_id, **sdk_kwargs)
job_name = job.name
print(f"\n✅ Created Reinforcement Fine-tuning Job: {job_name}")
Expand Down
35 changes: 30 additions & 5 deletions eval_protocol/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import List, Optional

import fireworks
from fireworks.types import EvaluatorVersionParam

Check failure on line 7 in eval_protocol/evaluation.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

"EvaluatorVersionParam" is unknown import symbol (reportAttributeAccessIssue)
import requests
from fireworks import Fireworks

from eval_protocol.auth import (
get_fireworks_account_id,
get_fireworks_api_key,
verify_api_key_and_get_account_id,
)
from eval_protocol.fireworks_client import create_fireworks_client
from eval_protocol.get_pep440_version import get_pep440_version

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -164,7 +165,11 @@
logger.error("Authentication error: API credentials appear to be invalid or incomplete.")
raise ValueError("Invalid or missing API credentials.")

client = Fireworks(api_key=auth_token, base_url=self.api_base, account_id=account_id)
client = create_fireworks_client(
api_key=auth_token,
base_url=self.api_base,
account_id=account_id,
)

self.display_name = display_name or evaluator_id
self.description = description or f"Evaluator created from {evaluator_id}"
Expand Down Expand Up @@ -230,6 +235,25 @@
f"Cannot proceed with code upload. Response: {result}"
)

evaluator_version_param: EvaluatorVersionParam = {}
if "commit_hash" in evaluator_params:
evaluator_version_param["commit_hash"] = evaluator_params["commit_hash"]
if "entry_point" in evaluator_params:
evaluator_version_param["entry_point"] = evaluator_params["entry_point"]
if "requirements" in evaluator_params:
evaluator_version_param["requirements"] = evaluator_params["requirements"]
Comment thread
cursor[bot] marked this conversation as resolved.

evaluator_version = client.evaluator_versions.create(

Check failure on line 246 in eval_protocol/evaluation.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Cannot access attribute "evaluator_versions" for class "Fireworks"   Attribute "evaluator_versions" is unknown (reportAttributeAccessIssue)
evaluator_id=evaluator_id,
evaluator_version=evaluator_version_param,
)
evaluator_version_id = evaluator_version.name.split("/")[-1] if evaluator_version.name else None
if not evaluator_version_id:
raise ValueError(
"Create evaluator version response missing 'name' field. "
f"Cannot proceed with code upload. Response: {evaluator_version}"
)

try:
# Create tar.gz of current directory
cwd = os.getcwd()
Expand All @@ -241,7 +265,8 @@

# Call GetEvaluatorUploadEndpoint using SDK
logger.info(f"Requesting upload endpoint for {tar_filename}")
upload_response = client.evaluators.get_upload_endpoint(
upload_response = client.evaluator_versions.get_upload_endpoint(

Check failure on line 268 in eval_protocol/evaluation.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Cannot access attribute "evaluator_versions" for class "Fireworks"   Attribute "evaluator_versions" is unknown (reportAttributeAccessIssue)
version_id=evaluator_version_id,
evaluator_id=evaluator_id,
filename_to_size={tar_filename: str(tar_size)},
)
Expand Down Expand Up @@ -322,9 +347,9 @@
raise

# Step 3: Validate upload using SDK
client.evaluators.validate_upload(
client.evaluator_versions.validate_upload(

Check failure on line 350 in eval_protocol/evaluation.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Cannot access attribute "evaluator_versions" for class "Fireworks"   Attribute "evaluator_versions" is unknown (reportAttributeAccessIssue)
version_id=evaluator_version_id,
evaluator_id=evaluator_id,
body={},
)
logger.info("Upload validated successfully")

Expand Down
132 changes: 132 additions & 0 deletions eval_protocol/fireworks_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
Consolidated Fireworks client factory.

This module provides a single point of instantiation for the Fireworks SDK client,
ensuring consistent handling of environment variables and configuration across the
eval_protocol codebase.

Environment variables:
FIREWORKS_API_KEY: API key for authentication (required)
FIREWORKS_ACCOUNT_ID: Account ID (optional, can be derived from API key)
FIREWORKS_API_BASE: Base URL for the API (default: https://api.fireworks.ai)
FIREWORKS_EXTRA_HEADERS: JSON-encoded extra headers to include in requests
Example: '{"X-Custom-Header": "value", "X-Another": "another-value"}'
"""

import json
import logging
import os
from typing import Mapping, Optional

from fireworks import Fireworks

from eval_protocol.auth import (
get_fireworks_account_id,
get_fireworks_api_base,
get_fireworks_api_key,
)

logger = logging.getLogger(__name__)


def get_fireworks_extra_headers() -> Optional[Mapping[str, str]]:
"""
Retrieves extra headers from the FIREWORKS_EXTRA_HEADERS environment variable.

The value should be a JSON-encoded object mapping header names to values.
Example: '{"X-Custom-Header": "value"}'

Returns:
A mapping of header names to values if set and valid, otherwise None.
"""
extra_headers_str = os.environ.get("FIREWORKS_EXTRA_HEADERS")
if not extra_headers_str or not extra_headers_str.strip():
return None

try:
headers = json.loads(extra_headers_str)
if not isinstance(headers, dict):
logger.warning(
"FIREWORKS_EXTRA_HEADERS must be a JSON object, got %s. Ignoring.",
type(headers).__name__,
)
return None
# Validate all keys and values are strings
for k, v in headers.items():
if not isinstance(k, str) or not isinstance(v, str):
logger.warning(
"FIREWORKS_EXTRA_HEADERS contains non-string key or value: %s=%s. Ignoring all extra headers.",
k,
v,
)
return None
logger.debug("Using FIREWORKS_EXTRA_HEADERS: %s", list(headers.keys()))
return headers
except json.JSONDecodeError as e:
logger.warning("Failed to parse FIREWORKS_EXTRA_HEADERS as JSON: %s. Ignoring.", e)
return None


def create_fireworks_client(
*,
api_key: Optional[str] = None,
account_id: Optional[str] = None,
base_url: Optional[str] = None,
extra_headers: Optional[Mapping[str, str]] = None,
) -> Fireworks:
"""
Create a Fireworks client with consistent configuration.

This factory function centralizes the logic for creating Fireworks clients,
ensuring that environment variables are handled consistently across the codebase.

Resolution order for each parameter:
1. Explicit argument passed to this function
2. Environment variable (via auth module helpers)
3. SDK defaults (for base_url only)

Args:
api_key: Fireworks API key. If not provided, resolves from FIREWORKS_API_KEY.
account_id: Fireworks account ID. If not provided, resolves from FIREWORKS_ACCOUNT_ID
or derives from the API key via the verifyApiKey endpoint.
base_url: Base URL for the Fireworks API. If not provided, resolves from
FIREWORKS_API_BASE or defaults to https://api.fireworks.ai.
extra_headers: Additional headers to include in all requests. If not provided,
resolves from FIREWORKS_EXTRA_HEADERS environment variable (JSON-encoded).

Returns:
A configured Fireworks client instance.

Raises:
fireworks.FireworksError: If api_key is not provided and FIREWORKS_API_KEY
environment variable is not set.
"""
# Resolve parameters from environment if not explicitly provided
resolved_api_key = api_key or get_fireworks_api_key()
resolved_account_id = account_id or get_fireworks_account_id()
resolved_base_url = base_url or get_fireworks_api_base()

# Merge extra headers: env var headers first, then explicit headers override
env_extra_headers = get_fireworks_extra_headers()
merged_headers: Optional[Mapping[str, str]] = None
if env_extra_headers or extra_headers:
merged = {}
if env_extra_headers:
merged.update(env_extra_headers)
if extra_headers:
merged.update(extra_headers)
merged_headers = merged if merged else None

logger.debug(
"Creating Fireworks client: base_url=%s, account_id=%s, extra_headers=%s",
resolved_base_url,
resolved_account_id,
list(merged_headers.keys()) if merged_headers else None,
)

return Fireworks(
api_key=resolved_api_key,
account_id=resolved_account_id,
base_url=resolved_base_url,
default_headers=merged_headers,
)
Loading
Loading