Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Elasticsearch Tests
name: Fireworks Tracing Tests

on:
push:
Expand All @@ -13,8 +13,8 @@ on:
workflow_dispatch: # Allow manual triggering

jobs:
elasticsearch-tests:
name: Elasticsearch Integration Tests
fireworks-tracing-tests:
name: Fireworks Tracing Integration Tests
runs-on: ubuntu-latest

steps:
Expand All @@ -36,14 +36,15 @@ jobs:
- name: Install the project
run: uv sync --locked --all-extras --dev

- name: Run Elasticsearch Tests
- name: Run Fireworks Tracing Tests
env:
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
PYTHONWARNINGS: "ignore::DeprecationWarning,ignore::RuntimeWarning"
run: |
# Run Elasticsearch direct HTTP handler tests
uv run pytest tests/logging/test_elasticsearch_direct_http_handler.py -v --tb=short
# Run RemoteRolloutProcessor End-to-End Test (auto server startup)
uv run pytest tests/remote_server/test_remote_fireworks.py::test_remote_rollout_and_fetch_fireworks \
-v --tb=short

# Run RemoteRolloutProcessor Propagate Status Smoke Test (also uses Elasticsearch)
# Run RemoteRolloutProcessor Propagate Status Test (auto server startup)
uv run pytest tests/remote_server/test_remote_fireworks_propagate_status.py::test_remote_rollout_and_fetch_fireworks_propagate_status \
-v --tb=short
2 changes: 0 additions & 2 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
filter_longest_conversation,
)
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor, GithubActionRolloutProcessor
from .pytest.remote_rollout_processor import create_elasticsearch_config_from_env
from .pytest.parameterize import DefaultParameterIdGenerator
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
from .log_utils.rollout_id_filter import RolloutIdFilter
Expand Down Expand Up @@ -90,7 +89,6 @@
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

__all__ = [
"create_elasticsearch_config_from_env",
"ElasticsearchConfig",
"ElasticsearchDirectHttpHandler",
"RolloutIdFilter",
Expand Down
1 change: 1 addition & 0 deletions eval_protocol/adapters/fireworks_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def search_logs(self, tags: List[str], limit: int = 100, hours_back: int = 24) -
"message": e.get("message"),
"severity": e.get("severity", "INFO"),
"tags": e.get("tags", []),
"status": e.get("status"),
}
)
return results
Expand Down
43 changes: 0 additions & 43 deletions eval_protocol/cli_commands/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,49 +39,6 @@ def logs_command(args):
or os.environ.get("GATEWAY_URL")
or "https://tracing.fireworks.ai"
)
try:
if not use_fireworks:
if getattr(args, "use_env_elasticsearch_config", False):
# Use environment variables for configuration
print("⚙️ Using environment variables for Elasticsearch config")
from eval_protocol.pytest.remote_rollout_processor import (
create_elasticsearch_config_from_env,
)

elasticsearch_config = create_elasticsearch_config_from_env()
# Ensure index exists with correct mapping, mirroring Docker setup path
try:
from eval_protocol.log_utils.elasticsearch_index_manager import (
ElasticsearchIndexManager,
)

index_manager = ElasticsearchIndexManager(
elasticsearch_config.url,
elasticsearch_config.index_name,
elasticsearch_config.api_key,
)
created = index_manager.create_logging_index_mapping()
if created:
print(
f"🧭 Verified Elasticsearch index '{elasticsearch_config.index_name}' mapping (created or already correct)"
)
else:
print(
f"⚠️ Could not verify/create mapping for index '{elasticsearch_config.index_name}'. Searches may behave unexpectedly."
)
except Exception as e:
print(f"⚠️ Failed to ensure index mapping via IndexManager: {e}")
elif not getattr(args, "disable_elasticsearch_setup", False):
# Default behavior: start or connect to local Elasticsearch via Docker helper
from eval_protocol.pytest.elasticsearch_setup import ElasticsearchSetup

