-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdb_setup_shared.py
More file actions
164 lines (140 loc) · 5.84 KB
/
db_setup_shared.py
File metadata and controls
164 lines (140 loc) · 5.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Database setup for single-database multi-user approach.
All users share one database with user_id scoping.
"""
import os
import sqlite3
import psycopg2
from typing import Iterable, Union, Type
from db_schema_definitions import (
MULTI_USER_POSTGRESQL_INDEXES,
MULTI_USER_POSTGRESQL_SCHEMAS,
MULTI_USER_SQLITE_INDEXES,
MULTI_USER_SQLITE_SCHEMAS,
MULTI_USER_DEFAULTS,
)
from error_utils import safe_execute
def _add_column_if_not_exists(
cursor, table: str, column: str, column_def: str, dialect: str
) -> None:
"""Add a column to a table if it doesn't already exist."""
if dialect == "postgres":
cursor.execute(
"""
SELECT 1 FROM information_schema.columns
WHERE table_name = %s AND column_name = %s
""",
(table, column),
)
exists = cursor.fetchone() is not None
elif dialect == "sqlite":
cursor.execute(f"PRAGMA table_info({table})")
exists = any(row[1] == column for row in cursor.fetchall())
else:
raise ValueError(f"Unsupported dialect: {dialect}")
if not exists:
_execute_with_reporting(
cursor, f"ALTER TABLE {table} ADD COLUMN {column} {column_def}"
)
def _execute_with_reporting(
cursor, sql: str, ignore_exceptions: Iterable[Type[BaseException]] = ()
) -> None:
"""Execute a SQL statement and report failures with the offending query."""
try:
cursor.execute(sql)
except tuple(ignore_exceptions):
pass
except Exception as e:
print(f"Error executing query: {e}\nQuery:\n{sql}")
raise
def setup_shared_database(connection: Union[str, object, None] = None) -> bool:
"""Set up database schema for single-database multi-user mode.
Args:
connection: Optional connection string. If not provided, the
PANTRY_DATABASE_URL environment variable will be used. For PostgreSQL
the value must start with ``postgresql://`` or ``postgres://``. Any
other string is treated as a SQLite path.
Returns:
bool: True if successful, False otherwise
"""
try:
if connection is None:
connection = os.getenv("PANTRY_DATABASE_URL")
if not connection:
raise ValueError("PANTRY_DATABASE_URL environment variable not set")
if isinstance(connection, str):
if connection.startswith(("postgresql://", "postgres://")):
return _setup_postgresql_shared(connection)
else:
return _setup_sqlite_shared(connection)
else:
# Connection object support not implemented - use connection string
raise ValueError(
"Connection objects not supported. Use connection string instead."
)
except Exception as e:
print(f"Error setting up shared database: {e}")
return False
def _setup_shared_database_impl(cursor, dialect: str) -> None:
"""Shared implementation for setting up multi-user database schema.
Args:
cursor: Database cursor (PostgreSQL or SQLite)
dialect: Either 'postgres' or 'sqlite'
"""
# Select appropriate schemas and indexes
if dialect == "postgres":
schemas = MULTI_USER_POSTGRESQL_SCHEMAS
indexes = MULTI_USER_POSTGRESQL_INDEXES
varchar_type = "VARCHAR(50)"
elif dialect == "sqlite":
schemas = MULTI_USER_SQLITE_SCHEMAS
indexes = MULTI_USER_SQLITE_INDEXES
varchar_type = "TEXT"
else:
raise ValueError(f"Unsupported dialect: {dialect}")
# Create all tables using centralized schema definitions
for _, schema in schemas.items():
_execute_with_reporting(cursor, schema)
# Add additional columns to existing users table if they don't exist
timestamp_type = "TIMESTAMP" if dialect == "postgres" else "TIMESTAMP"
boolean_type = "BOOLEAN" if dialect == "postgres" else "BOOLEAN"
columns_to_add = [
("household_id", "INTEGER REFERENCES users(id)"),
("household_adults", "INTEGER DEFAULT 2"),
("household_children", "INTEGER DEFAULT 0"),
("preferred_volume_unit", f"{varchar_type} DEFAULT 'Milliliter'"),
("preferred_weight_unit", f"{varchar_type} DEFAULT 'Gram'"),
("preferred_count_unit", f"{varchar_type} DEFAULT 'Piece'"),
("last_login", f"{timestamp_type} DEFAULT NULL"),
(
"is_admin",
f"{boolean_type} DEFAULT {'FALSE' if dialect == 'postgres' else '0'}",
),
]
for column, definition in columns_to_add:
_add_column_if_not_exists(cursor, "users", column, definition, dialect=dialect)
# Create indexes for better performance
for index_sql in indexes:
_execute_with_reporting(cursor, index_sql)
# Insert default data
for default_sql in MULTI_USER_DEFAULTS:
_execute_with_reporting(cursor, default_sql)
@safe_execute("setup PostgreSQL shared database", default_return=False, log_errors=True)
def _setup_postgresql_shared(connection_string: str) -> bool:
"""Set up PostgreSQL schema for shared database using centralized schema definitions."""
with psycopg2.connect(connection_string) as conn:
with conn.cursor() as cursor:
_setup_shared_database_impl(cursor, dialect="postgres")
return True
@safe_execute("setup SQLite shared database", default_return=False, log_errors=True)
def _setup_sqlite_shared(db_path: str) -> bool:
"""Set up SQLite schema for shared database using centralized schema definitions."""
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
_setup_shared_database_impl(cursor, dialect="sqlite")
return True
if __name__ == "__main__":
# Test setup using environment variable
print("Setting up shared database schema...")
success = setup_shared_database()
print(f"Setup {'successful' if success else 'failed'}")