diff --git a/AGENTS.md b/AGENTS.md index d4fb700..d3cf4a9 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,6 +9,7 @@ - Do not insert redundant comments. - Avoid anti-patterns when coding. - Observe the DRY principle. +- If you need to create an Alembic upgrade/downgrade script, create the framework by running the command `uv run alembic revision --autogenerate`, and then modify the parts you want to change in the generated file, instead of creating a file from scratch. ## Testing instructions - Run tests only when necessary, as running tests will delete the original database (if SQLite is used as the database engine). diff --git a/src/alembic/env.py b/src/alembic/env.py index 95f43f1..6009419 100644 --- a/src/alembic/env.py +++ b/src/alembic/env.py @@ -33,6 +33,8 @@ Document, DocumentRevision, DocumentAccessRule, + DocumentMetadata, + DocumentMetadataTag, Folder, FolderAccessRule, ) diff --git a/src/alembic/versions/a50674184a2c_document_metadata.py b/src/alembic/versions/a50674184a2c_document_metadata.py new file mode 100644 index 0000000..c7f383f --- /dev/null +++ b/src/alembic/versions/a50674184a2c_document_metadata.py @@ -0,0 +1,270 @@ +"""document metadata + +Revision ID: a50674184a2c +Revises: cb1df1f5c488 +Create Date: 2026-06-16 09:44:17.857205 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'a50674184a2c' +down_revision: Union[str, Sequence[str], None] = 'cb1df1f5c488' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _has_fk( + inspector: sa.engine.reflection.Inspector, table_name: str, column_name: str +) -> bool: + return any( + fk.get("constrained_columns") == [column_name] + for fk in inspector.get_foreign_keys(table_name) + ) + + +def _has_index( + inspector: sa.engine.reflection.Inspector, table_name: str, index_name: str +) -> bool: + return any( + index.get("name") == index_name for index in inspector.get_indexes(table_name) + ) + + +def upgrade() -> None: + """Upgrade schema.""" + conn = op.get_bind() + inspector = sa.inspect(conn) + + if not inspector.has_table("document_metadata"): + op.create_table( + "document_metadata", + sa.Column("document_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("creator_username", sa.VARCHAR(length=64), nullable=True), + sa.Column( + "last_modified_by_username", sa.VARCHAR(length=64), nullable=True + ), + sa.ForeignKeyConstraint( + ["creator_username"], + ["users.username"], + name=op.f("fk_document_metadata_creator_username_users"), + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["document_id"], + ["documents.id"], + name=op.f("fk_document_metadata_document_id_documents"), + ondelete="CASCADE", + ), + sa.ForeignKeyConstraint( + ["last_modified_by_username"], + ["users.username"], + name=op.f("fk_document_metadata_last_modified_by_username_users"), + ondelete="SET NULL", + ), + sa.PrimaryKeyConstraint("document_id", name=op.f("pk_document_metadata")), + ) + op.create_index( + op.f("ix_document_metadata_creator_username"), + "document_metadata", + ["creator_username"], + unique=False, + ) + op.create_index( + op.f("ix_document_metadata_last_modified_by_username"), + "document_metadata", + ["last_modified_by_username"], + unique=False, + ) + else: + column_names = { + column["name"] for column in inspector.get_columns("document_metadata") + } + document_id_column = next( + ( + column + for column in inspector.get_columns("document_metadata") + if column["name"] == "document_id" + ), + None, + ) + with op.batch_alter_table("document_metadata", schema=None) as batch_op: + if document_id_column is not None: + batch_op.alter_column( + "document_id", + existing_type=document_id_column["type"], + type_=sa.VARCHAR(length=255), + existing_nullable=False, + ) + if "creator_username" not in column_names: + batch_op.add_column( + sa.Column( + "creator_username", sa.VARCHAR(length=64), nullable=True + ) + ) + if "last_modified_by_username" not in column_names: + batch_op.add_column( + sa.Column( + "last_modified_by_username", + sa.VARCHAR(length=64), + nullable=True, + ) + ) + if not _has_fk(inspector, "document_metadata", "document_id"): + batch_op.create_foreign_key( + op.f("fk_document_metadata_document_id_documents"), + "documents", + ["document_id"], + ["id"], + ondelete="CASCADE", + ) + if not _has_fk(inspector, "document_metadata", "creator_username"): + batch_op.create_foreign_key( + op.f("fk_document_metadata_creator_username_users"), + "users", + ["creator_username"], + ["username"], + ondelete="SET NULL", + ) + if not _has_fk( + inspector, "document_metadata", "last_modified_by_username" + ): + batch_op.create_foreign_key( + op.f("fk_document_metadata_last_modified_by_username_users"), + "users", + ["last_modified_by_username"], + ["username"], + ondelete="SET NULL", + ) + + inspector = sa.inspect(conn) + if not _has_index( + inspector, + "document_metadata", + op.f("ix_document_metadata_creator_username"), + ): + op.create_index( + op.f("ix_document_metadata_creator_username"), + "document_metadata", + ["creator_username"], + unique=False, + ) + if not _has_index( + inspector, + "document_metadata", + op.f("ix_document_metadata_last_modified_by_username"), + ): + op.create_index( + op.f("ix_document_metadata_last_modified_by_username"), + "document_metadata", + ["last_modified_by_username"], + unique=False, + ) + + inspector = sa.inspect(conn) + if not inspector.has_table("document_metadata_tags"): + op.create_table( + "document_metadata_tags", + sa.Column("document_id", sa.VARCHAR(length=255), nullable=False), + sa.Column("tag", sa.VARCHAR(length=255), nullable=False), + sa.Column("position", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["document_id"], + ["document_metadata.document_id"], + name=op.f("fk_document_metadata_tags_document_id_document_metadata"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint( + "document_id", "tag", name=op.f("pk_document_metadata_tags") + ), + ) + + inspector = sa.inspect(conn) + if not _has_index( + inspector, "document_metadata_tags", op.f("ix_document_metadata_tags_tag") + ): + op.create_index( + op.f("ix_document_metadata_tags_tag"), + "document_metadata_tags", + ["tag"], + unique=False, + ) + + op.execute( + """ + INSERT INTO document_metadata (document_id) + SELECT id + FROM documents + WHERE NOT EXISTS ( + SELECT 1 + FROM document_metadata + WHERE document_metadata.document_id = documents.id + ) + """ + ) + + user_groups = sa.table("user_groups", sa.column("group_name", sa.String())) + group_permissions = sa.table( + "group_permissions", + sa.column("group_name", sa.String()), + sa.column("permission", sa.String()), + sa.column("granted", sa.Boolean()), + sa.column("start_time", sa.Float()), + sa.column("end_time", sa.Float()), + ) + + sysop_exists = conn.execute( + sa.select(user_groups.c.group_name).where(user_groups.c.group_name == "sysop") + ).first() + if sysop_exists: + for permission in ("view_metadata", "set_metadata_tags"): + permission_exists = conn.execute( + sa.select(group_permissions.c.permission).where( + group_permissions.c.group_name == "sysop", + group_permissions.c.permission == permission, + group_permissions.c.granted == True, + ) + ).first() + if not permission_exists: + conn.execute( + group_permissions.insert().values( + group_name="sysop", + permission=permission, + granted=True, + start_time=0.0, + end_time=None, + ) + ) + + +def downgrade() -> None: + """Downgrade schema.""" + conn = op.get_bind() + group_permissions = sa.table( + "group_permissions", + sa.column("group_name", sa.String()), + sa.column("permission", sa.String()), + ) + conn.execute( + group_permissions.delete().where( + group_permissions.c.group_name == "sysop", + group_permissions.c.permission.in_(["view_metadata", "set_metadata_tags"]), + ) + ) + + op.drop_index( + op.f("ix_document_metadata_tags_tag"), table_name="document_metadata_tags" + ) + op.drop_table("document_metadata_tags") + op.drop_index( + op.f("ix_document_metadata_last_modified_by_username"), + table_name="document_metadata", + ) + op.drop_index( + op.f("ix_document_metadata_creator_username"), table_name="document_metadata" + ) + op.drop_table("document_metadata") diff --git a/src/include/classes/enum/permissions.py b/src/include/classes/enum/permissions.py index f727caf..501ebf9 100644 --- a/src/include/classes/enum/permissions.py +++ b/src/include/classes/enum/permissions.py @@ -58,6 +58,8 @@ class Permissions(StrEnum): # 访问控制与锁定 VIEW_ACCESS_RULES = "view_access_rules" SET_ACCESS_RULES = "set_access_rules" + VIEW_METADATA = "view_metadata" + SET_METADATA_TAGS = "set_metadata_tags" MANAGE_ACCESS = "manage_access" VIEW_ACCESS_ENTRIES = "view_access_entries" APPLY_LOCKDOWN = "apply_lockdown" diff --git a/src/include/constants.py b/src/include/constants.py index 6eb056b..572487a 100644 --- a/src/include/constants.py +++ b/src/include/constants.py @@ -21,8 +21,8 @@ from include.classes.version import Version -CORE_VERSION = Version("0.3.0.260530_alpha") -PROTOCOL_VERSION = 13 +CORE_VERSION = Version("0.3.0.260616_alpha") +PROTOCOL_VERSION = 14 ROOT_ABSPATH = Path(__file__).resolve().parent.parent diff --git a/src/include/database/handler.py b/src/include/database/handler.py index 8c719cd..580f43a 100644 --- a/src/include/database/handler.py +++ b/src/include/database/handler.py @@ -34,6 +34,12 @@ if db_type == "sqlite": db_file = global_config["database"]["file"] engine = create_engine(f"sqlite:///{db_file}", echo=debug_enabled) + + @event.listens_for(engine, "connect") + def _enable_sqlite_foreign_keys(dbapi_connection, _connection_record) -> None: + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() else: username = global_config["database"]["username"] password = global_config["database"]["password"] diff --git a/src/include/database/models/entity/__init__.py b/src/include/database/models/entity/__init__.py new file mode 100644 index 0000000..4acfaaa --- /dev/null +++ b/src/include/database/models/entity/__init__.py @@ -0,0 +1,3 @@ +from .base import * +from .metadata import * +from .obj import * diff --git a/src/include/database/models/entity/base.py b/src/include/database/models/entity/base.py new file mode 100644 index 0000000..6840bc0 --- /dev/null +++ b/src/include/database/models/entity/base.py @@ -0,0 +1,239 @@ +__all__ = ["BaseObject"] + +import time +from typing import List, Literal, Optional, cast + +from sqlalchemy import VARCHAR, Integer +from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm.session import object_session + +from include.classes.access_rule import AccessRuleBase +from include.classes.enum.status import EntityStatus +from include.conf_loader import global_config +from include.constants import AVAILABLE_ACCESS_TYPES +from include.database.handler import Base +from include.database.models.classic import User +from include.util.fetch.fetch import batch_prefetch_granted_ids, prefetch_user_blocks + + +class BaseObject(Base): + __abstract__ = True + + id: Mapped[str] + access_rules: Mapped[List] + + # Whether to inherit access rules from parent folders. + # Useful when enabling recursion check. + inherit: Mapped[bool] + + status: Mapped[EntityStatus] = mapped_column( + Integer, nullable=False, default=EntityStatus.OK + ) + status_operation_id: Mapped[Optional[str]] = mapped_column( + VARCHAR(255), nullable=True, index=True + ) + + def check_access_requirements( + self, user: User, access_type: str = "read", _no_recursive_check=False + ) -> bool: + """ + Checks if a given user meets the access requirements for a specific access type based on defined access rules. + Args: + user (User): The user object whose permissions and groups are to be checked. + access_type (int, optional): The type of access to check for. Defaults to `"read"`. + _no_recursive_check (bool, optional): Useful when performing batch queries. Defaults to False. + Returns: + bool: True if the user meets all access requirements for the specified access type, False otherwise. + Raises: + ValueError: If the "match" value in any rule is not "all" or "any". + Access rules are evaluated as follows: + - Each rule may specify required permissions ("rights") and/or groups ("groups"). + - Each requirement can specify a "match" mode: "all" (all required items must be present) or "any" (at least one must be present). + - Rules are grouped and evaluated according to their match modes and requirements. + - If no access rules are defined, access is granted by default. + """ + + _TARGET_TYPE_MAPPING = {"folders": "directory", "documents": "document"} + + def match_rights(sub_rights_group): + if not sub_rights_group: + return True + + sub_match_mode = sub_rights_group.get("match", "all") + sub_rights_require = sub_rights_group.get("require", []) + + if not sub_rights_require: + return True + + if sub_match_mode == "all": + return set(sub_rights_require).issubset(user.all_permissions) + + elif sub_match_mode == "any": + for right in sub_rights_require: + if right in user.all_permissions: + return True + return False + + else: + raise ValueError('the value of "match" must be "all" or "any"') + + def match_groups(sub_groups_group): + if not sub_groups_group: + return True + + sub_match_mode = sub_groups_group.get("match", "all") + sub_groups_require = sub_groups_group.get("require", []) + + if not sub_groups_require: + return True + + if sub_match_mode == "all": + return set(sub_groups_require).issubset(user.all_groups) + + elif sub_match_mode == "any": + for group in sub_groups_require: + if group in user.all_groups: + return True + return False + else: + raise ValueError('the value of "match" must be "all" or "any"') + + def match_sub_group(sub_group): + sub_match_mode = sub_group.get("match", "all") + sub_rights_group = sub_group.get("rights", {}) + sub_groups_group = sub_group.get("groups", {}) + + if not (sub_rights_group.get("require", [])) or ( + not sub_groups_group.get("require", []) + ): + sub_match_mode = "all" + + if sub_match_mode == "any": + return match_rights(sub_rights_group) or match_groups(sub_groups_group) + if sub_match_mode == "all": + return match_rights(sub_rights_group) and match_groups(sub_groups_group) + else: + raise ValueError('the value of "match" must be "all" or "any"') + + def match_primary_sub_group(per_match_group): + match_mode = per_match_group.get("match", "all") + for sub_group in per_match_group["match_groups"]: + if not sub_group: + continue + + state = match_sub_group(sub_group) + + if match_mode == "any": + if state: + return True + elif match_mode == "all": + if not state: + return False + + if match_mode == "any": + return False + elif match_mode == "all": + return True + + # Checks whether the user or the user group to which he belongs + # has special access rights to this object. + + # Get `session` from `User` object + _session = object_session(user) + if not _session: + raise RuntimeError("No active session found for user") + + now = time.time() + + # check user blocks first + is_globally_blocked, blocked_ids = prefetch_user_blocks( + _session, user, access_type, now + ) + if is_globally_blocked or self.id in blocked_ids: + return False + + # then check special access entries + self_type = cast( + Literal["document", "directory"], _TARGET_TYPE_MAPPING[self.__tablename__] + ) + explicitly_granted_ids = batch_prefetch_granted_ids( + _session, user, [self.id], self_type, access_type, now + ) + + if self.id in explicitly_granted_ids: + return True + + if ( + global_config["access"]["enable_access_recursive_check"] + and self.inherit + and not _no_recursive_check + ): + # FIXME: Use lazy import when Python 3.15 is out + from include.database.models.entity.obj import Document, Folder + + # check all parent folders' access rules + parent = None + if isinstance(self, Document): + parent = self.folder + elif isinstance(self, Folder): + parent = self.parent + + visited_folder_ids = set() + while parent is not None: + if parent.id in visited_folder_ids: + # Cycle detected; break to prevent an infinite loop + raise RuntimeError("Cycle detected in folder hierarchy") + visited_folder_ids.add(parent.id) + + if not parent.check_access_requirements(user, access_type=access_type): + return False + + if not parent.inherit: + break # if the parent folder does not inherit, stop checking further up + + parent = parent.parent + + if not self.access_rules: + return True + + for each_rule in self.access_rules: + if not each_rule: + continue + + each_rule: AccessRuleBase + + # access_type 一览: + # read - 读 + # write - 写(删除=清空数据,重命名=写文件元数据,因此都算写) + # move - 移动 + # manage - 管理 + + if access_type not in AVAILABLE_ACCESS_TYPES: + raise ValueError( + f"Invalid access type for {self.__tablename__}: {access_type}" + ) + + match access_type: + case "read": # 如果要检查的是读权限 + if each_rule.access_type != "read": + continue + case "write": # 如果要检查写权限 + if each_rule.access_type not in ["read", "write"]: + continue + case "move": + # 取消了对读权限的要求 + if each_rule.access_type != "move": + continue + case "manage": # 如果要检查管理权限 + if each_rule.access_type not in ["read", "manage"]: + continue + case _: + raise NotImplementedError("Unsupported access type") + + if not each_rule.rule_data: + continue + + if not match_primary_sub_group(each_rule.rule_data): + return False + + return True diff --git a/src/include/database/models/entity/metadata.py b/src/include/database/models/entity/metadata.py new file mode 100644 index 0000000..41e7db1 --- /dev/null +++ b/src/include/database/models/entity/metadata.py @@ -0,0 +1,68 @@ +__all__ = ["DocumentMetadata", "DocumentMetadataTag"] + +from typing import TYPE_CHECKING, List, Optional + +from sqlalchemy import VARCHAR, ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from include.database.handler import Base + +if TYPE_CHECKING: + from include.database.models.classic import User + from include.database.models.entity.obj import Document + + +class DocumentMetadata(Base): + __tablename__ = "document_metadata" + + document_id: Mapped[str] = mapped_column( + VARCHAR(255), ForeignKey("documents.id", ondelete="CASCADE"), primary_key=True + ) + creator_username: Mapped[Optional[str]] = mapped_column( + VARCHAR(64), + ForeignKey("users.username", ondelete="SET NULL"), + nullable=True, + index=True, + ) + last_modified_by_username: Mapped[Optional[str]] = mapped_column( + VARCHAR(64), + ForeignKey("users.username", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + document: Mapped["Document"] = relationship( + "Document", + back_populates="metadata_record", + foreign_keys=[document_id], + ) + creator: Mapped[Optional["User"]] = relationship( + "User", foreign_keys=[creator_username] + ) + last_modified_by: Mapped[Optional["User"]] = relationship( + "User", foreign_keys=[last_modified_by_username] + ) + tags: Mapped[List["DocumentMetadataTag"]] = relationship( + "DocumentMetadataTag", + back_populates="metadata_record", + cascade="all, delete-orphan", + order_by="DocumentMetadataTag.position", + ) + + +class DocumentMetadataTag(Base): + __tablename__ = "document_metadata_tags" + + document_id: Mapped[str] = mapped_column( + VARCHAR(255), + ForeignKey("document_metadata.document_id", ondelete="CASCADE"), + primary_key=True, + ) + tag: Mapped[str] = mapped_column(VARCHAR(255), primary_key=True, index=True) + position: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + metadata_record: Mapped["DocumentMetadata"] = relationship( + "DocumentMetadata", + back_populates="tags", + foreign_keys=[document_id], + ) diff --git a/src/include/database/models/entity.py b/src/include/database/models/entity/obj.py similarity index 55% rename from src/include/database/models/entity.py rename to src/include/database/models/entity/obj.py index 1191c1b..5549cb7 100644 --- a/src/include/database/models/entity.py +++ b/src/include/database/models/entity/obj.py @@ -1,301 +1,36 @@ +__all__ = [ + "Document", + "DocumentRevision", + "DocumentAccessRule", + "Folder", + "FolderAccessRule", +] + import secrets import time from itertools import batched -from typing import Iterable, List, Literal, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, VARCHAR, Boolean, Float, ForeignKey, Integer, func -from sqlalchemy.orm import Mapped, Session, mapped_column, relationship +from sqlalchemy import JSON, VARCHAR, Boolean, Float, ForeignKey, Integer +from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy.orm.session import object_session from include.classes.access_rule import AccessRuleBase from include.classes.enum.status import DocumentRevisionStatus, EntityStatus -from include.conf_loader import global_config -from include.constants import AVAILABLE_ACCESS_TYPES, MAX_PARAM_SIZE, QUERY_CHUNK_SIZE +from include.constants import QUERY_CHUNK_SIZE from include.database.handler import Base -from include.database.models.classic import User +from include.database.models.entity.base import BaseObject from include.database.models.file import ( File, FileTask, _queue_deferred_file_deletion, ) from include.exceptions.misc import NoActiveRevisionsError +from include.util.bulk.count import batch_count_other_revisions from include.util.count import count_file_references -from include.util.fetch.fetch import batch_prefetch_granted_ids, prefetch_user_blocks - - -def _batch_count_other_revisions( - session: Session, - file_ids: Iterable[str], - exclude_doc_ids: Union[str, Iterable[str]], -) -> dict[str, int]: - """ - 计算引用了指定 file_ids 的修订版本数量,但排除属于 exclude_doc_ids 集合的文档。 - - 参数: - file_ids: 待检查的文件 ID 列表。 - exclude_doc_ids: 单个文档 ID 或文档 ID 集合。这些文档对文件的引用将不被计入。 - """ - # Return total references to each file EXCLUDING references that come - # from the provided exclude_doc_ids set. This centralizes counting via - # `count_file_references` so new FK references are automatically handled. - - # Materialize iterables so they can be safely iterated multiple times. - file_ids_list = list(file_ids) - if not file_ids_list: - return {} - - if isinstance(exclude_doc_ids, str): - exclude_doc_ids_list = [exclude_doc_ids] - else: - exclude_doc_ids_list = list(exclude_doc_ids) - - # Get total references across all tables (uses reflected FKs). - total_refs = count_file_references(session, file_ids_list) - - # Count references coming from the excluded documents (DocumentRevision only). - # The query combines `file_id IN (...)` and `document_id IN (...)`, so both - # dimensions must be chunked to keep total bind variables under MAX_PARAM_SIZE. - exclude_chunk_size = max(1, MAX_PARAM_SIZE - QUERY_CHUNK_SIZE) - excluded_counts: dict[str, int] = {} - for f_chunk in batched(file_ids_list, QUERY_CHUNK_SIZE): - for e_chunk in batched(exclude_doc_ids_list, exclude_chunk_size): - rows = ( - session.query(DocumentRevision.file_id, func.count(DocumentRevision.id)) - .filter(DocumentRevision.file_id.in_(list(f_chunk))) - .filter(DocumentRevision.document_id.in_(list(e_chunk))) - .group_by(DocumentRevision.file_id) - .all() - ) - for file_id, count in rows: - excluded_counts[file_id] = excluded_counts.get(file_id, 0) + count - - counts: dict[str, int] = {} - for fid in file_ids_list: - total = total_refs.get(fid, 0) - excluded = excluded_counts.get(fid, 0) - counts[fid] = max(0, total - excluded) - - return counts - - -class BaseObject(Base): - __abstract__ = True - - id: Mapped[str] - access_rules: Mapped[List] - - # Whether to inherit access rules from parent folders. - # Useful when enabling recursion check. - inherit: Mapped[bool] - - status: Mapped[EntityStatus] = mapped_column( - Integer, nullable=False, default=EntityStatus.OK - ) - status_operation_id: Mapped[Optional[str]] = mapped_column( - VARCHAR(255), nullable=True, index=True - ) - - def check_access_requirements( - self, user: User, access_type: str = "read", _no_recursive_check=False - ) -> bool: - """ - Checks if a given user meets the access requirements for a specific access type based on defined access rules. - Args: - user (User): The user object whose permissions and groups are to be checked. - access_type (int, optional): The type of access to check for. Defaults to `"read"`. - _no_recursive_check (bool, optional): Useful when performing batch queries. Defaults to False. - Returns: - bool: True if the user meets all access requirements for the specified access type, False otherwise. - Raises: - ValueError: If the "match" value in any rule is not "all" or "any". - Access rules are evaluated as follows: - - Each rule may specify required permissions ("rights") and/or groups ("groups"). - - Each requirement can specify a "match" mode: "all" (all required items must be present) or "any" (at least one must be present). - - Rules are grouped and evaluated according to their match modes and requirements. - - If no access rules are defined, access is granted by default. - """ - - _TARGET_TYPE_MAPPING = {"folders": "directory", "documents": "document"} - - def match_rights(sub_rights_group): - if not sub_rights_group: - return True - - sub_match_mode = sub_rights_group.get("match", "all") - sub_rights_require = sub_rights_group.get("require", []) - - if not sub_rights_require: - return True - - if sub_match_mode == "all": - return set(sub_rights_require).issubset(user.all_permissions) - - elif sub_match_mode == "any": - for right in sub_rights_require: - if right in user.all_permissions: - return True - return False - - else: - raise ValueError('the value of "match" must be "all" or "any"') - - def match_groups(sub_groups_group): - if not sub_groups_group: - return True - - sub_match_mode = sub_groups_group.get("match", "all") - sub_groups_require = sub_groups_group.get("require", []) - - if not sub_groups_require: - return True - - if sub_match_mode == "all": - return set(sub_groups_require).issubset(user.all_groups) - - elif sub_match_mode == "any": - for group in sub_groups_require: - if group in user.all_groups: - return True - return False - else: - raise ValueError('the value of "match" must be "all" or "any"') - - def match_sub_group(sub_group): - sub_match_mode = sub_group.get("match", "all") - sub_rights_group = sub_group.get("rights", {}) - sub_groups_group = sub_group.get("groups", {}) - - if not (sub_rights_group.get("require", [])) or ( - not sub_groups_group.get("require", []) - ): - sub_match_mode = "all" - - if sub_match_mode == "any": - return match_rights(sub_rights_group) or match_groups(sub_groups_group) - if sub_match_mode == "all": - return match_rights(sub_rights_group) and match_groups(sub_groups_group) - else: - raise ValueError('the value of "match" must be "all" or "any"') - - def match_primary_sub_group(per_match_group): - match_mode = per_match_group.get("match", "all") - for sub_group in per_match_group["match_groups"]: - if not sub_group: - continue - - state = match_sub_group(sub_group) - - if match_mode == "any": - if state: - return True - elif match_mode == "all": - if not state: - return False - - if match_mode == "any": - return False - elif match_mode == "all": - return True - # Checks whether the user or the user group to which he belongs - # has special access rights to this object. - - # Get `session` from `User` object - _session = object_session(user) - if not _session: - raise RuntimeError("No active session found for user") - - now = time.time() - - # check user blocks first - is_globally_blocked, blocked_ids = prefetch_user_blocks( - _session, user, access_type, now - ) - if is_globally_blocked or self.id in blocked_ids: - return False - - # then check special access entries - self_type = cast( - Literal["document", "directory"], _TARGET_TYPE_MAPPING[self.__tablename__] - ) - explicitly_granted_ids = batch_prefetch_granted_ids( - _session, user, [self.id], self_type, access_type, now - ) - - if self.id in explicitly_granted_ids: - return True - - if ( - global_config["access"]["enable_access_recursive_check"] - and self.inherit - and not _no_recursive_check - ): - # check all parent folders' access rules - parent = None - if isinstance(self, Document): - parent = self.folder - elif isinstance(self, Folder): - parent = self.parent - - visited_folder_ids = set() - while parent is not None: - if parent.id in visited_folder_ids: - # Cycle detected; break to prevent an infinite loop - raise RuntimeError("Cycle detected in folder hierarchy") - visited_folder_ids.add(parent.id) - - if not parent.check_access_requirements(user, access_type=access_type): - return False - - if not parent.inherit: - break # if the parent folder does not inherit, stop checking further up - - parent = parent.parent - - if not self.access_rules: - return True - - for each_rule in self.access_rules: - if not each_rule: - continue - - each_rule: AccessRuleBase - - # access_type 一览: - # read - 读 - # write - 写(删除=清空数据,重命名=写文件元数据,因此都算写) - # move - 移动 - # manage - 管理 - - if access_type not in AVAILABLE_ACCESS_TYPES: - raise ValueError( - f"Invalid access type for {self.__tablename__}: {access_type}" - ) - - match access_type: - case "read": # 如果要检查的是读权限 - if each_rule.access_type != "read": - continue - case "write": # 如果要检查写权限 - if each_rule.access_type not in ["read", "write"]: - continue - case "move": - # 取消了对读权限的要求 - if each_rule.access_type != "move": - continue - case "manage": # 如果要检查管理权限 - if each_rule.access_type not in ["read", "manage"]: - continue - case _: - raise NotImplementedError("Unsupported access type") - - if not each_rule.rule_data: - continue - - if not match_primary_sub_group(each_rule.rule_data): - return False - - return True +if TYPE_CHECKING: + from include.database.models.entity.metadata import DocumentMetadata class Folder(BaseObject): # 文档文件夹 @@ -409,6 +144,12 @@ class Document(BaseObject): overlaps="current_revision", # 声明与 current_revision 的重叠 ) inherit: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) + metadata_record: Mapped[Optional["DocumentMetadata"]] = relationship( + "DocumentMetadata", + back_populates="document", + cascade="all, delete-orphan", + uselist=False, + ) @property def active(self): @@ -472,7 +213,7 @@ def delete_all_revisions(self, do_commit: bool = True): all_file_ids = {row[1] for row in revision_tuples if row[1]} # Task 3: Chunked batch reference count queries to avoid variable limit. - other_counts = _batch_count_other_revisions(session, all_file_ids, self.id) + other_counts = batch_count_other_revisions(session, all_file_ids, self.id) # Determine which files are exclusively referenced by this document's revisions. deletable_file_ids = { diff --git a/src/include/handlers/document.py b/src/include/handlers/document.py index 858bccd..1dd03f3 100644 --- a/src/include/handlers/document.py +++ b/src/include/handlers/document.py @@ -8,6 +8,7 @@ "RequestDownloadFileHandler", "RequestUploadFileHandler", "RequestSetDocumentRulesHandler", + "RequestSetDocumentTagsHandler", "RequestMoveDocumentHandler", ] @@ -26,11 +27,13 @@ from include.database.models.classic import User from include.database.models.entity import ( Document, + DocumentMetadata, + DocumentMetadataTag, DocumentRevision, Folder, - NoActiveRevisionsError, ) from include.database.models.file import File, FileTask +from include.exceptions.misc import NoActiveRevisionsError from include.handlers.base import RequestHandler from include.system.messages import Messages as smsg from include.util.check import ( @@ -77,6 +80,32 @@ def create_file_task(file: File, transfer_mode: int = 0): } +def get_or_create_document_metadata(document: Document) -> DocumentMetadata: + if document.metadata_record is None: + document.metadata_record = DocumentMetadata() + return document.metadata_record + + +def mark_document_modified(document: Document, username: str) -> None: + get_or_create_document_metadata(document).last_modified_by_username = username + + +def serialize_document_metadata(document: Document) -> dict: + metadata = document.metadata_record + if metadata is None: + return { + "tags": [], + "creator": None, + "last_modified_by": None, + } + + return { + "tags": [tag.tag for tag in metadata.tags], + "creator": metadata.creator_username, + "last_modified_by": metadata.last_modified_by_username, + } + + class RequestGetDocumentInfoHandler(RequestHandler): """ Handles the "get_document_info" action. @@ -145,6 +174,9 @@ def handle(self, handler: ConnectionHandler): "info_code": info_code, } + if Permissions.VIEW_METADATA in user.all_permissions: + data["metadata"] = serialize_document_metadata(document) + handler.conclude_request(200, data, "Document info retrieved successfully") return 0, document_id, handler.username @@ -308,6 +340,10 @@ def handle(self, handler: ConnectionHandler): title=title, folder_id=folder_id, ) + new_document.metadata_record = DocumentMetadata( + creator_username=user.username, + last_modified_by_username=user.username, + ) new_revision = DocumentRevision(file_id=new_file.id) new_document.revisions.append(new_revision) @@ -417,11 +453,12 @@ def handle(self, handler: ConnectionHandler): file_id=file_id, parent_revision_id=parent_revision_id, ) - document.revisions.append(new_revision) - session.add(new_file) session.add(new_revision) + document.revisions.append(new_revision) + mark_document_modified(document, this_user.username) + document.current_revision = new_revision session.commit() @@ -475,6 +512,7 @@ def handle(self, handler: ConnectionHandler): document.status_operation_id = ( f"OP_DEL_{secrets.token_hex(8)}_{int(time.time())}" ) + mark_document_modified(document, user.username) session.commit() handler.conclude_request(200, {}, "Document successfully deleted") @@ -551,6 +589,7 @@ def handle(self, handler: ConnectionHandler): return err_code, document.folder_id, handler.username document.title = new_title + mark_document_modified(document, this_user.username) session.commit() handler.conclude_request( @@ -702,6 +741,7 @@ def handle(self, handler: ConnectionHandler): if apply_access_rules( document, access_rules_to_apply, user, inherit_parent ): + mark_document_modified(document, user.username) session.commit() handler.conclude_request(200, {}, "Set access rules successfully") return 0, document_id, handler.username @@ -845,6 +885,7 @@ def handle(self, handler: ConnectionHandler): return err_code, document.folder_id, handler.username document.folder = target_folder + mark_document_modified(document, user.username) session.commit() @@ -897,6 +938,7 @@ def handle(self, handler: ConnectionHandler): return document.delete_all_revisions(do_commit=False) + mark_document_modified(document, user.username) session.delete(document) session.commit() @@ -1007,6 +1049,7 @@ def handle(self, handler: ConnectionHandler): document.status_operation_id = None document.title = final_title document.folder_id = db_folder_id + mark_document_modified(document, user.username) session.commit() @@ -1019,3 +1062,78 @@ def handle(self, handler: ConnectionHandler): "Document successfully restored", ) return 0, doc_id, handler.username + + +class RequestSetDocumentTagsHandler(RequestHandler): + schema = { + "type": "object", + "properties": { + "document_id": {"type": "string", "minLength": 1}, + "tags": { + "type": "array", + "maxItems": 128, + "items": {"type": "string", "minLength": 1, "maxLength": 255}, + }, + }, + "required": ["document_id", "tags"], + "additionalProperties": False, + } + + require_auth = True + + def handle(self, handler: ConnectionHandler): + document_id: str = handler.data["document_id"] + + normalized_tags = [] + seen_tags = set() + for raw_tag in handler.data["tags"]: + tag = raw_tag.strip() + if not tag: + handler.conclude_request(400, {}, "Tags cannot be blank") + return 400, document_id, handler.username + if tag not in seen_tags: + normalized_tags.append(tag) + seen_tags.add(tag) + + with Session() as session: + user = User.get_existing(session, handler.username) + document = session.get(Document, document_id) + + if not document: + handler.conclude_request(404, {}, smsg.DOCUMENT_NOT_FOUND) + return 404, document_id, handler.username + + if ( + Permissions.SET_METADATA_TAGS not in user.all_permissions + or not document.check_access_requirements(user, access_type="write") + ): + handler.conclude_access_denial() + return 403, document_id, handler.username + + metadata = get_or_create_document_metadata(document) + existing_by_tag = { + tag_record.tag: tag_record for tag_record in metadata.tags + } + requested_tag_set = set(normalized_tags) + + for tag_record in list(metadata.tags): + if tag_record.tag not in requested_tag_set: + metadata.tags.remove(tag_record) + + for position, tag in enumerate(normalized_tags): + if tag in existing_by_tag: + existing_by_tag[tag].position = position + else: + metadata.tags.append( + DocumentMetadataTag(tag=tag, position=position) + ) + + metadata.last_modified_by_username = user.username + session.commit() + + handler.conclude_request( + 200, + {"tags": normalized_tags}, + "Document metadata tags updated successfully", + ) + return 0, document_id, {"tags": normalized_tags}, handler.username diff --git a/src/include/handlers/revision.py b/src/include/handlers/revision.py index 75e3f18..324dea7 100644 --- a/src/include/handlers/revision.py +++ b/src/include/handlers/revision.py @@ -4,7 +4,7 @@ from include.database.models.classic import User from include.database.models.entity import Document, DocumentRevision from include.handlers.base import RequestHandler -from include.handlers.document import create_file_task +from include.handlers.document import create_file_task, mark_document_modified from include.system.messages import Messages as smsg @@ -128,6 +128,7 @@ def handle(self, handler: ConnectionHandler): return 403, document_id, handler.username document.current_revision_id = revision.id + mark_document_modified(document, user.username) session.commit() handler.conclude_request(200, {}, "Current revision set successfully") @@ -179,6 +180,7 @@ def handle(self, handler: ConnectionHandler): child_rev.parent_revision = revision.parent_revision revision.before_delete() + mark_document_modified(document, user.username) session.delete(revision) session.commit() diff --git a/src/include/handlers/search.py b/src/include/handlers/search.py index 5024427..b67c7c5 100644 --- a/src/include/handlers/search.py +++ b/src/include/handlers/search.py @@ -15,7 +15,7 @@ from include.conf_loader import global_config from include.database.handler import Session from include.database.models.classic import User -from include.database.models.entity import NoActiveRevisionsError +from include.exceptions.misc import NoActiveRevisionsError from include.handlers.base import RequestHandler from include.util.fetch.fetch import batch_prefetch_granted_ids, prefetch_user_blocks from include.util.recursive.ancestors import ( diff --git a/src/include/router.py b/src/include/router.py index cd2a9b1..bf1aa07 100644 --- a/src/include/router.py +++ b/src/include/router.py @@ -42,6 +42,7 @@ RequestRenameDocumentHandler, RequestRestoreDocumentHandler, RequestSetDocumentRulesHandler, + RequestSetDocumentTagsHandler, RequestUploadDocumentHandler, RequestUploadFileHandler, ) @@ -129,6 +130,7 @@ "get_document_info": RequestGetDocumentInfoHandler, "get_document_access_rules": RequestGetDocumentAccessRulesHandler, "set_document_rules": RequestSetDocumentRulesHandler, + "set_document_tags": RequestSetDocumentTagsHandler, # 修订版本类 "list_revisions": RequestListRevisionsHandler, "get_revision": RequestGetRevisionHandler, diff --git a/src/include/util/bulk/count.py b/src/include/util/bulk/count.py new file mode 100644 index 0000000..4f860d8 --- /dev/null +++ b/src/include/util/bulk/count.py @@ -0,0 +1,69 @@ +from itertools import batched +from typing import Iterable, Union + +from sqlalchemy import func +from sqlalchemy.orm import Session + +from include.constants import MAX_PARAM_SIZE, QUERY_CHUNK_SIZE +from include.util.count import count_file_references + + +def batch_count_other_revisions( + session: Session, + file_ids: Iterable[str], + exclude_doc_ids: Union[str, Iterable[str]], +) -> dict[str, int]: + """ + Return total references to each file EXCLUDING references from exclude_doc_ids. + + This centralizes counting via `count_file_references` so new FK references + are automatically handled. + + Args: + session: Database session. + file_ids: File IDs to check. + exclude_doc_ids: Document ID or IDs to exclude from count. + + Returns: + Dict mapping file_id to reference count excluding specified documents. + """ + # FIXME: Use lazy import when Python 3.15 is out + from include.database.models.entity.obj import DocumentRevision + + # Materialize iterables so they can be safely iterated multiple times. + file_ids_list = list(file_ids) + if not file_ids_list: + return {} + + if isinstance(exclude_doc_ids, str): + exclude_doc_ids_list = [exclude_doc_ids] + else: + exclude_doc_ids_list = list(exclude_doc_ids) + + # Get total references across all tables (uses reflected FKs). + total_refs = count_file_references(session, file_ids_list) + + # Count references coming from the excluded documents (DocumentRevision only). + # The query combines `file_id IN (...)` and `document_id IN (...)`, so both + # dimensions must be chunked to keep total bind variables under MAX_PARAM_SIZE. + exclude_chunk_size = max(1, MAX_PARAM_SIZE - QUERY_CHUNK_SIZE) + excluded_counts: dict[str, int] = {} + for f_chunk in batched(file_ids_list, QUERY_CHUNK_SIZE): + for e_chunk in batched(exclude_doc_ids_list, exclude_chunk_size): + rows = ( + session.query(DocumentRevision.file_id, func.count(DocumentRevision.id)) + .filter(DocumentRevision.file_id.in_(list(f_chunk))) + .filter(DocumentRevision.document_id.in_(list(e_chunk))) + .group_by(DocumentRevision.file_id) + .all() + ) + for file_id, count in rows: + excluded_counts[file_id] = excluded_counts.get(file_id, 0) + count + + counts: dict[str, int] = {} + for fid in file_ids_list: + total = total_refs.get(fid, 0) + excluded = excluded_counts.get(fid, 0) + counts[fid] = max(0, total - excluded) + + return counts diff --git a/src/include/util/bulk/purge.py b/src/include/util/bulk/purge.py index 2f0f9c5..537d791 100644 --- a/src/include/util/bulk/purge.py +++ b/src/include/util/bulk/purge.py @@ -7,9 +7,9 @@ from include.database.models.entity import ( Document, DocumentRevision, - _batch_count_other_revisions, ) from include.database.models.file import File, FileTask, _queue_deferred_file_deletion +from include.util.bulk.count import batch_count_other_revisions def purge_documents_bulk(session: Session, document_ids: List[str]): @@ -42,7 +42,7 @@ def purge_documents_bulk(session: Session, document_ids: List[str]): # 2. 批量引用计数检查 # 使用集中计数,排除来自这批文档的引用后若为 0 则可删除 - other_counts = _batch_count_other_revisions(session, list(file_ids), document_ids) + other_counts = batch_count_other_revisions(session, list(file_ids), document_ids) # 找出仅被这批文档引用、可以物理删除的文件 ID deletable_file_ids = [fid for fid in file_ids if other_counts.get(fid, 0) == 0] diff --git a/src/main.py b/src/main.py index 638a938..0b7b518 100644 --- a/src/main.py +++ b/src/main.py @@ -26,7 +26,8 @@ ROOT_DIRECTORY_ID, ) from include.database.handler import Base, Session, engine -from include.database.models.entity import Document, DocumentRevision, Folder +from include.database.models.entity import DocumentMetadata +from include.database.models.entity.obj import Document, DocumentRevision, Folder from include.database.models.file import File from include.handlers.debugging.throw import RequestThrowExceptionHandler from include.providers.manager import ProviderManager @@ -133,6 +134,8 @@ def server_init(): {"permission": Permissions.SUPER_SET_PASSWD}, {"permission": Permissions.VIEW_ACCESS_RULES}, {"permission": Permissions.SET_ACCESS_RULES}, + {"permission": Permissions.VIEW_METADATA}, + {"permission": Permissions.SET_METADATA_TAGS}, {"permission": Permissions.LIST_USERS}, {"permission": Permissions.LIST_GROUPS}, {"permission": Permissions.CREATE_GROUP}, @@ -179,6 +182,7 @@ def server_init(): init_document = Document( id="hello", title="Hello World", folder_id=ROOT_DIRECTORY_ID ) + init_document.metadata_record = DocumentMetadata() init_document_revision = DocumentRevision(file_id=init_file.id) init_document.revisions.append(init_document_revision) init_document.current_revision = init_document_revision diff --git a/tests/test_client.py b/tests/test_client.py index 592456f..acdc36d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -549,6 +549,14 @@ async def get_document_info(self, document_id: str) -> Dict[str, Any]: "get_document_info", {"document_id": document_id} ) + async def set_document_tags( + self, document_id: str, tags: list[str] + ) -> Dict[str, Any]: + return await self.send_request( + "set_document_tags", + {"document_id": document_id, "tags": tags}, + ) + # --- Revisions --- async def list_revisions(self, document_id: str) -> Dict[str, Any]: return await self.send_request("list_revisions", {"document_id": document_id}) diff --git a/tests/test_documents.py b/tests/test_documents.py index 25ee492..d1c47a6 100644 --- a/tests/test_documents.py +++ b/tests/test_documents.py @@ -1,3 +1,6 @@ +import secrets +import time + import pytest from tests.test_client import CFMSTestClient @@ -35,6 +38,112 @@ async def test_get_document_info( data = assert_success(response) assert isinstance(data, dict) + @pytest.mark.asyncio + async def test_document_metadata_tags( + self, authenticated_client: CFMSTestClient, test_document: dict + ): + document_id = test_document["document_id"] + + info_response = await authenticated_client.get_document_info(document_id) + info = assert_success(info_response) + assert info["metadata"] == { + "tags": [], + "creator": "admin", + "last_modified_by": "admin", + } + + set_response = await authenticated_client.set_document_tags( + document_id, + ["secret", "finance", "secret", " topic "], + ) + data = assert_success(set_response) + assert data["tags"] == ["secret", "finance", "topic"] + + info_response = await authenticated_client.get_document_info(document_id) + info = assert_success(info_response) + assert info["metadata"]["tags"] == ["secret", "finance", "topic"] + assert info["metadata"]["creator"] == "admin" + assert info["metadata"]["last_modified_by"] == "admin" + + @pytest.mark.asyncio + async def test_get_document_info_omits_metadata_without_permission( + self, + authenticated_client: CFMSTestClient, + test_document: dict, + user_factory, + ): + document_id = test_document["document_id"] + + set_response = await authenticated_client.set_document_tags( + document_id, ["restricted"] + ) + assert_success(set_response) + + user = await user_factory() + grant_response = await authenticated_client.grant_access( + entity_type="user", + entity_identifier=user["username"], + target_type="document", + target_identifier=document_id, + access_types=["read"], + start_time=time.time(), + ) + assert_success(grant_response) + + client = CFMSTestClient() + await client.connect() + try: + login_response = await client.login(user["username"], user["password"]) + assert_success(login_response) + + info_response = await client.get_document_info(document_id) + info = assert_success(info_response) + assert "metadata" not in info + finally: + await client.disconnect() + + @pytest.mark.asyncio + async def test_set_document_tags_requires_permission( + self, + authenticated_client: CFMSTestClient, + test_document: dict, + server_process, + ): + username = f"metadata_user_{secrets.token_hex(4)}" + password = "TestPassword123!" + document_id = test_document["document_id"] + + response = await authenticated_client.create_user( + username=username, + password=password, + nickname="Metadata User", + ) + assert_success(response) + + try: + grant_response = await authenticated_client.grant_access( + entity_type="user", + entity_identifier=username, + target_type="document", + target_identifier=document_id, + access_types=["write"], + start_time=time.time(), + ) + assert_success(grant_response) + + client = CFMSTestClient() + await client.connect() + try: + login_response = await client.login(username, password) + assert_success(login_response) + + set_response = await client.set_document_tags(document_id, ["blocked"]) + assert_error(set_response, 403) + finally: + await client.disconnect() + finally: + await authenticated_client.delete_user(username) + @pytest.mark.asyncio async def test_rename_document( self, authenticated_client: CFMSTestClient, test_document: dict