print("🧰 Auto-configuring local Elasticsearch (Docker)")
elasticsearch_config = ElasticsearchSetup().setup_elasticsearch()
else:
print("🚫 Elasticsearch setup disabled; running without Elasticsearch integration")
except Exception as e:
print(f"❌ Failed to configure Elasticsearch: {e}")
return 1

try:
serve_logs(
Expand Down
48 changes: 30 additions & 18 deletions eval_protocol/log_utils/fireworks_tracing_http_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,34 @@ def _get_rollout_id(self, record: logging.LogRecord) -> Optional[str]:
return str(cast(Any, getattr(record, "rollout_id")))
return os.getenv(self.rollout_id_env)

def _get_status_info(self, record: logging.LogRecord) -> Optional[Dict[str, Any]]:
"""Extract status information from the log record's extra data."""
# Check if 'status' is in the extra data (passed via extra parameter)
if hasattr(record, "status") and record.status is not None: # type: ignore
status = record.status # type: ignore

# Handle Status class instances (Pydantic BaseModel)
if hasattr(status, "code") and hasattr(status, "message"):
# Status object - extract code and message
status_code = status.code
# Handle both enum values and direct integer values
if hasattr(status_code, "value"):
status_code = status_code.value

return {
"code": status_code,
"message": status.message,
"details": getattr(status, "details", []),
}
elif isinstance(status, dict):
# Dictionary representation of status
return {
"code": status.get("code"),
"message": status.get("message"),
"details": status.get("details", []),
}
return None

def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str, Any]:
timestamp = datetime.fromtimestamp(record.created, tz=timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
message = record.getMessage()
Expand All @@ -96,28 +124,12 @@ def _build_payload(self, record: logging.LogRecord, rollout_id: str) -> Dict[str
except Exception:
pass
program = cast(Optional[str], getattr(record, "program", None)) or "eval_protocol"
status_val = cast(Any, getattr(record, "status", None))
status = status_val if isinstance(status_val, str) else None
# Capture optional structured status fields if present
metadata: Dict[str, Any] = {}
status_code = cast(Any, getattr(record, "status_code", None))
if isinstance(status_code, int):
metadata["status_code"] = status_code
status_message = cast(Any, getattr(record, "status_message", None))
if isinstance(status_message, str):
metadata["status_message"] = status_message
status_details = getattr(record, "status_details", None)
if status_details is not None:
metadata["status_details"] = status_details
extra_metadata = cast(Any, getattr(record, "metadata", None))
if isinstance(extra_metadata, dict):
metadata.update(extra_metadata)

return {
"program": program,
"status": status,
"status": self._get_status_info(record),
"message": message,
"tags": tags,
"metadata": metadata or None,
"extras": {
"logger_name": record.name,
"level": record.levelname,
Expand Down
102 changes: 37 additions & 65 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@

import requests

from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import (
DataLoaderConfig,
ElasticsearchConfig,
)
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter

from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
from .elasticsearch_setup import ElasticsearchSetup
from .tracing_utils import default_fireworks_output_data_loader, build_init_request, update_row_with_remote_trace
import logging

Expand All @@ -22,25 +21,6 @@
logger = logging.getLogger(__name__)


def create_elasticsearch_config_from_env() -> ElasticsearchConfig:
"""Setup Elasticsearch config from environment variables."""
url = os.getenv("ELASTICSEARCH_URL")
api_key = os.getenv("ELASTICSEARCH_API_KEY")
index_name = os.getenv("ELASTICSEARCH_INDEX_NAME")

if url is None:
raise ValueError("ELASTICSEARCH_URL must be set")
if api_key is None:
raise ValueError("ELASTICSEARCH_API_KEY must be set")
if index_name is None:
raise ValueError("ELASTICSEARCH_INDEX_NAME must be set")
return ElasticsearchConfig(
url=url,
api_key=api_key,
index_name=index_name,
)


class RemoteRolloutProcessor(RolloutProcessor):
"""
Rollout processor that triggers a remote HTTP server to perform the rollout.
Expand All @@ -59,8 +39,6 @@ def __init__(
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
disable_elastic_search_setup: bool = False,
elastic_search_config: Optional[ElasticsearchConfig] = None,
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
Expand All @@ -74,21 +52,7 @@ def __init__(
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
self._disable_elastic_search_setup = disable_elastic_search_setup
self._elastic_search_config = elastic_search_config

def setup(self) -> None:
if self._disable_elastic_search_setup:
logger.info("Elasticsearch is disabled, skipping setup")
return
logger.info("Setting up Elasticsearch")
self._elastic_search_config = self._setup_elastic_search()
logger.info("Elasticsearch setup complete")

def _setup_elastic_search(self) -> ElasticsearchConfig:
"""Set up Elasticsearch using the dedicated setup module."""
setup = ElasticsearchSetup()
return setup.setup_elasticsearch()
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
tasks: List[asyncio.Task[EvaluationRow]] = []
Expand Down Expand Up @@ -123,7 +87,7 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow:
if row.input_metadata.row_id is None:
raise ValueError("Row ID is required in RemoteRolloutProcessor")

init_payload = build_init_request(row, config, model_base_url, self._elastic_search_config)
init_payload = build_init_request(row, config, model_base_url)

# Fire-and-poll
def _post_init() -> None:
Expand Down Expand Up @@ -153,10 +117,6 @@ def _get_status() -> Dict[str, Any]:
r.raise_for_status()
return r.json()

elasticsearch_client = (
ElasticsearchClient(self._elastic_search_config) if self._elastic_search_config else None
)

continue_polling_status = True
while time.time() < deadline:
try:
Expand All @@ -178,29 +138,41 @@ def _get_status() -> Dict[str, Any]:
# For all other exceptions, raise them
raise

if not elasticsearch_client:
continue

search_results = elasticsearch_client.search_by_status_code_not_in(
row.execution_metadata.rollout_id, [Status.Code.RUNNING]
# Search Fireworks tracing logs for completion
completed_logs = self._tracing_adapter.search_logs(
tags=[f"rollout_id:{row.execution_metadata.rollout_id}"]
)
hits = search_results["hits"]["hits"] if search_results else []
# Filter for logs that actually have status information
status_logs = []
for log in completed_logs:
status_dict = log.get("status")
if status_dict and isinstance(status_dict, dict) and "code" in status_dict:
status_logs.append(log)

if status_logs:
# Use the first log with status information
status_log = status_logs[0]
status_dict = status_log.get("status")

logger.info(
f"Found status log for rollout {row.execution_metadata.rollout_id}: {status_log.get('message', '')}"
)

if hits:
# log all statuses found and update rollout status from the last hit
for hit in hits:
document = hit["_source"]
logger.info(
f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}"
)
# Update rollout status from the document
if "status_code" in document:
row.rollout_status = Status(
code=Status.Code(document["status_code"]),
message=document.get("status_message", ""),
details=document.get("status_details", []),
)
logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id)
status_code = status_dict.get("code")
status_message = status_dict.get("message", "")
status_details = status_dict.get("details", [])

logger.info(
f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}"
)

row.rollout_status = Status(
code=Status.Code(status_code),
message=status_message,
details=status_details,
)

logger.info("Stopping polling for rollout %s", row.execution_metadata.rollout_id)
break

await asyncio.sleep(poll_interval)
Expand Down
2 changes: 0 additions & 2 deletions eval_protocol/pytest/tracing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def build_init_request(
row: EvaluationRow,
config: RolloutProcessorConfig,
model_base_url: str,
elastic_search_config: Optional[Any] = None,
) -> InitRequest:
"""Build an InitRequest from an EvaluationRow and config (shared logic)."""
# Validation
Expand Down Expand Up @@ -129,7 +128,6 @@ def build_init_request(
tools=row.tools,
metadata=meta,
model_base_url=final_model_base_url,
elastic_search_config=elastic_search_config,
)


Expand Down
Loading
Loading