From faf9457b0160368e5210c5bd299c8cb1648934d9 Mon Sep 17 00:00:00 2001 From: Manabu Niseki Date: Mon, 12 Jun 2023 16:02:23 +0900 Subject: [PATCH] feat: support query params --- prestodb/__init__.py | 1 + prestodb/dbapi.py | 4 +- prestodb/escaper.py | 89 +++++++++++++++++++++++++++++++++++++++++++ tests/test_escaper.py | 31 +++++++++++++++ 4 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 prestodb/escaper.py create mode 100644 tests/test_escaper.py diff --git a/prestodb/__init__.py b/prestodb/__init__.py index 13c3bdc..7d2d22b 100644 --- a/prestodb/__init__.py +++ b/prestodb/__init__.py @@ -18,5 +18,6 @@ from . import client from . import constants from . import exceptions +from . import escaper __version__ = "0.8.3" diff --git a/prestodb/dbapi.py b/prestodb/dbapi.py index 71229d3..f13e211 100644 --- a/prestodb/dbapi.py +++ b/prestodb/dbapi.py @@ -27,6 +27,7 @@ from prestodb import constants import prestodb.exceptions +import prestodb.escaper import prestodb.client import prestodb.redirect from prestodb.transaction import Transaction, IsolationLevel, NO_TRANSACTION @@ -232,7 +233,8 @@ def setoutputsize(self, size, column): raise prestodb.exceptions.NotSupportedError def execute(self, operation, params=None): - self._query = prestodb.client.PrestoQuery(self._request, sql=operation) + sql = operation if params is None else operation % prestodb.escaper.escape(params) + self._query = prestodb.client.PrestoQuery(self._request, sql=sql) result = self._query.execute() self._iterator = iter(result) return result diff --git a/prestodb/escaper.py b/prestodb/escaper.py new file mode 100644 index 0000000..751f654 --- /dev/null +++ b/prestodb/escaper.py @@ -0,0 +1,89 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This code is forked from https://github.com/dropbox/PyHive (the Apache License, Version 2.0) +from __future__ import absolute_import + +import datetime + +try: + from collections.abc import Iterable +except ImportError: + from collections import Iterable + +import prestodb.exceptions + +class ParamsEscaper(object): + _DATE_FORMAT = "%Y-%m-%d" + _TIME_FORMAT = "%H:%M:%S.%f" + _DATETIME_FORMAT = "{} {}".format(_DATE_FORMAT, _TIME_FORMAT) + + def escape_args(self, parameters): + if isinstance(parameters, dict): + return {k: self.escape_item(v) for k, v in parameters.items()} + + if isinstance(parameters, (list, tuple)): + return tuple(self.escape_item(x) for x in parameters) + + raise prestodb.exceptions.ProgrammingError("Unsupported param format: {}".format(parameters)) + + def escape_number(self, item): + return item + + def escape_bytes(self, item): + return self.escape_string(item.decode("utf-8")) + + def escape_string(self, item): + # This is good enough when backslashes are literal, newlines are just followed, and the way + # to escape a single quote is to put two single quotes. + # (i.e. only special character is single quote) + return "'{}'".format(item.replace("'", "''")) + + def escape_sequence(self, item): + l = map(str, map(self.escape_item, item)) + return '(' + ','.join(l) + ')' + + def escape_datetime(self, item, format, cutoff=0): + dt_str = item.strftime(format) + formatted = dt_str[:-cutoff] if cutoff and format.endswith(".%f") else dt_str + + _type = "timestamp" if isinstance(item, datetime.datetime) else "date" + return "{} {}".format(_type, formatted) + + def escape_item(self, item): + if item is None: + return 'NULL' + + if isinstance(item, (int, float)): + return self.escape_number(item) + + if isinstance(item, bytes): + return self.escape_bytes(item) + + if isinstance(item, str): + return self.escape_string(item) + + if isinstance(item, Iterable): + return self.escape_sequence(item) + + if isinstance(item, datetime.datetime): + return self.escape_datetime(item, self._DATETIME_FORMAT) + + if isinstance(item, datetime.date): + return self.escape_datetime(item, self._DATE_FORMAT) + + raise prestodb.exceptions.ProgrammingError("Unsupported object {}".format(item)) + +escaper = ParamsEscaper() + +def escape(params): + return escaper.escape_args(params) diff --git a/tests/test_escaper.py b/tests/test_escaper.py new file mode 100644 index 0000000..377be47 --- /dev/null +++ b/tests/test_escaper.py @@ -0,0 +1,31 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +import datetime +import prestodb.escaper + +def test_escape_args(): + escaper = prestodb.escaper.ParamsEscaper() + + assert escaper.escape_args({'foo': 'bar'}) == {'foo': "'bar'"} + assert escaper.escape_args({'foo': 123}) == {'foo': 123} + assert escaper.escape_args({'foo': 123.456}) == {'foo': 123.456} + assert escaper.escape_args({'foo': ['a', 'b', 'c']}) == {'foo': "('a','b','c')"} + assert escaper.escape_args({'foo': ('a', 'b', 'c')}) == {'foo': "('a','b','c')"} + assert escaper.escape_args({'foo': {'a', 'b'}}) in ({'foo': "('a','b')"}, {'foo': "('b','a')"}) + assert escaper.escape_args(('bar',)) == ("'bar'",) + assert escaper.escape_args([123]) == (123,) + assert escaper.escape_args((123.456,)) == (123.456,) + assert escaper.escape_args((['a', 'b', 'c'],)) == ("('a','b','c')",) + + assert escaper.escape_args((datetime.date(2020, 4, 17),)) == ('date 2020-04-17',) + assert escaper.escape_args((datetime.datetime(2020, 4, 17, 12, 0, 0, 123456),)) == ('timestamp 2020-04-17 12:00:00.123456',)