-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
94 lines (74 loc) · 3.29 KB
/
database.py
File metadata and controls
94 lines (74 loc) · 3.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from sqlalchemy import create_engine, Column, Integer, String, DateTime, Text, ForeignKey, Table
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.sql import func
from datetime import datetime
Base = declarative_base()
meme_tags = Table(
'meme_tags',
Base.metadata,
Column('meme_id', Integer, ForeignKey('memes.id'), primary_key=True),
Column('tag_id', Integer, ForeignKey('tags.id'), primary_key=True)
)
class Meme(Base):
__tablename__ = 'memes'
id = Column(Integer, primary_key=True)
telegram_file_id = Column(String(255), unique=True, nullable=False)
s3_key = Column(String(500), nullable=False)
file_type = Column(String(50), nullable=False)
file_name = Column(String(255))
mime_type = Column(String(100))
file_size = Column(Integer)
caption = Column(Text)
telegram_user_id = Column(Integer, nullable=False)
telegram_username = Column(String(255))
created_at = Column(DateTime, default=func.now())
ocr_text = Column(Text)
clip_embedding = Column(Text)
image_description = Column(Text)
processed_at = Column(DateTime)
processing_error = Column(Text)
tags = relationship("Tag", secondary=meme_tags, back_populates="memes")
def __repr__(self):
return f"<Meme(id={self.id}, file_type={self.file_type}, created_at={self.created_at})>"
class Tag(Base):
__tablename__ = 'tags'
id = Column(Integer, primary_key=True)
name = Column(String(100), unique=True, nullable=False, index=True)
created_at = Column(DateTime, default=func.now())
memes = relationship("Meme", secondary=meme_tags, back_populates="tags")
def __repr__(self):
return f"<Tag(id={self.id}, name={self.name})>"
class Database:
def __init__(self, database_url):
if database_url.startswith('sqlite'):
self.engine = create_engine(
database_url,
echo=False,
pool_pre_ping=True
)
else:
self.engine = create_engine(database_url)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
def create_tables(self):
Base.metadata.create_all(bind=self.engine)
if self.engine.url.get_backend_name() == 'sqlite':
self._sqlite_migrate()
def _sqlite_migrate(self):
with self.engine.begin() as conn:
cols = conn.exec_driver_sql("PRAGMA table_info(memes)").fetchall()
existing = {row[1] for row in cols}
migrations = {
'ocr_text': 'ALTER TABLE memes ADD COLUMN ocr_text TEXT',
'clip_embedding': 'ALTER TABLE memes ADD COLUMN clip_embedding TEXT',
'image_description': 'ALTER TABLE memes ADD COLUMN image_description TEXT',
'processed_at': 'ALTER TABLE memes ADD COLUMN processed_at DATETIME',
'processing_error': 'ALTER TABLE memes ADD COLUMN processing_error TEXT',
}
for col, sql in migrations.items():
if col not in existing:
conn.exec_driver_sql(sql)
def get_session(self):
return self.SessionLocal()
def close_session(self, session):
session.close()