diff --git a/src/agents/summarizer.py b/src/agents/summarizer.py index ea01b31..ce8e6bb 100644 --- a/src/agents/summarizer.py +++ b/src/agents/summarizer.py @@ -7,6 +7,7 @@ from __future__ import annotations +import asyncio from typing import Any, Dict from langchain_core.language_models import BaseChatModel @@ -14,15 +15,208 @@ from src.agents.base import BaseAgent from src.prompts.summarizer import build_system_prompt, pack_summary_query from src.schemas.summary import SummaryResult +from src.models.registry import get_model_context_window class SummarizerAgent(BaseAgent): + + MAX_RECURSION_DEPTH = 3 + CHUNK_OVERLAP_TOKENS = 200 + SAFE_THRESHOLD_RATIO = 0.8 + + + MAX_RETRY_ATTEMPTS = 3 + INITIAL_BACKOFF_SECONDS = 1.0 + def __init__(self, model: BaseChatModel) -> None: super().__init__( model=model, name="summarizer", system_prompt=build_system_prompt(), ) + + self._init_dynamic_chunk_tokens() + + def _init_dynamic_chunk_tokens(self) -> None: + """ + Initialize MAX_CHUNK_TOKENS based on the active model's context window. + This ensures we stay well within the model's limits across all providers. + """ + try: + + provider = self._detect_provider() + model_name = getattr( + self.model, "model", getattr(self.model, "model_name", None) + ) + + context_window = get_model_context_window(provider, model_name) + + self.MAX_CHUNK_TOKENS = max(int(context_window * self.SAFE_THRESHOLD_RATIO), 2000) + + self.logger.info( + f"Initialized dynamic chunking: context_window={context_window}, " + f"max_chunk_tokens={self.MAX_CHUNK_TOKENS}" + ) + except Exception as e: + + self.MAX_CHUNK_TOKENS = 3000 + self.logger.warning( + f"Failed to initialize dynamic chunk tokens, using fallback (3000): {e}" + ) + + def _detect_provider(self) -> str: + """Detect the provider from the model instance, unwrapping RunnableBinding if needed.""" + + model = self.model + while hasattr(model, "bound"): + model = model.bound + + model_type = type(model).__name__ + + + provider_map = { + "ChatAnthropic": "claude", + "ChatOpenAI": "openai", + "ChatGoogleGenerativeAI": "gemini", + "ChatGroq": "groq", + "OllamaLLM": "ollama", + "ChatOllama": "ollama", + "ChatBedrock": "bedrock", + "ChatDeepSeek": "deepseek", + "ChatMimo": "mimo", + } + + for class_name, provider in provider_map.items(): + if class_name in model_type: + return provider + + + return "openai" + + async def _call_model_with_retry(self, messages: list) -> str: + """ + Call the model with exponential backoff retry logic for rate limits. + + Handles rate limit (429) responses gracefully by retrying with + exponential backoff instead of failing immediately. + """ + backoff_seconds = self.INITIAL_BACKOFF_SECONDS + + for attempt in range(self.MAX_RETRY_ATTEMPTS): + try: + return await self._call_model(messages) + except Exception as e: + error_msg = str(e).lower() + is_rate_limit = ( + "429" in error_msg + or "rate limit" in error_msg + or "quota" in error_msg + or "too many requests" in error_msg + ) + + if not is_rate_limit or attempt == self.MAX_RETRY_ATTEMPTS - 1: + + raise + + self.logger.warning( + f"Rate limit hit (attempt {attempt + 1}/{self.MAX_RETRY_ATTEMPTS}). " + f"Retrying in {backoff_seconds:.1f}s..." + ) + await asyncio.sleep(backoff_seconds) + backoff_seconds *= 2 + + def _estimate_tokens(self, text: str) -> int: + """Lightweight token estimation (approx 4 characters per token).""" + return len(text) // 4 + + def _chunk_payload(self, text: str) -> list[str]: + """Splits text into overlapping chunks based on token limits. + + Fixes: Uses text.split() for proper whitespace handling and ensures + overlap calculation doesn't create infinite loops when single words + exceed MAX_CHUNK_TOKENS. + """ + words = text.split() + chunks = [] + current_chunk = [] + current_tokens = 0 + + for word in words: + word_tokens = self._estimate_tokens(word + " ") + if current_tokens + word_tokens > self.MAX_CHUNK_TOKENS and current_chunk: + + chunks.append(" ".join(current_chunk)) + + + overlap_words = [] + overlap_tokens = 0 + for w in reversed(current_chunk): + w_tokens = self._estimate_tokens(w + " ") + if overlap_tokens + w_tokens > self.CHUNK_OVERLAP_TOKENS: + break + overlap_words.insert(0, w) + overlap_tokens += w_tokens + + + + if len(overlap_words) >= len(current_chunk): + overlap_words = current_chunk[1:] if len(current_chunk) > 1 else [] + + current_chunk = overlap_words + [word] + current_tokens = sum( + self._estimate_tokens(w + " ") for w in current_chunk + ) + else: + current_chunk.append(word) + current_tokens += word_tokens + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + return chunks + + async def _recursive_summarize(self, text: str, depth: int = 0) -> str: + """Stateful graph-based loop to chunk, summarize, and map-reduce.""" + if depth >= self.MAX_RECURSION_DEPTH: + self.logger.warning( + f"Max recursion depth ({self.MAX_RECURSION_DEPTH}) reached. Truncating payload." + ) + messages = self._build_messages(text[: self.MAX_CHUNK_TOKENS * 4]) + return await self._call_model_with_retry(messages) + + estimated_tokens = self._estimate_tokens(text) + + # Base Case: Payload fits safely within the context window + if estimated_tokens <= self.MAX_CHUNK_TOKENS: + messages = self._build_messages(text) + return await self._call_model_with_retry(messages) + + + self.logger.info( + f"Payload too large ({estimated_tokens} tokens). Splitting into chunks (Depth: {depth})." + ) + chunks = self._chunk_payload(text) + + + tasks = [] + for i, chunk in enumerate(chunks): + self.logger.debug(f"Queuing chunk {i + 1}/{len(chunks)} for concurrent summarization...") + messages = self._build_messages(chunk) + tasks.append(self._call_model_with_retry(messages)) + + self.logger.debug(f"Processing {len(chunks)} chunks concurrently...") + results = await asyncio.gather(*tasks, return_exceptions=True) + + + exceptions = [r for r in results if isinstance(r, BaseException)] + if exceptions: + raise exceptions[0] + + chunk_summaries = [str(s).strip() for s in results] + + aggregated_text = "\n\n--- PARTIAL SUMMARIES ---\n\n".join(chunk_summaries) + + return await self._recursive_summarize(aggregated_text, depth=depth + 1) async def arun( self, @@ -36,12 +230,12 @@ async def arun( return SummaryResult() user_message = pack_summary_query(user_query, agent_response) - messages = self._build_messages(user_message) - raw_content = await self._call_model(messages) + + raw_content = await self._recursive_summarize(user_message) summary = raw_content.strip() - # Treat empty-like responses as no summary + if summary in ('""', "''", "empty", "(empty)", "(empty string)"): summary = "" diff --git a/src/models/registry.py b/src/models/registry.py index 8c3fd16..af76003 100644 --- a/src/models/registry.py +++ b/src/models/registry.py @@ -23,6 +23,69 @@ logger = logging.getLogger("xmem.models") + + +_CONTEXT_WINDOWS = { + + "claude": { + "claude-3-5-sonnet-20241022": 200000, + "claude-3-5-sonnet": 200000, + "claude-3-sonnet-20240229": 200000, + "claude-3-opus-20240229": 200000, + "claude-3-haiku-20240307": 200000, + "claude-opus": 200000, + "claude-sonnet": 200000, + "claude-haiku": 200000, + "default": 200000, + }, + + "openai": { + "gpt-4o": 128000, + "gpt-4-turbo": 128000, + "gpt-4": 8192, + "gpt-3.5-turbo": 16385, + "default": 128000, + }, + # Gemini models + "gemini": { + "gemini-2.0-flash": 1000000, + "gemini-2.0-pro": 1000000, + "gemini-1.5-pro": 1000000, + "gemini-1.5-flash": 1000000, + "gemini-pro": 32768, + "default": 1000000, + }, + # DeepSeek models + "deepseek": { + "deepseek-chat": 128000, + "deepseek-coder": 128000, + "default": 128000, + }, + + "groq": { + "mixtral-8x7b-32768": 32768, + "llama2-70b-4096": 4096, + "default": 32768, + }, + # OpenRouter (varies by model, use conservative default) + "openrouter": { + "default": 128000, + }, + # Ollama (local, typically depends on model) + "ollama": { + "default": 8000, + }, + # Bedrock (varies by model) + "bedrock": { + "default": 100000, + }, + # Mimo + "mimo": { + "default": 32768, + }, +} + + def _build_from_module(module_name: str, func_name: str, **kwargs) -> BaseChatModel: module = importlib.import_module(f"src.models.{module_name}") factory_fn = getattr(module, func_name) @@ -33,10 +96,14 @@ def _build_from_module(module_name: str, func_name: str, **kwargs) -> BaseChatMo "gemini": lambda **kw: _build_from_module("gemini", "build_gemini_model", **kw), "claude": lambda **kw: _build_from_module("claude", "build_claude_model", **kw), "openai": lambda **kw: _build_from_module("openai", "build_openai_model", **kw), - "deepseek": lambda **kw: _build_from_module("deepseek", "build_deepseek_model", **kw), + "deepseek": lambda **kw: _build_from_module( + "deepseek", "build_deepseek_model", **kw + ), "groq": lambda **kw: _build_from_module("groq", "build_groq_model", **kw), "mimo": lambda **kw: _build_from_module("mimo", "build_mimo_model", **kw), - "openrouter": lambda **kw: _build_from_module("openrouter", "build_openrouter_model", **kw), + "openrouter": lambda **kw: _build_from_module( + "openrouter", "build_openrouter_model", **kw + ), "bedrock": lambda **kw: _build_from_module("bedrock", "build_bedrock_model", **kw), "ollama": lambda **kw: _build_from_module("ollama", "build_ollama_model", **kw), } @@ -55,6 +122,48 @@ def _build_from_module(module_name: str, func_name: str, **kwargs) -> BaseChatMo } +def get_model_context_window( + provider: Provider, model_name: Optional[str] = None +) -> int: + """ + Retrieve the context window (max tokens) for a given provider and model. + + Args: + provider: The provider name (e.g., 'claude', 'openai', 'gemini') + model_name: Specific model name. If None, uses the provider default. + + Returns: + Context window size in tokens. + """ + if provider not in _CONTEXT_WINDOWS: + logger.warning( + f"Provider '{provider}' not found in context window mapping. Using 8192 default." + ) + return 8192 + + provider_windows = _CONTEXT_WINDOWS[provider] + + if model_name: + + if model_name in provider_windows: + return provider_windows[model_name] + + + for key, window in sorted( + ((k, v) for k, v in provider_windows.items() if k != "default"), + key=lambda kv: len(kv[0]), + reverse=True, + ): + if key in model_name: + logger.debug( + f"Matched model '{model_name}' to key '{key}' with context window {window}" + ) + return window + + + return provider_windows.get("default", 8192) + + @lru_cache(maxsize=16) def get_model( provider: Optional[Provider] = None, @@ -76,7 +185,7 @@ def get_model( if provider: return _BUILDERS[provider](**kw) - # Auto-select from fallback order + errors: list[str] = [] for p in settings.fallback_order: key_fn = _KEY_MAP.get(p) @@ -95,9 +204,7 @@ def get_model( ) -# --------------------------------------------------------------------------- -# Vision model (for image analysis) -# --------------------------------------------------------------------------- + _VISION_MODEL_MAP = { "gemini": lambda: settings.gemini_vision_model, @@ -130,16 +237,21 @@ def get_vision_model( """ if provider: vision_name = _VISION_MODEL_MAP[provider]() - return get_model(provider=provider, model_name=vision_name, temperature=temperature) + return get_model( + provider=provider, model_name=vision_name, temperature=temperature + ) - # Auto-select from fallback order + errors: list[str] = [] for p in settings.fallback_order: key_fn = _KEY_MAP.get(p) if key_fn and key_fn(): try: vision_name = _VISION_MODEL_MAP[p]() - model = _BUILDERS[p](model_name=vision_name, **({"temperature": temperature} if temperature is not None else {})) + model = _BUILDERS[p]( + model_name=vision_name, + **({"temperature": temperature} if temperature is not None else {}), + ) logger.info("Using vision provider: %s (model: %s)", p, vision_name) return model except Exception as exc: