Skip to content

Commit fac4e7d

Browse files
committed
fix: added missing methods
1 parent 6fa11af commit fac4e7d

1 file changed

Lines changed: 108 additions & 8 deletions

File tree

userbot/src/db_manager.py

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,68 +12,145 @@
1212

1313
logger: logging.Logger = logging.getLogger(__name__)
1414

15-
# --- Account Management ---
16-
async def add_account(db: AsyncSession, account_name: str, api_id: str, api_hash: str, lang_code: str, is_enabled: bool, device_model: str, system_version: str, app_version: str, user_telegram_id: Optional[int] = None) -> Optional[Account]:
15+
# --- Account CRUD ---
16+
async def add_account(
17+
db: AsyncSession,
18+
account_name: str,
19+
api_id: str,
20+
api_hash: str,
21+
lang_code: str,
22+
is_enabled: bool,
23+
device_model: str,
24+
system_version: str,
25+
app_version: str,
26+
user_telegram_id: Optional[int] = None
27+
) -> Optional[Account]:
28+
"""Adds a new account to the database with encrypted credentials."""
1729
try:
18-
new_account = Account(account_name=account_name, api_id=encryption_manager.encrypt(str(api_id).encode('utf-8')), api_hash=encryption_manager.encrypt(str(api_hash).encode('utf-8')), lang_code=lang_code, is_enabled=is_enabled, device_model=device_model, system_version=system_version, app_version=app_version, user_telegram_id=user_telegram_id)
30+
new_account = Account(
31+
account_name=account_name,
32+
api_id=encryption_manager.encrypt(str(api_id).encode('utf-8')),
33+
api_hash=encryption_manager.encrypt(str(api_hash).encode('utf-8')),
34+
lang_code=lang_code,
35+
is_enabled=is_enabled,
36+
device_model=device_model,
37+
system_version=system_version,
38+
app_version=app_version,
39+
user_telegram_id=user_telegram_id
40+
)
1941
db.add(new_account)
2042
await db.flush()
43+
logger.info(f"Added account '{account_name}' with ID: {new_account.account_id}")
2144
return new_account
2245
except IntegrityError:
46+
logger.warning(f"Account with name '{account_name}' or user_id '{user_telegram_id}' already exists.")
2347
await db.rollback()
2448
return None
25-
except Exception:
49+
except Exception as e:
50+
logger.error(f"Error adding account '{account_name}': {e}")
2651
await db.rollback()
2752
raise
2853

2954
async def get_account(db: AsyncSession, account_name: str) -> Optional[Account]:
55+
"""Retrieves an account by its name."""
3056
result = await db.execute(select(Account).where(Account.account_name == account_name))
3157
return result.scalars().first()
3258

3359
async def get_account_by_user_id(db: AsyncSession, user_id: int) -> Optional[Account]:
60+
"""Retrieves an account by its Telegram User ID."""
3461
result = await db.execute(select(Account).where(Account.user_telegram_id == user_id))
3562
return result.scalars().first()
3663

3764
async def get_all_accounts(db: AsyncSession) -> List[Account]:
65+
"""Retrieves all accounts from the database, with their sessions for status display."""
3866
result = await db.execute(select(Account).options(selectinload(Account.session)).order_by(Account.account_id))
3967
return result.scalars().all()
4068

4169
async def get_all_active_accounts(db: AsyncSession) -> List[Account]:
70+
"""Retrieves all enabled accounts from the database."""
4271
result = await db.execute(select(Account).where(Account.is_enabled == True))
4372
return result.scalars().all()
4473

4574
async def delete_account(db: AsyncSession, account_name: str) -> bool:
75+
"""Deletes an account by its name."""
4676
account = await get_account(db, account_name)
4777
if not account: return False
4878
await db.delete(account)
4979
await db.flush()
5080
return True
5181

5282
async def toggle_account_status(db: AsyncSession, account_name: str) -> Optional[bool]:
83+
"""Toggles the is_enabled status of an account."""
5384
account = await get_account(db, account_name)
5485
if not account: return None
5586
account.is_enabled = not account.is_enabled
5687
await db.flush()
5788
return account.is_enabled
5889

