From 8b583208b0770f734dfd7d059659ef22c18f7466 Mon Sep 17 00:00:00 2001 From: jokemanfire Date: Fri, 20 Mar 2026 17:34:49 +0800 Subject: [PATCH] fix(http): add host check --- .../rmcp/src/transport/common/http_header.rs | 2 + .../transport/streamable_http_server/tower.rs | 111 +++++++++++- crates/rmcp/tests/test_custom_headers.rs | 158 ++++++++++++++++++ 3 files changed, 269 insertions(+), 2 deletions(-) diff --git a/crates/rmcp/src/transport/common/http_header.rs b/crates/rmcp/src/transport/common/http_header.rs index 196d96fff..505c39a30 100644 --- a/crates/rmcp/src/transport/common/http_header.rs +++ b/crates/rmcp/src/transport/common/http_header.rs @@ -7,6 +7,7 @@ pub const JSON_MIME_TYPE: &str = "application/json"; /// Reserved headers that must not be overridden by user-supplied custom headers. /// `MCP-Protocol-Version` is in this list but is allowed through because the worker /// injects it after initialization. +#[allow(dead_code)] pub(crate) const RESERVED_HEADERS: &[&str] = &[ "accept", HEADER_SESSION_ID, @@ -36,6 +37,7 @@ pub(crate) fn validate_custom_header(name: &http::HeaderName) -> Result<(), Stri /// Extracts the `scope=` parameter from a `WWW-Authenticate` header value. /// Handles both quoted (`scope="files:read files:write"`) and unquoted (`scope=read:data`) forms. +#[allow(dead_code)] pub(crate) fn extract_scope_from_header(header: &str) -> Option { let header_lowercase = header.to_ascii_lowercase(); let scope_key = "scope="; diff --git a/crates/rmcp/src/transport/streamable_http_server/tower.rs b/crates/rmcp/src/transport/streamable_http_server/tower.rs index 7f4d888c7..eec3f41b3 100644 --- a/crates/rmcp/src/transport/streamable_http_server/tower.rs +++ b/crates/rmcp/src/transport/streamable_http_server/tower.rs @@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration}; use bytes::Bytes; use futures::{StreamExt, future::BoxFuture}; -use http::{Method, Request, Response, header::ALLOW}; +use http::{HeaderMap, Method, Request, Response, header::ALLOW}; use http_body::Body; use http_body_util::{BodyExt, Full, combinators::BoxBody}; use tokio_stream::wrappers::ReceiverStream; @@ -29,8 +29,8 @@ use crate::{ }, }; -#[derive(Debug, Clone)] #[non_exhaustive] +#[derive(Debug, Clone)] pub struct StreamableHttpServerConfig { /// The ping message duration for SSE connections. pub sse_keep_alive: Option, @@ -49,6 +49,16 @@ pub struct StreamableHttpServerConfig { /// When this token is cancelled, all active sessions are terminated and /// the server stops accepting new requests. pub cancellation_token: CancellationToken, + /// Allowed hostnames or `host:port` authorities for inbound `Host` validation. + /// + /// By default, Streamable HTTP servers only accept loopback hosts to + /// prevent DNS rebinding attacks against locally running servers. Public + /// deployments should override this list with their own hostnames. + /// examples: + /// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"] + /// or with ports: + /// allowed_hosts = ["example.com", "example.com:8080"] + pub allowed_hosts: Vec, } impl Default for StreamableHttpServerConfig { @@ -59,11 +69,24 @@ impl Default for StreamableHttpServerConfig { stateful_mode: true, json_response: false, cancellation_token: CancellationToken::new(), + allowed_hosts: vec!["localhost".into(), "127.0.0.1".into(), "::1".into()], } } } impl StreamableHttpServerConfig { + pub fn with_allowed_hosts( + mut self, + allowed_hosts: impl IntoIterator>, + ) -> Self { + self.allowed_hosts = allowed_hosts.into_iter().map(Into::into).collect(); + self + } + /// Disable allowed hosts. This will allow requests with any `Host` header, which is NOT recommended for public deployments. + pub fn disable_allowed_hosts(mut self) -> Self { + self.allowed_hosts.clear(); + self + } pub fn with_sse_keep_alive(mut self, duration: Option) -> Self { self.sse_keep_alive = duration; self @@ -130,6 +153,87 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box Ok(()) } +fn forbidden_response(message: impl Into) -> BoxResponse { + Response::builder() + .status(http::StatusCode::FORBIDDEN) + .body(Full::new(Bytes::from(message.into())).boxed()) + .expect("valid response") +} + +fn normalize_host(host: &str) -> String { + host.trim_matches('[') + .trim_matches(']') + .to_ascii_lowercase() +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct NormalizedAuthority { + host: String, + port: Option, +} + +fn normalize_authority(host: &str, port: Option) -> NormalizedAuthority { + NormalizedAuthority { + host: normalize_host(host), + port, + } +} + +fn parse_allowed_authority(allowed: &str) -> Option { + let allowed = allowed.trim(); + if allowed.is_empty() { + return None; + } + + if let Ok(authority) = http::uri::Authority::try_from(allowed) { + return Some(normalize_authority(authority.host(), authority.port_u16())); + } + + Some(normalize_authority(allowed, None)) +} + +fn host_is_allowed(host: &NormalizedAuthority, allowed_hosts: &[String]) -> bool { + if allowed_hosts.is_empty() { + // If the allowed hosts list is empty, allow all hosts (not recommended). + return true; + } + allowed_hosts + .iter() + .filter_map(|allowed| parse_allowed_authority(allowed)) + .any(|allowed| { + allowed.host == host.host + && match allowed.port { + Some(port) => host.port == Some(port), + None => true, + } + }) +} + +fn parse_host_header(headers: &HeaderMap) -> Result { + let Some(host) = headers.get(http::header::HOST) else { + return Err(forbidden_response("Forbidden:missing_host header")); + }; + + let host = host + .to_str() + .map_err(|_| forbidden_response("Forbidden: Invalid Host header encoding"))?; + let authority = http::uri::Authority::try_from(host) + .map_err(|_| forbidden_response("Forbidden: Invalid Host header"))?; + Ok(normalize_authority(authority.host(), authority.port_u16())) +} + +fn validate_dns_rebinding_headers( + headers: &HeaderMap, + config: &StreamableHttpServerConfig, +) -> Result<(), BoxResponse> { + let host = parse_host_header(headers)?; + if !host_is_allowed(&host, &config.allowed_hosts) { + return Err(forbidden_response("Forbidden: Host header is not allowed")); + } + + Ok(()) +} + /// # Streamable HTTP server /// /// An HTTP service that implements the @@ -279,6 +383,9 @@ where B: Body + Send + 'static, B::Error: Display, { + if let Err(response) = validate_dns_rebinding_headers(request.headers(), &self.config) { + return response; + } let method = request.method().clone(); let allowed_methods = match self.config.stateful_mode { true => "GET, POST, DELETE", diff --git a/crates/rmcp/tests/test_custom_headers.rs b/crates/rmcp/tests/test_custom_headers.rs index 7d4316d3e..558ff623d 100644 --- a/crates/rmcp/tests/test_custom_headers.rs +++ b/crates/rmcp/tests/test_custom_headers.rs @@ -761,6 +761,7 @@ async fn test_server_rejects_unsupported_protocol_version() { .method(Method::POST) .header("Accept", "application/json, text/event-stream") .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") .body(Full::new(Bytes::from(init_body.to_string()))) .unwrap(); @@ -785,6 +786,7 @@ async fn test_server_rejects_unsupported_protocol_version() { .method(Method::POST) .header("Accept", "application/json, text/event-stream") .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") .header("mcp-session-id", &session_id) .header("mcp-protocol-version", "2025-03-26") .body(Full::new(Bytes::from(initialized_body.to_string()))) @@ -802,6 +804,7 @@ async fn test_server_rejects_unsupported_protocol_version() { .method(Method::POST) .header("Accept", "application/json, text/event-stream") .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") .header("mcp-session-id", &session_id) .header("mcp-protocol-version", "2025-03-26") .body(Full::new(Bytes::from(valid_body.to_string()))) @@ -823,6 +826,7 @@ async fn test_server_rejects_unsupported_protocol_version() { .method(Method::POST) .header("Accept", "application/json, text/event-stream") .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") .header("mcp-session-id", &session_id) .header("mcp-protocol-version", "9999-01-01") .body(Full::new(Bytes::from(invalid_body.to_string()))) @@ -844,6 +848,7 @@ async fn test_server_rejects_unsupported_protocol_version() { .method(Method::POST) .header("Accept", "application/json, text/event-stream") .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") .header("mcp-session-id", &session_id) .body(Full::new(Bytes::from(no_version_body.to_string()))) .unwrap(); @@ -870,3 +875,156 @@ fn test_protocol_version_utilities() { assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_03_26)); assert!(ProtocolVersion::KNOWN_VERSIONS.contains(&ProtocolVersion::V_2025_06_18)); } + +/// Integration test: Verify server validates only the Host header for DNS rebinding protection +#[tokio::test] +#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))] +async fn test_server_validates_host_header_for_dns_rebinding_protection() { + use std::sync::Arc; + + use bytes::Bytes; + use http::{Method, Request, header::CONTENT_TYPE}; + use http_body_util::Full; + use rmcp::{ + handler::server::ServerHandler, + model::{ServerCapabilities, ServerInfo}, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }; + use serde_json::json; + + #[derive(Clone)] + struct TestHandler; + + impl ServerHandler for TestHandler { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().build()) + } + } + + let service = StreamableHttpService::new( + || Ok(TestHandler), + Arc::new(LocalSessionManager::default()), + StreamableHttpServerConfig::default(), + ); + + let init_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + let allowed_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") + .header("Origin", "http://localhost:8080") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(allowed_request).await; + assert_eq!(response.status(), http::StatusCode::OK); + + let bad_host_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "attacker.example") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(bad_host_request).await; + assert_eq!(response.status(), http::StatusCode::FORBIDDEN); + + let ignored_origin_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") + .header("Origin", "http://attacker.example") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(ignored_origin_request).await; + assert_eq!(response.status(), http::StatusCode::OK); +} + +/// Integration test: Verify server can enforce an allowed Host port when configured +#[tokio::test] +#[cfg(all(feature = "transport-streamable-http-server", feature = "server",))] +async fn test_server_validates_host_header_port_for_dns_rebinding_protection() { + use std::sync::Arc; + + use bytes::Bytes; + use http::{Method, Request, header::CONTENT_TYPE}; + use http_body_util::Full; + use rmcp::{ + handler::server::ServerHandler, + model::{ServerCapabilities, ServerInfo}, + transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::local::LocalSessionManager, + }, + }; + use serde_json::json; + + #[derive(Clone)] + struct TestHandler; + + impl ServerHandler for TestHandler { + fn get_info(&self) -> ServerInfo { + ServerInfo::new(ServerCapabilities::builder().build()) + } + } + + let service = StreamableHttpService::new( + || Ok(TestHandler), + Arc::new(LocalSessionManager::default()), + StreamableHttpServerConfig::default().with_allowed_hosts(["localhost:8080"]), + ); + + let init_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "1.0.0" + } + } + }); + + let allowed_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:8080") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(allowed_request).await; + assert_eq!(response.status(), http::StatusCode::OK); + + let wrong_port_request = Request::builder() + .method(Method::POST) + .header("Accept", "application/json, text/event-stream") + .header(CONTENT_TYPE, "application/json") + .header("Host", "localhost:3000") + .body(Full::new(Bytes::from(init_body.to_string()))) + .unwrap(); + + let response = service.handle(wrong_port_request).await; + assert_eq!(response.status(), http::StatusCode::FORBIDDEN); +}