Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ckan/model/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@
# SQLAlchemy database engine. Updated by model.init_model()
engine: Optional[Engine] = None


"""
Default Session, it is scoped session which will auto cleanup
"""
Session: AlchemySession = orm.scoped_session(orm.sessionmaker(
autoflush=False,
autocommit=False,
expire_on_commit=False,
))


"""
Manual sessions factory, you MUST ``session.close()`` is Explicitly required as it won't self clean up
"""
create_local_session = orm.sessionmaker(
autoflush=False,
autocommit=False,
Expand Down
14 changes: 8 additions & 6 deletions ckanext/datapusher/logic/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from ckan.types import Context
import ckan.model.meta as Meta
import logging
import json
import datetime
Expand Down Expand Up @@ -125,12 +126,13 @@ def datapusher_submit(context: Context, data_dict: dict[str, Any]):
context['ignore_auth'] = True
# Use local session for task_status_update, so it can commit its own
# results without messing up with the parent session that contains pending
# updats of dataset/resource/etc.
context.update({
'session': context['model'].meta.create_local_session() # type: ignore
})
p.toolkit.get_action('task_status_update')(context, task)

# updates of dataset/resource/etc.
meta: Meta = context['model'].meta # type: ignore
with meta.create_local_session() as session:
context.update({
'session': session # type: ignore
})
p.toolkit.get_action('task_status_update')(context, task)
timeout = config.get('ckan.requests.timeout')

# This setting is checked on startup
Expand Down
33 changes: 17 additions & 16 deletions ckanext/datastore/backend/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from typing_extensions import TypeAlias

from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.base import Engine, Connection
from sqlalchemy.dialects.postgresql import REGCLASS
from ckan.types import Context, ErrorDict
import copy
Expand Down Expand Up @@ -1807,8 +1807,8 @@ def search_sql(context: Context, data_dict: dict[str, Any]):
backend = DatastorePostgresqlBackend.get_active_backend()
engine = backend._get_read_engine() # type: ignore
_cache_types(engine)

context['connection'] = engine.connect()
connection: Connection = engine.connect()
context['connection'] = connection
timeout = context.get('query_timeout', _TIMEOUT)

sql = data_dict['sql']
Expand All @@ -1820,7 +1820,7 @@ def search_sql(context: Context, data_dict: dict[str, Any]):

try:

context['connection'].execute(sa.text(
connection.execute(sa.text(
f"SET LOCAL statement_timeout TO {timeout}"
))

Expand Down Expand Up @@ -1875,15 +1875,15 @@ def _remove_explain(msg: str):
})
raise
finally:
context['connection'].close()
connection.close()


class DatastorePostgresqlBackend(DatastoreBackend):

def _get_write_engine(self):
def _get_write_engine(self) -> Engine:
return _get_engine_from_url(self.write_url)

def _get_read_engine(self):
def _get_read_engine(self) -> Engine:
return _get_engine_from_url(self.read_url)

def _log_or_raise(self, message: str):
Expand All @@ -1907,7 +1907,7 @@ def _check_urls_and_permissions(self):
if not self._read_connection_has_correct_privileges():
self._log_or_raise('The read-only user has write privileges.')

def _is_postgresql_engine(self):
def _is_postgresql_engine(self) -> bool:
''' Returns True if the read engine is a Postgresql Database.

According to
Expand All @@ -1917,7 +1917,7 @@ def _is_postgresql_engine(self):
drivername = self._get_read_engine().engine.url.drivername
return drivername.startswith('postgres')

def _is_read_only_database(self):
def _is_read_only_database(self) -> bool:
''' Returns True if no connection has CREATE privileges on the public
schema. This is the case if replication is enabled.'''
for url in [self.ckan_url, self.write_url, self.read_url]:
Expand All @@ -1932,7 +1932,7 @@ def _is_read_only_database(self):
return False
return True

def _same_ckan_and_datastore_db(self):
def _same_ckan_and_datastore_db(self) -> bool:
'''Returns True if the CKAN and DataStore db are the same'''
return self._get_db_from_url(self.ckan_url) == self._get_db_from_url(
self.read_url)
Expand Down Expand Up @@ -2152,18 +2152,19 @@ def create(
engine = get_write_engine()
_cache_types(engine)

context['connection'] = engine.connect()
connection = engine.connect()
context['connection'] = connection
timeout = context.get('query_timeout', _TIMEOUT)

_rename_json_field(data_dict)

trans = context['connection'].begin()
trans = connection.begin()
try:
# check if table already exists
context['connection'].execute(sa.text(
connection.execute(sa.text(
f"SET LOCAL statement_timeout TO {timeout}"
))
result = context['connection'].execute(sa.text(
result = connection.execute(sa.text(
'SELECT * FROM pg_tables WHERE tablename = :table'
), {"table": data_dict['resource_id']}).fetchone()
if not result:
Expand All @@ -2175,7 +2176,7 @@ def create(
alter_table(context, data_dict, plugin_data)
if 'triggers' in data_dict:
_create_triggers(
context['connection'],
connection,
data_dict['resource_id'],
data_dict['triggers'])
insert_data(context, data_dict)
Expand Down Expand Up @@ -2211,7 +2212,7 @@ def create(
trans.rollback()
raise
finally:
context['connection'].close()
connection.close()

def upsert(self, context: Context, data_dict: dict[str, Any]):
data_dict['connection_url'] = self.write_url
Expand Down
7 changes: 5 additions & 2 deletions ckanext/datastore/logic/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from contextlib import contextmanager

import sqlalchemy
from sqlalchemy.engine.base import Connection
import sqlalchemy.exc

import ckan.lib.navl.dictization_functions
Expand Down Expand Up @@ -245,17 +246,19 @@ def datastore_run_triggers(context: Context, data_dict: dict[str, Any]) -> int:
res_id = data_dict['resource_id']
p.toolkit.check_access('datastore_run_triggers', context, data_dict)
backend = DatastoreBackend.get_active_backend()
connection = backend._get_write_engine().connect() # type: ignore
connection: Connection = backend._get_write_engine().connect() # type: ignore

sql = sqlalchemy.text(u'''update {0} set _id=_id '''.format(
identifier(res_id)))
try:
results: Any = connection.execute(sql)
return results.rowcount
except sqlalchemy.exc.DatabaseError as err:
message = str(err.args[0].split('\n')[0])
raise p.toolkit.ValidationError({
u'records': [message.split(u') ', 1)[-1]]})
return results.rowcount
finally:
connection.close()


def datastore_upsert(context: Context, data_dict: dict[str, Any]):
Expand Down