Skip to content

Commit dbab765

Browse files
author
Dylan Huang
committed
try
1 parent ae7a2d3 commit dbab765

File tree

4 files changed

+156
-56
lines changed

4 files changed

+156
-56
lines changed

eval_protocol/dataset_logger/tinydb_evaluation_row_store.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import json
2+
import logging
13
import os
4+
import time
25
from typing import List, Optional
36

47
from tinydb import Query, TinyDB
58
from tinyrecord.transaction import transaction
69

710
from eval_protocol.dataset_logger.evaluation_row_store import EvaluationRowStore
811

12+
logger = logging.getLogger(__name__)
13+
914

1015
class TinyDBEvaluationRowStore(EvaluationRowStore):
1116
"""
@@ -24,9 +29,30 @@ def __init__(self, db_path: str):
2429
if db_dir:
2530
os.makedirs(db_dir, exist_ok=True)
2631
self._db_path = db_path
27-
self._db = TinyDB(db_path)
32+
self._db = self._open_db_with_retry()
2833
self._table = self._db.table("evaluation_rows")
2934

35+
def _open_db_with_retry(self, max_retries: int = 3) -> TinyDB:
36+
"""Open TinyDB with retry logic to handle transient JSON decode errors."""
37+
last_error: Exception | None = None
38+
for attempt in range(max_retries):
39+
try:
40+
return TinyDB(self._db_path)
41+
except json.JSONDecodeError as e:
42+
last_error = e
43+
logger.warning(f"TinyDB JSON decode error on attempt {attempt + 1}: {e}")
44+
# Wait a bit and retry - the file might be mid-write
45+
time.sleep(0.1 * (attempt + 1))
46+
# Try to recover by removing the corrupted file
47+
if attempt == max_retries - 1 and os.path.exists(self._db_path):
48+
try:
49+
logger.warning(f"Removing corrupted TinyDB file: {self._db_path}")
50+
os.remove(self._db_path)
51+
return TinyDB(self._db_path)
52+
except Exception:
53+
pass
54+
raise last_error if last_error else RuntimeError("Failed to open TinyDB")
55+
3056
@property
3157
def db_path(self) -> str:
3258
return self._db_path
@@ -54,12 +80,25 @@ def upsert_row(self, data: dict) -> None:
5480
tr.insert(data)
5581

5682
def read_rows(self, rollout_id: Optional[str] = None) -> List[dict]:
57-
# Clear cache to ensure fresh read in multi-process scenarios
58-
self._table.clear_cache()
59-
if rollout_id is not None:
60-
Row = Query()
61-
return list(self._table.search(Row.execution_metadata.rollout_id == rollout_id))
62-
return list(self._table.all())
83+
"""Read rows with retry logic for transient JSON decode errors."""
84+
max_retries = 3
85+
for attempt in range(max_retries):
86+
try:
87+
# Clear cache to ensure fresh read in multi-process scenarios
88+
self._table.clear_cache()
89+
if rollout_id is not None:
90+
Row = Query()
91+
return list(self._table.search(Row.execution_metadata.rollout_id == rollout_id))
92+
return list(self._table.all())
93+
except json.JSONDecodeError as e:
94+
logger.warning(f"TinyDB JSON decode error on read attempt {attempt + 1}: {e}")
95+
if attempt < max_retries - 1:
96+
time.sleep(0.1 * (attempt + 1))
97+
else:
98+
# Return empty list on final failure rather than crash
99+
logger.warning("Failed to read TinyDB after retries, returning empty list")
100+
return []
101+
return []
63102

64103
def delete_row(self, rollout_id: str) -> int:
65104
Row = Query()

eval_protocol/event_bus/tinydb_event_bus_database.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import time
34
from typing import Any, List
@@ -27,9 +28,30 @@ def __init__(self, db_path: str):
2728
if db_dir:
2829
os.makedirs(db_dir, exist_ok=True)
2930
self._db_path = db_path
30-
self._db = TinyDB(db_path)
31+
self._db = self._open_db_with_retry()
3132
self._table = self._db.table("events")
3233

34+
def _open_db_with_retry(self, max_retries: int = 3) -> TinyDB:
35+
"""Open TinyDB with retry logic to handle transient JSON decode errors."""
36+
last_error: Exception | None = None
37+
for attempt in range(max_retries):
38+
try:
39+
return TinyDB(self._db_path)
40+
except json.JSONDecodeError as e:
41+
last_error = e
42+
logger.warning(f"TinyDB JSON decode error on attempt {attempt + 1}: {e}")
43+
# Wait a bit and retry - the file might be mid-write
44+
time.sleep(0.1 * (attempt + 1))
45+
# Try to recover by removing the corrupted file
46+
if attempt == max_retries - 1 and os.path.exists(self._db_path):
47+
try:
48+
logger.warning(f"Removing corrupted TinyDB file: {self._db_path}")
49+
os.remove(self._db_path)
50+
return TinyDB(self._db_path)
51+
except Exception:
52+
pass
53+
raise last_error if last_error else RuntimeError("Failed to open TinyDB")
54+
3355
def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
3456
"""Publish an event to the database using atomic transaction."""
3557
try:
@@ -55,38 +77,48 @@ def publish_event(self, event_type: str, data: Any, process_id: str) -> None:
5577
logger.warning(f"Failed to publish event to database: {e}")
5678

5779
def get_unprocessed_events(self, process_id: str) -> List[dict]:
58-
"""Get unprocessed events from other processes."""
59-
try:
60-
# Clear query cache to force fresh read from disk
61-
# TinyDB caches query results, so we need to clear cache to see
62-
# events written by other processes. The search() method will
63-
# automatically call _read_table() on a cache miss.
64-
self._table.clear_cache()
65-
66-
Event = Query()
67-
results = self._table.search((Event.process_id != process_id) & (Event.processed == False)) # noqa: E712
68-
69-
logger.debug(
70-
f"TinyDBEventBusDatabase: Found {len(results)} unprocessed events for process_id: {process_id} in database: {self._db_path}"
71-
)
72-
73-
events = []
74-
# Sort by timestamp
75-
for event in sorted(results, key=lambda x: x.get("timestamp", 0)):
76-
events.append(
77-
{
78-
"event_id": event["event_id"],
79-
"event_type": event["event_type"],
80-
"data": event["data"],
81-
"timestamp": event["timestamp"],
82-
"process_id": event["process_id"],
83-
}
80+
"""Get unprocessed events from other processes with retry logic."""
81+
max_retries = 3
82+
for attempt in range(max_retries):
83+
try:
84+
# Clear query cache to force fresh read from disk
85+
# TinyDB caches query results, so we need to clear cache to see
86+
# events written by other processes. The search() method will
87+
# automatically call _read_table() on a cache miss.
88+
self._table.clear_cache()
89+
90+
Event = Query()
91+
results = self._table.search((Event.process_id != process_id) & (Event.processed == False)) # noqa: E712
92+
93+
logger.debug(
94+
f"TinyDBEventBusDatabase: Found {len(results)} unprocessed events for process_id: {process_id} in database: {self._db_path}"
8495
)
8596

86-
return events
87-
except Exception as e:
88-
logger.warning(f"Failed to get unprocessed events: {e}")
89-
return []
97+
events = []
98+
# Sort by timestamp
99+
for event in sorted(results, key=lambda x: x.get("timestamp", 0)):
100+
events.append(
101+
{
102+
"event_id": event["event_id"],
103+
"event_type": event["event_type"],
104+
"data": event["data"],
105+
"timestamp": event["timestamp"],
106+
"process_id": event["process_id"],
107+
}
108+
)
109+
110+
return events
111+
except json.JSONDecodeError as e:
112+
logger.warning(f"TinyDB JSON decode error on get_unprocessed_events attempt {attempt + 1}: {e}")
113+
if attempt < max_retries - 1:
114+
time.sleep(0.1 * (attempt + 1))
115+
else:
116+
logger.warning("Failed to read events after retries, returning empty list")
117+
return []
118+
except Exception as e:
119+
logger.warning(f"Failed to get unprocessed events: {e}")
120+
return []
121+
return []
90122

91123
def mark_event_processed(self, event_id: str) -> None:
92124
"""Mark an event as processed using atomic transaction."""

tests/conftest.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import sys
3-
import tempfile
43
from pathlib import Path
54

65
import pytest
@@ -26,23 +25,26 @@
2625
# concurrent test workers from corrupting the shared logs.json file.
2726
# This is especially important in CI where pytest-xdist runs tests in parallel.
2827

28+
# Store the original function before any patching
29+
import eval_protocol.directory_utils as dir_utils
30+
31+
_original_find_eval_protocol_dir = dir_utils.find_eval_protocol_dir
32+
2933

3034
@pytest.fixture(scope="session", autouse=True)
31-
def isolated_eval_protocol_dir(tmp_path_factory):
35+
def isolated_eval_protocol_dir(tmp_path_factory, request):
3236
"""
3337
Create an isolated .eval_protocol directory for the test session.
3438
3539
This prevents concurrent test workers from corrupting the shared
3640
~/.eval_protocol/logs.json file when using TinyDB storage.
41+
42+
Note: Tests in test_directory_utils.py are excluded from this fixture
43+
as they need to test the actual find_eval_protocol_dir behavior.
3744
"""
3845
# Create a unique temp directory for this test session/worker
3946
isolated_dir = tmp_path_factory.mktemp("eval_protocol")
4047

41-
# Monkeypatch the find_eval_protocol_dir function to return our isolated dir
42-
import eval_protocol.directory_utils as dir_utils
43-
44-
original_find_eval_protocol_dir = dir_utils.find_eval_protocol_dir
45-
4648
def isolated_find_eval_protocol_dir() -> str:
4749
os.makedirs(str(isolated_dir), exist_ok=True)
4850
return str(isolated_dir)
@@ -52,4 +54,18 @@ def isolated_find_eval_protocol_dir() -> str:
5254
yield isolated_dir
5355

5456
# Restore original function after tests
55-
dir_utils.find_eval_protocol_dir = original_find_eval_protocol_dir
57+
dir_utils.find_eval_protocol_dir = _original_find_eval_protocol_dir
58+
59+
60+
@pytest.fixture
61+
def restore_original_find_eval_protocol_dir():
62+
"""
63+
Fixture to restore the original find_eval_protocol_dir for tests that
64+
need to test the actual implementation (e.g., test_directory_utils.py).
65+
66+
Use this fixture in tests that need to test the real directory behavior.
67+
"""
68+
# Temporarily restore the original function
69+
dir_utils.find_eval_protocol_dir = _original_find_eval_protocol_dir
70+
yield _original_find_eval_protocol_dir
71+
# The session fixture will clean up when tests complete

tests/test_directory_utils.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
import os
22
import tempfile
33
from unittest.mock import patch
4+
45
import pytest
56

6-
from eval_protocol.directory_utils import find_eval_protocol_dir, find_eval_protocol_datasets_dir
7+
import eval_protocol.directory_utils as dir_utils
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def use_real_directory_utils(restore_original_find_eval_protocol_dir):
12+
"""
13+
Automatically use the real find_eval_protocol_dir for all tests in this module.
14+
15+
This is necessary because the session-scoped isolated_eval_protocol_dir fixture
16+
patches find_eval_protocol_dir globally, but these tests need to test the
17+
actual implementation behavior.
18+
"""
19+
yield
720

821

922
class TestDirectoryUtils:
@@ -13,7 +26,7 @@ def test_find_eval_protocol_dir_uses_home_folder(self):
1326
"""Test that find_eval_protocol_dir always maps to home folder."""
1427
with tempfile.TemporaryDirectory() as temp_dir:
1528
with patch.dict(os.environ, {"HOME": temp_dir}):
16-
result = find_eval_protocol_dir()
29+
result = dir_utils.find_eval_protocol_dir()
1730
expected = os.path.expanduser("~/.eval_protocol")
1831
assert result == expected
1932
assert result.endswith(".eval_protocol")
@@ -29,7 +42,7 @@ def test_find_eval_protocol_dir_creates_directory(self):
2942
os.rmdir(eval_protocol_dir)
3043

3144
# Call the function
32-
result = find_eval_protocol_dir()
45+
result = dir_utils.find_eval_protocol_dir()
3346

3447
# Verify the directory was created
3548
assert result == eval_protocol_dir
@@ -40,7 +53,7 @@ def test_find_eval_protocol_dir_handles_tilde_expansion(self):
4053
"""Test that find_eval_protocol_dir properly handles tilde expansion."""
4154
with tempfile.TemporaryDirectory() as temp_dir:
4255
with patch.dict(os.environ, {"HOME": temp_dir}):
43-
result = find_eval_protocol_dir()
56+
result = dir_utils.find_eval_protocol_dir()
4457
expected = os.path.expanduser("~/.eval_protocol")
4558
assert result == expected
4659
assert result.startswith(temp_dir)
@@ -49,7 +62,7 @@ def test_find_eval_protocol_datasets_dir_uses_home_folder(self):
4962
"""Test that find_eval_protocol_datasets_dir also uses home folder."""
5063
with tempfile.TemporaryDirectory() as temp_dir:
5164
with patch.dict(os.environ, {"HOME": temp_dir}):
52-
result = find_eval_protocol_datasets_dir()
65+
result = dir_utils.find_eval_protocol_datasets_dir()
5366
expected = os.path.expanduser("~/.eval_protocol/datasets")
5467
assert result == expected
5568
assert result.endswith(".eval_protocol/datasets")
@@ -69,7 +82,7 @@ def test_find_eval_protocol_datasets_dir_creates_directory(self):
6982
os.rmdir(eval_protocol_dir)
7083

7184
# Call the function
72-
result = find_eval_protocol_datasets_dir()
85+
result = dir_utils.find_eval_protocol_datasets_dir()
7386

7487
# Verify both directories were created
7588
assert result == datasets_dir
@@ -82,14 +95,14 @@ def test_find_eval_protocol_dir_consistency(self):
8295
"""Test that multiple calls to find_eval_protocol_dir return the same path."""
8396
with tempfile.TemporaryDirectory() as temp_dir:
8497
with patch.dict(os.environ, {"HOME": temp_dir}):
85-
result1 = find_eval_protocol_dir()
86-
result2 = find_eval_protocol_dir()
98+
result1 = dir_utils.find_eval_protocol_dir()
99+
result2 = dir_utils.find_eval_protocol_dir()
87100
assert result1 == result2
88101

89102
def test_find_eval_protocol_datasets_dir_consistency(self):
90103
"""Test that multiple calls to find_eval_protocol_datasets_dir return the same path."""
91104
with tempfile.TemporaryDirectory() as temp_dir:
92105
with patch.dict(os.environ, {"HOME": temp_dir}):
93-
result1 = find_eval_protocol_datasets_dir()
94-
result2 = find_eval_protocol_datasets_dir()
106+
result1 = dir_utils.find_eval_protocol_datasets_dir()
107+
result2 = dir_utils.find_eval_protocol_datasets_dir()
95108
assert result1 == result2

0 commit comments

Comments
 (0)