-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
132 lines (102 loc) · 3.77 KB
/
database.py
File metadata and controls
132 lines (102 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
Database initialization and session management for Resolver.
Copyright (c) 2026 Stefan Kumarasinghe
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
"""
from __future__ import annotations
import logging
import os
import re
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Protocol
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, sessionmaker
from db_models import Base
logger = logging.getLogger(__name__)
class _SessionFactory(Protocol):
def __call__(self) -> Session: ...
def _new_session(factory: _SessionFactory) -> Session:
return factory()
_ENGINE: Engine | None = None
_SESSION_FACTORY: _SessionFactory | None = None
def _ensure_postgres_database_exists(database_url: str) -> None:
url = make_url(database_url)
if not url.drivername.startswith("postgresql"):
return
target_db = (url.database or "").strip()
if not target_db:
return
if not re.fullmatch(r"[A-Za-z0-9_]+", target_db):
raise RuntimeError(f"Invalid database name in RESOLVER_DATABASE_URL: {target_db!r}")
admin_url = url.set(database="postgres")
admin_engine = create_engine(admin_url, isolation_level="AUTOCOMMIT", pool_pre_ping=True)
try:
with admin_engine.connect() as conn:
exists = conn.execute(
text("SELECT 1 FROM pg_database WHERE datname = :name"),
{"name": target_db},
).scalar()
if exists:
return
conn.exec_driver_sql(f'CREATE DATABASE "{target_db}"')
finally:
admin_engine.dispose()
def init_database(database_url: str) -> None:
if _ENGINE is not None and _SESSION_FACTORY is not None:
return
_ensure_postgres_database_exists(database_url)
engine = create_engine(
database_url,
pool_pre_ping=True,
pool_size=int(os.getenv("RESOLVER_DB_POOL_SIZE", "10")),
max_overflow=int(os.getenv("RESOLVER_DB_MAX_OVERFLOW", "20")),
pool_timeout=int(os.getenv("RESOLVER_DB_POOL_TIMEOUT", "30")),
pool_recycle=int(os.getenv("RESOLVER_DB_POOL_RECYCLE", "1800")),
)
factory = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
globals()["_ENGINE"] = engine
globals()["_SESSION_FACTORY"] = factory
def _require_session_factory() -> _SessionFactory:
factory = _SESSION_FACTORY
if factory is None or not callable(factory):
raise RuntimeError("Database not initialized")
return factory
@contextmanager
def get_db_session() -> Iterator[Session]:
if _ENGINE is None:
raise RuntimeError("Database not initialized")
session = _new_session(_require_session_factory())
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def init_db() -> None:
if _ENGINE is None:
raise RuntimeError("Database not initialized")
logger.info("Initializing Resolver database tables...")
Base.metadata.create_all(bind=_ENGINE)
logger.info("Resolver database tables created successfully")
def connection_test() -> bool:
if _ENGINE is None:
return False
try:
with _ENGINE.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except SQLAlchemyError:
return False
def dispose_database() -> None:
globals()["_SESSION_FACTORY"] = None
if _ENGINE is not None:
_ENGINE.dispose()
globals()["_ENGINE"] = None