Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions app/commands/adminapi/command.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import threading
from datetime import timedelta
from pathlib import Path
from typing import final

Expand All @@ -6,9 +8,11 @@
import structlog
import yaml

from app.data import repositories
from app.data import model, repositories
from app.domain import adminapi as domain
from app.domain.adminapi import CacheRegistry
from app.lib import auth, clients, commands, config, tracing
from app.lib.cache import BackgroundCache
from app.lib.storage import postgres
from app.lib.tracing import TracingConfig
from app.lib.web import middlewares, server
Expand Down Expand Up @@ -36,13 +40,31 @@ def prepare(self):

authenticator = auth.PostgresAuthenticator(self.pg_storage)

layer0_repo = repositories.Layer0Repository(self.pg_storage, log)

def refresh_rawdata_row_counts() -> model.RawdataTableRowCounts:
names = layer0_repo.list_tables()
return model.RawdataTableRowCounts(rows_by_table={name: layer0_repo.total_rows(name) for name in names})

rawdata_row_counts: BackgroundCache[model.RawdataTableRowCounts] = BackgroundCache(
name="rawdata_table_row_counts",
refresh_func=refresh_rawdata_row_counts,
refresh_frequency=timedelta(minutes=1),
refresh_timeout=timedelta(seconds=30),
)
cache_thread = threading.Thread(target=rawdata_row_counts.run, daemon=True)
cache_thread.start()

cache_registry = CacheRegistry(rawdata_row_counts=rawdata_row_counts)

actions = domain.Actions(
common_repo=repositories.CommonRepository(self.pg_storage, log),
layer0_repo=repositories.Layer0Repository(self.pg_storage, log),
layer0_repo=layer0_repo,
layer1_repo=repositories.Layer1Repository(self.pg_storage, log),
layer2_repo=repositories.Layer2Repository(self.pg_storage, log),
authenticator=authenticator,
clients=clients.Clients(cfg.clients.ads_token),
cache_registry=cache_registry,
)

self.app = presentation.Server(actions, cfg.server, log)
Expand Down
2 changes: 2 additions & 0 deletions app/data/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Layer0RawData,
Layer0TableListItem,
Layer0TableMeta,
RawdataTableRowCounts,
TableRecord,
TableStatistics,
)
Expand All @@ -31,6 +32,7 @@
"Layer0TableMeta",
"Layer0CreationResponse",
"Layer0TableListItem",
"RawdataTableRowCounts",
"ColumnDescription",
"TableStatistics",
"get_object",
Expand Down
6 changes: 5 additions & 1 deletion app/data/model/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import pandas
import pydantic
from astropy import units as u

from app.lib.storage import enums
Expand Down Expand Up @@ -48,7 +49,6 @@ class Layer0TableMeta:
class Layer0TableListItem:
table_name: str
description: str
num_entries: int
num_fields: int
modification_dt: datetime.datetime

Expand All @@ -65,3 +65,7 @@ class TableStatistics:
last_modified_dt: datetime.datetime
total_rows: int
total_original_rows: int


class RawdataTableRowCounts(pydantic.BaseModel):
rows_by_table: dict[str, int]
6 changes: 6 additions & 0 deletions app/data/repositories/layer0/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ def register_records(self, table_name: str, record_ids: list[str]) -> None:
def get_table_statistics(self, table_name: str) -> model.TableStatistics:
return self.records_repo.get_table_statistics(table_name)

def list_tables(self) -> list[str]:
return self.table_repo.list_tables()

def total_rows(self, table_name: str) -> int:
return self.table_repo.total_rows(table_name)

