|
2 | 2 |
|
3 | 3 | import json |
4 | 4 | from collections.abc import AsyncGenerator |
5 | | -from typing import Any |
| 5 | +from typing import Any, cast |
6 | 6 | from unittest.mock import AsyncMock, MagicMock, Mock, patch |
7 | 7 | from urllib.parse import urlparse |
8 | 8 |
|
|
19 | 19 | import mcp.client.sse |
20 | 20 | from mcp import types |
21 | 21 | from mcp.client.session import ClientSession |
22 | | -from mcp.client.sse import _extract_session_id_from_endpoint, sse_client |
| 22 | +from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client |
23 | 23 | from mcp.server import Server, ServerRequestContext |
24 | 24 | from mcp.server.sse import SseServerTransport |
25 | 25 | from mcp.server.transport_security import TransportSecuritySettings |
26 | 26 | from mcp.shared._httpx_utils import McpHttpClientFactory |
27 | 27 | from mcp.shared.exceptions import MCPError |
| 28 | +from mcp.shared.message import SessionMessage |
28 | 29 | from mcp.types import ( |
29 | 30 | CallToolRequestParams, |
30 | 31 | CallToolResult, |
@@ -173,6 +174,100 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non |
173 | 174 | assert _extract_session_id_from_endpoint(endpoint_url) == expected |
174 | 175 |
|
175 | 176 |
|
| 177 | +@pytest.mark.parametrize( |
| 178 | + ("sse_url", "endpoint", "expected_url"), |
| 179 | + [ |
| 180 | + ( |
| 181 | + "https://example.com/gateway/deployment/v1/sse", |
| 182 | + "/v1/messages/?session_id=abc123", |
| 183 | + "https://example.com/gateway/deployment/v1/messages/?session_id=abc123", |
| 184 | + ), |
| 185 | + ( |
| 186 | + "https://example.com/gateway/deployment/v1/sse", |
| 187 | + "/gateway/deployment/v1/messages/?session_id=abc123", |
| 188 | + "https://example.com/gateway/deployment/v1/messages/?session_id=abc123", |
| 189 | + ), |
| 190 | + ( |
| 191 | + "https://example.com/gateway/deployment/sse", |
| 192 | + "/messages/?session_id=abc123", |
| 193 | + "https://example.com/messages/?session_id=abc123", |
| 194 | + ), |
| 195 | + ( |
| 196 | + "https://example.com/sse", |
| 197 | + "/messages/?session_id=abc123", |
| 198 | + "https://example.com/messages/?session_id=abc123", |
| 199 | + ), |
| 200 | + ( |
| 201 | + "https://example.com/gateway/sse", |
| 202 | + "/", |
| 203 | + "https://example.com/", |
| 204 | + ), |
| 205 | + ( |
| 206 | + "https://example.com/gateway/deployment/v1/sse", |
| 207 | + "messages/?session_id=abc123", |
| 208 | + "https://example.com/gateway/deployment/v1/messages/?session_id=abc123", |
| 209 | + ), |
| 210 | + ( |
| 211 | + "https://example.com/gateway/deployment/v1/sse", |
| 212 | + "https://example.com/messages/?session_id=abc123", |
| 213 | + "https://example.com/messages/?session_id=abc123", |
| 214 | + ), |
| 215 | + ], |
| 216 | +) |
| 217 | +def test_resolve_endpoint_url_preserves_gateway_path_prefix(sse_url: str, endpoint: str, expected_url: str) -> None: |
| 218 | + assert _resolve_endpoint_url(sse_url, endpoint) == expected_url |
| 219 | + |
| 220 | + |
| 221 | +@pytest.mark.anyio |
| 222 | +async def test_sse_client_posts_to_endpoint_with_gateway_path_prefix() -> None: |
| 223 | + """A gateway prefix on the public SSE URL is preserved for absolute-path endpoint events.""" |
| 224 | + posted = anyio.Event() |
| 225 | + posted_urls: list[str] = [] |
| 226 | + |
| 227 | + async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: |
| 228 | + yield ServerSentEvent(event="endpoint", data="/v1/messages/?session_id=abc123") |
| 229 | + await anyio.sleep_forever() |
| 230 | + |
| 231 | + mock_event_source = MagicMock() |
| 232 | + mock_event_source.aiter_sse.return_value = mock_aiter_sse() |
| 233 | + mock_event_source.response = MagicMock() |
| 234 | + mock_event_source.response.raise_for_status = MagicMock() |
| 235 | + |
| 236 | + mock_aconnect_sse = MagicMock() |
| 237 | + mock_aconnect_sse.__aenter__ = AsyncMock(return_value=mock_event_source) |
| 238 | + mock_aconnect_sse.__aexit__ = AsyncMock(return_value=None) |
| 239 | + |
| 240 | + async def mock_post(url: str, **kwargs: Any) -> MagicMock: |
| 241 | + posted_urls.append(url) |
| 242 | + posted.set() |
| 243 | + return MagicMock(status_code=200, raise_for_status=MagicMock()) |
| 244 | + |
| 245 | + mock_client = MagicMock() |
| 246 | + mock_client.__aenter__ = AsyncMock(return_value=mock_client) |
| 247 | + mock_client.__aexit__ = AsyncMock(return_value=None) |
| 248 | + mock_client.post = AsyncMock(side_effect=mock_post) |
| 249 | + |
| 250 | + def mock_factory( |
| 251 | + headers: dict[str, str] | None = None, |
| 252 | + timeout: httpx.Timeout | None = None, |
| 253 | + auth: httpx.Auth | None = None, |
| 254 | + ) -> httpx.AsyncClient: |
| 255 | + return cast(httpx.AsyncClient, mock_client) |
| 256 | + |
| 257 | + with patch("mcp.client.sse.aconnect_sse", return_value=mock_aconnect_sse): |
| 258 | + async with sse_client( |
| 259 | + "http://test/gateway/deployment/v1/sse", |
| 260 | + httpx_client_factory=mock_factory, |
| 261 | + ) as (_, write_stream): |
| 262 | + request = types.JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") |
| 263 | + await write_stream.send(SessionMessage(request)) |
| 264 | + |
| 265 | + with anyio.fail_after(5): |
| 266 | + await posted.wait() |
| 267 | + |
| 268 | + assert posted_urls == ["http://test/gateway/deployment/v1/messages/?session_id=abc123"] |
| 269 | + |
| 270 | + |
176 | 271 | @pytest.mark.anyio |
177 | 272 | async def test_sse_client_on_session_created_not_called_when_no_session_id(monkeypatch: pytest.MonkeyPatch) -> None: |
178 | 273 | """No session-created callback fires when the endpoint URL carries no session ID.""" |
|
0 commit comments