diff --git a/based/database.py b/based/database.py index 12511ab..16e6978 100644 --- a/based/database.py +++ b/based/database.py @@ -58,7 +58,7 @@ def __init__( raise ValueError("Invalid database URL") schema = url_parts[0] - if force_rollback or (schema == "sqlite" and use_lock): + if use_lock and (force_rollback or schema == "sqlite"): self._lock = Lock() if schema == "sqlite": diff --git a/tests/test_backend.py b/tests/test_backend.py index 3c9faaa..b2fe161 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -25,6 +25,27 @@ async def test_database_force_rollback( assert movie is None +async def test_database_force_rollback_with_lock( + table: sqlalchemy.Table, + database_url: str, + gen_movie: typing.Callable[[], typing.Tuple[str, int]], +): + title, year = gen_movie() + + async with based.Database( + database_url, force_rollback=True, use_lock=True, + ) as database: + async with database.session() as session: + query = table.insert().values(title=title, year=year) + await session.execute(query) + + async with based.Database(database_url, force_rollback=True) as database: + async with database.session() as session: + query = table.select().where(table.c.title == title) + movie = await session.fetch_one(query) + assert movie is None + + async def test_database_no_force_rollback( table: sqlalchemy.Table, database_url: str,