def get_processed_records(
self,
limit: int,
Expand Down
22 changes: 17 additions & 5 deletions app/data/repositories/layer0/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,6 @@ def search_tables(
t.table_name,
t.modification_dt,
COALESCE(ti.param->>'description', '') AS description,
COALESCE(ps.n_live_tup::bigint, 0)::int AS num_entries,
(
SELECT COUNT(*)::int
FROM meta.column_info c
Expand All @@ -515,8 +514,6 @@ def search_tables(
FROM layer0.tables t
LEFT JOIN meta.table_info ti
ON ti.schema_name = %s AND ti.table_name = t.table_name
LEFT JOIN pg_stat_user_tables ps
ON ps.schemaname = %s AND ps.relname = t.table_name
WHERE t.table_name ILIKE %s OR COALESCE(ti.param->>'description', '') ILIKE %s
ORDER BY t.modification_dt DESC
LIMIT %s OFFSET %s
Expand All @@ -525,7 +522,6 @@ def search_tables(
RAWDATA_SCHEMA,
INTERNAL_ID_COLUMN_NAME,
RAWDATA_SCHEMA,
RAWDATA_SCHEMA,
pattern,
pattern,
page_size,
Expand All @@ -536,13 +532,29 @@ def search_tables(
model.Layer0TableListItem(
table_name=row["table_name"],
description=row["description"] or "",
num_entries=int(row["num_entries"]),
num_fields=int(row["num_fields"]),
modification_dt=row["modification_dt"],
)
for row in rows
]

def list_tables(self) -> list[str]:
rows = self._storage.query(
"SELECT table_name FROM layer0.tables ORDER BY table_name",
params=[],
)
return [row["table_name"] for row in rows]

def total_rows(self, table_name: str) -> int:
row = self._storage.query_one(
sql.SQL("SELECT COUNT(1) AS cnt FROM {}.{}").format(
sql.Identifier(RAWDATA_SCHEMA),
sql.Identifier(table_name),
),
params=[],
)
return int(row["cnt"])

def _get_table_id(self, table_name: str) -> tuple[int, bool]:
try:
row = self._storage.query_one(
Expand Down
2 changes: 2 additions & 0 deletions app/domain/adminapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from app.domain.adminapi.actions import Actions
from app.domain.adminapi.cache_registry import CacheRegistry
from app.domain.adminapi.crossmatch import CrossmatchManager
from app.domain.adminapi.login import LoginManager
from app.domain.adminapi.mock import get_mock_actions
Expand All @@ -7,6 +8,7 @@

__all__ = [
"Actions",
"CacheRegistry",
"CrossmatchManager",
"get_mock_actions",
"LoginManager",
Expand Down
6 changes: 5 additions & 1 deletion app/domain/adminapi/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from app.data import repositories
from app.domain.adminapi import crossmatch, layer1_write, login, sources, table_upload
from app.domain.adminapi.cache_registry import CacheRegistry
from app.lib import auth, clients
from app.presentation import adminapi

Expand All @@ -16,10 +17,13 @@ def __init__(
layer2_repo: repositories.Layer2Repository,
authenticator: auth.Authenticator,
clients: clients.Clients,
cache_registry: CacheRegistry,
):
self.source_manager = sources.SourceManager(common_repo)
self.login_manager = login.LoginManager(authenticator)
self.table_upload_manager = table_upload.TableUploadManager(common_repo, layer0_repo, layer1_repo, clients)
self.table_upload_manager = table_upload.TableUploadManager(
common_repo, layer0_repo, layer1_repo, clients, cache_registry
)
self.crossmatch_manager = crossmatch.CrossmatchManager(layer0_repo, layer1_repo, layer2_repo)
self.layer1_writer = layer1_write.Layer1Writer(layer1_repo)

Expand Down
10 changes: 10 additions & 0 deletions app/domain/adminapi/cache_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from app.data.model import RawdataTableRowCounts
from app.lib.cache.cache import BackgroundCache


class CacheRegistry:
def __init__(
self,
rawdata_row_counts: BackgroundCache[RawdataTableRowCounts],
) -> None:
self.rawdata_row_counts = rawdata_row_counts
8 changes: 7 additions & 1 deletion app/domain/adminapi/mock.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock

from app.domain.adminapi import actions
from app.data import model
from app.domain.adminapi import CacheRegistry, actions
from app.lib import auth, clients


Expand All @@ -9,11 +10,16 @@ def get_mock_actions():
c.ads = mock.MagicMock()
c.vizier = mock.MagicMock()

cache_mock = mock.MagicMock()
cache_mock.get.return_value = model.RawdataTableRowCounts(rows_by_table={})
cache_registry = CacheRegistry(rawdata_row_counts=cache_mock)

return actions.Actions(
common_repo=mock.MagicMock(),
layer0_repo=mock.MagicMock(),
layer1_repo=mock.MagicMock(),
layer2_repo=mock.MagicMock(),
authenticator=auth.NoopAuthenticator(),
clients=c,
cache_registry=cache_registry,
)
6 changes: 5 additions & 1 deletion app/domain/adminapi/table_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from app.data import model, repositories
from app.data.repositories.common import ColumnSchemaInfo, TableSchemaInfo
from app.data.repositories.layer0.common import RAWDATA_SCHEMA
from app.domain.adminapi.cache_registry import CacheRegistry
from app.lib import astronomy, clients, concurrency
from app.lib.storage import enums, mapping
from app.lib.web.errors import NotFoundError, RuleValidationError
Expand Down Expand Up @@ -88,11 +89,13 @@ def __init__(
layer0_repo: repositories.Layer0Repository,
layer1_repo: repositories.Layer1Repository,
clients: clients.Clients,
cache_registry: CacheRegistry,
) -> None:
self.common_repo = common_repo
self.layer0_repo = layer0_repo
self.layer1_repo = layer1_repo
self.clients = clients
self._cache_registry = cache_registry

def create_table(self, r: adminapi.CreateTableRequest) -> tuple[adminapi.CreateTableResponse, bool]:
source_id = get_source_id(self.common_repo, self.clients.ads, r.bibcode)
Expand Down Expand Up @@ -162,12 +165,13 @@ def add_data(self, r: adminapi.AddDataRequest) -> adminapi.AddDataResponse:

def get_table_list(self, r: adminapi.GetTableListRequest) -> adminapi.GetTableListResponse:
items = self.layer0_repo.search_tables(r.query, r.page_size, r.page)
row_counts = self._cache_registry.rawdata_row_counts.get().rows_by_table
return adminapi.GetTableListResponse(
tables=[
adminapi.TableListItem(
name=item.table_name,
description=item.description,
num_entries=item.num_entries,
num_entries=row_counts.get(item.table_name, 0),
num_fields=item.num_fields,
modification_dt=item.modification_dt,
)
Expand Down
3 changes: 3 additions & 0 deletions app/lib/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app.lib.cache.cache import BackgroundCache

__all__ = ["BackgroundCache"]
69 changes: 69 additions & 0 deletions app/lib/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import concurrent.futures
import threading
from collections.abc import Callable
from datetime import timedelta
from time import monotonic

import pydantic
import structlog


class BackgroundCache[T: pydantic.BaseModel]:
def __init__(
self,
name: str,
refresh_func: Callable[[], T],
refresh_frequency: timedelta,
refresh_timeout: timedelta,
) -> None:
self.name = name
self._refresh_func = refresh_func
self.refresh_frequency = refresh_frequency
self._refresh_timeout = refresh_timeout
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._value: T = self._do_refresh()

def _do_refresh(self) -> T:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(self._refresh_func)
try:
return future.result(timeout=self._refresh_timeout.total_seconds())
except concurrent.futures.TimeoutError as e:
raise TimeoutError(f"refresh_func timed out after {self._refresh_timeout}") from e

def get(self) -> T:
with self._lock:
return self._value

def run(self) -> None:
log = structlog.get_logger().bind(cache_name=self.name)

while not self._stop_event.is_set():
start = monotonic()

try:
new_value = self._do_refresh()
except Exception as e:
log.error("cache refresh failed", reason=str(e), exc_info=True)
self._stop_event.wait(timeout=self.refresh_frequency.total_seconds())
continue

elapsed = monotonic() - start
with self._lock:
old_value = self._value
self._value = new_value

old_json = old_value.model_dump_json()
new_json = new_value.model_dump_json()
changed = old_json != new_json
log.debug(
"cache refreshed",
duration_seconds=round(elapsed, 3),
size_bytes=len(new_json),
changed=changed,
)
self._stop_event.wait(timeout=self.refresh_frequency.total_seconds())

def stop(self) -> None:
self._stop_event.set()
7 changes: 6 additions & 1 deletion tests/integration/create_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@ def setUpClass(cls) -> None:

cls.source_manager = domain.SourceManager(cls.common_repo)
cls.layer1_repo = repositories.Layer1Repository(cls.pg_storage.get_storage(), structlog.get_logger())
cache_registry = lib.get_cache_registry(cls.layer0_repo)
cls.upload_manager = domain.TableUploadManager(
cls.common_repo, cls.layer0_repo, cls.layer1_repo, clients.get_mock_clients()
cls.common_repo,
cls.layer0_repo,
cls.layer1_repo,
clients.get_mock_clients(),
cache_registry,
)

def tearDown(self):
Expand Down
6 changes: 4 additions & 2 deletions tests/integration/rawdata_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ class RawDataTableTest(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.pg_storage = lib.TestPostgresStorage.get()

layer0_repo = repositories.Layer0Repository(cls.pg_storage.get_storage(), structlog.get_logger())
cache_registry = lib.get_cache_registry(layer0_repo)
cls.manager = domain.TableUploadManager(
repositories.CommonRepository(cls.pg_storage.get_storage(), structlog.get_logger()),
repositories.Layer0Repository(cls.pg_storage.get_storage(), structlog.get_logger()),
layer0_repo,
repositories.Layer1Repository(cls.pg_storage.get_storage(), structlog.get_logger()),
clients.get_mock_clients(),
cache_registry,
)

def tearDown(self):
Expand Down
3 changes: 3 additions & 0 deletions tests/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tests.lib.astronomy import get_synthetic_data
from tests.lib.cache_registry import get_cache_registry, get_mock_cache_registry
from tests.lib.decorators import test_logging_decorator
from tests.lib.mocks import raises, returns
from tests.lib.postgres import TestPostgresStorage
Expand All @@ -8,6 +9,8 @@
"TestPostgresStorage",
"find_free_port",
"TestSession",
"get_cache_registry",
"get_mock_cache_registry",
"returns",
"raises",
"test_logging_decorator",
Expand Down
Loading
Loading