Skip to content

Commit 7022f3e

Browse files
authored
Quality of life improvements (#12)
* Add clear error message for SyntaxError in scanned files * Add source location comments to generated Query classes * Include source locations in row_type conflict error message * Sort generated queries by source location instead of md5 hash * Show all source locations for deduplicated queries, order by first occurrence * Restructure tests * Map sqlc errors to source locations When sqlc reports an error referencing the synthetic queries.sql, replace the reference with the original file:line of the _sql() call so the user can find the problematic query directly. * Simplify generator code
1 parent e62ab0b commit 7022f3e

14 files changed

Lines changed: 736 additions & 623 deletions

example/db/mydb.py

Lines changed: 84 additions & 75 deletions
Large diffs are not rendered by default.

src/iron_sql/codegen/generator.py

Lines changed: 166 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import hashlib
44
import importlib
55
import logging
6+
import re
67
import warnings
78
from collections import defaultdict
89
from collections.abc import Callable
@@ -186,6 +187,24 @@ def collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
186187
}
187188

188189

190+
def map_sqlc_error(
191+
error: str,
192+
block_starts: list[tuple[int, str]],
193+
all_locations: dict[str, list[str]],
194+
) -> str:
195+
def replace(m: re.Match[str]) -> str:
196+
line = int(m.group(1))
197+
name = next((n for start, n in reversed(block_starts) if start <= line), None)
198+
if name is None:
199+
return m.group(0)
200+
locations = all_locations.get(name)
201+
if not locations:
202+
return m.group(0)
203+
return f"{', '.join(locations)}:"
204+
205+
return re.sub(r"queries\.sql:(\d+)(?::\d+)?:", replace, error)
206+
207+
189208
def generate_sql_package( # noqa: PLR0913, PLR0914
190209
*,
191210
schema_path: Path,
@@ -200,21 +219,14 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
200219
src_path: Path = Path(),
201220
tempdir_path: Path | None = None,
202221
) -> bool:
203-
dsn_import_package, dsn_import_path = dsn_import.split(":")
204-
205-
package_name = package_full_name.split(".")[-1] # noqa: PLC0207
222+
package_name = package_full_name.rsplit(".", maxsplit=1)[-1]
206223
sql_fn_name = f"{package_name}_sql"
207224

208-
target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
225+
queries, all_locations = collect_queries(src_path, sql_fn_name)
209226

210-
queries = list(find_all_queries(src_path, sql_fn_name))
211-
validate_stmt_has_single_row_type(queries)
212-
queries = list({q.name: q for q in queries}.values())
227+
dsn, dsn_import_package, dsn_import_path = resolve_dsn(dsn_import)
213228

214-
dsn_package = importlib.import_module(dsn_import_package)
215-
dsn = eval(dsn_import_path, vars(dsn_package)) # noqa: S307
216-
217-
sqlc_res = run_sqlc(
229+
sqlc_res, block_starts = run_sqlc(
218230
src_path / schema_path,
219231
[(q.name, q.stmt) for q in queries],
220232
dsn=dsn,
@@ -223,63 +235,13 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
223235
)
224236

225237
if sqlc_res.error:
226-
logger.error("Error running SQLC:\n%s", sqlc_res.error)
238+
mapped = map_sqlc_error(sqlc_res.error, block_starts, all_locations)
239+
logger.error(f"Error running SQLC:\n{mapped}")
227240
return False
228241

229-
json_import_block = ""
230-
json_col_overrides: dict[tuple[str, str], str] = {}
231-
232-
if json_model_overrides:
233-
json_compatible_types = {"json", "jsonb", "text", "varchar"}
234-
col_types = {
235-
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
236-
for schema in sqlc_res.catalog.schemas
237-
for table in schema.tables
238-
for column in table.columns
239-
}
240-
tables = {table for table, _ in col_types}
241-
242-
parsed: dict[tuple[str, str], tuple[str, str]] = {}
243-
for key, import_path in json_model_overrides.items():
244-
table_name, sep, col_name = key.partition(".")
245-
if not sep:
246-
msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
247-
raise ValueError(msg)
248-
if table_name not in tables:
249-
msg = f"json_model_overrides: table {table_name!r} not found in catalog"
250-
raise ValueError(msg)
251-
if (table_name, col_name) not in col_types:
252-
msg = (
253-
f"json_model_overrides: column {col_name!r} "
254-
f"not found in table {table_name!r}"
255-
)
256-
raise ValueError(msg)
257-
258-
db_type = col_types[table_name, col_name]
259-
if db_type not in json_compatible_types:
260-
msg = (
261-
f"json_model_overrides: column "
262-
f"{table_name}.{col_name} has type "
263-
f"{db_type!r}, expected one of "
264-
f"{json_compatible_types}"
265-
)
266-
raise ValueError(msg)
267-
268-
module_path, sep, class_name = import_path.partition(":")
269-
if not sep:
270-
msg = (
271-
"json_model_overrides value must be "
272-
f"'module:Class', got: {import_path!r}"
273-
)
274-
raise ValueError(msg)
275-
276-
parsed[table_name, col_name] = (module_path, class_name)
277-
278-
modules = sorted({module for module, _ in parsed.values()})
279-
json_import_block = "\n" + "\n".join(f"import {m}" for m in modules)
280-
json_col_overrides = {
281-
key: f"{module}.{cls}" for key, (module, cls) in parsed.items()
282-
}
242+
json_import_block, json_col_overrides = resolve_json_model_overrides(
243+
json_model_overrides or {}, sqlc_res.catalog
244+
)
283245

284246
resolver = TypeResolver(
285247
catalog=sqlc_res.catalog,
@@ -297,18 +259,79 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
297259
resolver,
298260
)
299261

300-
entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]
262+
entities = sorted(render_entity(e.name, e.column_specs) for e in ordered_entities)
301263

