diff --git a/cmd/cosift/bench_extra_test.go b/cmd/cosift/bench_extra_test.go new file mode 100644 index 0000000..4ac242e --- /dev/null +++ b/cmd/cosift/bench_extra_test.go @@ -0,0 +1,101 @@ +package main + +import ( + "math/rand" + "strings" + "testing" +) + +func TestFormatHumanVector(t *testing.T) { + r := &benchResult{Mode: "vector", N: 100, Dim: 768, P50Micros: 500, P95Micros: 1200, P99Micros: 2500, QPS: 4500.7} + out := r.formatHuman() + if !strings.Contains(out, "vector") { + t.Errorf("missing mode: %q", out) + } + if !strings.Contains(out, "n=100") { + t.Errorf("missing n: %q", out) + } + if !strings.Contains(out, "dim=768") { + t.Errorf("missing dim: %q", out) + } +} + +func TestFormatHumanBM25(t *testing.T) { + r := &benchResult{Mode: "bm25", N: 100, P50Micros: 50, P95Micros: 120, P99Micros: 250, QPS: 12500} + out := r.formatHuman() + if !strings.Contains(out, "bm25") { + t.Errorf("missing mode: %q", out) + } + if !strings.Contains(out, "qps=12500") { + t.Errorf("missing qps: %q", out) + } +} + +func TestFormatHumanCrawl(t *testing.T) { + r := &benchResult{Mode: "crawl", N: 50, ElapsedMicros: 5_000_000, Docs: 50, PagesPerSec: 10, Terms: 200} + out := r.formatHuman() + if !strings.Contains(out, "crawl") { + t.Errorf("missing mode: %q", out) + } + if !strings.Contains(out, "docs=50") { + t.Errorf("missing docs: %q", out) + } + // PerHostDelayMs > 0 appends an extra label. + r.PerHostDelayMs = 250 + out2 := r.formatHuman() + if !strings.Contains(out2, "per-host-delay=250ms") { + t.Errorf("expected per-host-delay annotation: %q", out2) + } +} + +func TestFormatHumanUnknownMode(t *testing.T) { + r := &benchResult{Mode: "weird-mode", N: 1} + out := r.formatHuman() + if !strings.Contains(out, "weird-mode") { + t.Errorf("expected fallback to include struct dump: %q", out) + } +} + +func TestUnion(t *testing.T) { + a := map[string]*benchResult{"vector": {}, "extra-a": {}} + b := map[string]*benchResult{"bm25": {}, "vector": {}, "extra-b": {}} + out := union(a, b) + seen := map[string]bool{} + for _, k := range out { + seen[k] = true + } + for _, k := range []string{"vector", "bm25", "extra-a", "extra-b"} { + if !seen[k] { + t.Errorf("missing %s in union: %v", k, out) + } + } +} + +func TestNeutralVocabForDistractors(t *testing.T) { + v := neutralVocabForDistractors() + if len(v) < 30 { + t.Errorf("vocab too short: %d words", len(v)) + } + // All entries non-empty. + for i, w := range v { + if w == "" { + t.Errorf("entry %d is empty", i) + } + } +} + +func TestGenerateDistractorText(t *testing.T) { + rng := rand.New(rand.NewSource(1)) + vocab := []string{"a", "b", "c"} + got := generateDistractorText(rng, vocab, 5) + parts := strings.Fields(got) + if len(parts) != 5 { + t.Errorf("got %d words want 5", len(parts)) + } + allowed := map[string]bool{"a": true, "b": true, "c": true} + for _, p := range parts { + if !allowed[p] { + t.Errorf("unexpected word %q", p) + } + } +} diff --git a/cmd/cosift/helpers_extra_test.go b/cmd/cosift/helpers_extra_test.go new file mode 100644 index 0000000..961b5ee --- /dev/null +++ b/cmd/cosift/helpers_extra_test.go @@ -0,0 +1,263 @@ +package main + +import ( + "math" + "os" + "testing" + "time" +) + +func TestAuthStatus(t *testing.T) { + cases := []struct { + name string + configured bool + key string + urlSet bool + want string + }{ + {"not configured", false, "", false, "(none)"}, + {"not configured ignores other args", false, "k", true, "(none)"}, + {"with bearer", true, "secret", false, "bearer-token"}, + {"with bearer + custom url", true, "secret", true, "bearer-token"}, + {"anonymous custom URL", true, "", true, "anonymous(custom-url)"}, + {"missing", true, "", false, "MISSING"}, + } + for _, c := range cases { + if got := authStatus(c.configured, c.key, c.urlSet); got != c.want { + t.Errorf("%s: got %q want %q", c.name, got, c.want) + } + } +} + +func TestResolveAPIKey(t *testing.T) { + // Save & wipe relevant envs. + saved := map[string]string{} + for _, k := range []string{ + "COSIFT_EMBED_API_KEY", "COSIFT_CHAT_API_KEY", + "OPENAI_API_KEY", "OPENAI", + } { + saved[k] = os.Getenv(k) + os.Unsetenv(k) + } + t.Cleanup(func() { + for k, v := range saved { + if v == "" { + os.Unsetenv(k) + } else { + os.Setenv(k, v) + } + } + }) + + // Empty → "". + if got := resolveAPIKey("embed"); got != "" { + t.Errorf("empty env: got %q want \"\"", got) + } + + // OPENAI_API_KEY fallback. + os.Setenv("OPENAI_API_KEY", "openai-key") + if got := resolveAPIKey("embed"); got != "openai-key" { + t.Errorf("openai key: got %q", got) + } + + // Slot-specific takes precedence. + os.Setenv("COSIFT_EMBED_API_KEY", "embed-key") + if got := resolveAPIKey("embed"); got != "embed-key" { + t.Errorf("embed key: got %q", got) + } + + // Chat slot — uses chat slot env. + os.Setenv("COSIFT_CHAT_API_KEY", "chat-key") + if got := resolveAPIKey("chat"); got != "chat-key" { + t.Errorf("chat key: got %q", got) + } + + // Unknown slot falls through to OPENAI_API_KEY. + if got := resolveAPIKey("unknown"); got != "openai-key" { + t.Errorf("unknown slot fallthrough: got %q", got) + } + + // OPENAI (without _API_KEY) fallback. + os.Unsetenv("OPENAI_API_KEY") + os.Setenv("OPENAI", "legacy") + if got := resolveAPIKey("unknown"); got != "legacy" { + t.Errorf("OPENAI fallback: got %q", got) + } +} + +func TestResolveEmbedAPIKey(t *testing.T) { + saved := os.Getenv("COSIFT_EMBED_API_KEY") + os.Setenv("COSIFT_EMBED_API_KEY", "embed-direct") + t.Cleanup(func() { + if saved == "" { + os.Unsetenv("COSIFT_EMBED_API_KEY") + } else { + os.Setenv("COSIFT_EMBED_API_KEY", saved) + } + }) + if got := resolveEmbedAPIKey(); got != "embed-direct" { + t.Errorf("resolveEmbedAPIKey: got %q", got) + } +} + +func TestFirstEnv(t *testing.T) { + for _, k := range []string{"COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B", "COSIFT_FIRSTENV_C"} { + os.Unsetenv(k) + k := k + t.Cleanup(func() { os.Unsetenv(k) }) + } + + // All empty. + if got := firstEnv("COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B"); got != "" { + t.Errorf("all empty: got %q", got) + } + + // Returns first non-empty. + os.Setenv("COSIFT_FIRSTENV_B", "b-val") + if got := firstEnv("COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B", "COSIFT_FIRSTENV_C"); got != "b-val" { + t.Errorf("got %q want b-val", got) + } + + // Earlier wins. + os.Setenv("COSIFT_FIRSTENV_A", "a-val") + if got := firstEnv("COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B"); got != "a-val" { + t.Errorf("got %q want a-val (first)", got) + } + + // No args. + if got := firstEnv(); got != "" { + t.Errorf("no args: got %q", got) + } +} + +func TestContains(t *testing.T) { + if !contains([]string{"a", "b", "c"}, "b") { + t.Errorf("expected true for present element") + } + if contains([]string{"a", "b", "c"}, "z") { + t.Errorf("expected false for absent element") + } + if contains(nil, "x") { + t.Errorf("expected false on nil slice") + } +} + +func TestChunkerWith(t *testing.T) { + c := chunkerWith(100, 20) + if c == nil { + t.Errorf("chunkerWith returned nil") + } +} + +func TestSqrt(t *testing.T) { + cases := []struct { + in, want float64 + }{ + {4, 2}, + {9, 3}, + {16, 4}, + {2, math.Sqrt2}, + } + for _, c := range cases { + got := sqrt(c.in) + if math.Abs(got-c.want) > 1e-6 { + t.Errorf("sqrt(%v): got %v want %v", c.in, got, c.want) + } + } + // Edge: sqrt(0). + if got := sqrt(0); got != 0 { + // The Newton iteration starts at x/2 = 0 and breaks immediately. + _ = got + } +} + +func TestRandUnit(t *testing.T) { + rng := newSeededRand() + v := randUnit(rng, 8) + if len(v) != 8 { + t.Errorf("len: got %d want 8", len(v)) + } + // Unit vector — magnitude ~ 1. + var mag float64 + for _, x := range v { + mag += float64(x) * float64(x) + } + mag = math.Sqrt(mag) + if math.Abs(mag-1.0) > 0.01 { + t.Errorf("magnitude: got %v want ~1.0", mag) + } + + // Deterministic. + rng2 := newSeededRand() + v2 := randUnit(rng2, 8) + for i := range v { + if v[i] != v2[i] { + t.Errorf("non-deterministic at %d: %v vs %v", i, v[i], v2[i]) + } + } +} + +func TestPercentiles(t *testing.T) { + // Empty. + p50, p95, p99 := percentiles(nil) + if p50 != 0 || p95 != 0 || p99 != 0 { + t.Errorf("empty: got %v %v %v", p50, p95, p99) + } + + // Sorted input — index math gives p50 = sorted[N*0.5]. + ds := []time.Duration{ + 1 * time.Millisecond, + 2 * time.Millisecond, + 3 * time.Millisecond, + 4 * time.Millisecond, + 5 * time.Millisecond, + 6 * time.Millisecond, + 7 * time.Millisecond, + 8 * time.Millisecond, + 9 * time.Millisecond, + 10 * time.Millisecond, + } + p50, p95, p99 = percentiles(ds) + // idx = int((N-1)*p): p50 at idx 4 = 5ms, p95 at idx 8 = 9ms, p99 at idx 8 = 9ms. + if p50 != 5*time.Millisecond { + t.Errorf("p50: got %v want 5ms", p50) + } + if p95 != 9*time.Millisecond { + t.Errorf("p95: got %v want 9ms", p95) + } + + // Unsorted input — function copies + sorts. + unsorted := []time.Duration{10, 1, 5, 3, 7, 9, 2, 4, 6, 8} + for i := range unsorted { + unsorted[i] *= time.Millisecond + } + p50b, _, _ := percentiles(unsorted) + if p50b != 5*time.Millisecond { + t.Errorf("unsorted p50: got %v want 5ms", p50b) + } +} + +func TestSumDur(t *testing.T) { + if got := sumDur(nil); got != 0 { + t.Errorf("nil: got %v", got) + } + got := sumDur([]time.Duration{ + 100 * time.Millisecond, + 250 * time.Millisecond, + 50 * time.Millisecond, + }) + if got != 400*time.Millisecond { + t.Errorf("sum: got %v want 400ms", got) + } +} + +func TestNewSeededRandIsDeterministic(t *testing.T) { + a := newSeededRand() + b := newSeededRand() + for i := 0; i < 10; i++ { + x, y := a.Float64(), b.Float64() + if x != y { + t.Errorf("non-deterministic at iter %d: %v vs %v", i, x, y) + } + } +} diff --git a/internal/config/config_load_test.go b/internal/config/config_load_test.go new file mode 100644 index 0000000..62017e0 --- /dev/null +++ b/internal/config/config_load_test.go @@ -0,0 +1,288 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDefaultHasSensibleValues(t *testing.T) { + c := Default() + if c == nil { + t.Fatal("Default returned nil") + } + if c.DataDir == "" { + t.Errorf("DataDir empty") + } + if c.Server.Addr == "" { + t.Errorf("Server.Addr empty") + } + if c.Crawler.UserAgent == "" { + t.Errorf("Crawler.UserAgent empty") + } + if c.Crawler.MaxConcurrent <= 0 { + t.Errorf("MaxConcurrent should be > 0, got %d", c.Crawler.MaxConcurrent) + } + if c.Crawler.MaxBodyBytes <= 0 { + t.Errorf("MaxBodyBytes should be > 0, got %d", c.Crawler.MaxBodyBytes) + } + if !c.Crawler.RespectRobots { + t.Errorf("RespectRobots should default true") + } +} + +func TestLoadMissingFileReturnsDefaults(t *testing.T) { + // Save & restore env vars we touch. + for _, k := range []string{"PORT", "COSIFT_LISTEN", "COSIFT_DATA_DIR"} { + orig, had := os.LookupEnv(k) + os.Unsetenv(k) + defer func(k, v string, had bool) { + if had { + os.Setenv(k, v) + } else { + os.Unsetenv(k) + } + }(k, orig, had) + } + + path := filepath.Join(t.TempDir(), "does-not-exist.json") + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load missing file: %v", err) + } + if cfg.DataDir != Default().DataDir { + t.Errorf("DataDir not default: %q", cfg.DataDir) + } + if cfg.Server.Addr != Default().Server.Addr { + t.Errorf("Server.Addr not default: %q", cfg.Server.Addr) + } +} + +func TestLoadFromFile(t *testing.T) { + for _, k := range []string{"PORT", "COSIFT_LISTEN", "COSIFT_DATA_DIR"} { + orig, had := os.LookupEnv(k) + os.Unsetenv(k) + defer func(k, v string, had bool) { + if had { + os.Setenv(k, v) + } else { + os.Unsetenv(k) + } + }(k, orig, had) + } + + body := []byte(`{ + "data_dir": "/tmp/cosift-test", + "server": {"addr": "127.0.0.1:9999"}, + "crawler": {"max_concurrent": 16} + }`) + dir := t.TempDir() + path := filepath.Join(dir, "cosift.json") + if err := os.WriteFile(path, body, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.DataDir != "/tmp/cosift-test" { + t.Errorf("DataDir: got %q", cfg.DataDir) + } + if cfg.Server.Addr != "127.0.0.1:9999" { + t.Errorf("Addr: got %q", cfg.Server.Addr) + } + if cfg.Crawler.MaxConcurrent != 16 { + t.Errorf("MaxConcurrent: got %d", cfg.Crawler.MaxConcurrent) + } +} + +func TestLoadBadJSONErrors(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.json") + if err := os.WriteFile(path, []byte("{not json"), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if _, err := Load(path); err == nil { + t.Errorf("expected error for malformed JSON, got nil") + } +} + +func TestLoadEmptyDataDirFallsBackToDefault(t *testing.T) { + for _, k := range []string{"PORT", "COSIFT_LISTEN", "COSIFT_DATA_DIR"} { + orig, had := os.LookupEnv(k) + os.Unsetenv(k) + defer func(k, v string, had bool) { + if had { + os.Setenv(k, v) + } else { + os.Unsetenv(k) + } + }(k, orig, had) + } + + body := []byte(`{"data_dir": "", "server": {"addr": "x:1"}}`) + dir := t.TempDir() + path := filepath.Join(dir, "c.json") + _ = os.WriteFile(path, body, 0o644) + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.DataDir == "" { + t.Errorf("empty DataDir should fall back to Default") + } + if cfg.DataDir != Default().DataDir { + t.Errorf("DataDir: got %q want %q", cfg.DataDir, Default().DataDir) + } +} + +func TestApplyEnvOverridesPort(t *testing.T) { + orig, had := os.LookupEnv("PORT") + os.Setenv("PORT", "8765") + t.Cleanup(func() { + if had { + os.Setenv("PORT", orig) + } else { + os.Unsetenv("PORT") + } + }) + + cfg := Default() + applyEnvOverrides(cfg) + if cfg.Server.Addr != "0.0.0.0:8765" { + t.Errorf("PORT override: got %q want 0.0.0.0:8765", cfg.Server.Addr) + } +} + +func TestApplyEnvOverridesCosiftListen(t *testing.T) { + orig, had := os.LookupEnv("COSIFT_LISTEN") + os.Setenv("COSIFT_LISTEN", "10.0.0.1:1234") + t.Cleanup(func() { + if had { + os.Setenv("COSIFT_LISTEN", orig) + } else { + os.Unsetenv("COSIFT_LISTEN") + } + }) + + cfg := Default() + applyEnvOverrides(cfg) + if cfg.Server.Addr != "10.0.0.1:1234" { + t.Errorf("COSIFT_LISTEN override: got %q", cfg.Server.Addr) + } +} + +func TestApplyEnvOverridesDataDir(t *testing.T) { + orig, had := os.LookupEnv("COSIFT_DATA_DIR") + os.Setenv("COSIFT_DATA_DIR", "/custom/path") + t.Cleanup(func() { + if had { + os.Setenv("COSIFT_DATA_DIR", orig) + } else { + os.Unsetenv("COSIFT_DATA_DIR") + } + }) + + cfg := Default() + applyEnvOverrides(cfg) + if cfg.DataDir != "/custom/path" { + t.Errorf("COSIFT_DATA_DIR override: got %q", cfg.DataDir) + } +} + +func TestApplyEnvOverridesEmptyIgnored(t *testing.T) { + orig, had := os.LookupEnv("PORT") + os.Setenv("PORT", " ") + t.Cleanup(func() { + if had { + os.Setenv("PORT", orig) + } else { + os.Unsetenv("PORT") + } + }) + + cfg := Default() + addr := cfg.Server.Addr + applyEnvOverrides(cfg) + if cfg.Server.Addr != addr { + t.Errorf("whitespace-only PORT should be ignored, addr changed from %q to %q", addr, cfg.Server.Addr) + } +} + +func TestLoadDotEnvMissingFileNoError(t *testing.T) { + path := filepath.Join(t.TempDir(), "nope.env") + if err := LoadDotEnv(path); err != nil { + t.Errorf("missing .env: %v", err) + } +} + +func TestLoadDotEnvBasic(t *testing.T) { + const key = "COSIFT_DOTENV_TEST_KEY_1" + const val = "hello-world" + os.Unsetenv(key) + t.Cleanup(func() { os.Unsetenv(key) }) + + body := []byte(key + "=" + val + "\n") + path := filepath.Join(t.TempDir(), ".env") + if err := os.WriteFile(path, body, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := LoadDotEnv(path); err != nil { + t.Fatalf("LoadDotEnv: %v", err) + } + if got := os.Getenv(key); got != val { + t.Errorf("env value: got %q want %q", got, val) + } +} + +func TestLoadDotEnvQuotedAndComments(t *testing.T) { + keys := map[string]string{ + "COSIFT_DOTENV_DOUBLE": "double-quoted-value", + "COSIFT_DOTENV_SINGLE": "single-quoted-value", + "COSIFT_DOTENV_BARE": "bare", + } + for k := range keys { + os.Unsetenv(k) + k := k + t.Cleanup(func() { os.Unsetenv(k) }) + } + body := []byte(`# a comment line +COSIFT_DOTENV_DOUBLE="double-quoted-value" +COSIFT_DOTENV_SINGLE='single-quoted-value' +COSIFT_DOTENV_BARE=bare + +# blank line above and below comment + +malformed_line_no_equals +`) + path := filepath.Join(t.TempDir(), ".env") + if err := os.WriteFile(path, body, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + if err := LoadDotEnv(path); err != nil { + t.Fatalf("LoadDotEnv: %v", err) + } + for k, want := range keys { + if got := os.Getenv(k); got != want { + t.Errorf("%s: got %q want %q", k, got, want) + } + } +} + +func TestLoadDotEnvDoesNotOverrideExisting(t *testing.T) { + const key = "COSIFT_DOTENV_EXISTS" + const preset = "preset-wins" + os.Setenv(key, preset) + t.Cleanup(func() { os.Unsetenv(key) }) + + body := []byte(key + "=from-dotenv\n") + path := filepath.Join(t.TempDir(), ".env") + _ = os.WriteFile(path, body, 0o644) + + if err := LoadDotEnv(path); err != nil { + t.Fatalf("LoadDotEnv: %v", err) + } + if got := os.Getenv(key); got != preset { + t.Errorf("existing env was overridden: got %q want %q", got, preset) + } +} diff --git a/internal/crawler/parse_proxies_test.go b/internal/crawler/parse_proxies_test.go new file mode 100644 index 0000000..d9c2008 --- /dev/null +++ b/internal/crawler/parse_proxies_test.go @@ -0,0 +1,158 @@ +package crawler + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestParseProxies(t *testing.T) { + in := []string{ + "http://proxy.example.com:8080", + "http://user:pass@proxy.example.com:8080", + "socks5://socks.local:1080", + "", // empty — skipped + " ", // whitespace-only — skipped + "not-a-url", // no scheme/host — skipped + "http://", // empty host — skipped + "https://valid:443", // ok + } + out := parseProxies(in) + if len(out) != 4 { + t.Errorf("parseProxies: got %d want 4 (%+v)", len(out), out) + } + if out[0].Host != "proxy.example.com:8080" { + t.Errorf("first proxy host: %q", out[0].Host) + } +} + +func TestParseProxiesEmpty(t *testing.T) { + out := parseProxies(nil) + if len(out) != 0 { + t.Errorf("nil input: got len %d", len(out)) + } + out = parseProxies([]string{}) + if len(out) != 0 { + t.Errorf("empty input: got len %d", len(out)) + } +} + +// --- Robots --- + +func TestRobotsAllowedHTTPServer(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/robots.txt" { + http.NotFound(w, r) + return + } + _, _ = w.Write([]byte("User-agent: *\nDisallow: /private\nAllow: /private/public\nCrawl-delay: 2\nSitemap: https://example/sitemap.xml\n")) + })) + defer ts.Close() + + r := NewRobots(&http.Client{Timeout: 5 * time.Second}, "TestUA/1.0") + + // Disallowed. + ok, delay, err := r.Allowed(t.Context(), ts.URL+"/private/secret") + if err != nil { + t.Fatalf("Allowed: %v", err) + } + if ok { + t.Errorf("/private/secret should be disallowed") + } + if delay <= 0 { + t.Errorf("expected crawl-delay > 0, got %v", delay) + } + + // Allowed (more specific allow overrides disallow). + ok, _, _ = r.Allowed(t.Context(), ts.URL+"/private/public/foo") + if !ok { + t.Errorf("/private/public/foo should be allowed (longer allow rule)") + } + + // Sitemaps endpoint. + sitemaps := r.Sitemaps(t.Context(), ts.URL) + if len(sitemaps) != 1 || sitemaps[0] != "https://example/sitemap.xml" { + t.Errorf("Sitemaps: %v", sitemaps) + } + + // Bad URL. + _, _, err = r.Allowed(t.Context(), "::not a url") + if err == nil { + t.Errorf("Allowed should error on malformed URL") + } +} + +func TestRobotsSitemapsBadURL(t *testing.T) { + r := NewRobots(&http.Client{Timeout: 1 * time.Second}, "T/1") + if out := r.Sitemaps(t.Context(), "::bad"); out != nil { + t.Errorf("bad URL: got %v", out) + } +} + +func TestRobotsCacheHit(t *testing.T) { + calls := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/robots.txt" { + calls++ + } + _, _ = w.Write([]byte("User-agent: *\nDisallow:\n")) + })) + defer ts.Close() + + r := NewRobots(&http.Client{Timeout: 2 * time.Second}, "U/1") + for i := 0; i < 3; i++ { + if _, _, err := r.Allowed(t.Context(), ts.URL+"/p"); err != nil { + t.Fatalf("Allowed: %v", err) + } + } + if calls != 1 { + t.Errorf("robots.txt fetched %d times, want 1 (cache miss path)", calls) + } +} + +func TestRobotsFetch404TreatedAsEmpty(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer ts.Close() + r := NewRobots(&http.Client{Timeout: 2 * time.Second}, "U/1") + ok, _, err := r.Allowed(t.Context(), ts.URL+"/anything") + if err != nil { + t.Fatalf("Allowed: %v", err) + } + if !ok { + t.Errorf("404 robots.txt should leave host fully allowed") + } +} + +// Quick sanity: parseRobots tolerates a body with comments + blank lines +// + malformed lines (already covered in robots_test.go but exercises a +// different combination here). +func TestParseRobotsCornerCases(t *testing.T) { + body := strings.Join([]string{ + "# leading comment", + "", + "User-agent: SpecificBot", + "User-agent: AnotherBot", + "Disallow: /no", + "Allow: /no/yes", + "Crawl-delay: 5", + "", + "User-agent: *", + "Disallow: /", + "", + "Sitemap: https://x/sm1.xml", + "Sitemap: https://x/sm2.xml", + "unknown-directive: ignored", + "no-colon-line", + }, "\n") + rules := parseRobots(body) + if len(rules.sitemaps) != 2 { + t.Errorf("sitemaps: got %d want 2", len(rules.sitemaps)) + } + if len(rules.groups) == 0 { + t.Errorf("expected groups parsed") + } +} diff --git a/internal/embed/wrappers_test.go b/internal/embed/wrappers_test.go new file mode 100644 index 0000000..d81e83f --- /dev/null +++ b/internal/embed/wrappers_test.go @@ -0,0 +1,339 @@ +package embed + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" +) + +// fakeEmbedder is a deterministic in-memory Embedder for testing wrappers. +type fakeEmbedder struct { + model string + dim int + calls atomic.Int64 + totalIn atomic.Int64 // total input texts seen across calls + err error + delay time.Duration + maxConc int32 // tracks max concurrency + curConc int32 // current in-flight calls + concMu sync.Mutex +} + +func (f *fakeEmbedder) Model() string { return f.model } +func (f *fakeEmbedder) Dim() int { return f.dim } + +func (f *fakeEmbedder) Embed(ctx context.Context, texts []string) ([][]float32, error) { + cur := atomic.AddInt32(&f.curConc, 1) + defer atomic.AddInt32(&f.curConc, -1) + f.concMu.Lock() + if cur > f.maxConc { + f.maxConc = cur + } + f.concMu.Unlock() + + f.calls.Add(1) + f.totalIn.Add(int64(len(texts))) + if f.delay > 0 { + select { + case <-time.After(f.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + if f.err != nil { + return nil, f.err + } + out := make([][]float32, len(texts)) + for i, t := range texts { + // Deterministic per-text vector: first byte of text as float, rest zero. + v := make([]float32, f.dim) + if len(t) > 0 { + v[0] = float32(t[0]) + } + out[i] = v + } + return out, nil +} + +// --- OpenAIClient pure-method tests --- + +func TestOpenAIClientModelAndDim(t *testing.T) { + c := NewOpenAIClient("k", "http://x", "test-model", 768) + if c.Model() != "test-model" { + t.Errorf("Model: %q", c.Model()) + } + if c.Dim() != 768 { + t.Errorf("Dim: %d", c.Dim()) + } +} + +func TestNewOpenAIChatURLNormalization(t *testing.T) { + cases := []struct { + in, want string + }{ + {"", "https://api.openai.com/v1/chat/completions"}, + {"https://x/v1", "https://x/v1/chat/completions"}, + {"https://x/v1/", "https://x/v1/chat/completions"}, + {"https://x/v1/chat/completions", "https://x/v1/chat/completions"}, + } + for _, c := range cases { + got := NewOpenAIChat("k", c.in, "model-id") + if got.URL != c.want { + t.Errorf("NewOpenAIChat(%q).URL: got %q want %q", c.in, got.URL, c.want) + } + if got.Model() != "model-id" { + t.Errorf("Model: %q", got.Model()) + } + } +} + +// --- RoundRobinEmbedder --- + +func TestNewRoundRobinEmbedderEmpty(t *testing.T) { + if got := NewRoundRobinEmbedder(nil); got != nil { + t.Errorf("empty: got %v want nil", got) + } + if got := NewRoundRobinEmbedder([]Embedder{}); got != nil { + t.Errorf("empty slice: got %v want nil", got) + } +} + +func TestNewRoundRobinEmbedderSingle(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + got := NewRoundRobinEmbedder([]Embedder{inner}) + // Single element returns the inner unwrapped. + if _, isRR := got.(*RoundRobinEmbedder); isRR { + t.Errorf("single-element should NOT wrap as RoundRobinEmbedder") + } + if got.Model() != "x" { + t.Errorf("Model: %q", got.Model()) + } +} + +func TestRoundRobinEmbedderDistributes(t *testing.T) { + a := &fakeEmbedder{model: "m", dim: 4} + b := &fakeEmbedder{model: "m", dim: 4} + c := &fakeEmbedder{model: "m", dim: 4} + + rr := NewRoundRobinEmbedder([]Embedder{a, b, c}) + if rr.Model() != "m" || rr.Dim() != 4 { + t.Errorf("Model/Dim from inners[0]: %q/%d", rr.Model(), rr.Dim()) + } + + ctx := context.Background() + for i := 0; i < 6; i++ { + if _, err := rr.Embed(ctx, []string{"t"}); err != nil { + t.Fatalf("Embed: %v", err) + } + } + // 6 calls across 3 backends → each gets exactly 2. + if a.calls.Load() != 2 || b.calls.Load() != 2 || c.calls.Load() != 2 { + t.Errorf("distribution: a=%d b=%d c=%d", a.calls.Load(), b.calls.Load(), c.calls.Load()) + } +} + +// --- ThrottledEmbedder --- + +func TestNewThrottledEmbedderMaxZero(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + got := NewThrottledEmbedder(inner, 0) + if _, ok := got.(*ThrottledEmbedder); ok { + t.Errorf("max<=0 should return inner unwrapped") + } + got = NewThrottledEmbedder(inner, -5) + if _, ok := got.(*ThrottledEmbedder); ok { + t.Errorf("max<0 should return inner unwrapped") + } +} + +func TestThrottledEmbedderModelDim(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + w := NewThrottledEmbedder(inner, 2) + if w.Model() != "x" || w.Dim() != 4 { + t.Errorf("delegation broken: %q/%d", w.Model(), w.Dim()) + } +} + +func TestThrottledEmbedderCapsConcurrency(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4, delay: 30 * time.Millisecond} + w := NewThrottledEmbedder(inner, 2) + ctx := context.Background() + + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = w.Embed(ctx, []string{"hi"}) + }() + } + wg.Wait() + inner.concMu.Lock() + got := inner.maxConc + inner.concMu.Unlock() + if got > 2 { + t.Errorf("concurrency cap broken: maxConc=%d want <=2", got) + } + if inner.calls.Load() != 8 { + t.Errorf("all calls should complete: got %d want 8", inner.calls.Load()) + } +} + +func TestThrottledEmbedderCancelDuringQueue(t *testing.T) { + // Saturate with 1 in-flight, then attempt a 2nd that will hang on the sem. + inner := &fakeEmbedder{model: "x", dim: 4, delay: 100 * time.Millisecond} + w := NewThrottledEmbedder(inner, 1) + ctx := context.Background() + go func() { + _, _ = w.Embed(ctx, []string{"first"}) + }() + // Wait a beat so the goroutine actually grabs the slot. + time.Sleep(10 * time.Millisecond) + + ctx2, cancel := context.WithCancel(context.Background()) + cancel() // immediately cancel + _, err := w.Embed(ctx2, []string{"second"}) + if err == nil { + t.Errorf("expected ctx.Err from canceled context") + } +} + +// --- BatchingEmbedder --- + +func TestNewBatchingEmbedderDefaults(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + b := NewBatchingEmbedder(inner, 0, 0) + defer b.Close() + if b.maxBatch != 64 { + t.Errorf("maxBatch default: got %d", b.maxBatch) + } + if b.maxWait <= 0 { + t.Errorf("maxWait default: got %v", b.maxWait) + } + if b.Model() != "x" || b.Dim() != 4 { + t.Errorf("delegation: %q/%d", b.Model(), b.Dim()) + } +} + +func TestBatchingEmbedderSingleCall(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + b := NewBatchingEmbedder(inner, 32, 20*time.Millisecond) + defer b.Close() + + vecs, err := b.Embed(context.Background(), []string{"hi", "bye"}) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 2 { + t.Errorf("vecs len: got %d want 2", len(vecs)) + } + if vecs[0][0] != float32('h') || vecs[1][0] != float32('b') { + t.Errorf("wrong content order: %v %v", vecs[0][0], vecs[1][0]) + } +} + +func TestBatchingEmbedderEmpty(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + b := NewBatchingEmbedder(inner, 32, 5*time.Millisecond) + defer b.Close() + vecs, err := b.Embed(context.Background(), nil) + if err != nil { + t.Errorf("empty: %v", err) + } + if vecs != nil { + t.Errorf("nil input should return nil vecs, got %v", vecs) + } +} + +func TestBatchingEmbedderCoalesces(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4} + b := NewBatchingEmbedder(inner, 32, 50*time.Millisecond) + defer b.Close() + + ctx := context.Background() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + i := i + go func() { + defer wg.Done() + _, _ = b.Embed(ctx, []string{string(rune('a' + i))}) + }() + } + wg.Wait() + // All 5 callers issued 1 text each. Coalesced into 1-2 batches. + calls := inner.calls.Load() + if calls > 3 { + t.Errorf("expected coalesced (<=3 calls), got %d", calls) + } + if inner.totalIn.Load() != 5 { + t.Errorf("total texts: got %d want 5", inner.totalIn.Load()) + } +} + +func TestBatchingEmbedderInnerError(t *testing.T) { + inner := &fakeEmbedder{model: "x", dim: 4, err: errors.New("boom")} + b := NewBatchingEmbedder(inner, 32, 5*time.Millisecond) + defer b.Close() + _, err := b.Embed(context.Background(), []string{"x"}) + if err == nil { + t.Errorf("expected inner error to propagate") + } +} + +// --- CachedEmbedder --- + +func TestCachedEmbedderModelDim(t *testing.T) { + inner := &fakeEmbedder{model: "m", dim: 8} + c := NewCachedEmbedder(inner, "") + if c.Model() != "m" || c.Dim() != 8 { + t.Errorf("delegation: %q/%d", c.Model(), c.Dim()) + } +} + +func TestCachedEmbedderHitsMisses(t *testing.T) { + inner := &fakeEmbedder{model: "m", dim: 4} + c := NewCachedEmbedder(inner, t.TempDir()) + + // Initial counters. + if c.Hits() != 0 || c.Misses() != 0 { + t.Errorf("initial counters non-zero: hits=%d misses=%d", c.Hits(), c.Misses()) + } + + ctx := context.Background() + // First call: 2 misses. + if _, err := c.Embed(ctx, []string{"alpha", "beta"}); err != nil { + t.Fatalf("Embed: %v", err) + } + if c.Misses() != 2 { + t.Errorf("after first: misses=%d want 2", c.Misses()) + } + + // Second call with same texts: 2 hits. + if _, err := c.Embed(ctx, []string{"alpha", "beta"}); err != nil { + t.Fatalf("Embed: %v", err) + } + if c.Hits() != 2 { + t.Errorf("after second: hits=%d want 2", c.Hits()) + } + + // Mixed: 1 hit, 1 miss. + if _, err := c.Embed(ctx, []string{"alpha", "gamma"}); err != nil { + t.Fatalf("Embed: %v", err) + } + if c.Hits() != 3 || c.Misses() != 3 { + t.Errorf("after mixed: hits=%d misses=%d", c.Hits(), c.Misses()) + } +} + +func TestCachedEmbedderNoDir(t *testing.T) { + // Empty dir disables persistence; calls still work. + inner := &fakeEmbedder{model: "m", dim: 4} + c := NewCachedEmbedder(inner, "") + if _, err := c.Embed(context.Background(), []string{"x"}); err != nil { + t.Errorf("empty dir Embed: %v", err) + } +} diff --git a/internal/eval/eval_io_test.go b/internal/eval/eval_io_test.go new file mode 100644 index 0000000..4133637 --- /dev/null +++ b/internal/eval/eval_io_test.go @@ -0,0 +1,303 @@ +package eval + +import ( + "context" + "encoding/json" + "errors" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// fakeRetriever returns a fixed response per query. +type fakeRetriever struct { + bestOf map[string][]string + fail map[string]error +} + +func (f *fakeRetriever) Search(_ context.Context, q string, _ int) ([]string, error) { + if err, ok := f.fail[q]; ok { + return nil, err + } + return f.bestOf[q], nil +} + +func TestLoadQuerySet(t *testing.T) { + qs := QuerySet{ + Name: "fixture", + Queries: []Query{ + {Text: "alpha", Relevant: []string{"u1", "u2"}}, + }, + } + b, _ := json.Marshal(qs) + dir := t.TempDir() + p := filepath.Join(dir, "qs.json") + if err := os.WriteFile(p, b, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + got, err := LoadQuerySet(p) + if err != nil { + t.Fatalf("Load: %v", err) + } + if got.Name != "fixture" || len(got.Queries) != 1 || got.Queries[0].Text != "alpha" { + t.Errorf("bad roundtrip: %+v", got) + } +} + +func TestLoadQuerySetMissingFile(t *testing.T) { + if _, err := LoadQuerySet(filepath.Join(t.TempDir(), "nope.json")); err == nil { + t.Errorf("missing file: expected error") + } +} + +func TestLoadQuerySetBadJSON(t *testing.T) { + p := filepath.Join(t.TempDir(), "bad.json") + _ = os.WriteFile(p, []byte("{garbage"), 0o644) + if _, err := LoadQuerySet(p); err == nil { + t.Errorf("bad JSON: expected parse error") + } +} + +func TestLoadCorpus(t *testing.T) { + c := Corpus{Docs: []CorpusDoc{{URL: "u", Title: "t", Text: "body"}}} + b, _ := json.Marshal(c) + p := filepath.Join(t.TempDir(), "c.json") + if err := os.WriteFile(p, b, 0o644); err != nil { + t.Fatalf("write: %v", err) + } + got, err := LoadCorpus(p) + if err != nil { + t.Fatalf("Load: %v", err) + } + if len(got.Docs) != 1 || got.Docs[0].URL != "u" { + t.Errorf("bad roundtrip: %+v", got) + } +} + +func TestLoadCorpusMissingFile(t *testing.T) { + if _, err := LoadCorpus(filepath.Join(t.TempDir(), "nope.json")); err == nil { + t.Errorf("missing file: expected error") + } +} + +func TestLoadCorpusBadJSON(t *testing.T) { + p := filepath.Join(t.TempDir(), "bad.json") + _ = os.WriteFile(p, []byte("not json"), 0o644) + if _, err := LoadCorpus(p); err == nil { + t.Errorf("bad JSON: expected parse error") + } +} + +func TestRunNoParaphrases(t *testing.T) { + qs := &QuerySet{ + Name: "basic", + Queries: []Query{ + {Text: "q1", Relevant: []string{"u1"}}, + {Text: "q2", Relevant: []string{"u2", "u3"}}, + }, + } + r := &fakeRetriever{bestOf: map[string][]string{ + "q1": {"u1", "noise"}, + "q2": {"u2", "noise", "u3"}, + }} + sum, err := Run(context.Background(), qs, r) + if err != nil { + t.Fatalf("Run: %v", err) + } + if sum.NumQueries != 2 || len(sum.PerQuery) != 2 { + t.Fatalf("bad summary: %+v", sum) + } + // q1: u1 at pos 0 → R@1 = 1.0 + if sum.PerQuery[0].Metrics.Recall1 != 1.0 { + t.Errorf("q1 R@1: got %v want 1.0", sum.PerQuery[0].Metrics.Recall1) + } + if sum.Name != "basic" { + t.Errorf("name: %s", sum.Name) + } +} + +func TestRunRetrieverError(t *testing.T) { + qs := &QuerySet{ + Queries: []Query{{Text: "q1", Relevant: []string{"u"}}}, + } + r := &fakeRetriever{fail: map[string]error{"q1": errors.New("boom")}} + if _, err := Run(context.Background(), qs, r); err == nil { + t.Errorf("expected error") + } +} + +func TestRunWithParaphrases(t *testing.T) { + qs := &QuerySet{ + Queries: []Query{ + { + Text: "main", + Paraphrases: []string{"p1", "p2"}, + Relevant: []string{"target"}, + }, + }, + } + // Main misses but paraphrase finds target — fusion should rescue. + r := &fakeRetriever{bestOf: map[string][]string{ + "main": {"noise1", "noise2"}, + "p1": {"target", "noise3"}, + "p2": {"noise4", "target"}, + }} + sum, err := Run(context.Background(), qs, r) + if err != nil { + t.Fatalf("Run: %v", err) + } + if sum.PerQuery[0].Metrics.Recall10 != 1.0 { + t.Errorf("paraphrase fusion R@10: got %v want 1.0", sum.PerQuery[0].Metrics.Recall10) + } +} + +func TestRunMainErrorAborts(t *testing.T) { + qs := &QuerySet{ + Queries: []Query{ + {Text: "main", Paraphrases: []string{"p1"}, Relevant: []string{"x"}}, + }, + } + r := &fakeRetriever{fail: map[string]error{"main": errors.New("nope")}} + if _, err := Run(context.Background(), qs, r); err == nil { + t.Errorf("expected main error to propagate") + } +} + +func TestRunParaphraseErrorIsTolerated(t *testing.T) { + qs := &QuerySet{ + Queries: []Query{ + {Text: "main", Paraphrases: []string{"p1"}, Relevant: []string{"u"}}, + }, + } + r := &fakeRetriever{ + bestOf: map[string][]string{"main": {"u"}}, + fail: map[string]error{"p1": errors.New("transient")}, + } + sum, err := Run(context.Background(), qs, r) + if err != nil { + t.Fatalf("paraphrase error should not abort: %v", err) + } + if sum.PerQuery[0].Metrics.Recall10 == 0 { + t.Errorf("main-only RRF should still hit: %+v", sum.PerQuery[0].Metrics) + } +} + +func TestRRFFuse(t *testing.T) { + lists := [][]string{ + {"a", "b", "c"}, + {"b", "a", "d"}, + } + out := rrfFuse(lists, 3, 60) + if len(out) != 3 { + t.Fatalf("len: got %d want 3", len(out)) + } + // Both a and b score 1/61 + 1/62 ≈ 0.0325; their order is implementation + // dependent because of map iteration. Just assert they're both in top-2. + top2 := map[string]bool{out[0]: true, out[1]: true} + if !top2["a"] || !top2["b"] { + t.Errorf("expected a and b in top-2, got %v", out) + } +} + +func TestRRFFuseDefaultK(t *testing.T) { + // rrfK<=0 should fall back to 60 internally. + out := rrfFuse([][]string{{"x", "y"}}, 0, 0) + if len(out) != 2 { + t.Errorf("k<=0 should default to len(pairs); got %d", len(out)) + } +} + +func TestScore(t *testing.T) { + m := score([]string{"a", "b", "c"}, []string{"a", "c"}) + if m.Recall1 != 0.5 { + t.Errorf("R@1: got %v want 0.5", m.Recall1) + } + if m.MRR10 != 1.0 { + t.Errorf("MRR: got %v want 1.0", m.MRR10) + } +} + +func TestPrintTable(t *testing.T) { + sum := &Summary{ + Name: "set", + When: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC), + PerQuery: []PerQuery{ + {Query: "short", Metrics: Metrics{Recall10: 1.0, NDCG10: 1.0}}, + {Query: "this is a fairly long query string that should be truncated when printed in the table", Metrics: Metrics{Recall10: 0.5, NDCG10: 0.25}}, + }, + Mean: Metrics{Recall10: 0.75}, + } + out := PrintTable(sum) + if !strings.Contains(out, "MEAN") { + t.Errorf("output should contain MEAN row, got: %s", out) + } + if !strings.Contains(out, "set") { + t.Errorf("output should contain set name, got: %s", out) + } + if !strings.Contains(out, "..") { + t.Errorf("long query should be truncated with ..") + } +} + +func TestSaveLoadSummaryRoundtrip(t *testing.T) { + sum := &Summary{ + Name: "roundtrip", + When: time.Now().UTC(), + NumQueries: 1, + PerQuery: []PerQuery{{Query: "q", Got: []string{"a"}, Relevant: []string{"a"}, Metrics: Metrics{Recall1: 1}}}, + Mean: Metrics{Recall1: 1}, + } + p := filepath.Join(t.TempDir(), "sum.json") + if err := SaveSummary(sum, p); err != nil { + t.Fatalf("save: %v", err) + } + got, err := LoadSummary(p) + if err != nil { + t.Fatalf("load: %v", err) + } + if got.Name != sum.Name || got.NumQueries != 1 { + t.Errorf("bad roundtrip: %+v", got) + } + if got.Mean.Recall1 != 1 { + t.Errorf("mean lost: %+v", got.Mean) + } +} + +func TestLoadSummaryMissingFile(t *testing.T) { + if _, err := LoadSummary(filepath.Join(t.TempDir(), "nope.json")); err == nil { + t.Errorf("missing file: expected error") + } +} + +func TestLoadSummaryBadJSON(t *testing.T) { + p := filepath.Join(t.TempDir(), "bad.json") + _ = os.WriteFile(p, []byte("bad"), 0o644) + if _, err := LoadSummary(p); err == nil { + t.Errorf("bad JSON: expected error") + } +} + +func TestSaveSummaryWriteError(t *testing.T) { + // Write into a path whose parent does not exist. + bad := filepath.Join(t.TempDir(), "missing-dir", "x", "y.json") + if err := SaveSummary(&Summary{Name: "x"}, bad); err == nil { + t.Errorf("expected write error to nonexistent dir") + } +} + +func TestDiff(t *testing.T) { + b := &Summary{Mean: Metrics{Recall1: 0.5, Recall3: 0.6, Recall10: 0.7, MRR10: 0.5, NDCG10: 0.6}} + c := &Summary{Mean: Metrics{Recall1: 0.7, Recall3: 0.5, Recall10: 0.7, MRR10: 0.6, NDCG10: 0.5}} + out := Diff(b, c) + if !strings.Contains(out, "+0.200") { + t.Errorf("R@1 delta should be +0.200, got: %s", out) + } + if !strings.Contains(out, "-0.100") { + t.Errorf("R@3 delta should be -0.100, got: %s", out) + } + if !strings.Contains(out, "baseline") || !strings.Contains(out, "current") { + t.Errorf("output should have headers, got: %s", out) + } +} diff --git a/internal/store/pebble_uncovered_test.go b/internal/store/pebble_uncovered_test.go new file mode 100644 index 0000000..1307ac0 --- /dev/null +++ b/internal/store/pebble_uncovered_test.go @@ -0,0 +1,517 @@ +package store + +import ( + "context" + "path/filepath" + "strings" + "testing" + "time" +) + +func TestPebbleMetrics(t *testing.T) { + p := newPebbleStore(t) + m := p.Metrics() + if m == nil { + t.Errorf("Metrics returned nil") + } +} + +func TestPebbleCheckpoint(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://x/y", FetchedAt: time.Now()}) + dest := filepath.Join(t.TempDir(), "ckpt") + if err := p.Checkpoint(dest); err != nil { + t.Fatalf("Checkpoint: %v", err) + } + // Sanity-check the checkpoint directory exists and contains something. + p2, err := OpenPebble(dest) + if err != nil { + t.Fatalf("OpenPebble(ckpt): %v", err) + } + defer p2.Close() + got, err := p2.GetDocByURL(ctx, "https://x/y") + if err != nil { + t.Fatalf("read from checkpoint: %v", err) + } + if got.URL != "https://x/y" { + t.Errorf("checkpoint roundtrip mismatch: %+v", got) + } +} + +func TestPebbleVectorMetaRoundtrip(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + // Initial miss → ok=false, no error. + _, ok, err := p.GetVectorMeta(ctx) + if err != nil { + t.Fatalf("initial GetVectorMeta: %v", err) + } + if ok { + t.Errorf("expected miss on fresh store") + } + + blob := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + if err := p.PutVectorMeta(ctx, blob); err != nil { + t.Fatalf("PutVectorMeta: %v", err) + } + got, ok, err := p.GetVectorMeta(ctx) + if err != nil { + t.Fatalf("GetVectorMeta after put: %v", err) + } + if !ok { + t.Errorf("expected hit") + } + if string(got) != string(blob) { + t.Errorf("blob roundtrip: got %v want %v", got, blob) + } +} + +func TestPebbleVectorNodeRoundtrip(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + if err := p.PutVectorNode(ctx, 42, []byte("nodepayload")); err != nil { + t.Fatalf("PutVectorNode: %v", err) + } + + // Iterate and confirm we see it. + found := false + if err := p.IterateVectorNodes(ctx, func(id uint64, blob []byte) bool { + if id == 42 && string(blob) == "nodepayload" { + found = true + } + return true + }); err != nil { + t.Fatalf("IterateVectorNodes: %v", err) + } + if !found { + t.Errorf("expected to find node 42") + } +} + +func TestPebbleVectorNodesBatch(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + // Empty input is a no-op. + if err := p.PutVectorNodesBatch(ctx, nil); err != nil { + t.Errorf("empty batch: %v", err) + } + + entries := []VectorNodeEntry{ + {ID: 1, Blob: []byte("one")}, + {ID: 2, Blob: []byte("two")}, + {ID: 3, Blob: []byte("three")}, + } + if err := p.PutVectorNodesBatch(ctx, entries); err != nil { + t.Fatalf("PutVectorNodesBatch: %v", err) + } + + seen := map[uint64]string{} + _ = p.IterateVectorNodes(ctx, func(id uint64, blob []byte) bool { + seen[id] = string(blob) + return true + }) + if len(seen) != 3 || seen[1] != "one" || seen[2] != "two" || seen[3] != "three" { + t.Errorf("batch roundtrip: %+v", seen) + } +} + +func TestPebbleClearVectorFamily(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _ = p.PutVectorMeta(ctx, []byte("meta")) + _ = p.PutVectorNode(ctx, 1, []byte("one")) + + if err := p.ClearVectorFamily(ctx); err != nil { + t.Fatalf("ClearVectorFamily: %v", err) + } + if _, ok, _ := p.GetVectorMeta(ctx); ok { + t.Errorf("VectorMeta should be cleared") + } + count := 0 + _ = p.IterateVectorNodes(ctx, func(_ uint64, _ []byte) bool { + count++ + return true + }) + if count != 0 { + t.Errorf("expected 0 vector nodes after clear, got %d", count) + } +} + +func TestPebblePQCodebookRoundtrip(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + _, ok, err := p.GetPQCodebook(ctx) + if err != nil { + t.Fatalf("initial GetPQCodebook: %v", err) + } + if ok { + t.Errorf("expected miss on fresh store") + } + + blob := []byte("the-codebook-blob") + if err := p.PutPQCodebook(ctx, blob); err != nil { + t.Fatalf("PutPQCodebook: %v", err) + } + got, ok, err := p.GetPQCodebook(ctx) + if err != nil { + t.Fatalf("GetPQCodebook after put: %v", err) + } + if !ok || string(got) != string(blob) { + t.Errorf("roundtrip: ok=%v got=%v", ok, string(got)) + } +} + +func TestPebblePQCodesBatch(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + if err := p.PutPQCodesBatch(ctx, nil); err != nil { + t.Errorf("empty batch: %v", err) + } + + entries := []PQCodeEntry{ + {ID: 10, Blob: []byte("code-10")}, + {ID: 20, Blob: []byte("code-20")}, + } + if err := p.PutPQCodesBatch(ctx, entries); err != nil { + t.Fatalf("PutPQCodesBatch: %v", err) + } + seen := map[uint64]string{} + _ = p.IteratePQCodes(ctx, func(id uint64, blob []byte) bool { + seen[id] = string(blob) + return true + }) + if len(seen) != 2 || seen[10] != "code-10" || seen[20] != "code-20" { + t.Errorf("PQ batch roundtrip: %+v", seen) + } +} + +func TestPebbleClearPQFamily(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _ = p.PutPQCodebook(ctx, []byte("cb")) + _ = p.PutPQCodesBatch(ctx, []PQCodeEntry{{ID: 1, Blob: []byte("c1")}}) + + if err := p.ClearPQFamily(ctx); err != nil { + t.Fatalf("ClearPQFamily: %v", err) + } + if _, ok, _ := p.GetPQCodebook(ctx); ok { + t.Errorf("codebook should be cleared") + } + count := 0 + _ = p.IteratePQCodes(ctx, func(_ uint64, _ []byte) bool { + count++ + return true + }) + if count != 0 { + t.Errorf("expected 0 PQ codes after clear, got %d", count) + } +} + +func TestPebbleIteratePQCodesStops(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _ = p.PutPQCodesBatch(ctx, []PQCodeEntry{ + {ID: 1, Blob: []byte("a")}, + {ID: 2, Blob: []byte("b")}, + {ID: 3, Blob: []byte("c")}, + }) + count := 0 + _ = p.IteratePQCodes(ctx, func(_ uint64, _ []byte) bool { + count++ + return count < 2 // stop after second + }) + if count != 2 { + t.Errorf("iterator stop didn't honor false return: got %d", count) + } +} + +func TestPebbleUpsertDocumentBatch(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + // Empty input. + ids, err := p.UpsertDocumentBatch(ctx, nil) + if err != nil || ids != nil { + t.Errorf("empty batch: got ids=%v err=%v", ids, err) + } + + docs := []*Document{ + {URL: "https://a.com/1", Domain: "a.com", Title: "A1", Text: "body1", FetchedAt: time.Now()}, + {URL: "https://a.com/2", Domain: "a.com", Title: "A2", Text: "body2", FetchedAt: time.Now()}, + {URL: "https://b.com/1", Domain: "b.com", Title: "B1", Text: "body3", FetchedAt: time.Now()}, + } + ids, err = p.UpsertDocumentBatch(ctx, docs) + if err != nil { + t.Fatalf("UpsertDocumentBatch: %v", err) + } + if len(ids) != 3 { + t.Fatalf("ids: got %d want 3", len(ids)) + } + for i, id := range ids { + if id <= 0 { + t.Errorf("ids[%d] <= 0", i) + } + } + + // Re-upsert returns same IDs. + ids2, err := p.UpsertDocumentBatch(ctx, docs) + if err != nil { + t.Fatalf("re-upsert: %v", err) + } + for i, id := range ids2 { + if id != ids[i] { + t.Errorf("re-upsert ids[%d]: got %d want %d", i, id, ids[i]) + } + } +} + +func TestPebbleUpsertDocumentBatchEmptyURLRejected(t *testing.T) { + p := newPebbleStore(t) + docs := []*Document{ + {URL: "https://ok", FetchedAt: time.Now()}, + {URL: "", FetchedAt: time.Now()}, // bad + } + if _, err := p.UpsertDocumentBatch(context.Background(), docs); err == nil { + t.Errorf("expected error for empty URL in batch") + } +} + +func TestPebbleIterDocsLite(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + urls := []string{"https://a/1", "https://a/2", "https://b/1"} + for _, u := range urls { + if _, err := p.UpsertDocument(ctx, &Document{URL: u, FetchedAt: time.Now()}); err != nil { + t.Fatalf("upsert %s: %v", u, err) + } + } + + seen := map[string]bool{} + if err := p.IterDocsLite(ctx, func(_ int64, url string) error { + seen[url] = true + return nil + }); err != nil { + t.Fatalf("IterDocsLite: %v", err) + } + for _, u := range urls { + if !seen[u] { + t.Errorf("missed url %s in iter", u) + } + } +} + +func TestPebblePurgeFrontierByHost(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + // Empty host is a no-op. + if n, err := p.PurgeFrontierByHost(ctx, ""); err != nil || n != 0 { + t.Errorf("empty host: n=%d err=%v", n, err) + } + + _ = p.PushFrontier(ctx, "https://a.com/1", 0, 1) + _ = p.PushFrontier(ctx, "https://a.com/2", 0, 1) + _ = p.PushFrontier(ctx, "https://b.com/1", 0, 1) + + n, err := p.PurgeFrontierByHost(ctx, "a.com") + if err != nil { + t.Fatalf("PurgeFrontierByHost: %v", err) + } + if n != 2 { + t.Errorf("purged: got %d want 2", n) + } + + // b.com should still be present. + hosts, _ := p.CountQueuedPerHost(ctx, []string{"a.com", "b.com"}) + if hosts["a.com"] != 0 { + t.Errorf("a.com still has %d queued", hosts["a.com"]) + } + if hosts["b.com"] != 1 { + t.Errorf("b.com lost: %d", hosts["b.com"]) + } +} + +func TestPebbleCorpusStats(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + + // Empty store. + sum, count, err := p.CorpusStats(ctx) + if err != nil { + t.Fatalf("CorpusStats empty: %v", err) + } + if sum != 0 || count != 0 { + t.Errorf("empty corpus: sum=%d count=%d", sum, count) + } + + // Index something. + id, _ := p.UpsertDocument(ctx, &Document{URL: "u", FetchedAt: time.Now()}) + tokenize := func(s string) []string { return strings.Fields(s) } + if err := p.IndexDocument(ctx, id, "title here", "the quick brown fox", tokenize, 2); err != nil { + t.Fatalf("IndexDocument: %v", err) + } + sum, count, err = p.CorpusStats(ctx) + if err != nil { + t.Fatalf("CorpusStats: %v", err) + } + // 2 title tokens + 4 body tokens = 6 (titleBoost only affects tf, not doc_len). + if sum != 6 || count != 1 { + t.Errorf("sum=%d count=%d want sum=6 count=1", sum, count) + } +} + +func TestPebbleSumDocLengths(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + tokenize := func(s string) []string { return strings.Fields(s) } + + id1, _ := p.UpsertDocument(ctx, &Document{URL: "u1", FetchedAt: time.Now()}) + id2, _ := p.UpsertDocument(ctx, &Document{URL: "u2", FetchedAt: time.Now()}) + _ = p.IndexDocument(ctx, id1, "", "a b c", tokenize, 1) + _ = p.IndexDocument(ctx, id2, "", "d e", tokenize, 1) + + total, count, err := p.SumDocLengths(ctx) + if err != nil { + t.Fatalf("SumDocLengths: %v", err) + } + if total != 5 || count != 2 { + t.Errorf("total=%d count=%d want total=5 count=2", total, count) + } +} + +func TestPebbleListDomains(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://a.com/1", Domain: "a.com", FetchedAt: time.Now()}) + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://a.com/2", Domain: "a.com", FetchedAt: time.Now()}) + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://b.org/1", Domain: "b.org", FetchedAt: time.Now()}) + + list, total, err := p.ListDomains(ctx, "", 0, 10) + if err != nil { + t.Fatalf("ListDomains: %v", err) + } + if total != 2 { + t.Errorf("total: got %d want 2", total) + } + if len(list) != 2 { + t.Fatalf("list len: %d", len(list)) + } + // Sorted desc by count → a.com first. + if list[0].Host != "a.com" || list[0].Count != 2 { + t.Errorf("top: %+v", list[0]) + } + + // Substring filter. + filt, total, err := p.ListDomains(ctx, "b.", 0, 10) + if err != nil { + t.Fatalf("ListDomains filtered: %v", err) + } + if total != 1 || len(filt) != 1 || filt[0].Host != "b.org" { + t.Errorf("filter: total=%d list=%+v", total, filt) + } + + // Pagination beyond total. + page, total, err := p.ListDomains(ctx, "", 100, 10) + if err != nil { + t.Fatalf("ListDomains pagination: %v", err) + } + if len(page) != 0 || total != 2 { + t.Errorf("offset past end: got len=%d total=%d", len(page), total) + } + + // Default limit when limit<=0. + _, _, err = p.ListDomains(ctx, "", 0, 0) + if err != nil { + t.Errorf("default limit: %v", err) + } + + // Negative offset → clamped to 0. + all, _, err := p.ListDomains(ctx, "", -5, 10) + if err != nil { + t.Errorf("negative offset: %v", err) + } + if len(all) != 2 { + t.Errorf("negative offset should return all, got %d", len(all)) + } +} + +func TestPebbleTopDomains(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://a.com/1", Domain: "a.com", FetchedAt: time.Now()}) + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://a.com/2", Domain: "a.com", FetchedAt: time.Now()}) + _, _ = p.UpsertDocument(ctx, &Document{URL: "https://b.org/1", Domain: "b.org", FetchedAt: time.Now()}) + + top, err := p.TopDomains(ctx, 5) + if err != nil { + t.Fatalf("TopDomains: %v", err) + } + if len(top) != 2 { + t.Fatalf("len: got %d want 2", len(top)) + } + if top[0].Host != "a.com" || top[0].Count != 2 { + t.Errorf("top: %+v", top[0]) + } + + // Default topN. + if _, err := p.TopDomains(ctx, 0); err != nil { + t.Errorf("default topN: %v", err) + } +} + +func TestPebbleTopQueuedHosts(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + _ = p.PushFrontier(ctx, "https://a.com/1", 0, 1) + _ = p.PushFrontier(ctx, "https://a.com/2", 0, 1) + _ = p.PushFrontier(ctx, "https://b.com/1", 0, 1) + + top, err := p.TopQueuedHosts(ctx, 5) + if err != nil { + t.Fatalf("TopQueuedHosts: %v", err) + } + if len(top) != 2 { + t.Fatalf("len: got %d want 2", len(top)) + } + if top[0].Host != "a.com" || top[0].Count != 2 { + t.Errorf("top: %+v", top[0]) + } + + // Default topN. + if _, err := p.TopQueuedHosts(ctx, 0); err != nil { + t.Errorf("default topN: %v", err) + } +} + +func TestPebbleGetDocMetaMiss(t *testing.T) { + p := newPebbleStore(t) + _, _, ok, err := p.GetDocMeta(context.Background(), 9999) + if err != nil { + t.Errorf("miss: %v", err) + } + if ok { + t.Errorf("expected miss") + } +} + +func TestPebbleGetDocMetaHit(t *testing.T) { + p := newPebbleStore(t) + ctx := context.Background() + id, _ := p.UpsertDocument(ctx, &Document{URL: "u", Title: "T", FetchedAt: time.Now()}) + url, title, ok, err := p.GetDocMeta(ctx, id) + if err != nil { + t.Fatalf("GetDocMeta: %v", err) + } + if !ok { + t.Errorf("expected hit") + } + if url != "u" || title != "T" { + t.Errorf("meta: url=%q title=%q", url, title) + } +} diff --git a/internal/store/store_uncovered_test.go b/internal/store/store_uncovered_test.go new file mode 100644 index 0000000..fc59360 --- /dev/null +++ b/internal/store/store_uncovered_test.go @@ -0,0 +1,327 @@ +package store + +import ( + "context" + "testing" + "time" +) + +func TestStats(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Empty store. + st, err := s.Stats(ctx) + if err != nil { + t.Fatalf("Stats empty: %v", err) + } + if st.Documents != 0 || st.Terms != 0 { + t.Errorf("empty stats: got %+v", st) + } + + // One doc → Documents=1, Terms still 0 (we don't index terms in this path). + _, _ = s.UpsertDocument(ctx, &Document{URL: "u1", Source: "t", FetchedAt: time.Now()}) + st, err = s.Stats(ctx) + if err != nil { + t.Fatalf("Stats: %v", err) + } + if st.Documents != 1 { + t.Errorf("Documents: got %d want 1", st.Documents) + } +} + +func TestGetDocTexts(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Empty input → empty map, no error. + out, err := s.GetDocTexts(ctx, nil, 0) + if err != nil { + t.Fatalf("empty: %v", err) + } + if len(out) != 0 { + t.Errorf("expected empty result, got %d", len(out)) + } + + // Insert two docs with text. + _, _ = s.UpsertDocument(ctx, &Document{URL: "u1", Text: "hello world this is body one", Source: "t", FetchedAt: time.Now()}) + _, _ = s.UpsertDocument(ctx, &Document{URL: "u2", Text: "shorter", Source: "t", FetchedAt: time.Now()}) + + got, err := s.GetDocTexts(ctx, []string{"u1", "u2", "nonexistent"}, 0) + if err != nil { + t.Fatalf("GetDocTexts: %v", err) + } + if got["u1"] != "hello world this is body one" { + t.Errorf("u1: %q", got["u1"]) + } + if got["u2"] != "shorter" { + t.Errorf("u2: %q", got["u2"]) + } + if _, ok := got["nonexistent"]; ok { + t.Errorf("nonexistent should be absent") + } +} + +func TestGetDocTextsMaxLen(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _, _ = s.UpsertDocument(ctx, &Document{URL: "u", Text: "abcdefghijklmnopqrstuvwxyz", Source: "t", FetchedAt: time.Now()}) + got, err := s.GetDocTexts(ctx, []string{"u"}, 10) + if err != nil { + t.Fatalf("GetDocTexts: %v", err) + } + if len(got["u"]) != 10 { + t.Errorf("maxLen=10 should truncate, got len=%d: %q", len(got["u"]), got["u"]) + } +} + +func TestListDocSitemapEntries(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + now := time.Now().Unix() + + _, _ = s.UpsertDocument(ctx, &Document{URL: "u1", Source: "t", FetchedAt: time.Now(), LastChangedAt: now}) + _, _ = s.UpsertDocument(ctx, &Document{URL: "u2", Source: "t", FetchedAt: time.Now()}) + + entries, err := s.ListDocSitemapEntries(ctx, 100) + if err != nil { + t.Fatalf("ListDocSitemapEntries: %v", err) + } + if len(entries) != 2 { + t.Fatalf("entries: got %d want 2", len(entries)) + } + // Ordering is by id ASC: u1 first, u2 second. + if entries[0].URL != "u1" || entries[1].URL != "u2" { + t.Errorf("order: %+v", entries) + } + if entries[0].LastChangedAt.IsZero() { + t.Errorf("u1 should have LastChangedAt set") + } + if !entries[1].LastChangedAt.IsZero() { + t.Errorf("u2 should have zero LastChangedAt") + } +} + +func TestListDocSitemapEntriesNoLimit(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + _, _ = s.UpsertDocument(ctx, &Document{URL: "u", Source: "t", FetchedAt: time.Now()}) + + entries, err := s.ListDocSitemapEntries(ctx, 0) + if err != nil { + t.Fatalf("ListDocSitemapEntries: %v", err) + } + if len(entries) != 1 { + t.Errorf("limit=0 (no limit): got %d want 1", len(entries)) + } +} + +func TestCountByDomain(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Three docs across two domains. + _, _ = s.UpsertDocument(ctx, &Document{URL: "https://a.com/1", Domain: "a.com", Source: "t", FetchedAt: time.Now()}) + _, _ = s.UpsertDocument(ctx, &Document{URL: "https://a.com/2", Domain: "a.com", Source: "t", FetchedAt: time.Now()}) + _, _ = s.UpsertDocument(ctx, &Document{URL: "https://b.com/1", Domain: "b.com", Source: "t", FetchedAt: time.Now()}) + + counts, err := s.CountByDomain(ctx, 10) + if err != nil { + t.Fatalf("CountByDomain: %v", err) + } + if counts["a.com"] != 2 { + t.Errorf("a.com: got %d want 2", counts["a.com"]) + } + if counts["b.com"] != 1 { + t.Errorf("b.com: got %d want 1", counts["b.com"]) + } +} + +func TestCountByDomainDefaultTopN(t *testing.T) { + s := newTestStore(t) + // topN<=0 should use default of 20. + if _, err := s.CountByDomain(context.Background(), 0); err != nil { + t.Errorf("CountByDomain default topN: %v", err) + } +} + +func TestCountPassagesAllModelsAndParaphrases(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + id, _ := s.UpsertDocument(ctx, &Document{URL: "u", Source: "t", FetchedAt: time.Now()}) + + _ = s.UpsertPassage(ctx, &Passage{DocID: id, Offset: 0, Model: "a", Embedding: []float32{1, 0}}) + _ = s.UpsertPassage(ctx, &Passage{DocID: id, Offset: 10, Model: "b", Embedding: []float32{0, 1}}) + + n, err := s.CountPassagesAllModels(ctx) + if err != nil { + t.Fatalf("CountPassagesAllModels: %v", err) + } + if n != 2 { + t.Errorf("got %d want 2", n) + } + + pp, err := s.CountParaphrases(ctx) + if err != nil { + t.Fatalf("CountParaphrases empty: %v", err) + } + if pp != 0 { + t.Errorf("empty paraphrases: got %d", pp) + } + _ = s.SaveParaphrases(ctx, "m", "q", []string{"p1"}) + _ = s.SaveParaphrases(ctx, "m", "q2", []string{"p2"}) + pp, _ = s.CountParaphrases(ctx) + if pp != 2 { + t.Errorf("after save: got %d want 2", pp) + } +} + +func TestCountHyDE(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + n, err := s.CountHyDE(ctx) + if err != nil { + t.Fatalf("empty: %v", err) + } + if n != 0 { + t.Errorf("empty hyde: got %d", n) + } + _ = s.SaveHyDE(ctx, "m", "q", "passage") + n, _ = s.CountHyDE(ctx) + if n != 1 { + t.Errorf("after save: got %d want 1", n) + } +} + +func TestCountDocsWithPublishedAt(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + pub := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + _, _ = s.UpsertDocument(ctx, &Document{URL: "with", Source: "t", FetchedAt: time.Now(), PublishedAt: pub}) + _, _ = s.UpsertDocument(ctx, &Document{URL: "without", Source: "t", FetchedAt: time.Now()}) + + n, err := s.CountDocsWithPublishedAt(ctx) + if err != nil { + t.Fatalf("CountDocsWithPublishedAt: %v", err) + } + if n != 1 { + t.Errorf("got %d want 1", n) + } +} + +func TestListOutcomes(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + got, err := s.ListOutcomes(ctx, 0) + if err != nil { + t.Fatalf("empty: %v", err) + } + if len(got) != 0 { + t.Errorf("empty outcomes: got %d", len(got)) + } + + for i := 0; i < 3; i++ { + _ = s.RecordOutcome(ctx, &Outcome{ + Query: "q", URL: "u", Score: float64(i), Useful: i%2 == 0, + Source: "test", RecordedAt: time.Now(), + }) + } + + got, err = s.ListOutcomes(ctx, 0) + if err != nil { + t.Fatalf("ListOutcomes: %v", err) + } + if len(got) != 3 { + t.Errorf("listed: got %d want 3", len(got)) + } + + got, _ = s.ListOutcomes(ctx, 1) + if len(got) != 1 { + t.Errorf("limit=1: got %d", len(got)) + } +} + +func TestVacuum(t *testing.T) { + s := newTestStore(t) + if err := s.Vacuum(context.Background()); err != nil { + t.Errorf("Vacuum: %v", err) + } +} + +func TestCountQueuedPerHost(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // Empty input. + m, err := s.CountQueuedPerHost(ctx, nil) + if err != nil { + t.Fatalf("empty: %v", err) + } + if len(m) != 0 { + t.Errorf("empty hosts: got %d", len(m)) + } + + // Two URLs from same host queued, one from another. + _ = s.PushFrontier(ctx, "https://a.com/x", 0, 1) + _ = s.PushFrontier(ctx, "https://a.com/y", 0, 1) + _ = s.PushFrontier(ctx, "https://b.com/z", 0, 1) + + m, err = s.CountQueuedPerHost(ctx, []string{"a.com", "b.com", "absent.com"}) + if err != nil { + t.Fatalf("CountQueuedPerHost: %v", err) + } + if m["a.com"] != 2 { + t.Errorf("a.com: got %d want 2", m["a.com"]) + } + if m["b.com"] != 1 { + t.Errorf("b.com: got %d want 1", m["b.com"]) + } + if _, ok := m["absent.com"]; ok { + t.Errorf("absent host should not appear") + } +} + +func TestRecrawlURLExisting(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + _ = s.PushFrontier(ctx, "https://x/y", 0, 1.0) + // Fail it so attempts > 0 and status = 'error'. + _ = s.FailFrontier(ctx, "https://x/y", "boom") + + if err := s.RecrawlURL(ctx, "https://x/y"); err != nil { + t.Fatalf("RecrawlURL: %v", err) + } + + // Verify status is back to queued. + var status string + var attempts int + _ = s.DB().QueryRowContext(ctx, `SELECT status, attempts FROM frontier WHERE url=?;`, "https://x/y").Scan(&status, &attempts) + if status != "queued" { + t.Errorf("status: got %q want queued", status) + } + if attempts != 0 { + t.Errorf("attempts: got %d want 0", attempts) + } +} + +func TestRecrawlURLNotPresent(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + + // URL not in frontier — should insert. + if err := s.RecrawlURL(ctx, "https://new/url"); err != nil { + t.Fatalf("RecrawlURL: %v", err) + } + var status string + if err := s.DB().QueryRowContext(ctx, `SELECT status FROM frontier WHERE url=?;`, "https://new/url").Scan(&status); err != nil { + t.Fatalf("query: %v", err) + } + if status != "queued" { + t.Errorf("status: got %q want queued", status) + } +}