5990
async def update_account_lang(db: AsyncSession, account_id: int, lang_code: str) -> bool:
91+
"""Updates the language for a specific account."""
6092
stmt = update(Account).where(Account.account_id == account_id).values(lang_code=lang_code)
6193
result = await db.execute(stmt)
6294
return result.rowcount > 0
6395

96+
# --- Session CRUD ---
97+
async def get_session(db: AsyncSession, account_id: int) -> Optional[Session]:
98+
"""Retrieves a session for a given account and decrypts its auth key."""
99+
result = await db.execute(select(Session).where(Session.account_id == account_id))
100+
session = result.scalars().first()
101+
if session and session.auth_key_data:
102+
try:
103+
session.auth_key_data = encryption_manager.decrypt(session.auth_key_data)
104+
except Exception as e:
105+
logger.error(f"Failed to decrypt session auth_key for account {account_id}: {e}")
106+
return None
107+
return session
108+
109+
async def add_or_update_session(db: AsyncSession, **kwargs) -> Optional[Session]:
110+
"""Adds or updates a session in the database, encrypting the auth key."""
111+
account_id = kwargs.get("account_id")
112+
if not account_id: return None
113+
114+
# Use a raw select to avoid decryption/re-encryption cycle of get_session
115+
result = await db.execute(select(Session).where(Session.account_id == account_id))
116+
session = result.scalars().first()
117+
118+
if not session:
119+
session = Session(account_id=account_id)
120+
db.add(session)
121+
122+
for key, value in kwargs.items():
123+
if key == "auth_key_data" and value is not None:
124+
value = encryption_manager.encrypt(value)
125+
setattr(session, key, value)
126+
127+
session.last_used_at = datetime.now(timezone.utc)
128+
await db.flush()
129+
return session
130+
131+
async def delete_session(db: AsyncSession, account_id: int) -> bool:
132+
"""Deletes a session from the database."""
133+
stmt = delete(Session).where(Session.account_id == account_id)
134+
result = await db.execute(stmt)
135+
await db.flush()
136+
return result.rowcount > 0
137+
64138
# --- Module Management ---
65139
async def get_module(db: AsyncSession, module_name: str) -> Optional[Module]:
140+
"""Retrieves a module by its name."""
66141
result = await db.execute(select(Module).where(Module.module_name == module_name))
67142
return result.scalars().first()
68143

69144
async def get_all_modules(db: AsyncSession) -> List[Module]:
145+
"""Retrieves all modules from the database."""
70146
result = await db.execute(select(Module))
71147
return result.scalars().all()
72148

73149
async def add_module(db: AsyncSession, module_name: str, module_path: str) -> Optional[Module]:
150+
"""Adds a new module or returns the existing one."""
74151
module = await get_module(db, module_name)
75152
if module:
76-
module.module_path = module_path # Update path if it changed
153+
module.module_path = module_path
77154
await db.flush()
78155
return module
79156
new_module = Module(module_name=module_name, module_path=module_path)
@@ -82,10 +159,14 @@ async def add_module(db: AsyncSession, module_name: str, module_path: str) -> Op
82159
return new_module
83160

84161
async def get_account_module(db: AsyncSession, account_id: int, module_id: int) -> Optional[AccountModule]:
85-
result = await db.execute(select(AccountModule).where(AccountModule.account_id == account_id, AccountModule.module_id == module_id))
162+
"""Retrieves the link between an account and a module."""
163+
result = await db.execute(
164+
select(AccountModule).where(AccountModule.account_id == account_id, AccountModule.module_id == module_id)
165+
)
86166
return result.scalars().first()
87167

