Skip to content

Commit 0c52297

Browse files
committed
Add message selection and file download features: implement MessageSelectContainer, enhance chat UI, and integrate file size retrieval
1 parent 5c82964 commit 0c52297

9 files changed

Lines changed: 352 additions & 97 deletions

File tree

SDK/StealthIM/apis/file.py

Whitespace-only changes.

src/db.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import StealthIM
99
import codes
10-
import log
10+
from StealthIM.apis.message import MessageType
1111

1212
DB_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../data/configs.sqlite"))
1313
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
@@ -66,6 +66,15 @@ class Nickname(Base):
6666
last_update = Column(DateTime, nullable=False, default=datetime.datetime.now(datetime.timezone.utc))
6767

6868

69+
class FileHash(Base):
70+
__tablename__ = "file_hashes"
71+
id = Column(Integer, primary_key=True, autoincrement=True)
72+
server_id = Column(Integer, nullable=False)
73+
group_id = Column(Integer, nullable=False)
74+
hash = Column(String, nullable=False)
75+
size = Column(Integer, nullable=False)
76+
77+
6978
Base.metadata.create_all(bind=engine)
7079

7180

@@ -74,6 +83,11 @@ def load_servers_from_db() -> list[Server]:
7483
return cast(list[Server], session.query(Server).all())
7584

7685

