Skip to content

Commit 32ff07d

Browse files
committed
fixes #3
1 parent df72a36 commit 32ff07d

4 files changed

Lines changed: 231 additions & 49 deletions

File tree

fastasyncpg/_modidx.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
'fastasyncpg.core.Table.__init__': ('core.html#table.__init__', 'fastasyncpg/core.py'),
3333
'fastasyncpg.core.Table.__repr__': ('core.html#table.__repr__', 'fastasyncpg/core.py'),
3434
'fastasyncpg.core.Table.__str__': ('core.html#table.__str__', 'fastasyncpg/core.py'),
35+
'fastasyncpg.core.Table._rec': ('core.html#table._rec', 'fastasyncpg/core.py'),
36+
'fastasyncpg.core.Table._recs': ('core.html#table._recs', 'fastasyncpg/core.py'),
3537
'fastasyncpg.core.Table.c': ('core.html#table.c', 'fastasyncpg/core.py'),
3638
'fastasyncpg.core.Table.cols': ('core.html#table.cols', 'fastasyncpg/core.py'),
3739
'fastasyncpg.core.Table.count': ('core.html#table.count', 'fastasyncpg/core.py'),
@@ -73,6 +75,7 @@
7375
'fastasyncpg.core._TablesGetter': ('core.html#_tablesgetter', 'fastasyncpg/core.py'),
7476
'fastasyncpg.core._TablesGetter.__init__': ('core.html#_tablesgetter.__init__', 'fastasyncpg/core.py'),
7577
'fastasyncpg.core._add_xtra': ('core.html#_add_xtra', 'fastasyncpg/core.py'),
78+
'fastasyncpg.core._convq': ('core.html#_convq', 'fastasyncpg/core.py'),
7679
'fastasyncpg.core._dataclass': ('core.html#_dataclass', 'fastasyncpg/core.py'),
7780
'fastasyncpg.core._get_flds': ('core.html#_get_flds', 'fastasyncpg/core.py'),
7881
'fastasyncpg.core._pk_where': ('core.html#_pk_where', 'fastasyncpg/core.py'),

fastasyncpg/core.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -167,20 +167,35 @@ def c(self:Table): return _ColsGetter(self)
167167
from itertools import count
168168
from sqlparse import tokens
169169

170-
# %% ../nbs/00_core.ipynb #1ca2b1b9
171-
def conv_placeholders(sql):
172-
"Convert `?` placeholders to PostgreSQL `$n` style"
173-
if '?' not in sql: return sql
170+
# %% ../nbs/00_core.ipynb #d15d86a3
171+
def _convq(parsed, c):
172+
"Replace `?` placeholder tokens with `$n` in parsed SQL"
173+
for s in parsed:
174+
for t in s.flatten():
175+
if t.ttype is tokens.Name.Placeholder and t.value == '?': t.value = f'${next(c)}'
176+
177+
def conv_placeholders(sql, **kwargs):
178+
"Convert `?` and `:name` placeholders to PostgreSQL `$n` style"
179+
if '?' not in sql and not kwargs: return sql, []
180+
parsed = sqlparse.parse(sql)
174181
c = count(1)
175-
def _convpl1(s):
176-
return ''.join(f'${next(c)}' if t.ttype is tokens.Name.Placeholder and t.value=='?' else t.value for t in s.flatten())
177-
return ''.join([_convpl1(s) for s in sqlparse.parse(sql)])
178-
179-
# %% ../nbs/00_core.ipynb #ceeb54fb
182+
_convq(parsed, c)
183+
seen, kw_args = {}, []
184+
for s in parsed:
185+
for t in s.flatten():
186+
if t.ttype is tokens.Name.Placeholder and t.value[0]!='$':
187+
name = t.value.lstrip(':')
188+
if name not in seen:
189+
seen[name] = f'${next(c)}'
190+
kw_args.append(kwargs[name])
191+
t.value = seen[name]
192+
return ''.join(str(s) for s in parsed), kw_args
193+
194+
# %% ../nbs/00_core.ipynb #17fc5af6
180195
@patch
181-
async def q(self:Database, sql, *args):
182-
csql = conv_placeholders(sql)
183-
return Results(await self.fetch(csql, *args))
196+
async def q(self:Database, sql, *args, **kwargs):
197+
csql, kw_args = conv_placeholders(sql, **kwargs)
198+
return Results(await self.fetch(csql, *args, *kw_args))
184199

185200
# %% ../nbs/00_core.ipynb #0fc7310b
186201
from datetime import datetime, date, time, timedelta
@@ -273,6 +288,25 @@ def _add_xtra(tbl, where, args, offset=0):
273288
args.extend(tbl.xtra_id.values())
274289
return where, args
275290

291+
# %% ../nbs/00_core.ipynb #975d0f26
292+
@patch
293+
async def _recs(self:Table, sql, *args):
294+
"Fetch rows and convert to cls"
295+
cls = getattr(self, 'cls', None)
296+
def f(r):
297+
if not cls: return r
298+
res = cls(**r)
299+
res._db = self.db
300+
return res
301+
return [f(r) for r in await self.db.q(sql, *args)]
302+
303+
@patch
304+
async def _rec(self:Table, sql, *args, err=None):
305+
"Fetch one row, optionally raising NotFoundError"
306+
res = await self._recs(sql, *args)
307+
if res: return res[0]
308+
if err: raise NotFoundError(err)
309+
276310
# %% ../nbs/00_core.ipynb #de5874df
277311
class NotFoundError(Exception): pass
278312

0 commit comments

Comments
 (0)