|
5 | 5 | import pathlib |
6 | 6 | from urllib.parse import urljoin |
7 | 7 | from urllib.request import pathname2url |
| 8 | +import threading |
8 | 9 |
|
9 | 10 | # We import all the things we need from Tower. |
10 | 11 | import tower.polars as pl |
11 | 12 | import pyarrow as pa |
12 | 13 | from pyiceberg.catalog.memory import InMemoryCatalog |
| 14 | +from pyiceberg.catalog.sql import SqlCatalog |
| 15 | + |
| 16 | +import concurrent.futures |
13 | 17 |
|
14 | 18 | # Imports the library under test |
15 | 19 | import tower |
@@ -42,6 +46,28 @@ def in_memory_catalog(): |
42 | 46 | pass |
43 | 47 |
|
44 | 48 |
|
| 49 | +@pytest.fixture |
| 50 | +def sql_catalog(): |
| 51 | + temp_dir = tempfile.mkdtemp() # ← Returns string path, no auto-cleanup |
| 52 | + abs_path = pathlib.Path(temp_dir).absolute() |
| 53 | + file_url = urljoin("file:", pathname2url(str(abs_path))) |
| 54 | + |
| 55 | + catalog = SqlCatalog( |
| 56 | + "test.sql.catalog", |
| 57 | + **{ |
| 58 | + "uri": f"sqlite:///{abs_path}/catalog.db?check_same_thread=False", |
| 59 | + "warehouse": file_url, |
| 60 | + }, |
| 61 | + ) |
| 62 | + |
| 63 | + yield catalog |
| 64 | + |
| 65 | + try: |
| 66 | + shutil.rmtree(abs_path) |
| 67 | + except FileNotFoundError: |
| 68 | + pass |
| 69 | + |
| 70 | + |
45 | 71 | def test_reading_and_writing_to_tables(in_memory_catalog): |
46 | 72 | schema = pa.schema( |
47 | 73 | [ |
@@ -166,6 +192,135 @@ def test_upsert_to_tables(in_memory_catalog): |
166 | 192 | assert res["age"].item() == 26 |
167 | 193 |
|
168 | 194 |
|
| 195 | +def test_upsert_concurrent_writes_with_retry(sql_catalog): |
| 196 | + """Test that concurrent upserts succeed with retry logic handling conflicts.""" |
| 197 | + schema = pa.schema( |
| 198 | + [ |
| 199 | + pa.field("ticker", pa.string()), |
| 200 | + pa.field("date", pa.string()), |
| 201 | + pa.field("price", pa.float64()), |
| 202 | + ] |
| 203 | + ) |
| 204 | + |
| 205 | + ref = tower.tables("concurrent_test", catalog=sql_catalog) |
| 206 | + table = ref.create_if_not_exists(schema) |
| 207 | + |
| 208 | + initial_data = pa.Table.from_pylist( |
| 209 | + [ |
| 210 | + {"ticker": "AAPL", "date": "2024-01-01", "price": 100.0}, |
| 211 | + {"ticker": "GOOGL", "date": "2024-01-01", "price": 200.0}, |
| 212 | + {"ticker": "MSFT", "date": "2024-01-01", "price": 300.0}, |
| 213 | + ], |
| 214 | + schema=schema, |
| 215 | + ) |
| 216 | + table.insert(initial_data) |
| 217 | + |
| 218 | + retry_count = {"value": 0} |
| 219 | + retry_lock = threading.Lock() |
| 220 | + |
| 221 | + def upsert_ticker(ticker: str, new_price: float): |
| 222 | + t = tower.tables("concurrent_test", catalog=sql_catalog).load() |
| 223 | + |
| 224 | + original_refresh = t._table.refresh |
| 225 | + |
| 226 | + def tracked_refresh(): |
| 227 | + with retry_lock: |
| 228 | + retry_count["value"] += 1 |
| 229 | + return original_refresh() |
| 230 | + |
| 231 | + t._table.refresh = tracked_refresh |
| 232 | + |
| 233 | + data = pa.Table.from_pylist( |
| 234 | + [{"ticker": ticker, "date": "2024-01-01", "price": new_price}], |
| 235 | + schema=schema, |
| 236 | + ) |
| 237 | + t.upsert(data, join_cols=["ticker", "date"]) |
| 238 | + return ticker |
| 239 | + |
| 240 | + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: |
| 241 | + futures = [ |
| 242 | + executor.submit(upsert_ticker, "AAPL", 150.0), |
| 243 | + executor.submit(upsert_ticker, "GOOGL", 250.0), |
| 244 | + executor.submit(upsert_ticker, "MSFT", 350.0), |
| 245 | + ] |
| 246 | + results = [f.result() for f in concurrent.futures.as_completed(futures)] |
| 247 | + |
| 248 | + assert len(results) == 3 |
| 249 | + assert ( |
| 250 | + retry_count["value"] > 0 |
| 251 | + ), "Expected at least one retry due to concurrent conflicts" |
| 252 | + |
| 253 | + final_table = tower.tables("concurrent_test", catalog=sql_catalog).load() |
| 254 | + df = final_table.read() |
| 255 | + |
| 256 | + assert len(df) == 3 |
| 257 | + |
| 258 | + ticker_prices = {row["ticker"]: row["price"] for row in df.iter_rows(named=True)} |
| 259 | + |
| 260 | + assert ticker_prices["AAPL"] == 150.0 |
| 261 | + assert ticker_prices["GOOGL"] == 250.0 |
| 262 | + assert ticker_prices["MSFT"] == 350.0 |
| 263 | + |
| 264 | + |
| 265 | +def test_upsert_concurrent_writes_same_row(sql_catalog): |
| 266 | + """Test concurrent upserts to the SAME row - last write wins.""" |
| 267 | + schema = pa.schema( |
| 268 | + [ |
| 269 | + pa.field("id", pa.int64()), |
| 270 | + pa.field("counter", pa.int64()), |
| 271 | + ] |
| 272 | + ) |
| 273 | + |
| 274 | + ref = tower.tables("concurrent_same_row_test", catalog=sql_catalog) |
| 275 | + table = ref.create_if_not_exists(schema) |
| 276 | + |
| 277 | + initial_data = pa.Table.from_pylist( |
| 278 | + [{"id": 1, "counter": 0}], |
| 279 | + schema=schema, |
| 280 | + ) |
| 281 | + table.insert(initial_data) |
| 282 | + |
| 283 | + retry_count = {"value": 0} |
| 284 | + retry_lock = threading.Lock() |
| 285 | + |
| 286 | + def upsert_counter(value: int): |
| 287 | + t = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load() |
| 288 | + |
| 289 | + original_refresh = t._table.refresh |
| 290 | + |
| 291 | + def tracked_refresh(): |
| 292 | + with retry_lock: |
| 293 | + retry_count["value"] += 1 |
| 294 | + return original_refresh() |
| 295 | + |
| 296 | + t._table.refresh = tracked_refresh |
| 297 | + |
| 298 | + data = pa.Table.from_pylist( |
| 299 | + [{"id": 1, "counter": value}], |
| 300 | + schema=schema, |
| 301 | + ) |
| 302 | + t.upsert(data, join_cols=["id"]) |
| 303 | + return value |
| 304 | + |
| 305 | + with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: |
| 306 | + futures = [executor.submit(upsert_counter, i) for i in range(1, 6)] |
| 307 | + results = [f.result() for f in concurrent.futures.as_completed(futures)] |
| 308 | + |
| 309 | + assert len(results) == 5 |
| 310 | + |
| 311 | + assert ( |
| 312 | + retry_count["value"] > 0 |
| 313 | + ), "Expected at least one retry due to concurrent conflicts" |
| 314 | + |
| 315 | + final_table = tower.tables("concurrent_same_row_test", catalog=sql_catalog).load() |
| 316 | + df = final_table.read() |
| 317 | + |
| 318 | + assert len(df) == 1 |
| 319 | + |
| 320 | + final_counter = df.select("counter").item() |
| 321 | + assert final_counter in [1, 2, 3, 4, 5] |
| 322 | + |
| 323 | + |
169 | 324 | def test_delete_from_tables(in_memory_catalog): |
170 | 325 | schema = pa.schema( |
171 | 326 | [ |
|
0 commit comments