Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 115 additions & 27 deletions everyrow-mcp/src/everyrow_mcp/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility functions for the everyrow MCP server."""

import asyncio
import ipaddress
import json
import logging
Expand Down Expand Up @@ -37,6 +38,24 @@
}
)

# Restrict outbound fetches to standard HTTP(S) ports.
_ALLOWED_PORTS: frozenset[int] = frozenset({80, 443, 8080, 8443})


def _validate_port(port: int | None) -> None:
"""Reject non-standard ports for outbound URL fetching.

Default ports (omitted from the URL) are always allowed.
Explicit ports must be in the ``_ALLOWED_PORTS`` allowlist.
"""
if port is None:
return # Default port for the scheme — always allowed
if port not in _ALLOWED_PORTS:
raise ValueError(
f"Port {port} is not permitted for URL fetching. "
f"Allowed: {sorted(_ALLOWED_PORTS)}"
)


def _is_blocked_ip(addr: str) -> bool:
"""Check if an IP address falls within a blocked private/internal network."""
Expand All @@ -50,54 +69,78 @@ def _is_blocked_ip(addr: str) -> bool:
return any(ip in net for net in _BLOCKED_NETWORKS)


def _validate_hostname(hostname: str) -> None:
"""Validate that a hostname doesn't resolve to blocked IPs or metadata services.
async def _resolve_and_validate(hostname: str) -> str:
"""Resolve a hostname, validate all IPs, and return the first safe IP.

Called both as a pre-flight check and at transport request time to close
the TOCTOU gap between DNS validation and HTTP connection.
For IP literals, validates directly and returns the canonical form.
For DNS names, resolves via ``getaddrinfo`` (offloaded to a thread pool
to avoid blocking the event loop) and checks every result.

The returned IP is used by ``_SSRFSafeTransport`` to **pin** the TCP
connection, eliminating the TOCTOU gap between DNS validation and the
actual ``connect()`` call.

Raises:
ValueError: If the hostname is blocked, resolves to a blocked IP, or cannot be resolved.
ValueError: If the hostname is blocked, resolves to a blocked IP,
or cannot be resolved.
"""
if hostname.lower() in _BLOCKED_HOSTNAMES:
raise ValueError(f"Hostname is not permitted: {hostname}")

# Direct IP literal — validate without DNS resolution
parsed_ip = None
try:
ip = ipaddress.ip_address(hostname)
parsed_ip = ipaddress.ip_address(hostname)
except ValueError:
pass # Not an IP literal — fall through to DNS

if parsed_ip is not None:
# Unwrap IPv4-mapped IPv6 (e.g. ::ffff:127.0.0.1 → 127.0.0.1)
if isinstance(ip, ipaddress.IPv6Address) and ip.ipv4_mapped:
ip = ip.ipv4_mapped
if any(ip in net for net in _BLOCKED_NETWORKS):
if isinstance(parsed_ip, ipaddress.IPv6Address) and parsed_ip.ipv4_mapped:
parsed_ip = parsed_ip.ipv4_mapped
if any(parsed_ip in net for net in _BLOCKED_NETWORKS):
raise ValueError(f"Connection to blocked IP: {hostname}")
return
except ValueError:
pass # Not an IP literal, resolve via DNS
return str(parsed_ip)