302264
used_enums = collect_used_enums(sqlc_res)
303265

304-
enums = [
266+
enums = sorted(
305267
render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
306268
for schema in sqlc_res.catalog.schemas
307269
for e in schema.enums
308270
if (schema.name, e.name) in used_enums
271+
)
272+
273+
query_classes = render_query_classes(
274+
sqlc_res.queries, queries, resolver, result_types, all_locations
275+
)
276+
277+
query_overloads = [
278+
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
309279
]
310280

311-
query_classes = [
281+
query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
282+
283+
target_package_path = src_path / f"{package_full_name.replace('.', '/')}.py"
284+
285+
new_content = render_package(
286+
dsn_import_package,
287+
dsn_import_path,
288+
package_name,
289+
sql_fn_name,
290+
entities,
291+
enums,
292+
query_classes,
293+
query_overloads,
294+
query_dict_entries,
295+
application_name,
296+
json_import_block,
297+
)
298+
changed = write_if_changed(target_package_path, new_content + "\n")
299+
if changed:
300+
logger.info(f"Generated SQL package {package_full_name}")
301+
return changed
302+
303+
304+
def collect_queries(
305+
src_path: Path, sql_fn_name: str
306+
) -> tuple[list["CodeQuery"], defaultdict[str, list[str]]]:
307+
raw = list(find_all_queries(src_path, sql_fn_name))
308+
validate_stmt_has_single_row_type(raw)
309+
all_locations: defaultdict[str, list[str]] = defaultdict(list)
310+
first_occurrence: dict[str, CodeQuery] = {}
311+
for q in raw:
312+
all_locations[q.name].append(q.location)
313+
if q.name not in first_occurrence:
314+
first_occurrence[q.name] = q
315+
queries = sorted(first_occurrence.values(), key=lambda q: (q.file, q.lineno))
316+
return queries, all_locations
317+
318+
319+
def resolve_dsn(dsn_import: str) -> tuple[str, str, str]:
320+
package_name, attr_path = dsn_import.split(":")
321+
mod = importlib.import_module(package_name)
322+
dsn: str = eval(attr_path, vars(mod)) # noqa: S307
323+
return dsn, package_name, attr_path
324+
325+
326+
def render_query_classes(
327+
sqlc_queries: tuple[Query, ...],
328+
queries: list["CodeQuery"],
329+
resolver: TypeResolver,
330+
result_types: dict[str, str],
331+
all_locations: defaultdict[str, list[str]],
332+
) -> list[str]:
333+
query_order = {q.name: i for i, q in enumerate(queries)}
334+
return [
312335
render_query_class(
313336
q.name,
314337
q.text,
@@ -327,33 +350,67 @@ def generate_sql_package( # noqa: PLR0913, PLR0914
327350
if len(q.columns) == 1
328351
else None
329352
),
353+
all_locations[q.name],
330354
)
331-
for q in sqlc_res.queries
355+
for q in sorted(sqlc_queries, key=lambda q: query_order[q.name])
332356
]
333357

334-
query_overloads = [
335-
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
336-
]
337358

338-
query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]
359+
def resolve_json_model_overrides(
360+
overrides: dict[str, str], catalog: Catalog
361+
) -> tuple[str, dict[tuple[str, str], str]]:
362+
if not overrides:
363+
return "", {}
339364

340-
new_content = render_package(
341-
dsn_import_package,
342-
dsn_import_path,
343-
package_name,
344-
sql_fn_name,
345-
sorted(entities),
346-
sorted(enums),
347-
sorted(query_classes),
348-
sorted(query_overloads),
349-
sorted(query_dict_entries),
350-
application_name,
351-
json_import_block,
352-
)
353-
changed = write_if_changed(target_package_path, new_content + "\n")
354-
if changed:
355-
logger.info(f"Generated SQL package {package_full_name}")
356-
return changed
365+
json_compatible_types = {"json", "jsonb", "text", "varchar"}
366+
col_types = {
367+
(table.rel.name, column.name): column.type.name.removeprefix("pg_catalog.")
368+
for schema in catalog.schemas
369+
for table in schema.tables
370+
for column in table.columns
371+
}
372+
tables = {table for table, _ in col_types}
373+
374+
parsed: dict[tuple[str, str], tuple[str, str]] = {}
375+
for key, import_path in overrides.items():
376+
table_name, sep, col_name = key.partition(".")
377+
if not sep:
378+
msg = f"json_model_overrides key must be 'table.column', got: {key!r}"
379+
raise ValueError(msg)
380+
if table_name not in tables:
381+
msg = f"json_model_overrides: table {table_name!r} not found in catalog"
382+
raise ValueError(msg)
383+
if (table_name, col_name) not in col_types:
384+
msg = (
385+
f"json_model_overrides: column {col_name!r} "
386+
f"not found in table {table_name!r}"
387+
)
388+
raise ValueError(msg)
389+
390+
db_type = col_types[table_name, col_name]
391+
if db_type not in json_compatible_types:
392+
msg = (
393+
f"json_model_overrides: column "
394+
f"{table_name}.{col_name} has type "
395+
f"{db_type!r}, expected one of "
396+
f"{json_compatible_types}"
397+
)
398+
raise ValueError(msg)
399+
400+
module_path, sep, class_name = import_path.partition(":")
401+
if not sep:
402+
msg = (
403+
"json_model_overrides value must be "
404+
f"'module:Class', got: {import_path!r}"
405+
)
406+
raise ValueError(msg)
407+
408+
parsed[table_name, col_name] = (module_path, class_name)
409+
410+
modules = sorted({module for module, _ in parsed.values()})
411+
import_block = "\n" + "\n".join(f"import {m}" for m in modules)
412+
col_overrides = {key: f"{module}.{cls}" for key, (module, cls) in parsed.items()}
413+
return import_block, col_overrides
357414

358415

359416
def render_package( # noqa: PLR0913, PLR0917
@@ -562,7 +619,8 @@ def render_query_class(
562619
query_params: list[ParamSpec],
563620
result: str,
564621
columns_num: int,
565-
scalar_json_type: str | None = None,
622+
scalar_json_type: str | None,
623+
locations: list[str],
566624
) -> str:
567625
query_params = deduplicate_params(query_params)
568626

@@ -631,6 +689,7 @@ async def execute({", ".join(query_fn_params)}) -> None:
631689
return f"""
632690
633691
class {query_name}(Query[{result}]):
692+
# See: {", ".join(locations)}
634693
_stmt = psycopg.sql.SQL({stmt!r})
635694
_row_factory = staticmethod({row_factory})
636695
@@ -767,7 +826,12 @@ def find_fn_calls(
767826
content = path.read_text(encoding="utf-8")
768827
if fn_name not in content:
769828
continue
770-
for node in ast.walk(ast.parse(content, filename=str(path))):
829+
try:
830+
tree = ast.parse(content, filename=str(path))
831+
except SyntaxError as exc:
832+
msg = f"Failed to parse {path}: {exc.msg} (line {exc.lineno})"
833+
raise SyntaxError(msg) from exc
834+
for node in ast.walk(tree):
771835
match node:
772836
case ast.Call(func=ast.Name(id=id)) if id == fn_name:
773837
yield path, node.lineno, node
@@ -817,11 +881,15 @@ def find_all_queries(src_path: Path, sql_fn_name: str) -> Iterator[CodeQuery]:
817881

818882

819883
def validate_stmt_has_single_row_type(queries: list[CodeQuery]) -> None:
820-
row_type_by_stmt: dict[str, str | None] = {}
884+
first_by_stmt: dict[str, CodeQuery] = {}
821885
for query in queries:
822-
if query.stmt in row_type_by_stmt:
823-
if query.row_type != row_type_by_stmt[query.stmt]:
824-
msg = f"row_type conflict (existing={row_type_by_stmt[query.stmt]!r})"
886+
if query.stmt in first_by_stmt:
887+
first = first_by_stmt[query.stmt]
888+
if query.row_type != first.row_type:
889+
msg = (
890+
f"row_type conflict: {first.location} has {first.row_type!r},"
891+
f" {query.location} has {query.row_type!r}"
892+
)
825893
raise ValueError(msg)
826894
else:
827-
row_type_by_stmt[query.stmt] = query.row_type
895+
first_by_stmt[query.stmt] = query

0 commit comments

Comments
 (0)