Skip to content

Commit 4de1f30

Browse files
committed
Add flag for deny split transaction
1 parent 9c456d4 commit 4de1f30

File tree

6 files changed

+231
-7
lines changed

6 files changed

+231
-7
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
* Flag for deny split transaction
2+
13
## 2.12.3 ##
24
* Add six package to requirements
35
* Fixed error while passing date parameter in execute

tests/aio/test_tx.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def test_tx_snapshot_ro(driver, database):
8585

8686
await ro_tx.commit()
8787

88+
ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly())
8889
with pytest.raises(ydb.issues.GenericError) as exc_info:
8990
await ro_tx.execute("UPDATE `test` SET value = value + 1")
9091
assert "read only transaction" in exc_info.value.message
@@ -94,3 +95,62 @@ async def test_tx_snapshot_ro(driver, database):
9495
commit_tx=True,
9596
)
9697
assert data[0].rows == [{"value": 2}]
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_split_transactions_deny_split(driver, table_name):
102+
async with ydb.aio.SessionPool(driver, 1) as pool:
103+
104+
async def check_transaction(s: ydb.aio.table.Session):
105+
async with s.transaction(deny_split_transactions=True) as tx:
106+
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
107+
await tx.commit()
108+
109+
with pytest.raises(RuntimeError):
110+
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
111+
112+
await tx.commit()
113+
114+
async with s.transaction() as tx:
115+
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
116+
assert rs[0].rows[0].cnt == 1
117+
118+
await pool.retry_operation(check_transaction)
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_split_transactions_allow_split(driver, table_name):
123+
async with ydb.aio.SessionPool(driver, 1) as pool:
124+
125+
async def check_transaction(s: ydb.aio.table.Session):
126+
async with s.transaction(deny_split_transactions=False) as tx:
127+
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
128+
await tx.commit()
129+
130+
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
131+
await tx.commit()
132+
133+
async with s.transaction() as tx:
134+
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
135+
assert rs[0].rows[0].cnt == 2
136+
137+
await pool.retry_operation(check_transaction)
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_split_transactions_default(driver, table_name):
142+
async with ydb.aio.SessionPool(driver, 1) as pool:
143+
144+
async def check_transaction(s: ydb.aio.table.Session):
145+
async with s.transaction() as tx:
146+
await tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
147+
await tx.commit()
148+
149+
await tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
150+
await tx.commit()
151+
152+
async with s.transaction() as tx:
153+
rs = await tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
154+
assert rs[0].rows[0].cnt == 2
155+
156+
await pool.retry_operation(check_transaction)

tests/conftest.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,30 @@ async def driver_sync(endpoint, database, event_loop):
111111
yield driver
112112

113113
driver.stop(timeout=10)
114+
115+
116+
@pytest.fixture()
117+
def table_name(driver_sync, database):
118+
table_name = "table"
119+
120+
with ydb.SessionPool(driver_sync) as pool:
121+
122+
def create_table(s):
123+
try:
124+
s.drop_table(database + "/" + table_name)
125+
except ydb.SchemeError:
126+
pass
127+
128+
s.execute_scheme(
129+
"""
130+
CREATE TABLE %s (
131+
id Int64 NOT NULL,
132+
i64Val Int64,
133+
PRIMARY KEY(id)
134+
)
135+
"""
136+
% table_name
137+
)
138+
139+
pool.retry_operation_sync(create_table)
140+
return table_name

tests/table/test_tx.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def test_tx_snapshot_ro(driver_sync, database):
8080

8181
ro_tx.commit()
8282

