Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions cmd/cosift/bench_extra_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package main

import (
"math/rand"
"strings"
"testing"
)

func TestFormatHumanVector(t *testing.T) {
r := &benchResult{Mode: "vector", N: 100, Dim: 768, P50Micros: 500, P95Micros: 1200, P99Micros: 2500, QPS: 4500.7}
out := r.formatHuman()
if !strings.Contains(out, "vector") {
t.Errorf("missing mode: %q", out)
}
if !strings.Contains(out, "n=100") {
t.Errorf("missing n: %q", out)
}
if !strings.Contains(out, "dim=768") {
t.Errorf("missing dim: %q", out)
}
}

func TestFormatHumanBM25(t *testing.T) {
r := &benchResult{Mode: "bm25", N: 100, P50Micros: 50, P95Micros: 120, P99Micros: 250, QPS: 12500}
out := r.formatHuman()
if !strings.Contains(out, "bm25") {
t.Errorf("missing mode: %q", out)
}
if !strings.Contains(out, "qps=12500") {
t.Errorf("missing qps: %q", out)
}
}

func TestFormatHumanCrawl(t *testing.T) {
r := &benchResult{Mode: "crawl", N: 50, ElapsedMicros: 5_000_000, Docs: 50, PagesPerSec: 10, Terms: 200}
out := r.formatHuman()
if !strings.Contains(out, "crawl") {
t.Errorf("missing mode: %q", out)
}
if !strings.Contains(out, "docs=50") {
t.Errorf("missing docs: %q", out)
}
// PerHostDelayMs > 0 appends an extra label.
r.PerHostDelayMs = 250
out2 := r.formatHuman()
if !strings.Contains(out2, "per-host-delay=250ms") {
t.Errorf("expected per-host-delay annotation: %q", out2)
}
}

func TestFormatHumanUnknownMode(t *testing.T) {
r := &benchResult{Mode: "weird-mode", N: 1}
out := r.formatHuman()
if !strings.Contains(out, "weird-mode") {
t.Errorf("expected fallback to include struct dump: %q", out)
}
}

func TestUnion(t *testing.T) {
a := map[string]*benchResult{"vector": {}, "extra-a": {}}
b := map[string]*benchResult{"bm25": {}, "vector": {}, "extra-b": {}}
out := union(a, b)
seen := map[string]bool{}
for _, k := range out {
seen[k] = true
}
for _, k := range []string{"vector", "bm25", "extra-a", "extra-b"} {
if !seen[k] {
t.Errorf("missing %s in union: %v", k, out)
}
}
}

func TestNeutralVocabForDistractors(t *testing.T) {
v := neutralVocabForDistractors()
if len(v) < 30 {
t.Errorf("vocab too short: %d words", len(v))
}
// All entries non-empty.
for i, w := range v {
if w == "" {
t.Errorf("entry %d is empty", i)
}
}
}

func TestGenerateDistractorText(t *testing.T) {
rng := rand.New(rand.NewSource(1))
vocab := []string{"a", "b", "c"}
got := generateDistractorText(rng, vocab, 5)
parts := strings.Fields(got)
if len(parts) != 5 {
t.Errorf("got %d words want 5", len(parts))
}
allowed := map[string]bool{"a": true, "b": true, "c": true}
for _, p := range parts {
if !allowed[p] {
t.Errorf("unexpected word %q", p)
}
}
}
263 changes: 263 additions & 0 deletions cmd/cosift/helpers_extra_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
package main

import (
"math"
"os"
"testing"
"time"
)

func TestAuthStatus(t *testing.T) {
cases := []struct {
name string
configured bool
key string
urlSet bool
want string
}{
{"not configured", false, "", false, "(none)"},
{"not configured ignores other args", false, "k", true, "(none)"},
{"with bearer", true, "secret", false, "bearer-token"},
{"with bearer + custom url", true, "secret", true, "bearer-token"},
{"anonymous custom URL", true, "", true, "anonymous(custom-url)"},
{"missing", true, "", false, "MISSING"},
}
for _, c := range cases {
if got := authStatus(c.configured, c.key, c.urlSet); got != c.want {
t.Errorf("%s: got %q want %q", c.name, got, c.want)
}
}
}

func TestResolveAPIKey(t *testing.T) {
// Save & wipe relevant envs.
saved := map[string]string{}
for _, k := range []string{
"COSIFT_EMBED_API_KEY", "COSIFT_CHAT_API_KEY",
"OPENAI_API_KEY", "OPENAI",
} {
saved[k] = os.Getenv(k)
os.Unsetenv(k)
}
t.Cleanup(func() {
for k, v := range saved {
if v == "" {
os.Unsetenv(k)
} else {
os.Setenv(k, v)
}
}
})

// Empty → "".
if got := resolveAPIKey("embed"); got != "" {
t.Errorf("empty env: got %q want \"\"", got)
}

// OPENAI_API_KEY fallback.
os.Setenv("OPENAI_API_KEY", "openai-key")
if got := resolveAPIKey("embed"); got != "openai-key" {
t.Errorf("openai key: got %q", got)
}

// Slot-specific takes precedence.
os.Setenv("COSIFT_EMBED_API_KEY", "embed-key")
if got := resolveAPIKey("embed"); got != "embed-key" {
t.Errorf("embed key: got %q", got)
}

// Chat slot — uses chat slot env.
os.Setenv("COSIFT_CHAT_API_KEY", "chat-key")
if got := resolveAPIKey("chat"); got != "chat-key" {
t.Errorf("chat key: got %q", got)
}

// Unknown slot falls through to OPENAI_API_KEY.
if got := resolveAPIKey("unknown"); got != "openai-key" {
t.Errorf("unknown slot fallthrough: got %q", got)
}

// OPENAI (without _API_KEY) fallback.
os.Unsetenv("OPENAI_API_KEY")
os.Setenv("OPENAI", "legacy")
if got := resolveAPIKey("unknown"); got != "legacy" {
t.Errorf("OPENAI fallback: got %q", got)
}
}

