diff --git a/.gitignore b/.gitignore index 60b4f3d..d66bae9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ bin .direnv .devenv* devenv.local.nix + +.idea/ diff --git a/README.md b/README.md index c9f2531..d6d59f0 100644 --- a/README.md +++ b/README.md @@ -76,3 +76,49 @@ class Status(str, enum.Enum): OPEN = "op!en" CLOSED = "clo@sed" ``` + +### Bulk Inserts with `:copyfrom` + +Use the `:copyfrom` command to generate batch insert methods that leverage SQLAlchemy’s executemany behavior via `Connection.execute()` with a list of parameter mappings. + +SQL (example): + +```sql +-- name: CreateUsersBatch :copyfrom +INSERT INTO users (email, name) VALUES ($1, $2); +``` + +Generated methods: + +```py +def create_users_batch(self, arg_list: List[Any]) -> int +async def create_users_batch(self, arg_list: List[Any]) -> int +``` + +Call with a list of dicts using positional parameter keys `p1..pN` (the generator converts `$1`/`@name` to `:pN`): + +```py +rows = [ + {"p1": "a@example.com", "p2": "Alice"}, + {"p1": "b@example.com", "p2": "Bob"}, +] +count = queries.create_users_batch(rows) # returns affected rowcount (int) +``` + +When a typed params struct is emitted (e.g., many parameters or config thresholds), the method accepts `List[Params]`. The generator converts items to dicts internally: + +```py +@dataclasses.dataclass() +class CreateUsersWithDetailsParams: + email: str + name: str + bio: Optional[str] + age: Optional[int] + active: Optional[bool] + +count = queries.create_users_with_details([ + CreateUsersWithDetailsParams("a@example.com", "Alice", None, None, True), +]) +``` + +Implementation note: sync and async use `conn.execute(sqlalchemy.text(SQL), list_of_dicts)` and `await async_conn.execute(...)` respectively; SQLAlchemy performs efficient batch inserts under the hood. diff --git a/internal/endtoend/testdata/copyfrom/python/models.py b/internal/endtoend/testdata/copyfrom/python/models.py new file mode 100644 index 0000000..5bb8f99 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/python/models.py @@ -0,0 +1,24 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.30.0 +import dataclasses +import datetime +from typing import Optional + + +@dataclasses.dataclass() +class Author: + id: int + name: str + bio: str + + +@dataclasses.dataclass() +class User: + id: int + email: str + name: str + bio: Optional[str] + age: Optional[int] + active: Optional[bool] + created_at: datetime.datetime diff --git a/internal/endtoend/testdata/copyfrom/python/query.py b/internal/endtoend/testdata/copyfrom/python/query.py new file mode 100644 index 0000000..1e6307d --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/python/query.py @@ -0,0 +1,198 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.30.0 +# source: query.sql +import dataclasses +from typing import Any, List, Optional + +import sqlalchemy +import sqlalchemy.ext.asyncio + +from copyfrom import models + + +CREATE_AUTHOR = """-- name: create_author \\:one +INSERT INTO authors (name, bio) VALUES (:p1, :p2) RETURNING id, name, bio +""" + + +CREATE_AUTHORS = """-- name: create_authors \\:copyfrom +INSERT INTO authors (name, bio) VALUES (:p1, :p2) +""" + + +CREATE_AUTHORS_NAMED = """-- name: create_authors_named \\:copyfrom +INSERT INTO authors (name, bio) VALUES (:p1, :p2) +""" + + +CREATE_USER = """-- name: create_user \\:one +INSERT INTO users (email, name) VALUES (:p1, :p2) RETURNING id, email, name, bio, age, active, created_at +""" + + +CREATE_USERS_BATCH = """-- name: create_users_batch \\:copyfrom +INSERT INTO users (email, name) VALUES (:p1, :p2) +""" + + +CREATE_USERS_SHUFFLED = """-- name: create_users_shuffled \\:copyfrom +INSERT INTO users (email, name, bio, age, active) VALUES (:p2, :p1, :p3, :p4, :p5) +""" + + +@dataclasses.dataclass() +class CreateUsersShuffledParams: + name: str + email: str + bio: Optional[str] + age: Optional[int] + active: Optional[bool] + + +CREATE_USERS_WITH_DETAILS = """-- name: create_users_with_details \\:copyfrom +INSERT INTO users (email, name, bio, age, active) VALUES (:p1, :p2, :p3, :p4, :p5) +""" + + +@dataclasses.dataclass() +class CreateUsersWithDetailsParams: + email: str + name: str + bio: Optional[str] + age: Optional[int] + active: Optional[bool] + + +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + + def create_author(self, *, name: str, bio: str) -> Optional[models.Author]: + row = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + def create_authors(self, arg_list: List[Any]) -> int: + result = self._conn.execute(sqlalchemy.text(CREATE_AUTHORS), arg_list) + return result.rowcount + + def create_authors_named(self, arg_list: List[Any]) -> int: + result = self._conn.execute(sqlalchemy.text(CREATE_AUTHORS_NAMED), arg_list) + return result.rowcount + + def create_user(self, *, email: str, name: str) -> Optional[models.User]: + row = self._conn.execute(sqlalchemy.text(CREATE_USER), {"p1": email, "p2": name}).first() + if row is None: + return None + return models.User( + id=row[0], + email=row[1], + name=row[2], + bio=row[3], + age=row[4], + active=row[5], + created_at=row[6], + ) + + def create_users_batch(self, arg_list: List[Any]) -> int: + result = self._conn.execute(sqlalchemy.text(CREATE_USERS_BATCH), arg_list) + return result.rowcount + + def create_users_shuffled(self, arg_list: List[CreateUsersShuffledParams]) -> int: + data = list() + for item in arg_list: + data.append({ + "p1": item.name, + "p2": item.email, + "p3": item.bio, + "p4": item.age, + "p5": item.active, + }) + result = self._conn.execute(sqlalchemy.text(CREATE_USERS_SHUFFLED), data) + return result.rowcount + + def create_users_with_details(self, arg_list: List[CreateUsersWithDetailsParams]) -> int: + data = list() + for item in arg_list: + data.append({ + "p1": item.email, + "p2": item.name, + "p3": item.bio, + "p4": item.age, + "p5": item.active, + }) + result = self._conn.execute(sqlalchemy.text(CREATE_USERS_WITH_DETAILS), data) + return result.rowcount + + +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_author(self, *, name: str, bio: str) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio})).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + async def create_authors(self, arg_list: List[Any]) -> int: + result = await self._conn.execute(sqlalchemy.text(CREATE_AUTHORS), arg_list) + return result.rowcount + + async def create_authors_named(self, arg_list: List[Any]) -> int: + result = await self._conn.execute(sqlalchemy.text(CREATE_AUTHORS_NAMED), arg_list) + return result.rowcount + + async def create_user(self, *, email: str, name: str) -> Optional[models.User]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_USER), {"p1": email, "p2": name})).first() + if row is None: + return None + return models.User( + id=row[0], + email=row[1], + name=row[2], + bio=row[3], + age=row[4], + active=row[5], + created_at=row[6], + ) + + async def create_users_batch(self, arg_list: List[Any]) -> int: + result = await self._conn.execute(sqlalchemy.text(CREATE_USERS_BATCH), arg_list) + return result.rowcount + + async def create_users_shuffled(self, arg_list: List[CreateUsersShuffledParams]) -> int: + data = list() + for item in arg_list: + data.append({ + "p1": item.name, + "p2": item.email, + "p3": item.bio, + "p4": item.age, + "p5": item.active, + }) + result = await self._conn.execute(sqlalchemy.text(CREATE_USERS_SHUFFLED), data) + return result.rowcount + + async def create_users_with_details(self, arg_list: List[CreateUsersWithDetailsParams]) -> int: + data = list() + for item in arg_list: + data.append({ + "p1": item.email, + "p2": item.name, + "p3": item.bio, + "p4": item.age, + "p5": item.active, + }) + result = await self._conn.execute(sqlalchemy.text(CREATE_USERS_WITH_DETAILS), data) + return result.rowcount diff --git a/internal/endtoend/testdata/copyfrom/query.sql b/internal/endtoend/testdata/copyfrom/query.sql new file mode 100644 index 0000000..3afd4af --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/query.sql @@ -0,0 +1,20 @@ +-- name: CreateAuthors :copyfrom +INSERT INTO authors (name, bio) VALUES ($1, $2); + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING *; + +-- name: CreateAuthorsNamed :copyfrom +INSERT INTO authors (name, bio) VALUES (@name, @bio); + +-- name: CreateUser :one +INSERT INTO users (email, name) VALUES (@email, @name) RETURNING *; + +-- name: CreateUsersBatch :copyfrom +INSERT INTO users (email, name) VALUES (@email, @name); + +-- name: CreateUsersWithDetails :copyfrom +INSERT INTO users (email, name, bio, age, active) VALUES ($1, $2, $3, $4, $5); + +-- name: CreateUsersShuffled :copyfrom +INSERT INTO users (email, name, bio, age, active) VALUES ($2, $1, $3, $4, $5); diff --git a/internal/endtoend/testdata/copyfrom/schema.sql b/internal/endtoend/testdata/copyfrom/schema.sql new file mode 100644 index 0000000..3ce88a6 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name text NOT NULL, + bio text NOT NULL +); + +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + email text NOT NULL, + name text NOT NULL, + bio text, + age int, + active boolean DEFAULT true, + created_at timestamp NOT NULL DEFAULT NOW() +); \ No newline at end of file diff --git a/internal/endtoend/testdata/copyfrom/sqlc.yaml b/internal/endtoend/testdata/copyfrom/sqlc.yaml new file mode 100644 index 0000000..b45c289 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/sqlc.yaml @@ -0,0 +1,17 @@ +version: "2" +plugins: + - name: py + wasm: + url: file://../../../../bin/sqlc-gen-python.wasm + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" +sql: + - schema: "schema.sql" + queries: "query.sql" + engine: postgresql + codegen: + - out: python + plugin: py + options: + package: copyfrom + emit_sync_querier: true + emit_async_querier: true \ No newline at end of file diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/models.py b/internal/endtoend/testdata/emit_pydantic_models/db/models.py index 7676e5c..2300f77 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/models.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import pydantic from typing import Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/query.py b/internal/endtoend/testdata/emit_pydantic_models/db/query.py index 6f5b76f..946674d 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/query.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql from typing import AsyncIterator, Iterator, Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml index beae200..2bb6af9 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml +++ b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_str_enum/db/models.py b/internal/endtoend/testdata/emit_str_enum/db/models.py index 5fdf754..5fd5508 100644 --- a/internal/endtoend/testdata/emit_str_enum/db/models.py +++ b/internal/endtoend/testdata/emit_str_enum/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses import enum from typing import Optional diff --git a/internal/endtoend/testdata/emit_str_enum/db/query.py b/internal/endtoend/testdata/emit_str_enum/db/query.py index 8082889..c02a9ec 100644 --- a/internal/endtoend/testdata/emit_str_enum/db/query.py +++ b/internal/endtoend/testdata/emit_str_enum/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql from typing import AsyncIterator, Iterator, Optional diff --git a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml index 04e3feb..d56db19 100644 --- a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml +++ b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_result/python/models.py b/internal/endtoend/testdata/exec_result/python/models.py index 034fb2d..ced3715 100644 --- a/internal/endtoend/testdata/exec_result/python/models.py +++ b/internal/endtoend/testdata/exec_result/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses diff --git a/internal/endtoend/testdata/exec_result/python/query.py b/internal/endtoend/testdata/exec_result/python/query.py index b68ce39..c063868 100644 --- a/internal/endtoend/testdata/exec_result/python/query.py +++ b/internal/endtoend/testdata/exec_result/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/exec_result/sqlc.yaml b/internal/endtoend/testdata/exec_result/sqlc.yaml index ddffc83..c4d60a2 100644 --- a/internal/endtoend/testdata/exec_result/sqlc.yaml +++ b/internal/endtoend/testdata/exec_result/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_rows/python/models.py b/internal/endtoend/testdata/exec_rows/python/models.py index 034fb2d..ced3715 100644 --- a/internal/endtoend/testdata/exec_rows/python/models.py +++ b/internal/endtoend/testdata/exec_rows/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses diff --git a/internal/endtoend/testdata/exec_rows/python/query.py b/internal/endtoend/testdata/exec_rows/python/query.py index 7a9b2a6..c5a936d 100644 --- a/internal/endtoend/testdata/exec_rows/python/query.py +++ b/internal/endtoend/testdata/exec_rows/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/exec_rows/sqlc.yaml b/internal/endtoend/testdata/exec_rows/sqlc.yaml index ddffc83..c4d60a2 100644 --- a/internal/endtoend/testdata/exec_rows/sqlc.yaml +++ b/internal/endtoend/testdata/exec_rows/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py index 8ba8803..0614ac0 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py index 1e1e161..8b9eb26 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql from typing import Optional diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml index efbb150..44f056e 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml +++ b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py index 059675d..2ddf019 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py index e8b723e..5a97c59 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml index 336bca7..6a8a76c 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py index 30e80db..77bdfe5 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py index 5a1fbbc..6380dce 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml index c20cd57..3149851 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py index 059675d..2ddf019 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py index 47bd6a9..5edcd9c 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.30.0 # source: query.sql import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml index 6e2cdeb..b02f0d6 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml index c432e4f..98ad219 100644 --- a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "d587b6af0bfe07dbb2c316bffb5f7f5d0073a60788c5d533b84373067c9c1916" sql: - schema: schema.sql queries: query.sql diff --git a/internal/gen.go b/internal/gen.go index 6e50fae..7fc0c1f 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -2,7 +2,7 @@ package python import ( "context" - json "encoding/json" + "encoding/json" "errors" "fmt" "log" @@ -10,7 +10,6 @@ import ( "sort" "strings" - "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" @@ -50,9 +49,10 @@ func (t pyType) Annotation() *pyast.Node { } type Field struct { - Name string - Type pyType - Comment string + Name string + Type pyType + Comment string + ParamNumber int32 } type Struct struct { @@ -135,6 +135,33 @@ type Query struct { } func (q Query) AddArgs(args *pyast.Arguments) { + switch q.Cmd { + case ":copyfrom": + q.addCopyFromArgs(args) + default: + q.addRegularArgs(args) + } +} + +// addCopyFromArgs adds arguments for :copyfrom commands +func (q Query) addCopyFromArgs(args *pyast.Arguments) { + // Check if we have a struct parameter + if len(q.Args) == 1 && q.Args[0].IsStruct() { + args.Args = append(args.Args, &pyast.Arg{ + Arg: q.Args[0].Name + "_list", + Annotation: subscriptNode("List", q.Args[0].Annotation()), + }) + } else { + // Fall back to List[Any] for individual parameters + args.Args = append(args.Args, &pyast.Arg{ + Arg: "arg_list", + Annotation: subscriptNode("List", poet.Name("Any")), + }) + } +} + +// addRegularArgs adds arguments for regular (non-copyfrom) commands +func (q Query) addRegularArgs(args *pyast.Arguments) { // A single struct arg does not need to be passed as a keyword argument if len(q.Args) == 1 && q.Args[0].IsStruct() { args.Args = append(args.Args, &pyast.Arg{ @@ -143,6 +170,8 @@ func (q Query) AddArgs(args *pyast.Arguments) { }) return } + + // Multiple args or non-struct args are passed as keyword arguments for _, a := range q.Args { args.KwOnlyArgs = append(args.KwOnlyArgs, &pyast.Arg{ Arg: a.Name, @@ -180,6 +209,83 @@ func (q Query) ArgDictNode() *pyast.Node { } } +// BuildCopyFromBody generates the method body for :copyfrom commands. +func (q Query) BuildCopyFromBody(isAsync bool) []*pyast.Node { + var body []*pyast.Node + + dataVar := "arg_list" + if len(q.Args) == 1 && q.Args[0].IsStruct() { + argName := q.Args[0].Name + "_list" + dataVar = "data" + body = append(body, q.buildStructToDictList(argName, dataVar)...) + } + + sqlText := poet.Node(&pyast.Call{ + Func: poet.Attribute(poet.Name("sqlalchemy"), "text"), + Args: []*pyast.Node{poet.Name(q.ConstantName)}, + }) + + execCall := poet.Node(&pyast.Call{ + Func: poet.Attribute(poet.Name("self._conn"), "execute"), + Args: []*pyast.Node{ + sqlText, + poet.Name(dataVar), + }, + }) + + if isAsync { + execCall = poet.Await(execCall) + } + + body = append(body, + assignNode("result", execCall), + poet.Return(poet.Attribute(poet.Name("result"), "rowcount")), + ) + + return body +} + +// buildStructToDictList converts a list of parameter structs to a list of dicts for SQLAlchemy +func (q Query) buildStructToDictList(sourceVar, targetVar string) []*pyast.Node { + var body []*pyast.Node + + body = append(body, assignNode(targetVar, poet.Node(&pyast.Call{ + Func: poet.Name("list"), + Args: []*pyast.Node{}, + }))) + + loopVar := "item" + dict := &pyast.Dict{} + for i, field := range q.Args[0].Struct.Fields { + paramNumber := field.ParamNumber + if paramNumber == 0 { + paramNumber = int32(i + 1) + } + + dict.Keys = append(dict.Keys, poet.Constant(fmt.Sprintf("p%v", paramNumber))) + dict.Values = append(dict.Values, poet.Attribute(poet.Name(loopVar), field.Name)) + } + + body = append(body, poet.Node(&pyast.For{ + Target: poet.Name(loopVar), + Iter: poet.Name(sourceVar), + Body: []*pyast.Node{ + poet.Node(&pyast.Call{ + Func: poet.Attribute(poet.Name(targetVar), "append"), + Args: []*pyast.Node{ + { + Node: &pyast.Node_Dict{ + Dict: dict, + }, + }, + }, + }), + }, + })) + + return body +} + func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { typ := pyInnerType(req, col) return pyType{ @@ -343,8 +449,9 @@ func columnsToStruct(req *plugin.GenerateRequest, name string, columns []pyColum fieldName = fmt.Sprintf("%s_%d", fieldName, suffix) } gs.Fields = append(gs.Fields, Field{ - Name: fieldName, - Type: makePyType(req, c.Column), + Name: fieldName, + Type: makePyType(req, c.Column), + ParamNumber: c.id, }) seen[colName]++ } @@ -372,9 +479,6 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ if query.Cmd == "" { continue } - if query.Cmd == metadata.CmdCopyFrom { - return nil, errors.New("Support for CopyFrom in Python is not implemented") - } methodName := methodName(query.Name) @@ -403,11 +507,13 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ Column: p.Column, }) } - gq.Args = []QueryValue{{ - Emit: true, - Name: "arg", - Struct: columnsToStruct(req, query.Name+"Params", cols), - }} + gq.Args = []QueryValue{ + { + Emit: true, + Name: "arg", + Struct: columnsToStruct(req, query.Name+"Params", cols), + }, + } } else { args := make([]QueryValue, 0, len(query.Params)) for _, p := range query.Params { @@ -959,6 +1065,9 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { poet.Return(exec), ) f.Returns = typeRefNode("sqlalchemy", "engine", "Result") + case ":copyfrom": + f.Body = append(f.Body, q.BuildCopyFromBody(false)...) + f.Returns = poet.Name("int") default: panic("unknown cmd " + q.Cmd) } @@ -1052,6 +1161,9 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { poet.Return(poet.Await(exec)), ) f.Returns = typeRefNode("sqlalchemy", "engine", "Result") + case ":copyfrom": + f.Body = append(f.Body, q.BuildCopyFromBody(true)...) + f.Returns = poet.Name("int") default: panic("unknown cmd " + q.Cmd) } diff --git a/internal/imports.go b/internal/imports.go index b88c58c..b066b6f 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -161,6 +161,13 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} } } + if q.Cmd == ":copyfrom" { + std["typing.List"] = importSpec{Module: "typing", Name: "List"} + // Add Any if non-struct args + if !(len(q.Args) == 1 && q.Args[0].IsStruct()) { + std["typing.Any"] = importSpec{Module: "typing", Name: "Any"} + } + } queryValueModelImports(q.Ret) for _, qv := range q.Args { queryValueModelImports(qv)