Skip to content

Commit 3b84fe4

Browse files
author
Dylan Huang
authored
ElasticSearch logging in RemoteRolloutProcessor (part 4) (#249)
* logs can be used to update status * remote rollout processor works with logging! TODO: propagate non-running stauses in evaluation_test * vite build * vite build * less frame shift * Enhance ElasticsearchClient to support sorting in search_by_match method and update LogsServer to utilize this feature for retrieving logs sorted by timestamp. * vite build * fix ordering of logs * fix imports
1 parent 77bbb11 commit 3b84fe4

20 files changed

+387
-307
lines changed

eval_protocol/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
from .quickstart import aha_judge, multi_turn_assistant_to_ground_truth, assistant_to_ground_truth
3333
from .pytest import evaluation_test, SingleTurnRolloutProcessor, RemoteRolloutProcessor
3434
from .pytest.parameterize import DefaultParameterIdGenerator
35+
from .log_utils.elasticsearch_direct_http_handler import ElasticsearchDirectHttpHandler
36+
from .log_utils.rollout_id_filter import RolloutIdFilter
3537

3638
from .types.remote_rollout_processor import (
3739
InitRequest,
@@ -63,6 +65,8 @@
6365
warnings.filterwarnings("default", category=DeprecationWarning, module="eval_protocol")
6466

6567
__all__ = [
68+
"ElasticsearchDirectHttpHandler",
69+
"RolloutIdFilter",
6670
"Status",
6771
"RemoteRolloutProcessor",
6872
"InputMetadata",

eval_protocol/log_utils/__init__.py

Whitespace-only changes.

eval_protocol/logging/elasticsearch_client.py renamed to eval_protocol/log_utils/elasticsearch_client.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
and search functionality.
77
"""
88

9-
import json
109
import requests
1110
from typing import Any, Dict, List, Optional, Union
12-
from urllib.parse import urlparse
11+
from eval_protocol.models import Status
1312
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
1413

1514

@@ -203,33 +202,39 @@ def search(
203202
except Exception:
204203
return None
205204

206-
def search_by_term(self, field: str, value: Any, size: int = 10) -> Optional[Dict[str, Any]]:
205+
def search_by_term(
206+
self, field: str, value: Any, size: int = 10, sort: Optional[List[Dict[str, Any]]] = None
207+
) -> Optional[Dict[str, Any]]:
207208
"""Search documents by exact term match.
208209
209210
Args:
210211
field: Field name to search
211212
value: Value to match
212213
size: Number of results to return
214+
sort: Sort specification
213215
214216
Returns:
215217
Dict containing search results, or None if failed
216218
"""
217219
query = {"term": {field: value}}
218-
return self.search(query, size=size)
220+
return self.search(query, size=size, sort=sort)
219221

220-
def search_by_match(self, field: str, value: str, size: int = 10) -> Optional[Dict[str, Any]]:
222+
def search_by_match(
223+
self, field: str, value: str, size: int = 10, sort: Optional[List[Dict[str, Any]]] = None
224+
) -> Optional[Dict[str, Any]]:
221225
"""Search documents by text match.
222226
223227
Args:
224228
field: Field name to search
225229
value: Text to match
226230
size: Number of results to return
231+
sort: Sort specification (e.g., [{"@timestamp": {"order": "desc"}}])
227232
228233
Returns:
229234
Dict containing search results, or None if failed
230235
"""
231236
query = {"match": {field: value}}
232-
return self.search(query, size=size)
237+
return self.search(query, size=size, sort=sort)
233238

234239
def search_by_match_phrase_prefix(self, field: str, value: str, size: int = 10) -> Optional[Dict[str, Any]]:
235240
"""Search documents by phrase prefix match.
@@ -257,6 +262,34 @@ def search_all(self, size: int = 10) -> Optional[Dict[str, Any]]:
257262
query = {"match_all": {}}
258263
return self.search(query, size=size)
259264

265+
def search_by_status_code_not_in(
266+
self,
267+
rollout_id: str,
268+
excluded_codes: list[Status.Code],
269+
size: int = 10,
270+
) -> Optional[Dict[str, Any]]:
271+
"""
272+
Search documents where status_code does NOT match any of the provided status codes.
273+
274+
Args:
275+
excluded_codes: List of status codes to exclude (i.e., find logs NOT having these codes)
276+
size: Number of results to return
277+
rollout_id: Optional rollout ID to filter by
278+
279+
Returns:
280+
Dict containing search results, or None if failed
281+
"""
282+
# Build the query with must_not for status code exclusion
283+
bool_query: dict[str, list[dict[str, Any]]] = {
284+
"must_not": [{"terms": {"status_code": [code.value for code in excluded_codes]}}]
285+
}
286+
287+
# Add rollout_id filter and ensure status_code exists
288+
bool_query["must"] = [{"term": {"rollout_id": rollout_id}}, {"exists": {"field": "status_code"}}]
289+
290+
query = {"bool": bool_query}
291+
return self.search(query, size=size)
292+
260293
# Health and Status Operations
261294

262295
def health_check(self) -> bool:

eval_protocol/logging/elasticsearch_direct_http_handler.py renamed to eval_protocol/log_utils/elasticsearch_direct_http_handler.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,51 @@
1-
import json
21
import logging
3-
import asyncio
42
import os
5-
import threading
63
from concurrent.futures import ThreadPoolExecutor
7-
from typing import Optional, Tuple, Any, Dict
4+
from typing import Optional, Any, Dict
85
from datetime import datetime
96

107
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
118
from .elasticsearch_client import ElasticsearchClient
129

10+
import logging
11+
12+
logger = logging.getLogger(__name__)
13+
14+
# do not inherit root logger since we are a handler ourselves
15+
logger.propagate = False
16+
17+
logger.addHandler(logging.StreamHandler())
18+
19+
if os.environ.get("EP_DEBUG") == "true":
20+
logger.setLevel(logging.DEBUG)
21+
logger.debug("EP_DEBUG=true detected, set log level to DEBUG")
22+
1323

1424
class ElasticsearchDirectHttpHandler(logging.Handler):
15-
def __init__(self, elasticsearch_config: ElasticsearchConfig) -> None:
25+
def __init__(self, elasticsearch_config: ElasticsearchConfig | None = None) -> None:
1626
super().__init__()
17-
self.config = ElasticsearchConfig(
18-
url=elasticsearch_config.url,
19-
api_key=elasticsearch_config.api_key,
20-
index_name=elasticsearch_config.index_name,
21-
)
22-
self.client = ElasticsearchClient(self.config)
27+
self.config = elasticsearch_config
28+
self.client = ElasticsearchClient(self.config) if self.config else None
2329
self.formatter: logging.Formatter = logging.Formatter()
2430
self._executor = None
2531

32+
def configure(self, elasticsearch_config: ElasticsearchConfig) -> None:
33+
self.config = elasticsearch_config
34+
self.client = ElasticsearchClient(self.config)
35+
2636
def emit(self, record: logging.LogRecord) -> None:
2737
"""Emit a log record by scheduling it for async transmission."""
2838
try:
2939
# Create proper ISO 8601 timestamp
3040
timestamp = datetime.fromtimestamp(record.created).strftime("%Y-%m-%dT%H:%M:%S.%fZ")
3141

3242
rollout_id = self._get_rollout_id(record)
43+
logger.debug(f"Emitting log record: {record.getMessage()} with rollout_id: {rollout_id}")
44+
if not rollout_id:
45+
logger.debug(
46+
"No rollout_id provided in extra data for ElasticsearchDirectHttpHandler through EP_ROLLOUT_ID environment variable or rollout_id extra data. Skipping log record."
47+
)
48+
return
3349
status_info = self._get_status_info(record)
3450

3551
data: Dict[str, Any] = {
@@ -50,18 +66,14 @@ def emit(self, record: logging.LogRecord) -> None:
5066
self.handleError(record)
5167
print(f"Error preparing log for Elasticsearch: {e}")
5268

53-
def _get_rollout_id(self, record: logging.LogRecord) -> str:
69+
def _get_rollout_id(self, record: logging.LogRecord) -> str | None:
5470
"""Get the rollout ID from record extra data or environment variables."""
5571
# Check if rollout_id is provided in the extra data first
5672
if hasattr(record, "rollout_id") and record.rollout_id is not None: # type: ignore
5773
return str(record.rollout_id) # type: ignore
5874

5975
# Fall back to environment variable
6076
rollout_id = os.getenv("EP_ROLLOUT_ID")
61-
if rollout_id is None:
62-
raise ValueError(
63-
"EP_ROLLOUT_ID environment variable is not set and no rollout_id provided in extra data for ElasticsearchDirectHttpHandler"
64-
)
6577
return rollout_id
6678

6779
def _get_status_info(self, record: logging.LogRecord) -> Optional[Dict[str, Any]]:
@@ -105,6 +117,9 @@ def _schedule_async_send(self, data: Dict[str, Any], record: logging.LogRecord)
105117

106118
def _send_to_elasticsearch(self, data: Dict[str, Any], record: logging.LogRecord) -> None:
107119
"""Send data to Elasticsearch (runs in thread pool)."""
120+
if not self.client:
121+
logger.warning("No Elasticsearch client configured, skipping log record")
122+
return
108123
try:
109124
success = self.client.index_document(data)
110125
if not success:
File renamed without changes.
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import logging
2+
import os
3+
4+
logger = logging.getLogger(__name__)
5+
6+
# do not inherit root logger since we are a handler ourselves
7+
logger.propagate = False
8+
9+
logger.addHandler(logging.StreamHandler())
10+
11+
if os.environ.get("EP_DEBUG") == "true":
12+
logger.setLevel(logging.DEBUG)
13+
logger.debug("EP_DEBUG=true detected, set log level to DEBUG")
14+
15+
16+
class RolloutIdFilter(logging.Filter):
17+
"""
18+
A filter that simply adds the rollout_id to the record so that you don't
19+
have to pass it as extra data every time you log.
20+
"""
21+
22+
def __init__(self, rollout_id: str):
23+
self.rollout_id = rollout_id
24+
25+
def filter(self, record):
26+
logger.debug(f"Filtering record with rollout_id: {self.rollout_id}")
27+
record.rollout_id = self.rollout_id
28+
return True

eval_protocol/pytest/elasticsearch_setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dotenv import load_dotenv
88
from eval_protocol.directory_utils import find_eval_protocol_dir
99
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig
10-
from eval_protocol.logging.elasticsearch_index_manager import ElasticsearchIndexManager
10+
from eval_protocol.log_utils.elasticsearch_index_manager import ElasticsearchIndexManager
1111

1212
logger = logging.getLogger(__name__)
1313

eval_protocol/pytest/remote_rollout_processor.py

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import requests
66

7-
from eval_protocol.logging.elasticsearch_client import ElasticsearchClient
7+
from eval_protocol.log_utils.elasticsearch_client import ElasticsearchClient
88
from eval_protocol.models import EvaluationRow, Status
99
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
1010
from eval_protocol.types.remote_rollout_processor import ElasticsearchConfig, InitRequest, RolloutMetadata
@@ -188,27 +188,50 @@ def _get_status() -> Dict[str, Any]:
188188
ElasticsearchClient(self._elastic_search_config) if self._elastic_search_config else None
189189
)
190190

191+
continue_polling_status = True
191192
while time.time() < deadline:
192193
try:
193-
status = await asyncio.to_thread(_get_status)
194-
terminated = bool(status.get("terminated", False))
195-
if terminated:
196-
break
194+
if continue_polling_status:
195+
status = await asyncio.to_thread(_get_status)
196+
terminated = bool(status.get("terminated", False))
197+
if terminated:
198+
break
197199
except requests.exceptions.HTTPError as e:
198200
if e.response is not None and e.response.status_code == 404:
199201
# 404 means server doesn't implement /status endpoint, stop polling
200202
logger.info(
201203
f"Server doesn't implement /status endpoint (404), stopping status polling for rollout {row.execution_metadata.rollout_id}"
202204
)
203-
break
205+
continue_polling_status = False
204206
else:
205207
raise
206208
except Exception:
207209
# For all other exceptions, raise them
208210
raise
209211

212+
if not elasticsearch_client:
213+
continue
214+
215+
search_results = elasticsearch_client.search_by_status_code_not_in(
216+
row.execution_metadata.rollout_id, [Status.Code.RUNNING]
217+
)
218+
hits = search_results["hits"]["hits"] if search_results else []
219+
220+
if hits:
221+
# log all statuses found
222+
for hit in hits:
223+
document = hit["_source"]
224+
logger.info(
225+
f"Found log for rollout {row.execution_metadata.rollout_id} with status code {document['status_code']}"
226+
)
227+
logger.info("Stopping status polling for rollout %s", row.execution_metadata.rollout_id)
228+
break
229+
210230
await asyncio.sleep(poll_interval)
211231
else:
232+
logger.info(
233+
f"Loop completed without breaking for {row.execution_metadata.rollout_id}, which means we timed out"
234+
)
212235
# Loop completed without breaking, which means we timed out
213236
row.rollout_status = Status.rollout_error(
214237
f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds"
@@ -234,23 +257,20 @@ def _load_data():
234257
return row
235258
elif len(output_rows) == 1: # Return the Langfuse row
236259
langfuse_row = output_rows[0]
237-
langfuse_row.input_metadata.completion_params = row.input_metadata.completion_params
238-
# merge dataset_info dicts on input_metadata
239-
if langfuse_row.input_metadata.dataset_info and row.input_metadata.dataset_info:
240-
langfuse_row.input_metadata.dataset_info = {
241-
**row.input_metadata.dataset_info,
242-
**langfuse_row.input_metadata.dataset_info,
243-
}
244-
elif row.input_metadata.dataset_info:
245-
langfuse_row.input_metadata.dataset_info = row.input_metadata.dataset_info
246-
langfuse_row.eval_metadata = row.eval_metadata
247-
langfuse_row.ground_truth = row.ground_truth
248260

249-
# this is useful to detect stopped evaluations so we can update
250-
# the status in the logs server
251-
langfuse_row.pid = row.pid
261+
# if the langfuse_row has the same number of messages as the original row,
262+
# something went wrong
263+
if len(langfuse_row.messages) == len(row.messages):
264+
row.rollout_status = Status.rollout_error(
265+
"Rollout finished with the same number of messages as the original row"
266+
)
267+
return row
252268

253-
return langfuse_row
269+
row.messages = langfuse_row.messages
270+
row.tools = langfuse_row.tools
271+
row.input_metadata.session_data = langfuse_row.input_metadata.session_data
272+
row.execution_metadata = langfuse_row.execution_metadata
273+
return row
254274
else:
255275
raise ValueError("RemoteRolloutProcessor's output_data_loader should return exactly one row.")
256276

0 commit comments

Comments
 (0)