1+ import threading
2+ import time
13from collections .abc import Sequence
24from typing import Any
35
68import structlog
79from psycopg import rows , sql
810from psycopg .types import enum , numeric
11+ from psycopg_pool import ConnectionPool
912
1013from app .lib .storage import enums
1114from app .lib .storage .postgres import config
@@ -33,7 +36,7 @@ def dump(self, obj: Any) -> bytes | bytearray | memoryview:
3336 (np .int64 , NumpyIntDumper ),
3437]
3538
36- DEFAULT_ENUMS = [
39+ DEFAULT_ENUMS : list [ tuple [ type [ enum . Enum ], str ]] = [
3740 (enums .DataType , "common.datatype" ),
3841 (enums .RecordCrossmatchStatus , "layer0.crossmatch_status" ),
3942 (enums .RecordTriageStatus , "layer0.triage_status" ),
@@ -43,89 +46,112 @@ def dump(self, obj: Any) -> bytes | bytearray | memoryview:
4346class PgStorage :
4447 def __init__ (self , cfg : config .PgStorageConfig , logger : structlog .stdlib .BoundLogger ) -> None :
4548 self ._config = cfg
46- self ._connection : psycopg . Connection | None = None
49+ self ._pool : ConnectionPool | None = None
4750 self ._logger = logger
51+ self ._local = threading .local ()
52+ self ._extra_enums : list [tuple [type [enum .Enum ], str ]] = []
4853
49- def connect (self ) -> None :
50- self ._connection = psycopg .connect (self ._config .get_dsn (), row_factory = rows .dict_row , autocommit = True )
51- if self ._connection is None :
52- raise InternalError ("unable to create database connection" )
54+ def _configure_connection (self , conn : psycopg .Connection ) -> None :
55+ for python_type , dumper in DEFAULT_DUMPERS :
56+ conn .adapters .register_dumper (python_type , dumper )
57+ for enum_type , pg_type in DEFAULT_ENUMS + self ._extra_enums :
58+ type_info = enum .EnumInfo .fetch (conn , pg_type )
59+ if type_info is None :
60+ raise RuntimeError (f"Unable to find enum { pg_type } in DB" )
61+ enum .register_enum (
62+ type_info ,
63+ conn ,
64+ enum_type ,
65+ mapping = {m : m .value for m in enum_type },
66+ )
5367
68+ def connect (self ) -> None :
5469 self ._logger .debug ("connecting to Postgres" , endpoint = self ._config .endpoint , port = self ._config .port )
70+ self ._pool = ConnectionPool (
71+ self ._config .get_dsn (),
72+ min_size = 2 ,
73+ max_size = 10 ,
74+ kwargs = {"row_factory" : rows .dict_row , "autocommit" : True },
75+ configure = self ._configure_connection ,
76+ )
5577
56- for python_type , dumper in DEFAULT_DUMPERS :
57- self ._connection . adapters . register_dumper ( python_type , dumper )
78+ def register_type ( self , enum_type : type [ enum . Enum ], pg_type : str ) -> None :
79+ self ._extra_enums . append (( enum_type , pg_type ) )
5880
59- for python_type , pg_type in DEFAULT_ENUMS :
60- self .register_type ( python_type , pg_type )
81+ def get_thread_conn ( self ) -> psycopg . Connection | None :
82+ return getattr ( self ._local , "conn" , None )
6183
62- def register_type (self , enum_type : type [enum .Enum ], pg_type : str ) -> None :
63- if self ._connection is None :
64- raise RuntimeError ("did not connect to database" )
65-
66- type_info = enum .EnumInfo .fetch (self ._connection , pg_type )
67- if type_info is None :
68- raise RuntimeError (f"Unable to find enum { pg_type } in DB" )
69-
70- enum .register_enum (
71- type_info ,
72- self ._connection ,
73- enum_type ,
74- mapping = {m : m .value for m in enum_type },
75- )
84+ def set_thread_conn (self , conn : psycopg .Connection | None ) -> None :
85+ self ._local .conn = conn
7686
77- def get_connection (self ) -> psycopg .Connection :
78- if self ._connection is None :
79- raise InternalError ("unable to create database connection" )
87+ def get_pool (self ) -> ConnectionPool :
88+ if self ._pool is None :
89+ raise InternalError ("connection pool is not initialized" )
90+ return self ._pool
8091
81- return self ._connection
92+ def get_connection (self ) -> psycopg .Connection :
93+ conn = self .get_thread_conn ()
94+ if conn is not None :
95+ return conn
96+ raise InternalError ("no active transaction connection on this thread" )
8297
8398 def disconnect (self ) -> None :
84- if self ._connection is not None :
99+ if self ._pool is not None :
85100 self ._logger .debug ("disconnecting from Postgres" , endpoint = self ._config .endpoint , port = self ._config .port )
101+ self ._pool .close ()
86102
87- self ._connection .close ()
88-
89- def _query_str (self , query : str | sql .SQL | sql .Composed ) -> str :
103+ def query_str (self , query : str | sql .SQL | sql .Composed ) -> str :
90104 if isinstance (query , str ):
91105 return query
92- return query .as_string (self ._connection )
106+ conn = self .get_thread_conn ()
107+ if conn is not None :
108+ return query .as_string (conn )
109+ with self .get_pool ().connection () as c :
110+ return query .as_string (c )
93111
94112 def exec (self , query : str | sql .SQL | sql .Composed , * , params : list [Any ] | None = None ) -> None :
95113 if params is None :
96114 params = []
97- if self ._connection is None :
98- raise RuntimeError ("Unable to execute query: connection to Postgres was not established" )
99-
100- log .debug ("SQL query" , query = self ._query_str (query ).replace ("\n " , " " ), args = params )
101115
102- cursor = self ._connection .cursor ()
103- cursor .execute (query , params )
116+ log .debug ("SQL query" , query = self .query_str (query ).replace ("\n " , " " ), args = params )
104117
105- def execute_batch (self , query : str , rows : Sequence [Sequence [Any ]]) -> None :
106- if self ._connection is None :
107- raise RuntimeError ("Unable to execute query: connection to Postgres was not established" )
118+ conn = self .get_thread_conn ()
119+ if conn is not None :
120+ conn .cursor ().execute (query , params )
121+ else :
122+ with self .get_pool ().connection () as c :
123+ c .cursor ().execute (query , params )
108124
109- log .debug ("SQL execute batch" , query = query .replace ("\n " , " " ), num_rows = len (rows ))
125+ def execute_batch (self , query : str , rows_data : Sequence [Sequence [Any ]]) -> None :
126+ log .debug ("SQL execute batch" , query = query .replace ("\n " , " " ), num_rows = len (rows_data ))
110127
111- cursor = self ._connection .cursor ()
112- cursor .executemany (query , rows )
128+ conn = self .get_thread_conn ()
129+ if conn is not None :
130+ conn .cursor ().executemany (query , rows_data )
131+ else :
132+ with self .get_pool ().connection () as c :
133+ c .cursor ().executemany (query , rows_data )
113134
114135 def query (self , query : str | sql .SQL | sql .Composed , * , params : list [Any ] | None = None ) -> list [rows .DictRow ]:
115136 if params is None :
116137 params = []
117- if self ._connection is None :
118- raise RuntimeError ("Unable to execute query: connection to Postgres was not established" )
119-
120- log .debug ("SQL query" , query = self ._query_str (query ).replace ("\n " , " " ), args = params )
121-
122- cursor = self ._connection .cursor ()
123- cursor .execute (query , params )
124-
125- result = cursor .fetchall ()
126- log .debug ("SQL result" , num_rows = len (result ))
127138
128- return result
139+ log .debug ("SQL query" , query = self .query_str (query ).replace ("\n " , " " ), args = params )
140+
141+ def _run (conn : psycopg .Connection ) -> list [rows .DictRow ]:
142+ cursor = conn .cursor ()
143+ start = time .monotonic ()
144+ cursor .execute (query , params )
145+ result = cursor .fetchall ()
146+ elapsed = time .monotonic () - start
147+ log .debug ("SQL result" , num_rows = len (result ), elapsed_seconds = round (elapsed , 4 ))
148+ return result
149+
150+ conn = self .get_thread_conn ()
151+ if conn is not None :
152+ return _run (conn )
153+ with self .get_pool ().connection () as c :
154+ return _run (c )
129155
130156 def query_one (self , query : str | sql .SQL | sql .Composed , * , params : list [Any ] | None = None ) -> rows .DictRow :
131157 result = self .query (query , params = params )
0 commit comments