Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
## [v0.6.0]

## Added

1. Add a common engine for the app.
2. Add storage modules to keep the services layer lighter.

## Changed

1. Use a session factory for getting db sessions.
2. Reduce the maximum width of message bubbles.


---

## [v0.5.0]

### Added
Expand Down
4 changes: 2 additions & 2 deletions src/memorytext/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
QApplication,
)
from memorytext.windows.main_window import MainWindow
from memorytext.db import init_db
from memorytext.storage.db import startup

APP_NAME = "memory.text"


def main():
# Pass in sys.argv to allow command line arguments for your app.
init_db()
startup()
app = QApplication(sys.argv)
window = MainWindow()
window.show() # IMPORTANT!!!!! Windows are hidden by default.
Expand Down
6 changes: 6 additions & 0 deletions src/memorytext/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from platformdirs import user_data_path

APP_NAME = "memorytext"
USER_DIR = user_data_path(appname=APP_NAME, appauthor=False, ensure_exists=True)
DB_PATH = USER_DIR / f"{APP_NAME}-db.sqlite"
DB_URL = f"sqlite:///{DB_PATH}"
22 changes: 0 additions & 22 deletions src/memorytext/db.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/memorytext/delegates/message_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from PySide6.QtCore import Qt, QRect, QSize
from PySide6.QtGui import QFont, QFontMetrics, QColor, QPixmap
from memorytext.models.message_list import MessageList
from memorytext.db import BASE_DIR
from memorytext.config import USER_DIR

USER_COLOR = "#ABE7FF"
OTHER_COLOR = "#e0e0e0"
Expand Down Expand Up @@ -35,7 +35,7 @@ def paint(self, painter, option, index):
path = None

if attachment and attachment.startswith("STK"):
path = str(BASE_DIR / "media" / attachment)
path = str(USER_DIR / "media" / attachment)
extra_text = attachment + " (file attached)"
text = text.replace(extra_text, "")

Expand All @@ -53,7 +53,7 @@ def paint(self, painter, option, index):
ts_font = QFont(base_font)
ts_font.setPointSize(base_font.pointSize() - 3)

max_width = option.rect.width() * 0.8
max_width = option.rect.width() * 0.5
inner_width = max_width - 2 * self.padding

metrics = QFontMetrics(base_font)
Expand Down Expand Up @@ -166,7 +166,7 @@ def sizeHint(self, option, index):
is_same_sender = index.data(MessageList.IsSameSenderRole)
attachment = index.data(MessageList.AttachmentRole)

max_width = option.rect.width() * 0.8
max_width = option.rect.width() * 0.5
inner_width = max_width - 2 * self.padding
base_font = QFont(option.font)
ts_font = QFont(base_font)
Expand Down
6 changes: 2 additions & 4 deletions src/memorytext/io/import_wa.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from memorytext.db import get_db
from memorytext.storage.db import get_db_session
from memorytext.models.core import Conversation

if TYPE_CHECKING:
Expand All @@ -13,8 +12,7 @@ def import_whatsapp(
path: str | Path, title: str, tz: str = "Etc/UTC", username: str | None = None
):
conversation = Conversation.from_whatsapp(path, title, tz, username)
engine = get_db()
with Session(engine) as session:
with get_db_session() as session:
try:
with session.begin():
session.add(conversation)
Expand Down
10 changes: 7 additions & 3 deletions src/memorytext/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ class Base(DeclarativeBase):
message_tags = Table(
"message_tags",
Base.metadata,
Column("tag_id", ForeignKey("tag.id"), primary_key=True),
Column("message_id", ForeignKey("message.id"), primary_key=True),
Column("tag_id", ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True),
Column(
"message_id", ForeignKey("message.id", ondelete="CASCADE"), primary_key=True
),
)


Expand Down Expand Up @@ -95,7 +97,9 @@ class Message(Base):
tags: Mapped[List["Tag"]] = relationship(
secondary=message_tags, back_populates="messages"
)
conversation_id: Mapped[int] = mapped_column(ForeignKey("conversation.id"))
conversation_id: Mapped[int] = mapped_column(
ForeignKey("conversation.id", ondelete="CASCADE")
)
conversation: Mapped["Conversation"] = relationship(back_populates="messages")

def __str__(self) -> str:
Expand Down
53 changes: 29 additions & 24 deletions src/memorytext/services/chat_service.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,24 @@
from __future__ import annotations
from sqlalchemy.orm import Session, selectinload
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import select, update

from memorytext.models.core import Conversation
from memorytext.db import get_db
from memorytext.storage.db import get_db_session
import memorytext.storage.chats_repo as chrepo


def get_chats(count: int = 10):
engine = get_db()
with Session(engine) as session:
stmt = select(Conversation).options(selectinload(Conversation.messages))
conversations = session.scalars(stmt).fetchmany(count)
chat_list = [(c.id, c.title, c.participants) for c in conversations]
return chat_list
with get_db_session() as session:
return chrepo.get_chats(session, limit=count)


def get_participants(title: str):
engine = get_db()
with Session(engine) as session:
stmt = (
select(Conversation)
.where(Conversation.title == title)
.options(selectinload(Conversation.messages))
)
conversation = session.scalar(stmt)
if conversation:
participants = conversation.participants
return participants
return set()
with get_db_session() as session:
return chrepo.get_participants_by_title(session, title)


def set_username(title: str, username: str):
engine = get_db()
with Session(engine) as session:
with get_db_session() as session:
stmt = (
update(Conversation)
.where(Conversation.title == title)
Expand All @@ -47,8 +33,27 @@ def set_username(title: str, username: str):


def get_username(conversation_id: int | None):
engine = get_db()
with Session(engine) as session:
with get_db_session() as session:
stmt = select(Conversation.username).where(Conversation.id == conversation_id)
username = session.scalar(stmt)
return username


def delete_chat(conversation_id: int):
with get_db_session() as session:
try:
chrepo.delete_chat(session, conversation_id)
session.commit()
except SQLAlchemyError:
session.rollback()
raise


def delete_chat_by_title(title: str):
with get_db_session() as session:
try:
chrepo.delete_chat_by_title(session, title)
session.commit()
except SQLAlchemyError:
session.rollback()
raise
26 changes: 4 additions & 22 deletions src/memorytext/services/message_service.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
from __future__ import annotations
from sqlalchemy.orm import Session
from sqlalchemy import select, func

from memorytext.models.core import Message
from memorytext.db import get_db
import memorytext.storage.message_repo as mrepo
from memorytext.storage.db import get_db_session


def get_messages(
conversation_id: int | None, limit: int = 10, offset: int = 0
) -> tuple[list, int]:
engine = get_db()
with Session(engine) as session:
stmt = select(func.count(Message.id)).where(
Message.conversation_id == conversation_id
)
total = session.scalar(stmt) or 0

stmt_2 = (
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.timestamp.asc(), Message.id.asc())
.limit(limit)
.offset(offset)
)
message_list = list(session.scalars(stmt_2).fetchall())
session.expunge_all()
return message_list, total
with get_db_session() as session:
return mrepo.get_messages(session, conversation_id, limit, offset)
47 changes: 47 additions & 0 deletions src/memorytext/storage/chats_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import delete, select
from sqlalchemy.orm import selectinload

from memorytext.models.core import Conversation

if TYPE_CHECKING:
from sqlalchemy.orm import Session


def get_chats(session: Session, limit: int = 10):
stmt = select(Conversation).limit(limit)
conversations = session.scalars(stmt).fetchall()
chat_list = [(c.id, c.title, c.participants) for c in conversations]
return chat_list


def get_participants_by_id(session: Session, conversation_id: int | None):
stmt = (
select(Conversation)
.where(Conversation.id == conversation_id)
.options(selectinload(Conversation.messages))
)
conversation = session.scalar(stmt)
if conversation:
participants = conversation.participants
return participants
return set()


def get_participants_by_title(session: Session, title: str):
conversation_id = session.scalar(
select(Conversation.id).where(Conversation.title == title)
)
return get_participants_by_id(session, conversation_id)


def delete_chat(session: Session, conversation_id: int | None):
session.execute(delete(Conversation).where(Conversation.id == conversation_id))


def delete_chat_by_title(session: Session, title: str):
conversation_id = session.scalar(
select(Conversation.id).where(Conversation.title == title)
)
delete_chat(session, conversation_id)
49 changes: 49 additions & 0 deletions src/memorytext/storage/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import logging
from contextlib import contextmanager
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlite3 import Connection as SQLite3Connection

from memorytext.config import DB_URL
from memorytext.models.core import Base

if TYPE_CHECKING:
from sqlalchemy import Engine

engine: Engine = create_engine(DB_URL)

logging.basicConfig()
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)


@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
if isinstance(dbapi_connection, SQLite3Connection):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON;")
cursor.close()


SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)


def init_db():
"Initialize database from models schema."
Base.metadata.create_all(engine)


def startup():
"Application startup"
init_db()


@contextmanager
def get_db_session():
"Get a database session context manager."
Session = SessionLocal()
try:
yield Session
finally:
Session.close()
27 changes: 27 additions & 0 deletions src/memorytext/storage/message_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from sqlalchemy import select, func

from memorytext.models.core import Message

if TYPE_CHECKING:
from sqlalchemy.orm import Session


def get_messages(
session: Session, conversation_id: int | None, limit: int = 10, offset: int = 0
) -> tuple[list, int]:
stmt = select(func.count(Message.id)).where(
Message.conversation_id == conversation_id
)
total = session.scalar(stmt) or 0
stmt_2 = (
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.timestamp.asc(), Message.id.asc())
.limit(limit)
.offset(offset)
)
message_list = list(session.scalars(stmt_2).fetchall())
session.expunge_all()
return message_list, total
Loading