Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.0-alpha-10
current_version = 0.4.0-alpha-11
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(-(?P<release>.*)-(?P<build>\d+))?
serialize =
{major}.{minor}.{patch}-{release}-{build}
Expand Down
16 changes: 11 additions & 5 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ jobs:
publish:
runs-on: ubuntu-latest
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
UV_PUBLISH_TOKEN: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
steps:
- uses: actions/checkout@v4
- run: pipx install poetry

- uses: astral-sh/setup-uv@v5
with:
version: "0.6.6"
enable-cache: true
python-version: '3.12'

- uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: poetry
- run: poetry build
- run: poetry publish

- run: uv build
- run: uv publish
21 changes: 14 additions & 7 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
- '3.10'
- '3.11'
- '3.12'
- '3.13'
runs-on: ubuntu-latest
services:
postgres:
Expand All @@ -26,16 +27,22 @@ jobs:
PGPORT: 5432
steps:
- uses: actions/checkout@v4
- run: pipx install poetry

- uses: astral-sh/setup-uv@v5
with:
version: "0.6.6"
enable-cache: true
python-version: ${{ matrix.python-version }}

- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: poetry
- run: poetry install
- run: poetry run ruff check
- run: poetry run ruff format --diff

- run: uv sync
- run: uv run ruff check
- run: uv run ruff format --diff
if: success() || failure()
- run: poetry run mypy sql_athame/**.py tests/**.py
- run: uv run mypy sql_athame/**.py tests/**.py
if: success() || failure()
- run: poetry run pytest
- run: uv run pytest
if: success() || failure()
1,255 changes: 0 additions & 1,255 deletions poetry.lock

This file was deleted.

6 changes: 0 additions & 6 deletions poetry.toml

This file was deleted.

59 changes: 32 additions & 27 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,37 +1,43 @@
[tool.poetry]
[project]
authors = [
{name = "Brian Downing", email = "bdowning@lavos.net"},
]
license = {text = "MIT"}
requires-python = "<4.0,>=3.9"
dependencies = [
"typing-extensions",
]
name = "sql-athame"
version = "0.4.0-alpha-10"
version = "0.4.0-alpha-11"
description = "Python tool for slicing and dicing SQL"
authors = ["Brian Downing <bdowning@lavos.net>"]
license = "MIT"
readme = "README.md"

[project.urls]
homepage = "https://github.com/bdowning/sql-athame"
repository = "https://github.com/bdowning/sql-athame"

[tool.poetry.extras]
asyncpg = ["asyncpg"]

[tool.poetry.dependencies]
python = "^3.9"
asyncpg = { version = "*", optional = true }
typing-extensions = "*"
[project.optional-dependencies]
asyncpg = [
"asyncpg",
]

[tool.poetry.group.dev.dependencies]
pytest = "*"
mypy = "*"
flake8 = "*"
ipython = "*"
pytest-cov = "*"
bump2version = "*"
asyncpg = "*"
pytest-asyncio = "*"
grip = "*"
SQLAlchemy = "*"
ruff = "*"
[dependency-groups]
dev = [
"SQLAlchemy",
"asyncpg",
"bump2version",
"flake8",
"grip",
"ipython",
"mypy",
"pytest",
"pytest-asyncio",
"pytest-cov",
"ruff",
]

[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
[tool.setuptools]
packages = ["sql_athame"]

[tool.ruff]
target-version = "py39"
Expand Down Expand Up @@ -62,7 +68,6 @@ ignore = [
"E501", # line too long
"E721", # type checks, currently broken
"ISC001", # conflicts with ruff format
"PT004", # Fixture `...` does not return anything, add leading underscore
"RET505", # Unnecessary `else` after `return` statement
"RET506", # Unnecessary `else` after `raise` statement
]
Expand Down
14 changes: 7 additions & 7 deletions run
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,26 @@ usage=()

usage+=(" $0 tests - run tests")
tests() {
poetry run pytest "$@"
uv run pytest "$@"
lint
}

usage+=(" $0 refmt - reformat code")
refmt() {
poetry run ruff check --select I --fix
poetry run ruff format
uv run ruff check --select I --fix
uv run ruff format
}

usage+=(" $0 lint - run linting")
lint() {
poetry run ruff check
poetry run ruff format --diff
poetry run mypy sql_athame/**.py tests/**.py
uv run ruff check
uv run ruff format --diff
uv run mypy sql_athame/**.py tests/**.py
}

usage+=(" $0 bump2version {major|minor|patch} - bump version number")
bump2version() {
poetry run bump2version "$@"
uv run bump2version "$@"
}

cmd=$1
Expand Down
104 changes: 85 additions & 19 deletions sql_athame/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class ModelBase:
_cache: dict[tuple, Any]
table_name: str
primary_key_names: tuple[str, ...]
array_safe_insert: bool
insert_multiple_mode: str

def __init_subclass__(
cls,
Expand All @@ -169,12 +169,9 @@ def __init_subclass__(
):
cls._cache = {}
cls.table_name = table_name
if insert_multiple_mode == "array_safe":
cls.array_safe_insert = True
elif insert_multiple_mode == "unnest":
cls.array_safe_insert = False
else:
if insert_multiple_mode not in ("array_safe", "unnest", "executemany"):
raise ValueError("Unknown `insert_multiple_mode`")
cls.insert_multiple_mode = insert_multiple_mode
if isinstance(primary_key, str):
cls.primary_key_names = (primary_key,)
else:
Expand Down Expand Up @@ -357,19 +354,37 @@ def select_sql(
return query

@classmethod
async def select_cursor(
async def cursor_from(
cls: type[T],
connection: Connection,
query: Fragment,
prefetch: int = 1000,
) -> AsyncGenerator[T, None]:
async for row in connection.cursor(*query, prefetch=prefetch):
yield cls.from_mapping(row)

@classmethod
def select_cursor(
cls: type[T],
connection: Connection,
order_by: Union[FieldNames, str] = (),
for_update: bool = False,
where: Where = (),
prefetch: int = 1000,
) -> AsyncGenerator[T, None]:
async for row in connection.cursor(
*cls.select_sql(order_by=order_by, for_update=for_update, where=where),
return cls.cursor_from(
connection,
cls.select_sql(order_by=order_by, for_update=for_update, where=where),
prefetch=prefetch,
):
yield cls.from_mapping(row)
)

@classmethod
async def fetch_from(
cls: type[T],
connection_or_pool: Union[Connection, Pool],
query: Fragment,
) -> list[T]:
return [cls.from_mapping(row) for row in await connection_or_pool.fetch(*query)]

@classmethod
async def select(
Expand All @@ -379,12 +394,10 @@ async def select(
for_update: bool = False,
where: Where = (),
) -> list[T]:
return [
cls.from_mapping(row)
for row in await connection_or_pool.fetch(
*cls.select_sql(order_by=order_by, for_update=for_update, where=where)
)
]
return await cls.fetch_from(
connection_or_pool,
cls.select_sql(order_by=order_by, for_update=for_update, where=where),
)

@classmethod
def create_sql(cls: type[T], **kwargs: Any) -> Fragment:
Expand Down Expand Up @@ -506,6 +519,37 @@ def insert_multiple_array_safe_sql(cls: type[T], rows: Iterable[T]) -> Fragment:
),
)

@classmethod
def insert_multiple_executemany_chunk_sql(
cls: type[T], chunk_size: int
) -> Fragment:
def generate() -> Fragment:
columns = len(cls.column_info())
values = ", ".join(
f"({', '.join(f'${i}' for i in chunk)})"
for chunk in chunked(range(1, columns * chunk_size + 1), columns)
)
return sql(
"INSERT INTO {table} ({fields}) VALUES {values}",
table=cls.table_name_sql(),
fields=sql.list(cls.field_names_sql()),
values=sql.literal(values),
).flatten()

return cls._cached(
("insert_multiple_executemany_chunk", chunk_size),
generate,
)

@classmethod
async def insert_multiple_executemany(
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
) -> None:
args = [r.field_values() for r in rows]
query = cls.insert_multiple_executemany_chunk_sql(1).query()[0]
if args:
await connection_or_pool.executemany(query, args)

@classmethod
async def insert_multiple_unnest(
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
Expand All @@ -527,11 +571,28 @@ async def insert_multiple_array_safe(
async def insert_multiple(
cls: type[T], connection_or_pool: Union[Connection, Pool], rows: Iterable[T]
) -> str:
if cls.array_safe_insert:
if cls.insert_multiple_mode == "executemany":
await cls.insert_multiple_executemany(connection_or_pool, rows)
return "INSERT"
elif cls.insert_multiple_mode == "array_safe":
return await cls.insert_multiple_array_safe(connection_or_pool, rows)
else:
return await cls.insert_multiple_unnest(connection_or_pool, rows)

@classmethod
async def upsert_multiple_executemany(
cls: type[T],
connection_or_pool: Union[Connection, Pool],
rows: Iterable[T],
insert_only: FieldNamesSet = (),
) -> None:
args = [r.field_values() for r in rows]
query = cls.upsert_sql(
cls.insert_multiple_executemany_chunk_sql(1), exclude=insert_only
).query()[0]
if args:
await connection_or_pool.executemany(query, args)

@classmethod
async def upsert_multiple_unnest(
cls: type[T],
Expand Down Expand Up @@ -566,7 +627,12 @@ async def upsert_multiple(
rows: Iterable[T],
insert_only: FieldNamesSet = (),
) -> str:
if cls.array_safe_insert:
if cls.insert_multiple_mode == "executemany":
await cls.upsert_multiple_executemany(
connection_or_pool, rows, insert_only=insert_only
)
return "INSERT"
elif cls.insert_multiple_mode == "array_safe":
return await cls.upsert_multiple_array_safe(
connection_or_pool, rows, insert_only=insert_only
)
Expand Down
5 changes: 3 additions & 2 deletions tests/test_asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ class Test(ModelBase, table_name="test", primary_key="id"):
assert db.updated == new.updated


async def test_replace_multiple_arrays(conn):
@pytest.mark.parametrize("insert_multiple_mode", ["array_safe", "executemany"])
async def test_replace_multiple_arrays(conn, insert_multiple_mode):
@dataclass(order=True)
class Test(
ModelBase,
table_name="test",
primary_key="id",
insert_multiple_mode="array_safe",
insert_multiple_mode=insert_multiple_mode,
):
id: int
a: Annotated[list[int], ColumnInfo(type="INT[]")]
Expand Down
8 changes: 8 additions & 0 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ class Test(ModelBase, table_name="table"):
"hi",
]

assert list(Test.insert_multiple_executemany_chunk_sql(1)) == [
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2)'
]

assert list(Test.insert_multiple_executemany_chunk_sql(3)) == [
'INSERT INTO "table" ("foo", "bar") VALUES ($1, $2), ($3, $4), ($5, $6)'
]

assert sql(
"INSERT INTO table ({}) VALUES ({})",
sql(",").join(t.field_names_sql()),
Expand Down
Loading