Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions eval_protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
from .pytest.parameterize import DefaultParameterIdGenerator
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
from .log_utils.rollout_id_filter import RolloutIdFilter

from .types.remote_rollout_processor import (
InitRequest,
Expand Down Expand Up @@ -63,6 +65,8 @@
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")

__all__ = [
"ElasticsearchDirectHttpHandler",
"RolloutIdFilter",
"Status",
"RemoteRolloutProcessor",
"InputMetadata",
Expand Down
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
and search functionality.
"""

import json
import requests
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
from eval_protocol.models import Status
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig


Expand Down Expand Up @@ -203,33 +202,39 @@ def search(
except Exception:
return None

def search_by_term(self, field: str, value: Any, size: int = 10) -> Optional[Dict[str, Any]]:
def search_by_term(
self, field: str, value: Any, size: int = 10, sort: Optional[List[Dict[str, Any]]] = None
) -> Optional[Dict[str, Any]]:
"""Search documents by exact term match.

Args:
field: Field name to search
value: Value to match
size: Number of results to return
sort: Sort specification

Returns:
Dict containing search results, or None if failed
"""
query = {"term": {field: value}}
return self.search(query, size=size)
return self.search(query, size=size, sort=sort)

def search_by_match(self, field: str, value: str, size: int = 10) -> Optional[Dict[str, Any]]:
def search_by_match(
self, field: str, value: str, size: int = 10, sort: Optional[List[Dict[str, Any]]] = None
) -> Optional[Dict[str, Any]]:
"""Search documents by text match.

Args:
field: Field name to search
value: Text to match
size: Number of results to return
sort: Sort specification (e.g., [{"@timestamp": {"order": "desc"}}])

Returns:
Dict containing search results, or None if failed
"""
query = {"match": {field: value}}
return self.search(query, size=size)
return self.search(query, size=size, sort=sort)

def search_by_match_phrase_prefix(self, field: str, value: str, size: int = 10) -> Optional[Dict[str, Any]]:
"""Search documents by phrase prefix match.
Expand Down Expand Up @@ -257,6 +262,34 @@ def search_all(self, size: int = 10) -> Optional[Dict[str, Any]]:
query = {"match_all": {}}
return self.search(query, size=size)

def search_by_status_code_not_in(
self,
rollout_id: str,
excluded_codes: list[Status.Code],
size: int = 10,
) -> Optional[Dict[str, Any]]:
"""
Search documents where status_code does NOT match any of the provided status codes.

Args:
excluded_codes: List of status codes to exclude (i.e., find logs NOT having these codes)
size: Number of results to return
rollout_id: Optional rollout ID to filter by

Returns:
Dict containing search results, or None if failed
"""
# Build the query with must_not for status code exclusion
bool_query: dict[str, list[dict[str, Any]]] = {
"must_not": [{"terms": {"status_code": [code.value for code in excluded_codes]}}]
}

# Add rollout_id filter and ensure status_code exists
bool_query["must"] = [{"term": {"rollout_id": rollout_id}}, {"exists": {"field": "status_code"}}]

query = {"bool": bool_query}
return self.search(query, size=size)

# Health and Status Operations

def health_check(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,51 @@
import json
import logging
import asyncio
import os
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, Tuple, Any, Dict
from typing import Optional, Any, Dict
from datetime import datetime

from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
from .elasticsearch_client import ElasticsearchClient

import logging

logger = logging.getLogger(__name__)

# do not inherit root logger since we are a handler ourselves
logger.propagate = False

logger.addHandler(logging.StreamHandler())

if os.environ.get("EP_DEBUG") == "true":
logger.setLevel(logging.DEBUG)
logger.debug("EP_DEBUG=true detected, set log level to DEBUG")


class ElasticsearchDirectHttpHandler(logging.Handler):
def __init__(self, elasticsearch_config: ElasticsearchConfig) -> None:
def __init__(self, elasticsearch_config: ElasticsearchConfig | None = None) -> None:
super().__init__()
self.config = ElasticsearchConfig(
url=elasticsearch_config.url,
api_key=elasticsearch_config.api_key,
index_name=elasticsearch_config.index_name,
)
self.client = ElasticsearchClient(self.config)
self.config = elasticsearch_config
self.client = ElasticsearchClient(self.config) if self.config else None
self.formatter: logging.Formatter = logging.Formatter()
self._executor = None

def configure(self, elasticsearch_config: ElasticsearchConfig) -> None:
self.config = elasticsearch_config
self.client = ElasticsearchClient(self.config)

def emit(self, record: logging.LogRecord) -> None:
"""Emit a log record by scheduling it for async transmission."""
try:
# Create proper ISO 8601 timestamp
timestamp = datetime.fromtimestamp(record.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ")

rollout_id = self._get_rollout_id(record)
logger.debug(f"Emitting log record: {record.getMessage()} with rollout_id: {rollout_id}")
if not rollout_id:
logger.debug(
"No rollout_id provided in extra data for ElasticsearchDirectHttpHandler through EP_ROLLOUT_ID environment variable or rollout_id extra data. Skipping log record."
)
return
status_info = self._get_status_info(record)

data: Dict[str, Any] = {
Expand All @@ -50,18 +66,14 @@ def emit(self, record: logging.LogRecord) -> None:
self.handleError(record)
print(f"Error preparing log for Elasticsearch: {e}")

def _get_rollout_id(self, record: logging.LogRecord) -> str:
def _get_rollout_id(self, record: logging.LogRecord) -> str | None:
"""Get the rollout ID from record extra data or environment variables."""
# Check if rollout_id is provided in the extra data first
if hasattr(record, "rollout_id") and record.rollout_id is not None: # type: ignore
return str(record.rollout_id) # type: ignore

# Fall back to environment variable
rollout_id = os.getenv("EP_ROLLOUT_ID")
if rollout_id is None:
raise ValueError(
"EP_ROLLOUT_ID environment variable is not set and no rollout_id provided in extra data for ElasticsearchDirectHttpHandler"
)
return rollout_id

def _get_status_info(self, record: logging.LogRecord) -> Optional[Dict[str, Any]]:
Expand Down Expand Up @@ -105,6 +117,9 @@ def _schedule_async_send(self, data: Dict[str, Any], record: logging.LogRecord)

def _send_to_elasticsearch(self, data: Dict[str, Any], record: logging.LogRecord) -> None:
"""Send data to Elasticsearch (runs in thread pool)."""
if not self.client:
logger.warning("No Elasticsearch client configured, skipping log record")
return
try:
success = self.client.index_document(data)
if not success:
Expand Down
28 changes: 28 additions & 0 deletions eval_protocol/log_utils/rollout_id_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging
import os

logger = logging.getLogger(__name__)

# do not inherit root logger since we are a handler ourselves
logger.propagate = False

logger.addHandler(logging.StreamHandler())

if os.environ.get("EP_DEBUG") == "true":
logger.setLevel(logging.DEBUG)
logger.debug("EP_DEBUG=true detected, set log level to DEBUG")


class RolloutIdFilter(logging.Filter):
"""
A filter that simply adds the rollout_id to the record so that you don't
have to pass it as extra data every time you log.
"""

def __init__(self, rollout_id: str):
self.rollout_id = rollout_id

def filter(self, record):
logger.debug(f"Filtering record with rollout_id: {self.rollout_id}")
record.rollout_id = self.rollout_id
return True
2 changes: 1 addition & 1 deletion eval_protocol/pytest/elasticsearch_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from dotenv import load_dotenv
from eval_protocol.directory_utils import find_eval_protocol_dir
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
from eval_protocol.logging.elasticsearch_index_manager import ElasticsearchIndexManager
from eval_protocol.log_utils.elasticsearch_index_manager import ElasticsearchIndexManager

logger = logging.getLogger(__name__)

Expand Down
62 changes: 41 additions & 21 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import requests

from eval_protocol.logging.elasticsearch_client import ElasticsearchClient
from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata
Expand Down Expand Up @@ -188,27 +188,50 @@ def _get_status() -> Dict[str, Any]:
ElasticsearchClient(self._elastic_search_config) if self._elastic_search_config else None
)

continue_polling_status = True
while time.time() < deadline:
try:
status = await asyncio.to_thread(_get_status)
terminated = bool(status.get("terminated", False))
if terminated:
break
if continue_polling_status:
status = await asyncio.to_thread(_get_status)
terminated = bool(status.get("terminated", False))
if terminated:
break
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code == 404:
# 404 means server doesn't implement /status endpoint, stop polling
logger.info(
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
)
break
continue_polling_status = False
else:
raise
except Exception:
# For all other exceptions, raise them
raise

if not elasticsearch_client:
continue

search_results = elasticsearch_client.search_by_status_code_not_in(
row.execution_metadata.rollout_id, [Status.Code.RUNNING]
)
hits = search_results["hits"]["hits"] if search_results else []

if hits:
# log all statuses found
for hit in hits:
document = hit["_source"]
logger.info(
f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}"
)
logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id)
break

await asyncio.sleep(poll_interval)
else:
logger.info(
f"Loop completed without breaking for {row.execution_metadata.rollout_id}, which means we timed out"
)
# Loop completed without breaking, which means we timed out
row.rollout_status = Status.rollout_error(
f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds"
Expand All @@ -234,23 +257,20 @@ def _load_data():
return row
elif len(output_rows) == 1: # Return the Langfuse row
langfuse_row = output_rows[0]
langfuse_row.input_metadata.completion_params = row.input_metadata.completion_params
# merge dataset_info dicts on input_metadata
if langfuse_row.input_metadata.dataset_info and row.input_metadata.dataset_info:
langfuse_row.input_metadata.dataset_info = {
**row.input_metadata.dataset_info,
**langfuse_row.input_metadata.dataset_info,
}
elif row.input_metadata.dataset_info:
langfuse_row.input_metadata.dataset_info = row.input_metadata.dataset_info
langfuse_row.eval_metadata = row.eval_metadata
langfuse_row.ground_truth = row.ground_truth

# this is useful to detect stopped evaluations so we can update
# the status in the logs server
langfuse_row.pid = row.pid
# if the langfuse_row has the same number of messages as the original row,
# something went wrong
if len(langfuse_row.messages) == len(row.messages):
row.rollout_status = Status.rollout_error(
"Rollout finished with the same number of messages as the original row"
)
return row

return langfuse_row
row.messages = langfuse_row.messages
row.tools = langfuse_row.tools
row.input_metadata.session_data = langfuse_row.input_metadata.session_data
row.execution_metadata = langfuse_row.execution_metadata
return row
else:
raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.")

Expand Down
Loading
Loading