# DNS name — resolve in a thread pool to avoid blocking the event loop
loop = asyncio.get_running_loop()
try:
addrinfos = socket.getaddrinfo(
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
addrinfos = await loop.run_in_executor(
None,
socket.getaddrinfo,
hostname,
None,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
)
except socket.gaierror:
raise ValueError(f"Could not resolve hostname: {hostname}")

if not addrinfos:
raise ValueError(f"Could not resolve hostname: {hostname}")

for _, _, _, _, sockaddr in addrinfos:
if _is_blocked_ip(sockaddr[0]):
logger.warning("SSRF blocked: %s resolved to %s", hostname, sockaddr[0])
raise ValueError(f"URL target is not permitted: {hostname}")

# All addresses safe — return the first resolved IP for connection pinning
return addrinfos[0][4][0]

def _validate_url_target(url: str) -> None:
"""Resolve a URL's hostname and reject if any resolved IP is internal.

async def _validate_url_target(url: str) -> None:
"""Resolve a URL's hostname and reject if any resolved IP is internal or port is blocked.

Raises:
ValueError: If the hostname resolves to a blocked network or cannot be resolved.
ValueError: If the hostname resolves to a blocked network, port is not
in the allowlist, or hostname cannot be resolved.
"""
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
raise ValueError(f"URL has no hostname: {url}")
_validate_hostname(hostname)
_validate_port(parsed.port)
await _resolve_and_validate(hostname)


def is_url(value: str) -> bool:
Expand Down Expand Up @@ -155,7 +198,7 @@ async def _check_redirect(response: httpx.Response) -> None:
location = response.headers.get("location", "")
if location:
try:
_validate_url_target(location)
await _validate_url_target(location)
except ValueError:
# TooManyRedirects aborts the redirect chain — httpx
# has no "redirect rejected" error type.
Expand All @@ -166,20 +209,65 @@ async def _check_redirect(response: httpx.Response) -> None:


class _SSRFSafeTransport(httpx.AsyncBaseTransport):
"""Transport that re-validates hostnames at request time.
"""Transport that resolves DNS, validates IPs, and pins connections to safe IPs.

Eliminates the TOCTOU gap between DNS validation and TCP connection
by:

Narrows the TOCTOU window between DNS validation and connection to
near-zero by re-checking every hostname immediately before the inner
transport opens a TCP connection.
1. Resolving the hostname ourselves via ``getaddrinfo``
2. Validating every resolved IP against the blocklist
3. Rewriting the request URL to connect directly to the validated IP
4. Preserving the original hostname in the ``Host`` header and TLS SNI
extension so the remote server sees the correct virtual host

Also enforces the port allowlist at transport time as a
second check complementing the pre-flight validation.
"""

def __init__(self) -> None:
self._transport = httpx.AsyncHTTPTransport(retries=0)

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
if request.url.host:
_validate_hostname(request.url.host)
return await self._transport.handle_async_request(request)
hostname = request.url.host
if not hostname:
return await self._transport.handle_async_request(request)

# Validate port (defence-in-depth — also checked pre-flight)
_validate_port(request.url.port)

# Resolve DNS and validate — returns the first safe IP
resolved_ip = await _resolve_and_validate(hostname)

# Pin the URL to the validated IP so the inner transport connects
# directly without a second (unvalidated) DNS lookup.
pinned_url = request.url.copy_with(host=resolved_ip)

# Preserve the original hostname in the Host header.
# IPv6 addresses must be wrapped in brackets per RFC 7230 §5.4.
host_header = f"[{hostname}]" if ":" in hostname else hostname
if request.url.port and request.url.port not in (80, 443):
host_header = f"{host_header}:{request.url.port}"
headers = [
(name, value)
for name, value in request.headers.items()
if name.lower() != "host"
]
headers.insert(0, ("host", host_header))

# Preserve the original hostname for TLS SNI so the server
# presents the right certificate.
extensions = dict(request.extensions)
if request.url.scheme == "https":
extensions["sni_hostname"] = hostname.encode("idna")

pinned_request = httpx.Request(
method=request.method,
url=pinned_url,
headers=headers,
stream=request.stream,
extensions=extensions,
)
return await self._transport.handle_async_request(pinned_request)

async def aclose(self) -> None:
await self._transport.aclose()
Expand All @@ -196,7 +284,7 @@ async def fetch_csv_from_url(url: str) -> pd.DataFrame:
httpx.HTTPStatusError: On non-2xx responses.
"""
url = _normalise_google_sheets_url(url)
_validate_url_target(url)
await _validate_url_target(url)

async with httpx.AsyncClient(
transport=_SSRFSafeTransport(),
Expand Down
Loading