From 13ed846ee4e60bb5cfbf3374d6ae19f0c7afe131 Mon Sep 17 00:00:00 2001 From: MakiforDevelop Date: Sun, 31 May 2026 14:01:10 +0800 Subject: [PATCH] =?UTF-8?q?chore(types):=207=20=E5=80=8B=E6=A0=B8=E5=BF=83?= =?UTF-8?q?=E6=AA=94=E8=A3=9C=20type=20hints=EF=BC=88=E5=BE=9E=E5=B9=BE?= =?UTF-8?q?=E4=B9=8E=E7=84=A1=E5=9E=8B=E5=88=A5=20=E2=86=92=20=E5=85=A8?= =?UTF-8?q?=E7=B0=BD=E5=90=8D=E6=A8=99=E6=B3=A8=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit config/ingest/embed/enrich/score/search/serve 全部函式簽名補上參數+回傳 型別,加 from __future__ import annotations,精準型別優先於 Any (sqlite3.Connection / Path / list[dict] 等)。純標注、零行為改變。 驗證: ruff clean + pytest 64 passed(不變)。 非嚴格 mypy 抽驗浮現 8 個既有 | None 流問題(serve.py global-None 模式 5 個 + enrich.py 回傳/索引可能 None 3 個)— 這些是 hints 揭露的既有 type 鬆散, 非本次引入;列 follow-up(mypy 全綠是獨立一筆債)。 Constraint: 純加法 annotations,不改 logic/control flow,不重構 Directive: 先補 hints 讓型別可見;mypy strict-clean 另開 task(會改 None 守衛=行為) Rejected: Grok Build 原派此任務 | headless 兩次 exit=1(mem0:8791 MCP 噪音燒 turn + 找不存在的 pyproject + max_turns 爆),circuit breaker 換 Codex Not-tested: mypy strict(刻意,本次只補 hints) Co-Authored-By: Claude Opus 4.8 (1M context) --- config.py | 30 ++++++++++++++++-------------- embed.py | 25 ++++++++++++++++--------- enrich.py | 54 +++++++++++++++++++++++++++++++++++------------------- ingest.py | 20 +++++++++++--------- score.py | 38 ++++++++++++++++++++++---------------- search.py | 47 +++++++++++++++++++++++++++++++++-------------- serve.py | 40 ++++++++++++++++++++++++---------------- 7 files changed, 157 insertions(+), 97 deletions(-) diff --git a/config.py b/config.py index 8ee9d1c..c897f55 100644 --- a/config.py +++ b/config.py @@ -4,12 +4,15 @@ No external dependencies. """ +from __future__ import annotations + import os +import sqlite3 from pathlib import Path # ── Load .env if present ── -_ENV_PATH = Path(__file__).parent / ".env" +_ENV_PATH: Path = Path(__file__).parent / ".env" if _ENV_PATH.exists(): with open(_ENV_PATH) as f: for line in f: @@ -20,35 +23,34 @@ # ── Database ── -DB_PATH = Path(__file__).parent / "knowledge.db" +DB_PATH: Path = Path(__file__).parent / "knowledge.db" # ── LLM Backend (OpenAI-compatible) ── -LLM_BASE_URL = os.environ.get( +LLM_BASE_URL: str = os.environ.get( "LLM_BASE_URL", "http://localhost:11434/v1/chat/completions" ) -LLM_MODEL = os.environ.get("LLM_MODEL", "qwen2.5:7b") -LLM_API_KEY = os.environ.get("LLM_API_KEY", "") -LLM_TIMEOUT = int(os.environ.get("LLM_TIMEOUT", "120")) +LLM_MODEL: str = os.environ.get("LLM_MODEL", "qwen2.5:7b") +LLM_API_KEY: str = os.environ.get("LLM_API_KEY", "") +LLM_TIMEOUT: int = int(os.environ.get("LLM_TIMEOUT", "120")) # ── Embedding ── -EMBED_MODEL = "BAAI/bge-m3" -EMBED_DIM = 1024 -EMBED_REMOTE_URL = os.environ.get("EMBED_REMOTE_URL", "") +EMBED_MODEL: str = "BAAI/bge-m3" +EMBED_DIM: int = 1024 +EMBED_REMOTE_URL: str = os.environ.get("EMBED_REMOTE_URL", "") # ── Server ── -SERVE_PORT = int(os.environ.get("SERVE_PORT", "8780")) +SERVE_PORT: int = int(os.environ.get("SERVE_PORT", "8780")) # ── Scoring ── -SCORING_PROMPT_VERSION = "v1.0" +SCORING_PROMPT_VERSION: str = "v1.0" -def get_db_connection(): +def get_db_connection() -> sqlite3.Connection: """Return a SQLite connection with row_factory and WAL mode.""" - import sqlite3 conn = sqlite3.connect(str(DB_PATH), timeout=10) conn.row_factory = sqlite3.Row conn.execute("PRAGMA journal_mode=WAL") @@ -56,7 +58,7 @@ def get_db_connection(): return conn -def init_db(): +def init_db() -> None: """Create tables if they don't exist.""" schema_path = Path(__file__).parent / "schema.sql" conn = get_db_connection() diff --git a/embed.py b/embed.py index 76ee6b6..5f1279b 100644 --- a/embed.py +++ b/embed.py @@ -12,16 +12,23 @@ python3 embed.py --remote URL # Use a remote embedding server """ +from __future__ import annotations + import argparse import json +import sqlite3 import sys import time from datetime import datetime, timezone +from typing import Any from config import EMBED_DIM, EMBED_MODEL, EMBED_REMOTE_URL, get_db_connection, init_db -def get_embedding_text(row) -> str: +EmbeddingResult = dict[str, list[float] | dict[str, float]] + + +def get_embedding_text(row: sqlite3.Row) -> str: """Build the text to embed from item fields.""" parts = [] if row["core_insight"]: @@ -35,10 +42,10 @@ def get_embedding_text(row) -> str: return " ".join(parts) -_local_model = None +_local_model: Any | None = None -def _get_local_model(): +def _get_local_model() -> Any: """Lazy-load and cache the embedding model (avoid reloading on every call).""" global _local_model if _local_model is None: @@ -47,7 +54,7 @@ def _get_local_model(): return _local_model -def embed_local(texts: list[str]) -> list[dict]: +def embed_local(texts: list[str]) -> list[EmbeddingResult]: """Embed texts using local bge-m3 model. Returns list of {dense, sparse}.""" model = _get_local_model() output = model.encode( @@ -56,7 +63,7 @@ def embed_local(texts: list[str]) -> list[dict]: return_sparse=True, return_colbert_vecs=False, ) - results = [] + results: list[EmbeddingResult] = [] for i in range(len(texts)): dense = output["dense_vecs"][i].tolist() sparse = {str(k): float(v) for k, v in output["lexical_weights"][i].items()} @@ -64,7 +71,7 @@ def embed_local(texts: list[str]) -> list[dict]: return results -def embed_remote(texts: list[str], remote_url: str) -> list[dict]: +def embed_remote(texts: list[str], remote_url: str) -> list[EmbeddingResult]: """Embed texts via a remote HTTP embedding server.""" from urllib.request import Request, urlopen @@ -72,9 +79,9 @@ def embed_remote(texts: list[str], remote_url: str) -> list[dict]: req = Request(remote_url, data=body, method="POST", headers={"Content-Type": "application/json"}) with urlopen(req, timeout=120) as resp: - data = json.loads(resp.read()) + data: dict[str, Any] = json.loads(resp.read()) - results = [] + results: list[EmbeddingResult] = [] for i in range(len(texts)): dense = data["dense"][i] if len(dense) != EMBED_DIM: @@ -88,7 +95,7 @@ def embed_remote(texts: list[str], remote_url: str) -> list[dict]: return results -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Generate embeddings for scored items") parser.add_argument("--rebuild", action="store_true", help="Clear and re-embed all items") parser.add_argument("--remote", type=str, default=EMBED_REMOTE_URL, help="Remote embedding server URL") diff --git a/enrich.py b/enrich.py index 2ef1a16..5b80e99 100644 --- a/enrich.py +++ b/enrich.py @@ -9,15 +9,20 @@ python3 enrich.py --limit 10 # Process up to 10 items """ +from __future__ import annotations + import argparse import html.parser import ipaddress import json import logging import re +import sqlite3 import socket import time from datetime import datetime, timezone +from http.client import HTTPMessage +from typing import Any from urllib.parse import urlparse from urllib.request import HTTPRedirectHandler, Request, build_opener, urlopen @@ -30,17 +35,20 @@ init_db, ) -USER_AGENT = ( +FetchResult = dict[str, str | None] +LLMResult = dict[str, Any] + +USER_AGENT: str = ( "Mozilla/5.0 (compatible; knowledge-pipeline/1.0; " "+https://github.com/makifordevelop/knowledge-pipeline)" ) -SKIP_DOMAINS = {"apps.apple.com", "drive.google.com", "play.google.com"} +SKIP_DOMAINS: set[str] = {"apps.apple.com", "drive.google.com", "play.google.com"} -MAX_CONTENT_BYTES = 5 * 1024 * 1024 # 5MB limit to prevent OOM +MAX_CONTENT_BYTES: int = 5 * 1024 * 1024 # 5MB limit to prevent OOM # Hostnames to always block (SSRF protection) -_BLOCKED_HOSTNAMES = {"localhost", "metadata.google.internal"} +_BLOCKED_HOSTNAMES: set[str] = {"localhost", "metadata.google.internal"} def _is_private_ip(ip_str: str) -> bool: @@ -75,7 +83,15 @@ def _is_private_url(url: str) -> bool: class _SSRFSafeRedirectHandler(HTTPRedirectHandler): """Validate redirect targets against SSRF blocklist.""" - def redirect_request(self, req, fp, code, msg, headers, newurl): + def redirect_request( + self, + req: Request, + fp: Any, + code: int, + msg: str, + headers: HTTPMessage, + newurl: str, + ) -> Request | None: if _is_private_url(newurl): raise ValueError(f"Redirect to private/internal URL blocked: {newurl}") return super().redirect_request(req, fp, code, msg, headers, newurl) @@ -87,28 +103,28 @@ def redirect_request(self, req, fp, code, msg, headers, newurl): # ── HTML text extraction (zero dependencies) ── class _HTMLTextExtractor(html.parser.HTMLParser): - SKIP_TAGS = {"script", "style", "nav", "footer", "header", "aside", "noscript"} + SKIP_TAGS: set[str] = {"script", "style", "nav", "footer", "header", "aside", "noscript"} - def __init__(self): + def __init__(self) -> None: super().__init__() - self.result = [] - self._skip = 0 + self.result: list[str] = [] + self._skip: int = 0 - def handle_starttag(self, tag, attrs): + def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None: if tag in self.SKIP_TAGS: self._skip += 1 - def handle_endtag(self, tag): + def handle_endtag(self, tag: str) -> None: if tag in self.SKIP_TAGS and self._skip > 0: self._skip -= 1 - def handle_data(self, data): + def handle_data(self, data: str) -> None: if self._skip == 0: text = data.strip() if text: self.result.append(text) - def get_text(self): + def get_text(self) -> str: return "\n".join(self.result) @@ -128,7 +144,7 @@ def extract_title_from_html(html_content: str) -> str | None: # ── URL fetching ── -def fetch_url(url: str, timeout: int = 30) -> dict: +def fetch_url(url: str, timeout: int = 30) -> FetchResult: """Fetch a URL and return {html, title, text, status}.""" if _is_private_url(url): return {"status": "skipped", "reason": "blocked: private/internal URL"} @@ -159,7 +175,7 @@ def fetch_url(url: str, timeout: int = 30) -> dict: # ── LLM enrichment ── -_ENRICH_PROMPT = """Analyze this web content and provide a structured summary. +_ENRICH_PROMPT: str = """Analyze this web content and provide a structured summary. Title: {title} URL: {url} @@ -174,7 +190,7 @@ def fetch_url(url: str, timeout: int = 30) -> dict: }}""" -def call_llm(prompt: str) -> dict | None: +def call_llm(prompt: str) -> LLMResult | None: """Call an OpenAI-compatible LLM API. Returns parsed JSON or None.""" body = { "model": LLM_MODEL, @@ -203,7 +219,7 @@ def call_llm(prompt: str) -> dict | None: return None -def enrich_item(item_id: int, url: str, domain: str, conn) -> str: +def enrich_item(item_id: int, url: str, domain: str, conn: sqlite3.Connection) -> str: """Enrich a single item. Returns status string.""" if domain in SKIP_DOMAINS: conn.execute( @@ -249,7 +265,7 @@ def enrich_item(item_id: int, url: str, domain: str, conn) -> str: return "fetched" -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Enrich pending items with full text and LLM summaries") parser.add_argument("--limit", type=int, default=0, help="Max items to process (0 = all)") args = parser.parse_args() @@ -261,7 +277,7 @@ def main(): "SELECT id, url, domain FROM items " "WHERE fetch_status = 'pending' ORDER BY added_at" ) - query_params = [] + query_params: list[int] = [] if args.limit > 0: query += " LIMIT ?" query_params.append(args.limit) diff --git a/ingest.py b/ingest.py index 0c69b81..297344e 100644 --- a/ingest.py +++ b/ingest.py @@ -10,6 +10,8 @@ echo "https://example.com" | python3 ingest.py --stdin """ +from __future__ import annotations + import argparse import hashlib import re @@ -22,13 +24,13 @@ from config import get_db_connection, init_db # Tracking parameters to strip -TRACKING_PARAMS = { +TRACKING_PARAMS: set[str] = { "utm_source", "utm_medium", "utm_campaign", "utm_term", "utm_content", "fbclid", "gclid", "ref", "ref_src", "ref_url", "igsh", "si", "xmt", "slof", "hsLang", } -URL_RE = re.compile(r"https?://[^\s<>\"']+") +URL_RE: re.Pattern[str] = re.compile(r"https?://[^\s<>\"']+") def normalize_url(raw_url: str) -> str: @@ -44,9 +46,9 @@ def normalize_url(raw_url: str) -> str: def extract_urls(text: str) -> list[str]: """Extract and normalize all URLs from text.""" - raw = URL_RE.findall(text) - seen = set() - result = [] + raw: list[str] = URL_RE.findall(text) + seen: set[str] = set() + result: list[str] = [] for url in raw: normalized = normalize_url(url) if normalized not in seen: @@ -55,7 +57,7 @@ def extract_urls(text: str) -> list[str]: return result -def ingest_urls(urls: list[str], source: str = "cli") -> dict: +def ingest_urls(urls: list[str], source: str = "cli") -> dict[str, int]: """Insert URLs into the database. Returns stats.""" init_db() conn = get_db_connection() @@ -84,7 +86,7 @@ def ingest_urls(urls: list[str], source: str = "cli") -> dict: def extract_urls_from_obsidian_vault(vault_path: Path, after_date: datetime | None = None) -> list[str]: """Extract URLs from markdown files in an Obsidian vault.""" - urls = [] + urls: list[str] = [] for md_file in vault_path.rglob("*.md"): try: file_mtime = datetime.fromtimestamp(md_file.stat().st_mtime) @@ -96,7 +98,7 @@ def extract_urls_from_obsidian_vault(vault_path: Path, after_date: datetime | No return urls -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Ingest URLs into the knowledge pipeline", epilog="Examples:\n" @@ -111,7 +113,7 @@ def main(): parser.add_argument("--after", type=str, help="Date filter (YYYY-MM-DD)") args = parser.parse_args() - urls = [] + urls: list[str] = [] source = "cli" if args.stdin: diff --git a/score.py b/score.py index d3b3a9a..b83830f 100644 --- a/score.py +++ b/score.py @@ -13,12 +13,16 @@ python3 score.py --dry-run # Preview without writing """ +from __future__ import annotations + import argparse import json import logging import re +import sqlite3 import time from datetime import datetime, timezone +from typing import Any from config import ( LLM_API_KEY, @@ -35,7 +39,9 @@ # Each dimension is scored 0-5 by the LLM. # The prompt includes calibration baselines to ensure consistency. -SCORING_PROMPT = """You are a knowledge analyst. Score the following content on multiple dimensions. +ScoreDict = dict[str, Any] + +SCORING_PROMPT: str = """You are a knowledge analyst. Score the following content on multiple dimensions. Respond in strict JSON only (no other text). Title: {title} @@ -64,7 +70,7 @@ - actionability: 1=info only, 2=changes thinking, 3=changes decisions, 4=can build/test now, 5=complete implementation guide - source_credibility: 1=anonymous, 2=social media, 3=tech blog, 4=major publication/official, 5=academic paper/gov report""" -DEFAULT_SCORES = { +DEFAULT_SCORES: ScoreDict = { "knowledge_density": 2, "novelty": 2, "evidence_strength": 2, @@ -76,20 +82,20 @@ "decision_reason": "Unable to score (LLM unavailable)", } -VALID_TIME_HORIZONS = {"short", "mid", "long"} +VALID_TIME_HORIZONS: set[str] = {"short", "mid", "long"} # ── Routes ── # Each item is routed to a destination based on its scores. # Customize these for your workflow. -ROUTE_RESEARCH = "research" # Needs deeper investigation -ROUTE_WRITER = "writer" # Good for writing/publishing -ROUTE_ACTION = "action" # Directly actionable -ROUTE_VALIDATOR = "validator" # Needs fact-checking -ROUTE_ARCHIVE = "archive" # Low priority, file away +ROUTE_RESEARCH: str = "research" # Needs deeper investigation +ROUTE_WRITER: str = "writer" # Good for writing/publishing +ROUTE_ACTION: str = "action" # Directly actionable +ROUTE_VALIDATOR: str = "validator" # Needs fact-checking +ROUTE_ARCHIVE: str = "archive" # Low priority, file away -def compute_route(scores: dict) -> str: +def compute_route(scores: ScoreDict) -> str: """Determine where this item should go based on scores.""" act = scores.get("actionability", 0) rl = scores.get("risk_level", 0) @@ -117,7 +123,7 @@ def compute_route(scores: dict) -> str: return ROUTE_ARCHIVE -def compute_signal_score(scores: dict) -> int: +def compute_signal_score(scores: ScoreDict) -> int: """Composite signal score (0-100) for ranking and thresholds. Higher = more valuable knowledge. Used to decide publishing, @@ -142,7 +148,7 @@ def compute_signal_score(scores: dict) -> int: return max(0, min(100, int(raw * 100 / 95))) -def _call_llm(prompt: str) -> dict | None: +def _call_llm(prompt: str) -> ScoreDict | None: """Call LLM and return parsed JSON scores.""" from urllib.request import Request, urlopen @@ -175,9 +181,9 @@ def _call_llm(prompt: str) -> dict | None: return None -def validate_scores(raw: dict) -> dict: +def validate_scores(raw: ScoreDict) -> ScoreDict: """Clamp and validate LLM output, fallback to defaults.""" - scores = {} + scores: ScoreDict = {} for key in ("knowledge_density", "novelty", "evidence_strength", "actionability", "risk_level", "emotional_noise", "source_credibility"): val = raw.get(key, DEFAULT_SCORES[key]) @@ -189,7 +195,7 @@ def validate_scores(raw: dict) -> dict: return scores -def score_item(item, conn, dry_run: bool = False) -> dict: +def score_item(item: sqlite3.Row, conn: sqlite3.Connection, dry_run: bool = False) -> ScoreDict: """Score a single item. Returns the scores dict.""" prompt = SCORING_PROMPT.format( title=item["title"] or "(no title)", @@ -226,7 +232,7 @@ def score_item(item, conn, dry_run: bool = False) -> dict: return scores -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Score enriched items with multi-dimensional LLM analysis") parser.add_argument("--limit", type=int, default=0, help="Max items to score") parser.add_argument("--rescore", action="store_true", help="Re-score already scored items") @@ -241,7 +247,7 @@ def main(): where += " AND signal_score IS NULL" query = f"SELECT id, url, domain, title, core_insight, full_text FROM items WHERE {where} ORDER BY added_at" - query_params = [] + query_params: list[int] = [] if args.limit > 0: query += " LIMIT ?" query_params.append(args.limit) diff --git a/search.py b/search.py index 7890725..87d8beb 100644 --- a/search.py +++ b/search.py @@ -13,25 +13,36 @@ python3 search.py "Docker" --domain github.com """ +from __future__ import annotations + import argparse import json +import sqlite3 import sys +from typing import Any import numpy as np from FlagEmbedding import BGEM3FlagModel from config import EMBED_DIM, EMBED_MODEL, get_db_connection, init_db -RERANKER_NAME = "BAAI/bge-reranker-v2-m3" -DEFAULT_TOP_K = 10 -DENSE_WEIGHT = 0.7 -SPARSE_WEIGHT = 0.3 +SearchFilters = dict[str, str | int] +SearchResult = dict[str, Any] +SparseWeights = dict[str, float] + +RERANKER_NAME: str = "BAAI/bge-reranker-v2-m3" +DEFAULT_TOP_K: int = 10 +DENSE_WEIGHT: float = 0.7 +SPARSE_WEIGHT: float = 0.3 -def load_embeddings(conn, filters: dict | None = None): +def load_embeddings( + conn: sqlite3.Connection, + filters: SearchFilters | None = None, +) -> tuple[list[sqlite3.Row], np.ndarray, list[SparseWeights]]: """Load embedding matrix and sparse weights from DB.""" where = ["embedding IS NOT NULL", "length(embedding) > 2"] - params = [] + params: list[str | int] = [] if filters: if filters.get("domain"): @@ -48,7 +59,7 @@ def load_embeddings(conn, filters: dict | None = None): return [], np.array([]), [] matrix = np.zeros((len(rows), EMBED_DIM), dtype=np.float32) - sparse_list = [] + sparse_list: list[SparseWeights] = [] for i, row in enumerate(rows): vec = json.loads(row["embedding"]) matrix[i] = np.array(vec, dtype=np.float32) @@ -58,7 +69,15 @@ def load_embeddings(conn, filters: dict | None = None): return rows, matrix, sparse_list -def hybrid_search(query_text: str, model, rows, matrix, sparse_list, top_k: int = 10, dense_only: bool = False): +def hybrid_search( + query_text: str, + model: Any, + rows: list[sqlite3.Row], + matrix: np.ndarray, + sparse_list: list[SparseWeights], + top_k: int = 10, + dense_only: bool = False, +) -> list[SearchResult]: """Perform hybrid search (dense + sparse). Returns scored results.""" q_output = model.encode( [query_text], return_dense=True, return_sparse=True, return_colbert_vecs=False @@ -87,7 +106,7 @@ def hybrid_search(query_text: str, model, rows, matrix, sparse_list, top_k: int top_indices = np.argsort(final_scores)[::-1][:top_k] - results = [] + results: list[SearchResult] = [] for idx in top_indices: idx = int(idx) if final_scores[idx] <= 0: @@ -107,10 +126,10 @@ def hybrid_search(query_text: str, model, rows, matrix, sparse_list, top_k: int return results -_reranker = None +_reranker: Any | None = None -def _get_reranker(): +def _get_reranker() -> Any: """Lazy-load and cache the reranker model (avoid reloading on every call).""" global _reranker if _reranker is None: @@ -119,7 +138,7 @@ def _get_reranker(): return _reranker -def rerank(query: str, results: list[dict], top_k: int = 10) -> list[dict]: +def rerank(query: str, results: list[SearchResult], top_k: int = 10) -> list[SearchResult]: """Rerank results using cross-encoder.""" reranker = _get_reranker() @@ -135,7 +154,7 @@ def rerank(query: str, results: list[dict], top_k: int = 10) -> list[dict]: return results[:top_k] -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="Hybrid semantic search over your knowledge base", formatter_class=argparse.RawDescriptionHelpFormatter, @@ -160,7 +179,7 @@ def main(): print(f"Loading model {EMBED_MODEL}...", file=sys.stderr) model = BGEM3FlagModel(EMBED_MODEL, use_fp16=True) - filters = {} + filters: SearchFilters = {} if args.domain: filters["domain"] = args.domain if args.min_score: diff --git a/serve.py b/serve.py index 6c3103f..38b6976 100644 --- a/serve.py +++ b/serve.py @@ -16,13 +16,18 @@ GET /health """ +from __future__ import annotations + import argparse import json +import sqlite3 import sys import time from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any from urllib.parse import parse_qs, urlparse +import numpy as np from FlagEmbedding import BGEM3FlagModel from config import EMBED_MODEL, SERVE_PORT, get_db_connection, init_db @@ -30,15 +35,18 @@ # ── Global state (loaded at startup) ── -_model = None -_conn = None -_rows = None -_matrix = None -_sparse = None -_use_rerank = False -_stats = {"start_time": 0, "queries": 0, "avg_latency_ms": 0, "_latency_sum": 0} +SparseWeights = dict[str, float] +SearchResult = dict[str, Any] + +_model: Any | None = None +_conn: sqlite3.Connection | None = None +_rows: list[sqlite3.Row] | None = None +_matrix: np.ndarray | None = None +_sparse: list[SparseWeights] | None = None +_use_rerank: bool = False +_stats: dict[str, float] = {"start_time": 0, "queries": 0, "avg_latency_ms": 0, "_latency_sum": 0} -HTML = """ +HTML: str = """ @@ -136,7 +144,7 @@ """ -def _reload(): +def _reload() -> None: """Load/reload embeddings from DB.""" global _rows, _matrix, _sparse _rows, _matrix, _sparse = load_embeddings(_conn) @@ -144,7 +152,7 @@ def _reload(): class SearchHandler(BaseHTTPRequestHandler): - def do_GET(self): + def do_GET(self) -> None: parsed = urlparse(self.path) path = parsed.path params = parse_qs(parsed.query) @@ -168,7 +176,7 @@ def do_GET(self): else: self._json_response({"error": "not found"}, status=404) - def _html_response(self): + def _html_response(self) -> None: body = HTML.encode() self.send_response(200) self.send_header("Content-Type", "text/html; charset=utf-8") @@ -176,7 +184,7 @@ def _html_response(self): self.end_headers() self.wfile.write(body) - def _handle_search(self, params): + def _handle_search(self, params: dict[str, list[str]]) -> None: q = params.get("q", [""])[0] if not q: self._json_response({"error": "missing ?q= parameter"}, status=400) @@ -194,7 +202,7 @@ def _handle_search(self, params): # Filter if needed if domain or min_score: - filters = {} + filters: dict[str, str | int] = {} if domain: filters["domain"] = domain if min_score: @@ -225,7 +233,7 @@ def _handle_search(self, params): "latency_ms": round(elapsed_ms, 1), }) - def _json_response(self, data, status=200): + def _json_response(self, data: dict[str, Any], status: int = 200) -> None: body = json.dumps(data, ensure_ascii=False).encode() self.send_response(status) self.send_header("Content-Type", "application/json") @@ -233,13 +241,13 @@ def _json_response(self, data, status=200): self.end_headers() self.wfile.write(body) - def log_message(self, format, *args): + def log_message(self, format: str, *args: Any) -> None: # Quieter logging if "/health" not in str(args): sys.stderr.write(f"[{self.log_date_time_string()}] {format % args}\n") -def main(): +def main() -> None: global _model, _conn, _rows, _matrix, _sparse, _use_rerank parser = argparse.ArgumentParser(description="Knowledge search HTTP API server")