|
| 1 | +""" |
| 2 | +Monkey-patch for Google ADK's MCPSessionManager to fix session invalidation on server restart. |
| 3 | +
|
| 4 | +This module patches both the MCPSessionManager.create_session method and the retry_on_errors |
| 5 | +decorator to properly handle the case where an MCP server restarts and loses session state. |
| 6 | +
|
| 7 | +Root Cause: |
| 8 | +----------- |
| 9 | +When an MCP server restarts: |
| 10 | +1. The server loses all session state |
| 11 | +2. Client-side session streams remain open (not disconnected) |
| 12 | +3. Cached session appears valid because _is_session_disconnected() only checks local streams |
| 13 | +4. Server returns 404 for requests with old session IDs |
| 14 | +5. Tool calls time out waiting for responses |
| 15 | +6. On retry, the same bad cached session is reused |
| 16 | +
|
| 17 | +The Fix: |
| 18 | +-------- |
| 19 | +We patch the retry_on_errors decorator to: |
| 20 | +1. Detect when an error occurs during MCP operations |
| 21 | +2. Force-close the streams of the cached session |
| 22 | +3. This makes _is_session_disconnected() return True |
| 23 | +4. On retry, create_session() sees the session is disconnected and creates a fresh one |
| 24 | +
|
| 25 | +This is a temporary workaround until the fix is merged upstream in Google ADK. |
| 26 | +
|
| 27 | +Issue: https://github.com/agentic-layer/sdk-python/issues/XXX |
| 28 | +""" |
| 29 | + |
| 30 | +import functools |
| 31 | +import logging |
| 32 | +from typing import Any, Callable |
| 33 | + |
| 34 | +from google.adk.tools.mcp_tool import mcp_toolset |
| 35 | +from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager |
| 36 | + |
| 37 | +logger = logging.getLogger(__name__) |
| 38 | + |
| 39 | +# Store the original decorator |
| 40 | +_original_retry_on_errors = None |
| 41 | + |
| 42 | + |
| 43 | +def _patched_retry_on_errors(func: Callable[..., Any]) -> Callable[..., Any]: |
| 44 | + """Patched version of retry_on_errors that invalidates sessions on error. |
| 45 | +
|
| 46 | + This wraps the original decorator and adds logic to close cached session streams |
| 47 | + when an error occurs, ensuring the session is marked as disconnected for retry. |
| 48 | + """ |
| 49 | + # First, apply the original decorator if it exists |
| 50 | + if _original_retry_on_errors: |
| 51 | + func = _original_retry_on_errors(func) |
| 52 | + |
| 53 | + @functools.wraps(func) |
| 54 | + async def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: |
| 55 | + try: |
| 56 | + return await func(self, *args, **kwargs) |
| 57 | + except Exception as e: |
| 58 | + # When an error occurs, try to invalidate any cached MCP sessions |
| 59 | + # by closing their streams, so retry gets a fresh session |
| 60 | + if hasattr(self, "_mcp_session_manager"): |
| 61 | + session_manager: MCPSessionManager = self._mcp_session_manager |
| 62 | + logger.info( |
| 63 | + f"[PATCH] Error in MCP operation ({func.__name__}), invalidating cached sessions: {type(e).__name__}" |
| 64 | + ) |
| 65 | + |
| 66 | + # Access the session cache and close all sessions' streams |
| 67 | + if hasattr(session_manager, "_sessions"): |
| 68 | + try: |
| 69 | + # Use the lock to safely access sessions |
| 70 | + num_sessions = len(session_manager._sessions) # type: ignore |
| 71 | + logger.debug(f"[PATCH] Found {num_sessions} cached sessions to invalidate") |
| 72 | + |
| 73 | + # We can't use the lock here because we're already in an async context |
| 74 | + # and the lock might be held. Instead, just try to close streams. |
| 75 | + for session_key, (session, _, _) in list(session_manager._sessions.items()): # type: ignore |
| 76 | + try: |
| 77 | + logger.debug(f"[PATCH] Invalidating session: {session_key}") |
| 78 | + |
| 79 | + # Force-close the read stream |
| 80 | + if hasattr(session, "_read_stream"): |
| 81 | + stream = session._read_stream |
| 82 | + logger.debug( |
| 83 | + f"[PATCH] Read stream type: {type(stream).__name__}, has aclose: {hasattr(stream, 'aclose')}" |
| 84 | + ) |
| 85 | + if hasattr(stream, "aclose"): |
| 86 | + await stream.aclose() |
| 87 | + logger.debug("[PATCH] Closed read stream via aclose()") |
| 88 | + elif hasattr(stream, "close"): |
| 89 | + stream.close() |
| 90 | + logger.debug("[PATCH] Closed read stream via close()") |
| 91 | + else: |
| 92 | + logger.debug("[PATCH] Session has no _read_stream") |
| 93 | + |
| 94 | + # Force-close the write stream |
| 95 | + if hasattr(session, "_write_stream"): |
| 96 | + stream = session._write_stream |
| 97 | + logger.debug( |
| 98 | + f"[PATCH] Write stream type: {type(stream).__name__}, has aclose: {hasattr(stream, 'aclose')}" |
| 99 | + ) |
| 100 | + if hasattr(stream, "aclose"): |
| 101 | + await stream._write_stream.aclose() |
| 102 | + logger.debug("[PATCH] Closed write stream via aclose()") |
| 103 | + elif hasattr(stream, "close"): |
| 104 | + stream.close() |
| 105 | + logger.debug("[PATCH] Closed write stream via close()") |
| 106 | + else: |
| 107 | + logger.debug("[PATCH] Session has no _write_stream") |
| 108 | + |
| 109 | + logger.info(f"[PATCH] Successfully invalidated session {session_key}") |
| 110 | + except Exception as close_err: |
| 111 | + logger.warning(f"[PATCH] Could not close streams for {session_key}: {close_err}") |
| 112 | + except Exception as invalidate_err: |
| 113 | + logger.error(f"[PATCH] Error invalidating sessions: {invalidate_err}", exc_info=True) |
| 114 | + else: |
| 115 | + logger.debug("[PATCH] Session manager has no _sessions attribute") |
| 116 | + else: |
| 117 | + logger.debug(f"[PATCH] Object {type(self).__name__} has no _mcp_session_manager attribute") |
| 118 | + |
| 119 | + # Re-raise the exception so the original decorator can handle retry |
| 120 | + raise |
| 121 | + |
| 122 | + return wrapper |
| 123 | + |
| 124 | + |
| 125 | +def apply_mcp_session_patch() -> None: |
| 126 | + """Apply the monkey-patch to the retry_on_errors decorator. |
| 127 | +
|
| 128 | + This should be called once during application initialization before |
| 129 | + any MCP tools are created. |
| 130 | + """ |
| 131 | + global _original_retry_on_errors |
| 132 | + |
| 133 | + if _original_retry_on_errors is None: |
| 134 | + logger.info("Applying MCP session manager patch for server restart handling") |
| 135 | + |
| 136 | + # Store the original decorator |
| 137 | + from google.adk.tools.mcp_tool import mcp_session_manager |
| 138 | + from google.adk.tools.mcp_tool.mcp_toolset import McpToolset |
| 139 | + |
| 140 | + _original_retry_on_errors = mcp_session_manager.retry_on_errors |
| 141 | + |
| 142 | + # Replace the decorator in the module |
| 143 | + mcp_session_manager.retry_on_errors = _patched_retry_on_errors |
| 144 | + |
| 145 | + # Re-decorate the methods in McpToolset that use @retry_on_errors |
| 146 | + # Find all methods that were decorated and re-decorate them |
| 147 | + for attr_name in dir(McpToolset): |
| 148 | + if not attr_name.startswith("_"): |
| 149 | + attr = getattr(McpToolset, attr_name) |
| 150 | + if callable(attr) and hasattr(attr, "__wrapped__"): |
| 151 | + # This is likely a decorated method |
| 152 | + # Re-decorate it with our patched decorator |
| 153 | + original_func = attr.__wrapped__ |
| 154 | + setattr(McpToolset, attr_name, _patched_retry_on_errors(original_func)) |
| 155 | + logger.debug(f"Re-decorated McpToolset.{attr_name}") |
| 156 | + |
| 157 | + logger.info("MCP session manager patch applied successfully") |
| 158 | + else: |
| 159 | + logger.warning("MCP session manager patch already applied") |
| 160 | + |
| 161 | + |
| 162 | +def remove_mcp_session_patch() -> None: |
| 163 | + """Remove the monkey-patch and restore original behavior. |
| 164 | +
|
| 165 | + This is primarily for testing purposes. |
| 166 | + """ |
| 167 | + global _original_retry_on_errors |
| 168 | + |
| 169 | + if _original_retry_on_errors is not None: |
| 170 | + logger.info("Removing MCP session manager patch") |
| 171 | + |
| 172 | + from google.adk.tools.mcp_tool import mcp_session_manager |
| 173 | + |
| 174 | + mcp_session_manager.retry_on_errors = _original_retry_on_errors |
| 175 | + |
| 176 | + if hasattr(mcp_toolset, "retry_on_errors"): |
| 177 | + mcp_toolset.retry_on_errors = _original_retry_on_errors |
| 178 | + |
| 179 | + _original_retry_on_errors = None |
| 180 | + logger.info("MCP session manager patch removed") |
0 commit comments