Skip to content

Commit c99cf7b

Browse files
committed
Refactor file download logic: enhance filename validation, ensure uniqueness, and update message retrieval methods
1 parent 0c52297 commit c99cf7b

3 files changed

Lines changed: 30 additions & 52 deletions

File tree

src/db.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ class Group(Base):
4141
group_id = Column(Integer, nullable=False)
4242
name = Column(String, nullable=False)
4343
last_update = Column(DateTime, nullable=False, default=datetime.datetime.now(datetime.timezone.utc))
44-
latest_msgid = Column(Integer, nullable=False)
4544

4645

4746
class Message(Base):
@@ -139,8 +138,8 @@ def update_user_session(user_id: int, new_session: str) -> None:
139138
def add_group(server_id: int, group_id: int, name: str) -> None:
140139
with SessionLocal() as session:
141140
group = Group(server_id=server_id, group_id=group_id, name=name,
142-
last_update=datetime.datetime.now(datetime.timezone.utc),
143-
latest_msgid=0)
141+
last_update=datetime.datetime.now(datetime.timezone.utc)
142+
)
144143
session.add(group)
145144
session.commit()
146145

@@ -154,14 +153,6 @@ def update_group_name(group_id: int, server_id: int, new_name: str) -> None:
154153
session.commit()
155154

156155

157-
def update_group_msgid(group_id: int, server_id: int, msgid: int) -> None:
158-
with SessionLocal() as session:
159-
group = session.query(Group).filter_by(group_id=group_id, server_id=server_id).first()
160-
if group:
161-
group.latest_msgid = msgid
162-
session.commit()
163-
164-
165156
def get_group_msgid(group_id: int, server_id: int, latest: bool = True) -> Optional[int]:
166157
if latest:
167158
data = get_messages(server_id, group_id, MAX_MSGID, False, 1)
@@ -175,7 +166,6 @@ def get_group_msgid(group_id: int, server_id: int, latest: bool = True) -> Optio
175166
return MAX_MSGID
176167

177168

178-
179169
async def get_group_name(server_id: int, user: StealthIM.User, group_id: int, force_flush=False) -> StealthIM.group.GroupPublicInfoResult:
180170
with SessionLocal() as session:
181171
group = session.query(Group).filter_by(group_id=group_id, server_id=server_id).first()

src/screens/chat.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import datetime
33
import math
4+
import os
45
from typing import Optional, cast
56

67
import platformdirs
@@ -283,7 +284,22 @@ async def action_download(self):
283284

284285
download_path = platformdirs.user_downloads_path()
285286
hashes = [msg.hash for msg in self.selected]
286-
filenames = [msg.text for msg in files]
287+
filenames = []
288+
for msg in files:
289+
filename = msg.text
290+
# Remove invalid characters
291+
filename = "".join(c for c in filename if c not in r'\/:*?"<>|')
292+
if not filename:
293+
filename = msg.hash
294+
# Ensure the filename is unique
295+
file_path = os.path.join(download_path, filename)
296+
if os.path.exists(file_path):
297+
base, ext = os.path.splitext(filename)
298+
i = 1
299+
while os.path.exists(os.path.join(download_path, f"{base}({i}){ext}")):
300+
i += 1
301+
filename = f"{base}({i}){ext}"
302+
filenames.append(filename)
287303
await self.app.data.group.download_files(hashes, filenames, download_path)
288304

289305
@staticmethod
@@ -298,7 +314,7 @@ class ChatScreen(Screen):
298314
SCREEN_NAME = "Chat"
299315
CSS_PATH = "../../styles/chat.tcss"
300316

301-
LIMIT = 100
317+
LIMIT = 10
302318

303319
BINDINGS = [("ctrl+s", "select_msg", "Select message")]
304320