88168
async def link_module_to_account(db: AsyncSession, account_id: int, module_id: int, is_active: bool, configuration: Optional[Dict[str, Any]]) -> Optional[AccountModule]:
169+
"""Links a module to an account or updates the existing link."""
89170
link = await get_account_module(db, account_id, module_id)
90171
if not link:
91172
link = AccountModule(account_id=account_id, module_id=module_id)
@@ -97,23 +178,38 @@ async def link_module_to_account(db: AsyncSession, account_id: int, module_id: i
97178
return link
98179

99180
async def get_active_modules_for_account(db: AsyncSession, account_id: int) -> List[Dict[str, Any]]:
100-
stmt = select(Module, AccountModule.is_trusted, AccountModule.configuration).join(AccountModule, Module.module_id == AccountModule.module_id).where(AccountModule.account_id == account_id, AccountModule.is_active == True)
181+
"""Retrieves all active modules for a specific account."""
182+
stmt = (
183+
select(Module, AccountModule.is_trusted, AccountModule.configuration)
184+
.join(AccountModule, Module.module_id == AccountModule.module_id)
185+
.where(AccountModule.account_id == account_id, AccountModule.is_active == True)
186+
)
101187
results = await db.execute(stmt)
102-
return [{'module': module, 'is_trusted': is_trusted, 'configuration': configuration} for module, is_trusted, configuration in results.all()]
188+
return [
189+
{
190+
'module': module,
191+
'is_trusted': is_trusted,
192+
'configuration': configuration
193+
}
194+
for module, is_trusted, configuration in results.all()
195+
]
103196

104197
async def unlink_module_from_account(db: AsyncSession, account_id: int, module_id: int) -> bool:
198+
"""Unlinks a module from an account."""
105199
stmt = delete(AccountModule).where(AccountModule.account_id == account_id, AccountModule.module_id == module_id)
106200
result = await db.execute(stmt)
107201
return result.rowcount > 0
108202

109203
async def set_module_trust_status(db: AsyncSession, account_id: int, module_id: int, is_trusted: bool) -> bool:
204+
"""Sets the trust status for a module link."""
110205
link = await get_account_module(db, account_id, module_id)
111206
if not link: return False
112207
link.is_trusted = is_trusted
113208
await db.flush()
114209
return True
115210

116211
async def update_module_config(db: AsyncSession, account_id: int, module_id: int, config_key: str, config_value: Any) -> bool:
212+
"""Updates a single key in a module's JSONB configuration."""
117213
link = await get_account_module(db, account_id, module_id)
118214
if not link: return False
119215
if link.configuration is None:
@@ -126,6 +222,7 @@ async def update_module_config(db: AsyncSession, account_id: int, module_id: int
126222

127223
# --- Log Management ---
128224
async def add_logs_bulk(db: AsyncSession, logs: List[Dict[str, Any]]) -> None:
225+
"""Adds a batch of log entries to the database."""
129226
if not logs: return
130227
try:
131228
db.add_all([Log(**log_data) for log_data in logs])
@@ -134,6 +231,7 @@ async def add_logs_bulk(db: AsyncSession, logs: List[Dict[str, Any]]) -> None:
134231
print(f"CRITICAL: Error during bulk log insert: {e}")
135232

136233
async def get_logs_filtered(db: AsyncSession, limit: int, level: Optional[str] = None, source: Optional[str] = None) -> List[Log]:
234+
"""Retrieves logs from the database with optional filters."""
137235
stmt = select(Log).order_by(Log.timestamp.desc()).limit(limit)
138236
if level:
139237
stmt = stmt.where(Log.level == level.upper())
@@ -143,12 +241,14 @@ async def get_logs_filtered(db: AsyncSession, limit: int, level: Optional[str] =
143241
return result.scalars().all()
144242

145243
async def delete_old_logs(db: AsyncSession, days_to_keep: int) -> int:
244+
"""Deletes log entries older than a specified number of days."""
146245
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
147246
stmt = delete(Log).where(Log.timestamp < cutoff_date)
148247
result = await db.execute(stmt)
149248
return result.rowcount
150249

151250
async def purge_logs(db: AsyncSession) -> int:
251+
"""Deletes all log entries from the database."""
152252
stmt = delete(Log)
153253
result = await db.execute(stmt)
154254
return result.rowcount

0 commit comments

Comments
 (0)