|
1 | 1 | import { SessionState, ToolParameterEntry, WithParts } from "../state" |
2 | | -import { countTokens, extractToolContent } from "../strategies/utils" |
3 | | -import { clog, C } from "../compress-logger" |
| 2 | +import { countTokens } from "../strategies/utils" |
4 | 3 | import { isIgnoredUserMessage } from "../messages/utils" |
5 | 4 |
|
6 | 5 | function extractParameterKey(tool: string, parameters: any): string { |
@@ -178,94 +177,6 @@ export function formatSessionMap( |
178 | 177 | return `│${bar.join("")}│` |
179 | 178 | } |
180 | 179 |
|
181 | | -export interface CompressionGraphData { |
182 | | - systemTokens: number |
183 | | - recentCompressedTokens: number |
184 | | - olderCompressedTokens: number |
185 | | - remainingTokens: number |
186 | | - totalSessionTokens: number |
187 | | - segments: CompressionGraphSegment[] |
188 | | -} |
189 | | - |
190 | | -type CompressionGraphSegmentType = "system" | "recentCompressed" | "olderCompressed" | "inContext" |
191 | | - |
192 | | -export interface CompressionGraphSegment { |
193 | | - type: CompressionGraphSegmentType |
194 | | - tokens: number |
195 | | -} |
196 | | - |
197 | | -function appendGraphSegment( |
198 | | - segments: CompressionGraphSegment[], |
199 | | - type: CompressionGraphSegmentType, |
200 | | - tokens: number, |
201 | | -): void { |
202 | | - if (tokens <= 0) { |
203 | | - return |
204 | | - } |
205 | | - |
206 | | - const last = segments[segments.length - 1] |
207 | | - if (last && last.type === type) { |
208 | | - last.tokens += tokens |
209 | | - return |
210 | | - } |
211 | | - |
212 | | - segments.push({ type, tokens }) |
213 | | -} |
214 | | - |
215 | | -function incrementMapValue(map: Map<string, number>, key: string, value: number): void { |
216 | | - if (value <= 0) { |
217 | | - return |
218 | | - } |
219 | | - map.set(key, (map.get(key) || 0) + value) |
220 | | -} |
221 | | - |
222 | | -function countMessageTokensExcludingPrunedTools(state: SessionState, msg: WithParts): number { |
223 | | - const parts = Array.isArray(msg.parts) ? msg.parts : [] |
224 | | - const texts: string[] = [] |
225 | | - |
226 | | - for (const part of parts) { |
227 | | - if ((part as any).ignored) { |
228 | | - continue |
229 | | - } |
230 | | - |
231 | | - if (part.type === "text") { |
232 | | - texts.push(part.text) |
233 | | - continue |
234 | | - } |
235 | | - |
236 | | - if (part.type !== "tool") { |
237 | | - continue |
238 | | - } |
239 | | - |
240 | | - if (!part.callID || state.prune.tools.has(part.callID)) { |
241 | | - continue |
242 | | - } |
243 | | - |
244 | | - texts.push(...extractToolContent(part)) |
245 | | - } |
246 | | - |
247 | | - if (texts.length === 0) { |
248 | | - return 0 |
249 | | - } |
250 | | - return countTokens(texts.join(" ")) |
251 | | -} |
252 | | - |
253 | | -function buildToolParentMap(messages: WithParts[]): Map<string, string> { |
254 | | - const map = new Map<string, string>() |
255 | | - |
256 | | - for (const msg of messages) { |
257 | | - const parts = Array.isArray(msg.parts) ? msg.parts : [] |
258 | | - for (const part of parts) { |
259 | | - if (part.type !== "tool" || !part.callID) { |
260 | | - continue |
261 | | - } |
262 | | - map.set(part.callID, msg.info.id) |
263 | | - } |
264 | | - } |
265 | | - |
266 | | - return map |
267 | | -} |
268 | | - |
269 | 180 | export function cacheSystemPromptTokens(state: SessionState, messages: WithParts[]): void { |
270 | 181 | let firstInputTokens = 0 |
271 | 182 | for (const msg of messages) { |
@@ -304,208 +215,6 @@ export function cacheSystemPromptTokens(state: SessionState, messages: WithParts |
304 | 215 | state.systemPromptTokens = estimatedSystemTokens > 0 ? estimatedSystemTokens : undefined |
305 | 216 | } |
306 | 217 |
|
307 | | -export function buildCompressionGraphData( |
308 | | - state: SessionState, |
309 | | - messages: WithParts[], |
310 | | - newMessageIds: Set<string>, |
311 | | - newToolIds: Set<string>, |
312 | | -): CompressionGraphData { |
313 | | - const toolParentMap = buildToolParentMap(messages) |
314 | | - const prunedMessageIds = new Set(state.prune.messages.keys()) |
315 | | - const messageIds = new Set(messages.map((m) => m.info.id)) |
316 | | - |
317 | | - let compressedMessageTokens = 0 |
318 | | - for (const tokens of state.prune.messages.values()) { |
319 | | - compressedMessageTokens += tokens |
320 | | - } |
321 | | - |
322 | | - const recentStandaloneByMessage = new Map<string, number>() |
323 | | - const olderStandaloneByMessage = new Map<string, number>() |
324 | | - |
325 | | - let unparentedRecentStandaloneTokens = 0 |
326 | | - let unparentedOlderStandaloneTokens = 0 |
327 | | - let compressedStandaloneToolTokens = 0 |
328 | | - let recentStandaloneToolTokens = 0 |
329 | | - for (const [toolId, toolTokens] of state.prune.tools.entries()) { |
330 | | - const parentMessageId = toolParentMap.get(toolId) |
331 | | - if (parentMessageId && prunedMessageIds.has(parentMessageId)) { |
332 | | - continue |
333 | | - } |
334 | | - |
335 | | - compressedStandaloneToolTokens += toolTokens |
336 | | - |
337 | | - const isRecent = newToolIds.has(toolId) |
338 | | - if (isRecent) { |
339 | | - recentStandaloneToolTokens += toolTokens |
340 | | - } |
341 | | - |
342 | | - if (parentMessageId) { |
343 | | - incrementMapValue( |
344 | | - isRecent ? recentStandaloneByMessage : olderStandaloneByMessage, |
345 | | - parentMessageId, |
346 | | - toolTokens, |
347 | | - ) |
348 | | - } else if (isRecent) { |
349 | | - unparentedRecentStandaloneTokens += toolTokens |
350 | | - } else { |
351 | | - unparentedOlderStandaloneTokens += toolTokens |
352 | | - } |
353 | | - } |
354 | | - |
355 | | - const compressedTotalTokens = compressedMessageTokens + compressedStandaloneToolTokens |
356 | | - |
357 | | - let recentMessageTokens = 0 |
358 | | - for (const messageId of newMessageIds) { |
359 | | - recentMessageTokens += state.prune.messages.get(messageId) || 0 |
360 | | - } |
361 | | - |
362 | | - const recentCompressedTokens = recentMessageTokens + recentStandaloneToolTokens |
363 | | - const olderCompressedTokens = Math.max(0, compressedTotalTokens - recentCompressedTokens) |
364 | | - |
365 | | - const summaryTokensByAnchor = new Map<string, number>() |
366 | | - let summaryTokensTotal = 0 |
367 | | - for (const summary of state.compressSummaries) { |
368 | | - if (!messageIds.has(summary.anchorMessageId)) { |
369 | | - continue |
370 | | - } |
371 | | - |
372 | | - const tokens = countTokens(summary.summary) |
373 | | - summaryTokensTotal += tokens |
374 | | - incrementMapValue(summaryTokensByAnchor, summary.anchorMessageId, tokens) |
375 | | - } |
376 | | - |
377 | | - let remainingTokens = 0 |
378 | | - |
379 | | - for (const msg of messages) { |
380 | | - if (prunedMessageIds.has(msg.info.id)) { |
381 | | - continue |
382 | | - } |
383 | | - if (msg.info.role === "user" && isIgnoredUserMessage(msg)) { |
384 | | - continue |
385 | | - } |
386 | | - remainingTokens += countMessageTokensExcludingPrunedTools(state, msg) |
387 | | - } |
388 | | - |
389 | | - remainingTokens += summaryTokensTotal |
390 | | - |
391 | | - const systemTokens = state.systemPromptTokens ?? 0 |
392 | | - const totalSessionTokens = |
393 | | - systemTokens + recentCompressedTokens + olderCompressedTokens + remainingTokens |
394 | | - |
395 | | - const segments: CompressionGraphSegment[] = [] |
396 | | - appendGraphSegment(segments, "system", systemTokens) |
397 | | - |
398 | | - for (const msg of messages) { |
399 | | - const messageId = msg.info.id |
400 | | - const summaryTokens = summaryTokensByAnchor.get(messageId) || 0 |
401 | | - appendGraphSegment(segments, "inContext", summaryTokens) |
402 | | - |
403 | | - if (prunedMessageIds.has(messageId)) { |
404 | | - const messageTokens = state.prune.messages.get(messageId) || 0 |
405 | | - appendGraphSegment( |
406 | | - segments, |
407 | | - newMessageIds.has(messageId) ? "recentCompressed" : "olderCompressed", |
408 | | - messageTokens, |
409 | | - ) |
410 | | - } else if (!(msg.info.role === "user" && isIgnoredUserMessage(msg))) { |
411 | | - const messageTokens = countMessageTokensExcludingPrunedTools(state, msg) |
412 | | - appendGraphSegment(segments, "inContext", messageTokens) |
413 | | - } |
414 | | - |
415 | | - appendGraphSegment( |
416 | | - segments, |
417 | | - "recentCompressed", |
418 | | - recentStandaloneByMessage.get(messageId) || 0, |
419 | | - ) |
420 | | - appendGraphSegment( |
421 | | - segments, |
422 | | - "olderCompressed", |
423 | | - olderStandaloneByMessage.get(messageId) || 0, |
424 | | - ) |
425 | | - } |
426 | | - |
427 | | - appendGraphSegment(segments, "recentCompressed", unparentedRecentStandaloneTokens) |
428 | | - appendGraphSegment(segments, "olderCompressed", unparentedOlderStandaloneTokens) |
429 | | - |
430 | | - clog.info(C.COMPRESS, "Compression graph token accounting", { |
431 | | - systemTokens, |
432 | | - recentCompressedTokens, |
433 | | - olderCompressedTokens, |
434 | | - remainingTokens, |
435 | | - totalSessionTokens, |
436 | | - segments: segments.length, |
437 | | - }) |
438 | | - |
439 | | - return { |
440 | | - systemTokens, |
441 | | - recentCompressedTokens, |
442 | | - olderCompressedTokens, |
443 | | - remainingTokens, |
444 | | - totalSessionTokens, |
445 | | - segments, |
446 | | - } |
447 | | -} |
448 | | - |
449 | | -function allocateSegmentWidths(values: number[], total: number, width: number): number[] { |
450 | | - if (total <= 0 || width <= 0) { |
451 | | - return new Array(values.length).fill(0) |
452 | | - } |
453 | | - |
454 | | - const raw = values.map((v) => (v / total) * width) |
455 | | - const base = raw.map((v) => Math.floor(v)) |
456 | | - let used = base.reduce((acc, v) => acc + v, 0) |
457 | | - |
458 | | - const order = raw |
459 | | - .map((v, idx) => ({ idx, frac: v - Math.floor(v) })) |
460 | | - .sort((a, b) => b.frac - a.frac) |
461 | | - |
462 | | - for (let i = 0; used < width && i < order.length; i++) { |
463 | | - base[order[i].idx] += 1 |
464 | | - used++ |
465 | | - } |
466 | | - |
467 | | - return base |
468 | | -} |
469 | | - |
470 | | -export function formatCompressionGraph(data: CompressionGraphData, width: number = 50): string { |
471 | | - const segments: CompressionGraphSegment[] = |
472 | | - data.segments.length > 0 |
473 | | - ? data.segments |
474 | | - : [ |
475 | | - { type: "system", tokens: data.systemTokens }, |
476 | | - { type: "recentCompressed", tokens: data.recentCompressedTokens }, |
477 | | - { type: "olderCompressed", tokens: data.olderCompressedTokens }, |
478 | | - { type: "inContext", tokens: data.remainingTokens }, |
479 | | - ] |
480 | | - |
481 | | - const chars: Record<CompressionGraphSegmentType, string> = { |
482 | | - system: "▌", |
483 | | - recentCompressed: "⣿", |
484 | | - olderCompressed: "░", |
485 | | - inContext: "█", |
486 | | - } |
487 | | - const segmentWidths = allocateSegmentWidths( |
488 | | - segments.map((segment) => segment.tokens), |
489 | | - data.totalSessionTokens, |
490 | | - width, |
491 | | - ) |
492 | | - |
493 | | - let bar = "" |
494 | | - for (let i = 0; i < segments.length; i++) { |
495 | | - bar += chars[segments[i].type].repeat(Math.max(0, segmentWidths[i])) |
496 | | - } |
497 | | - |
498 | | - if (bar.length < width) { |
499 | | - bar += " ".repeat(width - bar.length) |
500 | | - } |
501 | | - |
502 | | - return `│${bar}│` |
503 | | -} |
504 | | - |
505 | | -export function formatCompressionGraphLegend(): string { |
506 | | - return "→ Legend: ▌ system | ⣿ recent compress | ░ older compressed | █ in context" |
507 | | -} |
508 | | - |
509 | 218 | export function shortenPath(input: string, workingDirectory?: string): string { |
510 | 219 | const inPathMatch = input.match(/^(.+) in (.+)$/) |
511 | 220 | if (inPathMatch) { |
|
0 commit comments