|
| 1 | +""" |
| 2 | +Context Tracker for maintaining conversation state and document references. |
| 3 | +
|
| 4 | +This module provides the ContextTracker class that manages conversation context, |
| 5 | +tracking document relationships and maintaining coherent dialogue state. |
| 6 | +""" |
| 7 | + |
| 8 | +from dataclasses import dataclass |
| 9 | +from dataclasses import field |
| 10 | +from datetime import datetime |
| 11 | +from datetime import timedelta |
| 12 | +import logging |
| 13 | +from typing import Any |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class DocumentReference: |
| 19 | + """Represents a reference to a document in conversation.""" |
| 20 | + document_id: str |
| 21 | + file_path: str |
| 22 | + title: str |
| 23 | + content_hash: str |
| 24 | + last_accessed: datetime |
| 25 | + access_count: int = 0 |
| 26 | + relevant_sections: list[str] = field(default_factory=list) |
| 27 | + topics: set[str] = field(default_factory=set) |
| 28 | + |
| 29 | + def to_dict(self) -> dict[str, Any]: |
| 30 | + return { |
| 31 | + "document_id": self.document_id, |
| 32 | + "file_path": self.file_path, |
| 33 | + "title": self.title, |
| 34 | + "content_hash": self.content_hash, |
| 35 | + "last_accessed": self.last_accessed.isoformat(), |
| 36 | + "access_count": self.access_count, |
| 37 | + "relevant_sections": self.relevant_sections, |
| 38 | + "topics": list(self.topics) |
| 39 | + } |
| 40 | + |
| 41 | + @classmethod |
| 42 | + def from_dict(cls, data: dict[str, Any]) -> 'DocumentReference': |
| 43 | + return cls( |
| 44 | + document_id=data["document_id"], |
| 45 | + file_path=data["file_path"], |
| 46 | + title=data["title"], |
| 47 | + content_hash=data["content_hash"], |
| 48 | + last_accessed=datetime.fromisoformat(data["last_accessed"]), |
| 49 | + access_count=data.get("access_count", 0), |
| 50 | + relevant_sections=data.get("relevant_sections", []), |
| 51 | + topics=set(data.get("topics", [])) |
| 52 | + ) |
| 53 | + |
| 54 | +@dataclass |
| 55 | +class ConversationTopic: |
| 56 | + """Represents a topic of conversation.""" |
| 57 | + name: str |
| 58 | + keywords: set[str] = field(default_factory=set) |
| 59 | + documents: set[str] = field(default_factory=set) |
| 60 | + first_mentioned: datetime = field(default_factory=datetime.now) |
| 61 | + last_mentioned: datetime = field(default_factory=datetime.now) |
| 62 | + mention_count: int = 0 |
| 63 | + |
| 64 | +class ContextTracker: |
| 65 | + """Tracks conversation context and document relationships.""" |
| 66 | + |
| 67 | + def __init__(self, config: dict[str, Any] | None = None): |
| 68 | + self.config = config or {} |
| 69 | + self.documents: dict[str, DocumentReference] = {} |
| 70 | + self.topics: dict[str, ConversationTopic] = {} |
| 71 | + self.current_focus: str | None = None # Current document focus |
| 72 | + self.conversation_thread: list[str] = [] # Topic progression |
| 73 | + |
| 74 | + # Configuration |
| 75 | + self.max_context_documents = self.config.get("max_context_documents", 5) |
| 76 | + self.topic_decay_hours = self.config.get("topic_decay_hours", 24) |
| 77 | + self.similarity_threshold = self.config.get("similarity_threshold", 0.3) |
| 78 | + |
| 79 | + def add_document_reference(self, document_id: str, file_path: str, |
| 80 | + title: str, content_hash: str, |
| 81 | + topics: set[str] | None = None) -> DocumentReference: |
| 82 | + """Add or update a document reference.""" |
| 83 | + if document_id in self.documents: |
| 84 | + doc_ref = self.documents[document_id] |
| 85 | + doc_ref.access_count += 1 |
| 86 | + doc_ref.last_accessed = datetime.now() |
| 87 | + else: |
| 88 | + doc_ref = DocumentReference( |
| 89 | + document_id=document_id, |
| 90 | + file_path=file_path, |
| 91 | + title=title, |
| 92 | + content_hash=content_hash, |
| 93 | + last_accessed=datetime.now(), |
| 94 | + access_count=1, |
| 95 | + topics=topics or set() |
| 96 | + ) |
| 97 | + self.documents[document_id] = doc_ref |
| 98 | + |
| 99 | + # Update topics |
| 100 | + if topics: |
| 101 | + for topic in topics: |
| 102 | + self.add_topic(topic, document_id) |
| 103 | + |
| 104 | + logger.debug(f"Added/updated document reference: {document_id}") |
| 105 | + return doc_ref |
| 106 | + |
| 107 | + def add_topic(self, topic_name: str, document_id: str | None = None): |
| 108 | + """Add or update a conversation topic.""" |
| 109 | + topic_name = topic_name.lower().strip() |
| 110 | + |
| 111 | + if topic_name in self.topics: |
| 112 | + topic = self.topics[topic_name] |
| 113 | + topic.mention_count += 1 |
| 114 | + topic.last_mentioned = datetime.now() |
| 115 | + else: |
| 116 | + topic = ConversationTopic( |
| 117 | + name=topic_name, |
| 118 | + keywords=self._extract_keywords(topic_name) |
| 119 | + ) |
| 120 | + self.topics[topic_name] = topic |
| 121 | + |
| 122 | + if document_id: |
| 123 | + topic.documents.add(document_id) |
| 124 | + |
| 125 | + # Update conversation thread |
| 126 | + if not self.conversation_thread or self.conversation_thread[-1] != topic_name: |
| 127 | + self.conversation_thread.append(topic_name) |
| 128 | + |
| 129 | + logger.debug(f"Added/updated topic: {topic_name}") |
| 130 | + |
| 131 | + def set_document_focus(self, document_id: str): |
| 132 | + """Set the current document focus for conversation.""" |
| 133 | + if document_id in self.documents: |
| 134 | + self.current_focus = document_id |
| 135 | + self.documents[document_id].access_count += 1 |
| 136 | + self.documents[document_id].last_accessed = datetime.now() |
| 137 | + logger.debug(f"Set document focus: {document_id}") |
| 138 | + else: |
| 139 | + logger.warning(f"Cannot set focus to unknown document: {document_id}") |
| 140 | + |
| 141 | + def get_current_context(self) -> dict[str, Any]: |
| 142 | + """Get the current conversation context.""" |
| 143 | + context = { |
| 144 | + "current_focus": self.current_focus, |
| 145 | + "active_documents": [], |
| 146 | + "recent_topics": [], |
| 147 | + "conversation_thread": self.conversation_thread[-10:], # Last 10 topics |
| 148 | + "suggested_documents": [] |
| 149 | + } |
| 150 | + |
| 151 | + # Get active documents (recently accessed) |
| 152 | + cutoff_time = datetime.now() - timedelta(hours=1) |
| 153 | + active_docs = [ |
| 154 | + doc for doc in self.documents.values() |
| 155 | + if doc.last_accessed > cutoff_time |
| 156 | + ] |
| 157 | + active_docs.sort(key=lambda x: x.last_accessed, reverse=True) |
| 158 | + context["active_documents"] = [ |
| 159 | + doc.to_dict() for doc in active_docs[:self.max_context_documents] |
| 160 | + ] |
| 161 | + |
| 162 | + # Get recent topics |
| 163 | + cutoff_time = datetime.now() - timedelta(hours=self.topic_decay_hours) |
| 164 | + recent_topics = [ |
| 165 | + topic for topic in self.topics.values() |
| 166 | + if topic.last_mentioned > cutoff_time |
| 167 | + ] |
| 168 | + recent_topics.sort(key=lambda x: x.last_mentioned, reverse=True) |
| 169 | + context["recent_topics"] = [ |
| 170 | + { |
| 171 | + "name": topic.name, |
| 172 | + "mention_count": topic.mention_count, |
| 173 | + "documents": list(topic.documents) |
| 174 | + } |
| 175 | + for topic in recent_topics[:10] |
| 176 | + ] |
| 177 | + |
| 178 | + # Suggest related documents |
| 179 | + if self.current_focus and self.current_focus in self.documents: |
| 180 | + current_doc = self.documents[self.current_focus] |
| 181 | + suggestions = self._find_related_documents(current_doc) |
| 182 | + context["suggested_documents"] = [ |
| 183 | + {"document_id": doc_id, "relevance_score": score} |
| 184 | + for doc_id, score in suggestions[:3] |
| 185 | + ] |
| 186 | + |
| 187 | + return context |
| 188 | + |
| 189 | + def get_document_context(self, document_id: str) -> dict[str, Any] | None: |
| 190 | + """Get context for a specific document.""" |
| 191 | + if document_id not in self.documents: |
| 192 | + return None |
| 193 | + |
| 194 | + doc_ref = self.documents[document_id] |
| 195 | + return { |
| 196 | + "document": doc_ref.to_dict(), |
| 197 | + "related_topics": [ |
| 198 | + topic.name for topic in self.topics.values() |
| 199 | + if document_id in topic.documents |
| 200 | + ], |
| 201 | + "related_documents": [ |
| 202 | + {"document_id": doc_id, "relevance_score": score} |
| 203 | + for doc_id, score in self._find_related_documents(doc_ref)[:5] |
| 204 | + ], |
| 205 | + "access_history": { |
| 206 | + "total_accesses": doc_ref.access_count, |
| 207 | + "last_accessed": doc_ref.last_accessed.isoformat() |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + def track_topic_shift(self, new_topic: str, context_clues: list[str] | None = None) -> bool: |
| 212 | + """Track when conversation shifts to a new topic.""" |
| 213 | + new_topic = new_topic.lower().strip() |
| 214 | + |
| 215 | + # Check if this is actually a new topic |
| 216 | + if self.conversation_thread and self.conversation_thread[-1] == new_topic: |
| 217 | + return False |
| 218 | + |
| 219 | + # Add the topic |
| 220 | + self.add_topic(new_topic) |
| 221 | + |
| 222 | + # Add context clues as keywords |
| 223 | + if context_clues and new_topic in self.topics: |
| 224 | + for clue in context_clues: |
| 225 | + self.topics[new_topic].keywords.add(clue.lower()) |
| 226 | + |
| 227 | + logger.info(f"Topic shift detected: {new_topic}") |
| 228 | + return True |
| 229 | + |
| 230 | + def suggest_related_content(self, query: str) -> list[dict[str, Any]]: |
| 231 | + """Suggest related content based on a query.""" |
| 232 | + suggestions = [] |
| 233 | + query_lower = query.lower() |
| 234 | + |
| 235 | + # Find matching documents |
| 236 | + for doc_ref in self.documents.values(): |
| 237 | + score = 0.0 |
| 238 | + |
| 239 | + # Title match |
| 240 | + if any(word in doc_ref.title.lower() for word in query_lower.split()): |
| 241 | + score += 0.5 |
| 242 | + |
| 243 | + # Topic match |
| 244 | + for topic_name in doc_ref.topics: |
| 245 | + if topic_name in query_lower or query_lower in topic_name: |
| 246 | + score += 0.3 |
| 247 | + |
| 248 | + # Keyword match in topics |
| 249 | + for topic_name in doc_ref.topics: |
| 250 | + if topic_name in self.topics: |
| 251 | + topic = self.topics[topic_name] |
| 252 | + for keyword in topic.keywords: |
| 253 | + if keyword in query_lower: |
| 254 | + score += 0.2 |
| 255 | + |
| 256 | + if score > self.similarity_threshold: |
| 257 | + suggestions.append({ |
| 258 | + "type": "document", |
| 259 | + "document_id": doc_ref.document_id, |
| 260 | + "title": doc_ref.title, |
| 261 | + "relevance_score": score, |
| 262 | + "access_count": doc_ref.access_count |
| 263 | + }) |
| 264 | + |
| 265 | + # Find matching topics |
| 266 | + for topic in self.topics.values(): |
| 267 | + score = 0.0 |
| 268 | + |
| 269 | + # Direct topic name match |
| 270 | + if topic.name in query_lower or query_lower in topic.name: |
| 271 | + score += 0.7 |
| 272 | + |
| 273 | + # Keyword match |
| 274 | + for keyword in topic.keywords: |
| 275 | + if keyword in query_lower: |
| 276 | + score += 0.3 |
| 277 | + |
| 278 | + if score > self.similarity_threshold: |
| 279 | + suggestions.append({ |
| 280 | + "type": "topic", |
| 281 | + "name": topic.name, |
| 282 | + "relevance_score": score, |
| 283 | + "mention_count": topic.mention_count, |
| 284 | + "related_documents": list(topic.documents) |
| 285 | + }) |
| 286 | + |
| 287 | + # Sort by relevance score |
| 288 | + suggestions.sort(key=lambda x: x["relevance_score"], reverse=True) |
| 289 | + return suggestions[:10] |
| 290 | + |
| 291 | + def _extract_keywords(self, text: str) -> set[str]: |
| 292 | + """Extract keywords from text.""" |
| 293 | + # Simple keyword extraction |
| 294 | + words = text.lower().split() |
| 295 | + # Filter out common stop words |
| 296 | + stop_words = {"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", "of", "with"} |
| 297 | + keywords = {word for word in words if len(word) > 2 and word not in stop_words} |
| 298 | + return keywords |
| 299 | + |
| 300 | + def _find_related_documents(self, doc_ref: DocumentReference) -> list[tuple[str, float]]: |
| 301 | + """Find documents related to the given document.""" |
| 302 | + related = [] |
| 303 | + |
| 304 | + for other_doc_id, other_doc in self.documents.items(): |
| 305 | + if other_doc_id == doc_ref.document_id: |
| 306 | + continue |
| 307 | + |
| 308 | + # Calculate similarity based on shared topics |
| 309 | + shared_topics = doc_ref.topics.intersection(other_doc.topics) |
| 310 | + if shared_topics: |
| 311 | + similarity_score = len(shared_topics) / max(len(doc_ref.topics), len(other_doc.topics), 1) |
| 312 | + related.append((other_doc_id, similarity_score)) |
| 313 | + |
| 314 | + # Sort by similarity score |
| 315 | + related.sort(key=lambda x: x[1], reverse=True) |
| 316 | + return related |
| 317 | + |
| 318 | + def cleanup_old_context(self): |
| 319 | + """Clean up old context data to prevent memory bloat.""" |
| 320 | + cutoff_time = datetime.now() - timedelta(hours=self.topic_decay_hours * 2) |
| 321 | + |
| 322 | + # Remove old topics with low mention count |
| 323 | + topics_to_remove = [ |
| 324 | + topic_name for topic_name, topic in self.topics.items() |
| 325 | + if topic.last_mentioned < cutoff_time and topic.mention_count < 2 |
| 326 | + ] |
| 327 | + |
| 328 | + for topic_name in topics_to_remove: |
| 329 | + del self.topics[topic_name] |
| 330 | + if topic_name in self.conversation_thread: |
| 331 | + self.conversation_thread.remove(topic_name) |
| 332 | + |
| 333 | + # Limit conversation thread length |
| 334 | + if len(self.conversation_thread) > 50: |
| 335 | + self.conversation_thread = self.conversation_thread[-30:] |
| 336 | + |
| 337 | + logger.info(f"Cleaned up {len(topics_to_remove)} old topics") |
| 338 | + |
| 339 | + def get_context_summary(self) -> dict[str, Any]: |
| 340 | + """Get a summary of the current context state.""" |
| 341 | + return { |
| 342 | + "total_documents": len(self.documents), |
| 343 | + "total_topics": len(self.topics), |
| 344 | + "current_focus": self.current_focus, |
| 345 | + "conversation_length": len(self.conversation_thread), |
| 346 | + "most_accessed_documents": [ |
| 347 | + { |
| 348 | + "document_id": doc.document_id, |
| 349 | + "title": doc.title, |
| 350 | + "access_count": doc.access_count |
| 351 | + } |
| 352 | + for doc in sorted(self.documents.values(), |
| 353 | + key=lambda x: x.access_count, reverse=True)[:5] |
| 354 | + ], |
| 355 | + "active_topics": [ |
| 356 | + { |
| 357 | + "name": topic.name, |
| 358 | + "mention_count": topic.mention_count, |
| 359 | + "document_count": len(topic.documents) |
| 360 | + } |
| 361 | + for topic in sorted(self.topics.values(), |
| 362 | + key=lambda x: x.mention_count, reverse=True)[:5] |
| 363 | + ] |
| 364 | + } |
0 commit comments