@@ -167,20 +167,35 @@ def c(self:Table): return _ColsGetter(self)
167167from itertools import count
168168from sqlparse import tokens
169169
170- # %% ../nbs/00_core.ipynb #1ca2b1b9
171- def conv_placeholders (sql ):
172- "Convert `?` placeholders to PostgreSQL `$n` style"
173- if '?' not in sql : return sql
170+ # %% ../nbs/00_core.ipynb #d15d86a3
171+ def _convq (parsed , c ):
172+ "Replace `?` placeholder tokens with `$n` in parsed SQL"
173+ for s in parsed :
174+ for t in s .flatten ():
175+ if t .ttype is tokens .Name .Placeholder and t .value == '?' : t .value = f'${ next (c )} '
176+
177+ def conv_placeholders (sql , ** kwargs ):
178+ "Convert `?` and `:name` placeholders to PostgreSQL `$n` style"
179+ if '?' not in sql and not kwargs : return sql , []
180+ parsed = sqlparse .parse (sql )
174181 c = count (1 )
175- def _convpl1 (s ):
176- return '' .join (f'${ next (c )} ' if t .ttype is tokens .Name .Placeholder and t .value == '?' else t .value for t in s .flatten ())
177- return '' .join ([_convpl1 (s ) for s in sqlparse .parse (sql )])
178-
179- # %% ../nbs/00_core.ipynb #ceeb54fb
182+ _convq (parsed , c )
183+ seen , kw_args = {}, []
184+ for s in parsed :
185+ for t in s .flatten ():
186+ if t .ttype is tokens .Name .Placeholder and t .value [0 ]!= '$' :
187+ name = t .value .lstrip (':' )
188+ if name not in seen :
189+ seen [name ] = f'${ next (c )} '
190+ kw_args .append (kwargs [name ])
191+ t .value = seen [name ]
192+ return '' .join (str (s ) for s in parsed ), kw_args
193+
194+ # %% ../nbs/00_core.ipynb #17fc5af6
180195@patch
181- async def q (self :Database , sql , * args ):
182- csql = conv_placeholders (sql )
183- return Results (await self .fetch (csql , * args ))
196+ async def q (self :Database , sql , * args , ** kwargs ):
197+ csql , kw_args = conv_placeholders (sql , ** kwargs )
198+ return Results (await self .fetch (csql , * args , * kw_args ))
184199
185200# %% ../nbs/00_core.ipynb #0fc7310b
186201from datetime import datetime , date , time , timedelta
@@ -273,6 +288,25 @@ def _add_xtra(tbl, where, args, offset=0):
273288 args .extend (tbl .xtra_id .values ())
274289 return where , args
275290
291+ # %% ../nbs/00_core.ipynb #975d0f26
292+ @patch
293+ async def _recs (self :Table , sql , * args ):
294+ "Fetch rows and convert to cls"
295+ cls = getattr (self , 'cls' , None )
296+ def f (r ):
297+ if not cls : return r
298+ res = cls (** r )
299+ res ._db = self .db
300+ return res
301+ return [f (r ) for r in await self .db .q (sql , * args )]
302+
303+ @patch
304+ async def _rec (self :Table , sql , * args , err = None ):
305+ "Fetch one row, optionally raising NotFoundError"
306+ res = await self ._recs (sql , * args )
307+ if res : return res [0 ]
308+ if err : raise NotFoundError (err )
309+
276310# %% ../nbs/00_core.ipynb #de5874df
277311class NotFoundError (Exception ): pass
278312
0 commit comments