Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions src/lib/knowledge/__tests__/search-vector-shape.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import { describe, it, expect, vi, beforeEach } from "vitest";

// ---------------------------------------------------------------------------
// Phase 0e regression-guard tests for searchKnowledgeBase
// ---------------------------------------------------------------------------
//
// Background: searchKnowledgeBase previously ran two separate Prisma calls —
// await prisma.$executeRaw(`SET LOCAL hnsw.ef_search = …`)
// await prisma.$queryRaw(...)
// — which Prisma could distribute across two pool connections, silently
// reverting the ef_search tuning before the SELECT.
//
// Phase 0e patched this to:
// await prisma.$transaction(async (tx) => {
// await tx.$executeRaw(`SET LOCAL hnsw.ef_search = …`)
// return tx.$queryRaw(...)
// })
//
// These tests verify the SHAPE of the fix:
// - $transaction is used (not direct prisma.$executeRaw)
// - SET LOCAL runs on the tx (not the outer client)
// - the SELECT runs on the same tx
//
// They cannot verify pool-survival semantics — that needs a real Postgres
// integration harness (TODO for Phase 1+).

vi.mock("@/lib/prisma", () => {
// Factory must be self-contained — vitest hoists vi.mock to top of file,
// so we cannot reference top-level variables here. State is exposed via
// helpers below that read from the mocked module after import.
const tx = {
$executeRaw: vi.fn(),
$queryRaw: vi.fn(),
};
return {
prisma: {
knowledgeBase: {
findUnique: vi.fn(),
},
$transaction: vi.fn(
async (fn: (tx: typeof tx) => Promise<unknown>) => fn(tx),
),
// Legacy direct paths — must NOT be used by the patched implementation
$executeRaw: vi.fn(),
$queryRaw: vi.fn(),
_tx: tx,
},
};
});

vi.mock("@/lib/logger", () => ({
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn() },
}));

vi.mock("@/lib/observability/metrics", () => ({
recordMetric: vi.fn(),
}));

vi.mock("../embeddings", () => ({
generateEmbedding: vi.fn(async () => new Array(1536).fill(0.1)),
}));

vi.mock("../embedding-cache", () => ({
getCachedQueryEmbedding: vi.fn(async () => null),
setCachedQueryEmbedding: vi.fn(async () => undefined),
}));

import { searchKnowledgeBase } from "../search";
import { prisma } from "@/lib/prisma";

// Pull mocked instances back out for assertions
const mockPrisma = prisma as unknown as {
$transaction: ReturnType<typeof vi.fn>;
$executeRaw: ReturnType<typeof vi.fn>;
$queryRaw: ReturnType<typeof vi.fn>;
_tx: {
$executeRaw: ReturnType<typeof vi.fn>;
$queryRaw: ReturnType<typeof vi.fn>;
};
};

beforeEach(() => {
vi.clearAllMocks();
mockPrisma._tx.$queryRaw.mockResolvedValue([]);
});

describe("searchKnowledgeBase — Phase 0e $transaction shape", () => {
it("opens a $transaction (does not run SET LOCAL on the outer prisma)", async () => {
await searchKnowledgeBase("kb-1", "hello world", 5);

expect(mockPrisma.$transaction).toHaveBeenCalledOnce();

// The pre-Phase-0e bug shape (SET LOCAL on outer prisma) must not appear.
expect(mockPrisma.$executeRaw).not.toHaveBeenCalled();
expect(mockPrisma.$queryRaw).not.toHaveBeenCalled();
});

it("runs SET LOCAL hnsw.ef_search on the tx, before the SELECT", async () => {
const callOrder: string[] = [];
mockPrisma._tx.$executeRaw.mockImplementation(async () => {
callOrder.push("set_local");
return 0;
});
mockPrisma._tx.$queryRaw.mockImplementation(async () => {
callOrder.push("select");
return [];
});

await searchKnowledgeBase("kb-1", "hello world", 5);

expect(callOrder).toEqual(["set_local", "select"]);
expect(mockPrisma._tx.$executeRaw).toHaveBeenCalledOnce();
expect(mockPrisma._tx.$queryRaw).toHaveBeenCalledOnce();
});

it("returns whatever tx.$queryRaw returns, mapped through the result shape", async () => {
mockPrisma._tx.$queryRaw.mockResolvedValue([
{
id: "chunk-A",
content: "result content",
similarity: 0.91,
sourceId: "src-1",
sourceName: "Doc",
sourceType: "pdf",
metadata: null,
},
]);

const result = await searchKnowledgeBase("kb-1", "hello world", 5);

expect(result).toHaveLength(1);
expect(result[0].chunkId).toBe("chunk-A");
expect(result[0].similarity).toBe(0.91);
});
});

