From 3dbad621d423e059b6675f09eba501d5ebc650fb Mon Sep 17 00:00:00 2001 From: ansipunk Date: Tue, 11 Feb 2025 03:47:55 +0500 Subject: [PATCH 1/4] Database URL parsing and building --- README.md | 2 +- based/backends/__init__.py | 14 ++++++++-- based/backends/mysql.py | 28 +++++++++++++++++-- based/backends/postgresql.py | 37 ++++++++++++++++++++++-- based/backends/sqlite.py | 2 +- based/database.py | 54 ++++++++++++++++++++++++++++++------ tests/test_mysql.py | 21 ++++++++++++++ tests/test_postgresql.py | 21 ++++++++++++++ 8 files changed, 160 insertions(+), 19 deletions(-) create mode 100644 tests/test_mysql.py create mode 100644 tests/test_postgresql.py diff --git a/README.md b/README.md index e8fa8d2..a8655e9 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ BASED_TEST_DB_URLS='postgresql://postgres:postgres@localhost:5432/postgres,mysql - [x] CI/CD - [x] Building and uploading packages to PyPi - [x] Testing with multiple Python versions -- [ ] Database URL parsing and building +- [x] Database URL parsing and building - [x] MySQL backend - [x] Add comments and docstrings - [x] Add lock for PostgreSQL in `force_rollback` mode and SQLite in both modes diff --git a/based/backends/__init__.py b/based/backends/__init__.py index 2ee6e0f..ac41a8a 100644 --- a/based/backends/__init__.py +++ b/based/backends/__init__.py @@ -3,6 +3,7 @@ import typing from contextlib import asynccontextmanager +from sqlalchemy import URL, make_url from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement @@ -19,9 +20,18 @@ class Backend: _connected: bool = False _connected_before: bool = False - def __init__(self, url: str, *, force_rollback: bool = False) -> None: + def __init__( + self, + url: str | None = None, + *, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, + database: str | None = None, + force_rollback: bool = False, + ) -> None: """Details of this method should be implementation specific.""" - _ = url self._force_rollback = force_rollback async def _connect(self) -> None: diff --git a/based/backends/mysql.py b/based/backends/mysql.py index 0277d45..4eda375 100644 --- a/based/backends/mysql.py +++ b/based/backends/mysql.py @@ -19,10 +19,32 @@ class MySQL(Backend): _force_rollback_connection: asyncmy.Connection _dialect: Dialect - def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107 - self._url = make_url(url) + def __init__( # noqa: D107 + self, + url: str | None = None, + *, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, + database: str | None = None, + force_rollback: bool = False, + ) -> None: + if url: + self._url = make_url(url) + else: + self._url = URL.create( + username=username, + password=password, + host=host, + port=port, + database=database, + drivername="asyncmy", + query={}, + ) + self._force_rollback = force_rollback - self._dialect = dialect() # type: ignore + self._dialect = dialect() async def _connect(self) -> None: self._pool = await asyncmy.create_pool( diff --git a/based/backends/postgresql.py b/based/backends/postgresql.py index e112f61..fc8fdeb 100644 --- a/based/backends/postgresql.py +++ b/based/backends/postgresql.py @@ -3,6 +3,7 @@ from psycopg import AsyncConnection from psycopg_pool import AsyncConnectionPool +from sqlalchemy import URL, make_url from sqlalchemy.dialects import postgresql from sqlalchemy.engine.interfaces import Dialect @@ -17,10 +18,40 @@ class PostgreSQL(Backend): _force_rollback_connection: AsyncConnection _dialect: Dialect - def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107 - self._pool = AsyncConnectionPool(url, open=False) + def __init__( # noqa: D107 + self, + url: str | None = None, + *, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, + database: str | None = None, + force_rollback: bool = False, + ) -> None: + if url: + self._url = make_url(url) + else: + self._url = URL.create( + username=username, + password=password, + host=host, + port=port, + database=database, + drivername="psycopg", + query={}, + ) + + conninfo = ( + f"user={self._url.username} " + f"password={self._url.password} " + f"host={self._url.host} " + f"port={self._url.port} " + f"dbname={self._url.database}" + ) + self._pool = AsyncConnectionPool(conninfo, open=False) self._force_rollback = force_rollback - self._dialect = postgresql.dialect() # type: ignore + self._dialect = postgresql.dialect() async def _connect(self) -> None: await self._pool.open() diff --git a/based/backends/sqlite.py b/based/backends/sqlite.py index 5762bf6..4bfa321 100644 --- a/based/backends/sqlite.py +++ b/based/backends/sqlite.py @@ -22,7 +22,7 @@ class SQLite(Backend): def __init__(self, url: str, *, force_rollback: bool = False) -> None: # noqa: D107 self._conn = connect(url, isolation_level=None) self._force_rollback = force_rollback - self._dialect = sqlite.dialect() # type: ignore + self._dialect = sqlite.dialect() async def _connect(self) -> None: await self._conn diff --git a/based/database.py b/based/database.py index b4f808e..892c657 100644 --- a/based/database.py +++ b/based/database.py @@ -1,7 +1,7 @@ from asyncio import Lock from contextlib import asynccontextmanager from types import TracebackType -from typing import AsyncGenerator, Optional, Type +from typing import AsyncGenerator, Literal, Optional, Type from based.backends import Backend, Session @@ -15,8 +15,14 @@ class Database: def __init__( self, - url: str, + url: str | None = None, *, + host: str | None = None, + port: str | None = None, + username: str | None = None, + password: str | None = None, + database: str | None = None, + schema: Literal["postgresql", "mysql", "sqlite"] | None = None, force_rollback: bool = False, use_lock: bool = False, ) -> None: @@ -36,7 +42,20 @@ def __init__( Args: url: Database URL should be a URL defined by RFC 1738, containing the correct - schema like `postgresql://user:password@host:port/database`. + schema like `postgresql://user:password@host:port/database`. Can be + omitted in favor of passing parameters separately. + username: + Database username. + password: + Database password. + host: + Database host. + port: + Database port. + database: + Database name. + schema: + Used database schema. Can be `postgresql` or `mysql`. force_rollback: If this flag is set to True, then all the queries to the database will be made in one single transaction which will be rolled back when the @@ -53,10 +72,11 @@ def __init__( Can be raised when an invalid database URL is provided or the database schema is not supported. """ - url_parts = url.split("://") - if len(url_parts) != 2: - raise ValueError("Invalid database URL") - schema = url_parts[0] + if url is not None: + url_parts = url.split("://") + if len(url_parts) != 2: + raise ValueError("Invalid database URL") + schema = url_parts[0] if use_lock and (force_rollback or schema == "sqlite"): self._lock = Lock() @@ -72,11 +92,27 @@ def __init__( elif schema == "postgresql": from based.backends.postgresql import PostgreSQL - self._backend = PostgreSQL(url, force_rollback=force_rollback) + self._backend = PostgreSQL( + url=url, + username=username, + password=password, + host=host, + port=port, + database=database, + force_rollback=force_rollback, + ) elif schema == "mysql": from based.backends.mysql import MySQL - self._backend = MySQL(url, force_rollback=force_rollback) + self._backend = MySQL( + url=url, + username=username, + password=password, + host=host, + port=port, + database=database, + force_rollback=force_rollback, + ) else: raise ValueError(f"Unknown database schema: {schema}") diff --git a/tests/test_mysql.py b/tests/test_mysql.py new file mode 100644 index 0000000..ce27cfc --- /dev/null +++ b/tests/test_mysql.py @@ -0,0 +1,21 @@ +import sqlalchemy as sa + +import based + + +async def test_mysql_url_building(database_url: str): + if not database_url.startswith("mysql"): + return + + url = sa.make_url(database_url) + + async with based.Database( + username=url.username, + password=url.password, + host=url.host, + port=url.port, + database=url.database, + schema="mysql", + ) as database: + async with database.session() as session: + await session.execute("SELECT 1;") diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py new file mode 100644 index 0000000..2f77150 --- /dev/null +++ b/tests/test_postgresql.py @@ -0,0 +1,21 @@ +import sqlalchemy as sa + +import based + + +async def test_postgresql_url_building(database_url: str): + if not database_url.startswith("postgresql"): + return + + url = sa.make_url(database_url) + + async with based.Database( + username=url.username, + password=url.password, + host=url.host, + port=url.port, + database=url.database, + schema="postgresql", + ) as database: + async with database.session() as session: + await session.execute("SELECT 1;") From a82ea1d89ae83aa7560734d1e87b956b2cbe2fbe Mon Sep 17 00:00:00 2001 From: ansipunk Date: Tue, 11 Feb 2025 03:48:32 +0500 Subject: [PATCH 2/4] Bump version --- based/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/based/__init__.py b/based/__init__.py index 5a3c7a8..fbc8eba 100644 --- a/based/__init__.py +++ b/based/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.0post1" +__version__ = "0.6.0" from based.backends import Session from based.database import Database From 3a51e11877cce6ab9616df78f40a827eb5443602 Mon Sep 17 00:00:00 2001 From: ansipunk Date: Tue, 11 Feb 2025 03:50:31 +0500 Subject: [PATCH 3/4] Fix formatting --- based/backends/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/based/backends/__init__.py b/based/backends/__init__.py index ac41a8a..7d05a65 100644 --- a/based/backends/__init__.py +++ b/based/backends/__init__.py @@ -3,7 +3,6 @@ import typing from contextlib import asynccontextmanager -from sqlalchemy import URL, make_url from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql import ClauseElement From 93ece9d6ec10e51a36bef55f669f03995a123e58 Mon Sep 17 00:00:00 2001 From: ansipunk Date: Tue, 11 Feb 2025 03:54:52 +0500 Subject: [PATCH 4/4] 3.8, 3.9 friendly annotations --- based/backends/__init__.py | 12 ++++++------ based/backends/mysql.py | 12 ++++++------ based/backends/postgresql.py | 12 ++++++------ based/database.py | 14 +++++++------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/based/backends/__init__.py b/based/backends/__init__.py index 7d05a65..5237e71 100644 --- a/based/backends/__init__.py +++ b/based/backends/__init__.py @@ -21,13 +21,13 @@ class Backend: def __init__( self, - url: str | None = None, + url: typing.Optional[str] = None, *, - host: str | None = None, - port: str | None = None, - username: str | None = None, - password: str | None = None, - database: str | None = None, + host: typing.Optional[str] = None, + port: typing.Optional[str] = None, + username: typing.Optional[str] = None, + password: typing.Optional[str] = None, + database: typing.Optional[str] = None, force_rollback: bool = False, ) -> None: """Details of this method should be implementation specific.""" diff --git a/based/backends/mysql.py b/based/backends/mysql.py index 4eda375..c1cee19 100644 --- a/based/backends/mysql.py +++ b/based/backends/mysql.py @@ -21,13 +21,13 @@ class MySQL(Backend): def __init__( # noqa: D107 self, - url: str | None = None, + url: typing.Optional[str] = None, *, - host: str | None = None, - port: str | None = None, - username: str | None = None, - password: str | None = None, - database: str | None = None, + host: typing.Optional[str] = None, + port: typing.Optional[str] = None, + username: typing.Optional[str] = None, + password: typing.Optional[str] = None, + database: typing.Optional[str] = None, force_rollback: bool = False, ) -> None: if url: diff --git a/based/backends/postgresql.py b/based/backends/postgresql.py index fc8fdeb..9c6557c 100644 --- a/based/backends/postgresql.py +++ b/based/backends/postgresql.py @@ -20,13 +20,13 @@ class PostgreSQL(Backend): def __init__( # noqa: D107 self, - url: str | None = None, + url: typing.Optional[str] = None, *, - host: str | None = None, - port: str | None = None, - username: str | None = None, - password: str | None = None, - database: str | None = None, + host: typing.Optional[str] = None, + port: typing.Optional[str] = None, + username: typing.Optional[str] = None, + password: typing.Optional[str] = None, + database: typing.Optional[str] = None, force_rollback: bool = False, ) -> None: if url: diff --git a/based/database.py b/based/database.py index 892c657..d146134 100644 --- a/based/database.py +++ b/based/database.py @@ -15,14 +15,14 @@ class Database: def __init__( self, - url: str | None = None, + url: Optional[str] = None, *, - host: str | None = None, - port: str | None = None, - username: str | None = None, - password: str | None = None, - database: str | None = None, - schema: Literal["postgresql", "mysql", "sqlite"] | None = None, + host: Optional[str] = None, + port: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + schema: Optional[Literal["postgresql", "mysql", "sqlite"]] = None, force_rollback: bool = False, use_lock: bool = False, ) -> None: