From 18145be7b42ade9a45de14ebe7706ef076d1f041 Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 14 Mar 2026 23:43:11 +0000 Subject: [PATCH 1/2] #419: basic cache class --- app/lib/cache/__init__.py | 0 app/lib/cache/cache.py | 69 +++++++++++++++++++++++++ tests/unit/lib/cache_test.py | 99 ++++++++++++++++++++++++++++++++++++ 3 files changed, 168 insertions(+) create mode 100644 app/lib/cache/__init__.py create mode 100644 app/lib/cache/cache.py create mode 100644 tests/unit/lib/cache_test.py diff --git a/app/lib/cache/__init__.py b/app/lib/cache/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/lib/cache/cache.py b/app/lib/cache/cache.py new file mode 100644 index 00000000..528789bb --- /dev/null +++ b/app/lib/cache/cache.py @@ -0,0 +1,69 @@ +import concurrent.futures +import threading +from collections.abc import Callable +from datetime import timedelta +from time import monotonic + +import pydantic +import structlog + + +class BackgroundCache[T: pydantic.BaseModel]: + def __init__( + self, + name: str, + refresh_func: Callable[[], T], + refresh_frequency: timedelta, + refresh_timeout: timedelta, + ) -> None: + self.name = name + self._refresh_func = refresh_func + self.refresh_frequency = refresh_frequency + self._refresh_timeout = refresh_timeout + self._lock = threading.Lock() + self._stop_event = threading.Event() + self._value: T = self._do_refresh() + + def _do_refresh(self) -> T: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(self._refresh_func) + try: + return future.result(timeout=self._refresh_timeout.total_seconds()) + except concurrent.futures.TimeoutError as e: + raise TimeoutError(f"refresh_func timed out after {self._refresh_timeout}") from e + + def get(self) -> T: + with self._lock: + return self._value + + def run(self) -> None: + log = structlog.get_logger().bind(cache_name=self.name) + + while not self._stop_event.is_set(): + start = monotonic() + + try: + new_value = self._do_refresh() + except Exception as e: + log.error("cache refresh failed", reason=str(e), exc_info=True) + self._stop_event.wait(timeout=self.refresh_frequency.total_seconds()) + continue + + elapsed = monotonic() - start + with self._lock: + old_value = self._value + self._value = new_value + + old_json = old_value.model_dump_json() + new_json = new_value.model_dump_json() + changed = old_json != new_json + log.debug( + "cache refreshed", + duration_seconds=round(elapsed, 3), + size_bytes=len(new_json), + changed=changed, + ) + self._stop_event.wait(timeout=self.refresh_frequency.total_seconds()) + + def stop(self) -> None: + self._stop_event.set() diff --git a/tests/unit/lib/cache_test.py b/tests/unit/lib/cache_test.py new file mode 100644 index 00000000..acf4ef34 --- /dev/null +++ b/tests/unit/lib/cache_test.py @@ -0,0 +1,99 @@ +import threading +import time +import unittest +from datetime import timedelta + +import pydantic + +from app.lib.cache.cache import BackgroundCache + + +class DummyModel(pydantic.BaseModel): + value: int + + +class BackgroundCacheTest(unittest.TestCase): + def test_happy_path_no_change(self): + value = DummyModel(value=1) + + def refresh(): + return value + + cache = BackgroundCache( + name="test_cache", + refresh_func=refresh, + refresh_frequency=timedelta(seconds=0.1), + refresh_timeout=timedelta(seconds=1), + ) + self.assertEqual(cache.get().value, 1) + t = threading.Thread(target=cache.run, daemon=True) + t.start() + time.sleep(0.35) + for _ in range(3): + self.assertEqual(cache.get().value, 1) + cache.stop() + t.join(timeout=1) + self.assertEqual(cache.get().value, 1) + + def test_happy_path_with_change(self): + state = {"v": 0} + + def refresh(): + return DummyModel(value=state["v"]) + + cache = BackgroundCache( + name="test_cache", + refresh_func=refresh, + refresh_frequency=timedelta(seconds=0.05), + refresh_timeout=timedelta(seconds=1), + ) + self.assertEqual(cache.get().value, 0) + t = threading.Thread(target=cache.run, daemon=True) + t.start() + self.assertEqual(cache.get().value, 0) + state["v"] = 1 + time.sleep(0.15) + self.assertEqual(cache.get().value, 1) + state["v"] = 2 + time.sleep(0.15) + self.assertEqual(cache.get().value, 2) + cache.stop() + t.join(timeout=1) + + def test_refresh_fails_then_restores_value_unchanged(self): + call_count = 0 + + def refresh(): + nonlocal call_count + call_count += 1 + if call_count == 2: + raise RuntimeError("transient failure") + return DummyModel(value=42) + + cache = BackgroundCache( + name="test_cache", + refresh_func=refresh, + refresh_frequency=timedelta(seconds=0.05), + refresh_timeout=timedelta(seconds=1), + ) + self.assertEqual(cache.get().value, 42) + t = threading.Thread(target=cache.run, daemon=True) + t.start() + time.sleep(0.2) + self.assertEqual(cache.get().value, 42) + cache.stop() + t.join(timeout=1) + self.assertEqual(cache.get().value, 42) + + def test_refresh_fails_during_init_raises(self): + def refresh(): + raise ValueError("init failed") + + with self.assertRaises(ValueError) as ctx: + BackgroundCache( + name="test_cache", + refresh_func=refresh, + refresh_frequency=timedelta(seconds=1), + refresh_timeout=timedelta(seconds=1), + ) + self.assertIn("init failed", str(ctx.exception)) From d10c796fd14ea4bbd3b4213e085c7c0c92bd1a6a Mon Sep 17 00:00:00 2001 From: kraysent Date: Sat, 14 Mar 2026 23:59:05 +0000 Subject: [PATCH 2/2] implement rawdata sizes cache --- app/commands/adminapi/command.py | 26 ++++++++++++++++++++-- app/data/model/__init__.py | 2 ++ app/data/model/table.py | 6 ++++- app/data/repositories/layer0/repository.py | 6 +++++ app/data/repositories/layer0/tables.py | 22 +++++++++++++----- app/domain/adminapi/__init__.py | 2 ++ app/domain/adminapi/actions.py | 6 ++++- app/domain/adminapi/cache_registry.py | 10 +++++++++ app/domain/adminapi/mock.py | 8 ++++++- app/domain/adminapi/table_upload.py | 6 ++++- app/lib/cache/__init__.py | 3 +++ tests/integration/create_table_test.py | 7 +++++- tests/integration/rawdata_table_test.py | 6 +++-- tests/lib/__init__.py | 3 +++ tests/lib/cache_registry.py | 26 ++++++++++++++++++++++ tests/unit/data/layer0_repository_test.py | 2 -- tests/unit/domain/table_upload_test.py | 3 +++ 17 files changed, 128 insertions(+), 16 deletions(-) create mode 100644 app/domain/adminapi/cache_registry.py create mode 100644 tests/lib/cache_registry.py diff --git a/app/commands/adminapi/command.py b/app/commands/adminapi/command.py index 07fcd446..47da0288 100644 --- a/app/commands/adminapi/command.py +++ b/app/commands/adminapi/command.py @@ -1,3 +1,5 @@ +import threading +from datetime import timedelta from pathlib import Path from typing import final @@ -6,9 +8,11 @@ import structlog import yaml -from app.data import repositories +from app.data import model, repositories from app.domain import adminapi as domain +from app.domain.adminapi import CacheRegistry from app.lib import auth, clients, commands, config, tracing +from app.lib.cache import BackgroundCache from app.lib.storage import postgres from app.lib.tracing import TracingConfig from app.lib.web import middlewares, server @@ -36,13 +40,31 @@ def prepare(self): authenticator = auth.PostgresAuthenticator(self.pg_storage) + layer0_repo = repositories.Layer0Repository(self.pg_storage, log) + + def refresh_rawdata_row_counts() -> model.RawdataTableRowCounts: + names = layer0_repo.list_tables() + return model.RawdataTableRowCounts(rows_by_table={name: layer0_repo.total_rows(name) for name in names}) + + rawdata_row_counts: BackgroundCache[model.RawdataTableRowCounts] = BackgroundCache( + name="rawdata_table_row_counts", + refresh_func=refresh_rawdata_row_counts, + refresh_frequency=timedelta(minutes=1), + refresh_timeout=timedelta(seconds=30), + ) + cache_thread = threading.Thread(target=rawdata_row_counts.run, daemon=True) + cache_thread.start() + + cache_registry = CacheRegistry(rawdata_row_counts=rawdata_row_counts) + actions = domain.Actions( common_repo=repositories.CommonRepository(self.pg_storage, log), - layer0_repo=repositories.Layer0Repository(self.pg_storage, log), + layer0_repo=layer0_repo, layer1_repo=repositories.Layer1Repository(self.pg_storage, log), layer2_repo=repositories.Layer2Repository(self.pg_storage, log), authenticator=authenticator, clients=clients.Clients(cfg.clients.ads_token), + cache_registry=cache_registry, ) self.app = presentation.Server(actions, cfg.server, log) diff --git a/app/data/model/__init__.py b/app/data/model/__init__.py index be6aba96..715d9608 100644 --- a/app/data/model/__init__.py +++ b/app/data/model/__init__.py @@ -21,6 +21,7 @@ Layer0RawData, Layer0TableListItem, Layer0TableMeta, + RawdataTableRowCounts, TableRecord, TableStatistics, ) @@ -31,6 +32,7 @@ "Layer0TableMeta", "Layer0CreationResponse", "Layer0TableListItem", + "RawdataTableRowCounts", "ColumnDescription", "TableStatistics", "get_object", diff --git a/app/data/model/table.py b/app/data/model/table.py index d4436d00..de057b41 100644 --- a/app/data/model/table.py +++ b/app/data/model/table.py @@ -3,6 +3,7 @@ from typing import Any import pandas +import pydantic from astropy import units as u from app.lib.storage import enums @@ -48,7 +49,6 @@ class Layer0TableMeta: class Layer0TableListItem: table_name: str description: str - num_entries: int num_fields: int modification_dt: datetime.datetime @@ -65,3 +65,7 @@ class TableStatistics: last_modified_dt: datetime.datetime total_rows: int total_original_rows: int + + +class RawdataTableRowCounts(pydantic.BaseModel): + rows_by_table: dict[str, int] diff --git a/app/data/repositories/layer0/repository.py b/app/data/repositories/layer0/repository.py index 545525e6..161daf54 100644 --- a/app/data/repositories/layer0/repository.py +++ b/app/data/repositories/layer0/repository.py @@ -92,6 +92,12 @@ def register_records(self, table_name: str, record_ids: list[str]) -> None: def get_table_statistics(self, table_name: str) -> model.TableStatistics: return self.records_repo.get_table_statistics(table_name) + def list_tables(self) -> list[str]: + return self.table_repo.list_tables() + + def total_rows(self, table_name: str) -> int: + return self.table_repo.total_rows(table_name) + def get_processed_records( self, limit: int, diff --git a/app/data/repositories/layer0/tables.py b/app/data/repositories/layer0/tables.py index ce2eed10..15005b75 100644 --- a/app/data/repositories/layer0/tables.py +++ b/app/data/repositories/layer0/tables.py @@ -504,7 +504,6 @@ def search_tables( t.table_name, t.modification_dt, COALESCE(ti.param->>'description', '') AS description, - COALESCE(ps.n_live_tup::bigint, 0)::int AS num_entries, ( SELECT COUNT(*)::int FROM meta.column_info c @@ -515,8 +514,6 @@ def search_tables( FROM layer0.tables t LEFT JOIN meta.table_info ti ON ti.schema_name = %s AND ti.table_name = t.table_name - LEFT JOIN pg_stat_user_tables ps - ON ps.schemaname = %s AND ps.relname = t.table_name WHERE t.table_name ILIKE %s OR COALESCE(ti.param->>'description', '') ILIKE %s ORDER BY t.modification_dt DESC LIMIT %s OFFSET %s @@ -525,7 +522,6 @@ def search_tables( RAWDATA_SCHEMA, INTERNAL_ID_COLUMN_NAME, RAWDATA_SCHEMA, - RAWDATA_SCHEMA, pattern, pattern, page_size, @@ -536,13 +532,29 @@ def search_tables( model.Layer0TableListItem( table_name=row["table_name"], description=row["description"] or "", - num_entries=int(row["num_entries"]), num_fields=int(row["num_fields"]), modification_dt=row["modification_dt"], ) for row in rows ] + def list_tables(self) -> list[str]: + rows = self._storage.query( + "SELECT table_name FROM layer0.tables ORDER BY table_name", + params=[], + ) + return [row["table_name"] for row in rows] + + def total_rows(self, table_name: str) -> int: + row = self._storage.query_one( + sql.SQL("SELECT COUNT(1) AS cnt FROM {}.{}").format( + sql.Identifier(RAWDATA_SCHEMA), + sql.Identifier(table_name), + ), + params=[], + ) + return int(row["cnt"]) + def _get_table_id(self, table_name: str) -> tuple[int, bool]: try: row = self._storage.query_one( diff --git a/app/domain/adminapi/__init__.py b/app/domain/adminapi/__init__.py index fbc924c2..ac7986ab 100644 --- a/app/domain/adminapi/__init__.py +++ b/app/domain/adminapi/__init__.py @@ -1,4 +1,5 @@ from app.domain.adminapi.actions import Actions +from app.domain.adminapi.cache_registry import CacheRegistry from app.domain.adminapi.crossmatch import CrossmatchManager from app.domain.adminapi.login import LoginManager from app.domain.adminapi.mock import get_mock_actions @@ -7,6 +8,7 @@ __all__ = [ "Actions", + "CacheRegistry", "CrossmatchManager", "get_mock_actions", "LoginManager", diff --git a/app/domain/adminapi/actions.py b/app/domain/adminapi/actions.py index 978ec8ef..c3040111 100644 --- a/app/domain/adminapi/actions.py +++ b/app/domain/adminapi/actions.py @@ -2,6 +2,7 @@ from app.data import repositories from app.domain.adminapi import crossmatch, layer1_write, login, sources, table_upload +from app.domain.adminapi.cache_registry import CacheRegistry from app.lib import auth, clients from app.presentation import adminapi @@ -16,10 +17,13 @@ def __init__( layer2_repo: repositories.Layer2Repository, authenticator: auth.Authenticator, clients: clients.Clients, + cache_registry: CacheRegistry, ): self.source_manager = sources.SourceManager(common_repo) self.login_manager = login.LoginManager(authenticator) - self.table_upload_manager = table_upload.TableUploadManager(common_repo, layer0_repo, layer1_repo, clients) + self.table_upload_manager = table_upload.TableUploadManager( + common_repo, layer0_repo, layer1_repo, clients, cache_registry + ) self.crossmatch_manager = crossmatch.CrossmatchManager(layer0_repo, layer1_repo, layer2_repo) self.layer1_writer = layer1_write.Layer1Writer(layer1_repo) diff --git a/app/domain/adminapi/cache_registry.py b/app/domain/adminapi/cache_registry.py new file mode 100644 index 00000000..abbf9bdd --- /dev/null +++ b/app/domain/adminapi/cache_registry.py @@ -0,0 +1,10 @@ +from app.data.model import RawdataTableRowCounts +from app.lib.cache.cache import BackgroundCache + + +class CacheRegistry: + def __init__( + self, + rawdata_row_counts: BackgroundCache[RawdataTableRowCounts], + ) -> None: + self.rawdata_row_counts = rawdata_row_counts diff --git a/app/domain/adminapi/mock.py b/app/domain/adminapi/mock.py index 10e1b347..0a905dae 100644 --- a/app/domain/adminapi/mock.py +++ b/app/domain/adminapi/mock.py @@ -1,6 +1,7 @@ from unittest import mock -from app.domain.adminapi import actions +from app.data import model +from app.domain.adminapi import CacheRegistry, actions from app.lib import auth, clients @@ -9,6 +10,10 @@ def get_mock_actions(): c.ads = mock.MagicMock() c.vizier = mock.MagicMock() + cache_mock = mock.MagicMock() + cache_mock.get.return_value = model.RawdataTableRowCounts(rows_by_table={}) + cache_registry = CacheRegistry(rawdata_row_counts=cache_mock) + return actions.Actions( common_repo=mock.MagicMock(), layer0_repo=mock.MagicMock(), @@ -16,4 +21,5 @@ def get_mock_actions(): layer2_repo=mock.MagicMock(), authenticator=auth.NoopAuthenticator(), clients=c, + cache_registry=cache_registry, ) diff --git a/app/domain/adminapi/table_upload.py b/app/domain/adminapi/table_upload.py index bd77a42e..a4e45b24 100644 --- a/app/domain/adminapi/table_upload.py +++ b/app/domain/adminapi/table_upload.py @@ -14,6 +14,7 @@ from app.data import model, repositories from app.data.repositories.common import ColumnSchemaInfo, TableSchemaInfo from app.data.repositories.layer0.common import RAWDATA_SCHEMA +from app.domain.adminapi.cache_registry import CacheRegistry from app.lib import astronomy, clients, concurrency from app.lib.storage import enums, mapping from app.lib.web.errors import NotFoundError, RuleValidationError @@ -88,11 +89,13 @@ def __init__( layer0_repo: repositories.Layer0Repository, layer1_repo: repositories.Layer1Repository, clients: clients.Clients, + cache_registry: CacheRegistry, ) -> None: self.common_repo = common_repo self.layer0_repo = layer0_repo self.layer1_repo = layer1_repo self.clients = clients + self._cache_registry = cache_registry def create_table(self, r: adminapi.CreateTableRequest) -> tuple[adminapi.CreateTableResponse, bool]: source_id = get_source_id(self.common_repo, self.clients.ads, r.bibcode) @@ -162,12 +165,13 @@ def add_data(self, r: adminapi.AddDataRequest) -> adminapi.AddDataResponse: def get_table_list(self, r: adminapi.GetTableListRequest) -> adminapi.GetTableListResponse: items = self.layer0_repo.search_tables(r.query, r.page_size, r.page) + row_counts = self._cache_registry.rawdata_row_counts.get().rows_by_table return adminapi.GetTableListResponse( tables=[ adminapi.TableListItem( name=item.table_name, description=item.description, - num_entries=item.num_entries, + num_entries=row_counts.get(item.table_name, 0), num_fields=item.num_fields, modification_dt=item.modification_dt, ) diff --git a/app/lib/cache/__init__.py b/app/lib/cache/__init__.py index e69de29b..c9432983 100644 --- a/app/lib/cache/__init__.py +++ b/app/lib/cache/__init__.py @@ -0,0 +1,3 @@ +from app.lib.cache.cache import BackgroundCache + +__all__ = ["BackgroundCache"] diff --git a/tests/integration/create_table_test.py b/tests/integration/create_table_test.py index acaedb44..7bf3245d 100644 --- a/tests/integration/create_table_test.py +++ b/tests/integration/create_table_test.py @@ -20,8 +20,13 @@ def setUpClass(cls) -> None: cls.source_manager = domain.SourceManager(cls.common_repo) cls.layer1_repo = repositories.Layer1Repository(cls.pg_storage.get_storage(), structlog.get_logger()) + cache_registry = lib.get_cache_registry(cls.layer0_repo) cls.upload_manager = domain.TableUploadManager( - cls.common_repo, cls.layer0_repo, cls.layer1_repo, clients.get_mock_clients() + cls.common_repo, + cls.layer0_repo, + cls.layer1_repo, + clients.get_mock_clients(), + cache_registry, ) def tearDown(self): diff --git a/tests/integration/rawdata_table_test.py b/tests/integration/rawdata_table_test.py index aeb5561c..fe9014a4 100644 --- a/tests/integration/rawdata_table_test.py +++ b/tests/integration/rawdata_table_test.py @@ -18,12 +18,14 @@ class RawDataTableTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.pg_storage = lib.TestPostgresStorage.get() - + layer0_repo = repositories.Layer0Repository(cls.pg_storage.get_storage(), structlog.get_logger()) + cache_registry = lib.get_cache_registry(layer0_repo) cls.manager = domain.TableUploadManager( repositories.CommonRepository(cls.pg_storage.get_storage(), structlog.get_logger()), - repositories.Layer0Repository(cls.pg_storage.get_storage(), structlog.get_logger()), + layer0_repo, repositories.Layer1Repository(cls.pg_storage.get_storage(), structlog.get_logger()), clients.get_mock_clients(), + cache_registry, ) def tearDown(self): diff --git a/tests/lib/__init__.py b/tests/lib/__init__.py index 451e5d47..d6a530ad 100644 --- a/tests/lib/__init__.py +++ b/tests/lib/__init__.py @@ -1,4 +1,5 @@ from tests.lib.astronomy import get_synthetic_data +from tests.lib.cache_registry import get_cache_registry, get_mock_cache_registry from tests.lib.decorators import test_logging_decorator from tests.lib.mocks import raises, returns from tests.lib.postgres import TestPostgresStorage @@ -8,6 +9,8 @@ "TestPostgresStorage", "find_free_port", "TestSession", + "get_cache_registry", + "get_mock_cache_registry", "returns", "raises", "test_logging_decorator", diff --git a/tests/lib/cache_registry.py b/tests/lib/cache_registry.py new file mode 100644 index 00000000..574fbff6 --- /dev/null +++ b/tests/lib/cache_registry.py @@ -0,0 +1,26 @@ +from datetime import timedelta +from unittest import mock + +from app.data import model +from app.domain.adminapi import CacheRegistry +from app.lib.cache import BackgroundCache + + +def get_mock_cache_registry() -> CacheRegistry: + cache = mock.MagicMock() + cache.get.return_value = model.RawdataTableRowCounts(rows_by_table={}) + return CacheRegistry(rawdata_row_counts=cache) + + +def get_cache_registry(layer0_repo): + def refresh(): + names = layer0_repo.list_tables() + return model.RawdataTableRowCounts(rows_by_table={name: layer0_repo.total_rows(name) for name in names}) + + rawdata_row_counts = BackgroundCache( + name="rawdata_table_row_counts", + refresh_func=refresh, + refresh_frequency=timedelta(minutes=1), + refresh_timeout=timedelta(seconds=30), + ) + return CacheRegistry(rawdata_row_counts=rawdata_row_counts) diff --git a/tests/unit/data/layer0_repository_test.py b/tests/unit/data/layer0_repository_test.py index df320c0f..f2effba1 100644 --- a/tests/unit/data/layer0_repository_test.py +++ b/tests/unit/data/layer0_repository_test.py @@ -67,7 +67,6 @@ def test_search_tables_calls_query_with_expected_structure(self): { "table_name": "my_table", "description": "A test table", - "num_entries": 100, "num_fields": 6, "modification_dt": datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC), } @@ -89,5 +88,4 @@ def test_search_tables_calls_query_with_expected_structure(self): self.assertEqual(len(result), 1) self.assertEqual(result[0].table_name, "my_table") self.assertEqual(result[0].description, "A test table") - self.assertEqual(result[0].num_entries, 100) self.assertEqual(result[0].num_fields, 6) diff --git a/tests/unit/domain/table_upload_test.py b/tests/unit/domain/table_upload_test.py index 5c122977..bd3b53f6 100644 --- a/tests/unit/domain/table_upload_test.py +++ b/tests/unit/domain/table_upload_test.py @@ -22,6 +22,7 @@ def setUp(self): layer0_repo=mock.MagicMock(), layer1_repo=mock.MagicMock(), clients=clients.get_mock_clients(), + cache_registry=lib.get_mock_cache_registry(), ) def test_add_data(self): @@ -147,6 +148,7 @@ def setUp(self): layer0_repo=mock.MagicMock(), layer1_repo=mock.MagicMock(), clients=clients.get_mock_clients(), + cache_registry=lib.get_mock_cache_registry(), ) @parameterized.expand( @@ -270,6 +272,7 @@ def setUp(self) -> None: layer0_repo=mock.MagicMock(), layer1_repo=layer1_repo, clients=clients.get_mock_clients(), + cache_registry=lib.get_mock_cache_registry(), ) def test_get_records_returns_records_with_pgc(self) -> None: