From d359b9e2da9308805628106dde92588d371218c3 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 9 Jun 2026 12:15:39 +0800 Subject: [PATCH 01/12] refactor: Refactor entity models and access rules - Moved `NoActiveRevisionsError` import to `include.exceptions.misc` for better organization. - Updated import paths for `Document`, `DocumentRevision`, and `Folder` to reflect new structure in `obj.py`. - Created `base.py` for shared base functionality among entity models. - Introduced `metadata.py` for document metadata handling. - Added `obj.py` to define core entity models including `Document`, `DocumentRevision`, `Folder`, and their access rules. - Implemented `batch_count_other_revisions` in `count.py` for centralized reference counting. - Refactored `purge.py` to utilize new counting method. - Adjusted `document.py` and `search.py` to align with new import paths. --- .../database/models/entity/__init__.py | 3 + src/include/database/models/entity/base.py | 239 ++++++++++++++ .../database/models/entity/metadata.py | 15 + .../models/{entity.py => entity/obj.py} | 298 +----------------- src/include/handlers/document.py | 2 +- src/include/handlers/search.py | 2 +- src/include/util/bulk/count.py | 69 ++++ src/include/util/bulk/purge.py | 4 +- src/main.py | 2 +- 9 files changed, 346 insertions(+), 288 deletions(-) create mode 100644 src/include/database/models/entity/__init__.py create mode 100644 src/include/database/models/entity/base.py create mode 100644 src/include/database/models/entity/metadata.py rename src/include/database/models/{entity.py => entity/obj.py} (55%) create mode 100644 src/include/util/bulk/count.py diff --git a/src/include/database/models/entity/__init__.py b/src/include/database/models/entity/__init__.py new file mode 100644 index 0000000..4acfaaa --- /dev/null +++ b/src/include/database/models/entity/__init__.py @@ -0,0 +1,3 @@ +from .base import * +from .metadata import * +from .obj import * diff --git a/src/include/database/models/entity/base.py b/src/include/database/models/entity/base.py new file mode 100644 index 0000000..6840bc0 --- /dev/null +++ b/src/include/database/models/entity/base.py @@ -0,0 +1,239 @@ +__all__ = ["BaseObject"] + +import time +from typing import List, Literal, Optional, cast + +from sqlalchemy import VARCHAR, Integer +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm.session import object_session + +from include.classes.access_rule import AccessRuleBase +from include.classes.enum.status import EntityStatus +from include.conf_loader import global_config +from include.constants import AVAILABLE_ACCESS_TYPES +from include.database.handler import Base +from include.database.models.classic import User +from include.util.fetch.fetch import batch_prefetch_granted_ids, prefetch_user_blocks + + +class BaseObject(Base): + __abstract__ = True + + id: Mapped[str] + access_rules: Mapped[List] + + # Whether to inherit access rules from parent folders. + # Useful when enabling recursion check. + inherit: Mapped[bool] + + status: Mapped[EntityStatus] = mapped_column( + Integer, nullable=False, default=EntityStatus.OK + ) + status_operation_id: Mapped[Optional[str]] = mapped_column( + VARCHAR(255), nullable=True, index=True + ) + + def check_access_requirements( + self, user: User, access_type: str = "read", _no_recursive_check=False + ) -> bool: + """ + Checks if a given user meets the access requirements for a specific access type based on defined access rules. + Args: + user (User): The user object whose permissions and groups are to be checked. + access_type (int, optional): The type of access to check for. Defaults to `"read"`. + _no_recursive_check (bool, optional): Useful when performing batch queries. Defaults to False. + Returns: + bool: True if the user meets all access requirements for the specified access type, False otherwise. + Raises: + ValueError: If the "match" value in any rule is not "all" or "any". + Access rules are evaluated as follows: + - Each rule may specify required permissions ("rights") and/or groups ("groups"). + - Each requirement can specify a "match" mode: "all" (all required items must be present) or "any" (at least one must be present). + - Rules are grouped and evaluated according to their match modes and requirements. + - If no access rules are defined, access is granted by default. + """ + + _TARGET_TYPE_MAPPING = {"folders": "directory", "documents": "document"} + + def match_rights(sub_rights_group): + if not sub_rights_group: + return True + + sub_match_mode = sub_rights_group.get("match", "all") + sub_rights_require = sub_rights_group.get("require", []) + + if not sub_rights_require: + return True + + if sub_match_mode == "all": + return set(sub_rights_require).issubset(user.all_permissions) + + elif sub_match_mode == "any": + for right in sub_rights_require: + if right in user.all_permissions: + return True + return False + + else: + raise ValueError('the value of "match" must be "all" or "any"') + + def match_groups(sub_groups_group): + if not sub_groups_group: + return True + + sub_match_mode = sub_groups_group.get("match", "all") + sub_groups_require = sub_groups_group.get("require", []) + + if not sub_groups_require: + return True + + if sub_match_mode == "all": + return set(sub_groups_require).issubset(user.all_groups) + + elif sub_match_mode == "any": + for group in sub_groups_require: + if group in user.all_groups: + return True + return False + else: + raise ValueError('the value of "match" must be "all" or "any"') + + def match_sub_group(sub_group): + sub_match_mode = sub_group.get("match", "all") + sub_rights_group = sub_group.get("rights", {}) + sub_groups_group = sub_group.get("groups", {}) + + if not (sub_rights_group.get("require", [])) or ( + not sub_groups_group.get("require", []) + ): + sub_match_mode = "all" + + if sub_match_mode == "any": + return match_rights(sub_rights_group) or match_groups(sub_groups_group) + if sub_match_mode == "all": + return match_rights(sub_rights_group) and match_groups(sub_groups_group) + else: + raise ValueError('the value of "match" must be "all" or "any"') + + def match_primary_sub_group(per_match_group): + match_mode = per_match_group.get("match", "all") + for sub_group in per_match_group["match_groups"]: + if not sub_group: + continue + + state = match_sub_group(sub_group) + + if match_mode == "any": + if state: + return True + elif match_mode == "all": + if not state: + return False + + if match_mode == "any": + return False + elif match_mode == "all": + return True + + # Checks whether the user or the user group to which he belongs + # has special access rights to this object. + + # Get `session` from `User` object + _session = object_session(user) + if not _session: + raise RuntimeError("No active session found for user") + + now = time.time() + + # check user blocks first + is_globally_blocked, blocked_ids = prefetch_user_blocks( + _session, user, access_type, now + ) + if is_globally_blocked or self.id in blocked_ids: + return False + + # then check special access entries + self_type = cast( + Literal["document", "directory"], _TARGET_TYPE_MAPPING[self.__tablename__] + ) + explicitly_granted_ids = batch_prefetch_granted_ids( + _session, user, [self.id], self_type, access_type, now + ) + + if self.id in explicitly_granted_ids: + return True + + if ( + global_config["access"]["enable_access_recursive_check"] + and self.inherit + and not _no_recursive_check + ): + # FIXME: Use lazy import when Python 3.15 is out + from include.database.models.entity.obj import Document, Folder + + # check all parent folders' access rules + parent = None + if isinstance(self, Document): + parent = self.folder + elif isinstance(self, Folder): + parent = self.parent + + visited_folder_ids = set() + while parent is not None: + if parent.id in visited_folder_ids: + # Cycle detected; break to prevent an infinite loop + raise RuntimeError("Cycle detected in folder hierarchy") + visited_folder_ids.add(parent.id) + + if not parent.check_access_requirements(user, access_type=access_type): + return False + + if not parent.inherit: + break # if the parent folder does not inherit, stop checking further up + + parent = parent.parent + + if not self.access_rules: + return True + + for each_rule in self.access_rules: + if not each_rule: + continue + + each_rule: AccessRuleBase + + # access_type 一览: + # read - 读 + # write - 写(删除=清空数据,重命名=写文件元数据,因此都算写) + # move - 移动 + # manage - 管理 + + if access_type not in AVAILABLE_ACCESS_TYPES: + raise ValueError( + f"Invalid access type for {self.__tablename__}: {access_type}" + ) + + match access_type: + case "read": # 如果要检查的是读权限 + if each_rule.access_type != "read": + continue + case "write": # 如果要检查写权限 + if each_rule.access_type not in ["read", "write"]: + continue + case "move": + # 取消了对读权限的要求 + if each_rule.access_type != "move": + continue + case "manage": # 如果要检查管理权限 + if each_rule.access_type not in ["read", "manage"]: + continue + case _: + raise NotImplementedError("Unsupported access type") + + if not each_rule.rule_data: + continue + + if not match_primary_sub_group(each_rule.rule_data): + return False + + return True diff --git a/src/include/database/models/entity/metadata.py b/src/include/database/models/entity/metadata.py new file mode 100644 index 0000000..5d135ae --- /dev/null +++ b/src/include/database/models/entity/metadata.py @@ -0,0 +1,15 @@ +__all__ = ["DocumentMetadata"] + +from sqlalchemy import VARCHAR, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column + +from include.database.handler import Base + + +class DocumentMetadata(Base): + __tablename__ = "document_metadata" + + document_id: Mapped[str] = mapped_column( + VARCHAR(64), ForeignKey("document.id", ondelete="CASCADE"), primary_key=True + ) + # TODO: Add more metadata fields as needed, e.g. author, creation date, etc. diff --git a/src/include/database/models/entity.py b/src/include/database/models/entity/obj.py similarity index 55% rename from src/include/database/models/entity.py rename to src/include/database/models/entity/obj.py index 1191c1b..2a1a776 100644 --- a/src/include/database/models/entity.py +++ b/src/include/database/models/entity/obj.py @@ -1,301 +1,33 @@ +__all__ = [ + "Document", + "DocumentRevision", + "DocumentAccessRule", + "Folder", + "FolderAccessRule", +] + import secrets import time from itertools import batched -from typing import Iterable, List, Literal, Optional, Union, cast +from typing import List, Optional -from sqlalchemy import JSON, VARCHAR, Boolean, Float, ForeignKey, Integer, func -from sqlalchemy.orm import Mapped, Session, mapped_column, relationship +from sqlalchemy import JSON, VARCHAR, Boolean, Float, ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm.session import object_session from include.classes.access_rule import AccessRuleBase from include.classes.enum.status import DocumentRevisionStatus, EntityStatus -from include.conf_loader import global_config -from include.constants import AVAILABLE_ACCESS_TYPES, MAX_PARAM_SIZE, QUERY_CHUNK_SIZE +from include.constants import QUERY_CHUNK_SIZE from include.database.handler import Base -from include.database.models.classic import User +from include.database.models.entity.base import BaseObject from include.database.models.file import ( File, FileTask, _queue_deferred_file_deletion, ) from include.exceptions.misc import NoActiveRevisionsError +from include.util.bulk.count import batch_count_other_revisions from include.util.count import count_file_references -from include.util.fetch.fetch import batch_prefetch_granted_ids, prefetch_user_blocks - - -def _batch_count_other_revisions( - session: Session, - file_ids: Iterable[str], - exclude_doc_ids: Union[str, Iterable[str]], -) -> dict[str, int]: - """ - 计算引用了指定 file_ids 的修订版本数量,但排除属于 exclude_doc_ids 集合的文档。 - - 参数: - file_ids: 待检查的文件 ID 列表。 - exclude_doc_ids: 单个文档 ID 或文档 ID 集合。这些文档对文件的引用将不被计入。 - """ - # Return total references to each file EXCLUDING references that come - # from the provided exclude_doc_ids set. This centralizes counting via - # `count_file_references` so new FK references are automatically handled. - - # Materialize iterables so they can be safely iterated multiple times. - file_ids_list = list(file_ids) - if not file_ids_list: - return {} - - if isinstance(exclude_doc_ids, str): - exclude_doc_ids_list = [exclude_doc_ids] - else: - exclude_doc_ids_list = list(exclude_doc_ids) - - # Get total references across all tables (uses reflected FKs). - total_refs = count_file_references(session, file_ids_list) - - # Count references coming from the excluded documents (DocumentRevision only). - # The query combines `file_id IN (...)` and `document_id IN (...)`, so both - # dimensions must be chunked to keep total bind variables under MAX_PARAM_SIZE. - exclude_chunk_size = max(1, MAX_PARAM_SIZE - QUERY_CHUNK_SIZE) - excluded_counts: dict[str, int] = {} - for f_chunk in batched(file_ids_list, QUERY_CHUNK_SIZE): - for e_chunk in batched(exclude_doc_ids_list, exclude_chunk_size): - rows = ( - session.query(DocumentRevision.file_id, func.count(DocumentRevision.id)) - .filter(DocumentRevision.file_id.in_(list(f_chunk))) - .filter(DocumentRevision.document_id.in_(list(e_chunk))) - .group_by(DocumentRevision.file_id) - .all() - ) - for file_id, count in rows: - excluded_counts[file_id] = excluded_counts.get(file_id, 0) + count - - counts: dict[str, int] = {} - for fid in file_ids_list: - total = total_refs.get(fid, 0) - excluded = excluded_counts.get(fid, 0) - counts[fid] = max(0, total - excluded) - - return counts - - -class BaseObject(Base): - __abstract__ = True - - id: Mapped[str] - access_rules: Mapped[List] - - # Whether to inherit access rules from parent folders. - # Useful when enabling recursion check. - inherit: Mapped[bool] - - status: Mapped[EntityStatus] = mapped_column( - Integer, nullable=False, default=EntityStatus.OK - ) - status_operation_id: Mapped[Optional[str]] = mapped_column( - VARCHAR(255), nullable=True, index=True - ) - - def check_access_requirements( - self, user: User, access_type: str = "read", _no_recursive_check=False - ) -> bool: - """ - Checks if a given user meets the access requirements for a specific access type based on defined access rules. - Args: - user (User): The user object whose permissions and groups are to be checked. - access_type (int, optional): The type of access to check for. Defaults to `"read"`. - _no_recursive_check (bool, optional): Useful when performing batch queries. Defaults to False. - Returns: - bool: True if the user meets all access requirements for the specified access type, False otherwise. - Raises: - ValueError: If the "match" value in any rule is not "all" or "any". - Access rules are evaluated as follows: - - Each rule may specify required permissions ("rights") and/or groups ("groups"). - - Each requirement can specify a "match" mode: "all" (all required items must be present) or "any" (at least one must be present). - - Rules are grouped and evaluated according to their match modes and requirements. - - If no access rules are defined, access is granted by default. - """ - - _TARGET_TYPE_MAPPING = {"folders": "directory", "documents": "document"} - - def match_rights(sub_rights_group): - if not sub_rights_group: - return True - - sub_match_mode = sub_rights_group.get("match", "all") - sub_rights_require = sub_rights_group.get("require", []) - - if not sub_rights_require: - return True - - if sub_match_mode == "all": - return set(sub_rights_require).issubset(user.all_permissions) - - elif sub_match_mode == "any": - for right in sub_rights_require: - if right in user.all_permissions: - return True - return False - - else: - raise ValueError('the value of "match" must be "all" or "any"') - - def match_groups(sub_groups_group): - if not sub_groups_group: - return True - - sub_match_mode = sub_groups_group.get("match", "all") - sub_groups_require = sub_groups_group.get("require", []) - - if not sub_groups_require: - return True - - if sub_match_mode == "all": - return set(sub_groups_require).issubset(user.all_groups) - - elif sub_match_mode == "any": - for group in sub_groups_require: - if group in user.all_groups: - return True - return False - else: - raise ValueError('the value of "match" must be "all" or "any"') - - def match_sub_group(sub_group): - sub_match_mode = sub_group.get("match", "all") - sub_rights_group = sub_group.get("rights", {}) - sub_groups_group = sub_group.get("groups", {}) - - if not (sub_rights_group.get("require", [])) or ( - not sub_groups_group.get("require", []) - ): - sub_match_mode = "all" - - if sub_match_mode == "any": - return match_rights(sub_rights_group) or match_groups(sub_groups_group) - if sub_match_mode == "all": - return match_rights(sub_rights_group) and match_groups(sub_groups_group) - else: - raise ValueError('the value of "match" must be "all" or "any"') - - def match_primary_sub_group(per_match_group): - match_mode = per_match_group.get("match", "all") - for sub_group in per_match_group["match_groups"]: - if not sub_group: - continue - - state = match_sub_group(sub_group) - - if match_mode == "any": - if state: - return True - elif match_mode == "all": - if not state: - return False - - if match_mode == "any": - return False - elif match_mode == "all": - return True - - # Checks whether the user or the user group to which he belongs - # has special access rights to this object. - - # Get `session` from `User` object - _session = object_session(user) - if not _session: - raise RuntimeError("No active session found for user") - - now = time.time() - - # check user blocks first - is_globally_blocked, blocked_ids = prefetch_user_blocks( - _session, user, access_type, now - ) - if is_globally_blocked or self.id in blocked_ids: - return False - - # then check special access entries - self_type = cast( - Literal["document", "directory"], _TARGET_TYPE_MAPPING[self.__tablename__] - ) - explicitly_granted_ids = batch_prefetch_granted_ids( - _session, user, [self.id], self_type, access_type, now - ) - - if self.id in explicitly_granted_ids: - return True - - if ( - global_config["access"]["enable_access_recursive_check"] - and self.inherit - and not _no_recursive_check - ): - # check all parent folders' access rules - parent = None - if isinstance(self, Document): - parent = self.folder - elif isinstance(self, Folder): - parent = self.parent - - visited_folder_ids = set() - while parent is not None: - if parent.id in visited_folder_ids: - # Cycle detected; break to prevent an infinite loop - raise RuntimeError("Cycle detected in folder hierarchy") - visited_folder_ids.add(parent.id) - - if not parent.check_access_requirements(user, access_type=access_type): - return False - - if not parent.inherit: - break # if the parent folder does not inherit, stop checking further up - - parent = parent.parent - - if not self.access_rules: - return True - - for each_rule in self.access_rules: - if not each_rule: - continue - - each_rule: AccessRuleBase - - # access_type 一览: - # read - 读 - # write - 写(删除=清空数据,重命名=写文件元数据,因此都算写) - # move - 移动 - # manage - 管理 - - if access_type not in AVAILABLE_ACCESS_TYPES: - raise ValueError( - f"Invalid access type for {self.__tablename__}: {access_type}" - ) - - match access_type: - case "read": # 如果要检查的是读权限 - if each_rule.access_type != "read": - continue - case "write": # 如果要检查写权限 - if each_rule.access_type not in ["read", "write"]: - continue - case "move": - # 取消了对读权限的要求 - if each_rule.access_type != "move": - continue - case "manage": # 如果要检查管理权限 - if each_rule.access_type not in ["read", "manage"]: - continue - case _: - raise NotImplementedError("Unsupported access type") - - if not each_rule.rule_data: - continue - - if not match_primary_sub_group(each_rule.rule_data): - return False - - return True class Folder(BaseObject): # 文档文件夹 @@ -472,7 +204,7 @@ def delete_all_revisions(self, do_commit: bool = True): all_file_ids = {row[1] for row in revision_tuples if row[1]} # Task 3: Chunked batch reference count queries to avoid variable limit. - other_counts = _batch_count_other_revisions(session, all_file_ids, self.id) + other_counts = batch_count_other_revisions(session, all_file_ids, self.id) # Determine which files are exclusively referenced by this document's revisions. deletable_file_ids = { diff --git a/src/include/handlers/document.py b/src/include/handlers/document.py index 858bccd..0241a5b 100644 --- a/src/include/handlers/document.py +++ b/src/include/handlers/document.py @@ -28,9 +28,9 @@ Document, DocumentRevision, Folder, - NoActiveRevisionsError, ) from include.database.models.file import File, FileTask +from include.exceptions.misc import NoActiveRevisionsError from include.handlers.base import RequestHandler from include.system.messages import Messages as smsg from include.util.check import ( diff --git a/src/include/handlers/search.py b/src/include/handlers/search.py index 5024427..b67c7c5 100644 --- a/src/include/handlers/search.py +++ b/src/include/handlers/search.py @@ -15,7 +15,7 @@ from include.conf_loader import global_config from include.database.handler import Session from include.database.models.classic import User -from include.database.models.entity import NoActiveRevisionsError +from include.exceptions.misc import NoActiveRevisionsError from include.handlers.base import RequestHandler from include.util.fetch.fetch import batch_prefetch_granted_ids, prefetch_user_blocks from include.util.recursive.ancestors import ( diff --git a/src/include/util/bulk/count.py b/src/include/util/bulk/count.py new file mode 100644 index 0000000..4f860d8 --- /dev/null +++ b/src/include/util/bulk/count.py @@ -0,0 +1,69 @@ +from itertools import batched +from typing import Iterable, Union + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from include.constants import MAX_PARAM_SIZE, QUERY_CHUNK_SIZE +from include.util.count import count_file_references + + +def batch_count_other_revisions( + session: Session, + file_ids: Iterable[str], + exclude_doc_ids: Union[str, Iterable[str]], +) -> dict[str, int]: + """ + Return total references to each file EXCLUDING references from exclude_doc_ids. + + This centralizes counting via `count_file_references` so new FK references + are automatically handled. + + Args: + session: Database session. + file_ids: File IDs to check. + exclude_doc_ids: Document ID or IDs to exclude from count. + + Returns: + Dict mapping file_id to reference count excluding specified documents. + """ + # FIXME: Use lazy import when Python 3.15 is out + from include.database.models.entity.obj import DocumentRevision + + # Materialize iterables so they can be safely iterated multiple times. + file_ids_list = list(file_ids) + if not file_ids_list: + return {} + + if isinstance(exclude_doc_ids, str): + exclude_doc_ids_list = [exclude_doc_ids] + else: + exclude_doc_ids_list = list(exclude_doc_ids) + + # Get total references across all tables (uses reflected FKs). + total_refs = count_file_references(session, file_ids_list) + + # Count references coming from the excluded documents (DocumentRevision only). + # The query combines `file_id IN (...)` and `document_id IN (...)`, so both + # dimensions must be chunked to keep total bind variables under MAX_PARAM_SIZE. + exclude_chunk_size = max(1, MAX_PARAM_SIZE - QUERY_CHUNK_SIZE) + excluded_counts: dict[str, int] = {} + for f_chunk in batched(file_ids_list, QUERY_CHUNK_SIZE): + for e_chunk in batched(exclude_doc_ids_list, exclude_chunk_size): + rows = ( + session.query(DocumentRevision.file_id, func.count(DocumentRevision.id)) + .filter(DocumentRevision.file_id.in_(list(f_chunk))) + .filter(DocumentRevision.document_id.in_(list(e_chunk))) + .group_by(DocumentRevision.file_id) + .all() + ) + for file_id, count in rows: + excluded_counts[file_id] = excluded_counts.get(file_id, 0) + count + + counts: dict[str, int] = {} + for fid in file_ids_list: + total = total_refs.get(fid, 0) + excluded = excluded_counts.get(fid, 0) + counts[fid] = max(0, total - excluded) + + return counts diff --git a/src/include/util/bulk/purge.py b/src/include/util/bulk/purge.py index 2f0f9c5..537d791 100644 --- a/src/include/util/bulk/purge.py +++ b/src/include/util/bulk/purge.py @@ -7,9 +7,9 @@ from include.database.models.entity import ( Document, DocumentRevision, - _batch_count_other_revisions, ) from include.database.models.file import File, FileTask, _queue_deferred_file_deletion +from include.util.bulk.count import batch_count_other_revisions def purge_documents_bulk(session: Session, document_ids: List[str]): @@ -42,7 +42,7 @@ def purge_documents_bulk(session: Session, document_ids: List[str]): # 2. 批量引用计数检查 # 使用集中计数,排除来自这批文档的引用后若为 0 则可删除 - other_counts = _batch_count_other_revisions(session, list(file_ids), document_ids) + other_counts = batch_count_other_revisions(session, list(file_ids), document_ids) # 找出仅被这批文档引用、可以物理删除的文件 ID deletable_file_ids = [fid for fid in file_ids if other_counts.get(fid, 0) == 0] diff --git a/src/main.py b/src/main.py index 638a938..5af7479 100644 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,7 @@ ROOT_DIRECTORY_ID, ) from include.database.handler import Base, Session, engine -from include.database.models.entity import Document, DocumentRevision, Folder +from include.database.models.entity.obj import Document, DocumentRevision, Folder from include.database.models.file import File from include.handlers.debugging.throw import RequestThrowExceptionHandler from include.providers.manager import ProviderManager From 4eaf03d237a018c1e595bdcc6434922b44fa14f5 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 9 Jun 2026 12:20:31 +0800 Subject: [PATCH 02/12] feat: add relationship for Document in DocumentMetadata --- src/include/database/models/entity/metadata.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/include/database/models/entity/metadata.py b/src/include/database/models/entity/metadata.py index 5d135ae..8e6e8de 100644 --- a/src/include/database/models/entity/metadata.py +++ b/src/include/database/models/entity/metadata.py @@ -1,9 +1,10 @@ __all__ = ["DocumentMetadata"] from sqlalchemy import VARCHAR, ForeignKey -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, relationship from include.database.handler import Base +from include.database.models.entity.obj import Document class DocumentMetadata(Base): @@ -12,4 +13,9 @@ class DocumentMetadata(Base): document_id: Mapped[str] = mapped_column( VARCHAR(64), ForeignKey("document.id", ondelete="CASCADE"), primary_key=True ) + document: Mapped["Document"] = relationship( + "Document", + back_populates="metadata", + foreign_keys=[document_id], + ) # TODO: Add more metadata fields as needed, e.g. author, creation date, etc. From dfd8e693dd187fd8263d84d5696f95af06e7397c Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 9 Jun 2026 12:22:00 +0800 Subject: [PATCH 03/12] fix: correct foreign key reference in DocumentMetadata --- src/include/database/models/entity/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/database/models/entity/metadata.py b/src/include/database/models/entity/metadata.py index 8e6e8de..fc9b8ff 100644 --- a/src/include/database/models/entity/metadata.py +++ b/src/include/database/models/entity/metadata.py @@ -11,7 +11,7 @@ class DocumentMetadata(Base): __tablename__ = "document_metadata" document_id: Mapped[str] = mapped_column( - VARCHAR(64), ForeignKey("document.id", ondelete="CASCADE"), primary_key=True + VARCHAR(64), ForeignKey("documents.id", ondelete="CASCADE"), primary_key=True ) document: Mapped["Document"] = relationship( "Document", From 96443ca1027bafb8fe3da947ff6ab4285b136689 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 09:39:36 +0800 Subject: [PATCH 04/12] feat: add instructions for creating Alembic upgrade/downgrade scripts --- AGENTS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/AGENTS.md b/AGENTS.md index d4fb700..a667efa 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,6 +9,7 @@ - Do not insert redundant comments. - Avoid anti-patterns when coding. - Observe the DRY principle. +- If you need to create an Alembic upgrade/downgrade script, create the framework by running the command `uv run alembic`, and then modify the parts you want to change in the generated file, instead of creating a file from scratch. ## Testing instructions - Run tests only when necessary, as running tests will delete the original database (if SQLite is used as the database engine). From 809a42f8c8180f98b9f74496fec1d5d64f8f5166 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 09:59:57 +0800 Subject: [PATCH 05/12] feat: add document metadata and tags management functionality --- AGENTS.md | 2 +- src/alembic/env.py | 2 + .../a50674184a2c_document_metadata.py | 270 ++++++++++++++++++ src/include/classes/enum/permissions.py | 2 + src/include/database/handler.py | 6 + .../database/models/entity/metadata.py | 59 +++- src/include/database/models/entity/obj.py | 11 +- src/include/handlers/document.py | 118 ++++++++ src/include/handlers/revision.py | 4 +- src/include/router.py | 2 + src/main.py | 30 ++ tests/test_client.py | 8 + tests/test_documents.py | 74 +++++ 13 files changed, 579 insertions(+), 9 deletions(-) create mode 100644 src/alembic/versions/a50674184a2c_document_metadata.py diff --git a/AGENTS.md b/AGENTS.md index a667efa..d3cf4a9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,7 +9,7 @@ - Do not insert redundant comments. - Avoid anti-patterns when coding. - Observe the DRY principle. -- If you need to create an Alembic upgrade/downgrade script, create the framework by running the command `uv run alembic`, and then modify the parts you want to change in the generated file, instead of creating a file from scratch. +- If you need to create an Alembic upgrade/downgrade script, create the framework by running the command `uv run alembic revision --autogenerate`, and then modify the parts you want to change in the generated file, instead of creating a file from scratch. ## Testing instructions - Run tests only when necessary, as running tests will delete the original database (if SQLite is used as the database engine). diff --git a/src/alembic/env.py b/src/alembic/env.py index 95f43f1..6009419 100644 --- a/src/alembic/env.py +++ b/src/alembic/env.py @@ -33,6 +33,8 @@ Document, DocumentRevision, DocumentAccessRule, + DocumentMetadata, + DocumentMetadataTag, Folder, FolderAccessRule, ) diff --git a/src/alembic/versions/a50674184a2c_document_metadata.py b/src/alembic/versions/a50674184a2c_document_metadata.py new file mode 100644 index 0000000..c7f383f --- /dev/null +++ b/src/alembic/versions/a50674184a2c_document_metadata.py @@ -0,0 +1,270 @@ +"""document metadata + +Revision ID: a50674184a2c +Revises: cb1df1f5c488 +Create Date: 2026-06-16 09:44:17.857205 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a50674184a2c' +down_revision: Union[str, Sequence[str], None] = 'cb1df1f5c488' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _has_fk( + inspector: sa.engine.reflection.Inspector, table_name: str, column_name: str +) -> bool: + return any( + fk.get("constrained_columns") == [column_name] + for fk in inspector.get_foreign_keys(table_name) + ) + + +def _has_index( + inspector: sa.engine.reflection.Inspector, table_name: str, index_name: str +) -> bool: + return any( + index.get("name") == index_name for index in inspector.get_indexes(table_name) + ) + + +def upgrade() -> None: + """Upgrade schema.""" + conn = op.get_bind() + inspector = sa.inspect(conn) + + if not inspector.has_table("document_metadata"): + op.create_table( + "document_metadata", + sa.Column("document_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("creator_username", sa.VARCHAR(length=64), nullable=True), + sa.Column( + "last_modified_by_username", sa.VARCHAR(length=64), nullable=True + ), + sa.ForeignKeyConstraint( + ["creator_username"], + ["users.username"], + name=op.f("fk_document_metadata_creator_username_users"), + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["document_id"], + ["documents.id"], + name=op.f("fk_document_metadata_document_id_documents"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["last_modified_by_username"], + ["users.username"], + name=op.f("fk_document_metadata_last_modified_by_username_users"), + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("document_id", name=op.f("pk_document_metadata")), + ) + op.create_index( + op.f("ix_document_metadata_creator_username"), + "document_metadata", + ["creator_username"], + unique=False, + ) + op.create_index( + op.f("ix_document_metadata_last_modified_by_username"), + "document_metadata", + ["last_modified_by_username"], + unique=False, + ) + else: + column_names = { + column["name"] for column in inspector.get_columns("document_metadata") + } + document_id_column = next( + ( + column + for column in inspector.get_columns("document_metadata") + if column["name"] == "document_id" + ), + None, + ) + with op.batch_alter_table("document_metadata", schema=None) as batch_op: + if document_id_column is not None: + batch_op.alter_column( + "document_id", + existing_type=document_id_column["type"], + type_=sa.VARCHAR(length=255), + existing_nullable=False, + ) + if "creator_username" not in column_names: + batch_op.add_column( + sa.Column( + "creator_username", sa.VARCHAR(length=64), nullable=True + ) + ) + if "last_modified_by_username" not in column_names: + batch_op.add_column( + sa.Column( + "last_modified_by_username", + sa.VARCHAR(length=64), + nullable=True, + ) + ) + if not _has_fk(inspector, "document_metadata", "document_id"): + batch_op.create_foreign_key( + op.f("fk_document_metadata_document_id_documents"), + "documents", + ["document_id"], + ["id"], + ondelete="CASCADE", + ) + if not _has_fk(inspector, "document_metadata", "creator_username"): + batch_op.create_foreign_key( + op.f("fk_document_metadata_creator_username_users"), + "users", + ["creator_username"], + ["username"], + ondelete="SET NULL", + ) + if not _has_fk( + inspector, "document_metadata", "last_modified_by_username" + ): + batch_op.create_foreign_key( + op.f("fk_document_metadata_last_modified_by_username_users"), + "users", + ["last_modified_by_username"], + ["username"], + ondelete="SET NULL", + ) + + inspector = sa.inspect(conn) + if not _has_index( + inspector, + "document_metadata", + op.f("ix_document_metadata_creator_username"), + ): + op.create_index( + op.f("ix_document_metadata_creator_username"), + "document_metadata", + ["creator_username"], + unique=False, + ) + if not _has_index( + inspector, + "document_metadata", + op.f("ix_document_metadata_last_modified_by_username"), + ): + op.create_index( + op.f("ix_document_metadata_last_modified_by_username"), + "document_metadata", + ["last_modified_by_username"], + unique=False, + ) + + inspector = sa.inspect(conn) + if not inspector.has_table("document_metadata_tags"): + op.create_table( + "document_metadata_tags", + sa.Column("document_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("tag", sa.VARCHAR(length=255), nullable=False), + sa.Column("position", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["document_id"], + ["document_metadata.document_id"], + name=op.f("fk_document_metadata_tags_document_id_document_metadata"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint( + "document_id", "tag", name=op.f("pk_document_metadata_tags") + ), + ) + + inspector = sa.inspect(conn) + if not _has_index( + inspector, "document_metadata_tags", op.f("ix_document_metadata_tags_tag") + ): + op.create_index( + op.f("ix_document_metadata_tags_tag"), + "document_metadata_tags", + ["tag"], + unique=False, + ) + + op.execute( + """ + INSERT INTO document_metadata (document_id) + SELECT id + FROM documents + WHERE NOT EXISTS ( + SELECT 1 + FROM document_metadata + WHERE document_metadata.document_id = documents.id + ) + """ + ) + + user_groups = sa.table("user_groups", sa.column("group_name", sa.String())) + group_permissions = sa.table( + "group_permissions", + sa.column("group_name", sa.String()), + sa.column("permission", sa.String()), + sa.column("granted", sa.Boolean()), + sa.column("start_time", sa.Float()), + sa.column("end_time", sa.Float()), + ) + + sysop_exists = conn.execute( + sa.select(user_groups.c.group_name).where(user_groups.c.group_name == "sysop") + ).first() + if sysop_exists: + for permission in ("view_metadata", "set_metadata_tags"): + permission_exists = conn.execute( + sa.select(group_permissions.c.permission).where( + group_permissions.c.group_name == "sysop", + group_permissions.c.permission == permission, + group_permissions.c.granted == True, + ) + ).first() + if not permission_exists: + conn.execute( + group_permissions.insert().values( + group_name="sysop", + permission=permission, + granted=True, + start_time=0.0, + end_time=None, + ) + ) + + +def downgrade() -> None: + """Downgrade schema.""" + conn = op.get_bind() + group_permissions = sa.table( + "group_permissions", + sa.column("group_name", sa.String()), + sa.column("permission", sa.String()), + ) + conn.execute( + group_permissions.delete().where( + group_permissions.c.group_name == "sysop", + group_permissions.c.permission.in_(["view_metadata", "set_metadata_tags"]), + ) + ) + + op.drop_index( + op.f("ix_document_metadata_tags_tag"), table_name="document_metadata_tags" + ) + op.drop_table("document_metadata_tags") + op.drop_index( + op.f("ix_document_metadata_last_modified_by_username"), + table_name="document_metadata", + ) + op.drop_index( + op.f("ix_document_metadata_creator_username"), table_name="document_metadata" + ) + op.drop_table("document_metadata") diff --git a/src/include/classes/enum/permissions.py b/src/include/classes/enum/permissions.py index f727caf..501ebf9 100644 --- a/src/include/classes/enum/permissions.py +++ b/src/include/classes/enum/permissions.py @@ -58,6 +58,8 @@ class Permissions(StrEnum): # 访问控制与锁定 VIEW_ACCESS_RULES = "view_access_rules" SET_ACCESS_RULES = "set_access_rules" + VIEW_METADATA = "view_metadata" + SET_METADATA_TAGS = "set_metadata_tags" MANAGE_ACCESS = "manage_access" VIEW_ACCESS_ENTRIES = "view_access_entries" APPLY_LOCKDOWN = "apply_lockdown" diff --git a/src/include/database/handler.py b/src/include/database/handler.py index 8c719cd..580f43a 100644 --- a/src/include/database/handler.py +++ b/src/include/database/handler.py @@ -34,6 +34,12 @@ if db_type == "sqlite": db_file = global_config["database"]["file"] engine = create_engine(f"sqlite:///{db_file}", echo=debug_enabled) + + @event.listens_for(engine, "connect") + def _enable_sqlite_foreign_keys(dbapi_connection, _connection_record) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() else: username = global_config["database"]["username"] password = global_config["database"]["password"] diff --git a/src/include/database/models/entity/metadata.py b/src/include/database/models/entity/metadata.py index fc9b8ff..41e7db1 100644 --- a/src/include/database/models/entity/metadata.py +++ b/src/include/database/models/entity/metadata.py @@ -1,21 +1,68 @@ -__all__ = ["DocumentMetadata"] +__all__ = ["DocumentMetadata", "DocumentMetadataTag"] -from sqlalchemy import VARCHAR, ForeignKey +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import VARCHAR, ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship from include.database.handler import Base -from include.database.models.entity.obj import Document + +if TYPE_CHECKING: + from include.database.models.classic import User + from include.database.models.entity.obj import Document class DocumentMetadata(Base): __tablename__ = "document_metadata" document_id: Mapped[str] = mapped_column( - VARCHAR(64), ForeignKey("documents.id", ondelete="CASCADE"), primary_key=True + VARCHAR(255), ForeignKey("documents.id", ondelete="CASCADE"), primary_key=True + ) + creator_username: Mapped[Optional[str]] = mapped_column( + VARCHAR(64), + ForeignKey("users.username", ondelete="SET NULL"), + nullable=True, + index=True, + ) + last_modified_by_username: Mapped[Optional[str]] = mapped_column( + VARCHAR(64), + ForeignKey("users.username", ondelete="SET NULL"), + nullable=True, + index=True, ) + document: Mapped["Document"] = relationship( "Document", - back_populates="metadata", + back_populates="metadata_record", + foreign_keys=[document_id], + ) + creator: Mapped[Optional["User"]] = relationship( + "User", foreign_keys=[creator_username] + ) + last_modified_by: Mapped[Optional["User"]] = relationship( + "User", foreign_keys=[last_modified_by_username] + ) + tags: Mapped[List["DocumentMetadataTag"]] = relationship( + "DocumentMetadataTag", + back_populates="metadata_record", + cascade="all, delete-orphan", + order_by="DocumentMetadataTag.position", + ) + + +class DocumentMetadataTag(Base): + __tablename__ = "document_metadata_tags" + + document_id: Mapped[str] = mapped_column( + VARCHAR(255), + ForeignKey("document_metadata.document_id", ondelete="CASCADE"), + primary_key=True, + ) + tag: Mapped[str] = mapped_column(VARCHAR(255), primary_key=True, index=True) + position: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + metadata_record: Mapped["DocumentMetadata"] = relationship( + "DocumentMetadata", + back_populates="tags", foreign_keys=[document_id], ) - # TODO: Add more metadata fields as needed, e.g. author, creation date, etc. diff --git a/src/include/database/models/entity/obj.py b/src/include/database/models/entity/obj.py index 2a1a776..5549cb7 100644 --- a/src/include/database/models/entity/obj.py +++ b/src/include/database/models/entity/obj.py @@ -9,7 +9,7 @@ import secrets import time from itertools import batched -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from sqlalchemy import JSON, VARCHAR, Boolean, Float, ForeignKey, Integer from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -29,6 +29,9 @@ from include.util.bulk.count import batch_count_other_revisions from include.util.count import count_file_references +if TYPE_CHECKING: + from include.database.models.entity.metadata import DocumentMetadata + class Folder(BaseObject): # 文档文件夹 __tablename__ = "folders" @@ -141,6 +144,12 @@ class Document(BaseObject): overlaps="current_revision", # 声明与 current_revision 的重叠 ) inherit: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + metadata_record: Mapped[Optional["DocumentMetadata"]] = relationship( + "DocumentMetadata", + back_populates="document", + cascade="all, delete-orphan", + uselist=False, + ) @property def active(self): diff --git a/src/include/handlers/document.py b/src/include/handlers/document.py index 0241a5b..e805ea0 100644 --- a/src/include/handlers/document.py +++ b/src/include/handlers/document.py @@ -8,6 +8,7 @@ "RequestDownloadFileHandler", "RequestUploadFileHandler", "RequestSetDocumentRulesHandler", + "RequestSetDocumentMetadataTagsHandler", "RequestMoveDocumentHandler", ] @@ -26,6 +27,8 @@ from include.database.models.classic import User from include.database.models.entity import ( Document, + DocumentMetadata, + DocumentMetadataTag, DocumentRevision, Folder, ) @@ -77,6 +80,32 @@ def create_file_task(file: File, transfer_mode: int = 0): } +def get_or_create_document_metadata(document: Document) -> DocumentMetadata: + if document.metadata_record is None: + document.metadata_record = DocumentMetadata() + return document.metadata_record + + +def mark_document_modified(document: Document, username: str) -> None: + get_or_create_document_metadata(document).last_modified_by_username = username + + +def serialize_document_metadata(document: Document) -> dict: + metadata_record = document.metadata_record + if metadata_record is None: + return { + "tags": [], + "creator": None, + "last_modified_by": None, + } + + return { + "tags": [tag.tag for tag in metadata_record.tags], + "creator": metadata_record.creator_username, + "last_modified_by": metadata_record.last_modified_by_username, + } + + class RequestGetDocumentInfoHandler(RequestHandler): """ Handles the "get_document_info" action. @@ -145,6 +174,9 @@ def handle(self, handler: ConnectionHandler): "info_code": info_code, } + if Permissions.VIEW_METADATA in user.all_permissions: + data["metadata"] = serialize_document_metadata(document) + handler.conclude_request(200, data, "Document info retrieved successfully") return 0, document_id, handler.username @@ -308,6 +340,10 @@ def handle(self, handler: ConnectionHandler): title=title, folder_id=folder_id, ) + new_document.metadata_record = DocumentMetadata( + creator_username=user.username, + last_modified_by_username=user.username, + ) new_revision = DocumentRevision(file_id=new_file.id) new_document.revisions.append(new_revision) @@ -418,6 +454,7 @@ def handle(self, handler: ConnectionHandler): parent_revision_id=parent_revision_id, ) document.revisions.append(new_revision) + mark_document_modified(document, this_user.username) session.add(new_file) session.add(new_revision) @@ -475,6 +512,7 @@ def handle(self, handler: ConnectionHandler): document.status_operation_id = ( f"OP_DEL_{secrets.token_hex(8)}_{int(time.time())}" ) + mark_document_modified(document, user.username) session.commit() handler.conclude_request(200, {}, "Document successfully deleted") @@ -551,6 +589,7 @@ def handle(self, handler: ConnectionHandler): return err_code, document.folder_id, handler.username document.title = new_title + mark_document_modified(document, this_user.username) session.commit() handler.conclude_request( @@ -702,6 +741,7 @@ def handle(self, handler: ConnectionHandler): if apply_access_rules( document, access_rules_to_apply, user, inherit_parent ): + mark_document_modified(document, user.username) session.commit() handler.conclude_request(200, {}, "Set access rules successfully") return 0, document_id, handler.username @@ -845,6 +885,7 @@ def handle(self, handler: ConnectionHandler): return err_code, document.folder_id, handler.username document.folder = target_folder + mark_document_modified(document, user.username) session.commit() @@ -897,6 +938,7 @@ def handle(self, handler: ConnectionHandler): return document.delete_all_revisions(do_commit=False) + mark_document_modified(document, user.username) session.delete(document) session.commit() @@ -1007,6 +1049,7 @@ def handle(self, handler: ConnectionHandler): document.status_operation_id = None document.title = final_title document.folder_id = db_folder_id + mark_document_modified(document, user.username) session.commit() @@ -1019,3 +1062,78 @@ def handle(self, handler: ConnectionHandler): "Document successfully restored", ) return 0, doc_id, handler.username + + +class RequestSetDocumentMetadataTagsHandler(RequestHandler): + schema = { + "type": "object", + "properties": { + "document_id": {"type": "string", "minLength": 1}, + "tags": { + "type": "array", + "maxItems": 128, + "items": {"type": "string", "minLength": 1, "maxLength": 255}, + }, + }, + "required": ["document_id", "tags"], + "additionalProperties": False, + } + + require_auth = True + + def handle(self, handler: ConnectionHandler): + document_id: str = handler.data["document_id"] + + normalized_tags = [] + seen_tags = set() + for raw_tag in handler.data["tags"]: + tag = raw_tag.strip() + if not tag: + handler.conclude_request(400, {}, "Tags cannot be blank") + return 400, document_id, handler.username + if tag not in seen_tags: + normalized_tags.append(tag) + seen_tags.add(tag) + + with Session() as session: + user = User.get_existing(session, handler.username) + document = session.get(Document, document_id) + + if not document: + handler.conclude_request(404, {}, smsg.DOCUMENT_NOT_FOUND) + return 404, document_id, handler.username + + if ( + Permissions.SET_METADATA_TAGS not in user.all_permissions + or not document.check_access_requirements(user, access_type="write") + ): + handler.conclude_access_denial() + return 403, document_id, handler.username + + metadata_record = get_or_create_document_metadata(document) + existing_by_tag = { + tag_record.tag: tag_record for tag_record in metadata_record.tags + } + requested_tag_set = set(normalized_tags) + + for tag_record in list(metadata_record.tags): + if tag_record.tag not in requested_tag_set: + metadata_record.tags.remove(tag_record) + + for position, tag in enumerate(normalized_tags): + if tag in existing_by_tag: + existing_by_tag[tag].position = position + else: + metadata_record.tags.append( + DocumentMetadataTag(tag=tag, position=position) + ) + + metadata_record.last_modified_by_username = user.username + session.commit() + + handler.conclude_request( + 200, + {"tags": normalized_tags}, + "Document metadata tags updated successfully", + ) + return 0, document_id, {"tags": normalized_tags}, handler.username diff --git a/src/include/handlers/revision.py b/src/include/handlers/revision.py index 75e3f18..324dea7 100644 --- a/src/include/handlers/revision.py +++ b/src/include/handlers/revision.py @@ -4,7 +4,7 @@ from include.database.models.classic import User from include.database.models.entity import Document, DocumentRevision from include.handlers.base import RequestHandler -from include.handlers.document import create_file_task +from include.handlers.document import create_file_task, mark_document_modified from include.system.messages import Messages as smsg @@ -128,6 +128,7 @@ def handle(self, handler: ConnectionHandler): return 403, document_id, handler.username document.current_revision_id = revision.id + mark_document_modified(document, user.username) session.commit() handler.conclude_request(200, {}, "Current revision set successfully") @@ -179,6 +180,7 @@ def handle(self, handler: ConnectionHandler): child_rev.parent_revision = revision.parent_revision revision.before_delete() + mark_document_modified(document, user.username) session.delete(revision) session.commit() diff --git a/src/include/router.py b/src/include/router.py index cd2a9b1..0e6fba3 100644 --- a/src/include/router.py +++ b/src/include/router.py @@ -41,6 +41,7 @@ RequestPurgeDocumentHandler, RequestRenameDocumentHandler, RequestRestoreDocumentHandler, + RequestSetDocumentMetadataTagsHandler, RequestSetDocumentRulesHandler, RequestUploadDocumentHandler, RequestUploadFileHandler, @@ -129,6 +130,7 @@ "get_document_info": RequestGetDocumentInfoHandler, "get_document_access_rules": RequestGetDocumentAccessRulesHandler, "set_document_rules": RequestSetDocumentRulesHandler, + "set_document_metadata_tags": RequestSetDocumentMetadataTagsHandler, # 修订版本类 "list_revisions": RequestListRevisionsHandler, "get_revision": RequestGetRevisionHandler, diff --git a/src/main.py b/src/main.py index 5af7479..18886a8 100644 --- a/src/main.py +++ b/src/main.py @@ -14,6 +14,7 @@ import sys from loguru import logger +from sqlalchemy import insert, select from websockets.sync.server import serve from include.classes.enum.permissions import Permissions @@ -26,6 +27,7 @@ ROOT_DIRECTORY_ID, ) from include.database.handler import Base, Session, engine +from include.database.models.entity import DocumentMetadata from include.database.models.entity.obj import Document, DocumentRevision, Folder from include.database.models.file import File from include.handlers.debugging.throw import RequestThrowExceptionHandler @@ -82,6 +84,26 @@ def ensure_root_folder(): session.commit() +def ensure_document_metadata_records(): + document_table = Document.__table__ + metadata_table = DocumentMetadata.__table__ + + missing_document_ids = ( + select(document_table.c.id) + .outerjoin( + metadata_table, + document_table.c.id == metadata_table.c.document_id, + ) + .where(metadata_table.c.document_id.is_(None)) + ) + + with Session() as session: + session.execute( + insert(metadata_table).from_select(["document_id"], missing_document_ids) + ) + session.commit() + + def server_init(): """ Initialize the server by checking whether the database is already set up. @@ -133,6 +155,8 @@ def server_init(): {"permission": Permissions.SUPER_SET_PASSWD}, {"permission": Permissions.VIEW_ACCESS_RULES}, {"permission": Permissions.SET_ACCESS_RULES}, + {"permission": Permissions.VIEW_METADATA}, + {"permission": Permissions.SET_METADATA_TAGS}, {"permission": Permissions.LIST_USERS}, {"permission": Permissions.LIST_GROUPS}, {"permission": Permissions.CREATE_GROUP}, @@ -179,6 +203,7 @@ def server_init(): init_document = Document( id="hello", title="Hello World", folder_id=ROOT_DIRECTORY_ID ) + init_document.metadata_record = DocumentMetadata() init_document_revision = DocumentRevision(file_id=init_file.id) init_document.revisions.append(init_document_revision) init_document.current_revision = init_document_revision @@ -186,6 +211,8 @@ def server_init(): session.add(init_document_revision) session.commit() + ensure_document_metadata_records() + import secrets import string @@ -488,6 +515,9 @@ def main(): # Ensure the root folder record exists (handles upgrades from older versions) ensure_root_folder() + # Ensure every existing document has its one-to-one metadata row. + ensure_document_metadata_records() + # Initialize Providers initialize_providers() diff --git a/tests/test_client.py b/tests/test_client.py index 592456f..3239bf8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -549,6 +549,14 @@ async def get_document_info(self, document_id: str) -> Dict[str, Any]: "get_document_info", {"document_id": document_id} ) + async def set_document_metadata_tags( + self, document_id: str, tags: list[str] + ) -> Dict[str, Any]: + return await self.send_request( + "set_document_metadata_tags", + {"document_id": document_id, "tags": tags}, + ) + # --- Revisions --- async def list_revisions(self, document_id: str) -> Dict[str, Any]: return await self.send_request("list_revisions", {"document_id": document_id}) diff --git a/tests/test_documents.py b/tests/test_documents.py index 25ee492..f1748d0 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -1,3 +1,6 @@ +import secrets +import time + import pytest from tests.test_client import CFMSTestClient @@ -35,6 +38,77 @@ async def test_get_document_info( data = assert_success(response) assert isinstance(data, dict) + @pytest.mark.asyncio + async def test_document_metadata_tags( + self, authenticated_client: CFMSTestClient, test_document: dict + ): + document_id = test_document["document_id"] + + info_response = await authenticated_client.get_document_info(document_id) + info = assert_success(info_response) + assert info["metadata"] == { + "tags": [], + "creator": "admin", + "last_modified_by": "admin", + } + + set_response = await authenticated_client.set_document_metadata_tags( + document_id, + ["secret", "finance", "secret", " topic "], + ) + data = assert_success(set_response) + assert data["tags"] == ["secret", "finance", "topic"] + + info_response = await authenticated_client.get_document_info(document_id) + info = assert_success(info_response) + assert info["metadata"]["tags"] == ["secret", "finance", "topic"] + assert info["metadata"]["creator"] == "admin" + assert info["metadata"]["last_modified_by"] == "admin" + + @pytest.mark.asyncio + async def test_set_document_metadata_tags_requires_permission( + self, + authenticated_client: CFMSTestClient, + test_document: dict, + server_process, + ): + username = f"metadata_user_{secrets.token_hex(4)}" + password = "TestPassword123!" + document_id = test_document["document_id"] + + response = await authenticated_client.create_user( + username=username, + password=password, + nickname="Metadata User", + ) + assert_success(response) + + try: + grant_response = await authenticated_client.grant_access( + entity_type="user", + entity_identifier=username, + target_type="document", + target_identifier=document_id, + access_types=["write"], + start_time=time.time(), + ) + assert_success(grant_response) + + client = CFMSTestClient() + await client.connect() + try: + login_response = await client.login(username, password) + assert_success(login_response) + + set_response = await client.set_document_metadata_tags( + document_id, ["blocked"] + ) + assert_error(set_response, 403) + finally: + await client.disconnect() + finally: + await authenticated_client.delete_user(username) + @pytest.mark.asyncio async def test_rename_document( self, authenticated_client: CFMSTestClient, test_document: dict From 026f2ffa769210474e6e1d6d4fc2531920b55e29 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:00:42 +0800 Subject: [PATCH 06/12] refactor: remove ensure_document_metadata_records function and its calls --- src/main.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/main.py b/src/main.py index 18886a8..0b7b518 100644 --- a/src/main.py +++ b/src/main.py @@ -14,7 +14,6 @@ import sys from loguru import logger -from sqlalchemy import insert, select from websockets.sync.server import serve from include.classes.enum.permissions import Permissions @@ -84,26 +83,6 @@ def ensure_root_folder(): session.commit() -def ensure_document_metadata_records(): - document_table = Document.__table__ - metadata_table = DocumentMetadata.__table__ - - missing_document_ids = ( - select(document_table.c.id) - .outerjoin( - metadata_table, - document_table.c.id == metadata_table.c.document_id, - ) - .where(metadata_table.c.document_id.is_(None)) - ) - - with Session() as session: - session.execute( - insert(metadata_table).from_select(["document_id"], missing_document_ids) - ) - session.commit() - - def server_init(): """ Initialize the server by checking whether the database is already set up. @@ -211,8 +190,6 @@ def server_init(): session.add(init_document_revision) session.commit() - ensure_document_metadata_records() - import secrets import string @@ -515,9 +492,6 @@ def main(): # Ensure the root folder record exists (handles upgrades from older versions) ensure_root_folder() - # Ensure every existing document has its one-to-one metadata row. - ensure_document_metadata_records() - # Initialize Providers initialize_providers() From f163eeec77fb252d1e44641c2c8a5dbb0b98d0cb Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:05:33 +0800 Subject: [PATCH 07/12] fix: ensure document revisions are updated correctly in RequestUploadDocumentHandler --- src/include/handlers/document.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/include/handlers/document.py b/src/include/handlers/document.py index e805ea0..79459cc 100644 --- a/src/include/handlers/document.py +++ b/src/include/handlers/document.py @@ -453,12 +453,12 @@ def handle(self, handler: ConnectionHandler): file_id=file_id, parent_revision_id=parent_revision_id, ) - document.revisions.append(new_revision) - mark_document_modified(document, this_user.username) - session.add(new_file) session.add(new_revision) + document.revisions.append(new_revision) + mark_document_modified(document, this_user.username) + document.current_revision = new_revision session.commit() From 5179efc4bc701a57ead36a7953684d35e4a99881 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:23:02 +0800 Subject: [PATCH 08/12] fix: update variable names for clarity in serialize_document_metadata and RequestSetDocumentMetadataTagsHandler --- src/include/handlers/document.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/include/handlers/document.py b/src/include/handlers/document.py index 79459cc..b579c0d 100644 --- a/src/include/handlers/document.py +++ b/src/include/handlers/document.py @@ -91,8 +91,8 @@ def mark_document_modified(document: Document, username: str) -> None: def serialize_document_metadata(document: Document) -> dict: - metadata_record = document.metadata_record - if metadata_record is None: + metadata = document.metadata_record + if metadata is None: return { "tags": [], "creator": None, @@ -100,9 +100,9 @@ def serialize_document_metadata(document: Document) -> dict: } return { - "tags": [tag.tag for tag in metadata_record.tags], - "creator": metadata_record.creator_username, - "last_modified_by": metadata_record.last_modified_by_username, + "tags": [tag.tag for tag in metadata.tags], + "creator": metadata.creator_username, + "last_modified_by": metadata.last_modified_by_username, } @@ -1110,25 +1110,25 @@ def handle(self, handler: ConnectionHandler): handler.conclude_access_denial() return 403, document_id, handler.username - metadata_record = get_or_create_document_metadata(document) + metadata = get_or_create_document_metadata(document) existing_by_tag = { - tag_record.tag: tag_record for tag_record in metadata_record.tags + tag_record.tag: tag_record for tag_record in metadata.tags } requested_tag_set = set(normalized_tags) - for tag_record in list(metadata_record.tags): + for tag_record in list(metadata.tags): if tag_record.tag not in requested_tag_set: - metadata_record.tags.remove(tag_record) + metadata.tags.remove(tag_record) for position, tag in enumerate(normalized_tags): if tag in existing_by_tag: existing_by_tag[tag].position = position else: - metadata_record.tags.append( + metadata.tags.append( DocumentMetadataTag(tag=tag, position=position) ) - metadata_record.last_modified_by_username = user.username + metadata.last_modified_by_username = user.username session.commit() handler.conclude_request( From 4678e9eda3c527e4ad5fc806b31371f5b85ce098 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:25:17 +0800 Subject: [PATCH 09/12] refactor: rename RequestSetDocumentMetadataTagsHandler to RequestSetDocumentTagsHandler for consistency --- src/include/handlers/document.py | 4 ++-- src/include/router.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/include/handlers/document.py b/src/include/handlers/document.py index b579c0d..1dd03f3 100644 --- a/src/include/handlers/document.py +++ b/src/include/handlers/document.py @@ -8,7 +8,7 @@ "RequestDownloadFileHandler", "RequestUploadFileHandler", "RequestSetDocumentRulesHandler", - "RequestSetDocumentMetadataTagsHandler", + "RequestSetDocumentTagsHandler", "RequestMoveDocumentHandler", ] @@ -1064,7 +1064,7 @@ def handle(self, handler: ConnectionHandler): return 0, doc_id, handler.username -class RequestSetDocumentMetadataTagsHandler(RequestHandler): +class RequestSetDocumentTagsHandler(RequestHandler): schema = { "type": "object", "properties": { diff --git a/src/include/router.py b/src/include/router.py index 0e6fba3..bf1aa07 100644 --- a/src/include/router.py +++ b/src/include/router.py @@ -41,8 +41,8 @@ RequestPurgeDocumentHandler, RequestRenameDocumentHandler, RequestRestoreDocumentHandler, - RequestSetDocumentMetadataTagsHandler, RequestSetDocumentRulesHandler, + RequestSetDocumentTagsHandler, RequestUploadDocumentHandler, RequestUploadFileHandler, ) @@ -130,7 +130,7 @@ "get_document_info": RequestGetDocumentInfoHandler, "get_document_access_rules": RequestGetDocumentAccessRulesHandler, "set_document_rules": RequestSetDocumentRulesHandler, - "set_document_metadata_tags": RequestSetDocumentMetadataTagsHandler, + "set_document_tags": RequestSetDocumentTagsHandler, # 修订版本类 "list_revisions": RequestListRevisionsHandler, "get_revision": RequestGetRevisionHandler, From ad6e60d31361283f88cf4502524cca6cbab95afa Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:32:11 +0800 Subject: [PATCH 10/12] feat: add test for omitting document metadata when user lacks permission --- tests/test_documents.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/test_documents.py b/tests/test_documents.py index f1748d0..e98a2a4 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -65,6 +65,43 @@ async def test_document_metadata_tags( assert info["metadata"]["creator"] == "admin" assert info["metadata"]["last_modified_by"] == "admin" + @pytest.mark.asyncio + async def test_get_document_info_omits_metadata_without_permission( + self, + authenticated_client: CFMSTestClient, + test_document: dict, + user_factory, + ): + document_id = test_document["document_id"] + + set_response = await authenticated_client.set_document_metadata_tags( + document_id, ["restricted"] + ) + assert_success(set_response) + + user = await user_factory() + grant_response = await authenticated_client.grant_access( + entity_type="user", + entity_identifier=user["username"], + target_type="document", + target_identifier=document_id, + access_types=["read"], + start_time=time.time(), + ) + assert_success(grant_response) + + client = CFMSTestClient() + await client.connect() + try: + login_response = await client.login(user["username"], user["password"]) + assert_success(login_response) + + info_response = await client.get_document_info(document_id) + info = assert_success(info_response) + assert "metadata" not in info + finally: + await client.disconnect() + @pytest.mark.asyncio async def test_set_document_metadata_tags_requires_permission( self, From ce3434387e2bc0b50d178d594a8fe355d7bcc8d0 Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:33:12 +0800 Subject: [PATCH 11/12] fix: rename set_document_metadata_tags to set_document_tags for consistency --- tests/test_client.py | 4 ++-- tests/test_documents.py | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index 3239bf8..acdc36d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -549,11 +549,11 @@ async def get_document_info(self, document_id: str) -> Dict[str, Any]: "get_document_info", {"document_id": document_id} ) - async def set_document_metadata_tags( + async def set_document_tags( self, document_id: str, tags: list[str] ) -> Dict[str, Any]: return await self.send_request( - "set_document_metadata_tags", + "set_document_tags", {"document_id": document_id, "tags": tags}, ) diff --git a/tests/test_documents.py b/tests/test_documents.py index e98a2a4..d1c47a6 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -52,7 +52,7 @@ async def test_document_metadata_tags( "last_modified_by": "admin", } - set_response = await authenticated_client.set_document_metadata_tags( + set_response = await authenticated_client.set_document_tags( document_id, ["secret", "finance", "secret", " topic "], ) @@ -74,7 +74,7 @@ async def test_get_document_info_omits_metadata_without_permission( ): document_id = test_document["document_id"] - set_response = await authenticated_client.set_document_metadata_tags( + set_response = await authenticated_client.set_document_tags( document_id, ["restricted"] ) assert_success(set_response) @@ -103,7 +103,7 @@ async def test_get_document_info_omits_metadata_without_permission( await client.disconnect() @pytest.mark.asyncio - async def test_set_document_metadata_tags_requires_permission( + async def test_set_document_tags_requires_permission( self, authenticated_client: CFMSTestClient, test_document: dict, @@ -137,9 +137,7 @@ async def test_set_document_metadata_tags_requires_permission( login_response = await client.login(username, password) assert_success(login_response) - set_response = await client.set_document_metadata_tags( - document_id, ["blocked"] - ) + set_response = await client.set_document_tags(document_id, ["blocked"]) assert_error(set_response, 403) finally: await client.disconnect() From dde957fb0d212a953e45b03f95f95198a44ff72c Mon Sep 17 00:00:00 2001 From: Creeper19472 Date: Tue, 16 Jun 2026 10:37:20 +0800 Subject: [PATCH 12/12] chore: bump CORE_VERSION and PROTOCOL_VERSION --- src/include/constants.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/include/constants.py b/src/include/constants.py index 6eb056b..572487a 100644 --- a/src/include/constants.py +++ b/src/include/constants.py @@ -21,8 +21,8 @@ from include.classes.version import Version -CORE_VERSION = Version("0.3.0.260530_alpha") -PROTOCOL_VERSION = 13 +CORE_VERSION = Version("0.3.0.260616_alpha") +PROTOCOL_VERSION = 14 ROOT_ABSPATH = Path(__file__).resolve().parent.parent