Skip to content

Commit 46d637f

Browse files
authored
add upsert retries for tower tables for failed commits (#178)
* add upsert retries for tower tables for failed commits * tests for commits retry * adding retry tracking and assertion
1 parent d1665e3 commit 46d637f

2 files changed

Lines changed: 203 additions & 18 deletions

File tree

src/tower/_tables.py

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from dataclasses import dataclass
33

44
from pyiceberg.exceptions import NoSuchTableError
5+
from pyiceberg.exceptions import CommitFailedException
56

67
TTable = TypeVar("TTable", bound="Table")
78

@@ -25,6 +26,9 @@
2526
namespace_or_default,
2627
)
2728

29+
import time
30+
import random
31+
2832

2933
@dataclass
3034
class RowsAffectedInformation:
@@ -178,24 +182,38 @@ def insert(self, data: pa.Table) -> TTable:
178182
self._stats.inserts += data.num_rows
179183
return self
180184

181-
def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTable:
185+
def upsert(
186+
self,
187+
data: pa.Table,
188+
join_cols: Optional[list[str]] = None,
189+
max_retries: int = 5,
190+
retry_delay_seconds: float = 0.5,
191+
) -> TTable:
182192
"""
183-
Performs an upsert operation (update or insert) on the Iceberg table.
193+
Performs an upsert operation (update or insert) on the Iceberg table. In case of commit conflicts, reloads the metadata and retries.
184194
185195
This method will:
186196
- Update existing rows if they match the join columns
187197
- Insert new rows if no match is found
198+
- Retry for max_retries if commits fail
188199
All operations are case-sensitive by default.
189200
190201
Args:
191202
data (pa.Table): The data to upsert into the table. The schema of this table
192203
must match the schema of the target table.
193204
join_cols (Optional[list[str]]): The columns that form the key to match rows on.
194205
If not provided, all columns will be used for matching.
206+
max_retries (int): Maximum number of retry attempts on commit conflicts.
207+
Defaults to 5.
208+
retry_delay_seconds (float): Wait time in seconds between retries.
209+
Defaults to 0.5 seconds.
195210
196211
Returns:
197212
TTable: The table instance with the upserted rows, allowing for method chaining.
198213
214+
Raises:
215+
CommitFailedException: If all retry attempts are exhausted.
216+
199217
Note:
200218
- The operation is always case-sensitive
201219
- When a match is found, all columns are updated
@@ -217,22 +235,34 @@ def upsert(self, data: pa.Table, join_cols: Optional[list[str]] = None) -> TTabl
217235
>>> print(f"Updated {stats.updates} rows")
218236
>>> print(f"Inserted {stats.inserts} rows")
219237
"""
220-
res = self._table.upsert(
221-
data,
222-
join_cols=join_cols,
223-
# All upserts will always be case sensitive. Perhaps we'll add this
224-
# as a parameter in the future?
225-
case_sensitive=True,
226-
# These are the defaults, but we're including them to be complete.
227-
when_matched_update_all=True,
228-
when_not_matched_insert_all=True,
229-
)
230-
231-
# Update the stats with the results of the relevant upsert.
232-
self._stats.updates += res.rows_updated
233-
self._stats.inserts += res.rows_inserted
234-
235-
return self
238+
last_exception = None
239+
240+
for attempt in range(max_retries + 1):
241+
try:
242+
if attempt > 0:
243+
self._table.refresh()
244+
245+
res = self._table.upsert(
246+
data,
247+
join_cols=join_cols,
248+
# All upserts will always be case sensitive. Perhaps we'll add this
249+
# as a parameter in the future?
250+
case_sensitive=True,
251+
# These are the defaults, but we're including them to be complete.
252+
when_matched_update_all=True,
253+
when_not_matched_insert_all=True,
254+
)
255+
256+
self._stats.updates += res.rows_updated
257+
self._stats.inserts += res.rows_inserted
258+
return self
259+
260+
except CommitFailedException as e:
261+
last_exception = e
262+
if attempt < max_retries:
263+
time.sleep(retry_delay_seconds)
264+
265+
raise last_exception
236266