83+
ro_tx = session.transaction(tx_mode=ydb.SnapshotReadOnly())
8384
with pytest.raises(ydb.issues.GenericError) as exc_info:
8485
ro_tx.execute("UPDATE `test` SET value = value + 1")
8586
assert "read only transaction" in exc_info.value.message
@@ -89,3 +90,59 @@ def test_tx_snapshot_ro(driver_sync, database):
8990
commit_tx=True,
9091
)
9192
assert data[0].rows == [{"value": 2}]
93+
94+
95+
def test_split_transactions_deny_split(driver_sync, table_name):
96+
with ydb.SessionPool(driver_sync, 1) as pool:
97+
98+
def check_transaction(s: ydb.table.Session):
99+
with s.transaction(deny_split_transactions=True) as tx:
100+
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
101+
tx.commit()
102+
103+
with pytest.raises(RuntimeError):
104+
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
105+
106+
tx.commit()
107+
108+
with s.transaction() as tx:
109+
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
110+
assert rs[0].rows[0].cnt == 1
111+
112+
pool.retry_operation_sync(check_transaction)
113+
114+
115+
def test_split_transactions_allow_split(driver_sync, table_name):
116+
with ydb.SessionPool(driver_sync, 1) as pool:
117+
118+
def check_transaction(s: ydb.table.Session):
119+
with s.transaction(deny_split_transactions=False) as tx:
120+
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
121+
tx.commit()
122+
123+
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
124+
tx.commit()
125+
126+
with s.transaction() as tx:
127+
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
128+
assert rs[0].rows[0].cnt == 2
129+
130+
pool.retry_operation_sync(check_transaction)
131+
132+
133+
def test_split_transactions_default(driver_sync, table_name):
134+
with ydb.SessionPool(driver_sync, 1) as pool:
135+
136+
def check_transaction(s: ydb.table.Session):
137+
with s.transaction() as tx:
138+
tx.execute("INSERT INTO %s (id) VALUES (1)" % table_name)
139+
tx.commit()
140+
141+
tx.execute("INSERT INTO %s (id) VALUES (2)" % table_name)
142+
tx.commit()
143+
144+
with s.transaction() as tx:
145+
rs = tx.execute("SELECT COUNT(*) as cnt FROM %s" % table_name)
146+
assert rs[0].rows[0].cnt == 2
147+
148+
pool.retry_operation_sync(check_transaction)

ydb/aio/table.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,14 @@ async def alter_table(
120120
set_read_replicas_settings,
121121
)
122122

123-
def transaction(self, tx_mode=None):
124-
return TxContext(self._driver, self._state, self, tx_mode)
123+
def transaction(self, tx_mode=None, *, deny_split_transactions=False):
124+
return TxContext(
125+
self._driver,
126+
self._state,
127+
self,
128+
tx_mode,
129+
deny_split_transactions=deny_split_transactions,
130+
)
125131

126132
async def describe_table(self, path, settings=None): # pylint: disable=W0236
127133
return await super().describe_table(path, settings)
@@ -184,6 +190,9 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
184190
async def execute(
185191
self, query, parameters=None, commit_tx=False, settings=None
186192
): # pylint: disable=W0236
193+
194+
self._check_split()
195+
187196
return await super().execute(query, parameters, commit_tx, settings)
188197

189198
async def commit(self, settings=None): # pylint: disable=W0236

ydb/table.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,7 +1176,7 @@ def execute_scheme(self, yql_text, settings=None):
11761176
pass
11771177

11781178
@abstractmethod
1179-
def transaction(self, tx_mode=None):
1179+
def transaction(self, tx_mode=None, deny_split_transactions=False):
11801180
pass
11811181

11821182
@abstractmethod
@@ -1681,8 +1681,14 @@ def execute_scheme(self, yql_text, settings=None):
16811681
self._state.endpoint,
16821682
)
16831683

1684-
def transaction(self, tx_mode=None):
1685-
return TxContext(self._driver, self._state, self, tx_mode)
1684+
def transaction(self, tx_mode=None, deny_split_transactions=False):
1685+
return TxContext(
1686+
self._driver,
1687+
self._state,
1688+
self,
1689+
tx_mode,
1690+
deny_split_transactions=deny_split_transactions,
1691+
)
16861692

16871693
def has_prepared(self, query):
16881694
return query in self._state
@@ -2194,9 +2200,27 @@ def begin(self, settings=None):
21942200

21952201

21962202
class BaseTxContext(ITxContext):
2197-
__slots__ = ("_tx_state", "_session_state", "_driver", "session")
2203+
__slots__ = (
2204+
"_tx_state",
2205+
"_session_state",
2206+
"_driver",
2207+
"session",
2208+
"_finished",
2209+
"_deny_split_transactions",
2210+
)
21982211

