33import hashlib
44import importlib
55import logging
6+ import re
67import warnings
78from collections import defaultdict
89from collections .abc import Callable
@@ -186,6 +187,24 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
186187 }
187188
188189
190+ def map_sqlc_error (
191+ error : str ,
192+ block_starts : list [tuple [int , str ]],
193+ all_locations : dict [str , list [str ]],
194+ ) -> str :
195+ def replace (m : re .Match [str ]) -> str :
196+ line = int (m .group (1 ))
197+ name = next ((n for start , n in reversed (block_starts ) if start <= line ), None )
198+ if name is None :
199+ return m .group (0 )
200+ locations = all_locations .get (name )
201+ if not locations :
202+ return m .group (0 )
203+ return f"{ ', ' .join (locations )} :"
204+
205+ return re .sub (r"queries\.sql:(\d+)(?::\d+)?:" , replace , error )
206+
207+
189208def generate_sql_package ( # noqa: PLR0913, PLR0914
190209 * ,
191210 schema_path : Path ,
@@ -200,21 +219,14 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
200219 src_path : Path = Path (),
201220 tempdir_path : Path | None = None ,
202221) -> bool :
203- dsn_import_package , dsn_import_path = dsn_import .split (":" )
204-
205- package_name = package_full_name .split ("." )[- 1 ] # noqa: PLC0207
222+ package_name = package_full_name .rsplit ("." , maxsplit = 1 )[- 1 ]
206223 sql_fn_name = f"{ package_name } _sql"
207224
208- target_package_path = src_path / f" { package_full_name . replace ( '.' , '/' ) } .py"
225+ queries , all_locations = collect_queries ( src_path , sql_fn_name )
209226
210- queries = list (find_all_queries (src_path , sql_fn_name ))
211- validate_stmt_has_single_row_type (queries )
212- queries = list ({q .name : q for q in queries }.values ())
227+ dsn , dsn_import_package , dsn_import_path = resolve_dsn (dsn_import )
213228
214- dsn_package = importlib .import_module (dsn_import_package )
215- dsn = eval (dsn_import_path , vars (dsn_package )) # noqa: S307
216-
217- sqlc_res = run_sqlc (
229+ sqlc_res , block_starts = run_sqlc (
218230 src_path / schema_path ,
219231 [(q .name , q .stmt ) for q in queries ],
220232 dsn = dsn ,
@@ -223,63 +235,13 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
223235 )
224236
225237 if sqlc_res .error :
226- logger .error ("Error running SQLC:\n %s" , sqlc_res .error )
238+ mapped = map_sqlc_error (sqlc_res .error , block_starts , all_locations )
239+ logger .error (f"Error running SQLC:\n { mapped } " )
227240 return False
228241
229- json_import_block = ""
230- json_col_overrides : dict [tuple [str , str ], str ] = {}
231-
232- if json_model_overrides :
233- json_compatible_types = {"json" , "jsonb" , "text" , "varchar" }
234- col_types = {
235- (table .rel .name , column .name ): column .type .name .removeprefix ("pg_catalog." )
236- for schema in sqlc_res .catalog .schemas
237- for table in schema .tables
238- for column in table .columns
239- }
240- tables = {table for table , _ in col_types }
241-
242- parsed : dict [tuple [str , str ], tuple [str , str ]] = {}
243- for key , import_path in json_model_overrides .items ():
244- table_name , sep , col_name = key .partition ("." )
245- if not sep :
246- msg = f"json_model_overrides key must be 'table.column', got: { key !r} "
247- raise ValueError (msg )
248- if table_name not in tables :
249- msg = f"json_model_overrides: table { table_name !r} not found in catalog"
250- raise ValueError (msg )
251- if (table_name , col_name ) not in col_types :
252- msg = (
253- f"json_model_overrides: column { col_name !r} "
254- f"not found in table { table_name !r} "
255- )
256- raise ValueError (msg )
257-
258- db_type = col_types [table_name , col_name ]
259- if db_type not in json_compatible_types :
260- msg = (
261- f"json_model_overrides: column "
262- f"{ table_name } .{ col_name } has type "
263- f"{ db_type !r} , expected one of "
264- f"{ json_compatible_types } "
265- )
266- raise ValueError (msg )
267-
268- module_path , sep , class_name = import_path .partition (":" )
269- if not sep :
270- msg = (
271- "json_model_overrides value must be "
272- f"'module:Class', got: { import_path !r} "
273- )
274- raise ValueError (msg )
275-
276- parsed [table_name , col_name ] = (module_path , class_name )
277-
278- modules = sorted ({module for module , _ in parsed .values ()})
279- json_import_block = "\n " + "\n " .join (f"import { m } " for m in modules )
280- json_col_overrides = {
281- key : f"{ module } .{ cls } " for key , (module , cls ) in parsed .items ()
282- }
242+ json_import_block , json_col_overrides = resolve_json_model_overrides (
243+ json_model_overrides or {}, sqlc_res .catalog
244+ )
283245
284246 resolver = TypeResolver (
285247 catalog = sqlc_res .catalog ,
@@ -297,18 +259,79 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
297259 resolver ,
298260 )
299261
300- entities = [ render_entity (e .name , e .column_specs ) for e in ordered_entities ]
262+ entities = sorted ( render_entity (e .name , e .column_specs ) for e in ordered_entities )
301263
302264 used_enums = collect_used_enums (sqlc_res )
303265
304- enums = [
266+ enums = sorted (
305267 render_enum_class (e , package_name , to_pascal_fn , to_snake_fn )
306268 for schema in sqlc_res .catalog .schemas
307269 for e in schema .enums
308270 if (schema .name , e .name ) in used_enums
271+ )
272+
273+ query_classes = render_query_classes (
274+ sqlc_res .queries , queries , resolver , result_types , all_locations
275+ )
276+
277+ query_overloads = [
278+ render_query_overload (sql_fn_name , q .name , q .stmt , q .row_type ) for q in queries
309279 ]
310280
311- query_classes = [
281+ query_dict_entries = [render_query_dict_entry (q .name , q .stmt ) for q in queries ]
282+
283+ target_package_path = src_path / f"{ package_full_name .replace ('.' , '/' )} .py"
284+
285+ new_content = render_package (
286+ dsn_import_package ,
287+ dsn_import_path ,
288+ package_name ,
289+ sql_fn_name ,
290+ entities ,
291+ enums ,
292+ query_classes ,
293+ query_overloads ,
294+ query_dict_entries ,
295+ application_name ,
296+ json_import_block ,
297+ )
298+ changed = write_if_changed (target_package_path , new_content + "\n " )
299+ if changed :
300+ logger .info (f"Generated SQL package { package_full_name } " )
301+ return changed
302+
303+
304+ def collect_queries (
305+ src_path : Path , sql_fn_name : str
306+ ) -> tuple [list ["CodeQuery" ], defaultdict [str , list [str ]]]:
307+ raw = list (find_all_queries (src_path , sql_fn_name ))
308+ validate_stmt_has_single_row_type (raw )
309+ all_locations : defaultdict [str , list [str ]] = defaultdict (list )
310+ first_occurrence : dict [str , CodeQuery ] = {}
311+ for q in raw :
312+ all_locations [q .name ].append (q .location )
313+ if q .name not in first_occurrence :
314+ first_occurrence [q .name ] = q
315+ queries = sorted (first_occurrence .values (), key = lambda q : (q .file , q .lineno ))
316+ return queries , all_locations
317+
318+
319+ def resolve_dsn (dsn_import : str ) -> tuple [str , str , str ]:
320+ package_name , attr_path = dsn_import .split (":" )
321+ mod = importlib .import_module (package_name )
322+ dsn : str = eval (attr_path , vars (mod )) # noqa: S307
323+ return dsn , package_name , attr_path
324+
325+
326+ def render_query_classes (
327+ sqlc_queries : tuple [Query , ...],
328+ queries : list ["CodeQuery" ],
329+ resolver : TypeResolver ,
330+ result_types : dict [str , str ],
331+ all_locations : defaultdict [str , list [str ]],
332+ ) -> list [str ]:
333+ query_order = {q .name : i for i , q in enumerate (queries )}
334+ return [
312335 render_query_class (
313336 q .name ,
314337 q .text ,
@@ -327,33 +350,67 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
327350 if len (q .columns ) == 1
328351 else None
329352 ),
353+ all_locations [q .name ],
330354 )
331- for q in sqlc_res . queries
355+ for q in sorted ( sqlc_queries , key = lambda q : query_order [ q . name ])
332356 ]
333357
334- query_overloads = [
335- render_query_overload (sql_fn_name , q .name , q .stmt , q .row_type ) for q in queries
336- ]
337358
338- query_dict_entries = [render_query_dict_entry (q .name , q .stmt ) for q in queries ]
359+ def resolve_json_model_overrides (
360+ overrides : dict [str , str ], catalog : Catalog
361+ ) -> tuple [str , dict [tuple [str , str ], str ]]:
362+ if not overrides :
363+ return "" , {}
339364
340- new_content = render_package (
341- dsn_import_package ,
342- dsn_import_path ,
343- package_name ,
344- sql_fn_name ,
345- sorted (entities ),
346- sorted (enums ),
347- sorted (query_classes ),
348- sorted (query_overloads ),
349- sorted (query_dict_entries ),
350- application_name ,
351- json_import_block ,
352- )
353- changed = write_if_changed (target_package_path , new_content + "\n " )
354- if changed :
355- logger .info (f"Generated SQL package { package_full_name } " )
356- return changed
365+ json_compatible_types = {"json" , "jsonb" , "text" , "varchar" }
366+ col_types = {
367+ (table .rel .name , column .name ): column .type .name .removeprefix ("pg_catalog." )
368+ for schema in catalog .schemas
369+ for table in schema .tables
370+ for column in table .columns
371+ }
372+ tables = {table for table , _ in col_types }
373+
374+ parsed : dict [tuple [str , str ], tuple [str , str ]] = {}
375+ for key , import_path in overrides .items ():
376+ table_name , sep , col_name = key .partition ("." )
377+ if not sep :
378+ msg = f"json_model_overrides key must be 'table.column', got: { key !r} "
379+ raise ValueError (msg )
380+ if table_name not in tables :
381+ msg = f"json_model_overrides: table { table_name !r} not found in catalog"
382+ raise ValueError (msg )
383+ if (table_name , col_name ) not in col_types :
384+ msg = (
385+ f"json_model_overrides: column { col_name !r} "
386+ f"not found in table { table_name !r} "
387+ )
388+ raise ValueError (msg )
389+
390+ db_type = col_types [table_name , col_name ]
391+ if db_type not in json_compatible_types :
392+ msg = (
393+ f"json_model_overrides: column "
394+ f"{ table_name } .{ col_name } has type "
395+ f"{ db_type !r} , expected one of "
396+ f"{ json_compatible_types } "
397+ )
398+ raise ValueError (msg )
399+
400+ module_path , sep , class_name = import_path .partition (":" )
401+ if not sep :
402+ msg = (
403+ "json_model_overrides value must be "
404+ f"'module:Class', got: { import_path !r} "
405+ )
406+ raise ValueError (msg )
407+
408+ parsed [table_name , col_name ] = (module_path , class_name )
409+
410+ modules = sorted ({module for module , _ in parsed .values ()})
411+ import_block = "\n " + "\n " .join (f"import { m } " for m in modules )
412+ col_overrides = {key : f"{ module } .{ cls } " for key , (module , cls ) in parsed .items ()}
413+ return import_block , col_overrides
357414
358415
359416def render_package ( # noqa: PLR0913, PLR0917
@@ -562,7 +619,8 @@ def render_query_class(
562619 query_params : list [ParamSpec ],
563620 result : str ,
564621 columns_num : int ,
565- scalar_json_type : str | None = None ,
622+ scalar_json_type : str | None ,
623+ locations : list [str ],
566624) -> str :
567625 query_params = deduplicate_params (query_params )
568626
@@ -631,6 +689,7 @@ async def execute({", ".join(query_fn_params)}) -> None:
631689 return f"""
632690
633691class { query_name } (Query[{ result } ]):
692+ # See: { ", " .join (locations )}
634693 _stmt = psycopg.sql.SQL({ stmt !r} )
635694 _row_factory = staticmethod({ row_factory } )
636695
@@ -767,7 +826,12 @@ def find_fn_calls(
767826 content = path .read_text (encoding = "utf-8" )
768827 if fn_name not in content :
769828 continue
770- for node in ast .walk (ast .parse (content , filename = str (path ))):
829+ try :
830+ tree = ast .parse (content , filename = str (path ))
831+ except SyntaxError as exc :
832+ msg = f"Failed to parse { path } : { exc .msg } (line { exc .lineno } )"
833+ raise SyntaxError (msg ) from exc
834+ for node in ast .walk (tree ):
771835 match node :
772836 case ast .Call (func = ast .Name (id = id )) if id == fn_name :
773837 yield path , node .lineno , node
@@ -817,11 +881,15 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
817881
818882
819883def validate_stmt_has_single_row_type (queries : list [CodeQuery ]) -> None :
820- row_type_by_stmt : dict [str , str | None ] = {}
884+ first_by_stmt : dict [str , CodeQuery ] = {}
821885 for query in queries :
822- if query .stmt in row_type_by_stmt :
823- if query .row_type != row_type_by_stmt [query .stmt ]:
824- msg = f"row_type conflict (existing={ row_type_by_stmt [query .stmt ]!r} )"
886+ if query .stmt in first_by_stmt :
887+ first = first_by_stmt [query .stmt ]
888+ if query .row_type != first .row_type :
889+ msg = (
890+ f"row_type conflict: { first .location } has { first .row_type !r} ,"
891+ f" { query .location } has { query .row_type !r} "
892+ )
825893 raise ValueError (msg )
826894 else :
827- row_type_by_stmt [query .stmt ] = query . row_type
895+ first_by_stmt [query .stmt ] = query
0 commit comments