237267
def delete(self, filters: Union[str, List[pc.Expression]]) -> TTable:
238268
"""

tests/tower/test_tables.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
import pathlib
66
from urllib.parse import urljoin
77
from urllib.request import pathname2url
8+
import threading
89

910
# We import all the things we need from Tower.
1011
import tower.polars as pl
1112
import pyarrow as pa
1213
from pyiceberg.catalog.memory import InMemoryCatalog
14+
from pyiceberg.catalog.sql import SqlCatalog
15+
16+
import concurrent.futures
1317

1418
# Imports the library under test
1519
import tower
@@ -42,6 +46,28 @@ def in_memory_catalog():
4246
pass
4347

4448

49+
@pytest.fixture
50+
def sql_catalog():
51+
temp_dir = tempfile.mkdtemp() # ← Returns string path, no auto-cleanup
52+
abs_path = pathlib.Path(temp_dir).absolute()
53+
file_url = urljoin("file:", pathname2url(str(abs_path)))
54+
55+
catalog = SqlCatalog(
56+
"test.sql.catalog",
57+
**{
58+
"uri": f"sqlite:///{abs_path}/catalog.db?check_same_thread=False",
59+
"warehouse": file_url,
60+
},
61+
)
62+
63+
yield catalog
64+
65+
try:
66+
shutil.rmtree(abs_path)
67+
except FileNotFoundError:
68+
pass
69+
70+
4571
def test_reading_and_writing_to_tables(in_memory_catalog):
4672
schema = pa.schema(
4773
[
@@ -166,6 +192,135 @@ def test_upsert_to_tables(in_memory_catalog):
166192
assert res["age"].item() == 26
167193

168194

195+
def test_upsert_concurrent_writes_with_retry(sql_catalog):
196+
"""Test that concurrent upserts succeed with retry logic handling conflicts."""
197+
schema = pa.schema(
198+
[
199+
pa.field("ticker", pa.string()),
200+
pa.field("date", pa.string()),
201+
pa.field("price", pa.float64()),
202+
]
203+
)
204+
205+
ref = tower.tables("concurrent_test", catalog=sql_catalog)
206+
table = ref.create_if_not_exists(schema)
207+
208+
initial_data = pa.Table.from_pylist(
209+
[
210+
{"ticker": "AAPL", "date": "2024-01-01", "price": 100.0},
211+
{"ticker": "GOOGL", "date": "2024-01-01", "price": 200.0},
212+
{"ticker": "MSFT", "date": "2024-01-01", "price": 300.0},
213+
],
214+
schema=schema,
215+
)
216+
table.insert(initial_data)
217+
218+
retry_count = {"value": 0}
219+
retry_lock = threading.Lock()
220+
221+
def upsert_ticker(ticker: str, new_price: float):
222+
t = tower.tables("concurrent_test", catalog=sql_catalog).load()
223+
224+
original_refresh = t._table.refresh
225+
226+
def tracked_refresh():
227+
with retry_lock:
228+
retry_count["value"] += 1
229+
return original_refresh()
230+
231+
t._table.refresh = tracked_refresh
232+
233+
data = pa.Table.from_pylist(
234+
[{"ticker": ticker, "date": "2024-01-01", "price": new_price}],
235+
schema=schema,
236+
)
237+
t.upsert(data, join_cols=["ticker", "date"])
238+
return ticker
239+
240+
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
241+
futures = [
242+
executor.submit(upsert_ticker, "AAPL", 150.0),
243+
executor.submit(upsert_ticker, "GOOGL", 250.0),
244+
executor.submit(upsert_ticker, "MSFT", 350.0),
245+
]
246+
results = [f.result() for f in concurrent.futures.as_completed(futures)]
247+
248+
assert len(results) == 3
249+
assert (
250+
retry_count["value"] > 0
251+
), "Expected at least one retry due to concurrent conflicts"
252+
253+
final_table = tower.tables("concurrent_test", catalog=sql_catalog).load()
254+
df = final_table.read()
255+
256+
assert len(df) == 3
257+
258+
ticker_prices = {row["ticker"]: row["price"] for row in df.iter_rows(named=True)}
259+
260+
assert ticker_prices["AAPL"] == 150.0
261+
assert ticker_prices["GOOGL"] == 250.0
262+
assert ticker_prices["MSFT"] == 350.0
263+
264+
265+
def test_upsert_concurrent_writes_same_row(sql_catalog):
266+
"""Test concurrent upserts to the SAME row - last write wins."""
267+
schema = pa.schema(
268+
[
269+
pa.field("id", pa.int64()),
270+
pa.field("counter", pa.int64()),
271+
]
272+
)
273+
274+
ref = tower.tables("concurrent_same_row_test", catalog=sql_catalog)
275+
table = ref.create_if_not_exists(schema)
276+
277+
initial_data = pa.Table.from_pylist(
278+
[{"id": 1, "counter": 0}],
279+
schema=schema,
280+
)
281+
table.insert(initial_data)
282+
283+
retry_count = {"value": 0}
284+
retry_lock = threading.Lock()
285+
286+
def upsert_counter(value: int):
287+
t = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load()
288+
289+
original_refresh = t._table.refresh
290+
291+
def tracked_refresh():
292+
with retry_lock:
293+
retry_count["value"] += 1
294+
return original_refresh()
295+
296+
t._table.refresh = tracked_refresh
297+
298+
data = pa.Table.from_pylist(
299+
[{"id": 1, "counter": value}],
300+
schema=schema,
301+
)
302+
t.upsert(data, join_cols=["id"])
303+
return value
304+
305+
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
306+
futures = [executor.submit(upsert_counter, i) for i in range(1, 6)]
307+
results = [f.result() for f in concurrent.futures.as_completed(futures)]
308+
309+
assert len(results) == 5
310+
311+
assert (
312+
retry_count["value"] > 0
313+
), "Expected at least one retry due to concurrent conflicts"
314+
315+
final_table = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load()
316+
df = final_table.read()
317+
318+
assert len(df) == 1
319+
320+
final_counter = df.select("counter").item()
321+
assert final_counter in [1, 2, 3, 4, 5]
322+
323+
169324
def test_delete_from_tables(in_memory_catalog):
170325
schema = pa.schema(
171326
[

0 commit comments

Comments
 (0)