Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions cashu/core/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import datetime
import os
import re
import time
from contextlib import asynccontextmanager
from typing import Optional, Union
Expand Down Expand Up @@ -167,6 +166,7 @@ async def get_connection(
conn: Optional[Connection] = None,
lock_table: Optional[str] = None,
lock_select_statement: Optional[str] = None,
lock_parameters: Optional[dict] = None,
lock_timeout: Optional[float] = None,
):
"""Either yield the existing database connection (passthrough) or create a new one.
Expand All @@ -175,6 +175,7 @@ async def get_connection(
conn (Optional[Connection], optional): Connection object. Defaults to None.
lock_table (Optional[str], optional): Table to lock. Defaults to None.
lock_select_statement (Optional[str], optional): Lock select statement. Defaults to None.
lock_parameters (Optional[dict], optional): Parameters for the lock select statement. Defaults to None.
lock_timeout (Optional[float], optional): Lock timeout. Defaults to None.

Yields:
Expand All @@ -187,7 +188,7 @@ async def get_connection(
else:
logger.trace("get_connection: Creating new connection")
async with self.connect(
lock_table, lock_select_statement, lock_timeout
lock_table, lock_select_statement, lock_parameters, lock_timeout
) as new_conn:
yield new_conn

Expand All @@ -196,6 +197,7 @@ async def connect(
self,
lock_table: Optional[str] = None,
lock_select_statement: Optional[str] = None,
lock_parameters: Optional[dict] = None,
lock_timeout: Optional[float] = None,
):
async def _handle_lock_retry(retry_delay, timeout, start_time) -> float:
Expand Down Expand Up @@ -224,7 +226,7 @@ def _is_lock_exception(e):
wconn = Connection(session, txn, self.type, self.name, self.schema)
if lock_table:
await self.acquire_lock(
wconn, lock_table, lock_select_statement
wconn, lock_table, lock_select_statement, lock_parameters
)
logger.trace(
f"> Yielding connection. Lock: {lock_table} - trial {trial} ({random_int})"
Expand Down Expand Up @@ -255,27 +257,23 @@ async def acquire_lock(
wconn: Connection,
lock_table: str,
lock_select_statement: Optional[str] = None,
lock_parameters: Optional[dict] = None,
):
"""Acquire a lock on a table or a row in a table.

Args:
wconn (Connection): Connection object.
lock_table (str): Table to lock.
lock_select_statement (Optional[str], optional):
lock_timeout (Optional[float], optional):

