Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
No external dependencies.
"""

from __future__ import annotations

import os
import sqlite3
from pathlib import Path

# ── Load .env if present ──

_ENV_PATH = Path(__file__).parent / ".env"
_ENV_PATH: Path = Path(__file__).parent / ".env"
if _ENV_PATH.exists():
with open(_ENV_PATH) as f:
for line in f:
Expand All @@ -20,43 +23,42 @@

# ── Database ──

DB_PATH = Path(__file__).parent / "knowledge.db"
DB_PATH: Path = Path(__file__).parent / "knowledge.db"

# ── LLM Backend (OpenAI-compatible) ──

LLM_BASE_URL = os.environ.get(
LLM_BASE_URL: str = os.environ.get(
"LLM_BASE_URL", "http://localhost:11434/v1/chat/completions"
)
LLM_MODEL = os.environ.get("LLM_MODEL", "qwen2.5:7b")
LLM_API_KEY = os.environ.get("LLM_API_KEY", "")
LLM_TIMEOUT = int(os.environ.get("LLM_TIMEOUT", "120"))
LLM_MODEL: str = os.environ.get("LLM_MODEL", "qwen2.5:7b")
LLM_API_KEY: str = os.environ.get("LLM_API_KEY", "")
LLM_TIMEOUT: int = int(os.environ.get("LLM_TIMEOUT", "120"))

# ── Embedding ──

EMBED_MODEL = "BAAI/bge-m3"
EMBED_DIM = 1024
EMBED_REMOTE_URL = os.environ.get("EMBED_REMOTE_URL", "")
EMBED_MODEL: str = "BAAI/bge-m3"
EMBED_DIM: int = 1024
EMBED_REMOTE_URL: str = os.environ.get("EMBED_REMOTE_URL", "")

# ── Server ──

SERVE_PORT = int(os.environ.get("SERVE_PORT", "8780"))
SERVE_PORT: int = int(os.environ.get("SERVE_PORT", "8780"))

# ── Scoring ──

SCORING_PROMPT_VERSION = "v1.0"
SCORING_PROMPT_VERSION: str = "v1.0"


def get_db_connection():
def get_db_connection() -> sqlite3.Connection:
"""Return a SQLite connection with row_factory and WAL mode."""
import sqlite3
conn = sqlite3.connect(str(DB_PATH), timeout=10)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
return conn


def init_db():
def init_db() -> None:
"""Create tables if they don't exist."""
schema_path = Path(__file__).parent / "schema.sql"
conn = get_db_connection()
Expand Down
25 changes: 16 additions & 9 deletions embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,23 @@
python3 embed.py --remote URL # Use a remote embedding server
"""

from __future__ import annotations

import argparse
import json
import sqlite3
import sys
import time
from datetime import datetime, timezone
from typing import Any

from config import EMBED_DIM, EMBED_MODEL, EMBED_REMOTE_URL, get_db_connection, init_db


def get_embedding_text(row) -> str:
EmbeddingResult = dict[str, list[float] | dict[str, float]]


def get_embedding_text(row: sqlite3.Row) -> str:
"""Build the text to embed from item fields."""
parts = []
if row["core_insight"]:
Expand All @@ -35,10 +42,10 @@ def get_embedding_text(row) -> str:
return " ".join(parts)


_local_model = None
_local_model: Any | None = None


def _get_local_model():
def _get_local_model() -> Any:
"""Lazy-load and cache the embedding model (avoid reloading on every call)."""
global _local_model
if _local_model is None:
Expand All @@ -47,7 +54,7 @@ def _get_local_model():
return _local_model


def embed_local(texts: list[str]) -> list[dict]:
def embed_local(texts: list[str]) -> list[EmbeddingResult]:
"""Embed texts using local bge-m3 model. Returns list of {dense, sparse}."""
model = _get_local_model()
output = model.encode(
Expand All @@ -56,25 +63,25 @@ def embed_local(texts: list[str]) -> list[dict]:
return_sparse=True,
return_colbert_vecs=False,
)
results = []
results: list[EmbeddingResult] = []
for i in range(len(texts)):
dense = output["dense_vecs"][i].tolist()
sparse = {str(k): float(v) for k, v in output["lexical_weights"][i].items()}
results.append({"dense": dense, "sparse": sparse})
return results


def embed_remote(texts: list[str], remote_url: str) -> list[dict]:
def embed_remote(texts: list[str], remote_url: str) -> list[EmbeddingResult]:
"""Embed texts via a remote HTTP embedding server."""
from urllib.request import Request, urlopen

body = json.dumps({"texts": texts, "return_dense": True, "return_sparse": True}).encode()
req = Request(remote_url, data=body, method="POST", headers={"Content-Type": "application/json"})

with urlopen(req, timeout=120) as resp:
data = json.loads(resp.read())
data: dict[str, Any] = json.loads(resp.read())

results = []
results: list[EmbeddingResult] = []
for i in range(len(texts)):
dense = data["dense"][i]
if len(dense) != EMBED_DIM:
Expand All @@ -88,7 +95,7 @@ def embed_remote(texts: list[str], remote_url: str) -> list[dict]:
return results


def main():
def main() -> None:
parser = argparse.ArgumentParser(description="Generate embeddings for scored items")
parser.add_argument("--rebuild", action="store_true", help="Clear and re-embed all items")
parser.add_argument("--remote", type=str, default=EMBED_REMOTE_URL, help="Remote embedding server URL")
Expand Down
54 changes: 35 additions & 19 deletions enrich.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,20 @@
python3 enrich.py --limit 10 # Process up to 10 items
"""