86+
def get_server_from_db(url: str) -> Optional[Server]:
87+
with SessionLocal() as session:
88+
return session.query(Server).filter_by(url=url).first()
89+
90+
7791
def save_server_to_db(name: str, url: str) -> None:
7892
with SessionLocal() as session:
7993
server = Server(name=name, url=url)
@@ -249,6 +263,21 @@ def add_message(
249263
return msg
250264

251265

266+
def recall_message(
267+
server_id: int,
268+
group_id: int,
269+
msgid: int,
270+
):
271+
with SessionLocal() as session:
272+
msg = session.query(Message).filter_by(server_id=server_id, group_id=group_id, msgid=msgid).first()
273+
if not msg:
274+
return
275+
msg.type = MessageType.Recall
276+
msg.msg = ""
277+
session.add(msg)
278+
session.commit()
279+
280+
252281
def get_latest_messages(
253282
server_id: int,
254283
group_id: int,
@@ -299,3 +328,31 @@ def get_messages(
299328
.from_statement(subquery.select().order_by(subquery.c.msgid.asc()))
300329
.all()
301330
)
331+
332+
333+
def add_file_size(server_id: int, group_id: int, hash_: str, size: int) -> None:
334+
with SessionLocal() as session:
335+
count = session.query(FileHash).count()
336+
if count >= 1000:
337+
# 删除最旧的项
338+
oldest = session.query(FileHash).order_by(FileHash.id.asc()).limit(count - 999).all()
339+
for item in oldest:
340+
session.delete(item)
341+
file_hash = FileHash(server_id=server_id, group_id=group_id, hash=hash_, size=size)
342+
session.add(file_hash)
343+
session.commit()
344+
345+
346+
async def get_file_size(group: StealthIM.Group, hash_str: str) -> int:
347+
server_id = cast(int, get_server_from_db(group.user.server.url).id)
348+
with SessionLocal() as session:
349+
res = session.query(FileHash).filter_by(server_id=server_id, group_id=group.group_id, hash=hash_str).first()
350+
if res:
351+
return cast(int, res.size)
352+
353+
size_res = await group.get_file_info(hash_str)
354+
if size_res.result.code == codes.SUCCESS:
355+
add_file_size(server_id, group.group_id, hash_str, size_res.size)
356+
return size_res.size
357+
return 0
358+

src/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class IMApp(App):
4343
SCREENS = {
4444
screen.SCREEN_NAME: screen for screen in ALL_SCREENS
4545
}
46-
BINDINGS = [("ctrl+b", "app_back", "返回上一屏")]
46+
BINDINGS = [("ctrl+b", "app_back", "Back")]
4747

4848
def __init__(self):
4949
super().__init__()

src/screens/chat.py

Lines changed: 143 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,28 @@
11
import asyncio
22
import datetime
3-
from typing import Optional
3+
import math
4+
from typing import Optional, cast
45

5-
from textual import on, work
6+
import platformdirs
7+
from textual import events, on, work
68
from textual.app import ComposeResult
9+
from textual.binding import Binding
710
from textual.containers import Horizontal, Right, Vertical, VerticalScroll
811
from textual.events import Click, Event, Key
9-
from textual.reactive import reactive
10-
from textual.widgets import Button, Label, ListItem, ListView, TextArea
12+
from textual.widgets import Button, Checkbox, Footer, Label, ListItem, ListView, TextArea
1113
from textual.worker import Worker
1214

1315
import StealthIM
1416
import codes
1517
import db
1618
import log
19+
import tools
20+
from StealthIM.apis.message import MessageType
1721
from patch import Screen, Container
1822
from .common import MessageData
1923
from .group_manage import InviteMemberScreen, JoinGroupScreen, CreateGroupScreen, ModifyGroupNameScreen, \
2024
ModifyGroupPasswordScreen, SetMemberScreen
21-
from .widgets import ChatMessage, FocusableLabel, PopupMenu, PopupPlane, TopDetectingScroll
25+
from .widgets import ChatMessage, FocusableLabel, Popup, PopupMenu, PopupPlane, TopDetectingScroll
2226

2327

2428
class GroupManagerContainer(Container):
@@ -183,13 +187,120 @@ async def on_set_member(self, _event) -> None:
183187
self.parent.flush_groups()
184188

185189

190+
class MessageSelectContainer(Container):
191+
DEFAULT_CSS = """
192+
MessageSelectContainer {
193+
width: 25%;
194+
padding: 1 2;
195+
border: round grey;
196+
background: $panel;
197+
}
198+
#message-select {
199+
height: 7fr;
200+
border: solid gray;
201+
padding-top: 3;
202+
padding-right: 1;
203+
}
204+
#message-select-keys {
205+
height: 3fr;
206+
}
207+
"""
208+
BINDINGS = [
209+
Binding("ctrl+z", "recall", "Recall", show=False),
210+
Binding("ctrl+d", "download", "Download", show=False),
211+
]
212+
213+
def __init__(self, message_list: VerticalScroll):
214+
super().__init__()
215+
self.message_count: Label | None = None
216+
self.checkbox_container: Container | None = None
217+
self.message_list = message_list
218+
self.last_int = None
219+
self.selected: list[ChatMessage] = []
220+
221+
def compose(self) -> ComposeResult:
222+
with Horizontal():
223+
yield Label("Selected: ")
224+
yield (message_count := Label("0"))
225+
self.message_count = message_count
226+
yield (checkbox_container := Container(id="message-select"))
227+
with Vertical(id="message-select-keys"):
228+
yield Label("Keys:")
229+
yield Label("Ctrl+z: Recall")
230+
yield Label("Ctrl+d: Download file")
231+
self.checkbox_container = checkbox_container
232+
233+
async def on_mount(self, event: events.Mount) -> None:
234+
await self.callback_scroll(0, 0)
235+
236+
async def callback_scroll(self, _, new):
237+
rounded = round(new)
238+
if not math.isclose(new, rounded, abs_tol=1e-6) or self.last_int == rounded:
239+
return
240+
self.last_int = rounded
241+
242+
if not self.is_mounted:
243+
return
244+
245+
messages = cast(list[ChatMessage], self.visible_children(self.message_list))
246+
247+
# 清空旧的 checkbox
248+
await self.checkbox_container.remove_children()
249+
250+
# 根据可见消息重新生成 checkbox
251+
for msg in messages:
252+
# 计算相对 y 位置(消息的 virtual_region 是在 scroll 坐标系下的)
253+
y = int(msg.virtual_region.y - self.message_list.scroll_y)
254+
255+
checkbox = Checkbox("Select", value=msg in self.selected)
256+
# 用 inline-style 定位 checkbox
257+
checkbox.styles.offset = (0, y)
258+
checkbox.styles.position = "absolute"
259+
checkbox.msg = msg
260+
261+
await self.checkbox_container.mount(checkbox)
262+
263+
@on(Checkbox.Changed)
264+
def on_check(self, event: Checkbox.Changed) -> None:
265+
checkbox = event.checkbox
266+
# noinspection PyUnresolvedReferences
267+
msg = checkbox.msg
268+
if event.value:
269+
self.selected.append(msg)
270+
else:
271+
self.selected.remove(msg)
272+
self.message_count.update(str(len(self.selected)))
273+
274+
def action_recall(self):
275+
return
276+
277+
async def action_download(self):
278+
if not (
279+
files := [msg for msg in self.selected if msg.type == MessageType.File.value]
280+
):
281+
self.notify("No message to download", severity="error")
282+
return
283+
284+
download_path = platformdirs.user_downloads_path()
285+
hashes = [msg.hash for msg in self.selected]
286+
filenames = [msg.text for msg in files]
287+
await self.app.data.group.download_files(hashes, filenames, download_path)
288+
289+
@staticmethod
290+
def visible_children(scroll_container: VerticalScroll):
291+
return [
292+
child for child in scroll_container.children
293+
if scroll_container.window_region.contains_region(child.virtual_region)
294+
]
295+
296+
186297
class ChatScreen(Screen):
187298
SCREEN_NAME = "Chat"
188299
CSS_PATH = "../../styles/chat.tcss"
189300

190301
LIMIT = 100
191302

192-
_push: reactive[bool] = reactive(True)
303+
BINDINGS = [("ctrl+s", "select_msg", "Select message")]
193304

194305
def __init__(self):
195306
super().__init__()
@@ -233,6 +344,7 @@ def compose(self) -> ComposeResult:
233344
with Right(id="tools"):
234345
yield Button("Send", id="send")
235346
yield Label("", id="status")
347+
yield Footer()
236348

237349
# Events
238350

@@ -286,7 +398,6 @@ async def on_change_group(self, event: ListView.Selected) -> None:
286398
# from_id=0, old_to_new=False means pull the latest messages
287399
gen = self.group.receive_text(from_id=0, old_to_new=False, sync=False, limit=self.LIMIT)
288400
msgs = [x async for x in gen][::-1]
289-
log.logger.error(msgs)
290401
for msg in msgs:
291402
message = db.add_message(
292403
self.app.data.server_db.id, self.group.group_id, msg.type.value,
@@ -359,6 +470,18 @@ async def on_send_by_key(self, event: Key) -> None:
359470
async def on_send_by_btn(self, _event: Event) -> None:
360471
self.do_send()
361472

473+
async def action_select_msg(self):
474+
if not self.group:
475+
self.notify("You need to select a group")
476+
return
477+
scroll = self.query_one("#messages", TopDetectingScroll)
478+
479+
container = MessageSelectContainer(scroll)
480+
self.watch(scroll, "scroll_y", container.callback_scroll)
481+
popup = Popup(container, position="left")
482+
self.mount(popup)
483+
await popup.show_popup()
484+
362485
# Helper functions
363486

364487
# Add a message in the scroll
@@ -372,11 +495,19 @@ async def add_message(self, scroll: VerticalScroll, message: MessageData, bottom
372495
attr = {"after": -1}
373496
else:
374497
attr = {"before": 0}
498+
499+
if message.type == MessageType.File.value:
500+
file_res = await db.get_file_size(self.group, message.hash)
501+
message.size = tools.int2size(int(file_res))
502+
375503
await scroll.mount(
376504
ChatMessage(message, self.app.data.user_db),
377505
**attr
378506
)
379507

508+
async def recall_message(self, scroll: VerticalScroll):
509+
...
510+
380511
@staticmethod
381512
async def get_group_members(group):
382513
res = await group.get_members()
@@ -486,7 +617,11 @@ async def get_messages(self, messages: VerticalScroll) -> None:
486617
)
487618
db.update_group_msgid(group_id, server_id, message.msgid)
488619

489-
await self.add_message(messages, self.build_msg_from_db(msg))
620+
if message.type != MessageType.Recall:
621+
await self.add_message(messages, self.build_msg_from_db(msg))
622+
else:
623+
db.recall_message(server_id, group_id, message.msgid)
624+
await self.recall_message(messages, message.msgid)
490625
except RuntimeError:
491626
pass
492627
except asyncio.CancelledError:

src/screens/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class MessageData:
3030
username: str
3131
hash: str
3232
nickname: Optional[str] = None
33+
size: Optional[str] = None
3334

3435

3536
@dataclasses.dataclass

0 commit comments

Comments
 (0)