diff --git a/README.md b/README.md index f8c185d..0777cf4 100644 --- a/README.md +++ b/README.md @@ -3,10 +3,10 @@ Fall 2025 group work with: - Oliver Boorstein - - table, page, merge, and secondary indexes + - table, page, merge, secondary indexes, and concurrency control - Emily Mayer - - querying, persistence, and bufferpool + - querying, persistence, bufferpool, and concurrency control - Jack Lund - - B+-tree, indexing, persistence, and bufferpool + - B+-tree, indexing, persistence, bufferpool, transactions - Inna Gruneva - - querying and merge + - querying, merge, and transactions diff --git a/config.py b/config.py index 83905fe..d4817f4 100644 --- a/config.py +++ b/config.py @@ -21,3 +21,4 @@ class Config: tail_meta_columns = 5 null_value = -2**63 deleted_record_value = -1 + max_attempts = 10 diff --git a/lstore/index.py b/lstore/index.py index f2d91ab..fab2440 100644 --- a/lstore/index.py +++ b/lstore/index.py @@ -6,6 +6,7 @@ from config import Config from lstore.bplus import BPlusTree +import threading if TYPE_CHECKING: from lstore.table import Table @@ -16,6 +17,7 @@ class Index: def __init__(self, table: "Table") -> None: self.table = table self.indices: List[Optional[BPlusTree]] = [None] * table.num_columns + self.locks = [threading.Lock() for _ in range(table.num_columns)] # Always build an index for the primary key column. self.create_index(table.key) @@ -23,16 +25,20 @@ def __init__(self, table: "Table") -> None: # Lookup helpers # ------------------------------------------------------------------ def locate(self, column: int, value: int) -> List[int]: - tree = self.indices[column] - if tree is None: - return [] - return tree.find(value) + lock = self.locks[column] + with lock: + tree = self.indices[column] + if tree is None: + return [] + return tree.find(value) def locate_range(self, begin: int, end: int, column: int) -> List[int]: - tree = self.indices[column] - if tree is None: - return [] - return tree.find_range(begin, end) + lock = self.locks[column] + with lock: + tree = self.indices[column] + if tree is None: + return [] + return tree.find_range(begin, end) # ------------------------------------------------------------------ # Mutation helpers @@ -44,7 +50,8 @@ def add(self, rid: int, columns: List[Optional[int]]) -> None: value = columns[column] if value is None: continue - tree.insert(value, rid) + with self.locks[column]: + tree.insert(value, rid) def remove(self, rid: int, columns: List[Optional[int]]) -> None: for column, tree in enumerate(self.indices): @@ -53,7 +60,8 @@ def remove(self, rid: int, columns: List[Optional[int]]) -> None: value = columns[column] if value is None: continue - tree.remove(value, rid) + with self.locks[column]: + tree.remove(value, rid) def update(self, rid: int, old_values: List[Optional[int]], new_values: List[Optional[int]]) -> None: for column, tree in enumerate(self.indices): @@ -63,8 +71,9 @@ def update(self, rid: int, old_values: List[Optional[int]], new_values: List[Opt new = new_values[column] if old == new or old is None or new is None: continue - tree.remove(old, rid) - tree.insert(new, rid) + with self.locks[column]: + tree.remove(old, rid) + tree.insert(new, rid) # ------------------------------------------------------------------ # Index lifecycle @@ -76,8 +85,9 @@ def create_index(self, column_number: int) -> bool: return False tree = BPlusTree() - self.indices[column_number] = tree - self._bulk_load(column_number, tree) + with self.locks[column_number]: + self.indices[column_number] = tree + self._bulk_load(column_number, tree) return True def drop_index(self, column_number: int) -> bool: @@ -133,4 +143,4 @@ def _iterate_existing_rows(self) -> Iterable[Tuple[int, List[Optional[int]]]]: and record[Config.indirection_column] == Config.deleted_record_value ): continue - yield rid, record[data_offset : data_offset + self.table.num_columns] + yield rid, record[data_offset : data_offset + self.table.num_columns] \ No newline at end of file diff --git a/lstore/table.py b/lstore/table.py index 95ebd1b..f113f05 100644 --- a/lstore/table.py +++ b/lstore/table.py @@ -2,6 +2,7 @@ from contextlib import ExitStack, contextmanager from time import time from typing import Optional +import threading from config import Config from lstore.bufferpool import Bufferpool @@ -283,6 +284,16 @@ def int_keys(values: dict) -> dict: self._ensure_logical_page(range_id, segment_offset, page_index) + def peek_base_rid(self): + """ + Compute the next base RID without mutating offsets. + Callers must hold the same critical section used for add_record to avoid mismatch. + """ + range_id = self.num_base_records // Config.records_per_range + offset = self.base_offsets[range_id] + return self.encode_rid(range_id, 0, offset) + + def add_record( self, columns: list[int], @@ -611,7 +622,8 @@ def __init__( metadata=directory_metadata, ) self.index = Index(self) - pass + self._insert_lock = threading.RLock() + def to_metadata(self) -> dict: return { @@ -676,6 +688,13 @@ def insert_record(self, columns: list[int], is_tail: bool = False, base_rid: int :param base_rid: int - the RID of the base record, only used for tail records :return: int - the RID of the record """ + with self._insert_lock: + return self._insert_record_locked(columns, is_tail=is_tail, base_rid=base_rid) + + def _insert_record_locked(self, columns: list[int], is_tail: bool = False, base_rid: int = Config.null_value): + """ + Internal helper that assumes _insert_lock is held. + """ prior_data = None if is_tail: prior_data = self.get_cumulative_updated_record(base_rid)[ @@ -715,7 +734,6 @@ def delete_record(self, rid: int): except ValueError: return False - def __merge(self): - print("merge is happening") - pass - + + def get_rid_for_lock(self): + return self.page_directory.peek_base_rid() diff --git a/lstore/tests/test_transactions.py b/lstore/tests/test_transactions.py new file mode 100644 index 0000000..caf7d0e --- /dev/null +++ b/lstore/tests/test_transactions.py @@ -0,0 +1,85 @@ +import os +import sys +import unittest + +# Ensure repository root is on sys.path when running this file directly. +ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if ROOT_DIR not in sys.path: + sys.path.insert(0, ROOT_DIR) + +from lstore.db import Database +from lstore.query import Query +from lstore.transaction import Transaction +from lstore.transaction_worker import TransactionWorker + + +class TransactionSemanticsTests(unittest.TestCase): + def setUp(self): + self.db = Database() + self.table = self.db.create_table("grades", 5, 0) + self.query = Query(self.table) + + def test_insert_commits_and_persists(self): + t = Transaction() + t.add_query(self.query.insert, self.table, 1, 2, 3, 4, 5) + + self.assertTrue(t.run()) + row = self.query.select(1, 0, [1, 1, 1, 1, 1])[0].columns + self.assertEqual(row, [1, 2, 3, 4, 5]) + + def test_update_abort_rolls_back_prior_changes(self): + # seed a record + self.query.insert(1, 10, 20, 30, 40) + + t = Transaction() + # valid update (would change col4 to 99) + t.add_query(self.query.update, self.table, 1, None, None, None, 99) + # invalid update: wrong column count -> returns False -> abort + t.add_query(self.query.update, self.table, 1, None, None, None) + + self.assertFalse(t.run(max_attempts=1)) + row = self.query.select(1, 0, [1, 1, 1, 1, 1])[0].columns + # values should remain the original ones + self.assertEqual(row, [1, 10, 20, 30, 40]) + + def test_delete_abort_restores_row_and_index(self): + self.query.insert(1, 2, 3, 4, 5) + + t = Transaction() + t.add_query(self.query.delete, self.table, 1) + # force abort with invalid update (column count mismatch) + t.add_query(self.query.update, self.table, 1, None, None, None) + + self.assertFalse(t.run(max_attempts=1)) + rows = self.query.select(1, 0, [1, 1, 1, 1, 1]) + self.assertEqual(len(rows), 1) + self.assertEqual(rows[0].columns, [1, 2, 3, 4, 5]) + + def test_concurrent_updates_retry_and_commit(self): + # start with a known value + self.query.insert(1, 5, 6, 7, 8) + + t1 = Transaction() + t1.add_query(self.query.update, self.table, 1, None, 50, None, None, None) + + t2 = Transaction() + t2.add_query(self.query.update, self.table, 1, None, 75, None, None, None) + + w1 = TransactionWorker([t1]) + w2 = TransactionWorker([t2]) + + w1.run() + w2.run() + w1.join() + w2.join() + + self.assertEqual(w1.result, 1) + self.assertEqual(w2.result, 1) + + # final value should reflect one of the successful updates + row = self.query.select(1, 0, [1, 1, 1, 1, 1])[0].columns + self.assertIn(row[1], (50, 75)) + + +if __name__ == "__main__": + unittest.main() diff --git a/lstore/transaction.py b/lstore/transaction.py index 7fd7611..8a69d7a 100644 --- a/lstore/transaction.py +++ b/lstore/transaction.py @@ -1,5 +1,64 @@ -from lstore.table import Table, Record -from lstore.index import Index +import time +import random +from config import Config +import threading +from enum import Enum + +class L(Enum): + S = 0 + X = 1 + + +class _LockManager: + def __init__(self): + self._lock = threading.Lock() + self._locks = {} + + + def _new_lock(self, mode, owner): + return {"mode": mode, "owners": {owner}} + + + def acquire(self, txn_id, rid, mode): + with self._lock: + lock = self._locks.get(rid) + if not lock: + self._locks[rid] = self._new_lock(mode, txn_id) + return True + owners = lock["owners"] + cur_mode = lock["mode"] + + if cur_mode == L.S: + if mode == L.S: + owners.add(txn_id) + return True + if mode == L.X and owners == {txn_id}: + lock["mode"] = L.X + return True + + elif cur_mode == L.X and owners == {txn_id}: + return True + return False + + + def release_all(self, txn_id): + with self._lock: + empty = [] + for rid, entry in self._locks.items(): + owners = entry["owners"] + owners.discard(txn_id) + if not owners: + empty.append(rid) + for rid in empty: + del self._locks[rid] + + +_GLOBAL_LOCK_MANAGER = _LockManager() + + +class _AbortTransaction(Exception): + """Internal control-flow exception to trigger an abort and retry.""" + pass class Transaction: @@ -7,8 +66,9 @@ class Transaction: # Creates a transaction object. """ def __init__(self): - self.queries = [] - pass + self.queries = [] # (query_fn, table, args) + self.undo_log = [] # list of undo entries + self.locked_rids = {} # rid -> held mode """ # Adds the given query to this transaction @@ -18,26 +78,225 @@ def __init__(self): # t.add_query(q.update, grades_table, 0, *[None, 1, None, 2, None]) """ def add_query(self, query, table, *args): - self.queries.append((query, args)) - # use grades_table for aborting + self.queries.append((query, table, args)) + # store table for lock/undo context # If you choose to implement this differently this method must still return True if transaction commits or False on abort - def run(self): - for query, args in self.queries: - result = query(*args) - # If the query has failed the transaction should abort - if result == False: - return self.abort() - return self.commit() + def run(self, max_attempts=Config.max_attempts): + # keep retrying until we successfully commit or attempts exhausted + attempts = 0 + current_op = None + while True: + self.undo_log.clear() + self.locked_rids.clear() + try: + for query_fn, table, args in self.queries: + op_name = query_fn.__name__ + current_op = f"{op_name}@{getattr(table, 'name', 'unknown')}" + if op_name == "insert": + self._execute_insert(query_fn, table, args) + elif op_name == "update": + self._execute_update(query_fn, table, args) + elif op_name == "delete": + self._execute_delete(query_fn, table, args) + else: + self._execute_read(query_fn, table, args) + return self.commit() + except _AbortTransaction as e: + print(f"[txn {id(self)}] abort attempt {attempts + 1} on {current_op}: {e}") + self.abort() + attempts += 1 + if max_attempts is not None and attempts >= max_attempts: + return False + time.sleep(random.uniform(0.001, 0.01 * attempts)) + continue + except Exception as e: + print( + f"[txn {id(self)}] exception attempt {attempts + 1} on {current_op}: " + f"{type(e).__name__}: {e}" + ) + self.abort() + attempts += 1 + if max_attempts is not None and attempts >= max_attempts: + return False + continue def abort(self): - #TODO: do roll-back and any other necessary operations + # rollback in reverse order + for entry in reversed(self.undo_log): + etype = entry["type"] + table = entry["table"] + if etype == "insert": + # remove the inserted record + rid = entry["rid"] + try: + table.delete_record(rid) + except Exception as e: + print(f"[txn {id(self)}] rollback insert failed for rid {rid}: {e}") + elif etype == "update": + rid = entry["rid"] + old_indirection = entry["old_indirection"] + old_schema = entry["old_schema"] + try: + self._restore_base_metadata(table, rid, old_indirection, old_schema) + except Exception as e: + print(f"[txn {id(self)}] rollback update failed for rid {rid}: {e}") + elif etype == "delete": + rid = entry["rid"] + old_values = entry["old_values"] + old_indirection = entry["old_indirection"] + old_schema = entry["old_schema"] + try: + # restore index entry + table.index.add(rid, old_values) + # restore base metadata + self._restore_base_metadata(table, rid, old_indirection, old_schema) + except Exception as e: + print(f"[txn {id(self)}] rollback delete failed for rid {rid}: {e}") + + + self._release_locks() return False def commit(self): - # TODO: commit to database + self.undo_log.clear() + self._release_locks() return True + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _acquire_lock(self, rid, mode=L.X): + held = self.locked_rids.get(rid) + if held: + # already have a lock; allow if sufficient or attempt upgrade + if held == L.X or mode == held: + return True + # upgrade S -> X + ok = _GLOBAL_LOCK_MANAGER.acquire(id(self), rid, mode) + if ok: + self.locked_rids[rid] = mode + return ok + ok = _GLOBAL_LOCK_MANAGER.acquire(id(self), rid, mode) + if ok: + self.locked_rids[rid] = mode + return ok + + def _release_locks(self): + if self.locked_rids: + _GLOBAL_LOCK_MANAGER.release_all(id(self)) + self.locked_rids.clear() + + def _execute_read(self, query_fn, table, args): + op_name = query_fn.__name__ + rids = [] + if op_name in ("select", "select_version"): + search_key, search_key_index = args[0], args[1] + rids = table.index.locate(search_key_index, search_key) + elif op_name in ("sum", "sum_version"): + start, end = args[0], args[1] + rids = table.index.locate_range(start, end, table.key) + + for rid in rids: + if not self._acquire_lock(rid, L.S): + raise _AbortTransaction(f"{op_name}: shared lock failed for rid {rid}") + + result = query_fn(*args) + if result is False: + raise _AbortTransaction(f"{op_name} returned False") + return result + + def _execute_insert(self, query_fn, table, args): + # lock allocation and write together to keep RID consistent + with table._insert_lock: + new_rid = table.page_directory.peek_base_rid() + if not self._acquire_lock(new_rid, L.X): + raise _AbortTransaction("lock acquisition failed after insert") + + result = query_fn(*args) + if result is False: + raise _AbortTransaction("insert returned False") + + self.undo_log.append({"type": "insert", "table": table, "rid": new_rid}) + + def _execute_update(self, query_fn, table, args): + primary_key = args[0] + rids = table.index.locate(table.key, primary_key) + if not rids: + raise _AbortTransaction(f"update: no rid for key {primary_key}") + rid = rids[0] + if not self._acquire_lock(rid, L.X): + raise _AbortTransaction(f"update: lock acquisition failed for rid {rid}") + + try: + base_record = table.get_record(rid) + cumulative = table.get_cumulative_updated_record(rid) + except Exception: + raise _AbortTransaction(f"update: failed to fetch record for rid {rid}") + + old_values = cumulative[Config.tail_meta_columns : Config.tail_meta_columns + table.num_columns] + old_indirection = base_record[Config.indirection_column] + old_schema = base_record[Config.schema_encoding_column] + + result = query_fn(*args) + if result is False: + raise _AbortTransaction("update returned False") + + self.undo_log.append( + { + "type": "update", + "table": table, + "rid": rid, + "old_values": old_values, + "old_indirection": old_indirection, + "old_schema": old_schema, + } + ) + + def _execute_delete(self, query_fn, table, args): + primary_key = args[0] + rids = table.index.locate(table.key, primary_key) + if not rids: + raise _AbortTransaction(f"delete: no rid for key {primary_key}") + rid = rids[0] + if not self._acquire_lock(rid, L.X): + raise _AbortTransaction(f"delete: lock acquisition failed for rid {rid}") + + try: + base_record = table.get_record(rid) + cumulative = table.get_cumulative_updated_record(rid) + except Exception: + raise _AbortTransaction(f"delete: failed to fetch record for rid {rid}") + + old_values = cumulative[Config.tail_meta_columns : Config.tail_meta_columns + table.num_columns] + old_indirection = base_record[Config.indirection_column] + old_schema = base_record[Config.schema_encoding_column] + + result = query_fn(*args) + if result is False: + raise _AbortTransaction("delete returned False") + + self.undo_log.append( + { + "type": "delete", + "table": table, + "rid": rid, + "old_values": old_values, + "old_indirection": old_indirection, + "old_schema": old_schema, + } + ) + + def _restore_base_metadata(self, table, rid, indirection, schema): + """Write the base record's indirection and schema back to prior values.""" + range_id, segment, page_index, slot_index = table.page_directory.decode_rid(rid) + if segment != 0: + return False + with table.page_directory._column(range_id, 0, page_index, Config.indirection_column) as ind_page: + ind_page.write_slot(slot_index, indirection) + with table.page_directory._column(range_id, 0, page_index, Config.schema_encoding_column) as schema_page: + schema_page.write_slot(slot_index, schema) + return True diff --git a/lstore/transaction_worker.py b/lstore/transaction_worker.py index c53ea49..f36c5e5 100644 --- a/lstore/transaction_worker.py +++ b/lstore/transaction_worker.py @@ -1,16 +1,17 @@ from lstore.table import Table, Record from lstore.index import Index +import threading class TransactionWorker: """ # Creates a transaction worker object. """ - def __init__(self, transactions = []): + def __init__(self, transactions=None): self.stats = [] - self.transactions = transactions + self.transactions = list(transactions) if transactions is not None else [] self.result = 0 - pass + """ @@ -24,15 +25,17 @@ def add_transaction(self, t): Runs all transaction as a thread """ def run(self): - pass - # here you need to create a thread and call __run + self._thread = threading.Thread(target=self.__run) + self._thread.start() """ Waits for the worker to finish """ def join(self): - pass + thread = getattr(self, "_thread", None) + if thread: + thread.join() def __run(self): @@ -41,4 +44,3 @@ def __run(self): self.stats.append(transaction.run()) # stores the number of transactions that committed self.result = len(list(filter(lambda x: x, self.stats))) -