diff --git a/everyrow-mcp/src/everyrow_mcp/utils.py b/everyrow-mcp/src/everyrow_mcp/utils.py index d0a5573e..019aefec 100644 --- a/everyrow-mcp/src/everyrow_mcp/utils.py +++ b/everyrow-mcp/src/everyrow_mcp/utils.py @@ -1,5 +1,6 @@ """Utility functions for the everyrow MCP server.""" +import asyncio import ipaddress import json import logging @@ -37,6 +38,24 @@ } ) +# Restrict outbound fetches to standard HTTP(S) ports. +_ALLOWED_PORTS: frozenset[int] = frozenset({80, 443, 8080, 8443}) + + +def _validate_port(port: int | None) -> None: + """Reject non-standard ports for outbound URL fetching. + + Default ports (omitted from the URL) are always allowed. + Explicit ports must be in the ``_ALLOWED_PORTS`` allowlist. + """ + if port is None: + return # Default port for the scheme — always allowed + if port not in _ALLOWED_PORTS: + raise ValueError( + f"Port {port} is not permitted for URL fetching. " + f"Allowed: {sorted(_ALLOWED_PORTS)}" + ) + def _is_blocked_ip(addr: str) -> bool: """Check if an IP address falls within a blocked private/internal network.""" @@ -50,54 +69,78 @@ def _is_blocked_ip(addr: str) -> bool: return any(ip in net for net in _BLOCKED_NETWORKS) -def _validate_hostname(hostname: str) -> None: - """Validate that a hostname doesn't resolve to blocked IPs or metadata services. +async def _resolve_and_validate(hostname: str) -> str: + """Resolve a hostname, validate all IPs, and return the first safe IP. - Called both as a pre-flight check and at transport request time to close - the TOCTOU gap between DNS validation and HTTP connection. + For IP literals, validates directly and returns the canonical form. + For DNS names, resolves via ``getaddrinfo`` (offloaded to a thread pool + to avoid blocking the event loop) and checks every result. + + The returned IP is used by ``_SSRFSafeTransport`` to **pin** the TCP + connection, eliminating the TOCTOU gap between DNS validation and the + actual ``connect()`` call. Raises: - ValueError: If the hostname is blocked, resolves to a blocked IP, or cannot be resolved. + ValueError: If the hostname is blocked, resolves to a blocked IP, + or cannot be resolved. """ if hostname.lower() in _BLOCKED_HOSTNAMES: raise ValueError(f"Hostname is not permitted: {hostname}") # Direct IP literal — validate without DNS resolution + parsed_ip = None try: - ip = ipaddress.ip_address(hostname) + parsed_ip = ipaddress.ip_address(hostname) + except ValueError: + pass # Not an IP literal — fall through to DNS + + if parsed_ip is not None: # Unwrap IPv4-mapped IPv6 (e.g. ::ffff:127.0.0.1 → 127.0.0.1) - if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped: - ip = ip.ipv4_mapped - if any(ip in net for net in _BLOCKED_NETWORKS): + if isinstance(parsed_ip, ipaddress.IPv6Address) and parsed_ip.ipv4_mapped: + parsed_ip = parsed_ip.ipv4_mapped + if any(parsed_ip in net for net in _BLOCKED_NETWORKS): raise ValueError(f"Connection to blocked IP: {hostname}") - return - except ValueError: - pass # Not an IP literal, resolve via DNS + return str(parsed_ip) + # DNS name — resolve in a thread pool to avoid blocking the event loop + loop = asyncio.get_running_loop() try: - addrinfos = socket.getaddrinfo( - hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM + addrinfos = await loop.run_in_executor( + None, + socket.getaddrinfo, + hostname, + None, + socket.AF_UNSPEC, + socket.SOCK_STREAM, ) except socket.gaierror: raise ValueError(f"Could not resolve hostname: {hostname}") + if not addrinfos: + raise ValueError(f"Could not resolve hostname: {hostname}") + for _, _, _, _, sockaddr in addrinfos: if _is_blocked_ip(sockaddr[0]): logger.warning("SSRF blocked: %s resolved to %s", hostname, sockaddr[0]) raise ValueError(f"URL target is not permitted: {hostname}") + # All addresses safe — return the first resolved IP for connection pinning + return addrinfos[0][4][0] -def _validate_url_target(url: str) -> None: - """Resolve a URL's hostname and reject if any resolved IP is internal. + +async def _validate_url_target(url: str) -> None: + """Resolve a URL's hostname and reject if any resolved IP is internal or port is blocked. Raises: - ValueError: If the hostname resolves to a blocked network or cannot be resolved. + ValueError: If the hostname resolves to a blocked network, port is not + in the allowlist, or hostname cannot be resolved. """ parsed = urlparse(url) hostname = parsed.hostname if not hostname: raise ValueError(f"URL has no hostname: {url}") - _validate_hostname(hostname) + _validate_port(parsed.port) + await _resolve_and_validate(hostname) def is_url(value: str) -> bool: @@ -155,7 +198,7 @@ async def _check_redirect(response: httpx.Response) -> None: location = response.headers.get("location", "") if location: try: - _validate_url_target(location) + await _validate_url_target(location) except ValueError: # TooManyRedirects aborts the redirect chain — httpx # has no "redirect rejected" error type. @@ -166,20 +209,65 @@ async def _check_redirect(response: httpx.Response) -> None: class _SSRFSafeTransport(httpx.AsyncBaseTransport): - """Transport that re-validates hostnames at request time. + """Transport that resolves DNS, validates IPs, and pins connections to safe IPs. + + Eliminates the TOCTOU gap between DNS validation and TCP connection + by: - Narrows the TOCTOU window between DNS validation and connection to - near-zero by re-checking every hostname immediately before the inner - transport opens a TCP connection. + 1. Resolving the hostname ourselves via ``getaddrinfo`` + 2. Validating every resolved IP against the blocklist + 3. Rewriting the request URL to connect directly to the validated IP + 4. Preserving the original hostname in the ``Host`` header and TLS SNI + extension so the remote server sees the correct virtual host + + Also enforces the port allowlist at transport time as a + second check complementing the pre-flight validation. """ def __init__(self) -> None: self._transport = httpx.AsyncHTTPTransport(retries=0) async def handle_async_request(self, request: httpx.Request) -> httpx.Response: - if request.url.host: - _validate_hostname(request.url.host) - return await self._transport.handle_async_request(request) + hostname = request.url.host + if not hostname: + return await self._transport.handle_async_request(request) + + # Validate port (defence-in-depth — also checked pre-flight) + _validate_port(request.url.port) + + # Resolve DNS and validate — returns the first safe IP + resolved_ip = await _resolve_and_validate(hostname) + + # Pin the URL to the validated IP so the inner transport connects + # directly without a second (unvalidated) DNS lookup. + pinned_url = request.url.copy_with(host=resolved_ip) + + # Preserve the original hostname in the Host header. + # IPv6 addresses must be wrapped in brackets per RFC 7230 §5.4. + host_header = f"[{hostname}]" if ":" in hostname else hostname + if request.url.port and request.url.port not in (80, 443): + host_header = f"{host_header}:{request.url.port}" + headers = [ + (name, value) + for name, value in request.headers.items() + if name.lower() != "host" + ] + headers.insert(0, ("host", host_header)) + + # Preserve the original hostname for TLS SNI so the server + # presents the right certificate. + extensions = dict(request.extensions) + if request.url.scheme == "https": + extensions["sni_hostname"] = hostname.encode("idna") + + pinned_request = httpx.Request( + method=request.method, + url=pinned_url, + headers=headers, + stream=request.stream, + extensions=extensions, + ) + return await self._transport.handle_async_request(pinned_request) async def aclose(self) -> None: await self._transport.aclose() @@ -196,7 +284,7 @@ async def fetch_csv_from_url(url: str) -> pd.DataFrame: httpx.HTTPStatusError: On non-2xx responses. """ url = _normalise_google_sheets_url(url) - _validate_url_target(url) + await _validate_url_target(url) async with httpx.AsyncClient( transport=_SSRFSafeTransport(), diff --git a/everyrow-mcp/tests/test_utils.py b/everyrow-mcp/tests/test_utils.py index 47cc3e21..67237a4e 100644 --- a/everyrow-mcp/tests/test_utils.py +++ b/everyrow-mcp/tests/test_utils.py @@ -2,14 +2,19 @@ import socket from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch +import httpx import pandas as pd import pytest from everyrow_mcp.utils import ( + _ALLOWED_PORTS, _is_blocked_ip, _normalise_google_sheets_url, + _resolve_and_validate, + _SSRFSafeTransport, + _validate_port, _validate_url_target, is_url, resolve_output_path, @@ -254,42 +259,244 @@ def test_allows_public_ip(self): def test_allows_public_ip_2(self): assert _is_blocked_ip("93.184.216.34") is False - def test_validate_url_target_blocks_localhost(self): + @pytest.mark.asyncio + async def test_validate_url_target_blocks_localhost(self): with patch( "everyrow_mcp.utils.socket.getaddrinfo", return_value=_mock_resolve("localhost", "127.0.0.1"), ): with pytest.raises(ValueError, match="not permitted"): - _validate_url_target("http://localhost/secret") + await _validate_url_target("http://localhost/secret") - def test_validate_url_target_blocks_10_x(self): + @pytest.mark.asyncio + async def test_validate_url_target_blocks_10_x(self): with patch( "everyrow_mcp.utils.socket.getaddrinfo", return_value=_mock_resolve("internal.corp", "10.0.0.5"), ): with pytest.raises(ValueError, match="not permitted"): - _validate_url_target("http://internal.corp/data") + await _validate_url_target("http://internal.corp/data") - def test_validate_url_target_blocks_metadata_endpoint(self): + @pytest.mark.asyncio + async def test_validate_url_target_blocks_metadata_endpoint(self): with patch( "everyrow_mcp.utils.socket.getaddrinfo", return_value=_mock_resolve("metadata", "169.254.169.254"), ): with pytest.raises(ValueError, match="not permitted"): - _validate_url_target("http://metadata/latest/api-token") + await _validate_url_target("http://metadata/latest/api-token") - def test_validate_url_target_allows_public(self): + @pytest.mark.asyncio + async def test_validate_url_target_allows_public(self): with patch( "everyrow_mcp.utils.socket.getaddrinfo", return_value=_mock_resolve("example.com", "93.184.216.34"), ): # Should not raise - _validate_url_target("https://example.com/data.csv") + await _validate_url_target("https://example.com/data.csv") - def test_validate_url_target_blocks_unresolvable(self): + @pytest.mark.asyncio + async def test_validate_url_target_blocks_unresolvable(self): with patch( "everyrow_mcp.utils.socket.getaddrinfo", side_effect=socket.gaierror("Name resolution failed"), ): with pytest.raises(ValueError, match="Could not resolve"): - _validate_url_target("http://nonexistent.invalid/data") + await _validate_url_target("http://nonexistent.invalid/data") + + +# ── Port restriction tests ──────────────────────────────────── + + +class TestPortRestriction: + """Tests for port allowlist in URL validation.""" + + def test_default_port_allowed(self): + """None (default port) is always allowed.""" + _validate_port(None) + + def test_standard_ports_allowed(self): + for port in sorted(_ALLOWED_PORTS): + _validate_port(port) + + def test_redis_port_blocked(self): + with pytest.raises(ValueError, match="not permitted"): + _validate_port(6379) + + def test_postgres_port_blocked(self): + with pytest.raises(ValueError, match="not permitted"): + _validate_port(5432) + + def test_smtp_port_blocked(self): + with pytest.raises(ValueError, match="not permitted"): + _validate_port(25) + + def test_arbitrary_high_port_blocked(self): + with pytest.raises(ValueError, match="not permitted"): + _validate_port(9090) + + @pytest.mark.asyncio + async def test_validate_url_target_blocks_redis_port(self): + with patch( + "everyrow_mcp.utils.socket.getaddrinfo", + return_value=_mock_resolve("example.com", "93.184.216.34"), + ): + with pytest.raises(ValueError, match="not permitted"): + await _validate_url_target("http://example.com:6379/data") + + @pytest.mark.asyncio + async def test_validate_url_target_allows_port_443(self): + with patch( + "everyrow_mcp.utils.socket.getaddrinfo", + return_value=_mock_resolve("example.com", "93.184.216.34"), + ): + await _validate_url_target("https://example.com:443/data.csv") + + @pytest.mark.asyncio + async def test_validate_url_target_allows_port_8080(self): + with patch( + "everyrow_mcp.utils.socket.getaddrinfo", + return_value=_mock_resolve("example.com", "93.184.216.34"), + ): + await _validate_url_target("http://example.com:8080/data.csv") + + +# ── DNS-pinning tests ──────────────────────────────────────── + + +class TestResolveAndValidate: + """Tests for _resolve_and_validate (IP pinning).""" + + @pytest.mark.asyncio + async def test_returns_ip_for_public_hostname(self): + with patch( + "everyrow_mcp.utils.socket.getaddrinfo", + return_value=_mock_resolve("example.com", "93.184.216.34"), + ): + ip = await _resolve_and_validate("example.com") + assert ip == "93.184.216.34" + + @pytest.mark.asyncio + async def test_returns_ip_literal_directly(self): + ip = await _resolve_and_validate("8.8.8.8") + assert ip == "8.8.8.8" + + @pytest.mark.asyncio + async def test_blocks_private_ip_literal(self): + with pytest.raises(ValueError, match="blocked IP"): + await _resolve_and_validate("127.0.0.1") + + @pytest.mark.asyncio + async def test_blocks_metadata_hostname(self): + with pytest.raises(ValueError, match="not permitted"): + await _resolve_and_validate("metadata.google.internal") + + @pytest.mark.asyncio + async def test_blocks_hostname_resolving_to_private(self): + with patch( + "everyrow_mcp.utils.socket.getaddrinfo", + return_value=_mock_resolve("evil.com", "10.0.0.1"), + ): + with pytest.raises(ValueError, match="not permitted"): + await _resolve_and_validate("evil.com") + + @pytest.mark.asyncio + async def test_blocks_unresolvable(self): + with patch( + "everyrow_mcp.utils.socket.getaddrinfo", + side_effect=socket.gaierror("Name resolution failed"), + ): + with pytest.raises(ValueError, match="Could not resolve"): + await _resolve_and_validate("nonexistent.invalid") + + @pytest.mark.asyncio + async def test_unwraps_ipv4_mapped_ipv6(self): + with pytest.raises(ValueError, match="blocked IP"): + await _resolve_and_validate("::ffff:127.0.0.1") + + @pytest.mark.asyncio + async def test_allows_public_ipv6_literal(self): + ip = await _resolve_and_validate("2001:db8::1") + assert ip == "2001:db8::1" + + @pytest.mark.asyncio + async def test_blocks_ipv6_ula(self): + with pytest.raises(ValueError, match="blocked IP"): + await _resolve_and_validate("fd12:3456:789a::1") + + @pytest.mark.asyncio + async def test_blocks_ipv6_link_local(self): + with pytest.raises(ValueError, match="blocked IP"): + await _resolve_and_validate("fe80::1") + + +# ── IPv6 Host header tests ─────────────────────────────────── + + +class TestSSRFSafeTransportIPv6Host: + """Tests for IPv6 Host header bracket wrapping in _SSRFSafeTransport.""" + + @pytest.mark.asyncio + async def test_ipv6_host_header_no_port(self): + """IPv6 hostname gets brackets in Host header even without explicit port.""" + transport = _SSRFSafeTransport() + request = httpx.Request("GET", "http://[2001:db8::1]/data") + + with patch.object( + transport._transport, + "handle_async_request", + return_value=httpx.Response(200), + ) as mock: + with patch( + "everyrow_mcp.utils._resolve_and_validate", + new_callable=AsyncMock, + return_value="2001:db8::1", + ): + await transport.handle_async_request(request) + + pinned = mock.call_args[0][0] + assert pinned.headers["host"] == "[2001:db8::1]" + + @pytest.mark.asyncio + async def test_ipv6_host_header_with_non_standard_port(self): + """IPv6 hostname + non-standard port gets [addr]:port format.""" + transport = _SSRFSafeTransport() + request = httpx.Request("GET", "http://[2001:db8::1]:8080/data") + + with patch.object( + transport._transport, + "handle_async_request", + return_value=httpx.Response(200), + ) as mock: + with patch( + "everyrow_mcp.utils._resolve_and_validate", + new_callable=AsyncMock, + return_value="2001:db8::1", + ): + with patch("everyrow_mcp.utils._validate_port"): + await transport.handle_async_request(request) + + pinned = mock.call_args[0][0] + assert pinned.headers["host"] == "[2001:db8::1]:8080" + + @pytest.mark.asyncio + async def test_ipv4_host_header_no_brackets(self): + """IPv4 hostname does NOT get brackets.""" + transport = _SSRFSafeTransport() + request = httpx.Request("GET", "http://example.com:8080/data") + + with patch.object( + transport._transport, + "handle_async_request", + return_value=httpx.Response(200), + ) as mock: + with patch( + "everyrow_mcp.utils._resolve_and_validate", + new_callable=AsyncMock, + return_value="93.184.216.34", + ): + with patch("everyrow_mcp.utils._validate_port"): + await transport.handle_async_request(request) + + pinned = mock.call_args[0][0] + assert pinned.headers["host"] == "example.com:8080"