Skip to content

Commit 9fbd221

Browse files
authored
feat(search): iterative deep research
* feat(search): iterative deep research * Harden deep research search loop * fixup
1 parent db542ba commit 9fbd221

2 files changed

Lines changed: 201 additions & 15 deletions

File tree

src/lib/formatting.ts

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export function formatConversationAsXml(messages: Message[]): string {
2828
}
2929

3030
/** Escape special characters for XML */
31-
function escapeXml(str: string): string {
31+
export function escapeXml(str: string): string {
3232
return str
3333
.replace(/&/g, "&")
3434
.replace(/</g, "&lt;")
@@ -58,8 +58,8 @@ export function formatNodesForPrompt(
5858
.map(
5959
(node) =>
6060
`<node id="${escapeXml(node.tempId)}" type="${escapeXml(node.type)}" timestamp="${node.timestamp}">
61-
<label>${node.label ?? ""}</label>
62-
<description>${node.description || ""}</description>
61+
<label>${escapeXml(node.label ?? "")}</label>
62+
<description>${escapeXml(node.description || "")}</description>
6363
</node>`,
6464
)
6565
.join("\n");
@@ -81,8 +81,8 @@ export function formatLabelDescList(
8181

8282
const xmlItems = items
8383
.map(
84-
(item) => `<item label="${escapeXml(item.label ?? "Unnamed")}"
85-
>${item.description ?? ""}</item>`,
84+
(item) =>
85+
`<item label="${escapeXml(item.label ?? "Unnamed")}">${escapeXml(item.description ?? "")}</item>`,
8686
)
8787
.join("\n");
8888
return `<items>
@@ -105,8 +105,8 @@ export type SearchResults = RerankResult<SearchGroups>;
105105
// Helpers for formatting individual result items
106106
function formatSearchNode(node: NodeSearchResult): string {
107107
return `<node type="${escapeXml(node.type)}" timestamp="${formatISO(node.timestamp)}">
108-
<label>${node.label ?? ""}</label>
109-
<description>${node.description ?? ""}</description>
108+
<label>${escapeXml(node.label ?? "")}</label>
109+
<description>${escapeXml(node.description ?? "")}</description>
110110
</node>`;
111111
}
112112

@@ -122,10 +122,14 @@ function formatSearchConnection(conn: OneHopNode): string {
122122
return `<edge from="${escapeXml(conn.sourceLabel ?? "")}" to="${escapeXml(
123123
conn.targetLabel ?? "",
124124
)}" type="${escapeXml(conn.edgeType)}" timestamp="${formatISO(conn.timestamp)}">
125-
<description>${conn.description ?? ""}</description>
125+
<description>${escapeXml(conn.description ?? "")}</description>
126126
</edge>`;
127127
}
128128

129+
function assertNever(value: never, message: string): never {
130+
throw new Error(message);
131+
}
132+
129133
/**
130134
* Formats reranked search results as an XML-like structure for LLM prompts.
131135
* Items are ordered by descending relevance and tagged by their group.
@@ -141,9 +145,51 @@ export function formatSearchResultsAsXml(results: SearchResults): string {
141145
return formatSearchEdge(r.item);
142146
case "connections":
143147
return formatSearchConnection(r.item);
148+
default:
149+
return assertNever(
150+
r.group,
151+
`[formatSearchResultsAsXml] Unhandled search result group: ${String(
152+
r.group,
153+
)}`,
154+
);
144155
}
145156
})
146157
.join("\n")
147158
: "";
148159
return body;
149160
}
161+
162+
export type SearchResultWithId = SearchResults[number] & { tempId: string };
163+
164+
/**
165+
* Format search results with temporary IDs so the LLM can reference them.
166+
*/
167+
export function formatSearchResultsWithIds(
168+
results: SearchResultWithId[],
169+
): string {
170+
const body = results.length
171+
? results
172+
.map((r) => {
173+
const inner = (() => {
174+
switch (r.group) {
175+
case "similarNodes":
176+
return formatSearchNode(r.item);
177+
case "similarEdges":
178+
return formatSearchEdge(r.item);
179+
case "connections":
180+
return formatSearchConnection(r.item);
181+
default:
182+
return assertNever(
183+
r.group,
184+
`[formatSearchResultsWithIds] Unhandled search result group: ${String(
185+
r.group,
186+
)}`,
187+
);
188+
}
189+
})();
190+
return `<result id="${escapeXml(r.tempId)}">${inner}</result>`;
191+
})
192+
.join("\n")
193+
: "";
194+
return body;
195+
}

src/lib/jobs/deep-research.ts

Lines changed: 147 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import { performStructuredAnalysis } from "../ai";
22
import { storeDeepResearchResult } from "../cache/deep-research-cache";
33
import { generateEmbeddings } from "../embeddings";
4+
import {
5+
escapeXml,
6+
formatSearchResultsWithIds,
7+
type SearchResultWithId,
8+
type SearchResults,
9+
} from "../formatting";
410
import {
511
findOneHopNodes,
612
findSimilarEdges,
@@ -14,6 +20,7 @@ import {
1420
DeepResearchJobInput,
1521
DeepResearchResult,
1622
} from "../schemas/deep-research";
23+
import { TemporaryIdMapper } from "../temporary-id-mapper";
1724
import { z } from "zod";
1825
import { DrizzleDB } from "~/db";
1926
import { useDatabase } from "~/utils/db";
@@ -28,6 +35,8 @@ type SearchGroups = {
2835

2936
// Default TTL for deep research results (24 hours)
3037
const DEFAULT_TTL_SECONDS = 24 * 60 * 60;
38+
// Maximum number of refinement loops
39+
const MAX_SEARCH_LOOPS = 4;
3140

3241
/**
3342
* Main job handler for deep research
@@ -42,8 +51,7 @@ export async function performDeepResearch(
4251
console.log(`Starting deep research for conversation ${conversationId}`);
4352

4453
try {
45-
// Get search queries based on recent conversation turns
46-
// Filter to only include user and assistant messages
54+
// Prepare initial queries based on recent conversation turns
4755
const recentMessages = messages
4856
.slice(-lastNMessages)
4957
.filter((m) => m.role === "user" || m.role === "assistant");
@@ -54,11 +62,16 @@ export async function performDeepResearch(
5462
return;
5563
}
5664

57-
// Execute search queries and aggregate results
58-
const searchResults = await executeDeepSearchQueries(db, userId, queries);
65+
// Run iterative search/refine loop
66+
const searchResults = await runIterativeSearch(
67+
db,
68+
userId,
69+
recentMessages,
70+
queries,
71+
);
5972

60-
// Process results and cache them
61-
await cacheDeepResearchResults(userId, conversationId, searchResults);
73+
// Cache the combined results
74+
await cacheDeepResearchResults(userId, conversationId, [searchResults]);
6275

6376
console.log(`Deep research completed for conversation ${conversationId}`);
6477
} catch (error) {
@@ -82,7 +95,7 @@ async function generateSearchQueries(
8295

8396
// Format messages for context
8497
const messageContext = messages
85-
.map((m) => `<message role="${m.role}">${m.content}</message>`)
98+
.map((m) => `<message role="${m.role}">${escapeXml(m.content)}</message>`)
8699
.join("\n");
87100

88101
// Use structured analysis to generate tangential search queries
@@ -110,6 +123,133 @@ Come up with 1-5 search queries that explore adjacent or less obvious connection
110123
}
111124
}
112125

126+
/**
127+
* Run iterative search with LLM refinement.
128+
*/
129+
async function runIterativeSearch(
130+
db: DrizzleDB,
131+
userId: string,
132+
messages: DeepResearchJobInput["messages"],
133+
initialQueries: string[],
134+
): Promise<RerankResult<SearchGroups>> {
135+
const queue = [...initialQueries];
136+
const history: string[] = [];
137+
let results: SearchResultWithId[] = [];
138+
let tempIdCounter = 0;
139+
const mapper = new TemporaryIdMapper<SearchResults[number], string>(
140+
() => `r${++tempIdCounter}`,
141+
);
142+
const seen = new Set<string>();
143+
let loops = 0;
144+
145+
while (loops < MAX_SEARCH_LOOPS && queue.length > 0) {
146+
const query = queue.shift()!;
147+
history.push(query);
148+
149+
const embResp = await generateEmbeddings({
150+
model: "jina-embeddings-v3",
151+
task: "retrieval.query",
152+
input: [query],
153+
truncate: true,
154+
});
155+
const embedding = embResp.data[0]?.embedding;
156+
if (embedding) {
157+
const res = await executeSearchWithEmbedding(
158+
db,
159+
userId,
160+
query,
161+
embedding,
162+
20,
163+
);
164+
if (res) {
165+
const dedup = res.filter((r) => {
166+
const key = `${r.group}:${r.item.id}`;
167+
if (seen.has(key)) return false;
168+
seen.add(key);
169+
return true;
170+
});
171+
results.push(...mapper.mapItems(dedup));
172+
}
173+
}
174+
175+
loops++;
176+
if (loops >= MAX_SEARCH_LOOPS) break;
177+
178+
const refinement = await refineSearchResults(
179+
userId,
180+
messages,
181+
history,
182+
results,
183+
);
184+
if (refinement.dropIds.length) {
185+
const drop = new Set(refinement.dropIds);
186+
results = results.filter((r) => !drop.has(r.tempId));
187+
}
188+
if (refinement.done) break;
189+
if (refinement.nextQuery) queue.push(refinement.nextQuery);
190+
}
191+
192+
return results.map(({ tempId, ...rest }) => rest);
193+
}
194+
195+
interface RefinementResult {
196+
dropIds: string[];
197+
done: boolean;
198+
nextQuery?: string;
199+
}
200+
201+
/**
202+
* Ask the LLM to refine search results.
203+
*/
204+
async function refineSearchResults(
205+
userId: string,
206+
messages: DeepResearchJobInput["messages"],
207+
queries: string[],
208+
results: SearchResultWithId[],
209+
): Promise<RefinementResult> {
210+
const schema = z
211+
.object({
212+
dropIds: z.array(z.string()).default([]),
213+
done: z.boolean(),
214+
nextQuery: z.string().optional(),
215+
})
216+
.describe("DeepResearchRefinement");
217+
218+
const messageContext = messages
219+
.map((m) => `<message role="${m.role}">${escapeXml(m.content)}</message>`)
220+
.join("\n");
221+
const queriesXml = queries
222+
.map((q) => `<query>${escapeXml(q)}</query>`)
223+
.join("\n");
224+
const resultsXml = formatSearchResultsWithIds(results);
225+
226+
try {
227+
return await performStructuredAnalysis({
228+
userId,
229+
systemPrompt: "You refine background search results.",
230+
prompt: `<conversation>
231+
${messageContext}
232+
</conversation>
233+
234+
<queries>
235+
${queriesXml}
236+
</queries>
237+
238+
<results>
239+
${resultsXml}
240+
</results>
241+
242+
<system:instruction>
243+
Remove irrelevant results by listing their ids in dropIds. If more searching is needed, set done=false and provide nextQuery. If satisfied, set done=true.
244+
</system:instruction>`,
245+
schema,
246+
});
247+
} catch (error) {
248+
console.error("Failed to refine deep search results:", error);
249+
return { dropIds: [], done: true };
250+
}
251+
}
252+
113253
/**
114254
* Execute multiple search queries in parallel with higher limits
115255
* and return combined results

0 commit comments

Comments
 (0)