From 131ba355fe7fdd0564e667626808bb351850f0af Mon Sep 17 00:00:00 2001 From: Yixin Luo <18810541851@163.com> Date: Sun, 18 Jan 2026 22:35:16 +0800 Subject: [PATCH 1/2] fix : remove summary Signed-off-by: Yixin Luo <18810541851@163.com> --- .env.example | 4 +- README.md | 4 - config.py | 7 +- docs/advanced-features.md | 1 - docs/configuration.md | 2 - docs/examples.md | 1 - docs/memory-management.md | 32 +------ memory/compressor.py | 53 +++--------- memory/manager.py | 129 +++++++++-------------------- memory/short_term.py | 17 ++++ memory/token_tracker.py | 20 ----- test/memory/test_compressor.py | 37 +++++---- test/memory/test_integration.py | 17 ++-- test/memory/test_memory_manager.py | 76 +++++++---------- test/test_memory.py | 3 - 15 files changed, 138 insertions(+), 265 deletions(-) diff --git a/.env.example b/.env.example index 3351391..b0aaa7b 100644 --- a/.env.example +++ b/.env.example @@ -43,9 +43,7 @@ RETRY_MAX_DELAY=60.0 # Maximum delay in seconds # Memory Management Configuration MEMORY_ENABLED=true # Enable/disable memory compression -MEMORY_MAX_CONTEXT_TOKENS=100000 # Maximum context window size -MEMORY_TARGET_TOKENS=30000 # Target working memory size (soft limit) -MEMORY_COMPRESSION_THRESHOLD=25000 # Hard limit - compress when exceeded +MEMORY_COMPRESSION_THRESHOLD=25000 # Token threshold - compress when exceeded MEMORY_SHORT_TERM_SIZE=100 # Number of recent messages to keep MEMORY_COMPRESSION_RATIO=0.3 # Target compression ratio (0.3 = 30% of original) diff --git a/README.md b/README.md index 2152452..19d1993 100644 --- a/README.md +++ b/README.md @@ -87,8 +87,6 @@ MAX_ITERATIONS=100 # Maximum iteration loops # Memory Management MEMORY_ENABLED=true -MEMORY_MAX_CONTEXT_TOKENS=100000 -MEMORY_TARGET_TOKENS=30000 MEMORY_COMPRESSION_THRESHOLD=25000 MEMORY_SHORT_TERM_SIZE=100 MEMORY_COMPRESSION_RATIO=0.3 @@ -249,8 +247,6 @@ See the full configuration template in `.env.example`. Key options: | `LITELLM_DROP_PARAMS` | Drop unsupported params | `true` | | `LITELLM_TIMEOUT` | Request timeout in seconds | `600` | | `MAX_ITERATIONS` | Maximum agent iterations | `100` | -| `MEMORY_MAX_CONTEXT_TOKENS` | Maximum context window | `100000` | -| `MEMORY_TARGET_TOKENS` | Target working memory size | `30000` | | `MEMORY_COMPRESSION_THRESHOLD` | Compress when exceeded | `25000` | | `MEMORY_SHORT_TERM_SIZE` | Recent messages to keep | `100` | | `RETRY_MAX_ATTEMPTS` | Retry attempts for rate limits | `3` | diff --git a/config.py b/config.py index 6e96309..17a366d 100644 --- a/config.py +++ b/config.py @@ -40,13 +40,10 @@ class Config: # Memory Management Configuration MEMORY_ENABLED = os.getenv("MEMORY_ENABLED", "true").lower() == "true" - MEMORY_MAX_CONTEXT_TOKENS = int(os.getenv("MEMORY_MAX_CONTEXT_TOKENS", "100000")) - MEMORY_TARGET_TOKENS = int(os.getenv("MEMORY_TARGET_TOKENS", "50000")) - MEMORY_COMPRESSION_THRESHOLD = int(os.getenv("MEMORY_COMPRESSION_THRESHOLD", "40000")) + MEMORY_COMPRESSION_THRESHOLD = int(os.getenv("MEMORY_COMPRESSION_THRESHOLD", "60000")) MEMORY_SHORT_TERM_SIZE = int(os.getenv("MEMORY_SHORT_TERM_SIZE", "100")) - MEMORY_SHORT_TERM_MIN_SIZE = int(os.getenv("MEMORY_SHORT_TERM_MIN_SIZE", "5")) + MEMORY_SHORT_TERM_MIN_SIZE = int(os.getenv("MEMORY_SHORT_TERM_MIN_SIZE", "6")) MEMORY_COMPRESSION_RATIO = float(os.getenv("MEMORY_COMPRESSION_RATIO", "0.3")) - MEMORY_PRESERVE_TOOL_CALLS = True MEMORY_PRESERVE_SYSTEM_PROMPTS = True # Tool Result Processing Configuration diff --git a/docs/advanced-features.md b/docs/advanced-features.md index 11e1e4c..ab7f6cb 100644 --- a/docs/advanced-features.md +++ b/docs/advanced-features.md @@ -82,7 +82,6 @@ Enable in `.env`: ```bash MEMORY_ENABLED=true -MEMORY_MAX_CONTEXT_TOKENS=100000 MEMORY_COMPRESSION_THRESHOLD=40000 ``` diff --git a/docs/configuration.md b/docs/configuration.md index cf09236..a191d6a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -70,8 +70,6 @@ MAX_ITERATIONS=100 ```bash MEMORY_ENABLED=true -MEMORY_MAX_CONTEXT_TOKENS=100000 -MEMORY_TARGET_TOKENS=30000 MEMORY_COMPRESSION_THRESHOLD=25000 MEMORY_SHORT_TERM_SIZE=100 MEMORY_COMPRESSION_RATIO=0.3 diff --git a/docs/examples.md b/docs/examples.md index df35290..9f21ced 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -198,7 +198,6 @@ For long-running tasks with many iterations: ```bash # Enable memory management in .env: MEMORY_ENABLED=true -MEMORY_MAX_CONTEXT_TOKENS=100000 # Run a complex task: python main.py --mode react --task "Analyze all Python files, find patterns, and generate a detailed report" diff --git a/docs/memory-management.md b/docs/memory-management.md index 8b4f1a4..ffbf35c 100644 --- a/docs/memory-management.md +++ b/docs/memory-management.md @@ -36,14 +36,8 @@ In your `.env` file: # Enable memory management MEMORY_ENABLED=true -# Maximum total context size -MEMORY_MAX_CONTEXT_TOKENS=100000 - # Trigger compression at this threshold MEMORY_COMPRESSION_THRESHOLD=40000 - -# Target size after compression -MEMORY_TARGET_TOKENS=50000 ``` ### 2. Run Your Agent @@ -150,14 +144,8 @@ cost = tracker.calculate_cost("claude-3-5-sonnet-20241022") # Enable/disable memory (default: true) MEMORY_ENABLED=true -# Maximum total context tokens (default: 100000) -MEMORY_MAX_CONTEXT_TOKENS=100000 - # Start compression when context exceeds this (default: 40000) MEMORY_COMPRESSION_THRESHOLD=40000 - -# Target size after compression (default: 50000) -MEMORY_TARGET_TOKENS=50000 ``` ### Advanced Settings @@ -176,8 +164,6 @@ MEMORY_COMPRESSION_STRATEGY=sliding_window # Preserve system prompts (default: true) MEMORY_PRESERVE_SYSTEM_PROMPTS=true -# Preserve tool calls and results (default: true) -MEMORY_PRESERVE_TOOL_CALLS=true ``` ### Memory Presets @@ -420,22 +406,6 @@ config = MemoryConfig( agent = ReActAgent(llm=llm, tools=tools, memory_config=config) ``` -### Example 4: Monitor Budget - -```python -memory = MemoryManager(config, llm) - -# ... use memory ... - -# Check budget status -budget = memory.token_tracker.get_budget_status(max_tokens=50000) - -if budget['over_budget']: - print(f"⚠️ Over budget by {budget['total_tokens'] - budget['max_tokens']} tokens") -else: - print(f"✅ {budget['remaining']} tokens remaining ({budget['percentage']:.1f}% used)") -``` - ## How Compression Works ### Step-by-Step Process @@ -526,7 +496,7 @@ MEMORY_COMPRESSION_THRESHOLD=40000 1. Use `selective` strategy instead of `sliding_window` 2. Increase `MEMORY_SHORT_TERM_SIZE` to preserve more recent messages 3. Increase `MEMORY_COMPRESSION_RATIO` to keep more content -4. Set `MEMORY_PRESERVE_TOOL_CALLS=true` + ### Issue: High compression cost diff --git a/memory/compressor.py b/memory/compressor.py index 48a8ca0..a602620 100644 --- a/memory/compressor.py +++ b/memory/compressor.py @@ -1,7 +1,7 @@ """Memory compression using LLM-based summarization.""" import logging -from typing import TYPE_CHECKING, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from config import Config from llm.base import LLMMessage @@ -48,7 +48,6 @@ def compress( messages: List[LLMMessage], strategy: str = CompressionStrategy.SLIDING_WINDOW, target_tokens: Optional[int] = None, - orphaned_tool_use_ids: Optional[Set[str]] = None, ) -> CompressedMemory: """Compress messages using specified strategy. @@ -56,8 +55,6 @@ def compress( messages: List of messages to compress strategy: Compression strategy to use target_tokens: Target token count for compressed output - orphaned_tool_use_ids: Set of tool_use IDs from previous summaries that are - waiting for tool_result in current messages Returns: CompressedMemory object @@ -70,14 +67,11 @@ def compress( original_tokens = self._estimate_tokens(messages) target_tokens = int(original_tokens * Config.MEMORY_COMPRESSION_RATIO) - if orphaned_tool_use_ids is None: - orphaned_tool_use_ids = set() - # Select and apply compression strategy if strategy == CompressionStrategy.SLIDING_WINDOW: return self._compress_sliding_window(messages, target_tokens) elif strategy == CompressionStrategy.SELECTIVE: - return self._compress_selective(messages, target_tokens, orphaned_tool_use_ids) + return self._compress_selective(messages, target_tokens) elif strategy == CompressionStrategy.DELETION: return self._compress_deletion(messages) else: @@ -147,7 +141,7 @@ def _compress_sliding_window( ) def _compress_selective( - self, messages: List[LLMMessage], target_tokens: int, orphaned_tool_use_ids: set = None + self, messages: List[LLMMessage], target_tokens: int ) -> CompressedMemory: """Compress using selective preservation strategy. @@ -157,16 +151,12 @@ def _compress_selective( Args: messages: Messages to compress target_tokens: Target token count - orphaned_tool_use_ids: Set of tool_use IDs from previous summaries Returns: CompressedMemory object """ - if orphaned_tool_use_ids is None: - orphaned_tool_use_ids = set() - # Separate preserved vs compressible messages - preserved, to_compress = self._separate_messages(messages, orphaned_tool_use_ids) + preserved, to_compress = self._separate_messages(messages) if not to_compress: # Nothing to compress @@ -255,29 +245,25 @@ def _compress_deletion(self, messages: List[LLMMessage]) -> CompressedMemory: ) def _separate_messages( - self, messages: List[LLMMessage], orphaned_tool_use_ids_from_summaries: set = None + self, messages: List[LLMMessage] ) -> Tuple[List[LLMMessage], List[LLMMessage]]: """Separate messages into preserved and compressible. Strategy: 1. Preserve system messages (if configured) - 2. Preserve protected tools (todo list, etc.) - NEVER compress these - 3. Use selective strategy for other messages (system decides based on recency, importance) - 4. **Critical rule**: Tool pairs (tool_use + tool_result) must stay together + 2. Preserve orphaned tool_use (waiting for tool_result) + 3. Preserve protected tools (todo list, etc.) - NEVER compress these + 4. Preserve the most recent N messages (MEMORY_SHORT_TERM_MIN_SIZE) + 5. **Critical rule**: Tool pairs (tool_use + tool_result) must stay together - If one is preserved, the other must be preserved too - If one is compressed, the other must be compressed too - 5. **Critical fix**: Preserve tool_result that match orphaned tool_use from previous summaries Args: messages: All messages - orphaned_tool_use_ids_from_summaries: Tool_use IDs from previous summaries waiting for results Returns: Tuple of (preserved, to_compress) """ - if orphaned_tool_use_ids_from_summaries is None: - orphaned_tool_use_ids_from_summaries = set() - preserve_indices = set() # Step 1: Mark system messages for preservation @@ -293,27 +279,13 @@ def _separate_messages( for orphan_idx in orphaned_tool_use_indices: preserve_indices.add(orphan_idx) - # Step 2b: CRITICAL FIX - Preserve tool_result that match orphaned tool_use from previous summaries - # These results finally arrived and must be preserved to match their tool_use - for i, msg in enumerate(messages): - if msg.role == "user" and isinstance(msg.content, list): - for block in msg.content: - if isinstance(block, dict) and block.get("type") == "tool_result": - tool_use_id = block.get("tool_use_id") - if tool_use_id in orphaned_tool_use_ids_from_summaries: - preserve_indices.add(i) - logger.info( - f"Preserving tool_result for orphaned tool_use '{tool_use_id}' from previous summary" - ) - - # Step 2c: Mark protected tools for preservation (CRITICAL for stateful tools) + # Step 2b: Mark protected tools for preservation (CRITICAL for stateful tools) protected_pairs = self._find_protected_tool_pairs(messages, tool_pairs) for assistant_idx, user_idx in protected_pairs: preserve_indices.add(assistant_idx) preserve_indices.add(user_idx) - # Step 3: Apply selective preservation strategy (keep recent N messages) - # Preserve last short_term_min_message_count messages by default (sliding window approach) + # Step 3: Preserve the most recent N messages to maintain conversation continuity preserve_count = min(Config.MEMORY_SHORT_TERM_MIN_SIZE, len(messages)) for i in range(len(messages) - preserve_count, len(messages)): if i >= 0: @@ -339,7 +311,8 @@ def _separate_messages( logger.info( f"Separated: {len(preserved)} preserved, {len(to_compress)} to compress " f"({len(tool_pairs)} tool pairs, {len(protected_pairs)} protected, " - f"{len(orphaned_tool_use_indices)} orphaned tool_use)" + f"{len(orphaned_tool_use_indices)} orphaned tool_use, " + f"{preserve_count} recent)" ) return preserved, to_compress diff --git a/memory/manager.py b/memory/manager.py index c419460..9af3539 100644 --- a/memory/manager.py +++ b/memory/manager.py @@ -71,8 +71,7 @@ def __init__( f"Tool result processing enabled with external storage: {storage_path or 'in-memory'}" ) - # Storage for compressed memories and system messages - self.summaries: List[CompressedMemory] = [] + # Storage for system messages (summaries are now stored as regular messages in short_term) self.system_messages: List[LLMMessage] = [] # State tracking @@ -81,6 +80,9 @@ def __init__( self.last_compression_savings = 0 self.compression_count = 0 + # Summary message prefix for identification + self.SUMMARY_PREFIX = "[Conversation Summary]\n" + @classmethod def from_session( cls, @@ -114,10 +116,9 @@ def from_session( # Restore state manager.system_messages = session_data["system_messages"] - manager.summaries = session_data["summaries"] manager.compression_count = session_data["stats"]["compression_count"] - # Add messages to short-term memory + # Add messages to short-term memory (including any summary messages) for msg in session_data["messages"]: manager.short_term.add_message(msg) @@ -127,7 +128,6 @@ def from_session( logger.info( f"Loaded session {session_id}: " f"{len(session_data['messages'])} messages, " - f"{len(session_data['summaries'])} summaries, " f"{manager.current_tokens} tokens" ) @@ -204,7 +204,6 @@ def add_message(self, message: LLMMessage, actual_tokens: Dict[str, int] = None) logger.debug( f"Compression check: current={self.current_tokens}, " f"threshold={Config.MEMORY_COMPRESSION_THRESHOLD}, " - f"target={Config.MEMORY_TARGET_TOKENS}, " f"short_term_full={self.short_term.is_full()}" ) @@ -217,29 +216,14 @@ def get_context_for_llm(self) -> List[LLMMessage]: """Get optimized context for LLM call. Returns: - List of messages combining summaries and recent messages + List of messages: system messages + short-term messages (which includes summaries) """ context = [] # 1. Add system messages (always included) context.extend(self.system_messages) - # 2. Add summaries - for summary in self.summaries: - # add summary text (if any) - if summary.summary: - context.append( - LLMMessage( - role="user", - content=f"[Previous conversation summary]\n{summary.summary}", - ) - ) - - # 3. Add preserved messages - for summary in self.summaries: - context.extend(summary.preserved_messages) - - # 4. Add short-term memory (recent messages) + # 2. Add short-term memory (includes summary messages and recent messages) context.extend(self.short_term.get_messages()) return context @@ -247,6 +231,9 @@ def get_context_for_llm(self) -> List[LLMMessage]: def compress(self, strategy: str = None) -> Optional[CompressedMemory]: """Compress current short-term memory. + After compression, summary and preserved messages are put back into short_term + as regular messages, so they can participate in future compressions. + Args: strategy: Compression strategy (None = auto-select) @@ -254,6 +241,7 @@ def compress(self, strategy: str = None) -> Optional[CompressedMemory]: CompressedMemory object if compression was performed """ messages = self.short_term.get_messages() + message_count = len(messages) if not messages: logger.warning("No messages to compress") @@ -263,23 +251,17 @@ def compress(self, strategy: str = None) -> Optional[CompressedMemory]: if strategy is None: strategy = self._select_strategy(messages) - logger.info(f"🗜️ Compressing {len(messages)} messages using {strategy} strategy") + logger.info(f"🗜️ Compressing {message_count} messages using {strategy} strategy") try: - # CRITICAL: Find orphaned tool_use IDs from previous summaries - # These tool_use are waiting for tool_result that might be in current short_term - orphaned_tool_use_ids = self._get_orphaned_tool_use_ids_from_summaries() - - # Perform compression (pass orphaned IDs so compressor can protect matching tool_results) + # Perform compression compressed = self.compressor.compress( messages, strategy=strategy, target_tokens=self._calculate_target_tokens(), - orphaned_tool_use_ids=orphaned_tool_use_ids, ) # Track compression results - self.summaries.append(compressed) self.compression_count += 1 self.was_compressed_last_iteration = True self.last_compression_savings = compressed.token_savings @@ -291,9 +273,30 @@ def compress(self, strategy: str = None) -> Optional[CompressedMemory]: compression_cost = compressed.compressed_tokens self.token_tracker.add_compression_cost(compression_cost) - # Clear short-term memory + # Remove compressed messages from short-term memory + self.short_term.remove_first(message_count) + + # Rebuild short_term with: summary + preserved messages (in order) + # Get any remaining messages (added after compression started) + remaining_messages = self.short_term.get_messages() self.short_term.clear() + # 1. Add summary first (represents older context) + if compressed.summary: + summary_message = LLMMessage( + role="user", + content=f"{self.SUMMARY_PREFIX}{compressed.summary}", + ) + self.short_term.add_message(summary_message) + + # 2. Add preserved messages in order + for msg in compressed.preserved_messages: + self.short_term.add_message(msg) + + # 3. Add any remaining messages + for msg in remaining_messages: + self.short_term.add_message(msg) + # Update current token count old_tokens = self.current_tokens self.current_tokens = self._recalculate_current_tokens() @@ -302,7 +305,8 @@ def compress(self, strategy: str = None) -> Optional[CompressedMemory]: logger.info( f"✅ Compression complete: {compressed.original_tokens} → {compressed.compressed_tokens} tokens " f"({compressed.savings_percentage:.1f}% saved, ratio: {compressed.compression_ratio:.2f}), " - f"context: {old_tokens} → {self.current_tokens} tokens" + f"context: {old_tokens} → {self.current_tokens} tokens, " + f"short_term now has {self.short_term.count()} messages" ) return compressed @@ -337,13 +341,6 @@ def _should_compress(self) -> tuple[bool, Optional[str]]: f"current tokens: {self.current_tokens})", ) - # Soft limit: compress if over target token count - if self.current_tokens > Config.MEMORY_TARGET_TOKENS: - return ( - True, - f"soft_limit ({self.current_tokens} > {Config.MEMORY_TARGET_TOKENS})", - ) - return False, None def _select_strategy(self, messages: List[LLMMessage]) -> str: @@ -396,42 +393,6 @@ def _calculate_target_tokens(self) -> int: target = int(original_tokens * Config.MEMORY_COMPRESSION_RATIO) return max(target, 500) # Minimum 500 tokens for summary - def _get_orphaned_tool_use_ids_from_summaries(self) -> set: - """Get tool_use IDs from summaries that don't have matching tool_result yet. - - These are tool_use that were preserved in previous compressions but their - tool_result might arrive in later messages (in current short_term). - - Returns: - Set of tool_use IDs that are waiting for results - """ - orphaned_ids = set() - - for summary in self.summaries: - # Collect tool_use IDs from preserved messages - tool_use_ids = set() - tool_result_ids = set() - - for msg in summary.preserved_messages: - if isinstance(msg.content, list): - for block in msg.content: - if isinstance(block, dict): - if block.get("type") == "tool_use": - tool_use_ids.add(block.get("id")) - elif block.get("type") == "tool_result": - tool_result_ids.add(block.get("tool_use_id")) - - # Orphaned = tool_use without result in the same summary - summary_orphaned = tool_use_ids - tool_result_ids - orphaned_ids.update(summary_orphaned) - - if orphaned_ids: - logger.debug( - f"Found {len(orphaned_ids)} orphaned tool_use IDs in summaries: {orphaned_ids}" - ) - - return orphaned_ids - def process_tool_result( self, tool_name: str, tool_call_id: str, result: str, context: str = "" ) -> str: @@ -509,11 +470,7 @@ def _recalculate_current_tokens(self) -> int: for msg in self.system_messages: total += self.token_tracker.count_message_tokens(msg, provider, model) - # Count summaries - for summary in self.summaries: - total += summary.compressed_tokens - - # Count short-term messages + # Count short-term messages (includes summary messages) for msg in self.short_term.get_messages(): total += self.token_tracker.count_message_tokens(msg, provider, model) @@ -535,9 +492,7 @@ def get_stats(self) -> Dict[str, Any]: "net_savings": self.token_tracker.compression_savings - self.token_tracker.compression_cost, "short_term_count": self.short_term.count(), - "summary_count": len(self.summaries), "total_cost": self.token_tracker.get_total_cost(self.llm.model), - "budget_status": self.token_tracker.get_budget_status(Config.MEMORY_MAX_CONTEXT_TOKENS), } def save_memory(self): @@ -545,8 +500,7 @@ def save_memory(self): This saves the complete memory state including: - System messages - - Short-term messages - - Summaries + - Short-term messages (which includes summary messages after compression) Call this method after completing a task or at key checkpoints. """ @@ -558,7 +512,7 @@ def save_memory(self): messages = self.short_term.get_messages() # Skip saving if there are no messages (empty conversation) - if not messages and not self.system_messages and not self.summaries: + if not messages and not self.system_messages: logger.debug(f"Skipping save_memory: no messages to save for session {self.session_id}") return @@ -566,14 +520,13 @@ def save_memory(self): session_id=self.session_id, system_messages=self.system_messages, messages=messages, - summaries=self.summaries, + summaries=[], # Summaries are now part of messages ) logger.info(f"Saved memory state for session {self.session_id}") def reset(self): """Reset memory manager state.""" self.short_term.clear() - self.summaries.clear() self.system_messages.clear() self.token_tracker.reset() self.current_tokens = 0 diff --git a/memory/short_term.py b/memory/short_term.py index 8a956ba..7ef180a 100644 --- a/memory/short_term.py +++ b/memory/short_term.py @@ -46,6 +46,23 @@ def clear(self) -> List[LLMMessage]: self.messages.clear() return messages + def remove_first(self, count: int) -> List[LLMMessage]: + """Remove the first N messages (oldest) from memory. + + This is useful after compression to remove only the compressed messages + while preserving any new messages that arrived during compression. + + Args: + count: Number of messages to remove from the front + + Returns: + List of removed messages + """ + removed = [] + for _ in range(min(count, len(self.messages))): + removed.append(self.messages.popleft()) + return removed + def is_full(self) -> bool: """Check if short-term memory is at capacity. diff --git a/memory/token_tracker.py b/memory/token_tracker.py index 3886c53..af4356c 100644 --- a/memory/token_tracker.py +++ b/memory/token_tracker.py @@ -204,26 +204,6 @@ def get_net_savings(self, model: str) -> Dict[str, float]: "compression_overhead_tokens": self.compression_cost, } - def get_budget_status(self, max_tokens: int) -> Dict: - """Get current usage vs budget. - - Args: - max_tokens: Maximum token budget - - Returns: - Dict with usage statistics - """ - total_tokens = self.total_input_tokens + self.total_output_tokens - percentage = (total_tokens / max_tokens * 100) if max_tokens > 0 else 0 - - return { - "total_tokens": total_tokens, - "max_tokens": max_tokens, - "percentage": percentage, - "remaining": max(0, max_tokens - total_tokens), - "over_budget": total_tokens > max_tokens, - } - def reset(self): """Reset all counters.""" self.total_input_tokens = 0 diff --git a/test/memory/test_compressor.py b/test/memory/test_compressor.py index fb3d1fd..4320593 100644 --- a/test/memory/test_compressor.py +++ b/test/memory/test_compressor.py @@ -74,8 +74,9 @@ def test_selective_strategy_with_tools(self, set_memory_config, mock_llm, tool_u assert result is not None assert result.metadata["strategy"] == "selective" - # Tool pairs should be preserved - assert len(result.preserved_messages) > 0 + # Regular tool pairs are compressed (not preserved) unless they are protected tools + # Only system messages, protected tools, and orphaned tool pairs are preserved + assert result.summary != "" # Should have a summary for compressed content def test_selective_strategy_preserves_system_messages(self, set_memory_config, mock_llm): """Test that selective strategy preserves system messages.""" @@ -249,28 +250,36 @@ class TestMessageSeparation: """Test message separation logic.""" def test_separate_messages_basic(self, set_memory_config, mock_llm, simple_messages): - """Test basic message separation.""" - set_memory_config(MEMORY_SHORT_TERM_MIN_SIZE=2) + """Test basic message separation - recent messages are preserved, others compressed.""" + set_memory_config(MEMORY_SHORT_TERM_MIN_SIZE=0) # Don't preserve recent messages for this test compressor = WorkingMemoryCompressor(mock_llm) preserved, to_compress = compressor._separate_messages(simple_messages) - # Should preserve at least short_term_min_message_count messages - assert len(preserved) >= 2 + # With MIN_SIZE=0, simple messages (no system, no protected tools) should all be compressed + assert len(to_compress) == len(simple_messages) + assert len(preserved) == 0 # Total should equal original assert len(preserved) + len(to_compress) == len(simple_messages) - def test_separate_preserves_recent_messages(self, set_memory_config, mock_llm, simple_messages): - """Test that most recent messages are preserved.""" - set_memory_config(MEMORY_SHORT_TERM_MIN_SIZE=2) + def test_separate_preserves_system_messages(self, set_memory_config, mock_llm): + """Test that system messages are preserved.""" + set_memory_config(MEMORY_PRESERVE_SYSTEM_PROMPTS=True, MEMORY_SHORT_TERM_MIN_SIZE=0) compressor = WorkingMemoryCompressor(mock_llm) - preserved, to_compress = compressor._separate_messages(simple_messages) + messages = [ + LLMMessage(role="system", content="System prompt"), + LLMMessage(role="user", content="Hello"), + LLMMessage(role="assistant", content="Hi there!"), + ] - # Last N messages should be in preserved - last_n_messages = simple_messages[-2:] - for msg in last_n_messages: - assert msg in preserved + preserved, to_compress = compressor._separate_messages(messages) + + # System message should be preserved + assert len(preserved) == 1 + assert preserved[0].role == "system" + # Other messages should be compressed + assert len(to_compress) == 2 def test_tool_pair_preservation_rule(self, set_memory_config, mock_llm, tool_use_messages): """Test that tool pairs are preserved together (critical rule).""" diff --git a/test/memory/test_integration.py b/test/memory/test_integration.py index e2f97e5..8279a89 100644 --- a/test/memory/test_integration.py +++ b/test/memory/test_integration.py @@ -243,7 +243,7 @@ def test_full_conversation_lifecycle(self, set_memory_config, mock_llm): """Test a complete conversation lifecycle with multiple compressions.""" set_memory_config( MEMORY_SHORT_TERM_SIZE=8, - MEMORY_TARGET_TOKENS=200, + MEMORY_COMPRESSION_THRESHOLD=200, ) manager = MemoryManager(mock_llm) @@ -359,7 +359,7 @@ def test_rapid_compression_cycles(self, set_memory_config, mock_llm): """Test many rapid compression cycles.""" set_memory_config( MEMORY_SHORT_TERM_SIZE=2, - MEMORY_TARGET_TOKENS=50, + MEMORY_COMPRESSION_THRESHOLD=50, ) manager = MemoryManager(mock_llm) @@ -405,7 +405,13 @@ def test_alternating_compression_strategies(self, set_memory_config, mock_llm): # Should have multiple compressions with different strategies assert manager.compression_count == 2 - assert len(manager.summaries) == 2 + # Summaries are now stored as messages in short_term, check context has summary messages + context = manager.get_context_for_llm() + summary_count = sum( + 1 for msg in context + if isinstance(msg.content, str) and msg.content.startswith("[Conversation Summary]") + ) + assert summary_count >= 1 # At least one summary should exist def test_empty_content_blocks(self, set_memory_config, mock_llm): """Test handling of empty content blocks.""" @@ -433,7 +439,7 @@ def test_very_long_single_message(self, set_memory_config, mock_llm): """Test handling of a very long single message.""" set_memory_config( MEMORY_SHORT_TERM_SIZE=5, - MEMORY_TARGET_TOKENS=100, + MEMORY_COMPRESSION_THRESHOLD=100, ) manager = MemoryManager(mock_llm) @@ -463,14 +469,13 @@ def test_reset_after_compression(self, set_memory_config, mock_llm, simple_messa # Everything should be cleared assert manager.current_tokens == 0 assert manager.compression_count == 0 - assert len(manager.summaries) == 0 assert manager.short_term.count() == 0 def test_reuse_after_reset(self, set_memory_config, mock_llm): """Test that manager can be reused after reset.""" set_memory_config( MEMORY_SHORT_TERM_SIZE=10, # Large enough to avoid compression - MEMORY_TARGET_TOKENS=100000, + MEMORY_COMPRESSION_THRESHOLD=100000, ) manager = MemoryManager(mock_llm) diff --git a/test/memory/test_memory_manager.py b/test/memory/test_memory_manager.py index acffd4a..3664f6f 100644 --- a/test/memory/test_memory_manager.py +++ b/test/memory/test_memory_manager.py @@ -15,8 +15,8 @@ def test_initialization(self, mock_llm): assert manager.llm == mock_llm assert manager.current_tokens == 0 assert manager.compression_count == 0 - assert len(manager.summaries) == 0 assert len(manager.system_messages) == 0 + assert manager.short_term.count() == 0 def test_add_system_message(self, mock_llm): """Test that system messages are stored separately.""" @@ -81,7 +81,6 @@ def test_reset(self, mock_llm, simple_messages): assert manager.current_tokens == 0 assert manager.compression_count == 0 - assert len(manager.summaries) == 0 assert len(manager.system_messages) == 0 assert manager.short_term.count() == 0 @@ -93,7 +92,6 @@ def test_compression_on_short_term_full(self, set_memory_config, mock_llm): """Test compression triggers when short-term memory is full.""" set_memory_config( MEMORY_SHORT_TERM_SIZE=5, - MEMORY_TARGET_TOKENS=100000, # Very high to avoid soft limit MEMORY_COMPRESSION_THRESHOLD=200000, # Very high to avoid hard limit ) manager = MemoryManager(mock_llm) @@ -108,26 +106,9 @@ def test_compression_on_short_term_full(self, set_memory_config, mock_llm): # After compression, short-term is cleared so it's not full assert not manager.short_term.is_full() - def test_compression_on_soft_limit(self, set_memory_config, mock_llm): - """Test compression triggers on soft limit (target tokens).""" - set_memory_config( - MEMORY_TARGET_TOKENS=50, # Very low to trigger easily - MEMORY_COMPRESSION_THRESHOLD=10000, - MEMORY_SHORT_TERM_SIZE=100, # Large enough to not trigger on count - ) - manager = MemoryManager(mock_llm) - - # Add messages until we exceed target tokens - long_message = "This is a long message. " * 50 - manager.add_message(LLMMessage(role="user", content=long_message)) - - # Should trigger compression - assert manager.compression_count >= 1 - def test_compression_on_hard_limit(self, set_memory_config, mock_llm): """Test compression triggers on hard limit (compression threshold).""" set_memory_config( - MEMORY_TARGET_TOKENS=10000, MEMORY_COMPRESSION_THRESHOLD=100, # Very low to trigger easily MEMORY_SHORT_TERM_SIZE=100, ) @@ -140,25 +121,29 @@ def test_compression_on_hard_limit(self, set_memory_config, mock_llm): assert manager.compression_count >= 1 def test_compression_creates_summary(self, set_memory_config, mock_llm, simple_messages): - """Test that compression creates a summary.""" + """Test that compression creates a summary message in short_term.""" set_memory_config( - MEMORY_SHORT_TERM_SIZE=3, - MEMORY_TARGET_TOKENS=100000, + MEMORY_SHORT_TERM_SIZE=10, # Large enough to not auto-trigger MEMORY_COMPRESSION_THRESHOLD=200000, ) manager = MemoryManager(mock_llm) - # Add messages to trigger compression + # Add messages for msg in simple_messages: manager.add_message(msg) - # Trigger compression - assert manager.compression_count >= 1 - assert len(manager.summaries) >= 1 + # Manually trigger compression with sliding_window strategy (which creates summary) + result = manager.compress(strategy=CompressionStrategy.SLIDING_WINDOW) + assert result is not None + assert manager.compression_count == 1 - # Check that short-term was cleared after compression - # Note: may have new messages added after compression - assert manager.short_term.count() < len(simple_messages) + # Check that summary message exists in short_term (at the front) + context = manager.get_context_for_llm() + has_summary = any( + isinstance(msg.content, str) and msg.content.startswith("[Conversation Summary]") + for msg in context + ) + assert has_summary, "Summary message should be present after compression" def test_get_stats(self, mock_llm, simple_messages): """Test getting memory statistics.""" @@ -177,7 +162,6 @@ def test_get_stats(self, mock_llm, simple_messages): assert "compression_cost" in stats assert "net_savings" in stats assert "short_term_count" in stats - assert "summary_count" in stats class TestToolCallMatching: @@ -188,7 +172,6 @@ def test_tool_pairs_preserved_together(self, set_memory_config, mock_llm, tool_u set_memory_config( MEMORY_SHORT_TERM_SIZE=3, MEMORY_SHORT_TERM_MIN_SIZE=2, - MEMORY_TARGET_TOKENS=100000, MEMORY_COMPRESSION_THRESHOLD=200000, ) manager = MemoryManager(mock_llm) @@ -283,23 +266,22 @@ def test_protected_tool_always_preserved( # Check that todo list tool was preserved assert compressed is not None - # Verify the protected tool is in preserved messages or summaries + # Verify the protected tool is in context (now stored in short_term) found_protected = False + context = manager.get_context_for_llm() - # Check in all summaries (including the one just created) - for summary in manager.summaries: - for msg in summary.preserved_messages: - if isinstance(msg.content, list): - for block in msg.content: - if isinstance(block, dict): - if ( - block.get("type") == "tool_use" - and block.get("name") == "manage_todo_list" - ): - found_protected = True - break - - assert found_protected, "Protected tool 'manage_todo_list' should be preserved" + for msg in context: + if isinstance(msg.content, list): + for block in msg.content: + if isinstance(block, dict): + if ( + block.get("type") == "tool_use" + and block.get("name") == "manage_todo_list" + ): + found_protected = True + break + + assert found_protected, "Protected tool 'manage_todo_list' should be preserved in context" def test_multiple_tool_pairs_in_sequence(self, set_memory_config, mock_llm): """Test multiple consecutive tool_use/tool_result pairs.""" diff --git a/test/test_memory.py b/test/test_memory.py index 307d91d..ef6f798 100644 --- a/test/test_memory.py +++ b/test/test_memory.py @@ -44,8 +44,6 @@ def main(): # Configure memory settings directly via Config class # (In production, these would be set via environment variables) - Config.MEMORY_MAX_CONTEXT_TOKENS = 10000 - Config.MEMORY_TARGET_TOKENS = 500 # Low threshold for demo Config.MEMORY_COMPRESSION_THRESHOLD = 400 # Trigger compression quickly Config.MEMORY_SHORT_TERM_SIZE = 5 Config.MEMORY_COMPRESSION_RATIO = 0.3 @@ -54,7 +52,6 @@ def main(): memory = MemoryManager(mock_llm) print("\nConfiguration:") - print(f" Target tokens: {Config.MEMORY_TARGET_TOKENS}") print(f" Compression threshold: {Config.MEMORY_COMPRESSION_THRESHOLD}") print(f" Short-term size: {Config.MEMORY_SHORT_TERM_SIZE}") From 4b0dcfa5b111e4327b8b30dae340d0790cd41e5c Mon Sep 17 00:00:00 2001 From: Yixin Luo <18810541851@163.com> Date: Sun, 18 Jan 2026 22:44:37 +0800 Subject: [PATCH 2/2] fix Signed-off-by: Yixin Luo <18810541851@163.com> --- memory/code_extractor.py | 6 +++--- memory/tool_result_processor.py | 7 ++++--- memory/tool_result_store.py | 3 +-- test/memory/test_compressor.py | 4 +++- test/memory/test_integration.py | 3 ++- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/memory/code_extractor.py b/memory/code_extractor.py index 4a22d8f..ef9fd11 100644 --- a/memory/code_extractor.py +++ b/memory/code_extractor.py @@ -1,7 +1,7 @@ """Code structure extraction using tree-sitter for multiple languages.""" import logging -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) @@ -167,8 +167,8 @@ class CodeExtractor: def __init__(self): """Initialize code extractor.""" - self.parsers: Dict[str, any] = {} - self.languages: Dict[str, any] = {} + self.parsers: Dict[str, Any] = {} + self.languages: Dict[str, Any] = {} def detect_language(self, filename: str, content: str) -> Optional[str]: """Detect programming language from filename or content. diff --git a/memory/tool_result_processor.py b/memory/tool_result_processor.py index eefb94c..a562ce9 100644 --- a/memory/tool_result_processor.py +++ b/memory/tool_result_processor.py @@ -2,6 +2,7 @@ import logging import re +from typing import Dict, Union from memory.code_extractor import CodeExtractor @@ -19,7 +20,7 @@ class ToolResultProcessor: """ # Tool-specific processing strategies - TOOL_STRATEGIES = { + TOOL_STRATEGIES: Dict[str, Dict[str, Union[int, str]]] = { "read_file": { "max_tokens": 1000, "strategy": "extract_key_sections", @@ -89,8 +90,8 @@ def process_result( strategy_config = self.TOOL_STRATEGIES.get( tool_name, {"max_tokens": 1000, "strategy": "smart_truncate"} ) - max_tokens = strategy_config["max_tokens"] - strategy = strategy_config["strategy"] + max_tokens: int = int(strategy_config["max_tokens"]) + strategy: str = str(strategy_config["strategy"]) # Estimate tokens (rough: 3.5 chars per token) estimated_tokens = len(result) / 3.5 diff --git a/memory/tool_result_store.py b/memory/tool_result_store.py index 319ce72..2e1b210 100644 --- a/memory/tool_result_store.py +++ b/memory/tool_result_store.py @@ -23,12 +23,11 @@ def __init__(self, db_path: Optional[str] = None): db_path: Path to SQLite database file. If None, uses in-memory database. """ self.db_path = db_path or ":memory:" - self.conn = None + self.conn: sqlite3.Connection = sqlite3.connect(self.db_path, check_same_thread=False) self._init_db() def _init_db(self): """Initialize database schema.""" - self.conn = sqlite3.connect(self.db_path, check_same_thread=False) self.conn.row_factory = sqlite3.Row self.conn.execute( diff --git a/test/memory/test_compressor.py b/test/memory/test_compressor.py index 4320593..50733b5 100644 --- a/test/memory/test_compressor.py +++ b/test/memory/test_compressor.py @@ -251,7 +251,9 @@ class TestMessageSeparation: def test_separate_messages_basic(self, set_memory_config, mock_llm, simple_messages): """Test basic message separation - recent messages are preserved, others compressed.""" - set_memory_config(MEMORY_SHORT_TERM_MIN_SIZE=0) # Don't preserve recent messages for this test + set_memory_config( + MEMORY_SHORT_TERM_MIN_SIZE=0 + ) # Don't preserve recent messages for this test compressor = WorkingMemoryCompressor(mock_llm) preserved, to_compress = compressor._separate_messages(simple_messages) diff --git a/test/memory/test_integration.py b/test/memory/test_integration.py index 8279a89..9c87863 100644 --- a/test/memory/test_integration.py +++ b/test/memory/test_integration.py @@ -408,7 +408,8 @@ def test_alternating_compression_strategies(self, set_memory_config, mock_llm): # Summaries are now stored as messages in short_term, check context has summary messages context = manager.get_context_for_llm() summary_count = sum( - 1 for msg in context + 1 + for msg in context if isinstance(msg.content, str) and msg.content.startswith("[Conversation Summary]") ) assert summary_count >= 1 # At least one summary should exist