Skip to content

Commit defdd17

Browse files
authored
#221: implement connection pool (#408)
1 parent 8843c56 commit defdd17

4 files changed

Lines changed: 98 additions & 73 deletions

File tree

app/data/repositories/layer0/tables.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def insert_raw_data(self, data: model.Layer0RawData) -> None:
104104
field_identifiers,
105105
placeholders,
106106
)
107-
query_str = query.as_string(self._storage.get_connection())
107+
query_str = self._storage.query_str(query)
108108

109109
rows: list[list[Any]] = []
110110
for row in data.data.to_dict(orient="records"):
@@ -420,12 +420,11 @@ def update_column_metadata(self, table_name: str, column_description: model.Colu
420420

421421
modification_query = "UPDATE layer0.tables SET modification_dt = now() WHERE id = %s"
422422

423-
with self.with_tx():
424-
self._storage.exec(
425-
"SELECT meta.setparams(%s, %s, %s, %s::json)",
426-
params=[RAWDATA_SCHEMA, table_name, column_description.name, json.dumps(column_params)],
427-
)
428-
self._storage.exec(modification_query, params=[table_id])
423+
self._storage.exec(
424+
"SELECT meta.setparams(%s, %s, %s, %s::json)",
425+
params=[RAWDATA_SCHEMA, table_name, column_description.name, json.dumps(column_params)],
426+
)
427+
self._storage.exec(modification_query, params=[table_id])
429428

430429
def search_tables(
431430
self,

app/lib/storage/postgres/postgres_storage.py

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import threading
2+
import time
13
from collections.abc import Sequence
24
from typing import Any
35

@@ -6,6 +8,7 @@
68
import structlog
79
from psycopg import rows, sql
810
from psycopg.types import enum, numeric
11+
from psycopg_pool import ConnectionPool
912

1013
from app.lib.storage import enums
1114
from app.lib.storage.postgres import config
@@ -33,7 +36,7 @@ def dump(self, obj: Any) -> bytes | bytearray | memoryview:
3336
(np.int64, NumpyIntDumper),
3437
]
3538

36-
DEFAULT_ENUMS = [
39+
DEFAULT_ENUMS: list[tuple[type[enum.Enum], str]] = [
3740
(enums.DataType, "common.datatype"),
3841
(enums.RecordCrossmatchStatus, "layer0.crossmatch_status"),
3942
(enums.RecordTriageStatus, "layer0.triage_status"),
@@ -43,89 +46,112 @@ def dump(self, obj: Any) -> bytes | bytearray | memoryview:
4346
class PgStorage:
4447
def __init__(self, cfg: config.PgStorageConfig, logger: structlog.stdlib.BoundLogger) -> None:
4548
self._config = cfg
46-
self._connection: psycopg.Connection | None = None
49+
self._pool: ConnectionPool | None = None
4750
self._logger = logger
51+
self._local = threading.local()
52+
self._extra_enums: list[tuple[type[enum.Enum], str]] = []
4853

49-
def connect(self) -> None:
50-
self._connection = psycopg.connect(self._config.get_dsn(), row_factory=rows.dict_row, autocommit=True)
51-
if self._connection is None:
52-
raise InternalError("unable to create database connection")
54+
def _configure_connection(self, conn: psycopg.Connection) -> None:
55+
for python_type, dumper in DEFAULT_DUMPERS:
56+
conn.adapters.register_dumper(python_type, dumper)
57+
for enum_type, pg_type in DEFAULT_ENUMS + self._extra_enums:
58+
type_info = enum.EnumInfo.fetch(conn, pg_type)
59+
if type_info is None:
60+
raise RuntimeError(f"Unable to find enum {pg_type} in DB")
61+
enum.register_enum(
62+
type_info,
63+
conn,
64+
enum_type,
65+
mapping={m: m.value for m in enum_type},
66+
)
5367

68+
def connect(self) -> None:
5469
self._logger.debug("connecting to Postgres", endpoint=self._config.endpoint, port=self._config.port)
70+
self._pool = ConnectionPool(
71+
self._config.get_dsn(),
72+
min_size=2,
73+
max_size=10,
74+
kwargs={"row_factory": rows.dict_row, "autocommit": True},
75+
configure=self._configure_connection,
76+
)
5577

56-
for python_type, dumper in DEFAULT_DUMPERS:
57-
self._connection.adapters.register_dumper(python_type, dumper)
78+
def register_type(self, enum_type: type[enum.Enum], pg_type: str) -> None:
79+
self._extra_enums.append((enum_type, pg_type))
5880

59-
for python_type, pg_type in DEFAULT_ENUMS:
60-
self.register_type(python_type, pg_type)
81+
def get_thread_conn(self) -> psycopg.Connection | None:
82+
return getattr(self._local, "conn", None)
6183

62-
def register_type(self, enum_type: type[enum.Enum], pg_type: str) -> None:
63-
if self._connection is None:
64-
raise RuntimeError("did not connect to database")
65-
66-
type_info = enum.EnumInfo.fetch(self._connection, pg_type)
67-
if type_info is None:
68-
raise RuntimeError(f"Unable to find enum {pg_type} in DB")
69-
70-
enum.register_enum(
71-
type_info,
72-
self._connection,
73-
enum_type,
74-
mapping={m: m.value for m in enum_type},
75-
)
84+
def set_thread_conn(self, conn: psycopg.Connection | None) -> None:
85+
self._local.conn = conn
7686

77-
def get_connection(self) -> psycopg.Connection:
78-
if self._connection is None:
79-
raise InternalError("unable to create database connection")
87+
def get_pool(self) -> ConnectionPool:
88+
if self._pool is None:
89+
raise InternalError("connection pool is not initialized")
90+
return self._pool
8091

81-
return self._connection
92+
def get_connection(self) -> psycopg.Connection:
93+
conn = self.get_thread_conn()
94+
if conn is not None:
95+
return conn
96+
raise InternalError("no active transaction connection on this thread")
8297

8398
def disconnect(self) -> None:
84-
if self._connection is not None:
99+
if self._pool is not None:
85100
self._logger.debug("disconnecting from Postgres", endpoint=self._config.endpoint, port=self._config.port)
101+
self._pool.close()
86102

87-
self._connection.close()
88-
89-
def _query_str(self, query: str | sql.SQL | sql.Composed) -> str:
103+
def query_str(self, query: str | sql.SQL | sql.Composed) -> str:
90104
if isinstance(query, str):
91105
return query
92-
return query.as_string(self._connection)
106+
conn = self.get_thread_conn()
107+
if conn is not None:
108+
return query.as_string(conn)
109+
with self.get_pool().connection() as c:
110+
return query.as_string(c)
93111

94112
def exec(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> None:
95113
if params is None:
96114
params = []
97-
if self._connection is None:
98-
raise RuntimeError("Unable to execute query: connection to Postgres was not established")
99-
100-
log.debug("SQL query", query=self._query_str(query).replace("\n", " "), args=params)
101115

102-
cursor = self._connection.cursor()
103-
cursor.execute(query, params)
116+
log.debug("SQL query", query=self.query_str(query).replace("\n", " "), args=params)
104117

105-
def execute_batch(self, query: str, rows: Sequence[Sequence[Any]]) -> None:
106-
if self._connection is None:
107-
raise RuntimeError("Unable to execute query: connection to Postgres was not established")
118+
conn = self.get_thread_conn()
119+
if conn is not None:
120+
conn.cursor().execute(query, params)
121+
else:
122+
with self.get_pool().connection() as c:
123+
c.cursor().execute(query, params)
108124

109-
log.debug("SQL execute batch", query=query.replace("\n", " "), num_rows=len(rows))
125+
def execute_batch(self, query: str, rows_data: Sequence[Sequence[Any]]) -> None:
126+
log.debug("SQL execute batch", query=query.replace("\n", " "), num_rows=len(rows_data))
110127

111-
cursor = self._connection.cursor()
112-
cursor.executemany(query, rows)
128+
conn = self.get_thread_conn()
129+
if conn is not None:
130+
conn.cursor().executemany(query, rows_data)
131+
else:
132+
with self.get_pool().connection() as c:
133+
c.cursor().executemany(query, rows_data)
113134

114135
def query(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> list[rows.DictRow]:
115136
if params is None:
116137
params = []
117-
if self._connection is None:
118-
raise RuntimeError("Unable to execute query: connection to Postgres was not established")
119-
120-
log.debug("SQL query", query=self._query_str(query).replace("\n", " "), args=params)
121-
122-
cursor = self._connection.cursor()
123-
cursor.execute(query, params)
124-
125-
result = cursor.fetchall()
126-
log.debug("SQL result", num_rows=len(result))
127138

128-
return result
139+
log.debug("SQL query", query=self.query_str(query).replace("\n", " "), args=params)
140+
141+
def _run(conn: psycopg.Connection) -> list[rows.DictRow]:
142+
cursor = conn.cursor()
143+
start = time.monotonic()
144+
cursor.execute(query, params)
145+
result = cursor.fetchall()
146+
elapsed = time.monotonic() - start
147+
log.debug("SQL result", num_rows=len(result), elapsed_seconds=round(elapsed, 4))
148+
return result
149+
150+
conn = self.get_thread_conn()
151+
if conn is not None:
152+
return _run(conn)
153+
with self.get_pool().connection() as c:
154+
return _run(c)
129155

130156
def query_one(self, query: str | sql.SQL | sql.Composed, *, params: list[Any] | None = None) -> rows.DictRow:
131157
result = self.query(query, params=params)

app/lib/storage/postgres/transactional.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ def __init__(self, storage: postgres.PgStorage) -> None:
99

1010
@contextmanager
1111
def with_tx(self):
12-
conn = self._storage.get_connection()
13-
try:
14-
with conn.transaction():
15-
yield
16-
except Exception as e:
17-
raise e
12+
with self._storage.get_pool().connection() as conn:
13+
self._storage.set_thread_conn(conn)
14+
try:
15+
with conn.transaction():
16+
yield
17+
finally:
18+
self._storage.set_thread_conn(None)

tests/integration/transactional_storage_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ def test_several_queries(self):
1818

1919
self.assertEqual(result["id"], 42)
2020

21-
def test_nested_transactions(self):
21+
def test_multiple_statements_in_transaction(self):
2222
repo = transactional.TransactionalPGRepository(self.pg_storage.get_storage())
2323
with repo.with_tx():
2424
self.pg_storage.get_storage().exec("CREATE TABLE test_table2 (id INTEGER)")
25-
with repo.with_tx():
26-
self.pg_storage.get_storage().exec("INSERT INTO test_table2 VALUES (42)")
27-
result = self.pg_storage.get_storage().query_one("SELECT id FROM test_table2 LIMIT 1")
25+
self.pg_storage.get_storage().exec("INSERT INTO test_table2 VALUES (42)")
26+
result = self.pg_storage.get_storage().query_one("SELECT id FROM test_table2 LIMIT 1")
2827

2928
self.assertEqual(result["id"], 42)
3029

0 commit comments

Comments
 (0)