-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcompaction.py
More file actions
240 lines (196 loc) · 8.58 KB
/
compaction.py
File metadata and controls
240 lines (196 loc) · 8.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""Context window management: two-layer compression for long conversations."""
from __future__ import annotations
import providers
# ── Token estimation ──────────────────────────────────────────────────────
def estimate_tokens(messages: list) -> int:
"""Estimate token count by summing content lengths / 3.5.
Args:
messages: list of message dicts with "content" field (str or list of dicts)
Returns:
approximate token count, int
"""
total_chars = 0
for m in messages:
content = m.get("content", "")
if isinstance(content, str):
total_chars += len(content)
elif isinstance(content, list):
for block in content:
if isinstance(block, dict):
# Sum all string values in the block
for v in block.values():
if isinstance(v, str):
total_chars += len(v)
# Also count tool_calls if present
for tc in m.get("tool_calls", []):
if isinstance(tc, dict):
for v in tc.values():
if isinstance(v, str):
total_chars += len(v)
return int(total_chars / 3.5)
def get_context_limit(model: str) -> int:
"""Look up context window size for a model.
Args:
model: model string (e.g. "claude-opus-4-6", "ollama/llama3.3")
Returns:
context limit in tokens
"""
provider_name = providers.detect_provider(model)
prov = providers.PROVIDERS.get(provider_name, {})
return prov.get("context_limit", 128000)
# ── Layer 1: Snip old tool results ────────────────────────────────────────
def snip_old_tool_results(
messages: list,
max_chars: int = 2000,
preserve_last_n_turns: int = 6,
) -> list:
"""Truncate tool-role messages older than preserve_last_n_turns from end.
For old tool messages whose content exceeds max_chars, keep the first half
and last quarter, inserting '[... N chars snipped ...]' in between.
Mutates in place and returns the same list.
Args:
messages: list of message dicts (mutated in place)
max_chars: maximum character length before truncation
preserve_last_n_turns: number of messages from end to preserve
Returns:
the same messages list (mutated)
"""
cutoff = max(0, len(messages) - preserve_last_n_turns)
for i in range(cutoff):
m = messages[i]
if m.get("role") != "tool":
continue
content = m.get("content", "")
if not isinstance(content, str) or len(content) <= max_chars:
continue
first_half = content[: max_chars // 2]
last_quarter = content[-(max_chars // 4):]
snipped = len(content) - len(first_half) - len(last_quarter)
m["content"] = f"{first_half}\n[... {snipped} chars snipped ...]\n{last_quarter}"
return messages
# ── Layer 2: Auto-compact ─────────────────────────────────────────────────
def find_split_point(messages: list, keep_ratio: float = 0.3) -> int:
"""Find index that splits messages so ~keep_ratio of tokens are in the recent portion.
Walks backwards from end, accumulating token estimates, and returns the
index where the recent portion reaches ~keep_ratio of total tokens.
Args:
messages: list of message dicts
keep_ratio: fraction of tokens to keep in the recent portion
Returns:
split index (messages[:idx] = old, messages[idx:] = recent)
"""
total = estimate_tokens(messages)
target = int(total * keep_ratio)
running = 0
for i in range(len(messages) - 1, -1, -1):
running += estimate_tokens([messages[i]])
if running >= target:
return i
return 0
def compact_messages(messages: list, config: dict, focus: str = "") -> list:
"""Compress old messages into a summary via LLM call.
Splits at find_split_point, summarizes old portion, returns
[summary_msg, ack_msg, *recent_messages].
Args:
messages: full message list
config: agent config dict (must contain "model")
focus: optional focus instructions for the summarizer
Returns:
new compacted message list
"""
split = find_split_point(messages)
if split <= 0:
return messages
old = messages[:split]
recent = messages[split:]
# Build summary request
old_text = ""
for m in old:
role = m.get("role", "?")
content = m.get("content", "")
if isinstance(content, str):
old_text += f"[{role}]: {content[:500]}\n"
elif isinstance(content, list):
old_text += f"[{role}]: (structured content)\n"
summary_prompt = (
"Summarize the following conversation history concisely. "
"Preserve key decisions, file paths, tool results, and context "
"needed to continue the conversation."
)
if focus:
summary_prompt += f"\n\nFocus especially on: {focus}"
summary_prompt += "\n\n" + old_text
# Call LLM for summary
summary_text = ""
for event in providers.stream(
model=config["model"],
system="You are a concise summarizer.",
messages=[{"role": "user", "content": summary_prompt}],
tool_schemas=[],
config=config,
):
if isinstance(event, providers.TextChunk):
summary_text += event.text
summary_msg = {
"role": "user",
"content": f"[Previous conversation summary]\n{summary_text}",
}
ack_msg = {
"role": "assistant",
"content": "Understood. I have the context from the previous conversation. Let's continue.",
}
return [summary_msg, ack_msg, *recent]
# ── Main entry ────────────────────────────────────────────────────────────
def maybe_compact(state, config: dict) -> bool:
"""Check if context window is getting full and compress if needed.
Runs snip_old_tool_results first, then auto-compact if still over threshold.
Args:
state: AgentState with .messages list
config: agent config dict (must contain "model")
Returns:
True if compaction was performed
"""
model = config.get("model", "")
limit = get_context_limit(model)
threshold = limit * 0.7
if estimate_tokens(state.messages) <= threshold:
return False
# Layer 1: snip old tool results
snip_old_tool_results(state.messages)
if estimate_tokens(state.messages) <= threshold:
return True
# Layer 2: auto-compact
state.messages = compact_messages(state.messages, config)
state.messages.extend(_restore_plan_context(config))
return True
# ── Plan context restoration ─────────────────────────────────────────────
def _restore_plan_context(config: dict) -> list:
"""If in plan mode, return messages that restore plan file context."""
from pathlib import Path
plan_file = config.get("_plan_file", "")
if not plan_file or config.get("permission_mode") != "plan":
return []
p = Path(plan_file)
if not p.exists():
return []
content = p.read_text(encoding="utf-8").strip()
if not content:
return []
return [
{"role": "user", "content": f"[Plan file restored after compaction: {plan_file}]\n\n{content}"},
{"role": "assistant", "content": "I have the plan context. Let's continue."},
]
# ── Manual compact ───────────────────────────────────────────────────────
def manual_compact(state, config: dict, focus: str = "") -> tuple[bool, str]:
"""User-triggered compaction via /compact. Not gated by threshold.
Returns (success, info_message).
"""
if len(state.messages) < 4:
return False, "Not enough messages to compact."
before = estimate_tokens(state.messages)
snip_old_tool_results(state.messages)
state.messages = compact_messages(state.messages, config, focus=focus)
state.messages.extend(_restore_plan_context(config))
after = estimate_tokens(state.messages)
saved = before - after
return True, f"Compacted: ~{before} → ~{after} tokens (~{saved} saved)"