@@ -396,7 +412,7 @@ async def on_change_group(self, event: ListView.Selected) -> None:
396412
else:
397413
# A new group, we only get the newest LIMIT messages
398414
# from_id=0, old_to_new=False means pull the latest messages
399-
gen = self.group.receive_text(from_id=0, old_to_new=False, sync=False, limit=self.LIMIT)
415+
gen = self.group.receive_latest_text(limit=self.LIMIT)
400416
msgs = [x async for x in gen][::-1]
401417
for msg in msgs:
402418
message = db.add_message(
@@ -442,8 +458,8 @@ async def on_no_more_message(self, _event: TopDetectingScroll.ScrolledToTop):
442458
else:
443459
# There's no more messages in the database, try to pull from server
444460
oldest_msgid = db.get_group_msgid(self.group.group_id, self.app.data.server_db.id, False)
445-
gen = self.group.receive_text(from_id=oldest_msgid, old_to_new=False, sync=False, limit=self.LIMIT)
446-
msgs = [x async for x in gen][::-1]
461+
gen = self.group.receive_text(from_id=oldest_msgid, sync=False, limit=self.LIMIT)
462+
msgs = [x async for x in gen]
447463
self.log(msgs)
448464
for msg in msgs:
449465
message = db.add_message(
@@ -605,23 +621,21 @@ async def get_messages(self, messages: VerticalScroll) -> None:
605621
group_id = self.group.group_id
606622

607623
while True:
608-
latest_msgid = db.get_group_msgid(group_id, server_id)
609624
try:
610-
gen = self.group.receive_text(from_id=latest_msgid)
625+
gen = self.group.receive_new_text(limit=self.LIMIT)
611626
async for message in gen:
612627
msg = db.add_message(
613628
server_id, group_id, message.type.value,
614629
message.msg.replace("\n", "\n\n"),
615630
datetime.datetime.fromtimestamp(int(message.time)), message.username,
616631
message.msgid, message.hash
617632
)
618-
db.update_group_msgid(group_id, server_id, message.msgid)
619633

620-
if message.type != MessageType.Recall:
621-
await self.add_message(messages, self.build_msg_from_db(msg))
622-
else:
623-
db.recall_message(server_id, group_id, message.msgid)
624-
await self.recall_message(messages, message.msgid)
634+
# if message.type != MessageType.Recall:
635+
await self.add_message(messages, self.build_msg_from_db(msg))
636+
# else:
637+
# db.recall_message(server_id, group_id, message.msgid)
638+
# await self.recall_message(messages, message.msgid)
625639
except RuntimeError:
626640
pass
627641
except asyncio.CancelledError:

src/screens/widgets.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -109,32 +109,6 @@ def on_changed(self, _event) -> None:
109109
self.selected = None
110110

111111

112-
# noinspection PyProtectedMember
113-
def add_global_hook(screen):
114-
if not hasattr(screen, "_popup_callback_click"):
115-
screen._popups = []
116-
117-
# noinspection PyUnresolvedReferences,PyProtectedMember
118-
async def callback(_self, event: Click) -> None:
119-
for obj in screen._popups:
120-
await obj.on_global_click(event)
121-
122-
screen._popup_callback_click = callback
123-
screen._decorated_handlers.setdefault(Click, [])
124-
screen._decorated_handlers[Click].append((callback, None))
125-
if not hasattr(screen, "_popup_callback_key"):
126-
screen._popups = []
127-
128-
# noinspection PyUnresolvedReferences,PyProtectedMember
129-
async def callback(_self, event: Key) -> None:
130-
for obj in screen._popups:
131-
await obj.on_global_key(event)
132-
133-
screen._popup_callback_key = callback
134-
screen._decorated_handlers.setdefault(Key, [])
135-
screen._decorated_handlers[Key].append((callback, None))
136-
137-
138112
class AbstractPopup(Widget):
139113
def __init__(self, id):
140114
super().__init__(id=id)
@@ -165,7 +139,7 @@ async def callback(_self, event: Key) -> None:
165139
screen._decorated_handlers[Key].append((callback, None))
166140

167141
def on_mount(self) -> None:
168-
add_global_hook(self.app.screen)
142+
self.add_global_hook(self.app.screen)
169143
# noinspection PyProtectedMember,PyUnresolvedReferences
170144
self.app.screen._popups.append(self)
171145

0 commit comments

Comments
 (0)