Skip to content
This repository was archived by the owner on Jun 3, 2026. It is now read-only.
200 changes: 197 additions & 3 deletions src/agents/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,216 @@

from __future__ import annotations

import asyncio
from typing import Any, Dict

from langchain_core.language_models import BaseChatModel

from src.agents.base import BaseAgent
from src.prompts.summarizer import build_system_prompt, pack_summary_query
from src.schemas.summary import SummaryResult
from src.models.registry import get_model_context_window


class SummarizerAgent(BaseAgent):

MAX_RECURSION_DEPTH = 3
CHUNK_OVERLAP_TOKENS = 200
SAFE_THRESHOLD_RATIO = 0.8


MAX_RETRY_ATTEMPTS = 3
INITIAL_BACKOFF_SECONDS = 1.0

def __init__(self, model: BaseChatModel) -> None:
super().__init__(
model=model,
name="summarizer",
system_prompt=build_system_prompt(),
)

self._init_dynamic_chunk_tokens()

def _init_dynamic_chunk_tokens(self) -> None:
"""
Initialize MAX_CHUNK_TOKENS based on the active model's context window.
This ensures we stay well within the model's limits across all providers.
"""
try:

provider = self._detect_provider()
model_name = getattr(
self.model, "model", getattr(self.model, "model_name", None)
)

context_window = get_model_context_window(provider, model_name)

self.MAX_CHUNK_TOKENS = max(int(context_window * self.SAFE_THRESHOLD_RATIO), 2000)

self.logger.info(
f"Initialized dynamic chunking: context_window={context_window}, "
f"max_chunk_tokens={self.MAX_CHUNK_TOKENS}"
)
except Exception as e:

self.MAX_CHUNK_TOKENS = 3000
self.logger.warning(
f"Failed to initialize dynamic chunk tokens, using fallback (3000): {e}"
)

def _detect_provider(self) -> str:
"""Detect the provider from the model instance, unwrapping RunnableBinding if needed."""

model = self.model
while hasattr(model, "bound"):
model = model.bound

model_type = type(model).__name__


provider_map = {
"ChatAnthropic": "claude",
"ChatOpenAI": "openai",
"ChatGoogleGenerativeAI": "gemini",
"ChatGroq": "groq",
"OllamaLLM": "ollama",
"ChatOllama": "ollama",
"ChatBedrock": "bedrock",
"ChatDeepSeek": "deepseek",
"ChatMimo": "mimo",
}
Comment thread
vakrahul marked this conversation as resolved.

for class_name, provider in provider_map.items():
if class_name in model_type:
return provider


return "openai"
Comment thread
vakrahul marked this conversation as resolved.

async def _call_model_with_retry(self, messages: list) -> str:
"""
Call the model with exponential backoff retry logic for rate limits.

Handles rate limit (429) responses gracefully by retrying with
exponential backoff instead of failing immediately.
"""
backoff_seconds = self.INITIAL_BACKOFF_SECONDS

for attempt in range(self.MAX_RETRY_ATTEMPTS):
try:
return await self._call_model(messages)
except Exception as e:
error_msg = str(e).lower()
is_rate_limit = (
"429" in error_msg
or "rate limit" in error_msg
or "quota" in error_msg
or "too many requests" in error_msg
)

if not is_rate_limit or attempt == self.MAX_RETRY_ATTEMPTS - 1:

raise

self.logger.warning(
f"Rate limit hit (attempt {attempt + 1}/{self.MAX_RETRY_ATTEMPTS}). "
f"Retrying in {backoff_seconds:.1f}s..."
)
await asyncio.sleep(backoff_seconds)
backoff_seconds *= 2

def _estimate_tokens(self, text: str) -> int:
"""Lightweight token estimation (approx 4 characters per token)."""
return len(text) // 4

def _chunk_payload(self, text: str) -> list[str]:
"""Splits text into overlapping chunks based on token limits.

Fixes: Uses text.split() for proper whitespace handling and ensures
overlap calculation doesn't create infinite loops when single words
exceed MAX_CHUNK_TOKENS.
"""
words = text.split()
chunks = []
current_chunk = []
current_tokens = 0

for word in words:
word_tokens = self._estimate_tokens(word + " ")
if current_tokens + word_tokens > self.MAX_CHUNK_TOKENS and current_chunk:

chunks.append(" ".join(current_chunk))


overlap_words = []
overlap_tokens = 0
for w in reversed(current_chunk):
w_tokens = self._estimate_tokens(w + " ")
if overlap_tokens + w_tokens > self.CHUNK_OVERLAP_TOKENS:
break
overlap_words.insert(0, w)
overlap_tokens += w_tokens



if len(overlap_words) >= len(current_chunk):
overlap_words = current_chunk[1:] if len(current_chunk) > 1 else []

current_chunk = overlap_words + [word]
current_tokens = sum(
self._estimate_tokens(w + " ") for w in current_chunk
)
else:
current_chunk.append(word)
current_tokens += word_tokens

if current_chunk:
chunks.append(" ".join(current_chunk))

return chunks
Comment thread
vakrahul marked this conversation as resolved.

async def _recursive_summarize(self, text: str, depth: int = 0) -> str:
"""Stateful graph-based loop to chunk, summarize, and map-reduce."""
if depth >= self.MAX_RECURSION_DEPTH:
self.logger.warning(
f"Max recursion depth ({self.MAX_RECURSION_DEPTH}) reached. Truncating payload."
)
messages = self._build_messages(text[: self.MAX_CHUNK_TOKENS * 4])
return await self._call_model_with_retry(messages)

estimated_tokens = self._estimate_tokens(text)

# Base Case: Payload fits safely within the context window
if estimated_tokens <= self.MAX_CHUNK_TOKENS:
messages = self._build_messages(text)
return await self._call_model_with_retry(messages)


self.logger.info(
f"Payload too large ({estimated_tokens} tokens). Splitting into chunks (Depth: {depth})."
)
chunks = self._chunk_payload(text)


tasks = []
for i, chunk in enumerate(chunks):
self.logger.debug(f"Queuing chunk {i + 1}/{len(chunks)} for concurrent summarization...")
messages = self._build_messages(chunk)
tasks.append(self._call_model_with_retry(messages))

self.logger.debug(f"Processing {len(chunks)} chunks concurrently...")
results = await asyncio.gather(*tasks, return_exceptions=True)


exceptions = [r for r in results if isinstance(r, BaseException)]
if exceptions:
raise exceptions[0]

chunk_summaries = [str(s).strip() for s in results]

aggregated_text = "\n\n--- PARTIAL SUMMARIES ---\n\n".join(chunk_summaries)

return await self._recursive_summarize(aggregated_text, depth=depth + 1)

async def arun(
self,
Expand All @@ -36,12 +230,12 @@ async def arun(
return SummaryResult()

user_message = pack_summary_query(user_query, agent_response)
messages = self._build_messages(user_message)
raw_content = await self._call_model(messages)


raw_content = await self._recursive_summarize(user_message)
summary = raw_content.strip()

# Treat empty-like responses as no summary

if summary in ('""', "''", "empty", "(empty)", "(empty string)"):
summary = ""

Expand Down
Loading
Loading