-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdatabase.py
More file actions
939 lines (801 loc) · 34.8 KB
/
database.py
File metadata and controls
939 lines (801 loc) · 34.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
"""
Database abstraction layer for GGUF Forge.
Supports both SQLite and MSSQL backends with async operations.
Uses aioodbc for async ODBC connections.
"""
import os
import logging
import asyncio
from pathlib import Path
from typing import Optional, Any, List, Dict
from contextlib import asynccontextmanager
from abc import ABC, abstractmethod
logger = logging.getLogger("GGUF_Forge")
# Database configuration
DB_TYPE = os.getenv("DB_TYPE", "sqlite").lower() # sqlite or mssql
DB_PATH = None # For SQLite
# MSSQL Configuration
MSSQL_HOST = os.getenv("MSSQL_HOST", "")
MSSQL_PORT = os.getenv("MSSQL_PORT", "1433")
MSSQL_DATABASE = os.getenv("MSSQL_DATABASE", "")
MSSQL_USER = os.getenv("MSSQL_USER", "")
MSSQL_PASSWORD = os.getenv("MSSQL_PASSWORD", "")
MSSQL_ENCRYPT = os.getenv("MSSQL_ENCRYPT", "yes")
MSSQL_TRUST_CERT = os.getenv("MSSQL_TRUST_CERT", "yes")
MSSQL_CONN_TIMEOUT = os.getenv("MSSQL_CONN_TIMEOUT", "60") # Connection timeout in seconds
MSSQL_LOGIN_TIMEOUT = os.getenv("MSSQL_LOGIN_TIMEOUT", "60") # Login timeout in seconds
# Connection retry settings
MAX_RETRIES = 3
RETRY_DELAY = 1.0 # seconds
# Admin Users - comma-separated list of HuggingFace usernames who should be admins
ADMIN_USERS = [u.strip().lower() for u in os.getenv("ADMIN_USERS", "").split(",") if u.strip()]
# Connection pool settings
POOL_MIN_SIZE = int(os.getenv("DB_POOL_MIN", "2"))
POOL_MAX_SIZE = int(os.getenv("DB_POOL_MAX", "10"))
# Global connection pool (initialized on first use)
_mssql_pool = None
_pool_lock = asyncio.Lock()
async def _get_mssql_pool():
"""Get or create the MSSQL connection pool."""
global _mssql_pool
async with _pool_lock:
if _mssql_pool is None:
try:
import aioodbc
conn_str = AsyncMSSQLConnection._get_connection_string()
_mssql_pool = await aioodbc.create_pool(
dsn=conn_str,
minsize=POOL_MIN_SIZE,
maxsize=POOL_MAX_SIZE,
autocommit=False
)
logger.info(f"MSSQL connection pool created (min={POOL_MIN_SIZE}, max={POOL_MAX_SIZE})")
except ImportError as e:
error_msg = str(e)
if "libodbc" in error_msg:
raise ImportError(
"Missing ODBC system libraries. Please install unixODBC and the Microsoft ODBC Driver.\n\n"
"On Ubuntu/Debian:\n"
" sudo apt-get update\n"
" sudo apt-get install -y unixodbc-dev unixodbc\n"
" curl https://packages.microsoft.com/keys/microsoft.asc | sudo apt-key add -\n"
" curl https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/prod.list | sudo tee /etc/apt/sources.list.d/mssql-release.list\n"
" sudo apt-get update\n"
" sudo ACCEPT_EULA=Y apt-get install -y msodbcsql18\n\n"
"Alternatively, switch to SQLite by setting DB_TYPE=sqlite in your environment.\n\n"
f"Original error: {error_msg}"
) from e
raise
return _mssql_pool
async def close_pool():
"""Close the connection pool (call on shutdown)."""
global _mssql_pool
if _mssql_pool:
_mssql_pool.close()
await _mssql_pool.wait_closed()
_mssql_pool = None
logger.info("MSSQL connection pool closed")
def is_admin_user(username: str) -> bool:
"""Check if a username is in the admin list."""
return username.lower() in ADMIN_USERS
def set_db_path(path: Path):
"""Set the SQLite database path."""
global DB_PATH
DB_PATH = path
class DatabaseRow:
"""A dict-like row that supports both dict access and attribute access."""
def __init__(self, data: dict):
self._data = data
def __getitem__(self, key):
return self._data[key]
def __contains__(self, key):
return key in self._data
def get(self, key, default=None):
return self._data.get(key, default)
def keys(self):
return self._data.keys()
def values(self):
return self._data.values()
def items(self):
return self._data.items()
def to_dict(self) -> dict:
"""Convert to a plain dict with datetime objects serialized to ISO format strings."""
from datetime import datetime
result = {}
for key, value in self._data.items():
if isinstance(value, datetime):
result[key] = value.isoformat()
else:
result[key] = value
return result
class AsyncDatabaseConnection(ABC):
"""Abstract base class for async database connections."""
@abstractmethod
async def execute(self, query: str, params: tuple = ()) -> Any:
"""Execute a query and return cursor/result."""
pass
@abstractmethod
async def commit(self):
"""Commit the transaction."""
pass
@abstractmethod
async def close(self):
"""Close the connection."""
pass
@abstractmethod
async def fetchone(self) -> Optional[DatabaseRow]:
"""Fetch one row from last query."""
pass
@abstractmethod
async def fetchall(self) -> List[DatabaseRow]:
"""Fetch all rows from last query."""
pass
@property
@abstractmethod
def lastrowid(self) -> int:
"""Get the last inserted row ID."""
pass
class AsyncSQLiteConnection(AsyncDatabaseConnection):
"""Async SQLite database connection wrapper using aiosqlite."""
def __init__(self, conn, cursor=None):
self.conn = conn
self.cursor = cursor
self._last_id = 0
async def execute(self, query: str, params: tuple = ()) -> 'AsyncSQLiteConnection':
self.cursor = await self.conn.execute(query, params)
self._last_id = self.cursor.lastrowid
return self
async def commit(self):
await self.conn.commit()
async def close(self):
await self.conn.close()
async def fetchone(self) -> Optional[DatabaseRow]:
if self.cursor:
row = await self.cursor.fetchone()
if row:
# Get column names from cursor description
columns = [desc[0] for desc in self.cursor.description]
return DatabaseRow(dict(zip(columns, row)))
return None
async def fetchall(self) -> List[DatabaseRow]:
if self.cursor:
rows = await self.cursor.fetchall()
columns = [desc[0] for desc in self.cursor.description]
return [DatabaseRow(dict(zip(columns, row))) for row in rows]
return []
@property
def lastrowid(self) -> int:
return self._last_id
class AsyncMSSQLConnection(AsyncDatabaseConnection):
"""Async MSSQL database connection wrapper using aioodbc."""
# Cache the detected driver to avoid repeated lookups and logging
_detected_driver = None
# Connection timeout errors that indicate session timeout
TIMEOUT_ERROR_CODES = [
'08S01', # Communication link failure
'08001', # Unable to connect
'HYT00', # Timeout expired
'40001', # Deadlock
]
@classmethod
def _get_driver(cls):
"""Auto-detect available ODBC driver or use environment override."""
# Return cached driver if already detected
if cls._detected_driver:
return cls._detected_driver
# Allow override via environment variable
driver_override = os.getenv("MSSQL_DRIVER", "")
if driver_override:
cls._detected_driver = driver_override
logger.info(f"Using ODBC driver (override): {driver_override}")
return driver_override
# Try to auto-detect available driver (prefer newer versions)
try:
import pyodbc
drivers = pyodbc.drivers()
except ImportError as e:
raise ImportError(
"Missing ODBC Python package. This should have been installed with requirements.txt.\n"
"Try: pip install pyodbc\n\n"
f"Original error: {str(e)}"
) from e
# Check for drivers in order of preference
preferred_drivers = [
"ODBC Driver 18 for SQL Server",
"ODBC Driver 17 for SQL Server",
"ODBC Driver 13 for SQL Server",
"SQL Server",
]
for driver in preferred_drivers:
if driver in drivers:
cls._detected_driver = driver
logger.info(f"Using ODBC driver: {driver}")
return driver
# No driver found - provide helpful error message
available_drivers = ", ".join(drivers) if drivers else "none"
logger.warning(
f"No Microsoft SQL Server ODBC driver found. Available drivers: {available_drivers}\n"
"Please install the Microsoft ODBC Driver for SQL Server.\n"
"See: https://docs.microsoft.com/en-us/sql/connect/odbc/download-odbc-driver-for-sql-server"
)
# Fallback to 18 if nothing found (will error with helpful message later)
cls._detected_driver = "ODBC Driver 18 for SQL Server"
return cls._detected_driver
@classmethod
def _get_connection_string(cls):
"""Build the ODBC connection string with timeout and keepalive settings."""
driver = cls._get_driver()
return (
f"DRIVER={{{driver}}};"
f"SERVER={MSSQL_HOST},{MSSQL_PORT};"
f"DATABASE={MSSQL_DATABASE};"
f"UID={MSSQL_USER};"
f"PWD={MSSQL_PASSWORD};"
f"Encrypt={MSSQL_ENCRYPT};"
f"TrustServerCertificate={MSSQL_TRUST_CERT};"
f"Connection Timeout={MSSQL_CONN_TIMEOUT};"
f"Login Timeout={MSSQL_LOGIN_TIMEOUT};"
)
def __init__(self, conn, pool=None):
self.conn = conn
self._pool = pool # Reference to pool for releasing connection
self.cursor = None
self._columns = []
self._last_id = 0
self._retry_count = 0
def _adapt_params(self, params: tuple) -> tuple:
"""Adapt parameter values for MSSQL compatibility."""
from datetime import datetime as dt
import re
if not params:
return params
adapted = []
for p in params:
# Convert datetime objects directly - pyodbc handles these properly
if isinstance(p, dt):
# Pass datetime objects as-is, pyodbc will handle conversion
adapted.append(p)
continue
if isinstance(p, str):
# Convert ISO datetime format (2025-12-29T07:02:59.123456) to MSSQL format
# Check if it looks like an ISO datetime
if len(p) >= 19 and p[4] == '-' and p[7] == '-' and p[10] == 'T':
# Replace T with space
p = p.replace('T', ' ')
# Truncate microseconds to milliseconds (SQL Server DATETIME only supports 3 decimal places)
# Match pattern like ".123456" and truncate to ".123"
p = re.sub(r'\.(\d{3})\d*', r'.\1', p)
adapted.append(p)
return tuple(adapted)
@classmethod
def _is_connection_timeout(cls, error) -> bool:
"""Check if the error is a connection/session timeout."""
error_str = str(error).upper()
# Check for known error codes
for code in cls.TIMEOUT_ERROR_CODES:
if code in error_str:
return True
# Check for timeout-related messages
timeout_keywords = [
'TIMEOUT', 'TIMED OUT', 'CONNECTION LOST',
'CONNECTION RESET', 'COMMUNICATION LINK FAILURE',
'LOGIN FAILED', 'CONNECTION FAILURE', 'NETWORK ERROR',
'BROKEN PIPE', 'CONNECTION CLOSED'
]
for keyword in timeout_keywords:
if keyword in error_str:
return True
return False
async def execute(self, query: str, params: tuple = ()) -> 'AsyncMSSQLConnection':
"""Execute a query with automatic retry on connection timeout."""
# Convert SQLite-style placeholders and syntax
query = self._adapt_query(query)
# Adapt parameters for MSSQL compatibility
params = self._adapt_params(params)
retries = 0
last_error = None
while retries <= MAX_RETRIES:
try:
self.cursor = await self.conn.cursor()
if params:
await self.cursor.execute(query, params)
else:
await self.cursor.execute(query)
# Get column names if available
if self.cursor.description:
self._columns = [column[0] for column in self.cursor.description]
# Handle INSERT to get last row ID
if query.strip().upper().startswith('INSERT'):
try:
await self.cursor.execute("SELECT SCOPE_IDENTITY()")
result = await self.cursor.fetchone()
if result and result[0]:
self._last_id = int(result[0])
except:
pass
# Reset retry count on success
self._retry_count = 0
return self
except Exception as e:
last_error = e
if self._is_connection_timeout(e) and retries < MAX_RETRIES:
retries += 1
self._retry_count = retries
logger.warning(f"Connection timeout detected, attempting reconnect (attempt {retries}/{MAX_RETRIES}): {e}")
# Wait before retrying
await asyncio.sleep(RETRY_DELAY * retries)
# Try to reconnect
try:
await self._reconnect()
except Exception as reconnect_error:
logger.error(f"Reconnection failed: {reconnect_error}")
# Continue to next retry attempt
else:
# Not a timeout error or max retries exceeded
raise
# Max retries exceeded
raise last_error
async def _reconnect(self):
"""Attempt to reconnect to the database."""
import aioodbc
# Close existing connection if possible
try:
await self.conn.close()
except:
pass
# Create new connection
conn_str = self._get_connection_string()
self.conn = await aioodbc.connect(dsn=conn_str)
logger.info("Successfully reconnected to MSSQL database")
def _adapt_query(self, query: str) -> str:
"""Adapt SQLite query syntax to MSSQL."""
import re
# Replace AUTOINCREMENT with IDENTITY
query = query.replace("AUTOINCREMENT", "IDENTITY(1,1)")
# Replace INTEGER PRIMARY KEY AUTOINCREMENT with INT PRIMARY KEY IDENTITY
query = query.replace("INTEGER PRIMARY KEY IDENTITY(1,1)", "INT PRIMARY KEY IDENTITY(1,1)")
# Replace DATETIME DEFAULT CURRENT_TIMESTAMP
query = query.replace("DATETIME DEFAULT CURRENT_TIMESTAMP", "DATETIME DEFAULT GETDATE()")
# Replace TEXT with NVARCHAR(MAX) for better Unicode support
query = query.replace(" TEXT ", " NVARCHAR(MAX) ")
query = query.replace(" TEXT,", " NVARCHAR(MAX),")
query = query.replace(" TEXT)", " NVARCHAR(MAX))")
# Handle INSERT OR REPLACE -> MERGE (simplified: just use INSERT for now)
if "INSERT OR REPLACE" in query.upper():
# For simplicity, we'll handle this with a DELETE + INSERT pattern
# In production, you'd want proper MERGE statements
query = query.replace("INSERT OR REPLACE", "INSERT")
# Handle CREATE TABLE IF NOT EXISTS
if "CREATE TABLE IF NOT EXISTS" in query:
table_name = query.split("CREATE TABLE IF NOT EXISTS")[1].split("(")[0].strip()
query = f"""
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='{table_name}' AND xtype='U')
{query.replace('CREATE TABLE IF NOT EXISTS', 'CREATE TABLE')}
"""
# Handle LIMIT -> TOP (MSSQL uses TOP instead of LIMIT)
# Pattern: SELECT ... FROM ... LIMIT N -> SELECT TOP N ... FROM ...
limit_match = re.search(r'\bLIMIT\s+(\d+)\s*$', query, re.IGNORECASE)
if limit_match:
limit_num = limit_match.group(1)
# Remove the LIMIT clause
query = re.sub(r'\bLIMIT\s+\d+\s*$', '', query, flags=re.IGNORECASE)
# Add TOP after SELECT
query = re.sub(r'^(\s*SELECT\s+)', rf'\1TOP {limit_num} ', query, flags=re.IGNORECASE)
return query
async def commit(self):
await self.conn.commit()
async def close(self):
"""Release connection back to pool or close it."""
if self._pool:
await self._pool.release(self.conn)
else:
await self.conn.close()
async def fetchone(self) -> Optional[DatabaseRow]:
if self.cursor:
row = await self.cursor.fetchone()
if row:
return DatabaseRow(dict(zip(self._columns, row)))
return None
async def fetchall(self) -> List[DatabaseRow]:
if self.cursor:
rows = await self.cursor.fetchall()
return [DatabaseRow(dict(zip(self._columns, row))) for row in rows]
return []
@property
def lastrowid(self) -> int:
return self._last_id
async def get_db_connection() -> AsyncDatabaseConnection:
"""Get an async database connection based on configuration.
For MSSQL, connections are acquired from a pool for better performance.
For SQLite, a new connection is created (SQLite doesn't need pooling).
"""
if DB_TYPE == "mssql":
pool = await _get_mssql_pool()
conn = await pool.acquire()
return AsyncMSSQLConnection(conn, pool=pool)
else:
import aiosqlite
conn = await aiosqlite.connect(DB_PATH)
return AsyncSQLiteConnection(conn)
async def init_db():
"""Initialize the database with all required tables."""
conn = await get_db_connection()
try:
if DB_TYPE == "mssql":
await _init_mssql_tables(conn)
else:
await _init_sqlite_tables(conn)
await conn.commit()
logger.info(f"Database initialized successfully ({DB_TYPE})")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
finally:
await conn.close()
async def _init_sqlite_tables(conn: AsyncDatabaseConnection):
"""Initialize SQLite tables."""
# Users table
await conn.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
hashed_password TEXT NOT NULL,
role TEXT NOT NULL,
api_key TEXT
)
''')
# Models table
await conn.execute('''
CREATE TABLE IF NOT EXISTS models (
id TEXT PRIMARY KEY,
hf_repo_id TEXT NOT NULL,
status TEXT NOT NULL,
progress INTEGER DEFAULT 0,
log TEXT DEFAULT '',
error_details TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
completed_at DATETIME,
completed_quants TEXT DEFAULT '',
ignore_space_check INTEGER DEFAULT 0,
quants_to_run TEXT DEFAULT '',
enable_shard_merging INTEGER DEFAULT 1,
requested_by TEXT
)
''')
# Migration: Add completed_quants column if it doesn't exist
try:
await conn.execute("ALTER TABLE models ADD COLUMN completed_quants TEXT DEFAULT ''")
except:
pass # Column already exists
# Migration: Add ignore_space_check column if it doesn't exist
try:
await conn.execute("ALTER TABLE models ADD COLUMN ignore_space_check INTEGER DEFAULT 0")
except:
pass # Column already exists
# Migration: Add quants_to_run column if it doesn't exist
try:
await conn.execute("ALTER TABLE models ADD COLUMN quants_to_run TEXT DEFAULT ''")
except:
pass # Column already exists
# Migration: Add enable_shard_merging column if it doesn't exist
try:
await conn.execute("ALTER TABLE models ADD COLUMN enable_shard_merging INTEGER DEFAULT 1")
except:
pass # Column already exists
# Migration: Add requested_by column if it doesn't exist
try:
await conn.execute("ALTER TABLE models ADD COLUMN requested_by TEXT")
except:
pass # Column already exists
# Requests table
await conn.execute('''
CREATE TABLE IF NOT EXISTS requests (
id INTEGER PRIMARY KEY AUTOINCREMENT,
hf_repo_id TEXT NOT NULL,
requested_by TEXT,
status TEXT DEFAULT 'pending',
decline_reason TEXT DEFAULT '',
requested_quants TEXT DEFAULT '',
approved_quants TEXT DEFAULT '',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# Migration: Add quant columns if they don't exist (for existing databases)
try:
await conn.execute("ALTER TABLE requests ADD COLUMN requested_quants TEXT DEFAULT ''")
except:
pass # Column already exists
try:
await conn.execute("ALTER TABLE requests ADD COLUMN approved_quants TEXT DEFAULT ''")
except:
pass # Column already exists
# OAuth users table
await conn.execute('''
CREATE TABLE IF NOT EXISTS oauth_users (
id TEXT PRIMARY KEY,
username TEXT NOT NULL,
email TEXT,
avatar_url TEXT,
session_token TEXT,
role TEXT DEFAULT 'user',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# Tickets table
await conn.execute('''
CREATE TABLE IF NOT EXISTS tickets (
id INTEGER PRIMARY KEY AUTOINCREMENT,
request_id INTEGER NOT NULL,
status TEXT DEFAULT 'open',
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
closed_at DATETIME,
FOREIGN KEY (request_id) REFERENCES requests(id)
)
''')
# Ticket messages table
await conn.execute('''
CREATE TABLE IF NOT EXISTS ticket_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
ticket_id INTEGER NOT NULL,
sender TEXT NOT NULL,
sender_role TEXT NOT NULL,
message TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (ticket_id) REFERENCES tickets(id)
)
''')
# User preferences table
await conn.execute('''
CREATE TABLE IF NOT EXISTS user_preferences (
hf_username TEXT PRIMARY KEY,
default_quants TEXT DEFAULT '',
theme TEXT DEFAULT 'dark',
notifications_enabled INTEGER DEFAULT 1,
auto_refresh_interval INTEGER DEFAULT 30
)
''')
# Quant priority table for admin-configurable quantization order
await conn.execute('''
CREATE TABLE IF NOT EXISTS quant_priority (
id INTEGER PRIMARY KEY CHECK (id = 1),
priority_order TEXT DEFAULT '',
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# Initialize default quant priority if not exists
await conn.execute('''
INSERT OR IGNORE INTO quant_priority (id, priority_order) VALUES (1, '')
''')
# Performance indexes for frequently queried columns
await conn.execute('CREATE INDEX IF NOT EXISTS idx_models_status ON models(status)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_models_hf_repo_id ON models(hf_repo_id)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at DESC)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_requests_status ON requests(status)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_requests_requested_by ON requests(requested_by)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_requests_created_at ON requests(created_at DESC)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_tickets_status ON tickets(status)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_tickets_request_id ON tickets(request_id)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_ticket_messages_ticket_id ON ticket_messages(ticket_id)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_oauth_users_session_token ON oauth_users(session_token)')
await conn.execute('CREATE INDEX IF NOT EXISTS idx_users_api_key ON users(api_key)')
async def _init_mssql_tables(conn: AsyncDatabaseConnection):
"""Initialize MSSQL tables."""
# Users table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='users' AND xtype='U')
CREATE TABLE users (
id INT PRIMARY KEY IDENTITY(1,1),
username NVARCHAR(255) UNIQUE NOT NULL,
hashed_password NVARCHAR(MAX) NOT NULL,
role NVARCHAR(50) NOT NULL,
api_key NVARCHAR(255)
)
''')
await conn.commit()
# Models table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='models' AND xtype='U')
CREATE TABLE models (
id NVARCHAR(255) PRIMARY KEY,
hf_repo_id NVARCHAR(500) NOT NULL,
status NVARCHAR(50) NOT NULL,
progress INT DEFAULT 0,
log NVARCHAR(MAX) DEFAULT '',
error_details NVARCHAR(MAX) DEFAULT '',
created_at DATETIME DEFAULT GETDATE(),
completed_at DATETIME,
completed_quants NVARCHAR(MAX) DEFAULT '',
ignore_space_check BIT DEFAULT 0,
quants_to_run NVARCHAR(MAX) DEFAULT '',
enable_shard_merging BIT DEFAULT 1,
requested_by NVARCHAR(255)
)
''')
await conn.commit()
# Migration: Add completed_quants column if it doesn't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'models' AND COLUMN_NAME = 'completed_quants')
ALTER TABLE models ADD completed_quants NVARCHAR(MAX) DEFAULT ''
''')
await conn.commit()
except:
pass
# Migration: Add ignore_space_check column if it doesn't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'models' AND COLUMN_NAME = 'ignore_space_check')
ALTER TABLE models ADD ignore_space_check BIT DEFAULT 0
''')
await conn.commit()
except:
pass
# Migration: Add quants_to_run column if it doesn't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'models' AND COLUMN_NAME = 'quants_to_run')
ALTER TABLE models ADD quants_to_run NVARCHAR(MAX) DEFAULT ''
''')
await conn.commit()
except:
pass
# Migration: Add enable_shard_merging column if it doesn't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'models' AND COLUMN_NAME = 'enable_shard_merging')
ALTER TABLE models ADD enable_shard_merging BIT DEFAULT 1
''')
await conn.commit()
except:
pass
# Migration: Add requested_by column if it doesn't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'models' AND COLUMN_NAME = 'requested_by')
ALTER TABLE models ADD requested_by NVARCHAR(255)
''')
await conn.commit()
except:
pass
# Requests table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='requests' AND xtype='U')
CREATE TABLE requests (
id INT PRIMARY KEY IDENTITY(1,1),
hf_repo_id NVARCHAR(500) NOT NULL,
requested_by NVARCHAR(255),
status NVARCHAR(50) DEFAULT 'pending',
decline_reason NVARCHAR(MAX) DEFAULT '',
requested_quants NVARCHAR(MAX) DEFAULT '',
approved_quants NVARCHAR(MAX) DEFAULT '',
created_at DATETIME DEFAULT GETDATE()
)
''')
await conn.commit()
# Migration: Add quant columns if they don't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'requests' AND COLUMN_NAME = 'requested_quants')
ALTER TABLE requests ADD requested_quants NVARCHAR(MAX) DEFAULT ''
''')
await conn.commit()
except:
pass
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'requests' AND COLUMN_NAME = 'approved_quants')
ALTER TABLE requests ADD approved_quants NVARCHAR(MAX) DEFAULT ''
''')
await conn.commit()
except:
pass
# OAuth users table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='oauth_users' AND xtype='U')
CREATE TABLE oauth_users (
id NVARCHAR(255) PRIMARY KEY,
username NVARCHAR(255) NOT NULL,
email NVARCHAR(255),
avatar_url NVARCHAR(500),
session_token NVARCHAR(255),
role NVARCHAR(50) DEFAULT 'user',
created_at DATETIME DEFAULT GETDATE()
)
''')
await conn.commit()
# Migration: Add role column to oauth_users if it doesn't exist
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = 'oauth_users' AND COLUMN_NAME = 'role')
ALTER TABLE oauth_users ADD role NVARCHAR(50) DEFAULT 'user'
''')
await conn.commit()
except:
pass
# Tickets table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='tickets' AND xtype='U')
CREATE TABLE tickets (
id INT PRIMARY KEY IDENTITY(1,1),
request_id INT NOT NULL,
status NVARCHAR(50) DEFAULT 'open',
created_at DATETIME DEFAULT GETDATE(),
closed_at DATETIME,
FOREIGN KEY (request_id) REFERENCES requests(id)
)
''')
await conn.commit()
# Ticket messages table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='ticket_messages' AND xtype='U')
CREATE TABLE ticket_messages (
id INT PRIMARY KEY IDENTITY(1,1),
ticket_id INT NOT NULL,
sender NVARCHAR(255) NOT NULL,
sender_role NVARCHAR(50) NOT NULL,
message NVARCHAR(MAX) NOT NULL,
created_at DATETIME DEFAULT GETDATE(),
FOREIGN KEY (ticket_id) REFERENCES tickets(id)
)
''')
await conn.commit()
# User preferences table
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='user_preferences' AND xtype='U')
CREATE TABLE user_preferences (
hf_username NVARCHAR(255) PRIMARY KEY,
default_quants NVARCHAR(MAX) DEFAULT '',
theme NVARCHAR(20) DEFAULT 'dark',
notifications_enabled BIT DEFAULT 1,
auto_refresh_interval INT DEFAULT 30
)
''')
await conn.commit()
# Quant priority table for admin-configurable quantization order
await conn.execute('''
IF NOT EXISTS (SELECT * FROM sysobjects WHERE name='quant_priority' AND xtype='U')
CREATE TABLE quant_priority (
id INT PRIMARY KEY CHECK (id = 1),
priority_order NVARCHAR(MAX) DEFAULT '',
updated_at DATETIME DEFAULT GETDATE()
)
''')
await conn.commit()
# Initialize default quant priority if not exists
try:
await conn.execute('''
IF NOT EXISTS (SELECT * FROM quant_priority WHERE id = 1)
INSERT INTO quant_priority (id, priority_order) VALUES (1, '')
''')
await conn.commit()
except:
pass
# Performance indexes for frequently queried columns
index_queries = [
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_models_status') CREATE INDEX idx_models_status ON models(status)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_models_hf_repo_id') CREATE INDEX idx_models_hf_repo_id ON models(hf_repo_id)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_models_created_at') CREATE INDEX idx_models_created_at ON models(created_at DESC)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_requests_status') CREATE INDEX idx_requests_status ON requests(status)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_requests_requested_by') CREATE INDEX idx_requests_requested_by ON requests(requested_by)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_requests_created_at') CREATE INDEX idx_requests_created_at ON requests(created_at DESC)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_tickets_status') CREATE INDEX idx_tickets_status ON tickets(status)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_tickets_request_id') CREATE INDEX idx_tickets_request_id ON tickets(request_id)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_ticket_messages_ticket_id') CREATE INDEX idx_ticket_messages_ticket_id ON ticket_messages(ticket_id)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_oauth_users_session_token') CREATE INDEX idx_oauth_users_session_token ON oauth_users(session_token)",
"IF NOT EXISTS (SELECT * FROM sys.indexes WHERE name = 'idx_users_api_key') CREATE INDEX idx_users_api_key ON users(api_key)",
]
for query in index_queries:
try:
await conn.execute(query)
await conn.commit()
except:
pass # Index might already exist
async def test_connection() -> tuple[bool, str]:
"""Test database connection. Returns (success, message)."""
try:
conn = await get_db_connection()
await conn.execute("SELECT 1")
await conn.close()
return True, f"Successfully connected to {DB_TYPE} database"
except Exception as e:
return False, f"Failed to connect to {DB_TYPE} database: {str(e)}"