Skip to content

Commit 18145be

Browse files
committed
#419: basic cache class
1 parent 67cf2b3 commit 18145be

3 files changed

Lines changed: 168 additions & 0 deletions

File tree

app/lib/cache/__init__.py

Whitespace-only changes.

app/lib/cache/cache.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import concurrent.futures
2+
import threading
3+
from collections.abc import Callable
4+
from datetime import timedelta
5+
from time import monotonic
6+
7+
import pydantic
8+
import structlog
9+
10+
11+
class BackgroundCache[T: pydantic.BaseModel]:
12+
def __init__(
13+
self,
14+
name: str,
15+
refresh_func: Callable[[], T],
16+
refresh_frequency: timedelta,
17+
refresh_timeout: timedelta,
18+
) -> None:
19+
self.name = name
20+
self._refresh_func = refresh_func
21+
self.refresh_frequency = refresh_frequency
22+
self._refresh_timeout = refresh_timeout
23+
self._lock = threading.Lock()
24+
self._stop_event = threading.Event()
25+
self._value: T = self._do_refresh()
26+
27+
def _do_refresh(self) -> T:
28+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
29+
future = executor.submit(self._refresh_func)
30+
try:
31+
return future.result(timeout=self._refresh_timeout.total_seconds())
32+
except concurrent.futures.TimeoutError as e:
33+
raise TimeoutError(f"refresh_func timed out after {self._refresh_timeout}") from e
34+
35+
def get(self) -> T:
36+
with self._lock:
37+
return self._value
38+
39+
def run(self) -> None:
40+
log = structlog.get_logger().bind(cache_name=self.name)
41+
42+
while not self._stop_event.is_set():
43+
start = monotonic()
44+
45+
try:
46+
new_value = self._do_refresh()
47+
except Exception as e:
48+
log.error("cache refresh failed", reason=str(e), exc_info=True)
49+
self._stop_event.wait(timeout=self.refresh_frequency.total_seconds())
50+
continue
51+
52+
elapsed = monotonic() - start
53+
with self._lock:
54+
old_value = self._value
55+
self._value = new_value
56+
57+
old_json = old_value.model_dump_json()
58+
new_json = new_value.model_dump_json()
59+
changed = old_json != new_json
60+
log.debug(
61+
"cache refreshed",
62+
duration_seconds=round(elapsed, 3),
63+
size_bytes=len(new_json),
64+
changed=changed,
65+
)
66+
self._stop_event.wait(timeout=self.refresh_frequency.total_seconds())
67+
68+
def stop(self) -> None:
69+
self._stop_event.set()

tests/unit/lib/cache_test.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import threading
2+
import time
3+
import unittest
4+
from datetime import timedelta
5+
6+
import pydantic
7+
8+
from app.lib.cache.cache import BackgroundCache
9+
10+
11+
class DummyModel(pydantic.BaseModel):
12+
value: int
13+
14+
15+
class BackgroundCacheTest(unittest.TestCase):
16+
def test_happy_path_no_change(self):
17+
value = DummyModel(value=1)
18+
19+
def refresh():
20+
return value
21+
22+
cache = BackgroundCache(
23+
name="test_cache",
24+
refresh_func=refresh,
25+
refresh_frequency=timedelta(seconds=0.1),
26+
refresh_timeout=timedelta(seconds=1),
27+
)
28+
self.assertEqual(cache.get().value, 1)
29+
t = threading.Thread(target=cache.run, daemon=True)
30+
t.start()
31+
time.sleep(0.35)
32+
for _ in range(3):
33+
self.assertEqual(cache.get().value, 1)
34+
cache.stop()
35+
t.join(timeout=1)
36+
self.assertEqual(cache.get().value, 1)
37+
38+
def test_happy_path_with_change(self):
39+
state = {"v": 0}
40+
41+
def refresh():
42+
return DummyModel(value=state["v"])
43+
44+
cache = BackgroundCache(
45+
name="test_cache",
46+
refresh_func=refresh,
47+
refresh_frequency=timedelta(seconds=0.05),
48+
refresh_timeout=timedelta(seconds=1),
49+
)
50+
self.assertEqual(cache.get().value, 0)
51+
t = threading.Thread(target=cache.run, daemon=True)
52+
t.start()
53+
self.assertEqual(cache.get().value, 0)
54+
state["v"] = 1
55+
time.sleep(0.15)
56+
self.assertEqual(cache.get().value, 1)
57+
state["v"] = 2
58+
time.sleep(0.15)
59+
self.assertEqual(cache.get().value, 2)
60+
cache.stop()
61+
t.join(timeout=1)
62+
63+
def test_refresh_fails_then_restores_value_unchanged(self):
64+
call_count = 0
65+
66+
def refresh():
67+
nonlocal call_count
68+
call_count += 1
69+
if call_count == 2:
70+
raise RuntimeError("transient failure")
71+
return DummyModel(value=42)
72+
73+
cache = BackgroundCache(
74+
name="test_cache",
75+
refresh_func=refresh,
76+
refresh_frequency=timedelta(seconds=0.05),
77+
refresh_timeout=timedelta(seconds=1),
78+
)
79+
self.assertEqual(cache.get().value, 42)
80+
t = threading.Thread(target=cache.run, daemon=True)
81+
t.start()
82+
time.sleep(0.2)
83+
self.assertEqual(cache.get().value, 42)
84+
cache.stop()
85+
t.join(timeout=1)
86+
self.assertEqual(cache.get().value, 42)
87+
88+
def test_refresh_fails_during_init_raises(self):
89+
def refresh():
90+
raise ValueError("init failed")
91+
92+
with self.assertRaises(ValueError) as ctx:
93+
BackgroundCache(
94+
name="test_cache",
95+
refresh_func=refresh,
96+
refresh_frequency=timedelta(seconds=1),
97+
refresh_timeout=timedelta(seconds=1),
98+
)
99+
self.assertIn("init failed", str(ctx.exception))

0 commit comments

Comments
 (0)