-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdependency.py
More file actions
155 lines (126 loc) · 4.33 KB
/
dependency.py
File metadata and controls
155 lines (126 loc) · 4.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from typing import AsyncGenerator, Generator
from fastapi import Depends, FastAPI, Request, HTTPException, status
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.ext.asyncio import (
create_async_engine,
async_sessionmaker,
AsyncSession,
AsyncEngine,
)
from loguru import logger
from app.core import settings
def _async_engine_kwargs() -> dict:
return {
"echo": False,
"pool_pre_ping": True,
"pool_size": 5,
"max_overflow": 10,
}
def _sync_engine_kwargs() -> dict:
return {
"echo": False,
"pool_pre_ping": True,
"pool_size": 5,
"max_overflow": 10,
}
async def init_database_on_app(app: FastAPI) -> None:
async_engine = create_async_engine(
settings.ASYNC_DATABASE_URL, **_async_engine_kwargs()
)
sync_engine = create_engine(
settings.SYNC_DATABASE_URL, **_sync_engine_kwargs()
)
app.state.async_engine = async_engine
app.state.sync_engine = sync_engine
app.state.async_sessionmaker = async_sessionmaker(
bind=async_engine,
expire_on_commit=False,
autoflush=True,
)
app.state.sync_sessionmaker = sessionmaker(
bind=sync_engine,
expire_on_commit=False,
autoflush=True,
)
logger.info("PostgreSQL engines & sessionmakers initialized.")
async def shutdown_database_on_app(app: FastAPI) -> None:
if hasattr(app.state, "async_engine"):
await app.state.async_engine.dispose()
if hasattr(app.state, "sync_engine"):
app.state.sync_engine.dispose()
logger.info("PostgreSQL engines disposed.")
async def _ensure(
app: FastAPI,
) -> tuple[async_sessionmaker[AsyncSession], sessionmaker]:
if not hasattr(app.state, "async_sessionmaker"):
logger.error("Async sessionmaker not initialized.")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service Unavailable",
)
if not hasattr(app.state, "sync_sessionmaker"):
logger.error("Sync sessionmaker not initialized.")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service Unavailable",
)
return app.state.async_sessionmaker, app.state.sync_sessionmaker
async def _get_async_sessionmaker(
request: Request,
) -> async_sessionmaker[AsyncSession]:
async_maker, _ = await _ensure(request.app)
return async_maker
async def _get_sync_sessionmaker(
request: Request,
) -> sessionmaker:
_, sync_maker = await _ensure(request.app)
return sync_maker
async def get_async_session(
session_maker: async_sessionmaker[AsyncSession] = Depends(
_get_async_sessionmaker
),
) -> AsyncGenerator[AsyncSession, None]:
async with session_maker() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
def get_sync_session(
session_maker: sessionmaker = Depends(_get_sync_sessionmaker),
) -> Generator[Session, None, None]:
session: Session = session_maker()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
_global_async_engine: AsyncEngine | None = None
_global_async_sessionmaker: async_sessionmaker[AsyncSession] | None = None
def _init_global_async_engine() -> AsyncEngine:
global _global_async_engine
if _global_async_engine is None:
_global_async_engine = create_async_engine(
settings.ASYNC_DATABASE_URL, **_async_engine_kwargs()
)
logger.info("Initialized global async engine.")
return _global_async_engine
def _init_global_async_sessionmaker() -> async_sessionmaker[AsyncSession]:
global _global_async_sessionmaker
if _global_async_sessionmaker is None:
engine = _init_global_async_engine()
_global_async_sessionmaker = async_sessionmaker(
bind=engine,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
logger.info("Initialized global async sessionmaker.")
return _global_async_sessionmaker
def get_global_async_sessionmaker() -> async_sessionmaker[AsyncSession]:
return _init_global_async_sessionmaker()