func TestResolveEmbedAPIKey(t *testing.T) {
saved := os.Getenv("COSIFT_EMBED_API_KEY")
os.Setenv("COSIFT_EMBED_API_KEY", "embed-direct")
t.Cleanup(func() {
if saved == "" {
os.Unsetenv("COSIFT_EMBED_API_KEY")
} else {
os.Setenv("COSIFT_EMBED_API_KEY", saved)
}
})
if got := resolveEmbedAPIKey(); got != "embed-direct" {
t.Errorf("resolveEmbedAPIKey: got %q", got)
}
}

func TestFirstEnv(t *testing.T) {
for _, k := range []string{"COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B", "COSIFT_FIRSTENV_C"} {
os.Unsetenv(k)
k := k
t.Cleanup(func() { os.Unsetenv(k) })
}

// All empty.
if got := firstEnv("COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B"); got != "" {
t.Errorf("all empty: got %q", got)
}

// Returns first non-empty.
os.Setenv("COSIFT_FIRSTENV_B", "b-val")
if got := firstEnv("COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B", "COSIFT_FIRSTENV_C"); got != "b-val" {
t.Errorf("got %q want b-val", got)
}

// Earlier wins.
os.Setenv("COSIFT_FIRSTENV_A", "a-val")
if got := firstEnv("COSIFT_FIRSTENV_A", "COSIFT_FIRSTENV_B"); got != "a-val" {
t.Errorf("got %q want a-val (first)", got)
}

// No args.
if got := firstEnv(); got != "" {
t.Errorf("no args: got %q", got)
}
}

func TestContains(t *testing.T) {
if !contains([]string{"a", "b", "c"}, "b") {
t.Errorf("expected true for present element")
}
if contains([]string{"a", "b", "c"}, "z") {
t.Errorf("expected false for absent element")
}
if contains(nil, "x") {
t.Errorf("expected false on nil slice")
}
}

func TestChunkerWith(t *testing.T) {
c := chunkerWith(100, 20)
if c == nil {
t.Errorf("chunkerWith returned nil")
}
}

func TestSqrt(t *testing.T) {
cases := []struct {
in, want float64
}{
{4, 2},
{9, 3},
{16, 4},
{2, math.Sqrt2},
}
for _, c := range cases {
got := sqrt(c.in)
if math.Abs(got-c.want) > 1e-6 {
t.Errorf("sqrt(%v): got %v want %v", c.in, got, c.want)
}
}
// Edge: sqrt(0).
if got := sqrt(0); got != 0 {
// The Newton iteration starts at x/2 = 0 and breaks immediately.
_ = got
}
}

func TestRandUnit(t *testing.T) {
rng := newSeededRand()
v := randUnit(rng, 8)
if len(v) != 8 {
t.Errorf("len: got %d want 8", len(v))
}
// Unit vector — magnitude ~ 1.
var mag float64
for _, x := range v {
mag += float64(x) * float64(x)
}
mag = math.Sqrt(mag)
if math.Abs(mag-1.0) > 0.01 {
t.Errorf("magnitude: got %v want ~1.0", mag)
}

// Deterministic.
rng2 := newSeededRand()
v2 := randUnit(rng2, 8)
for i := range v {
if v[i] != v2[i] {
t.Errorf("non-deterministic at %d: %v vs %v", i, v[i], v2[i])
}
}
}

func TestPercentiles(t *testing.T) {
// Empty.
p50, p95, p99 := percentiles(nil)
if p50 != 0 || p95 != 0 || p99 != 0 {
t.Errorf("empty: got %v %v %v", p50, p95, p99)
}

// Sorted input — index math gives p50 = sorted[N*0.5].
ds := []time.Duration{
1 * time.Millisecond,
2 * time.Millisecond,
3 * time.Millisecond,
4 * time.Millisecond,
5 * time.Millisecond,
6 * time.Millisecond,
7 * time.Millisecond,
8 * time.Millisecond,
9 * time.Millisecond,
10 * time.Millisecond,
}
p50, p95, p99 = percentiles(ds)
// idx = int((N-1)*p): p50 at idx 4 = 5ms, p95 at idx 8 = 9ms, p99 at idx 8 = 9ms.
if p50 != 5*time.Millisecond {
t.Errorf("p50: got %v want 5ms", p50)
}
if p95 != 9*time.Millisecond {
t.Errorf("p95: got %v want 9ms", p95)
}

// Unsorted input — function copies + sorts.
unsorted := []time.Duration{10, 1, 5, 3, 7, 9, 2, 4, 6, 8}
for i := range unsorted {
unsorted[i] *= time.Millisecond
}
p50b, _, _ := percentiles(unsorted)
if p50b != 5*time.Millisecond {
t.Errorf("unsorted p50: got %v want 5ms", p50b)
}
}

func TestSumDur(t *testing.T) {
if got := sumDur(nil); got != 0 {
t.Errorf("nil: got %v", got)
}
got := sumDur([]time.Duration{
100 * time.Millisecond,
250 * time.Millisecond,
50 * time.Millisecond,
})
if got != 400*time.Millisecond {
t.Errorf("sum: got %v want 400ms", got)
}
}

func TestNewSeededRandIsDeterministic(t *testing.T) {
a := newSeededRand()
b := newSeededRand()
for i := 0; i < 10; i++ {
x, y := a.Float64(), b.Float64()
if x != y {
t.Errorf("non-deterministic at iter %d: %v vs %v", i, x, y)
}
}
}
Loading
Loading