Raises:
Exception: _description_
lock_parameters (Optional[dict], optional): Parameters to pass to the lock select query.
"""
if lock_select_statement:
assert (
len(re.findall(r"^[^=]+='[^']+'$", lock_select_statement)) == 1
), "lock_select_statement must have exactly one {column}='{value}' pattern."
try:
logger.trace(
f"Acquiring lock on {lock_table} with statement {self.lock_table(lock_table, lock_select_statement)}"
f"Acquiring lock on {lock_table} with statement {self.lock_table(lock_table, lock_select_statement)} parameters: {lock_parameters}"
)
await wconn.execute(
self.lock_table(lock_table, lock_select_statement), lock_parameters or {}
)
await wconn.execute(self.lock_table(lock_table, lock_select_statement))
logger.trace(f"Success: Acquired lock on {lock_table}")
return
except Exception as e:
Expand Down
34 changes: 27 additions & 7 deletions cashu/mint/db/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ async def _set_mint_quote_pending(self, quote_id: str) -> MintQuote:
"""
quote: Union[MintQuote, None] = None
async with self.db.get_connection(
lock_table="mint_quotes", lock_select_statement=f"quote='{quote_id}'"
lock_table="mint_quotes",
lock_select_statement="quote = :quote",
lock_parameters={"quote": quote_id},
) as conn:
# get mint quote from db and check if it is already pending
quote = await self.crud.get_mint_quote(
Expand Down Expand Up @@ -177,7 +179,11 @@ async def _unset_mint_quote_pending(
state (MintQuoteState): New state of the mint quote.
"""
quote: Union[MintQuote, None] = None
async with self.db.get_connection(lock_table="mint_quotes") as conn:
async with self.db.get_connection(
lock_table="mint_quotes",
lock_select_statement="quote = :quote",
lock_parameters={"quote": quote_id},
) as conn:
# get mint quote from db and check if it is pending
quote = await self.crud.get_mint_quote(
quote_id=quote_id, db=self.db, conn=conn
Expand Down Expand Up @@ -209,7 +215,8 @@ async def _set_melt_quote_pending(self, quote: MeltQuote) -> MeltQuote:
raise TransactionError("Melt quote doesn't have checking ID.")
async with self.db.get_connection(
lock_table="melt_quotes",
lock_select_statement=f"checking_id='{quote.checking_id}'",
lock_select_statement="checking_id = :checking_id",
lock_parameters={"checking_id": quote.checking_id},
) as conn:
# get all melt quotes with same checking_id from db and check if there is one already pending or paid
quotes_db = await self.crud.get_melt_quotes_by_checking_id(
Expand Down Expand Up @@ -243,7 +250,11 @@ async def _unset_melt_quote_pending(
TransactionError: If the melt quote is not found or not pending.
"""
quote_copy = quote.model_copy()
async with self.db.get_connection(lock_table="melt_quotes") as conn:
async with self.db.get_connection(
lock_table="melt_quotes",
lock_select_statement="quote = :quote",
lock_parameters={"quote": quote.quote},
) as conn:
# get melt quote from db and check if it is pending
quote_db = await self.crud.get_melt_quote(
quote_id=quote.quote, db=self.db, conn=conn
Expand All @@ -263,7 +274,11 @@ async def _unset_melt_quote_pending(
return quote_copy

async def _update_mint_quote_state(self, quote_id: str, state: MintQuoteState):
async with self.db.get_connection(lock_table="mint_quotes") as conn:
async with self.db.get_connection(
lock_table="mint_quotes",
lock_select_statement="quote = :quote",
lock_parameters={"quote": quote_id},
) as conn:
mint_quote = await self.crud.get_mint_quote(
quote_id=quote_id, db=self.db, conn=conn
)
Expand All @@ -286,7 +301,11 @@ async def _update_melt_quote_state(
Raises:
TransactionError: If the melt quote is not found.
"""
async with self.db.get_connection(lock_table="melt_quotes") as conn:
async with self.db.get_connection(
lock_table="melt_quotes",
lock_select_statement="quote = :quote",
lock_parameters={"quote": quote_id},
) as conn:
melt_quote = await self.crud.get_melt_quote(
quote_id=quote_id, db=self.db, conn=conn
)
Expand All @@ -306,7 +325,8 @@ async def _store_melt_quote(self, quote: MeltQuote):
"""
async with self.db.get_connection(
lock_table="melt_quotes",
lock_select_statement=f"checking_id='{quote.checking_id}'",
lock_select_statement="checking_id = :checking_id",
lock_parameters={"checking_id": quote.checking_id},
) as conn:
# get all melt quotes with same checking_id from db and check if there is one already pending or paid
quotes_db = await self.crud.get_melt_quotes_by_checking_id(
Expand Down
20 changes: 14 additions & 6 deletions cashu/mint/ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,11 @@ async def get_mint_quote(self, quote_id: str) -> MintQuote:
if status.settled:
# change state to paid in one transaction, it could have been marked paid
# by the invoice listener in the mean time
async with self.db.get_connection(
lock_table="mint_quotes",
lock_select_statement=f"quote='{quote_id}'",
) as conn:
async with self.db.get_connection(
lock_table="mint_quotes",
lock_select_statement="quote = :quote",
lock_parameters={"quote": quote_id},
) as conn:
quote = await self.crud.get_mint_quote(
quote_id=quote_id, db=self.db, conn=conn
)
Expand Down Expand Up @@ -1082,8 +1083,15 @@ async def swap(
await self.db_write._verify_spent_proofs_and_set_pending(
proofs, keysets=self.keysets
)
try:
async with self.db.get_connection(lock_table="proofs_pending") as conn:
try:
Ys = [p.Y for p in proofs]
lock_parameters = {f"y{i}": y for i, y in enumerate(Ys)}
ys_list = ", ".join(f":y{i}" for i in range(len(Ys)))
async with self.db.get_connection(
lock_table="proofs_pending",
lock_select_statement=f"y IN ({ys_list})",
lock_parameters=lock_parameters,
) as conn:
await self._store_blinded_messages(outputs, keyset=keyset, conn=conn)
await self._invalidate_proofs(proofs=proofs, conn=conn)
promises = await self._sign_blinded_messages(outputs, conn)
Expand Down
3 changes: 2 additions & 1 deletion cashu/mint/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ async def invoice_callback_dispatcher(self, checking_id: str) -> None:
logger.debug(f"Invoice callback dispatcher: {checking_id}")
async with self.db.get_connection(
lock_table="mint_quotes",
lock_select_statement=f"checking_id='{checking_id}'",
lock_select_statement="checking_id = :checking_id",
lock_parameters={"checking_id": checking_id},
lock_timeout=5,
) as conn:
quote = await self.crud.get_mint_quote(
Expand Down
1 change: 0 additions & 1 deletion tests/mint/test_mint_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,4 +507,3 @@ async def test_mint_quote_paid_time_update(wallet: Wallet, ledger: Ledger):
assert quote.paid_time >= quote.created_time
# Ensure it's recent (within last minute)
assert quote.paid_time > int(time.time()) - 60

154 changes: 154 additions & 0 deletions tests/mint/test_mint_db_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,157 @@ async def test_concurrent_set_melt_quote_pending_same_checking_id(ledger: Ledger
# The error should be about the quote already being pending
error = next(r for r in results if isinstance(r, Exception))
assert "Melt quote already paid or pending." in str(error)

# === CONCURRENCY TESTS FOR PARAMETERIZED ROW LOCKS ===

@pytest.mark.asyncio
@pytest.mark.skipif(
is_github_actions and is_regtest and not is_postgres,
reason=("Fails on GitHub Actions for regtest + SQLite"),
)
async def test_concurrent_set_mint_quote_pending_same_quote(wallet: Wallet, ledger: Ledger):
mint_quote = await wallet.request_mint(64)
await pay_if_regtest(mint_quote.request)
_ = await ledger.get_mint_quote(mint_quote.quote)
# Get quote object
quote = await ledger.crud.get_mint_quote(quote_id=mint_quote.quote, db=ledger.db)

results = await asyncio.gather(
ledger.db_write._set_mint_quote_pending(quote.quote),
ledger.db_write._set_mint_quote_pending(quote.quote),
return_exceptions=True
)
success = sum(1 for r in results if not isinstance(r, Exception))
errors = [r for r in results if isinstance(r, Exception)]
assert success == 1, f"Expected 1 success, got {success}. Errors: {errors}"
assert len(errors) == 1, f"Expected 1 error, got {len(errors)}"
err_str = str(errors[0])
assert "lock" in err_str.lower() or "pending" in err_str.lower() or "locked" in err_str.lower(), f"Unexpected error: {err_str}"


@pytest.mark.asyncio
@pytest.mark.skipif(
is_github_actions and is_regtest and not is_postgres,
reason=("Fails on GitHub Actions for regtest + SQLite"),
)
async def test_concurrent_set_mint_quote_pending_different_quotes(wallet: Wallet, ledger: Ledger):
mint_quote1 = await wallet.request_mint(64)
mint_quote2 = await wallet.request_mint(64)
await pay_if_regtest(mint_quote1.request)
await pay_if_regtest(mint_quote2.request)
_ = await ledger.get_mint_quote(mint_quote1.quote)
_ = await ledger.get_mint_quote(mint_quote2.quote)

results = await asyncio.gather(
ledger.db_write._set_mint_quote_pending(mint_quote1.quote),
ledger.db_write._set_mint_quote_pending(mint_quote2.quote),
return_exceptions=True
)
errors = [r for r in results if isinstance(r, Exception)]
assert len(errors) == 0, f"Expected 0 errors, got: {errors}"


@pytest.mark.asyncio
@pytest.mark.skipif(
is_github_actions and is_regtest and not is_postgres,
reason=("Fails on GitHub Actions for regtest + SQLite"),
)
async def test_concurrent_set_melt_quote_pending_same_quote(wallet: Wallet, ledger: Ledger):
mint_quote = await wallet.request_mint(64)
melt_quote = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote.request, unit="sat")
)
quote_db = await ledger.crud.get_melt_quote(quote_id=melt_quote.quote, db=ledger.db)

results = await asyncio.gather(
ledger.db_write._set_melt_quote_pending(quote_db),
ledger.db_write._set_melt_quote_pending(quote_db),
return_exceptions=True
)
success = sum(1 for r in results if not isinstance(r, Exception))
errors = [r for r in results if isinstance(r, Exception)]
assert success == 1, f"Expected 1 success, got {success}. Errors: {errors}"
assert len(errors) == 1, f"Expected 1 error, got {len(errors)}"
err_str = str(errors[0])
assert "lock" in err_str.lower() or "pending" in err_str.lower() or "locked" in err_str.lower() or "paid" in err_str.lower(), f"Unexpected error: {err_str}"


@pytest.mark.asyncio
@pytest.mark.skipif(
is_github_actions and is_regtest and not is_postgres,
reason=("Fails on GitHub Actions for regtest + SQLite"),
)
async def test_concurrent_set_melt_quote_pending_different_quotes(wallet: Wallet, ledger: Ledger):
mint_quote1 = await wallet.request_mint(64)
mint_quote2 = await wallet.request_mint(64)
melt_quote1 = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote1.request, unit="sat")
)
melt_quote2 = await ledger.melt_quote(
PostMeltQuoteRequest(request=mint_quote2.request, unit="sat")
)
quote_db1 = await ledger.crud.get_melt_quote(quote_id=melt_quote1.quote, db=ledger.db)
quote_db2 = await ledger.crud.get_melt_quote(quote_id=melt_quote2.quote, db=ledger.db)

results = await asyncio.gather(
ledger.db_write._set_melt_quote_pending(quote_db1),
ledger.db_write._set_melt_quote_pending(quote_db2),
return_exceptions=True
)
errors = [r for r in results if isinstance(r, Exception)]
assert len(errors) == 0, f"Expected 0 errors, got: {errors}"


@pytest.mark.asyncio
@pytest.mark.skipif(
is_github_actions and is_regtest and not is_postgres,
reason=("Fails on GitHub Actions for regtest + SQLite"),
)
async def test_concurrent_swap_same_proofs(wallet: Wallet, ledger: Ledger):
mint_quote = await wallet.request_mint(64)
await pay_if_regtest(mint_quote.request)
await wallet.mint(64, quote_id=mint_quote.quote)

secrets, rs, _ = await wallet.generate_n_secrets(2)
outputs, _ = wallet._construct_outputs([32, 32], secrets, rs)

results = await asyncio.gather(
ledger.swap(proofs=wallet.proofs, outputs=outputs),
ledger.swap(proofs=wallet.proofs, outputs=outputs),
return_exceptions=True
)

success = sum(1 for r in results if not isinstance(r, Exception))
errors = [r for r in results if isinstance(r, Exception)]
assert success == 1, f"Expected 1 success, got {success}. Errors: {errors}"
assert len(errors) == 1, f"Expected 1 error, got {len(errors)}"
err_str = str(errors[0])
assert "lock" in err_str.lower() or "pending" in err_str.lower() or "locked" in err_str.lower() or "spent" in err_str.lower(), f"Unexpected error: {err_str}"


@pytest.mark.asyncio
@pytest.mark.skipif(
is_github_actions and is_regtest and not is_postgres,
reason=("Fails on GitHub Actions for regtest + SQLite"),
)
async def test_concurrent_swap_different_proofs(wallet: Wallet, ledger: Ledger):
mint_quote = await wallet.request_mint(64)
await pay_if_regtest(mint_quote.request)
await wallet.mint(64, quote_id=mint_quote.quote, split=[32, 32])

proofs1 = wallet.proofs[:1]
proofs2 = wallet.proofs[1:]

secrets1, rs1, _ = await wallet.generate_n_secrets(1)
outputs1, _ = wallet._construct_outputs([32], secrets1, rs1)

secrets2, rs2, _ = await wallet.generate_n_secrets(1)
outputs2, _ = wallet._construct_outputs([32], secrets2, rs2)

results = await asyncio.gather(
ledger.swap(proofs=proofs1, outputs=outputs1),
ledger.swap(proofs=proofs2, outputs=outputs2),
return_exceptions=True
)
errors = [r for r in results if isinstance(r, Exception)]
assert len(errors) == 0, f"Expected 0 errors, got: {errors}"
Loading