2199-
def __init__(self, driver, session_state, session, tx_mode=None):
2212+
_COMMIT = "commit"
2213+
_ROLLBACK = "rollback"
2214+
2215+
def __init__(
2216+
self,
2217+
driver,
2218+
session_state,
2219+
session,
2220+
tx_mode=None,
2221+
*,
2222+
deny_split_transactions=False
2223+
):
22002224
"""
22012225
An object that provides a simple transaction context manager that allows statements execution
22022226
in a transaction. You don't have to open transaction explicitly, because context manager encapsulates
@@ -2219,6 +2243,8 @@ def __init__(self, driver, session_state, session, tx_mode=None):
22192243
self._tx_state = _tx_ctx_impl.TxState(tx_mode)
22202244
self._session_state = session_state
22212245
self.session = session
2246+
self._finished = ""
2247+
self._deny_split_transactions = deny_split_transactions
22222248

22232249
def __enter__(self):
22242250
"""
@@ -2276,6 +2302,9 @@ def execute(self, query, parameters=None, commit_tx=False, settings=None):
22762302
22772303
:return: A result sets or exception in case of execution errors
22782304
"""
2305+
2306+
self._check_split()
2307+
22792308
return self._driver(
22802309
_tx_ctx_impl.execute_request_factory(
22812310
self._session_state,
@@ -2302,8 +2331,12 @@ def commit(self, settings=None):
23022331
23032332
:return: A committed transaction or exception if commit is failed
23042333
"""
2334+
2335+
self._set_finish(self._COMMIT)
2336+
23052337
if self._tx_state.tx_id is None and not self._tx_state.dead:
23062338
return self
2339+
23072340
return self._driver(
23082341
_tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state),
23092342
_apis.TableService.Stub,
@@ -2323,8 +2356,12 @@ def rollback(self, settings=None):
23232356
23242357
:return: A rolled back transaction or exception if rollback is failed
23252358
"""
2359+
2360+
self._set_finish(self._ROLLBACK)
2361+
23262362
if self._tx_state.tx_id is None and not self._tx_state.dead:
23272363
return self
2364+
23282365
return self._driver(
23292366
_tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state),
23302367
_apis.TableService.Stub,
@@ -2345,6 +2382,9 @@ def begin(self, settings=None):
23452382
"""
23462383
if self._tx_state.tx_id is not None:
23472384
return self
2385+
2386+
self._check_split()
2387+
23482388
return self._driver(
23492389
_tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state),
23502390
_apis.TableService.Stub,
@@ -2355,6 +2395,21 @@ def begin(self, settings=None):
23552395
self._session_state.endpoint,
23562396
)
23572397

2398+
def _set_finish(self, val):
2399+
self._check_split(val)
2400+
self._finished = val
2401+
2402+
def _check_split(self, allow=""):
2403+
"""
2404+
Deny all operaions with transaction after commit/rollback.
2405+
Exception: double commit and double rollbacks, because it is safe
2406+
"""
2407+
if not self._deny_split_transactions:
2408+
return
2409+
2410+
if self._finished != "" and self._finished != allow:
2411+
raise RuntimeError("Any operation with finished transaction is denied")
2412+
23582413

23592414
class TxContext(BaseTxContext):
23602415
@_utilities.wrap_async_call_exceptions
@@ -2370,6 +2425,9 @@ def async_execute(self, query, parameters=None, commit_tx=False, settings=None):
23702425
23712426
:return: A future of query execution
23722427
"""
2428+
2429+
self._check_split()
2430+
23732431
return self._driver.future(
23742432
_tx_ctx_impl.execute_request_factory(
23752433
self._session_state,
@@ -2401,8 +2459,12 @@ def async_commit(self, settings=None):
24012459
24022460
:return: A future of commit call
24032461
"""
2462+
self._check_split()
2463+
self._finished = True
2464+
24042465
if self._tx_state.tx_id is None and not self._tx_state.dead:
24052466
return _utilities.wrap_result_in_future(self)
2467+
24062468
return self._driver.future(
24072469
_tx_ctx_impl.commit_request_factory(self._session_state, self._tx_state),
24082470
_apis.TableService.Stub,
@@ -2423,8 +2485,12 @@ def async_rollback(self, settings=None):
24232485
24242486
:return: A future of rollback call
24252487
"""
2488+
self._check_split()
2489+
self._finished = True
2490+
24262491
if self._tx_state.tx_id is None and not self._tx_state.dead:
24272492
return _utilities.wrap_result_in_future(self)
2493+
24282494
return self._driver.future(
24292495
_tx_ctx_impl.rollback_request_factory(self._session_state, self._tx_state),
24302496
_apis.TableService.Stub,
@@ -2446,6 +2512,9 @@ def async_begin(self, settings=None):
24462512
"""
24472513
if self._tx_state.tx_id is not None:
24482514
return _utilities.wrap_result_in_future(self)
2515+
2516+
self._check_split()
2517+
24492518
return self._driver.future(
24502519
_tx_ctx_impl.begin_request_factory(self._session_state, self._tx_state),
24512520
_apis.TableService.Stub,

0 commit comments

Comments
 (0)