// TODO(rls-phase-1): integration test against a real Postgres harness to
// verify ef_search actually survives onto the SELECT query — mocked tests
// only verify call shape, not the pool-survival semantics being fixed.
41 changes: 24 additions & 17 deletions src/lib/knowledge/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,31 @@ export async function searchKnowledgeBase(
// ($1 syntax is rejected with "syntax error at or near $1").
// efSearch is always 40 | 60 | 100 — a computed integer, never user input — so
// Prisma.raw() injection is safe here.
await prisma.$executeRaw(Prisma.sql`SET LOCAL hnsw.ef_search = ${Prisma.raw(String(efSearch))}`);

//
// $transaction wrapper: SET LOCAL only persists for the lifetime of the
// current transaction. Without $transaction, Prisma's pool may run the
// SET LOCAL on one connection and the SELECT on another — the ef_search
// tuning would silently revert to the server default. Pin both on the
// same connection by wrapping in a single transaction.
const searchStart = performance.now();
const results = await prisma.$queryRaw<VectorSearchRow[]>(
Prisma.sql`
SELECT
c."id", c."content",
1 - (c."embedding" <=> ${vectorStr}::vector) as similarity,
c."sourceId", s."name" as "sourceName", s."type" as "sourceType", c."metadata"
FROM "KBChunk" c
INNER JOIN "KBSource" s ON c."sourceId" = s."id"
WHERE s."knowledgeBaseId" = ${knowledgeBaseId}
AND s."status" = 'READY'
AND c."embedding" IS NOT NULL
ORDER BY c."embedding" <=> ${vectorStr}::vector
LIMIT ${topK}
`
);
const results = await prisma.$transaction(async (tx) => {
await tx.$executeRaw(Prisma.sql`SET LOCAL hnsw.ef_search = ${Prisma.raw(String(efSearch))}`);
return tx.$queryRaw<VectorSearchRow[]>(
Prisma.sql`
SELECT
c."id", c."content",
1 - (c."embedding" <=> ${vectorStr}::vector) as similarity,
c."sourceId", s."name" as "sourceName", s."type" as "sourceType", c."metadata"
FROM "KBChunk" c
INNER JOIN "KBSource" s ON c."sourceId" = s."id"
WHERE s."knowledgeBaseId" = ${knowledgeBaseId}
AND s."status" = 'READY'
AND c."embedding" IS NOT NULL
ORDER BY c."embedding" <=> ${vectorStr}::vector
LIMIT ${topK}
`
);
});
const searchDurationMs = performance.now() - searchStart;
recordMetric("kb.search.vector_query_ms", searchDurationMs, "ms", {
knowledgeBaseId,
Expand Down
39 changes: 32 additions & 7 deletions src/lib/memory/__tests__/hot-cold-tier.test.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,34 @@
import { describe, it, expect, vi, beforeEach } from "vitest";

// Mock prisma before importing module
vi.mock("@/lib/prisma", () => ({
prisma: {
agentMemory: {
findMany: vi.fn(),
},
vi.mock("@/lib/prisma", () => {
// Tx mock — receives the same shape we expect Prisma's TransactionClient
// to expose for the SET LOCAL + raw query pair in getColdMemories.
const tx = {
$executeRawUnsafe: vi.fn(),
$queryRawUnsafe: vi.fn(),
},
}));
};
return {
prisma: {
agentMemory: {
findMany: vi.fn(),
},
// Legacy $executeRawUnsafe / $queryRawUnsafe references on the outer
// client are kept so tests that assert on them keep working. The new
// $transaction route is used by getColdMemories under the hot-cold-tier
// patch (Phase 0e) — the tx mock is what callbacks actually receive.
$executeRawUnsafe: tx.$executeRawUnsafe,
$queryRawUnsafe: tx.$queryRawUnsafe,
$transaction: vi.fn(async (fn: (tx: typeof tx_) => Promise<unknown>) =>
fn(tx as never),
),
_tx: tx,
},
};
});

// Helper alias for the tx callback signature above
type tx_ = { $executeRawUnsafe: typeof vi.fn; $queryRawUnsafe: typeof vi.fn };

vi.mock("@/lib/logger", () => ({
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn() },
Expand All @@ -35,8 +54,14 @@ import {
import type { RuntimeContext } from "@/lib/runtime/types";

const mockFindMany = prisma.agentMemory.findMany as ReturnType<typeof vi.fn>;
// For getColdMemories: the patched (Phase 0e) implementation runs
// SET LOCAL and the raw SELECT inside prisma.$transaction(...), so they're
// invoked on the TX client, not the outer prisma. The mock above wires
// the tx's $executeRawUnsafe/$queryRawUnsafe to be the same fn references
// exposed on the outer prisma, so these aliases assert on either path.
const mockQueryRaw = prisma.$queryRawUnsafe as ReturnType<typeof vi.fn>;
const mockExecRaw = prisma.$executeRawUnsafe as ReturnType<typeof vi.fn>;
const mockTransaction = (prisma as unknown as { $transaction: ReturnType<typeof vi.fn> }).$transaction;

function makeMemory(overrides: Record<string, unknown> = {}) {
return {
Expand Down
58 changes: 32 additions & 26 deletions src/lib/memory/hot-cold-tier.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,32 +110,38 @@ export async function getColdMemories(

const vectorStr = `[${embedding.join(",")}]`;

await prisma.$executeRawUnsafe("SET LOCAL hnsw.ef_search = 40");

const results = await prisma.$queryRawUnsafe<
Array<{
id: string;
key: string;
value: unknown;
category: string;
importance: number;
accessCount: number;
accessedAt: Date;
createdAt: Date;
similarity: number;
}>
>(
`SELECT id, key, value, category, importance, "accessCount", "accessedAt", "createdAt",
1 - (embedding <=> $1::vector) as similarity
FROM "AgentMemory"
WHERE "agentId" = $2
AND embedding IS NOT NULL
ORDER BY embedding <=> $1::vector
LIMIT $3`,
vectorStr,
agentId,
topK,
);
// $transaction wrapper: SET LOCAL only persists for the lifetime of the
// current transaction. Without $transaction, Prisma's pool may run the
// SET LOCAL on one connection and the SELECT on another — the ef_search
// tuning would silently revert to the server default. Pin both on the
// same connection by wrapping in a single transaction.
const results = await prisma.$transaction(async (tx) => {
await tx.$executeRawUnsafe("SET LOCAL hnsw.ef_search = 40");
return tx.$queryRawUnsafe<
Array<{
id: string;
key: string;
value: unknown;
category: string;
importance: number;
accessCount: number;
accessedAt: Date;
createdAt: Date;
similarity: number;
}>
>(
`SELECT id, key, value, category, importance, "accessCount", "accessedAt", "createdAt",
1 - (embedding <=> $1::vector) as similarity
FROM "AgentMemory"
WHERE "agentId" = $2
AND embedding IS NOT NULL
ORDER BY embedding <=> $1::vector
LIMIT $3`,
vectorStr,
agentId,
topK,
);
});

return results.filter((r) => r.similarity > 0.3);
} catch (error) {
Expand Down
Loading
Loading