|
| 1 | +""" |
| 2 | +Unified database connection utilities. |
| 3 | +Provides consistent connection management across all database operations. |
| 4 | +""" |
| 5 | +import os |
| 6 | +import sqlite3 |
| 7 | +from typing import Optional |
| 8 | +from contextlib import contextmanager |
| 9 | +from utils.logger import get_logger |
| 10 | + |
| 11 | +logger = get_logger(__name__) |
| 12 | + |
| 13 | + |
| 14 | +def get_db_connection( |
| 15 | + db_path: str, |
| 16 | + timeout: float = 30.0, |
| 17 | + enable_wal: bool = True, |
| 18 | + enable_vector: bool = False, |
| 19 | + row_factory: bool = True |
| 20 | +) -> sqlite3.Connection: |
| 21 | + """ |
| 22 | + Create a database connection with consistent configuration. |
| 23 | + |
| 24 | + Args: |
| 25 | + db_path: Path to the SQLite database file |
| 26 | + timeout: Timeout in seconds for waiting on locks (default: 30.0) |
| 27 | + enable_wal: Enable Write-Ahead Logging mode (default: True) |
| 28 | + enable_vector: Load sqlite-vector extension (default: False) |
| 29 | + row_factory: Use sqlite3.Row factory for dict-like access (default: True) |
| 30 | + |
| 31 | + Returns: |
| 32 | + sqlite3.Connection object configured for the specified operations |
| 33 | + |
| 34 | + Raises: |
| 35 | + RuntimeError: If vector extension fails to load when enable_vector=True |
| 36 | + """ |
| 37 | + # Create directory if needed |
| 38 | + dirname = os.path.dirname(os.path.abspath(db_path)) |
| 39 | + if dirname and not os.path.isdir(dirname): |
| 40 | + os.makedirs(dirname, exist_ok=True) |
| 41 | + |
| 42 | + # Create connection with consistent settings |
| 43 | + conn = sqlite3.connect(db_path, timeout=timeout, check_same_thread=False) |
| 44 | + |
| 45 | + if row_factory: |
| 46 | + conn.row_factory = sqlite3.Row |
| 47 | + |
| 48 | + # Enable WAL mode for better concurrency |
| 49 | + if enable_wal: |
| 50 | + try: |
| 51 | + conn.execute("PRAGMA journal_mode = WAL;") |
| 52 | + except Exception as e: |
| 53 | + logger.warning(f"Failed to enable WAL mode: {e}") |
| 54 | + |
| 55 | + # Set busy timeout (milliseconds) |
| 56 | + try: |
| 57 | + conn.execute(f"PRAGMA busy_timeout = {int(timeout * 1000)};") |
| 58 | + except Exception as e: |
| 59 | + logger.warning(f"Failed to set busy_timeout: {e}") |
| 60 | + |
| 61 | + # Load vector extension if requested |
| 62 | + if enable_vector: |
| 63 | + from .vector_operations import load_sqlite_vector_extension |
| 64 | + load_sqlite_vector_extension(conn) |
| 65 | + logger.debug(f"Vector extension loaded for connection to {db_path}") |
| 66 | + |
| 67 | + return conn |
| 68 | + |
| 69 | + |
| 70 | +@contextmanager |
| 71 | +def db_connection(db_path: str, **kwargs): |
| 72 | + """ |
| 73 | + Context manager for database connections with automatic cleanup. |
| 74 | + |
| 75 | + Args: |
| 76 | + db_path: Path to the SQLite database file |
| 77 | + **kwargs: Additional arguments passed to get_db_connection() |
| 78 | + |
| 79 | + Yields: |
| 80 | + sqlite3.Connection object |
| 81 | + |
| 82 | + Example: |
| 83 | + with db_connection(db_path) as conn: |
| 84 | + cur = conn.cursor() |
| 85 | + cur.execute("SELECT * FROM files") |
| 86 | + results = cur.fetchall() |
| 87 | + """ |
| 88 | + conn = get_db_connection(db_path, **kwargs) |
| 89 | + try: |
| 90 | + yield conn |
| 91 | + finally: |
| 92 | + try: |
| 93 | + conn.close() |
| 94 | + except Exception as e: |
| 95 | + logger.warning(f"Error closing database connection: {e}") |
0 commit comments