Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions src/utils_flask_sqla/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from dateutil import parser
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData
from sqlalchemy import MetaData, inspect
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.types import Boolean, Date, DateTime, Integer, Numeric
from werkzeug.exceptions import BadRequest
Expand Down Expand Up @@ -102,13 +102,46 @@ def __init__(self, tableName, schemaName, engine):
- engine : sqlalchemy instance engine
for exemple : DB.engine if DB = Sqlalchemy()
"""
meta = MetaData(schema=schemaName)
meta.reflect(views=True, bind=engine)
# En sqlalchemy 2.0, il faut utiliser MetaData()
meta = MetaData()

# Try to reflect just the specific table instead of all tables
try:
self.tableDef = meta.tables["{}.{}".format(schemaName, tableName)]
except KeyError:
raise KeyError("table {}.{} doesn't exists".format(schemaName, tableName))

meta.reflect(only=[tableName], schema=schemaName, views=True, bind=engine)
table_key = f"{schemaName}.{tableName}"

if table_key in meta.tables:
self.tableDef = meta.tables[table_key]
# If not found with schema, try without schema
elif tableName in meta.tables:
self.tableDef = meta.tables[tableName]
else:
# Si on ne trouve pas la table, en essaye de la trouver dans le schema et les vues

inspector = inspect(engine)
available_views = inspector.get_view_names(schema=schemaName)
available_tables = inspector.get_table_names(schema=schemaName)
if tableName in available_views or tableName in available_tables:
# Force reflection with explicit view flag
meta = MetaData()
meta.reflect(
only=[tableName],
schema=schemaName,
views=True,
bind=engine,
extend_existing=True,
)
table_key = f"{schemaName}.{tableName}"
if table_key in meta.tables:
self.tableDef = meta.tables[table_key]
else:
raise KeyError(f"table {schemaName}.{tableName} doesn't exist")
else:
raise KeyError(f"table {schemaName}.{tableName} doesn't exist")
except Exception as e:
# If any error occurs, provide a detailed error message
raise KeyError(f"Error accessing table {schemaName}.{tableName}: {str(e)}")

# Mise en place d'un mapping des colonnes en vue d'une sérialisation
self.serialize_columns, self.db_cols = self.get_serialized_columns()
Expand Down Expand Up @@ -250,8 +283,8 @@ def raw_query(self, process_filter=True):
Renvoie la requete 'brute' (sans .all)
- process_filter: application des filtres (et du sort)
"""

q = self.DB.session.query(self.view.tableDef)
# Use select() instead of query()
q = self.DB.select(self.view.tableDef)

if not process_filter:
return q
Expand All @@ -268,13 +301,23 @@ def query(self):
"""
Lance la requete et retourne l'objet sqlalchemy
"""
q = self.DB.session.query(self.view.tableDef)
nb_result_without_filter = q.count()
# Use select() instead of query()
nb_result_without_filter = self.DB.session.scalar(
self.DB.select(self.DB.func.count()).select_from(self.view.tableDef)
)

# Get filtered query using raw_query
q = self.raw_query(process_filter=True)
total_filtered = q.count() if self.filters else nb_result_without_filter

data = q.all()
# Calculate total filtered rows
if self.filters:
count_stmt = self.DB.select(self.DB.func.count()).select_from(q.subquery())
total_filtered = self.DB.session.scalar(count_stmt)
else:
total_filtered = nb_result_without_filter

# Execute query
data = self.DB.session.execute(q).all()

return data, nb_result_without_filter, total_filtered

Expand Down
34 changes: 19 additions & 15 deletions src/utils_flask_sqla/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from itertools import chain
from functools import lru_cache
from uuid import UUID
from flask import current_app

from sqlalchemy.orm import ColumnProperty
from sqlalchemy import inspect
from sqlalchemy import inspect, select
from sqlalchemy.ext.hybrid import hybrid_property, HYBRID_PROPERTY
from sqlalchemy.types import DateTime, Date, Time
from sqlalchemy.dialects.postgresql.base import UUID
Expand Down Expand Up @@ -355,7 +356,7 @@ def populatefn(self, dict_in, recursif=False):
recursif: si on renseigne les relationships

"""

db = current_app.extensions["sqlalchemy"]
cls_db_columns_key = list(map(lambda x: x[0], get_cls_db_columns()))

# populate cls_db_columns
Expand Down Expand Up @@ -400,27 +401,30 @@ def populatefn(self, dict_in, recursif=False):

# preload with id
# pour faire une seule requête
ids = filter(lambda x: x, map(lambda x: x.get(id_field_name), values))
preload_res_with_ids = Model.query.where(
getattr(Model, id_field_name).in_(ids)
).all()
ids = set(filter(lambda x: x, map(lambda x: x.get(id_field_name), values)))

stmt = select(Model).where(getattr(Model, id_field_name).in_(ids))

preload_res_with_ids = db.session.execute(stmt).scalars().all()

# resul
v_obj = []

for data in values:
id_value = data.pop(id_field_name, None)

# On filtre la liste des objets préchargés
filtered_results = list(
filter(
lambda x: getattr(x, id_field_name) == id_value,
preload_res_with_ids,
)
)

res = (
# si on a une id -> on recupère dans la liste preload_res_with_ids
# TODO trouver un find plus propre ?
list(
filter(
lambda x: getattr(x, id_field_name) == id_value,
preload_res_with_ids,
)
)[0]
if id_value and len(preload_res_with_ids)
# si on a une id et qu'on a trouvé au moins un résultat
filtered_results[0]
if id_value and filtered_results
# sinon on cree une nouvelle instance
else Model()
)
Expand Down