from __future__ import annotations

import argparse
import html.parser
import ipaddress
import json
import logging
import re
import sqlite3
import socket
import time
from datetime import datetime, timezone
from http.client import HTTPMessage
from typing import Any
from urllib.parse import urlparse
from urllib.request import HTTPRedirectHandler, Request, build_opener, urlopen

Expand All @@ -30,17 +35,20 @@
init_db,
)

USER_AGENT = (
FetchResult = dict[str, str | None]
LLMResult = dict[str, Any]

USER_AGENT: str = (
"Mozilla/5.0 (compatible; knowledge-pipeline/1.0; "
"+https://github.com/makifordevelop/knowledge-pipeline)"
)

SKIP_DOMAINS = {"apps.apple.com", "drive.google.com", "play.google.com"}
SKIP_DOMAINS: set[str] = {"apps.apple.com", "drive.google.com", "play.google.com"}

MAX_CONTENT_BYTES = 5 * 1024 * 1024 # 5MB limit to prevent OOM
MAX_CONTENT_BYTES: int = 5 * 1024 * 1024 # 5MB limit to prevent OOM

# Hostnames to always block (SSRF protection)
_BLOCKED_HOSTNAMES = {"localhost", "metadata.google.internal"}
_BLOCKED_HOSTNAMES: set[str] = {"localhost", "metadata.google.internal"}


def _is_private_ip(ip_str: str) -> bool:
Expand Down Expand Up @@ -75,7 +83,15 @@ def _is_private_url(url: str) -> bool:
class _SSRFSafeRedirectHandler(HTTPRedirectHandler):
"""Validate redirect targets against SSRF blocklist."""

def redirect_request(self, req, fp, code, msg, headers, newurl):
def redirect_request(
self,
req: Request,
fp: Any,
code: int,
msg: str,
headers: HTTPMessage,
newurl: str,
) -> Request | None:
if _is_private_url(newurl):
raise ValueError(f"Redirect to private/internal URL blocked: {newurl}")
return super().redirect_request(req, fp, code, msg, headers, newurl)
Expand All @@ -87,28 +103,28 @@ def redirect_request(self, req, fp, code, msg, headers, newurl):
# ── HTML text extraction (zero dependencies) ──

class _HTMLTextExtractor(html.parser.HTMLParser):
SKIP_TAGS = {"script", "style", "nav", "footer", "header", "aside", "noscript"}
SKIP_TAGS: set[str] = {"script", "style", "nav", "footer", "header", "aside", "noscript"}

def __init__(self):
def __init__(self) -> None:
super().__init__()
self.result = []
self._skip = 0
self.result: list[str] = []
self._skip: int = 0

def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag in self.SKIP_TAGS:
self._skip += 1

def handle_endtag(self, tag):
def handle_endtag(self, tag: str) -> None:
if tag in self.SKIP_TAGS and self._skip > 0:
self._skip -= 1

def handle_data(self, data):
def handle_data(self, data: str) -> None:
if self._skip == 0:
text = data.strip()
if text:
self.result.append(text)

def get_text(self):
def get_text(self) -> str:
return "\n".join(self.result)


Expand All @@ -128,7 +144,7 @@ def extract_title_from_html(html_content: str) -> str | None:

# ── URL fetching ──

def fetch_url(url: str, timeout: int = 30) -> dict:
def fetch_url(url: str, timeout: int = 30) -> FetchResult:
"""Fetch a URL and return {html, title, text, status}."""
if _is_private_url(url):
return {"status": "skipped", "reason": "blocked: private/internal URL"}
Expand Down Expand Up @@ -159,7 +175,7 @@ def fetch_url(url: str, timeout: int = 30) -> dict:

# ── LLM enrichment ──

_ENRICH_PROMPT = """Analyze this web content and provide a structured summary.
_ENRICH_PROMPT: str = """Analyze this web content and provide a structured summary.

Title: {title}
URL: {url}
Expand All @@ -174,7 +190,7 @@ def fetch_url(url: str, timeout: int = 30) -> dict:
}}"""


def call_llm(prompt: str) -> dict | None:
def call_llm(prompt: str) -> LLMResult | None:
"""Call an OpenAI-compatible LLM API. Returns parsed JSON or None."""
body = {
"model": LLM_MODEL,
Expand Down Expand Up @@ -203,7 +219,7 @@ def call_llm(prompt: str) -> dict | None:
return None


def enrich_item(item_id: int, url: str, domain: str, conn) -> str:
def enrich_item(item_id: int, url: str, domain: str, conn: sqlite3.Connection) -> str:
"""Enrich a single item. Returns status string."""
if domain in SKIP_DOMAINS:
conn.execute(
Expand Down Expand Up @@ -249,7 +265,7 @@ def enrich_item(item_id: int, url: str, domain: str, conn) -> str:
return "fetched"


def main():
def main() -> None:
parser = argparse.ArgumentParser(description="Enrich pending items with full text and LLM summaries")
parser.add_argument("--limit", type=int, default=0, help="Max items to process (0 = all)")
args = parser.parse_args()
Expand All @@ -261,7 +277,7 @@ def main():
"SELECT id, url, domain FROM items "
"WHERE fetch_status = 'pending' ORDER BY added_at"
)
query_params = []
query_params: list[int] = []
if args.limit > 0:
query += " LIMIT ?"
query_params.append(args.limit)
Expand Down
20 changes: 11 additions & 9 deletions ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
echo "https://example.com" | python3 ingest.py --stdin
"""

