diff --git a/apps/api/.env.example b/apps/api/.env.example index 820e849a..c16deb1b 100644 --- a/apps/api/.env.example +++ b/apps/api/.env.example @@ -95,8 +95,19 @@ MAX_FILE_SIZE=314572800 MAX_PDF_PAGE_LIMIT=200 OVERSIZED_PDF_SHARD_ENABLED=true OVERSIZED_PDF_SOFT_LIMIT=1500 -PDF_PROFILE_TOC_ENABLED=false +PDF_PAGE_TOC_ENABLED=true +RETRIEVAL_PAGE_MEMORY_ENABLED=false MINERU_SHARD_CONCURRENCY=3 +PARSE_AGENT_PLAN_BUDGET=50000 +PARSE_AGENT_VISUAL_BUDGET=80000 +PARSE_AGENT_TOC_CONFIRM_MIN_BUDGET=8000 +PARSE_AGENT_TOC_CONFIRM_CAP=24000 +PARSE_AGENT_COARSE_PLANNER_MIN_BUDGET=12000 +PARSE_AGENT_COARSE_PLANNER_CAP=36000 +PARSE_AGENT_STRUCTURAL_REACT_MIN_BUDGET=24000 +PARSE_AGENT_STRUCTURAL_REACT_CAP=64000 +PARSE_AGENT_PAGE_TAGGING_MIN_BUDGET=0 +PARSE_AGENT_PAGE_TAGGING_CAP=0 # Required for specific features: webhooks and callbacks WEBHOOK_MASTER_KEY= diff --git a/apps/api/alembic/versions/f9a0b1c2d3e4_add_doc_profile_to_document_page_plan.py b/apps/api/alembic/versions/f9a0b1c2d3e4_add_doc_profile_to_document_page_plan.py new file mode 100644 index 00000000..35d0543f --- /dev/null +++ b/apps/api/alembic/versions/f9a0b1c2d3e4_add_doc_profile_to_document_page_plan.py @@ -0,0 +1,29 @@ +"""add doc profile to document page plan + +Revision ID: f9a0b1c2d3e4 +Revises: f8a9b0c1d2e3 +Create Date: 2026-06-11 09:50:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "f9a0b1c2d3e4" +down_revision: Union[str, Sequence[str], None] = "f8a9b0c1d2e3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "document_page_plan", + sa.Column("doc_profile", sa.JSON(), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("document_page_plan", "doc_profile") diff --git a/apps/api/alembic/versions/f9b0c1d2e3f4_add_parse_track_to_documents.py b/apps/api/alembic/versions/f9b0c1d2e3f4_add_parse_track_to_documents.py new file mode 100644 index 00000000..11193be0 --- /dev/null +++ b/apps/api/alembic/versions/f9b0c1d2e3f4_add_parse_track_to_documents.py @@ -0,0 +1,35 @@ +"""add parse track to documents + +Revision ID: f9b0c1d2e3f4 +Revises: f9a0b1c2d3e4 +Create Date: 2026-06-11 10:05:00.000000 + +""" + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +revision: str = "f9b0c1d2e3f4" +down_revision: Union[str, Sequence[str], None] = "f9a0b1c2d3e4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "documents", + sa.Column( + "parse_track", + sa.String(length=32), + nullable=False, + server_default="chunk", + ), + ) + op.alter_column("documents", "parse_track", server_default=None) + + +def downgrade() -> None: + op.drop_column("documents", "parse_track") diff --git a/apps/api/app/api/v1/routes/documents.py b/apps/api/app/api/v1/routes/documents.py index a77c5d12..a99e849e 100644 --- a/apps/api/app/api/v1/routes/documents.py +++ b/apps/api/app/api/v1/routes/documents.py @@ -17,7 +17,7 @@ router = APIRouter(tags=["Documents"]) _document_service = DocumentService() -DocumentChunkType = Literal["text", "image", "table"] +DocumentChunkType = Literal["text", "image", "table", "page"] async def _archive_document_response( diff --git a/apps/api/app/services/document_ingestion/service.py b/apps/api/app/services/document_ingestion/service.py index 7313c07f..0bc17bd3 100644 --- a/apps/api/app/services/document_ingestion/service.py +++ b/apps/api/app/services/document_ingestion/service.py @@ -196,6 +196,15 @@ async def _validate_create_payload(self, payload: JobCreate) -> None: } ], ) + _validate_parse_track_for_extension( + parse_track=payload.parse_track, + file_extension=file_extension, + ) + elif payload.file_name: + _validate_parse_track_for_extension( + parse_track=payload.parse_track, + file_extension=os.path.splitext(payload.file_name)[1].lower(), + ) async def _resolve_scope( self, @@ -261,3 +270,35 @@ def _is_supported_file_name(file_name: str) -> bool: return False file_extension = os.path.splitext(file_name)[1].lower() return file_extension in settings.get_supported_extensions() + + +def _validate_parse_track_for_extension(*, parse_track: str, file_extension: str) -> None: + if parse_track == "chunk": + return + if parse_track != "page_memory": + raise ValidationException( + user_message="Unsupported parse_track", + violations=[ + {"field": "parse_track", "description": "Must be chunk or page_memory"} + ], + ) + if not settings.RETRIEVAL_PAGE_MEMORY_ENABLED: + raise ValidationException( + user_message="page_memory parse track is not enabled", + violations=[ + { + "field": "parse_track", + "description": "page_memory is disabled by configuration", + } + ], + ) + if file_extension.lower() not in {".pdf", ".pptx"}: + raise ValidationException( + user_message="page_memory parse track only supports PDF and PPTX", + violations=[ + { + "field": "parse_track", + "description": "Allowed file types in this build: .pdf, .pptx", + } + ], + ) diff --git a/apps/api/tests/contract/test_documents_contract.py b/apps/api/tests/contract/test_documents_contract.py index 6016aaeb..124feba9 100644 --- a/apps/api/tests/contract/test_documents_contract.py +++ b/apps/api/tests/contract/test_documents_contract.py @@ -42,6 +42,7 @@ async def _insert_document( status, current_job_result_id, source_file_name, + parse_track, created_at, updated_at, archived_at @@ -52,6 +53,7 @@ async def _insert_document( :status, :current_job_result_id, :source_file_name, + :parse_track, :created_at, :updated_at, :archived_at @@ -64,6 +66,7 @@ async def _insert_document( "status": status, "current_job_result_id": None, "source_file_name": source_file_name or f"{document_id}.pdf", + "parse_track": "chunk", "created_at": timestamp, "updated_at": effective_updated_at, "archived_at": ( @@ -268,6 +271,7 @@ async def _insert_document_revision_with_chunks( status, current_job_result_id, source_file_name, + parse_track, created_at, updated_at, archived_at @@ -278,6 +282,7 @@ async def _insert_document_revision_with_chunks( 'active', NULL, :source_file_name, + :parse_track, :created_at, :updated_at, NULL @@ -288,6 +293,7 @@ async def _insert_document_revision_with_chunks( "user_id": user_id, "namespace": namespace, "source_file_name": source_file_name, + "parse_track": "chunk", "created_at": timestamp, "updated_at": timestamp, }, diff --git a/apps/api/tests/contract/test_job_creation_contract.py b/apps/api/tests/contract/test_job_creation_contract.py index 57248835..34ee54ac 100644 --- a/apps/api/tests/contract/test_job_creation_contract.py +++ b/apps/api/tests/contract/test_job_creation_contract.py @@ -80,6 +80,7 @@ async def _insert_document( namespace, status, source_file_name, + parse_track, created_at, updated_at, archived_at @@ -89,6 +90,7 @@ async def _insert_document( :namespace, :status, :source_file_name, + :parse_track, :created_at, :updated_at, :archived_at @@ -101,6 +103,7 @@ async def _insert_document( "namespace": namespace, "status": status, "source_file_name": f"{document_id}.pdf", + "parse_track": "chunk", "created_at": timestamp, "updated_at": timestamp, "archived_at": timestamp if status == "archived" else None, diff --git a/apps/api/tests/contract/test_page_memory_parse_track_contract.py b/apps/api/tests/contract/test_page_memory_parse_track_contract.py new file mode 100644 index 00000000..c670e6dd --- /dev/null +++ b/apps/api/tests/contract/test_page_memory_parse_track_contract.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import os + +import pytest + +os.environ.setdefault("DATABASE_URL", "postgresql+asyncpg://test:test@localhost/test") +os.environ.setdefault("TMP_PATH", "/tmp/knowhere-test") +os.environ.setdefault("S3_BUCKET_NAME", "test-uploads") +os.environ.setdefault("S3_ACCESS_KEY_ID", "test") +os.environ.setdefault("S3_SECRET_ACCESS_KEY", "test") +os.environ.setdefault("S3_TEMP_PATH", "/tmp") + +from shared.core.exceptions.domain_exceptions import ValidationException + + +def test_page_memory_parse_track_rejects_when_flag_disabled(monkeypatch) -> None: + from app.services.document_ingestion.service import ( + _validate_parse_track_for_extension, + ) + from shared.core.config import settings + + monkeypatch.setattr( + settings, + "RETRIEVAL_PAGE_MEMORY_ENABLED", + False, + ) + + with pytest.raises(ValidationException): + _validate_parse_track_for_extension( + parse_track="page_memory", + file_extension=".pdf", + ) + + +def test_page_memory_parse_track_allows_only_pdf_and_pptx(monkeypatch) -> None: + from app.services.document_ingestion.service import ( + _validate_parse_track_for_extension, + ) + from shared.core.config import settings + + monkeypatch.setattr( + settings, + "RETRIEVAL_PAGE_MEMORY_ENABLED", + True, + ) + + _validate_parse_track_for_extension( + parse_track="page_memory", + file_extension=".pdf", + ) + _validate_parse_track_for_extension( + parse_track="page_memory", + file_extension=".pptx", + ) + with pytest.raises(ValidationException): + _validate_parse_track_for_extension( + parse_track="page_memory", + file_extension=".docx", + ) diff --git a/apps/api/tests/support/contract_database.py b/apps/api/tests/support/contract_database.py index 2944cf14..3d4c5354 100644 --- a/apps/api/tests/support/contract_database.py +++ b/apps/api/tests/support/contract_database.py @@ -290,6 +290,7 @@ async def insert_document( status: str = "active", current_job_result_id: str | None = None, source_file_name: str | None = None, + parse_track: str = "chunk", created_at: datetime | None = None, updated_at: datetime | None = None, archived_at: datetime | None = None, @@ -304,6 +305,7 @@ async def insert_document( status, current_job_result_id, source_file_name, + parse_track, created_at, updated_at, archived_at @@ -314,6 +316,7 @@ async def insert_document( :status, :current_job_result_id, :source_file_name, + :parse_track, :created_at, :updated_at, :archived_at @@ -326,6 +329,7 @@ async def insert_document( "status": status, "current_job_result_id": current_job_result_id, "source_file_name": source_file_name or f"{document_id}.pdf", + "parse_track": parse_track, "created_at": timestamp, "updated_at": updated_at or timestamp, "archived_at": archived_at, diff --git a/apps/worker/.env.example b/apps/worker/.env.example index c9029a58..aad0d15c 100644 --- a/apps/worker/.env.example +++ b/apps/worker/.env.example @@ -120,8 +120,19 @@ MAX_FILE_SIZE=314572800 MAX_PDF_PAGE_LIMIT=200 OVERSIZED_PDF_SHARD_ENABLED=true OVERSIZED_PDF_SOFT_LIMIT=1500 -PDF_PROFILE_TOC_ENABLED=false +PDF_PAGE_TOC_ENABLED=true +RETRIEVAL_PAGE_MEMORY_ENABLED=false MINERU_SHARD_CONCURRENCY=3 +PARSE_AGENT_PLAN_BUDGET=50000 +PARSE_AGENT_VISUAL_BUDGET=80000 +PARSE_AGENT_TOC_CONFIRM_MIN_BUDGET=8000 +PARSE_AGENT_TOC_CONFIRM_CAP=24000 +PARSE_AGENT_COARSE_PLANNER_MIN_BUDGET=12000 +PARSE_AGENT_COARSE_PLANNER_CAP=36000 +PARSE_AGENT_STRUCTURAL_REACT_MIN_BUDGET=24000 +PARSE_AGENT_STRUCTURAL_REACT_CAP=64000 +PARSE_AGENT_PAGE_TAGGING_MIN_BUDGET=0 +PARSE_AGENT_PAGE_TAGGING_CAP=0 # Legacy parser compatibility fields. # ALL_DF_COLS=content,path,type,length,keywords,summary,know_id,tokens,connectto,addtime,page_nums diff --git a/apps/worker/app/services/connect_builder/summary_builder.py b/apps/worker/app/services/connect_builder/summary_builder.py index ae020e42..156cde97 100644 --- a/apps/worker/app/services/connect_builder/summary_builder.py +++ b/apps/worker/app/services/connect_builder/summary_builder.py @@ -62,6 +62,7 @@ def _llm_summarize(snippets_text: str, node_name: str, max_tokens: int = 100) -> messages=messages, timeout=60, max_tokens=max_tokens, + usage_task="finalization.doc_nav_summary", ) if resp is None: return "" diff --git a/apps/worker/app/services/document_agent/budget.py b/apps/worker/app/services/document_agent/budget.py index 6edbc130..0e732027 100644 --- a/apps/worker/app/services/document_agent/budget.py +++ b/apps/worker/app/services/document_agent/budget.py @@ -3,6 +3,15 @@ from __future__ import annotations from dataclasses import dataclass +from typing import Literal + + +BudgetStage = Literal[ + "toc_confirm", + "coarse_planner", + "structural_react", + "page_tagging", +] @dataclass @@ -16,6 +25,22 @@ def remaining(self) -> int: return max(self.capacity - self.used - self.reserved, 0) +@dataclass(frozen=True) +class StageEnvelope: + min_guarantee: int = 0 + cap: int | None = None + + +@dataclass +class StageUsage: + used: int = 0 + reserved: int = 0 + + @property + def committed(self) -> int: + return self.used + self.reserved + + class BudgetTracker: """A minimal synchronous ledger with plan and visual pools.""" @@ -24,21 +49,37 @@ def __init__( *, plan_budget: int = 5000, visual_budget: int = 8000, + visual_stage_envelopes: dict[str, StageEnvelope] | None = None, ) -> None: self._plan = BudgetPool(capacity=max(int(plan_budget), 0)) self._visual = BudgetPool(capacity=max(int(visual_budget), 0)) + self._visual_stage_envelopes = visual_stage_envelopes or {} + self._visual_stage_usage: dict[str, StageUsage] = { + stage: StageUsage() for stage in self._visual_stage_envelopes + } - def try_reserve(self, pool: str, est: int) -> bool: + def try_reserve(self, pool: str, est: int, *, stage: str | None = None) -> bool: if pool not in {"plan", "visual"}: return True est = max(int(est), 0) budget_pool = self._pool(pool) if budget_pool.remaining < est: return False + if pool == "visual" and stage and not self._can_reserve_visual_stage(stage, est): + return False budget_pool.reserved += est + if pool == "visual" and stage: + self._stage_usage(stage).reserved += est return True - def commit(self, pool: str, *, actual: int, est: int) -> None: + def commit( + self, + pool: str, + *, + actual: int, + est: int, + stage: str | None = None, + ) -> None: if pool not in {"plan", "visual"}: return est = max(int(est), 0) @@ -46,16 +87,50 @@ def commit(self, pool: str, *, actual: int, est: int) -> None: budget_pool = self._pool(pool) budget_pool.reserved = max(budget_pool.reserved - est, 0) budget_pool.used = min(budget_pool.capacity, budget_pool.used + actual) + if pool == "visual" and stage: + stage_usage = self._stage_usage(stage) + stage_usage.reserved = max(stage_usage.reserved - est, 0) + stage_usage.used += actual - def refund(self, pool: str, *, est: int) -> None: + def refund(self, pool: str, *, est: int, stage: str | None = None) -> None: if pool not in {"plan", "visual"}: return budget_pool = self._pool(pool) - budget_pool.reserved = max(budget_pool.reserved - max(int(est), 0), 0) + est = max(int(est), 0) + budget_pool.reserved = max(budget_pool.reserved - est, 0) + if pool == "visual" and stage: + stage_usage = self._stage_usage(stage) + stage_usage.reserved = max(stage_usage.reserved - est, 0) def _pool(self, pool: str) -> BudgetPool: return self._visual if pool == "visual" else self._plan + def _stage_usage(self, stage: str) -> StageUsage: + if stage not in self._visual_stage_usage: + self._visual_stage_usage[stage] = StageUsage() + return self._visual_stage_usage[stage] + + def _can_reserve_visual_stage(self, stage: str, est: int) -> bool: + envelope = self._visual_stage_envelopes.get(stage) + if envelope is None: + return True + + stage_usage = self._stage_usage(stage) + stage_committed_after_reserve = stage_usage.committed + est + if envelope.cap is not None and stage_committed_after_reserve > envelope.cap: + return False + + reserved_by_other_stages = 0 + for other_stage, other_envelope in self._visual_stage_envelopes.items(): + if other_stage == stage: + continue + other_usage = self._stage_usage(other_stage) + if other_usage.committed >= other_envelope.min_guarantee: + continue + reserved_by_other_stages += other_envelope.min_guarantee - other_usage.committed + + return self._visual.remaining - est >= reserved_by_other_stages + def _pool_snapshot(self, pool: BudgetPool) -> dict[str, int]: return { "capacity": pool.capacity, @@ -68,4 +143,17 @@ def snapshot(self) -> dict[str, object]: return { "plan": self._pool_snapshot(self._plan), "visual": self._pool_snapshot(self._visual), + "visual_stages": { + stage: { + "used": usage.used, + "reserved": usage.reserved, + "min_guarantee": self._visual_stage_envelopes.get( + stage, StageEnvelope() + ).min_guarantee, + "cap": self._visual_stage_envelopes.get( + stage, StageEnvelope() + ).cap, + } + for stage, usage in sorted(self._visual_stage_usage.items()) + }, } diff --git a/apps/worker/app/services/document_agent/coordinator.py b/apps/worker/app/services/document_agent/coordinator.py index bcb4edb9..9d871866 100644 --- a/apps/worker/app/services/document_agent/coordinator.py +++ b/apps/worker/app/services/document_agent/coordinator.py @@ -12,7 +12,7 @@ classify_page_kinds, probe_page_features, ) -from app.services.document_agent.budget import BudgetTracker +from app.services.document_agent.budget import BudgetTracker, StageEnvelope from app.services.document_agent.executor import ReActExecutor from app.services.document_agent.manifest import ( DocumentProfile, @@ -45,6 +45,32 @@ def __init__( self.budget = BudgetTracker( plan_budget=int(os.environ.get("PARSE_AGENT_PLAN_BUDGET", "50000")), visual_budget=int(os.environ.get("PARSE_AGENT_VISUAL_BUDGET", "80000")), + visual_stage_envelopes={ + "toc_confirm": StageEnvelope( + min_guarantee=int( + os.environ.get("PARSE_AGENT_TOC_CONFIRM_MIN_BUDGET", "8000") + ), + cap=int(os.environ.get("PARSE_AGENT_TOC_CONFIRM_CAP", "24000")), + ), + "coarse_planner": StageEnvelope( + min_guarantee=int( + os.environ.get("PARSE_AGENT_COARSE_PLANNER_MIN_BUDGET", "12000") + ), + cap=int(os.environ.get("PARSE_AGENT_COARSE_PLANNER_CAP", "36000")), + ), + "structural_react": StageEnvelope( + min_guarantee=int( + os.environ.get("PARSE_AGENT_STRUCTURAL_REACT_MIN_BUDGET", "24000") + ), + cap=int(os.environ.get("PARSE_AGENT_STRUCTURAL_REACT_CAP", "64000")), + ), + "page_tagging": StageEnvelope( + min_guarantee=int( + os.environ.get("PARSE_AGENT_PAGE_TAGGING_MIN_BUDGET", "0") + ), + cap=int(os.environ.get("PARSE_AGENT_PAGE_TAGGING_CAP", "0")) or None, + ), + }, ) effective_settings = settings or {} if model: @@ -122,7 +148,10 @@ def _run_structural(self) -> PageAnatomyMap: self.state = DocumentAgentState.RUNNING if not self.blackboard.page_features: self._run_bootstrap() - self._ensure_toc_profile(strict=True) + if self._toc_profile_enabled(): + self._ensure_toc_profile(strict=True) + else: + self._ensure_disabled_toc_placeholder() profile, initial_decision, _planner_result = self._propose_profile( actor="planner" ) @@ -143,6 +172,12 @@ def _run_toc(self) -> TocResult: self.state = DocumentAgentState.RUNNING if not self.blackboard.page_features: self._run_bootstrap() + if not self._toc_profile_enabled(): + self._ensure_disabled_toc_placeholder() + toc_result = self.blackboard.toc_result + if toc_result is None: + raise RuntimeError("TOC placeholder was not initialized") + return toc_result self._ensure_toc_profile(strict=False) if self.blackboard.toc_result is None: self.blackboard.toc_result = TocResult( @@ -156,10 +191,13 @@ def _run_lightweight_anatomy(self) -> PageAnatomyMap: if not self.blackboard.page_features: self._run_bootstrap() if self.blackboard.toc_result is None: - self.blackboard.toc_result = TocResult( - method="none", - notes="TOC profiling disabled or not attempted", - ) + if self._toc_profile_enabled(): + self.blackboard.toc_result = TocResult( + method="none", + notes="TOC profiling not attempted", + ) + else: + self._ensure_disabled_toc_placeholder() self._run_h1_boundary_pipeline() result = REGISTRY.dispatch("propose.shard_plan", self.ctx, {}) self.trace.record_step( @@ -236,13 +274,20 @@ def _toc_result_requires_strict_retry(self) -> bool: ) def _should_run_toc_before_coarse(self) -> bool: - if self.ctx.settings.get("toc_before_coarse"): - return True - try: - page_limit = int(self.ctx.settings.get("toc_before_coarse_page_limit", 0)) - except (TypeError, ValueError): - page_limit = 0 - return page_limit > 0 and self.blackboard.page_count > page_limit + return self._toc_profile_enabled() and bool( + self.ctx.settings.get("toc_before_coarse") + ) + + def _toc_profile_enabled(self) -> bool: + return bool(self.ctx.settings.get("toc_profile_enabled", True)) + + def _ensure_disabled_toc_placeholder(self) -> None: + self.blackboard.toc_result = TocResult( + method="none", + notes="TOC profiling disabled by PDF_PAGE_TOC_ENABLED", + ) + self.blackboard.toc_hierarchies = None + self.blackboard.global_signals["toc_profile_attempted"] = False def _ensure_toc_profile(self, *, strict: bool) -> None: should_run = self.blackboard.toc_result is None @@ -255,6 +300,7 @@ def _ensure_toc_profile(self, *, strict: bool) -> None: return self._planner_cache = None + self.blackboard.global_signals["toc_profile_attempted"] = True try: self._run_toc_extraction_pipeline() except Exception as exc: diff --git a/apps/worker/app/services/document_agent/executor/react_loop.py b/apps/worker/app/services/document_agent/executor/react_loop.py index ea2daf32..d391c764 100644 --- a/apps/worker/app/services/document_agent/executor/react_loop.py +++ b/apps/worker/app/services/document_agent/executor/react_loop.py @@ -244,6 +244,7 @@ def _next_decision(self, round_index: int) -> tuple[ReflexionDecision, ToolResul temperature=0.0, max_tokens=1200, response_format={"type": "json_object"}, + usage_task="document_agent.react_loop", ) self.ctx.budget.commit( "plan", diff --git a/apps/worker/app/services/document_agent/planner/planner.py b/apps/worker/app/services/document_agent/planner/planner.py index 7c796586..2efae784 100644 --- a/apps/worker/app/services/document_agent/planner/planner.py +++ b/apps/worker/app/services/document_agent/planner/planner.py @@ -264,7 +264,8 @@ def propose(self) -> tuple[DocumentProfile, ReflexionDecision, ToolResult]: ensure_ascii=False, ) prompt_tokens_est = estimate_tokens(prompt_text) + len(pngs) * 800 - if not self.ctx.budget.try_reserve("visual", prompt_tokens_est): + stage = "coarse_planner" + if not self.ctx.budget.try_reserve("visual", prompt_tokens_est, stage=stage): raise RuntimeError("Insufficient visual budget for profile planning.") content_parts: list[dict[str, Any]] = [{"type": "text", "text": prompt_text}] @@ -294,11 +295,13 @@ def propose(self) -> tuple[DocumentProfile, ReflexionDecision, ToolResult]: temperature=0.0, max_tokens=1800, response_format={"type": "json_object"}, + usage_task="document_agent.coarse_profile", ) self.ctx.budget.commit( "visual", actual=usage.get("total_tokens", prompt_tokens_est), est=prompt_tokens_est, + stage=stage, ) profile, decision = _parse_profile_and_decision(raw) return profile, decision, ToolResult( @@ -319,5 +322,5 @@ def propose(self) -> tuple[DocumentProfile, ReflexionDecision, ToolResult]: }, ) except Exception: - self.ctx.budget.refund("visual", est=prompt_tokens_est) + self.ctx.budget.refund("visual", est=prompt_tokens_est, stage=stage) raise diff --git a/apps/worker/app/services/document_agent/tools/extract_toc_with_boundaries.py b/apps/worker/app/services/document_agent/tools/extract_toc_with_boundaries.py index ddb876e4..33293210 100644 --- a/apps/worker/app/services/document_agent/tools/extract_toc_with_boundaries.py +++ b/apps/worker/app/services/document_agent/tools/extract_toc_with_boundaries.py @@ -119,7 +119,8 @@ def _vlm_confirm_anchors( messages = cast(Any, [{"role": "user", "content": content_parts}]) est = estimate_tokens(str(content_parts[0]["text"])) + len(anchor_pages) * 800 - if budget and not budget.try_reserve("visual", est): + stage = "toc_confirm" + if budget and not budget.try_reserve("visual", est, stage=stage): logger.warning("[extract.toc] insufficient visual budget for anchor confirmation") return [], True, [] @@ -131,9 +132,15 @@ def _vlm_confirm_anchors( temperature=0.1, max_tokens=500, response_format={"type": "json_object"}, + usage_task="document_agent.toc_anchor_confirm", ) if budget: - budget.commit("visual", actual=usage.get("total_tokens", est), est=est) + budget.commit( + "visual", + actual=usage.get("total_tokens", est), + est=est, + stage=stage, + ) data = json.loads(raw) if isinstance(data, dict): items = data.get("pages") or data.get("results") or data.get("data") or [] @@ -191,7 +198,7 @@ def _vlm_confirm_anchors( return confirmed, False, evidence except Exception as exc: if budget: - budget.refund("visual", est=est) + budget.refund("visual", est=est, stage=stage) logger.warning( "[extract.toc] VLM anchor confirmation failed: {}, " "falling back to no confirmed anchors (safe degradation)", diff --git a/apps/worker/app/services/document_agent/tools/find_toc_anchor_pages.py b/apps/worker/app/services/document_agent/tools/find_toc_anchor_pages.py index db03c90e..e2acb4e5 100644 --- a/apps/worker/app/services/document_agent/tools/find_toc_anchor_pages.py +++ b/apps/worker/app/services/document_agent/tools/find_toc_anchor_pages.py @@ -1,4 +1,4 @@ -"""Scan for TOC anchor pages and render their PNGs for VLM inspection.""" +"""Scan PDF text for TOC anchor pages and render their PNGs for VLM inspection.""" from __future__ import annotations @@ -16,8 +16,9 @@ ) from loguru import logger -# CJK and English TOC keywords used for anchor detection. -TOC_KEYWORDS = {"目录", "目次", "contents", "tableofcontents", "table of contents"} +# CJK and English TOC keywords used for first-pass anchor detection. +TOC_KEYWORDS = {"目录", "目次", "contents", "tableofcontents"} +TOC_CROSS_LINE_WINDOW = 6 # If a TOC keyword fingerprint appears on more than this fraction of total # pages, it is treated as a recurring navigation element (header/footer link) @@ -34,6 +35,100 @@ def _normalize_for_toc(text: str) -> str: return text.replace(" ", "").replace("\u3000", "").lower() +def _meaningful_text_lines(text: str) -> list[str]: + return [" ".join(line.split()) for line in text.splitlines() if line.split()] + + +def _match_toc_keyword(normalized_text: str) -> str | None: + for keyword in sorted(TOC_KEYWORDS, key=len, reverse=True): + if keyword in normalized_text: + return keyword + return None + + +def _find_toc_text_matches( + lines: list[str], + *, + cross_line_window: int, +) -> list[dict[str, Any]]: + matches: list[dict[str, Any]] = [] + direct_hit_lines: set[int] = set() + for line_idx, raw_line in enumerate(lines): + keyword = _match_toc_keyword(_normalize_for_toc(raw_line)) + if keyword is None: + continue + direct_hit_lines.add(line_idx) + matches.append( + { + "raw_line": raw_line.strip(), + "line_index": line_idx, + "line_end_index": line_idx, + "match_kind": f"keyword:{keyword}", + } + ) + + if matches: + return matches + + # TODO: Add second-layer table-of-contents shape scoring if first-pass + # keywords produce too many candidates. Keep this layer as pure anchor + # discovery; VLM confirmation remains the semantic judge. + window = max(cross_line_window, 1) + for start_idx in range(len(lines)): + joined_parts: list[str] = [] + end_limit = min(len(lines), start_idx + window) + for end_idx in range(start_idx, end_limit): + if end_idx in direct_hit_lines: + continue + joined_parts.append(lines[end_idx].strip()) + if end_idx == start_idx: + continue + keyword = _match_toc_keyword(_normalize_for_toc("".join(joined_parts))) + if keyword is None: + continue + matches.append( + { + "raw_line": " / ".join(joined_parts), + "line_index": start_idx, + "line_end_index": end_idx, + "match_kind": f"cross_line:{keyword}", + } + ) + return matches + return matches + + +@worker +def _scan_toc_text_worker( + queue, pdf_path: str, cross_line_window: int +) -> None: + import pymupdf # type: ignore[import] + + matches: list[dict[str, Any]] = [] + page_count = 0 + doc = None + try: + doc = pymupdf.open(pdf_path) + page_count = doc.page_count + for page_idx in range(doc.page_count): + page_num = page_idx + 1 + text = str(doc[page_idx].get_text("text") or "") + lines = _meaningful_text_lines(text) + for match in _find_toc_text_matches( + lines, + cross_line_window=cross_line_window, + ): + matches.append({"page": page_num, **match}) + finally: + if doc is not None: + try: + doc.close() + except Exception: + pass + gc.collect() + queue.put({"ok": True, "matches": matches, "page_count": page_count}) + + @worker def _render_pages_worker( queue, pdf_path: str, pages: list[int], output_dir: str, dpi: int @@ -63,14 +158,15 @@ def _render_pages_worker( def _filter_recurring_elements( - matches: list[tuple[int, str, int]], + matches: list[dict[str, Any]], total_pages: int, ) -> set[int]: """Remove pages whose TOC keyword pattern is a recurring navigation element. - Each *match* is ``(page, raw_line, line_index)``. We build a **composite + Each match contains ``page``, ``raw_line``, ``line_index``, and optionally + ``line_end_index``. We build a **composite fingerprint** per page by joining all its matches as - ``"raw_line@line_idx\\n..."``. If the same composite fingerprint appears + ``"raw_line@line_idx-line_end\\n..."``. If the same composite fingerprint appears on more than ``RECURRING_ELEMENT_THRESHOLD`` of all pages, those pages are header/footer false-positives. @@ -80,16 +176,21 @@ def _filter_recurring_elements( - page 401 → ``"Table of Contents@0"`` (~31 pages → VLM decides) """ # Collect all matches per page - page_matches: dict[int, list[tuple[str, int]]] = {} - for page, raw_line, line_idx in matches: - page_matches.setdefault(page, []).append((raw_line, line_idx)) + page_matches: dict[int, list[tuple[str, int, int]]] = {} + for match in matches: + page = int(match["page"]) + raw_line = str(match.get("raw_line") or "").strip() + line_idx = int(match.get("line_index") or 0) + line_end_idx = int(match.get("line_end_index") or line_idx) + page_matches.setdefault(page, []).append((raw_line, line_idx, line_end_idx)) # Build composite fingerprint per page (sorted by line_idx for stability) page_fingerprints: dict[int, str] = {} for page, hits in page_matches.items(): hits_sorted = sorted(hits, key=lambda h: h[1]) page_fingerprints[page] = "\n".join( - f"{raw}@{idx}" for raw, idx in hits_sorted + f"{raw}@{start_idx}-{end_idx}" + for raw, start_idx, end_idx in hits_sorted ) # Group pages by composite fingerprint @@ -118,7 +219,7 @@ def _filter_recurring_elements( @register_tool( name="find.toc_anchor_pages", description=( - "Scan page text previews for TOC keywords, filter recurring " + "Scan full PDF page text for TOC keywords, filter recurring " "navigation elements, then render candidate PNGs for VLM confirmation." ), preconditions=(has_page_labels,), @@ -127,44 +228,20 @@ def find_toc_anchor_pages(ctx: ToolContext, _args: dict[str, Any]) -> ToolResult start = time.monotonic() total_pages = ctx.blackboard.page_count - # Scan text previews for TOC keywords and record per-line matches. - # Each entry: (page, raw_line_text, line_index) - # We use raw (original) text for fingerprinting so that casing - # differences (e.g. "Table of Contents" vs "TABLE OF CONTENTS") - # naturally produce distinct composite fingerprints. - keyword_matches: list[tuple[int, str, int]] = [] - raw_hit_pages: set[int] = set() - - for feature in ctx.blackboard.page_features: - page_matched = False - for line_idx, raw_line in enumerate(feature.text_lines_preview): - norm_line = _normalize_for_toc(raw_line) - for keyword in TOC_KEYWORDS: - if keyword in norm_line: - keyword_matches.append((feature.page, raw_line.strip(), line_idx)) - raw_hit_pages.add(feature.page) - page_matched = True - break # one match per line is enough - - # Fallback: check if a TOC keyword spans across adjacent lines. - # PyMuPDF sometimes splits large headings across lines, e.g. - # "目" + "录" or "Table of" + "Contents". Join the first few - # preview lines (where a page title would appear) and re-check - # with the same keywords and normalisation. - if not page_matched and feature.text_lines_preview: - head = feature.text_lines_preview[:10] - joined_head = _normalize_for_toc("".join(head)) - for keyword in TOC_KEYWORDS: - if keyword in joined_head: - keyword_matches.append((feature.page, keyword, 0)) - raw_hit_pages.add(feature.page) - logger.debug( - "[find.toc_anchor_pages] cross-line keyword '{}' " - "detected on page {} (head lines joined)", - keyword, - feature.page, - ) - break + scan_timeout = int(ctx.settings.get("toc_text_scan_timeout", "180")) + cross_line_window = int( + ctx.settings.get("toc_cross_line_window", TOC_CROSS_LINE_WINDOW) + ) + scan_result = run_in_child_process( + _scan_toc_text_worker, + ctx.pdf_path, + cross_line_window, + timeout=scan_timeout, + ) + keyword_matches = list(scan_result.get("matches") or []) + raw_hit_pages = {int(match["page"]) for match in keyword_matches} + if scan_result.get("page_count"): + total_pages = int(scan_result["page_count"]) # Apply recurring element fingerprint filter if keyword_matches: @@ -234,4 +311,3 @@ def find_toc_anchor_pages(ctx: ToolContext, _args: dict[str, Any]) -> ToolResult "pages": [a.to_dict() for a in anchors], }, ) - diff --git a/apps/worker/app/services/document_agent/tools/inspect_pages.py b/apps/worker/app/services/document_agent/tools/inspect_pages.py index 3b31d97e..8881a9b7 100644 --- a/apps/worker/app/services/document_agent/tools/inspect_pages.py +++ b/apps/worker/app/services/document_agent/tools/inspect_pages.py @@ -61,7 +61,8 @@ def inspect_pages(ctx: ToolContext, args: dict[str, Any]) -> ToolResult: latency_ms=int((time.monotonic() - start) * 1000), warnings=["No VLM model configured; returned rendered page paths only."], ) - if not ctx.budget.try_reserve("visual", est): + stage = "structural_react" + if not ctx.budget.try_reserve("visual", est, stage=stage): return ToolResult( status="error", error="insufficient visual budget", @@ -86,8 +87,14 @@ def inspect_pages(ctx: ToolContext, args: dict[str, Any]) -> ToolResult: temperature=0.0, max_tokens=1200, response_format={"type": "json_object"}, + usage_task="document_agent.inspect_pages", + ) + ctx.budget.commit( + "visual", + actual=usage.get("total_tokens", est), + est=est, + stage=stage, ) - ctx.budget.commit("visual", actual=usage.get("total_tokens", est), est=est) try: payload: dict[str, Any] = json.loads(raw) except json.JSONDecodeError: @@ -104,5 +111,5 @@ def inspect_pages(ctx: ToolContext, args: dict[str, Any]) -> ToolResult: tokens_used=usage.get("total_tokens", 0), ) except Exception: - ctx.budget.refund("visual", est=est) + ctx.budget.refund("visual", est=est, stage=stage) raise diff --git a/apps/worker/app/services/document_agent/tools/match_h1_pages.py b/apps/worker/app/services/document_agent/tools/match_h1_pages.py index 7cae13bf..3e4d9ed0 100644 --- a/apps/worker/app/services/document_agent/tools/match_h1_pages.py +++ b/apps/worker/app/services/document_agent/tools/match_h1_pages.py @@ -148,7 +148,8 @@ def verify_section_start( 'Return JSON: {"is_section_start": true/false, "reason": "brief"}' ) est = 800 # ~800 tokens for 1 image - if not ctx.budget.try_reserve("visual", est): + stage = "structural_react" + if not ctx.budget.try_reserve("visual", est, stage=stage): return True # Budget exhausted → trust GREP try: @@ -175,9 +176,13 @@ def verify_section_start( temperature=0.0, max_tokens=256, response_format={"type": "json_object"}, + usage_task="document_agent.match_h1_pages", ) ctx.budget.commit( - "visual", actual=usage.get("total_tokens", est), est=est, + "visual", + actual=usage.get("total_tokens", est), + est=est, + stage=stage, ) data = json.loads(raw) result = bool(data.get("is_section_start", True)) @@ -187,7 +192,7 @@ def verify_section_start( ) return result except Exception as exc: - ctx.budget.refund("visual", est=est) + ctx.budget.refund("visual", est=est, stage=stage) logger.warning("[verify_section_start] VLM failed for page {}: {}", page, exc) return True # VLM failure → trust GREP diff --git a/apps/worker/app/services/document_agent/tools/propose_shard_plan.py b/apps/worker/app/services/document_agent/tools/propose_shard_plan.py index a52eeb05..7b50c98d 100644 --- a/apps/worker/app/services/document_agent/tools/propose_shard_plan.py +++ b/apps/worker/app/services/document_agent/tools/propose_shard_plan.py @@ -451,6 +451,7 @@ def propose_shard_plan(ctx: ToolContext, _args: dict[str, Any]) -> ToolResult: temperature=0.0, max_tokens=1600, response_format={"type": "json_object"}, + usage_task="document_agent.propose_shard_plan", ) ctx.budget.commit("plan", actual=usage.get("total_tokens", prompt_tokens_est), est=prompt_tokens_est) enabled, cuts, reason, rationale = _parse_llm_plan(raw_response, page_count, min_pages, max_pages) diff --git a/apps/worker/app/services/document_agent/tools/vlm_toc_extractor.py b/apps/worker/app/services/document_agent/tools/vlm_toc_extractor.py index 7c1d6247..bd5c7c04 100644 --- a/apps/worker/app/services/document_agent/tools/vlm_toc_extractor.py +++ b/apps/worker/app/services/document_agent/tools/vlm_toc_extractor.py @@ -197,6 +197,7 @@ def vlm_extract_toc_batch( temperature=0.1, max_tokens=8192, response_format={"type": "json_object"}, + usage_task="document_agent.vlm_toc_batch", ) elapsed_ms = int((time.monotonic() - start) * 1000) diff --git a/apps/worker/app/services/document_agent/trace.py b/apps/worker/app/services/document_agent/trace.py index b455b63e..5771e602 100644 --- a/apps/worker/app/services/document_agent/trace.py +++ b/apps/worker/app/services/document_agent/trace.py @@ -21,7 +21,9 @@ def __init__(self, *, job_id: str, db: Any | None = None) -> None: self._started = time.monotonic() self._steps: list[dict[str, Any]] = [] self._anatomy: PageAnatomyMap | None = None + self._doc_profile_source: Any | None = None self._artifact_path: str | None = None + self._profile_plan_row: Any | None = None def record_step( self, @@ -61,6 +63,73 @@ def set_anatomy_map(self, anatomy: PageAnatomyMap, artifact_path: str) -> None: self._artifact_path = artifact_path self.write_trace_json(str(Path(artifact_path).with_name("trace.json"))) + def set_doc_profile(self, profile: Any) -> None: + self._doc_profile_source = profile + + def persist_doc_profile(self, profile: Any | None = None) -> None: + """Persist the coarse PDF profile independently from anatomy-map creation.""" + if profile is not None: + self.set_doc_profile(profile) + if self._db is None: + return + try: + from shared.models.database.document_page_plan import DocumentPagePlan + + doc_profile = self._doc_profile_for_plan() + if doc_profile is None: + return + if self._profile_plan_row is None: + self._profile_plan_row = DocumentPagePlan( + page_plan_id=f"dpp_{uuid4().hex[:12]}", + job_id=self.job_id, + page_count=int(doc_profile.get("page_count") or 0), + shard_plan=None, + doc_profile=doc_profile, + global_signals=None, + ) + self._db.add(self._profile_plan_row) + else: + self._profile_plan_row.page_count = int( + doc_profile.get("page_count") or self._profile_plan_row.page_count or 0 + ) + self._profile_plan_row.doc_profile = doc_profile + self._db.flush() + except Exception as exc: + logger.debug(f"parse agent doc profile persist failed: {exc}") + try: + if self._profile_plan_row is not None: + self._db.expunge(self._profile_plan_row) + self._profile_plan_row = None + except Exception: + pass + + def _doc_profile_for_plan(self) -> dict[str, Any] | None: + if self._doc_profile_source is None and self._anatomy is None: + return None + if self._doc_profile_source is None: + profile: dict[str, Any] = {} + elif hasattr(self._doc_profile_source, "to_dict"): + profile = self._doc_profile_source.to_dict() + else: + profile = dict(self._doc_profile_source) + + if self._anatomy is not None: + toc_result = self._anatomy.toc_result + profile.setdefault("file_type", "pdf") + profile["page_count"] = self._anatomy.page_count + profile["toc"] = { + "toc_pages": list(toc_result.toc_pages), + "hierarchies": self._anatomy.toc_hierarchies, + "evidence": [item.to_dict() for item in toc_result.evidence], + "source": "pdf_vlm" if toc_result.method != "none" else "none", + "method": toc_result.method, + "notes": toc_result.notes, + "attempted": bool( + self._anatomy.global_signals.get("toc_profile_attempted", True) + ), + } + return profile + def write_trace_artifact( self, output_dir: str | None, @@ -155,15 +224,22 @@ def flush(self, *, final_status: str, summary: dict[str, Any] | None = None) -> ) ) if self._anatomy is not None: - self._db.add( - DocumentPagePlan( + doc_profile = self._doc_profile_for_plan() + if self._profile_plan_row is None: + self._profile_plan_row = DocumentPagePlan( page_plan_id=f"dpp_{uuid4().hex[:12]}", job_id=self.job_id, page_count=self._anatomy.page_count, shard_plan=self._anatomy.shard_plan.to_dict(), + doc_profile=doc_profile, global_signals=self._anatomy.global_signals, ) - ) + self._db.add(self._profile_plan_row) + else: + self._profile_plan_row.page_count = self._anatomy.page_count + self._profile_plan_row.shard_plan = self._anatomy.shard_plan.to_dict() + self._profile_plan_row.doc_profile = doc_profile + self._profile_plan_row.global_signals = self._anatomy.global_signals self._db.flush() except Exception as exc: logger.debug(f"parse agent trace flush failed: {exc}") diff --git a/apps/worker/app/services/document_ingestion/parse_execution.py b/apps/worker/app/services/document_ingestion/parse_execution.py index 665194b6..0286467b 100644 --- a/apps/worker/app/services/document_ingestion/parse_execution.py +++ b/apps/worker/app/services/document_ingestion/parse_execution.py @@ -8,6 +8,7 @@ stage_timer, init_stage_tracker, cleanup_stage_tracker, + get_current_stage_tracker, ) from loguru import logger @@ -15,6 +16,7 @@ from shared.services.ai.token_tracking import ( init_token_tracker, cleanup_token_tracker, + get_current_token_tracker, ) @@ -31,14 +33,23 @@ def execute_document_parse( "doc_type", "auto", ) + parse_track = JobMetadataHelper.get_parse_track(job_context.job_metadata) logger.info( f"Start parse: job_id={job_id}, " f"filename={prepared_source.source_file_name}, " - f"internal_filename={prepared_source.internal_parse_name}, type={doc_type}" + f"internal_filename={prepared_source.internal_parse_name}, " + f"type={doc_type}, parse_track={parse_track}" ) - token_usage_dict = init_token_tracker() - stage_timing_dict = init_stage_tracker() + token_usage_dict = get_current_token_tracker() + owns_token_tracker = token_usage_dict is None + if token_usage_dict is None: + token_usage_dict = init_token_tracker() + + stage_timing_dict = get_current_stage_tracker() + owns_stage_tracker = stage_timing_dict is None + if stage_timing_dict is None: + stage_timing_dict = init_stage_tracker() try: with stage_timer( @@ -46,41 +57,49 @@ def execute_document_parse( job_id=job_id, filename=prepared_source.source_file_name, doc_type=doc_type, + parse_track=parse_track, ): - parse_output = parse_service.checkerboard_parse_output( - file_full_path=prepared_source.local_file_path, - filename=prepared_source.source_file_name, - output_dir=output_dir, - job_id=job_id, - internal_output_filename=prepared_source.internal_parse_name, - doc_type=doc_type, - smart_title_parse=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "smart_title_parse", - True, - ), - summary_image=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "summary_image", - True, - ), - summary_table=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "summary_table", - True, - ), - summary_txt=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "summary_txt", - True, - ), - add_frag_desc=JobMetadataHelper.get_parsing_param( - job_context.job_metadata, - "add_frag_desc", - "", - ), - s3_key=job_context.s3_key, - ) + if parse_track == "page_memory": + parse_output = _execute_page_memory_parse( + job_id=job_id, + prepared_source=prepared_source, + output_dir=output_dir, + ) + else: + parse_output = parse_service.checkerboard_parse_output( + file_full_path=prepared_source.local_file_path, + filename=prepared_source.source_file_name, + output_dir=output_dir, + job_id=job_id, + internal_output_filename=prepared_source.internal_parse_name, + doc_type=doc_type, + smart_title_parse=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "smart_title_parse", + True, + ), + summary_image=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "summary_image", + True, + ), + summary_table=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "summary_table", + True, + ), + summary_txt=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "summary_txt", + True, + ), + add_frag_desc=JobMetadataHelper.get_parsing_param( + job_context.job_metadata, + "add_frag_desc", + "", + ), + s3_key=job_context.s3_key, + ) logger.info( "File parsing completed: " @@ -93,7 +112,29 @@ def execute_document_parse( "token_usage": dict(token_usage_dict), } finally: - cleanup_token_tracker() - cleanup_stage_tracker() + if owns_token_tracker: + cleanup_token_tracker() + if owns_stage_tracker: + cleanup_stage_tracker() return parse_output + + +def _execute_page_memory_parse( + *, + job_id: str, + prepared_source: PreparedSourceFile, + output_dir: str, +) -> ParseOutput: + from app.services.page_memory.memory_service import PageMemoryInput, run + + page_output_dir, parsed_df = run( + PageMemoryInput( + file_path=prepared_source.local_file_path, + filename=prepared_source.source_file_name, + internal_output_filename=prepared_source.internal_parse_name, + output_dir=output_dir, + job_id=job_id, + ) + ) + return ParseOutput(output_dir=page_output_dir, parsed_df=parsed_df) diff --git a/apps/worker/app/services/document_ingestion/processing_run.py b/apps/worker/app/services/document_ingestion/processing_run.py index 975c46cf..ee03f1f8 100644 --- a/apps/worker/app/services/document_ingestion/processing_run.py +++ b/apps/worker/app/services/document_ingestion/processing_run.py @@ -28,6 +28,11 @@ ) from loguru import logger +from app.services.document_parser.support.stage_profiler import ( + cleanup_stage_tracker, + init_stage_tracker, +) +from shared.services.ai.token_tracking import cleanup_token_tracker, init_token_tracker from shared.services.jobs.lifecycle.service import get_sync_job_lifecycle_service from shared.services.redis.distributed_lock import RedisJobLock from shared.services.redis.redis_sync_service import ( @@ -78,30 +83,49 @@ def _run_parse_job( task_workspace: TemporaryParseWorkspace, ) -> dict[str, object]: lifecycle_service.update_progress(job_id, progress=10, message="Parsing document...") + token_usage_dict = init_token_tracker() + stage_timing_dict = init_stage_tracker() + + try: + prepared_source = prepare_source_file( + job_id=job_id, + job_context=job_context, + input_dir=task_workspace.input_dir, + ) + + workload_estimate = PageEstimator.estimate_workload(prepared_source.local_file_path) + page_count = workload_estimate.page_count + logger.info( + "Workload estimation: " + f"job_id={job_id}, page_count={page_count}, " + f"method={workload_estimate.method}, " + f"fallback_reason={workload_estimate.fallback_reason}" + ) - prepared_source = prepare_source_file( - job_id=job_id, - job_context=job_context, - input_dir=task_workspace.input_dir, - ) - - workload_estimate = PageEstimator.estimate_workload(prepared_source.local_file_path) - page_count = workload_estimate.page_count - logger.info( - "Workload estimation: " - f"job_id={job_id}, page_count={page_count}, " - f"method={workload_estimate.method}, " - f"fallback_reason={workload_estimate.fallback_reason}" - ) - - processing_started_at = datetime.now(timezone.utc) - oversized_pdf_rejection = build_oversized_pdf_rejection( - file_extension=prepared_source.file_extension, - page_count=page_count, - ) - if oversized_pdf_rejection is not None: - billing_snapshot = record_skipped_parse_job_billing( + processing_started_at = datetime.now(timezone.utc) + oversized_pdf_rejection = build_oversized_pdf_rejection( + file_extension=prepared_source.file_extension, + page_count=page_count, + ) + if oversized_pdf_rejection is not None: + billing_snapshot = record_skipped_parse_job_billing( + job_id=job_id, + workload_estimate=workload_estimate, + ) + record_processing_start( + job_id=job_id, + job_context=job_context, + billing_snapshot=billing_snapshot, + processing_started_at=processing_started_at, + workload_estimate=workload_estimate, + extra_metadata={"error_details": oversized_pdf_rejection.details}, + ) + raise oversized_pdf_rejection + + billing_snapshot = charge_parse_job_pages( job_id=job_id, + filename=prepared_source.source_file_name, + job_user_id=job_context.job_user_id, workload_estimate=workload_estimate, ) record_processing_start( @@ -110,57 +134,48 @@ def _run_parse_job( billing_snapshot=billing_snapshot, processing_started_at=processing_started_at, workload_estimate=workload_estimate, - extra_metadata={"error_details": oversized_pdf_rejection.details}, ) - raise oversized_pdf_rejection - - billing_snapshot = charge_parse_job_pages( - job_id=job_id, - filename=prepared_source.source_file_name, - job_user_id=job_context.job_user_id, - workload_estimate=workload_estimate, - ) - record_processing_start( - job_id=job_id, - job_context=job_context, - billing_snapshot=billing_snapshot, - processing_started_at=processing_started_at, - workload_estimate=workload_estimate, - ) - - parse_output = execute_document_parse( - job_id=job_id, - job_context=job_context, - prepared_source=prepared_source, - output_dir=task_workspace.output_dir, - ) - - lifecycle_service.update_progress( - job_id, - progress=30, - message="Parse completed, preparing chunks...", - ) - result_package = build_parse_result_package( - job_id=job_id, - filename=prepared_source.source_file_name, - parse_output=parse_output, - ) - - lifecycle_service.update_progress( - job_id, - progress=70, - message="Chunks ready, generating zip...", - ) - logger.info( - f"Chunks prepared: job_id={job_id}, count={len(result_package.chunks)}" - ) - - return finalize_parse_success( - result_package=result_package, - job_context=job_context, - job_id=job_id, - lifecycle_service=lifecycle_service, - processing_started_at=processing_started_at, - task_workspace_dir=task_workspace.root_dir, - result_storage_factory=get_result_storage, - ) + + parse_output = execute_document_parse( + job_id=job_id, + job_context=job_context, + prepared_source=prepared_source, + output_dir=task_workspace.output_dir, + ) + + lifecycle_service.update_progress( + job_id, + progress=30, + message="Parse completed, preparing chunks...", + ) + result_package = build_parse_result_package( + job_id=job_id, + filename=prepared_source.source_file_name, + parse_output=parse_output, + ) + + lifecycle_service.update_progress( + job_id, + progress=70, + message="Chunks ready, generating zip...", + ) + logger.info( + f"Chunks prepared: job_id={job_id}, count={len(result_package.chunks)}" + ) + + return finalize_parse_success( + result_package=result_package, + job_context=job_context, + job_id=job_id, + lifecycle_service=lifecycle_service, + processing_started_at=processing_started_at, + task_workspace_dir=task_workspace.root_dir, + result_storage_factory=get_result_storage, + ) + finally: + job_context.job_metadata["stages"] = { + "timing_ms": dict(stage_timing_dict), + "token_usage": dict(token_usage_dict), + } + cleanup_token_tracker() + cleanup_stage_tracker() diff --git a/apps/worker/app/services/document_ingestion/success_finalization.py b/apps/worker/app/services/document_ingestion/success_finalization.py index 62397d56..93c5c301 100644 --- a/apps/worker/app/services/document_ingestion/success_finalization.py +++ b/apps/worker/app/services/document_ingestion/success_finalization.py @@ -21,6 +21,8 @@ from loguru import logger from shared.models.schemas.job_metadata import JobMetadataHelper +from shared.services.ai.token_tracking import get_current_token_tracker +from app.services.document_parser.support.stage_profiler import get_current_stage_tracker from shared.services.storage.result_storage import ResultStorage, get_result_storage from shared.services.storage.zip_result_service import ZipResultService @@ -46,6 +48,7 @@ def finalize_parse_success( source_file_name=source_file_name, ) _attach_document_top_summary(result_package.chunks, document_top_summary) + _refresh_processing_stages(job_context) lifecycle_service.update_progress( job_id, @@ -57,6 +60,7 @@ def finalize_parse_success( job_context=job_context, processing_started_at=processing_started_at, ) + _refresh_processing_stages(job_context) generated_package = _generate_result_package( result_package=result_package, job_context=job_context, @@ -160,6 +164,21 @@ def _enrich_document_navigation( return document_top_summary, section_summaries +def _refresh_processing_stages(job_context: ParseJobContext) -> None: + token_usage = get_current_token_tracker() + timing_ms = get_current_stage_tracker() + if token_usage is None and timing_ms is None: + return + + current_stages = job_context.job_metadata.get("stages") + stages = dict(current_stages) if isinstance(current_stages, dict) else {} + if token_usage is not None: + stages["token_usage"] = dict(token_usage) + if timing_ms is not None: + stages["timing_ms"] = dict(timing_ms) + job_context.job_metadata["stages"] = stages + + def _attach_document_top_summary( chunks: list[dict[str, Any]], document_top_summary: str, @@ -189,6 +208,9 @@ def _record_processing_completion( int((processing_completed_at - processing_started_at).total_seconds() * 1000), ), } + _refresh_processing_stages(job_context) + if "stages" in job_context.job_metadata: + processing_timing_updates["stages"] = job_context.job_metadata["stages"] job_context.metadata_service.update_metadata(job_id, processing_timing_updates) job_context.job_metadata.update(processing_timing_updates) diff --git a/apps/worker/app/services/document_parser/formats/atlas/parser.py b/apps/worker/app/services/document_parser/formats/atlas/parser.py index efca015e..90af7764 100644 --- a/apps/worker/app/services/document_parser/formats/atlas/parser.py +++ b/apps/worker/app/services/document_parser/formats/atlas/parser.py @@ -23,7 +23,6 @@ from app.services.document_parser.support.identifiers import gen_str_codes, get_str_time from app.services.document_parser.support.parser_rows import ParsedRow, ParsedRowsBuilder from app.services.document_parser.formats.pdf.pymupdf_subprocess import run_in_child_process, worker -from app.services.document_parser.structure.toc_parser import detect_tocs_in_texts from loguru import logger from shared.core.config import settings @@ -204,63 +203,6 @@ def _atlas_render_pages_worker( queue.put({"ok": True, "page_data": page_data}) -def _detect_toc_pages_from_texts( - page_texts: list[str], - model_name: str, - hierarchy_model_name: str | None = None, -) -> tuple: - """Detect TOC pages from pre-extracted page texts (no PyMuPDF needed). - - Args: - page_texts: list of text strings, one per page (0-indexed). - model_name: LLM model for TOC range detection. - hierarchy_model_name: Optional dedicated model for TOC hierarchy parsing. - - Returns: - (toc_page_set, toc_hierarchies) - """ - md_lines = [] - for page_idx, text in enumerate(page_texts): - if text: - md_lines.append(f"") - for line in text.split("\n"): - stripped = line.strip() - if stripped: - md_lines.append(stripped) - - if not md_lines: - return set(), None - - toc_hierarchies, _ = detect_tocs_in_texts( - md_lines, - model_name=model_name, - hierarchy_model_name=hierarchy_model_name, - ) - - if not toc_hierarchies: - return set(), None - - page_marker_re = re.compile(r"", re.IGNORECASE) - line_to_page = {} - current_page = 0 - for i, line in enumerate(md_lines): - m = page_marker_re.search(line) - if m: - current_page = int(m.group(1)) - line_to_page[i] = current_page - - toc_pages = set() - for toc_info in toc_hierarchies: - toc_start = toc_info.get("toc_range", (0, 0))[0] - toc_end = toc_info.get("toc_range", (0, 0))[1] - for line_idx in range(toc_start, toc_end + 1): - pg = line_to_page.get(line_idx, 0) - if pg > 0: - toc_pages.add(pg) - - return toc_pages, toc_hierarchies - - # ─── Main entry point ──────────────────────────────────────────────── @@ -309,14 +251,22 @@ def parse_atlas( scan_label = "scanned" if is_scanned else "non-scanned" logger.info(f"📐 Atlas: {scan_label} document, VLM enabled for info extraction") - # ── TOC detection (runs in parent — uses LLM, no PyMuPDF) ── - model_name = base_llm_paras.get("model_name", settings.NORMOL_MODEL) - hierarchy_model_name = base_llm_paras.get("hierarchy_model_name") or model_name - toc_page_set, toc_hierarchies = _detect_toc_pages_from_texts( - page_texts, - model_name, - hierarchy_model_name=hierarchy_model_name, - ) + # ── TOC detection is owned by PDF profiling, not atlas text parsing. ── + profile_toc = getattr(profile, "toc", None) + if profile_toc is not None and getattr(profile_toc, "attempted", False): + toc_page_set = set(profile_toc.toc_pages) + toc_hierarchies = profile_toc.hierarchies + logger.info( + f"📐 Atlas: consuming profile.toc (source={profile_toc.source}, " + f"method={profile_toc.method}, pages={sorted(toc_page_set)}); " + "skipping atlas text TOC detection" + ) + else: + toc_page_set = set() + toc_hierarchies = None + logger.info( + "📐 Atlas: no attempted profile.toc available; treating document as no-TOC" + ) if toc_page_set: logger.info(f"📐 Atlas: TOC pages detected: {sorted(toc_page_set)}") diff --git a/apps/worker/app/services/document_parser/formats/fragment/parser.py b/apps/worker/app/services/document_parser/formats/fragment/parser.py index 66a20069..1409724b 100644 --- a/apps/worker/app/services/document_parser/formats/fragment/parser.py +++ b/apps/worker/app/services/document_parser/formats/fragment/parser.py @@ -43,6 +43,7 @@ def generate_fragment_title(content: str, max_tokens: int = 30) -> Optional[str] messages=messages, max_tokens=max_tokens, timeout=30, + usage_task="parser.fragment_title", ) if generated_title: return str(generated_title).strip()[:50] diff --git a/apps/worker/app/services/document_parser/formats/image/parser.py b/apps/worker/app/services/document_parser/formats/image/parser.py index df7555bb..fe4f29b0 100755 --- a/apps/worker/app/services/document_parser/formats/image/parser.py +++ b/apps/worker/app/services/document_parser/formats/image/parser.py @@ -166,6 +166,7 @@ def ask_image( temperature=temperature, max_tokens=max_tokens, top_p=top_p, + usage_task=f"parser.image.{task}", ) logger.debug(f"Image understanding response: {resp}") # Only parse as JSON for tasks that return structured data diff --git a/apps/worker/app/services/document_parser/formats/markdown/parser.py b/apps/worker/app/services/document_parser/formats/markdown/parser.py index 925bd524..8c887899 100755 --- a/apps/worker/app/services/document_parser/formats/markdown/parser.py +++ b/apps/worker/app/services/document_parser/formats/markdown/parser.py @@ -228,6 +228,7 @@ def parse_md( toc_hierarchies=None, lines_with_heading=None, is_first_shard=True, + skip_toc_detection=False, ): if lines_with_heading is not None: # ── Phase A bypass ── @@ -277,6 +278,8 @@ def parse_md( f"({len(toc_hierarchies)} regions), " f"skipping detect_tocs_in_texts" ) + elif skip_toc_detection: + logger.info("📌 Skipping TOC detection by upstream parser decision") else: with stage_timer( "md.detect_toc", line_count=len(md_lines), model_name=toc_model_name diff --git a/apps/worker/app/services/document_parser/formats/pdf/parser.py b/apps/worker/app/services/document_parser/formats/pdf/parser.py index 791cc40e..3d28e7e5 100755 --- a/apps/worker/app/services/document_parser/formats/pdf/parser.py +++ b/apps/worker/app/services/document_parser/formats/pdf/parser.py @@ -9,9 +9,7 @@ ) from app.services.document_parser.providers.mineru.pdf_service import parse_via_full from app.services.document_parser.profiling.taxonomy import PdfRoutingCategory -from app.services.document_parser.structure.toc_parser import detect_tocs_in_texts from app.services.document_parser.support.stage_profiler import stage_timer -from app.services.document_parser.support.text_helpers import normalize_md from loguru import logger from shared.core.config import settings @@ -84,6 +82,7 @@ def parse_pdfs( file_path=os.path.join(output_dir, "full.md"), base_llm_paras=base_llm_paras, relative_root=relative_root, + skip_toc_detection=True, ) @@ -241,7 +240,6 @@ class ShardHeadingResult: heading_count: int smart_parse = base_llm_paras.get("smart_title_parse", True) - toc_model_name = base_llm_paras.get("model_name", settings.NORMOL_MODEL) hierarchy_model_name = ( base_llm_paras.get("hierarchy_model_name") or base_llm_paras.get("model_name", settings.NORMOL_MODEL) @@ -263,16 +261,6 @@ def _predict_shard_headings( is_first_shard = shard_idx == 0 shard_toc = toc_hierarchies - if shard_toc is None and is_first_shard and _md_has_toc_keyword(md_lines): - logger.info( - f"📌 shard_{shard_idx}: TOC keyword found without profile TOC; " - "reusing markdown TOC detector" - ) - shard_toc, md_lines = detect_tocs_in_texts( - md_lines, - model_name=toc_model_name, - hierarchy_model_name=hierarchy_model_name, - ) lines_with_heading = eval_md_headings( md_lines, @@ -366,12 +354,6 @@ def _predict_shard_headings( _cleanup_temp_shard_s3_assets(temp_shard_s3_keys) _cleanup_local_shard_workspace(work_dir) - -def _md_has_toc_keyword(md_lines: list[str]) -> bool: - toc_keywords = {"目录", "目次", "tableofcontents", "contents"} - return any(normalize_md(line) in toc_keywords for line in md_lines) - - def _build_temp_shard_s3_key( *, source_s3_key: str | None, diff --git a/apps/worker/app/services/document_parser/formats/text/parser.py b/apps/worker/app/services/document_parser/formats/text/parser.py index 161e0986..b13c9651 100755 --- a/apps/worker/app/services/document_parser/formats/text/parser.py +++ b/apps/worker/app/services/document_parser/formats/text/parser.py @@ -136,7 +136,10 @@ def extract_title_keywords_summary(texts, max_keywords=3, summary_len=None): redis_service.set(f"task:{ctx_task_id}:status", "processing", ttl=7200) resp = get_openai_client().chat_completion( - messages=messages, timeout=90, max_tokens=max_tokens + messages=messages, + timeout=90, + max_tokens=max_tokens, + usage_task="parser.text_summary", ) # Handle null/none response diff --git a/apps/worker/app/services/document_parser/profiling/doc_profiler.py b/apps/worker/app/services/document_parser/profiling/doc_profiler.py index 8944c7ab..ce9fa1d0 100644 --- a/apps/worker/app/services/document_parser/profiling/doc_profiler.py +++ b/apps/worker/app/services/document_parser/profiling/doc_profiler.py @@ -3,7 +3,10 @@ from __future__ import annotations import os +from contextlib import contextmanager +from typing import Any, Iterator +from loguru import logger from app.services.document_agent.coordinator import ProfileCoordinator from app.services.document_parser.orchestration.oversized_pdf_policy import ( build_oversized_pdf_processing_failed_exception, @@ -17,6 +20,7 @@ from app.services.document_parser.profiling.taxonomy import PdfRoutingCategory from shared.core.config import settings +from shared.core.database_sync import get_sync_session_factory def profile_document( @@ -59,20 +63,40 @@ def _profile_pdf( *, job_id: str | None, output_dir: str | None, +) -> ParserDocumentProfile: + with _profile_db_context(enabled=bool(job_id)) as db: + return _profile_pdf_with_db( + file_path=file_path, + filename=filename, + job_id=job_id, + output_dir=output_dir, + db=db, + ) + + +def _profile_pdf_with_db( + *, + file_path: str, + filename: str, + job_id: str | None, + output_dir: str | None, + db: Any | None, ) -> ParserDocumentProfile: profile_job_id = job_id or filename agent_output_dir = os.path.join(output_dir, "_doc_agent") if output_dir else None + page_toc_enabled = settings.PDF_PAGE_TOC_ENABLED coordinator = ProfileCoordinator( pdf_path=file_path, job_id=profile_job_id, output_dir=agent_output_dir, + db=db, model=settings.IMAGE_MODEL, settings={ "planner_model": settings.IMAGE_MODEL, "vlm_model": settings.IMAGE_MODEL, "model": settings.HIERARCHY_LLM_MODEL or settings.NORMOL_MODEL, - "toc_before_coarse": settings.PDF_PROFILE_TOC_ENABLED, - "toc_before_coarse_page_limit": settings.MAX_PDF_PAGE_LIMIT, + "toc_profile_enabled": page_toc_enabled, + "toc_before_coarse": page_toc_enabled, }, ) agent_profile = coordinator.run_coarse() @@ -95,7 +119,6 @@ def _profile_pdf( ), }, ) - if profile.page_count > settings.MAX_PDF_PAGE_LIMIT: raise_if_oversized_pdf_not_supported(page_count=profile.page_count) if not profile.is_atlas: @@ -107,11 +130,16 @@ def _profile_pdf( page_count=profile.page_count, original_exception=exc, ) from exc - elif settings.PDF_PROFILE_TOC_ENABLED: + else: + profile.toc = _map_toc_profile(coordinator) + else: if not profile.is_atlas: profile.anatomy = coordinator.run_lightweight_anatomy() profile.toc = _map_toc_profile(coordinator) + if trace := getattr(coordinator, "trace", None): + trace.persist_doc_profile(profile) + return profile @@ -119,6 +147,10 @@ def _map_toc_profile(coordinator: ProfileCoordinator) -> ParserTocProfile: toc_result = coordinator.blackboard.toc_result if toc_result is None: return ParserTocProfile() + attempted_signal = coordinator.blackboard.global_signals.get( + "toc_profile_attempted" + ) + attempted = bool(attempted_signal) if attempted_signal is not None else True evidence = [ TocEvidence( page_index=item.page_index, @@ -136,7 +168,32 @@ def _map_toc_profile(coordinator: ProfileCoordinator) -> ParserTocProfile: source=source, method=toc_result.method, notes=toc_result.notes, + attempted=attempted, ) +@contextmanager +def _profile_db_context(*, enabled: bool) -> Iterator[Any | None]: + if not enabled: + yield None + return + session = None + try: + session = get_sync_session_factory()() + except Exception as exc: + logger.debug(f"parse profile db session unavailable: {exc}") + yield None + return + + try: + yield session + try: + session.commit() + except Exception as exc: + logger.debug(f"parse profile db commit failed: {exc}") + session.rollback() + finally: + session.close() + + __all__ = ["profile_document"] diff --git a/apps/worker/app/services/document_parser/profiling/profile_model.py b/apps/worker/app/services/document_parser/profiling/profile_model.py index 9334ae77..79df6146 100644 --- a/apps/worker/app/services/document_parser/profiling/profile_model.py +++ b/apps/worker/app/services/document_parser/profiling/profile_model.py @@ -22,6 +22,7 @@ class ParserTocProfile: source: str = "none" method: str = "none" notes: str = "" + attempted: bool = False @property def has_toc(self) -> bool: diff --git a/apps/worker/app/services/document_parser/structure/heading_llm_executor.py b/apps/worker/app/services/document_parser/structure/heading_llm_executor.py index 9db85c0c..1587f2e2 100644 --- a/apps/worker/app/services/document_parser/structure/heading_llm_executor.py +++ b/apps/worker/app/services/document_parser/structure/heading_llm_executor.py @@ -194,6 +194,7 @@ def run_merge_pre_pass( model=model_name, max_tokens=max_tokens, temperature=temperature, + usage_task="parser.heading_merge_pre_pass", ) result = eval_response(answer) except Exception as exc: diff --git a/apps/worker/app/services/document_parser/structure/layout_parser.py b/apps/worker/app/services/document_parser/structure/layout_parser.py index 6c49910c..5b39a08e 100755 --- a/apps/worker/app/services/document_parser/structure/layout_parser.py +++ b/apps/worker/app/services/document_parser/structure/layout_parser.py @@ -210,6 +210,7 @@ def _is_candidate_id(val): model=model_name, max_tokens=max_tokens, temperature=temperature, + usage_task="parser.heading_hierarchy", ) layout_res = eval_response(answer) diff --git a/apps/worker/app/services/document_parser/structure/toc_parser.py b/apps/worker/app/services/document_parser/structure/toc_parser.py index 9380884f..84795aab 100644 --- a/apps/worker/app/services/document_parser/structure/toc_parser.py +++ b/apps/worker/app/services/document_parser/structure/toc_parser.py @@ -235,6 +235,7 @@ def llm_judge_toc_range( model=model_name, max_tokens=max_tokens, temperature=temperature, + usage_task="parser.toc_detect_range", ) result = eval_response(answer) diff --git a/apps/worker/app/services/document_parser/support/stage_profiler.py b/apps/worker/app/services/document_parser/support/stage_profiler.py index 7ec11c34..2bfe9899 100644 --- a/apps/worker/app/services/document_parser/support/stage_profiler.py +++ b/apps/worker/app/services/document_parser/support/stage_profiler.py @@ -60,6 +60,14 @@ def init_stage_tracker() -> dict[str, int]: return tracker +def get_current_stage_tracker() -> dict[str, int] | None: + """Return the active stage timing accumulator for this task, if any.""" + root = _find_root_id() + if root is None: + return None + return _trackers.get(root) + + def cleanup_stage_tracker() -> None: """Remove the stage tracker for the current greenlet.""" gid = _current_greenlet_id() diff --git a/apps/worker/app/services/document_parser/tables/table_frame_parser.py b/apps/worker/app/services/document_parser/tables/table_frame_parser.py index 76d7965d..aa9c4167 100644 --- a/apps/worker/app/services/document_parser/tables/table_frame_parser.py +++ b/apps/worker/app/services/document_parser/tables/table_frame_parser.py @@ -71,6 +71,7 @@ def parse_headers_nonsmart(candidate_frame: pd.DataFrame) -> list[int]: header_response = get_openai_client().chat_completion( messages=messages, timeout=60, + usage_task="parser.table_detect_headers", ) parsed_response = eval_response(header_response) if isinstance(parsed_response, dict): diff --git a/apps/worker/app/services/page_memory/__init__.py b/apps/worker/app/services/page_memory/__init__.py new file mode 100644 index 00000000..b2d39159 --- /dev/null +++ b/apps/worker/app/services/page_memory/__init__.py @@ -0,0 +1,2 @@ +"""Experimental page-memory parser track.""" + diff --git a/apps/worker/app/services/page_memory/memory_service.py b/apps/worker/app/services/page_memory/memory_service.py new file mode 100644 index 00000000..816ea70b --- /dev/null +++ b/apps/worker/app/services/page_memory/memory_service.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pandas as pd + +from app.services.document_agent.pdf_text import read_page_texts +from app.services.document_agent.visual import render_pages +from app.services.document_agent.manifest import ToolContext +from app.services.document_agent.state import AgentBlackboard +from app.services.document_parser.profiling.doc_profiler import profile_document +from app.services.document_parser.support.identifiers import gen_str_codes, get_str_time +from app.services.document_parser.support.parser_rows import PARSER_ROW_COLUMNS +from app.services.page_memory.normalizer import normalize_to_pdf + +from shared.core.exceptions.domain_exceptions import ValidationException + + +_SUPPORTED_GRANULARITY = "whole_doc" +_UNSUPPORTED_GRANULARITY_REASON = "PAGE_MEMORY_GRANULARITY_NOT_IMPLEMENTED" + + +@dataclass(frozen=True) +class PageMemoryInput: + file_path: str + filename: str + output_dir: str + job_id: str | None = None + internal_output_filename: str | None = None + base_url: str = "" + + +def run(request: PageMemoryInput) -> tuple[str, pd.DataFrame]: + """Run the page-memory track. + + PR3 intentionally implements only the whole_doc skeleton. Full page mode, + tagging, and section mapping land in PR4. + """ + full_output_dir = _resolve_output_dir(request) + os.makedirs(full_output_dir, exist_ok=True) + pdf_path, pdf_filename = normalize_to_pdf( + file_path=request.file_path, + filename=request.filename, + output_dir=full_output_dir, + base_url=request.base_url, + ) + profile = profile_document( + pdf_path, + pdf_filename, + job_id=request.job_id, + output_dir=full_output_dir, + ) + verdict = _decide_granularity(profile) + if verdict != _SUPPORTED_GRANULARITY: + _raise_unsupported_granularity(verdict) + return full_output_dir, _build_whole_doc_dataframe( + pdf_path=pdf_path, + filename=request.filename, + output_dir=full_output_dir, + page_count=max(int(profile.page_count or 0), 0), + verdict=verdict, + ) + + +def _resolve_output_dir(request: PageMemoryInput) -> str: + output_name = request.internal_output_filename or request.filename + return os.path.join(request.output_dir, Path(output_name).stem) + + +def _decide_granularity(profile: Any) -> str: + page_count = int(getattr(profile, "page_count", 0) or 0) + toc = getattr(profile, "toc", None) + has_toc = bool(getattr(toc, "has_toc", False)) + if page_count > 200: + return "shard_page" + if page_count <= 6 and not has_toc: + return "whole_doc" + return "page" + + +def _raise_unsupported_granularity(verdict: str) -> None: + raise ValidationException( + user_message=( + "page_memory is enabled, but this PR only supports whole-document " + "page memory. Per-page and shard-page modes are intentionally gated " + "until the page renderer, tagger, and section mapper land." + ), + violations=[ + { + "field": "parse_track", + "description": ( + f"{_UNSUPPORTED_GRANULARITY_REASON}: " + f"granularity={verdict}; supported={_SUPPORTED_GRANULARITY}" + ), + } + ], + internal_message=( + f"{_UNSUPPORTED_GRANULARITY_REASON}: granularity={verdict}; " + f"supported={_SUPPORTED_GRANULARITY}" + ), + ) + + +def _build_whole_doc_dataframe( + *, + pdf_path: str, + filename: str, + output_dir: str, + page_count: int, + verdict: str, +) -> pd.DataFrame: + pages = list(range(1, page_count + 1)) if page_count > 0 else [1] + page_texts = read_page_texts(pdf_path, pages) + raw_text = "\n\n".join(page_texts.get(page, "") for page in pages).strip() + summary = _build_summary(filename=filename, page_count=page_count, raw_text=raw_text) + page_image_uris = _render_page_images( + pdf_path=pdf_path, + output_dir=output_dir, + page_count=page_count, + pages=pages, + ) + content = f"[SUMMARY]\n{summary}\n\n[RAW]\n{raw_text}".strip() + know_id = gen_str_codes(f"wholedoc::{filename}::{content}") + row = { + "content": content, + "path": f"{filename}/Root", + "type": "page", + "length": len(content), + "keywords": "", + "summary": summary, + "know_id": know_id, + "tokens": "", + "connectto": "", + "addtime": get_str_time(), + "page_nums": ",".join(str(page) for page in pages), + "extra_metadata": { + "granularity": "whole_doc", + "strategy_used": "whole_doc" if verdict == "whole_doc" else "whole_doc_fallback", + "source_verdict": verdict, + "page_index": None, + "page_image_uris": page_image_uris, + "status": "clear", + }, + } + return pd.DataFrame([row], columns=pd.Index([*PARSER_ROW_COLUMNS, "extra_metadata"])) + + +def _build_summary(*, filename: str, page_count: int, raw_text: str) -> str: + prefix = f"{filename} whole-document memory ({page_count} pages)" + preview = " ".join(raw_text.split())[:500] + return f"{prefix}: {preview}" if preview else prefix + + +def _render_page_images( + *, + pdf_path: str, + output_dir: str, + page_count: int, + pages: list[int], +) -> list[str]: + if page_count <= 0: + return [] + blackboard = AgentBlackboard() + blackboard.page_count = page_count + ctx = ToolContext( + pdf_path=pdf_path, + job_id="page_memory_render", + blackboard=blackboard, + budget=None, + trace=None, + output_dir=output_dir, + settings={}, + ) + rendered = render_pages(ctx, pages, folder_name="pages", prefix="page", timeout=180) + return [ + str(Path(item["png_path"]).relative_to(output_dir)) + for item in rendered + if item.get("png_path") + ] diff --git a/apps/worker/app/services/page_memory/normalizer.py b/apps/worker/app/services/page_memory/normalizer.py new file mode 100644 index 00000000..04db3f00 --- /dev/null +++ b/apps/worker/app/services/page_memory/normalizer.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import os + +from app.services.document_parser.formats.pptx.parser import ( + _pptx_bytes_to_pdf_bytes, + pptx_to_pdf_libreoffice, +) +from app.services.common.file_loading import load_file_bytes +from loguru import logger + +from shared.core.exceptions.domain_exceptions import ValidationException + + +def normalize_to_pdf( + *, + file_path: str, + filename: str, + output_dir: str, + base_url: str = "", +) -> tuple[str, str]: + """Return a local PDF path and filename for page-memory processing.""" + extension = os.path.splitext(filename)[1].lower() + if extension == ".pdf": + return file_path, filename + if extension == ".pptx": + return _normalize_pptx_to_pdf( + file_path=file_path, + filename=filename, + output_dir=output_dir, + base_url=base_url, + ) + raise ValidationException( + user_message="page_memory parse track only supports PDF and PPTX", + violations=[ + { + "field": "parse_track", + "description": "Allowed file types in this build: .pdf, .pptx", + } + ], + ) + + +def _normalize_pptx_to_pdf( + *, + file_path: str, + filename: str, + output_dir: str, + base_url: str, +) -> tuple[str, str]: + pptx_data = load_file_bytes(file_path, file_url=base_url) + pdf_filename = f"{os.path.splitext(filename)[0]}.pdf" + pdf_path = os.path.join(output_dir, pdf_filename) + try: + pdf_bytes = _pptx_bytes_to_pdf_bytes(pptx_data, filename) + with open(pdf_path, "wb") as f: + f.write(pdf_bytes) + return pdf_path, pdf_filename + except Exception as exc: + logger.warning( + "[page_memory] iLoveAPI PPTX normalization failed for {}; " + "falling back to LibreOffice: {}", + filename, + exc, + ) + local_pptx_path = os.path.join(output_dir, filename) + with open(local_pptx_path, "wb") as f: + f.write(pptx_data) + return pptx_to_pdf_libreoffice(local_pptx_path, output_dir) diff --git a/apps/worker/tests/contract/test_doc_profile_anatomy_contract.py b/apps/worker/tests/contract/test_doc_profile_anatomy_contract.py index 448e1630..19c5242d 100644 --- a/apps/worker/tests/contract/test_doc_profile_anatomy_contract.py +++ b/apps/worker/tests/contract/test_doc_profile_anatomy_contract.py @@ -1,6 +1,7 @@ from __future__ import annotations import importlib +import json import os from pathlib import Path from types import SimpleNamespace @@ -14,6 +15,7 @@ from app.services.document_agent import coordinator as coordinator_module from app.services.document_agent.coordinator import ProfileCoordinator +from app.services.document_agent.trace import ParseRunRecorder from app.services.document_agent.manifest import ( DocumentProfile, H1BoundaryResult, @@ -27,11 +29,17 @@ TocResult, ToolResult, ) +from app.services.document_agent.tools import find_toc_anchor_pages as toc_anchor_tool from app.services.document_agent.validators import validate_shard_plan +from app.services.document_parser.formats.atlas import parser as atlas_parser from app.services.document_parser.formats.pdf import parser as pdf_parser from app.services.document_parser.formats.pdf import shard_splitter from app.services.document_parser.profiling import doc_profiler from app.services.document_parser.profiling.doc_profiler import profile_document +from app.services.document_parser.profiling.profile_model import ( + ParserDocumentProfile, + ParserTocProfile, +) from app.services.document_parser.profiling.taxonomy import PdfRoutingCategory from app.services.document_parser.structure.layout_parser import pred_titles @@ -53,6 +61,24 @@ def _page_feature(page: int = 1) -> PageFeature: ) +def test_toc_anchor_text_scan_matches_full_page_and_cross_line_keywords() -> None: + late_lines = [f"body line {idx}" for idx in range(60)] + ["目录"] + split_lines = ["Table of", "Con", "tents"] + + late_matches = toc_anchor_tool._find_toc_text_matches( # noqa: SLF001 + late_lines, + cross_line_window=6, + ) + split_matches = toc_anchor_tool._find_toc_text_matches( # noqa: SLF001 + split_lines, + cross_line_window=6, + ) + + assert late_matches[0]["line_index"] == 60 + assert late_matches[0]["match_kind"] == "keyword:目录" + assert split_matches[0]["match_kind"] == "cross_line:tableofcontents" + + def test_run_toc_degrades_to_empty_result_on_standard_failure(tmp_path: Path) -> None: coordinator = ProfileCoordinator( pdf_path=str(tmp_path / "standard.pdf"), @@ -109,6 +135,93 @@ def test_run_lightweight_anatomy_builds_single_shard_without_planner_llm( assert anatomy.shard_plan.shards[0].page_end == 2 assert anatomy.toc_result.method == "none" assert (output_dir / "anatomy_map.json").exists() + trace_data = json.loads((output_dir / "trace.json").read_text(encoding="utf-8")) + assert "visual_stages" in trace_data["summary"]["budget"] + + +def test_parse_run_recorder_doc_profile_uses_final_anatomy_toc() -> None: + recorder = ParseRunRecorder(job_id="job-doc-profile") + recorder.set_doc_profile( + ParserDocumentProfile( + file_type="pdf", + category="Financial Prospectus", + routing_category=PdfRoutingCategory.GENERIC, + page_count=0, + ) + ) + recorder._anatomy = PageAnatomyMap( # noqa: SLF001 + job_id="job-doc-profile", + file_path="/tmp/doc.pdf", + page_count=2, + page_features=[_page_feature(1), _page_feature(2)], + page_labels=[ + PageLabel(page=1, kind="normal", confidence=1.0), + PageLabel(page=2, kind="normal", confidence=1.0), + ], + toc_result=TocResult(toc_pages=[2], method="vlm_batch", notes="ok"), + h1_result=H1BoundaryResult(method="toc_grep"), + shard_plan=ShardPlan(enabled=False, reason="not_needed"), + toc_hierarchies=[{"toc_range": [2, 2], "toc_tree": {}}], + global_signals={"toc_profile_attempted": True}, + ) + + doc_profile = recorder._doc_profile_for_plan() # noqa: SLF001 + + assert doc_profile is not None + assert doc_profile["category"] == "Financial Prospectus" + assert doc_profile["routing_category"] == PdfRoutingCategory.GENERIC.value + assert doc_profile["page_count"] == 2 + assert doc_profile["toc"]["method"] == "vlm_batch" + assert doc_profile["toc"]["toc_pages"] == [2] + + +def test_parse_run_recorder_profile_only_row_is_updated_by_anatomy_flush() -> None: + class FakeDb: + def __init__(self) -> None: + self.rows = [] + self.flushes = 0 + + def add(self, row) -> None: + self.rows.append(row) + + def flush(self) -> None: + self.flushes += 1 + + def rollback(self) -> None: + raise AssertionError("rollback should not be called") + + db = FakeDb() + recorder = ParseRunRecorder(job_id="job-profile-only", db=db) + recorder.persist_doc_profile( + ParserDocumentProfile( + file_type="pdf", + category="Atlas", + routing_category=PdfRoutingCategory.ATLAS, + page_count=12, + ) + ) + + assert len(db.rows) == 1 + assert db.rows[0].doc_profile["category"] == "Atlas" + assert db.rows[0].shard_plan is None + + recorder._anatomy = PageAnatomyMap( # noqa: SLF001 + job_id="job-profile-only", + file_path="/tmp/doc.pdf", + page_count=12, + page_features=[_page_feature(1)], + page_labels=[PageLabel(page=1, kind="normal", confidence=1.0)], + toc_result=TocResult(toc_pages=[2], method="vlm_batch", notes="ok"), + h1_result=H1BoundaryResult(method="toc_grep"), + shard_plan=ShardPlan(enabled=False, reason="not_needed"), + toc_hierarchies=[{"toc_range": [2, 2], "toc_tree": {}}], + global_signals={"toc_profile_attempted": True}, + ) + recorder.flush(final_status="ready") + + assert len(db.rows) == 2 + assert db.rows[0].shard_plan is not None + assert db.rows[0].doc_profile["toc"]["method"] == "vlm_batch" def test_run_structural_retries_transient_confirm_failed_toc_result( @@ -320,7 +433,7 @@ def test_run_coarse_runs_toc_before_planner_for_oversized_and_reuses_planner( pdf_path=str(tmp_path / "oversized.pdf"), job_id="job-toc-before-coarse", output_dir=str(tmp_path / "profile"), - settings={"toc_before_coarse_page_limit": 2}, + settings={"toc_before_coarse": True}, ) (tmp_path / "profile").mkdir() coordinator.blackboard.page_count = 3 @@ -449,15 +562,18 @@ def test_oversized_single_shard_plan_is_invalid() -> None: assert report.errors == ["shard 0 exceeds max_pages=200"] -def test_standard_pdf_profile_toc_flag_off_preserves_current_behavior( +def test_standard_pdf_profile_builds_page_toc_and_lightweight_anatomy_by_default( monkeypatch, tmp_path: Path, ) -> None: fake_instances: list[object] = [] + fake_anatomy = object() + init_settings: list[dict[str, object]] = [] class FakeCoordinator: - def __init__(self, **_kwargs) -> None: + def __init__(self, **kwargs) -> None: self.calls: list[str] = [] + init_settings.append(kwargs["settings"]) self.blackboard = SimpleNamespace( page_count=2, doc_stats={"page_count": 2}, @@ -467,6 +583,56 @@ def __init__(self, **_kwargs) -> None: ) fake_instances.append(self) + def run_coarse(self) -> DocumentProfile: + self.calls.append("run_coarse") + self.blackboard.toc_result = TocResult(method="none") + return DocumentProfile( + is_scanned=False, + category="Research Report", + routing_category=PdfRoutingCategory.GENERIC.value, + ) + + def run_lightweight_anatomy(self): + self.calls.append("run_lightweight_anatomy") + return fake_anatomy + + monkeypatch.setattr(doc_profiler, "ProfileCoordinator", FakeCoordinator) + monkeypatch.setattr(doc_profiler.settings, "MAX_PDF_PAGE_LIMIT", 200) + + profile = profile_document( + str(tmp_path / "standard.pdf"), + "standard.pdf", + job_id="job-page-toc", + output_dir=str(tmp_path), + ) + + assert profile.toc.has_toc is False + assert profile.toc.attempted is True + assert profile.anatomy is fake_anatomy + assert fake_instances[0].calls == ["run_coarse", "run_lightweight_anatomy"] + assert init_settings[0]["toc_profile_enabled"] is True + assert init_settings[0]["toc_before_coarse"] is True + + +def test_standard_pdf_page_toc_kill_switch_builds_no_toc_anatomy( + monkeypatch, + tmp_path: Path, +) -> None: + fake_anatomy = object() + init_settings: list[dict[str, object]] = [] + + class FakeCoordinator: + def __init__(self, **kwargs) -> None: + self.calls: list[str] = [] + init_settings.append(kwargs["settings"]) + self.blackboard = SimpleNamespace( + page_count=2, + doc_stats={"page_count": 2}, + global_signals={}, + toc_result=None, + toc_hierarchies=None, + ) + def run_coarse(self) -> DocumentProfile: self.calls.append("run_coarse") return DocumentProfile( @@ -476,30 +642,37 @@ def run_coarse(self) -> DocumentProfile: ) def run_toc(self) -> TocResult: - self.calls.append("run_toc") - raise AssertionError("run_toc should be flag-gated for standard PDFs") + raise AssertionError("kill switch should not call TOC profiling") def run_lightweight_anatomy(self): self.calls.append("run_lightweight_anatomy") - raise AssertionError("lightweight anatomy should be flag-gated") + self.blackboard.toc_result = TocResult( + method="none", + notes="TOC profiling disabled by PDF_PAGE_TOC_ENABLED", + ) + self.blackboard.global_signals["toc_profile_attempted"] = False + return fake_anatomy monkeypatch.setattr(doc_profiler, "ProfileCoordinator", FakeCoordinator) - monkeypatch.setattr(doc_profiler.settings, "PDF_PROFILE_TOC_ENABLED", False) + monkeypatch.setattr(doc_profiler.settings, "PDF_PAGE_TOC_ENABLED", False) monkeypatch.setattr(doc_profiler.settings, "MAX_PDF_PAGE_LIMIT", 200) profile = profile_document( str(tmp_path / "standard.pdf"), "standard.pdf", - job_id="job-flag-off", + job_id="job-page-toc-disabled", output_dir=str(tmp_path), ) + assert init_settings[0]["toc_profile_enabled"] is False + assert init_settings[0]["toc_before_coarse"] is False + assert profile.toc.attempted is False assert profile.toc.has_toc is False - assert profile.anatomy is None - assert fake_instances[0].calls == ["run_coarse"] + assert profile.toc.notes == "TOC profiling disabled by PDF_PAGE_TOC_ENABLED" + assert profile.anatomy is fake_anatomy -def test_standard_pdf_profile_toc_flag_on_builds_toc_and_lightweight_anatomy( +def test_standard_pdf_profile_maps_page_toc_evidence( monkeypatch, tmp_path: Path, ) -> None: @@ -555,13 +728,12 @@ def __init__(self, **kwargs) -> None: fake_instances.append(self) monkeypatch.setattr(doc_profiler, "ProfileCoordinator", CapturingCoordinator) - monkeypatch.setattr(doc_profiler.settings, "PDF_PROFILE_TOC_ENABLED", True) monkeypatch.setattr(doc_profiler.settings, "MAX_PDF_PAGE_LIMIT", 200) profile = profile_document( str(tmp_path / "standard.pdf"), "standard.pdf", - job_id="job-flag-on", + job_id="job-page-toc-evidence", output_dir=str(tmp_path), ) @@ -570,11 +742,215 @@ def __init__(self, **kwargs) -> None: "run_lightweight_anatomy", ] assert profile.toc.has_toc is True + assert profile.toc.attempted is True assert profile.toc.method == "vlm_batch" assert profile.toc.evidence[0].confidence == 0.95 assert profile.anatomy is fake_anatomy +def test_oversized_atlas_surfaces_profile_toc_without_structural_anatomy( + monkeypatch, + tmp_path: Path, +) -> None: + calls: list[str] = [] + + class FakeCoordinator: + def __init__(self, **_kwargs) -> None: + self.blackboard = SimpleNamespace( + page_count=250, + doc_stats={"page_count": 250}, + global_signals={}, + toc_result=None, + toc_hierarchies=None, + ) + + def run_coarse(self) -> DocumentProfile: + calls.append("run_coarse") + self.blackboard.toc_result = TocResult( + toc_pages=[4], + evidence=[ + AgentTocEvidence( + page_index=4, + source="vlm", + confidence=0.9, + reason="table of contents", + ) + ], + method="vlm_batch", + ) + self.blackboard.toc_hierarchies = [ + {"toc_range": [4, 4], "toc_range_unit": "page", "toc_tree": {}} + ] + return DocumentProfile( + is_scanned=False, + category="Engineering Atlas", + routing_category=PdfRoutingCategory.ATLAS.value, + ) + + def run_structural(self): + raise AssertionError("oversized atlas should not run structural anatomy") + + def run_lightweight_anatomy(self): + raise AssertionError("oversized atlas should not run lightweight anatomy") + + monkeypatch.setattr(doc_profiler, "ProfileCoordinator", FakeCoordinator) + monkeypatch.setattr(doc_profiler.settings, "MAX_PDF_PAGE_LIMIT", 200) + monkeypatch.setattr(doc_profiler.settings, "OVERSIZED_PDF_SHARD_ENABLED", True) + monkeypatch.setattr(doc_profiler.settings, "OVERSIZED_PDF_SOFT_LIMIT", 500) + + profile = profile_document( + str(tmp_path / "oversized-atlas.pdf"), + "oversized-atlas.pdf", + job_id="job-oversized-atlas", + output_dir=str(tmp_path), + ) + + assert calls == ["run_coarse"] + assert profile.is_atlas is True + assert profile.anatomy is None + assert profile.toc.has_toc is True + assert profile.toc.attempted is True + assert profile.toc.toc_pages == [4] + + +def _patch_atlas_page_pipeline( + monkeypatch, + page_texts: list[str], +) -> list[set[int]]: + render_skip_pages: list[set[int]] = [] + + def fake_run_in_child_process(func, *_args, **_kwargs): + if func is atlas_parser._atlas_extract_texts_worker: + return {"total_pages": len(page_texts), "page_texts": page_texts} + + if func is atlas_parser._atlas_render_pages_worker: + _pdf_path, img_dir, skip_pages_list, _dpi, page_texts_list = _args + skip_pages = set(skip_pages_list) + render_skip_pages.append(skip_pages) + page_data = [] + for page_idx, page_text in enumerate(page_texts_list): + page_num = page_idx + 1 + if page_num in skip_pages: + continue + img_name = f"page-{page_num}.jpg" + Path(img_dir, img_name).write_bytes(f"image-{page_num}".encode()) + page_data.append((page_num, page_text, img_name)) + return {"page_data": page_data} + + raise AssertionError(f"unexpected atlas worker: {func}") + + monkeypatch.setattr(atlas_parser, "run_in_child_process", fake_run_in_child_process) + monkeypatch.setattr( + atlas_parser, + "_vlm_extract_page_info", + lambda _output_dir, img_name: f"Drawing info for {img_name}", + ) + return render_skip_pages + + +def test_atlas_consumes_profile_toc_without_internal_detector( + monkeypatch, + tmp_path: Path, +) -> None: + render_skip_pages = _patch_atlas_page_pipeline( + monkeypatch, + ["Cover", "Contents", "Drawing page"], + ) + + profile = SimpleNamespace( + is_scanned=False, + toc=ParserTocProfile( + toc_pages=[2], + hierarchies=[ + {"toc_range": [2, 2], "toc_range_unit": "page", "toc_tree": {}} + ], + source="pdf_vlm", + method="vlm_batch", + attempted=True, + ), + ) + + df = atlas_parser.parse_atlas( + str(tmp_path / "atlas.pdf"), + str(tmp_path / "out"), + {"stopwords": [], "model_name": "test-model"}, + relative_root="kb/atlas.pdf", + profile=profile, + ) + + assert render_skip_pages == [{2}] + assert len(df) == 2 + assert not any("[page-2]" in content for content in df["content"]) + toc_json = tmp_path / "out" / "toc_hierarchies.json" + assert toc_json.exists() + + +def test_atlas_does_not_fallback_when_profile_toc_not_attempted( + monkeypatch, + tmp_path: Path, +) -> None: + render_skip_pages = _patch_atlas_page_pipeline( + monkeypatch, + ["Contents", "Drawing page"], + ) + profile = SimpleNamespace( + is_scanned=False, + toc=ParserTocProfile(toc_pages=[2], attempted=False), + ) + + df = atlas_parser.parse_atlas( + str(tmp_path / "atlas.pdf"), + str(tmp_path / "out"), + { + "stopwords": [], + "model_name": "content-model", + "hierarchy_model_name": "hierarchy-model", + }, + relative_root="kb/atlas.pdf", + profile=profile, + ) + + assert render_skip_pages == [set()] + assert len(df) == 2 + assert any("[page-1]" in content for content in df["content"]) + + +def test_pdf_standard_single_pass_skips_markdown_toc_detection( + monkeypatch, + tmp_path: Path, +) -> None: + output_dir = tmp_path / "out" + output_dir.mkdir() + calls: dict[str, object] = {} + + def fake_parse_via_full(_pdf_path, _filename, out_dir, s3_key=None): + calls["s3_key"] = s3_key + Path(out_dir, "full.md").write_text("Contents\nBody\n", encoding="utf-8") + + def fake_parse_md(*_args, **kwargs): + calls["skip_toc_detection"] = kwargs.get("skip_toc_detection") + return {"ok": True} + + monkeypatch.setattr(pdf_parser, "parse_via_full", fake_parse_via_full) + monkeypatch.setattr(pdf_parser, "parse_md", fake_parse_md) + + result = pdf_parser.parse_pdfs( + str(tmp_path / "standard.pdf"), + "standard.pdf", + str(output_dir), + {"smart_title_parse": False}, + profile=SimpleNamespace( + routing_category=PdfRoutingCategory.GENERIC, + anatomy=None, + page_count=2, + ), + s3_key="uploads/source.pdf", + ) + + assert result == {"ok": True} + assert calls == {"s3_key": "uploads/source.pdf", "skip_toc_detection": True} + + def test_pdf_shard_pipeline_accepts_single_shard_fast_path( monkeypatch, tmp_path: Path, @@ -652,13 +1028,12 @@ def fake_parse_md(*_args, **kwargs): assert result["lines"] == ["# 1. Introduction", "Body"] -def test_pdf_first_shard_reuses_markdown_toc_detector_when_profile_misses_toc( +def test_pdf_shard_pipeline_does_not_use_markdown_toc_detector( monkeypatch, tmp_path: Path, ) -> None: output_dir = tmp_path / "out" output_dir.mkdir() - detector_calls: list[list[str]] = [] heading_contexts: list[object] = [] def fake_parse_via_full(_pdf_path, _filename, out_dir, s3_key=None): @@ -667,18 +1042,8 @@ def fake_parse_via_full(_pdf_path, _filename, out_dir, s3_key=None): encoding="utf-8", ) - def fake_detect_tocs_in_texts(md_lines, **_kwargs): - detector_calls.append(list(md_lines)) - return ( - [ - { - "toc_range": [0, 1], - "toc_range_unit": "line", - "toc_tree": {"Introduction": {}}, - } - ], - ["1 Introduction", "Body"], - ) + def fail_detect_tocs_in_texts(*_args, **_kwargs): + raise AssertionError("PDF shard pipeline must not call markdown TOC detector") def fake_eval_md_headings(md_lines, *_args, **kwargs): heading_contexts.append(kwargs.get("toc_hierarchies")) @@ -687,8 +1052,15 @@ def fake_eval_md_headings(md_lines, *_args, **kwargs): active_markdown_parser = importlib.import_module( "app.services.document_parser.formats.markdown.parser" ) + active_toc_parser = importlib.import_module( + "app.services.document_parser.structure.toc_parser" + ) monkeypatch.setattr(pdf_parser, "parse_via_full", fake_parse_via_full) - monkeypatch.setattr(pdf_parser, "detect_tocs_in_texts", fake_detect_tocs_in_texts) + monkeypatch.setattr( + active_toc_parser, + "detect_tocs_in_texts", + fail_detect_tocs_in_texts, + ) monkeypatch.setattr( active_markdown_parser, "eval_md_headings", @@ -739,9 +1111,13 @@ def fake_eval_md_headings(md_lines, *_args, **kwargs): profile=profile, ) - assert len(detector_calls) == 1 - assert heading_contexts[0][0]["toc_range_unit"] == "line" - assert result["lines"] == ["# 1 Introduction", "Body"] + assert heading_contexts == [None] + assert result["lines"] == [ + "Contents", + "# 1 Introduction .... 2", + "# 1 Introduction", + "Body", + ] def test_page_based_toc_demotes_front_matter_only_on_first_shard() -> None: diff --git a/apps/worker/tests/contract/test_document_agent_budget_contract.py b/apps/worker/tests/contract/test_document_agent_budget_contract.py new file mode 100644 index 00000000..3fb37961 --- /dev/null +++ b/apps/worker/tests/contract/test_document_agent_budget_contract.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import os + +os.environ.setdefault("DATABASE_URL", "postgresql+asyncpg://test:test@localhost/test") +os.environ.setdefault("TMP_PATH", "/tmp/knowhere-test") +os.environ.setdefault("S3_BUCKET_NAME", "test-uploads") +os.environ.setdefault("S3_ACCESS_KEY_ID", "test") +os.environ.setdefault("S3_SECRET_ACCESS_KEY", "test") +os.environ.setdefault("S3_TEMP_PATH", "/tmp") + +from app.services.document_agent.budget import BudgetTracker, StageEnvelope +from app.services.page_memory import memory_service +from shared.core.exceptions.domain_exceptions import ValidationException +from shared.services.chunks.dataframe_chunk_converter import dataframe_to_chunks + +import pandas as pd +import pytest + + +def test_visual_stage_envelope_preserves_other_stage_guarantee() -> None: + budget = BudgetTracker( + plan_budget=100, + visual_budget=100, + visual_stage_envelopes={ + "toc_confirm": StageEnvelope(min_guarantee=30, cap=60), + "coarse_planner": StageEnvelope(min_guarantee=40, cap=70), + }, + ) + + assert budget.try_reserve("visual", 30, stage="toc_confirm") is True + budget.commit("visual", actual=25, est=30, stage="toc_confirm") + assert budget.try_reserve("visual", 36, stage="toc_confirm") is False + + snapshot = budget.snapshot() + assert snapshot["visual"]["used"] == 25 + assert snapshot["visual_stages"]["toc_confirm"]["used"] == 25 + + assert budget.try_reserve("visual", 40, stage="coarse_planner") is True + budget.refund("visual", est=40, stage="coarse_planner") + assert budget.snapshot()["visual_stages"]["coarse_planner"]["reserved"] == 0 + + +def test_visual_stage_cap_rejects_overage_while_legacy_calls_remain_supported() -> None: + budget = BudgetTracker( + plan_budget=100, + visual_budget=100, + visual_stage_envelopes={ + "toc_confirm": StageEnvelope(min_guarantee=0, cap=20), + }, + ) + + assert budget.try_reserve("visual", 21, stage="toc_confirm") is False + assert budget.try_reserve("visual", 90) is True + budget.commit("visual", actual=80, est=90) + + snapshot = budget.snapshot() + assert snapshot["visual"]["used"] == 80 + assert snapshot["visual_stages"]["toc_confirm"]["used"] == 0 + + +def test_dataframe_converter_accepts_page_chunks_with_extra_metadata() -> None: + df = pd.DataFrame( + [ + { + "content": "[SUMMARY]\nshort\n\n[RAW]\nbody", + "path": "demo.pdf/Root", + "type": "page", + "length": 26, + "keywords": "", + "summary": "short", + "know_id": "page-1", + "tokens": "", + "connectto": "", + "addtime": "2026-06-11 00:00:00", + "page_nums": "1,2", + "extra_metadata": { + "granularity": "whole_doc", + "page_image_uris": ["pages/page_page_1.png"], + "page_nums": [99], + }, + } + ] + ) + + chunks = dataframe_to_chunks(df) + + assert chunks[0]["type"] == "page" + assert chunks[0]["metadata"]["granularity"] == "whole_doc" + assert chunks[0]["metadata"]["page_nums"] == [1, 2] + + +def test_page_memory_unsupported_granularity_is_explicit() -> None: + with pytest.raises(ValidationException) as exc_info: + memory_service._raise_unsupported_granularity("page") # noqa: SLF001 + + error = exc_info.value + assert "whole-document page memory" in error.user_message + assert "PAGE_MEMORY_GRANULARITY_NOT_IMPLEMENTED" in error.internal_message + assert error.violations[0]["field"] == "parse_track" + assert "granularity=page" in error.violations[0]["description"] diff --git a/packages/shared-python/shared/core/config/ai.py b/packages/shared-python/shared/core/config/ai.py index 559e53f3..7b9246c7 100644 --- a/packages/shared-python/shared/core/config/ai.py +++ b/packages/shared-python/shared/core/config/ai.py @@ -78,6 +78,13 @@ class AIConfig(BaseModel): default=8, description="Max concurrent gevent greenlets for parallel post-heading summary LLM calls -- image/table/text (Dashscope).", ) + TOKEN_PRICING_TABLE_JSON: str = Field( + default="", + description=( + "JSON model pricing table for internal token cost estimates. " + "Rates are USD per 1M tokens, keyed by model name." + ), + ) # Compatibility fields retained during migration. ARK_API_KEY: str = Field( diff --git a/packages/shared-python/shared/core/config/storage.py b/packages/shared-python/shared/core/config/storage.py index 75046091..8af1d0d1 100644 --- a/packages/shared-python/shared/core/config/storage.py +++ b/packages/shared-python/shared/core/config/storage.py @@ -79,12 +79,19 @@ class StorageConfig(BaseModel): description="Soft page limit for oversized PDF shard pipeline. " "Documents exceeding this are rejected with a contact-support message.", ) - PDF_PROFILE_TOC_ENABLED: bool = Field( + PDF_PAGE_TOC_ENABLED: bool = Field( + default=True, + description=( + "Enable page-owned PDF TOC profiling during parser-entry DOC_PROFILE. " + "When disabled, PDF parsing treats documents as no-TOC and does not " + "fall back to Markdown TOC detection." + ), + ) + RETRIEVAL_PAGE_MEMORY_ENABLED: bool = Field( default=False, description=( - "Enable PDF TOC extraction during parser-entry DOC_PROFILE for " - "standard and atlas PDFs. Oversized PDFs still run TOC profiling as " - "part of the shard pipeline." + "Enable the experimental page_memory parse track. The API rejects " + "page_memory requests while this is false." ), ) MINERU_SHARD_CONCURRENCY: int = Field( diff --git a/packages/shared-python/shared/models/database/document.py b/packages/shared-python/shared/models/database/document.py index 41bed45b..dcf3a982 100644 --- a/packages/shared-python/shared/models/database/document.py +++ b/packages/shared-python/shared/models/database/document.py @@ -46,6 +46,7 @@ class Document(Base): nullable=True, ) source_file_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + parse_track: Mapped[str] = mapped_column(String(32), nullable=False, default="chunk") created_at: Mapped[datetime] = mapped_column( DateTime, default=utc_now_naive, nullable=False ) diff --git a/packages/shared-python/shared/models/database/document_page_plan.py b/packages/shared-python/shared/models/database/document_page_plan.py index ee2fb228..691ad5a8 100644 --- a/packages/shared-python/shared/models/database/document_page_plan.py +++ b/packages/shared-python/shared/models/database/document_page_plan.py @@ -21,6 +21,7 @@ class DocumentPagePlan(Base): ) page_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) shard_plan: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) + doc_profile: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) global_signals: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSON, nullable=True) created_at: Mapped[datetime] = mapped_column( DateTime, default=utc_now_naive, nullable=False diff --git a/packages/shared-python/shared/models/schemas/job.py b/packages/shared-python/shared/models/schemas/job.py index a2b68655..e4ad54be 100644 --- a/packages/shared-python/shared/models/schemas/job.py +++ b/packages/shared-python/shared/models/schemas/job.py @@ -58,6 +58,13 @@ class JobCreate(BaseModel): description="File name; required when source_type=file and must include the extension", ) data_id: Optional[str] = Field(None, max_length=128, description="User-defined ID") + parse_track: Literal["chunk", "page_memory"] = Field( + "chunk", + description=( + "Parser track. page_memory is feature-gated and currently only " + "supported for PDF/PPT/PPTX." + ), + ) parsing_params: Optional[ParsingParams] = Field( None, description="Parsing parameters" ) diff --git a/packages/shared-python/shared/models/schemas/job_metadata.py b/packages/shared-python/shared/models/schemas/job_metadata.py index 71aa4d2d..109845fa 100644 --- a/packages/shared-python/shared/models/schemas/job_metadata.py +++ b/packages/shared-python/shared/models/schemas/job_metadata.py @@ -46,6 +46,7 @@ def create_from_request(request, **kwargs) -> Dict[str, Any]: "original_request": request.model_dump(), "namespace": namespace, "document_id": request.document_id, + "parse_track": request.parse_track, "parsing_params": ( request.parsing_params.model_dump() if request.parsing_params else None ), @@ -126,6 +127,11 @@ def get_document_id(metadata: Optional[Dict[str, Any]]) -> str | None: """Return the retrieval document id stored in metadata.""" return JobMetadataHelper.get_string_field(metadata, "document_id") + @staticmethod + def get_parse_track(metadata: Optional[Dict[str, Any]]) -> str: + """Return the parser track stored in metadata.""" + return JobMetadataHelper.get_string_field(metadata, "parse_track", "chunk") or "chunk" + @staticmethod def get_data_id(metadata: Optional[Dict[str, Any]]) -> str | None: """Return the user-defined data id stored in metadata.""" diff --git a/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py b/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py index 9ce586eb..9e3d9667 100644 --- a/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py +++ b/packages/shared-python/shared/services/ai/openai_compatible_client_sync.py @@ -184,6 +184,7 @@ def _make_ali_pool_raw_call( temperature: float, max_tokens: int, api_kwargs: Dict[str, Any], + usage_task: str | None = None, ) -> tuple[Any, LLMUsage]: """Acquire a token, make the call, and retry inline on 429.""" from shared.services.ai.ali_quota_manager import get_ali_quota_manager @@ -215,7 +216,7 @@ def _make_ali_pool_raw_call( provider=self.default_model, ) usage = _extract_usage(response) - record_tokens(usage) + record_tokens(usage, model=model, task=usage_task) return response, usage except openai.RateLimitError as exc: retry_after = _parse_retry_after(exc) @@ -262,6 +263,7 @@ def _make_ali_pool_call( temperature: float, max_tokens: int, api_kwargs: Dict[str, Any], + usage_task: str | None = None, ) -> tuple[str, LLMUsage]: response, usage = self._make_ali_pool_raw_call( model=model, @@ -269,6 +271,7 @@ def _make_ali_pool_call( temperature=temperature, max_tokens=max_tokens, api_kwargs=api_kwargs, + usage_task=usage_task, ) return response.choices[0].message.content or "", usage @@ -290,6 +293,7 @@ def chat_completion_raw_with_usage( else: all_messages = [{"role": "user", "content": str(messages)}] + usage_task = str(kwargs.pop("usage_task", "") or "") or None api_kwargs: Dict[str, Any] = {} if top_p is not None: api_kwargs["top_p"] = top_p @@ -326,6 +330,7 @@ def chat_completion_raw_with_usage( temperature=temperature, max_tokens=max_tokens, api_kwargs=api_kwargs, + usage_task=usage_task, ) client = self._client @@ -348,7 +353,7 @@ def chat_completion_raw_with_usage( provider=self.default_model, ) usage = _extract_usage(response) - record_tokens(usage) + record_tokens(usage, model=effective_model, task=usage_task) return response, usage except LLMServiceException: raise @@ -381,6 +386,7 @@ def chat_completion_with_usage( else: all_messages = [{"role": "user", "content": str(messages)}] + usage_task = str(kwargs.pop("usage_task", "") or "") or None api_kwargs: Dict[str, Any] = {} if top_p is not None: api_kwargs["top_p"] = top_p @@ -423,6 +429,7 @@ def chat_completion_with_usage( temperature=temperature, max_tokens=max_tokens, api_kwargs=api_kwargs, + usage_task=usage_task, ) except LLMServiceException: raise @@ -460,7 +467,7 @@ def chat_completion_with_usage( content = choices[0].message.content or "" usage = _extract_usage(response) - record_tokens(usage) + record_tokens(usage, model=effective_model, task=usage_task) return content, usage except LLMServiceException: raise diff --git a/packages/shared-python/shared/services/ai/token_costing.py b/packages/shared-python/shared/services/ai/token_costing.py new file mode 100644 index 00000000..36e7e136 --- /dev/null +++ b/packages/shared-python/shared/services/ai/token_costing.py @@ -0,0 +1,251 @@ +"""Token cost estimation for AI model usage snapshots.""" + +from __future__ import annotations + +import json +import hashlib +from copy import deepcopy +from datetime import date +from typing import Any + +from loguru import logger + +from shared.core.config import settings + +DEFAULT_TOKEN_PRICING_TABLE: dict[str, dict[str, Any]] = { + "deepseek-v4-flash": { + "currency": "USD", + "unit": "per_1m_tokens", + "input_per_1m": 0.14, + "cached_input_per_1m": 0.0028, + "output_per_1m": 0.28, + "source": "DeepSeek official pricing", + "effective_date": "2026-06-11", + }, + "deepseek-chat": { + "alias_of": "deepseek-v4-flash", + }, + "deepseek-reasoner": { + "alias_of": "deepseek-v4-flash", + }, + "qwen3.6-flash": { + "currency": "USD", + "unit": "per_1m_tokens", + "input_per_1m": 0.25, + "output_per_1m": 1.50, + "source": "Qwen Cloud official pricing, <=256K input tier", + "effective_date": "2026-06-11", + }, +} + + +def build_token_cost_estimate(token_usage: dict[str, Any] | None) -> dict[str, Any]: + """Build a cost estimate from token usage grouped by model and task.""" + if not isinstance(token_usage, dict): + return {} + + pricing_table = load_token_pricing_table() + by_model_usage = token_usage.get("by_model") + if not isinstance(by_model_usage, dict): + by_model_usage = {} + + by_model_cost: dict[str, Any] = {} + missing_models: list[str] = [] + total_input_cost = 0.0 + total_output_cost = 0.0 + + for model_name, usage in by_model_usage.items(): + if not isinstance(usage, dict): + continue + model_cost = _estimate_usage_cost( + model_name=str(model_name), + usage=usage, + pricing_table=pricing_table, + ) + by_model_cost[str(model_name)] = model_cost + if model_cost.get("pricing_missing"): + missing_models.append(str(model_name)) + continue + total_input_cost += float(model_cost["input_cost"]) + total_output_cost += float(model_cost["output_cost"]) + + by_task_cost = _estimate_task_costs( + token_usage.get("by_task"), + pricing_table=pricing_table, + ) + + total_cost = total_input_cost + total_output_cost + return { + "currency": "USD", + "pricing_version": _pricing_version(pricing_table), + "total_input_cost": _round_cost(total_input_cost), + "total_output_cost": _round_cost(total_output_cost), + "total_cost": _round_cost(total_cost), + "by_model": by_model_cost, + "by_task": by_task_cost, + "missing_pricing_models": sorted(set(missing_models)), + } + + +def load_token_pricing_table() -> dict[str, dict[str, Any]]: + pricing_table = deepcopy(DEFAULT_TOKEN_PRICING_TABLE) + override_raw = getattr(settings, "TOKEN_PRICING_TABLE_JSON", "") + if not override_raw: + return pricing_table + + try: + override = json.loads(override_raw) + except json.JSONDecodeError as exc: + logger.warning("Invalid TOKEN_PRICING_TABLE_JSON, using defaults: {}", exc) + return pricing_table + + if not isinstance(override, dict): + logger.warning("TOKEN_PRICING_TABLE_JSON must be a JSON object, using defaults") + return pricing_table + + for model_name, entry in override.items(): + if isinstance(entry, dict): + pricing_table[str(model_name)] = dict(entry) + return pricing_table + + +def _estimate_task_costs( + by_task_usage: Any, + *, + pricing_table: dict[str, dict[str, Any]], +) -> dict[str, Any]: + if not isinstance(by_task_usage, dict): + return {} + + by_task_cost: dict[str, Any] = {} + for task_name, task_usage in by_task_usage.items(): + if not isinstance(task_usage, dict): + continue + models = task_usage.get("models") + if not isinstance(models, dict): + by_task_cost[str(task_name)] = _estimate_usage_cost( + model_name="unknown", + usage=task_usage, + pricing_table=pricing_table, + ) + continue + + task_input_cost = 0.0 + task_output_cost = 0.0 + model_costs: dict[str, Any] = {} + missing_models: list[str] = [] + for model_name, model_usage in models.items(): + if not isinstance(model_usage, dict): + continue + model_cost = _estimate_usage_cost( + model_name=str(model_name), + usage=model_usage, + pricing_table=pricing_table, + ) + model_costs[str(model_name)] = model_cost + if model_cost.get("pricing_missing"): + missing_models.append(str(model_name)) + continue + task_input_cost += float(model_cost["input_cost"]) + task_output_cost += float(model_cost["output_cost"]) + + by_task_cost[str(task_name)] = { + "input_cost": _round_cost(task_input_cost), + "output_cost": _round_cost(task_output_cost), + "total_cost": _round_cost(task_input_cost + task_output_cost), + "models": model_costs, + "missing_pricing_models": sorted(set(missing_models)), + } + return by_task_cost + + +def _estimate_usage_cost( + *, + model_name: str, + usage: dict[str, Any], + pricing_table: dict[str, dict[str, Any]], +) -> dict[str, Any]: + pricing = _resolve_pricing(model_name, pricing_table) + input_tokens = int(usage.get("prompt_tokens") or 0) + output_tokens = int(usage.get("completion_tokens") or 0) + calls = int(usage.get("calls") or 0) + + if pricing is None: + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "calls": calls, + "input_cost": 0.0, + "output_cost": 0.0, + "total_cost": 0.0, + "pricing_missing": True, + } + + input_rate = float(pricing.get("input_per_1m") or 0) + output_rate = float(pricing.get("output_per_1m") or 0) + input_cost = input_tokens / 1_000_000 * input_rate + output_cost = output_tokens / 1_000_000 * output_rate + return { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "calls": calls, + "input_rate_per_1m": input_rate, + "output_rate_per_1m": output_rate, + "input_cost": _round_cost(input_cost), + "output_cost": _round_cost(output_cost), + "total_cost": _round_cost(input_cost + output_cost), + "pricing": _public_pricing_entry(pricing), + } + + +def _resolve_pricing( + model_name: str, + pricing_table: dict[str, dict[str, Any]], +) -> dict[str, Any] | None: + seen: set[str] = set() + current = model_name + while current and current not in seen: + seen.add(current) + entry = pricing_table.get(current) + if not isinstance(entry, dict): + return None + alias = entry.get("alias_of") + if alias: + current = str(alias) + continue + return entry + return None + + +def _public_pricing_entry(pricing: dict[str, Any]) -> dict[str, Any]: + return { + key: pricing[key] + for key in ( + "currency", + "unit", + "input_per_1m", + "output_per_1m", + "source", + "effective_date", + ) + if key in pricing + } + + +def _pricing_version(pricing_table: dict[str, dict[str, Any]]) -> str: + version = { + str(model): { + key: entry.get(key) + for key in ("input_per_1m", "output_per_1m", "effective_date", "alias_of") + if key in entry + } + for model, entry in pricing_table.items() + if isinstance(entry, dict) + } + encoded = json.dumps(version, sort_keys=True, ensure_ascii=False) + digest = hashlib.sha256(encoded.encode("utf-8")).hexdigest()[:12] + return f"{date.today().isoformat()}:{digest}" + + +def _round_cost(value: float) -> float: + return round(float(value), 8) diff --git a/packages/shared-python/shared/services/ai/token_tracking.py b/packages/shared-python/shared/services/ai/token_tracking.py index 63b84dda..99da9108 100644 --- a/packages/shared-python/shared/services/ai/token_tracking.py +++ b/packages/shared-python/shared/services/ai/token_tracking.py @@ -14,8 +14,9 @@ from __future__ import annotations import threading +from typing import Any -_trackers: dict[int, dict[str, int]] = {} +_trackers: dict[int, dict[str, Any]] = {} _lock = threading.Lock() # The root greenlet id for the current parse task. Stored so that child @@ -58,7 +59,7 @@ def _find_root_id() -> int | None: return None -def init_token_tracker() -> dict[str, int]: +def init_token_tracker() -> dict[str, Any]: """Create a new token accumulator for the current parse task. Must be called from the root greenlet of the task (i.e. from @@ -66,16 +67,27 @@ def init_token_tracker() -> dict[str, int]: accumulate all token usage for the lifetime of this task. """ gid = _current_greenlet_id() - tracker: dict[str, int] = { + tracker: dict[str, Any] = { "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, + "calls": 0, + "by_model": {}, + "by_task": {}, } with _lock: _trackers[gid] = tracker return tracker +def get_current_token_tracker() -> dict[str, Any] | None: + """Return the active token accumulator for this task, if any.""" + root = _find_root_id() + if root is None: + return None + return _trackers.get(root) + + def cleanup_token_tracker() -> None: """Remove the tracker for the current greenlet. Call after parsing.""" gid = _current_greenlet_id() @@ -87,7 +99,24 @@ def cleanup_token_tracker() -> None: del _root_ids[k] -def record_tokens(usage: dict[str, int]) -> None: +def _normalize_bucket_key(value: str | None, fallback: str) -> str: + normalized = str(value or "").strip() + return normalized or fallback + + +def _add_usage(bucket: dict[str, Any], usage: dict[str, int]) -> None: + bucket["prompt_tokens"] = bucket.get("prompt_tokens", 0) + usage.get("prompt_tokens", 0) + bucket["completion_tokens"] = bucket.get("completion_tokens", 0) + usage.get("completion_tokens", 0) + bucket["total_tokens"] = bucket.get("total_tokens", 0) + usage.get("total_tokens", 0) + bucket["calls"] = bucket.get("calls", 0) + 1 + + +def record_tokens( + usage: dict[str, int], + *, + model: str | None = None, + task: str | None = None, +) -> None: """Accumulate token usage into the current task's tracker. Safe to call from any greenlet (root or child). If no tracker is @@ -100,7 +129,22 @@ def record_tokens(usage: dict[str, int]) -> None: tracker = _trackers.get(root) if tracker is None: return + model_key = _normalize_bucket_key(model, "unknown") + task_key = _normalize_bucket_key(task, "unknown") with _lock: - tracker["prompt_tokens"] += usage.get("prompt_tokens", 0) - tracker["completion_tokens"] += usage.get("completion_tokens", 0) - tracker["total_tokens"] += usage.get("total_tokens", 0) + _add_usage(tracker, usage) + by_model = tracker.setdefault("by_model", {}) + if isinstance(by_model, dict): + model_bucket = by_model.setdefault(model_key, {}) + if isinstance(model_bucket, dict): + _add_usage(model_bucket, usage) + by_task = tracker.setdefault("by_task", {}) + if isinstance(by_task, dict): + task_bucket = by_task.setdefault(task_key, {}) + if isinstance(task_bucket, dict): + _add_usage(task_bucket, usage) + task_models = task_bucket.setdefault("models", {}) + if isinstance(task_models, dict): + task_model_bucket = task_models.setdefault(model_key, {}) + if isinstance(task_model_bucket, dict): + _add_usage(task_model_bucket, usage) diff --git a/packages/shared-python/shared/services/chunks/dataframe_chunk_converter.py b/packages/shared-python/shared/services/chunks/dataframe_chunk_converter.py index c57301e7..e837b85d 100644 --- a/packages/shared-python/shared/services/chunks/dataframe_chunk_converter.py +++ b/packages/shared-python/shared/services/chunks/dataframe_chunk_converter.py @@ -38,7 +38,7 @@ def iterrows(self) -> Iterable[tuple[object, _ParserRow]]: ... list["JsonValue"], dict[str, "JsonValue"], ] -ChunkType: TypeAlias = Literal["text", "image", "table"] +ChunkType: TypeAlias = Literal["text", "image", "table", "page"] class ChunkMetadata(TypedDict, total=False): @@ -202,9 +202,50 @@ def _get_chunk_type(value: object) -> ChunkType: return "image" if normalized_type == "table": return "table" + if normalized_type == "page": + return "page" return "text" +def _parse_extra_metadata(value: object) -> dict[str, JsonValue]: + if not value or _is_missing(value): + return {} + if isinstance(value, dict): + return cast(dict[str, JsonValue], value) + if isinstance(value, str): + try: + parsed = json.loads(value) + except json.JSONDecodeError: + return {} + if isinstance(parsed, dict): + return cast(dict[str, JsonValue], parsed) + return {} + + +_RESERVED_METADATA_KEYS = { + "keywords", + "summary", + "length", + "tokens", + "connect_to", + "_relationship_refs", + "page_nums", + "file_path", + "original_name", +} + + +def _merge_extra_metadata( + metadata: ChunkMetadata, + extra_metadata: dict[str, JsonValue], +) -> None: + for key, value in extra_metadata.items(): + if key in _RESERVED_METADATA_KEYS: + logger.warning("Ignoring extra_metadata reserved key: {}", key) + continue + metadata[key] = value + + def _get_relationship_refs(type_value: object, content: str) -> list[RelationshipRef]: return parse_relationship_refs(type_value, content) @@ -245,6 +286,7 @@ def dataframe_to_chunks(df: _ParserDataFrame | None) -> list[Dict[str, JsonValue "_relationship_refs": relationship_refs, "page_nums": _parse_page_numbers(row.get("page_nums", "")), } + _merge_extra_metadata(metadata, _parse_extra_metadata(row.get("extra_metadata"))) if chunk_type == "image": embedded_image_path = _find_embedded_resource_path( diff --git a/packages/shared-python/shared/services/retrieval/publication_service.py b/packages/shared-python/shared/services/retrieval/publication_service.py index aa252ef8..2e180cbc 100644 --- a/packages/shared-python/shared/services/retrieval/publication_service.py +++ b/packages/shared-python/shared/services/retrieval/publication_service.py @@ -96,6 +96,7 @@ def _publish_document_state_for_job( job_metadata = job.job_metadata or {} namespace = normalize_retrieval_namespace(job_metadata.get("namespace")) document_id = job_metadata.get("document_id") + parse_track = str(job_metadata.get("parse_track") or "chunk") source_file_name = job_metadata.get("source_file_name") or job_metadata.get( "file_name" ) @@ -121,6 +122,7 @@ def _publish_document_state_for_job( job_result_id=job_result_id, document_id=str(document_id) if document_id else None, namespace=namespace, + parse_track=parse_track, source_file_name=str(source_file_name) if source_file_name else None, ) if document is None: @@ -160,6 +162,7 @@ def _upsert_document_revision( job_result_id: str, document_id: str | None, namespace: str, + parse_track: str, source_file_name: str | None, ) -> Document | None: document = None @@ -181,6 +184,7 @@ def _upsert_document_revision( status="active", current_job_result_id=job_result_id, source_file_name=source_file_name, + parse_track=parse_track, ) db.add(document) else: @@ -198,6 +202,7 @@ def _upsert_document_revision( document.archived_at = None document.current_job_result_id = job_result_id document.source_file_name = source_file_name or document.source_file_name + document.parse_track = parse_track or document.parse_track document.updated_at = utc_now_naive() db.flush() diff --git a/packages/shared-python/shared/services/storage/zip_manifest_schema.py b/packages/shared-python/shared/services/storage/zip_manifest_schema.py index ce2570a4..4190fe4e 100644 --- a/packages/shared-python/shared/services/storage/zip_manifest_schema.py +++ b/packages/shared-python/shared/services/storage/zip_manifest_schema.py @@ -4,6 +4,7 @@ from typing import Any +from shared.services.ai.token_costing import build_token_cost_estimate from shared.utils.utc_now import utc_now_naive @@ -18,6 +19,8 @@ def generate_manifest( job_metadata: dict[str, Any], hierarchy: dict[str, Any] | None = None, ) -> dict[str, Any]: + stages = job_metadata.get("stages", {}) + token_usage = stages.get("token_usage") if isinstance(stages, dict) else {} return { "version": "2.0", "job_id": job_id, @@ -31,12 +34,13 @@ def generate_manifest( "micro_dollars": job_metadata.get("billing_amount_micro_dollars"), "credits": job_metadata.get("billing_credits"), }, + "cost_estimate": build_token_cost_estimate(token_usage), "timing": { "started_at": job_metadata.get("processing_started_at"), "completed_at": job_metadata.get("processing_completed_at"), "duration_ms": job_metadata.get("processing_duration_ms"), }, - "stages": job_metadata.get("stages", {}), + "stages": stages, }, "statistics": statistics, "HIERARCHY": hierarchy or {},