diff --git a/enrich.py b/enrich.py index 6e0ee9f..2ef1a16 100644 --- a/enrich.py +++ b/enrich.py @@ -13,6 +13,7 @@ import html.parser import ipaddress import json +import logging import re import socket import time @@ -115,8 +116,8 @@ def extract_text_from_html(html_content: str) -> str: extractor = _HTMLTextExtractor() try: extractor.feed(html_content) - except Exception: - pass + except Exception as e: + logging.getLogger(__name__).warning("HTML parse failed: %s", e) return extractor.get_text() diff --git a/score.py b/score.py index 46cca27..d3b3a9a 100644 --- a/score.py +++ b/score.py @@ -15,8 +15,8 @@ import argparse import json +import logging import re -import sys import time from datetime import datetime, timezone @@ -171,7 +171,7 @@ def _call_llm(prompt: str) -> dict | None: text = re.sub(r".*?", "", text, flags=re.DOTALL).strip() return json.loads(text) except Exception as e: - print(f" LLM error: {e}", file=sys.stderr) + logging.getLogger(__name__).warning("LLM scoring failed: %s", e) return None diff --git a/tests/test_embed.py b/tests/test_embed.py new file mode 100644 index 0000000..56b5f95 --- /dev/null +++ b/tests/test_embed.py @@ -0,0 +1,119 @@ +"""Tests for embed.py — text selection and embedding adapters.""" + +import json +import os +import sys + +import numpy as np + + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +import embed + + +def test_get_embedding_text_uses_priority_content_and_truncates_full_text(): + row = { + "core_insight": "core", + "summary": "summary", + "full_text": "x" * 1200, + "url": "https://example.com", + } + + text = embed.get_embedding_text(row) + + assert text == f"core summary {'x' * 1000}" + + +def test_get_embedding_text_falls_back_to_url_when_content_missing(): + row = { + "core_insight": None, + "summary": "", + "full_text": None, + "url": "https://example.com/fallback", + } + + assert embed.get_embedding_text(row) == "https://example.com/fallback" + + +def test_embed_local_converts_dense_and_sparse(monkeypatch): + class FakeModel: + def encode(self, texts, return_dense, return_sparse, return_colbert_vecs): + assert texts == ["one", "two"] + assert return_dense is True + assert return_sparse is True + assert return_colbert_vecs is False + return { + "dense_vecs": [ + np.array([1.0, 2.0], dtype=np.float32), + np.array([3.0, 4.0], dtype=np.float32), + ], + "lexical_weights": [{1: 0.5}, {"token": 2}], + } + + monkeypatch.setattr(embed, "_get_local_model", lambda: FakeModel()) + + assert embed.embed_local(["one", "two"]) == [ + {"dense": [1.0, 2.0], "sparse": {"1": 0.5}}, + {"dense": [3.0, 4.0], "sparse": {"token": 2.0}}, + ] + + +class FakeResponse: + def __init__(self, payload): + self.payload = payload + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self): + return json.dumps(self.payload).encode() + + +def test_embed_remote_parses_dense_and_sparse(monkeypatch): + monkeypatch.setattr(embed, "EMBED_DIM", 3) + + def fake_urlopen(req, timeout): + assert timeout == 120 + assert json.loads(req.data) == { + "texts": ["hello"], + "return_dense": True, + "return_sparse": True, + } + return FakeResponse({"dense": [[1, 2, 3]], "sparse": [{"kw": 0.5}]}) + + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + + assert embed.embed_remote(["hello"], "http://embed.test") == [ + {"dense": [1, 2, 3], "sparse": {"kw": 0.5}} + ] + + +def test_embed_remote_rejects_wrong_dense_dimension(monkeypatch): + monkeypatch.setattr(embed, "EMBED_DIM", 3) + monkeypatch.setattr( + "urllib.request.urlopen", + lambda req, timeout: FakeResponse({"dense": [[1, 2]]}), + ) + + try: + embed.embed_remote(["hello"], "http://embed.test") + except ValueError as e: + assert "2-dim vector, expected 3" in str(e) + else: + raise AssertionError("expected ValueError") + + +def test_embed_remote_defaults_missing_sparse_to_empty_dict(monkeypatch): + monkeypatch.setattr(embed, "EMBED_DIM", 3) + monkeypatch.setattr( + "urllib.request.urlopen", + lambda req, timeout: FakeResponse({"dense": [[1, 2, 3]]}), + ) + + assert embed.embed_remote(["hello"], "http://embed.test") == [ + {"dense": [1, 2, 3], "sparse": {}} + ] diff --git a/tests/test_enrich.py b/tests/test_enrich.py index 728557b..48bf668 100644 --- a/tests/test_enrich.py +++ b/tests/test_enrich.py @@ -1,6 +1,7 @@ """Tests for enrich.py — HTML extraction and SSRF protection.""" import os +import socket import sys @@ -26,7 +27,12 @@ def test_blocks_aws_metadata(self): def test_blocks_gcp_metadata(self): assert _is_private_url("http://metadata.google.internal/computeMetadata/v1/") is True - def test_allows_public_urls(self): + def test_allows_public_urls(self, monkeypatch): + monkeypatch.setattr( + socket, + "getaddrinfo", + lambda *args: [(socket.AF_INET, socket.SOCK_STREAM, 0, "", ("93.184.216.34", 443))], + ) assert _is_private_url("https://example.com/page") is False assert _is_private_url("https://github.com/repo") is False diff --git a/tests/test_search.py b/tests/test_search.py new file mode 100644 index 0000000..3768c66 --- /dev/null +++ b/tests/test_search.py @@ -0,0 +1,169 @@ +"""Tests for search.py — embedding loading, hybrid search, and reranking.""" + +import json +import importlib +import os +import sqlite3 +import sys +import types + +import numpy as np + + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +flagembedding = types.ModuleType("FlagEmbedding") +flagembedding.BGEM3FlagModel = object +flagembedding.FlagReranker = object +sys.modules.setdefault("FlagEmbedding", flagembedding) + +search = importlib.import_module("search") + + +def make_conn(): + conn = sqlite3.connect(":memory:") + conn.row_factory = sqlite3.Row + with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "schema.sql")) as f: + conn.executescript(f.read()) + return conn + + +def insert_item(conn, url, domain, score, dense, sparse=None): + conn.execute( + """ + INSERT INTO items ( + url, domain, title, core_insight, signal_score, route_to, + embedding, sparse_weights + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + url, + domain, + f"title {url}", + f"insight {url}", + score, + "research", + json.dumps(dense), + json.dumps(sparse or {}), + ), + ) + conn.commit() + + +class FakeModel: + def __init__(self, dense, sparse=None): + self.dense = np.array(dense, dtype=np.float32) + self.sparse = sparse or {} + + def encode(self, texts, return_dense, return_sparse, return_colbert_vecs): + assert texts + assert return_dense is True + assert return_sparse is True + assert return_colbert_vecs is False + return {"dense_vecs": [self.dense], "lexical_weights": [self.sparse]} + + +def rows_for_search(): + return [ + { + "id": 1, + "url": "https://a.example/1", + "domain": "a.example", + "title": "A", + "core_insight": "A insight", + "signal_score": 80, + "route_to": "research", + }, + { + "id": 2, + "url": "https://b.example/2", + "domain": "b.example", + "title": "B", + "core_insight": "B insight", + "signal_score": 70, + "route_to": "writer", + }, + ] + + +def test_load_embeddings_filters_and_empty(monkeypatch): + monkeypatch.setattr(search, "EMBED_DIM", 3) + conn = make_conn() + insert_item(conn, "https://a.example/1", "a.example", 80, [1, 0, 0], {"42": 1.5}) + insert_item(conn, "https://b.example/1", "b.example", 40, [0, 1, 0], {"7": 2.0}) + insert_item(conn, "https://a.example/2", "a.example", 20, [0, 0, 1], {"42": 0.5}) + + rows, matrix, sparse = search.load_embeddings(conn, {"domain": "a.example", "min_score": 50}) + + assert [row["url"] for row in rows] == ["https://a.example/1"] + assert matrix.tolist() == [[1.0, 0.0, 0.0]] + assert sparse == [{"42": 1.5}] + + rows, matrix, sparse = search.load_embeddings(conn, {"domain": "missing.example"}) + assert rows == [] + assert matrix.tolist() == [] + assert sparse == [] + + +def test_hybrid_search_dense_cosine_sorting_and_fields(): + rows = rows_for_search() + matrix = np.array([[1, 0], [0.8, 0.6]], dtype=np.float32) + sparse = [{"kw": 0}, {"kw": 10}] + model = FakeModel([1, 0], {"kw": 1}) + + results = search.hybrid_search("query", model, rows, matrix, sparse, top_k=2, dense_only=True) + + assert [result["id"] for result in results] == [1, 2] + assert set(results[0]) == { + "id", + "url", + "domain", + "title", + "core_insight", + "signal_score", + "route_to", + "similarity", + } + + +def test_hybrid_search_combines_dense_and_sparse_weights(): + rows = rows_for_search() + matrix = np.array([[1, 0], [0.8, 0.6]], dtype=np.float32) + sparse = [{"kw": 0}, {"kw": 10}] + model = FakeModel([1, 0], {"kw": 1}) + + results = search.hybrid_search("query", model, rows, matrix, sparse, top_k=2) + + assert [result["id"] for result in results] == [2, 1] + assert results[0]["similarity"] == 0.86 + assert results[1]["similarity"] == 0.7 + + +def test_hybrid_search_top_k_and_non_positive_filtering(): + rows = rows_for_search() + matrix = np.array([[1, 0], [0, 1]], dtype=np.float32) + sparse = [{}, {}] + model = FakeModel([1, 0]) + + results = search.hybrid_search("query", model, rows, matrix, sparse, top_k=1) + + assert [result["id"] for result in results] == [1] + + results = search.hybrid_search("query", model, rows, matrix, sparse, top_k=2) + assert [result["id"] for result in results] == [1] + + +def test_rerank_sorts_adds_scores_and_truncates(monkeypatch): + class FakeReranker: + def compute_score(self, pairs, normalize): + assert pairs == [("query", "A insight"), ("query", "B insight")] + assert normalize is True + return [0.1, 0.9] + + monkeypatch.setattr(search, "_get_reranker", lambda: FakeReranker()) + results = rows_for_search() + + reranked = search.rerank("query", results, top_k=1) + + assert [result["id"] for result in reranked] == [2] + assert reranked[0]["rerank_score"] == 0.9 diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 0000000..2cd7a43 --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,103 @@ +"""Tests for serve.py — HTTP handler contract without loading real models.""" + +import json +import importlib +import os +import sys +import time +import types +from io import BytesIO + +import numpy as np + + +sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) + +flagembedding = types.ModuleType("FlagEmbedding") +flagembedding.BGEM3FlagModel = object +flagembedding.FlagReranker = object +sys.modules.setdefault("FlagEmbedding", flagembedding) + +serve = importlib.import_module("serve") + + +class FakeModel: + def encode(self, texts, return_dense, return_sparse, return_colbert_vecs): + assert texts == ["hello"] + assert return_dense is True + assert return_sparse is True + assert return_colbert_vecs is False + return {"dense_vecs": [np.array([1, 0], dtype=np.float32)], "lexical_weights": [{}]} + + +def configure_globals(): + serve._model = FakeModel() + serve._conn = None + serve._rows = [ + { + "id": 1, + "url": "https://example.com/1", + "domain": "example.com", + "title": "Example", + "core_insight": "Useful result", + "signal_score": 90, + "route_to": "research", + } + ] + serve._matrix = np.array([[1, 0]], dtype=np.float32) + serve._sparse = [{}] + serve._use_rerank = False + serve._stats = {"start_time": time.time(), "queries": 0, "avg_latency_ms": 0, "_latency_sum": 0} + + +def invoke_get(path): + handler = serve.SearchHandler.__new__(serve.SearchHandler) + handler.path = path + handler.wfile = BytesIO() + handler.status = None + handler.headers = [] + + def send_response(status): + handler.status = status + + def send_header(key, value): + handler.headers.append((key, value)) + + handler.send_response = send_response + handler.send_header = send_header + handler.end_headers = lambda: None + + handler.do_GET() + return handler.status, json.loads(handler.wfile.getvalue()) + + +def test_search_handler_contract(): + configure_globals() + + status, data = invoke_get("/search") + assert status == 400 + assert data == {"error": "missing ?q= parameter"} + + status, data = invoke_get("/search?q=hello&k=abc") + assert status == 400 + assert data == {"error": "k and min_score must be integers"} + + status, data = invoke_get("/health") + assert status == 200 + assert data == {"status": "ok", "items": 1} + + status, data = invoke_get("/stats") + assert status == 200 + assert set(data) == {"items_total", "queries_served", "avg_latency_ms", "uptime_seconds"} + assert data["items_total"] == 1 + + status, data = invoke_get("/search?q=hello") + assert status == 200 + assert data["query"] == "hello" + assert data["count"] == 1 + assert data["results"][0]["id"] == 1 + assert "latency_ms" in data + + status, data = invoke_get("/missing") + assert status == 404 + assert data == {"error": "not found"}