from __future__ import annotations

import argparse
import hashlib
import re
Expand All @@ -22,13 +24,13 @@
from config import get_db_connection, init_db

# Tracking parameters to strip
TRACKING_PARAMS = {
TRACKING_PARAMS: set[str] = {
"utm_source", "utm_medium", "utm_campaign", "utm_term", "utm_content",
"fbclid", "gclid", "ref", "ref_src", "ref_url",
"igsh", "si", "xmt", "slof", "hsLang",
}

URL_RE = re.compile(r"https?://[^\s<>\"']+")
URL_RE: re.Pattern[str] = re.compile(r"https?://[^\s<>\"']+")


def normalize_url(raw_url: str) -> str:
Expand All @@ -44,9 +46,9 @@ def normalize_url(raw_url: str) -> str:

def extract_urls(text: str) -> list[str]:
"""Extract and normalize all URLs from text."""
raw = URL_RE.findall(text)
seen = set()
result = []
raw: list[str] = URL_RE.findall(text)
seen: set[str] = set()
result: list[str] = []
for url in raw:
normalized = normalize_url(url)
if normalized not in seen:
Expand All @@ -55,7 +57,7 @@ def extract_urls(text: str) -> list[str]:
return result


def ingest_urls(urls: list[str], source: str = "cli") -> dict:
def ingest_urls(urls: list[str], source: str = "cli") -> dict[str, int]:
"""Insert URLs into the database. Returns stats."""
init_db()
conn = get_db_connection()
Expand Down Expand Up @@ -84,7 +86,7 @@ def ingest_urls(urls: list[str], source: str = "cli") -> dict:

def extract_urls_from_obsidian_vault(vault_path: Path, after_date: datetime | None = None) -> list[str]:
"""Extract URLs from markdown files in an Obsidian vault."""
urls = []
urls: list[str] = []
for md_file in vault_path.rglob("*.md"):
try:
file_mtime = datetime.fromtimestamp(md_file.stat().st_mtime)
Expand All @@ -96,7 +98,7 @@ def extract_urls_from_obsidian_vault(vault_path: Path, after_date: datetime | No
return urls


def main():
def main() -> None:
parser = argparse.ArgumentParser(
description="Ingest URLs into the knowledge pipeline",
epilog="Examples:\n"
Expand All @@ -111,7 +113,7 @@ def main():
parser.add_argument("--after", type=str, help="Date filter (YYYY-MM-DD)")
args = parser.parse_args()

urls = []
urls: list[str] = []
source = "cli"

if args.stdin:
Expand Down
Loading
Loading