diff --git a/crates/rmcp/src/transport/auth.rs b/crates/rmcp/src/transport/auth.rs index cd4176f6..db1f3536 100644 --- a/crates/rmcp/src/transport/auth.rs +++ b/crates/rmcp/src/transport/auth.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, future::Future, + net::{IpAddr, Ipv4Addr, Ipv6Addr}, pin::Pin, sync::Arc, time::{Duration, SystemTime, UNIX_EPOCH}, @@ -16,7 +17,7 @@ use oauth2::{ }; use reqwest::{ Client as ReqwestClient, IntoUrl, StatusCode, Url, - header::{AUTHORIZATION, CONTENT_TYPE, WWW_AUTHENTICATE}, + header::{AUTHORIZATION, CONTENT_TYPE, LOCATION, WWW_AUTHENTICATE}, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -28,6 +29,12 @@ use crate::transport::common::http_header::HEADER_MCP_PROTOCOL_VERSION; const DEFAULT_HTTP_TIMEOUT: Duration = Duration::from_secs(30); const MAX_OAUTH_HTTP_RESPONSE_BODY_BYTES: usize = 1024 * 1024; +const MAX_OAUTH_DISCOVERY_REDIRECTS: usize = 10; +const CLOUD_METADATA_HOSTS: &[&str] = &[ + "metadata", + "metadata.google.internal", + "metadata.azure.internal", +]; /// Redirect handling requested for an outbound OAuth HTTP operation. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -821,6 +828,100 @@ fn is_https_url(value: &str) -> bool { } impl AuthorizationManager { + fn is_http_url(url: &Url) -> bool { + matches!(url.scheme(), "http" | "https") && url.host_str().is_some() + } + + fn is_same_origin(base: &Url, candidate: &Url) -> bool { + base.scheme() == candidate.scheme() + && base + .host_str() + .zip(candidate.host_str()) + .is_some_and(|(base, candidate)| base.eq_ignore_ascii_case(candidate)) + && base.port_or_known_default() == candidate.port_or_known_default() + } + + fn is_same_origin_resource_metadata_url(base_url: &Url, candidate: &Url) -> bool { + Self::is_http_url(candidate) && Self::is_same_origin(base_url, candidate) + } + + fn is_disallowed_metadata_ipv4(addr: Ipv4Addr) -> bool { + let octets = addr.octets(); + addr.is_private() + || addr.is_loopback() + || addr.is_link_local() + || addr.is_broadcast() + || addr.is_unspecified() + || addr.is_multicast() + || octets[0] == 0 + || (octets[0] == 100 && (64..=127).contains(&octets[1])) + || (octets[0] == 198 && matches!(octets[1], 18 | 19)) + } + + fn is_disallowed_metadata_ipv6(addr: Ipv6Addr) -> bool { + if let Some(mapped) = addr.to_ipv4_mapped() { + return Self::is_disallowed_metadata_ipv4(mapped); + } + + let segments = addr.segments(); + addr.is_loopback() + || addr.is_unspecified() + || addr.is_multicast() + || (segments[0] & 0xffc0) == 0xfe80 + || (segments[0] & 0xfe00) == 0xfc00 + } + + fn is_disallowed_metadata_hostname(host: &str) -> bool { + matches!(host, "localhost") + || host.ends_with(".localhost") + || CLOUD_METADATA_HOSTS.contains(&host) + } + + fn is_disallowed_metadata_host(host: &str) -> bool { + let host = host.trim_end_matches('.').to_ascii_lowercase(); + if Self::is_disallowed_metadata_hostname(&host) { + return true; + } + + match host.parse::() { + Ok(IpAddr::V4(addr)) => Self::is_disallowed_metadata_ipv4(addr), + Ok(IpAddr::V6(addr)) => Self::is_disallowed_metadata_ipv6(addr), + Err(_) => false, + } + } + + fn is_allowed_authorization_server_metadata_url(url: &Url) -> bool { + Self::is_http_url(url) + && url + .host_str() + .is_some_and(|host| !Self::is_disallowed_metadata_host(host)) + } + + fn resolve_resource_metadata_url(value: &str, base_url: &Url) -> Option { + let value = value.trim(); + if value.is_empty() { + debug!("ignoring empty resource_metadata value"); + return None; + } + + let url = match Url::parse(value).or_else(|_| base_url.join(value)) { + Ok(url) => url, + Err(error) => { + debug!("failed to parse resource metadata value `{value}` as URL: {error}"); + return None; + } + }; + + if Self::is_same_origin_resource_metadata_url(base_url, &url) { + Some(url) + } else { + warn!( + "rejecting resource metadata URL `{url}` because it is not same-origin with `{base_url}`" + ); + None + } + } + fn well_known_paths(base_path: &str, resource: &str) -> Vec { let trimmed = base_path.trim_start_matches('/').trim_end_matches('/'); let mut candidates = Vec::new(); @@ -1771,6 +1872,11 @@ impl AuthorizationManager { }, }; + if !Self::is_allowed_authorization_server_metadata_url(&candidate_url) { + warn!("rejecting authorization server metadata URL `{candidate_url}`"); + continue; + } + if candidate_url.path().contains("/.well-known/") { if let Some(metadata) = self.fetch_authorization_metadata(&candidate_url).await? { return Ok(Some(metadata)); @@ -1889,18 +1995,49 @@ impl AuthorizationManager { } async fn discovery_get(&self, url: &Url) -> Result { - let request = oauth2::http::Request::builder() - .method("GET") - .uri(url.as_str()) - .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") - .body(Vec::new()) - .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; - self.http_client - .execute(OAuthHttpRequest::new( - request, - OAuthHttpRedirectPolicy::Follow, - )) - .await + let mut current_url = url.clone(); + for _ in 0..MAX_OAUTH_DISCOVERY_REDIRECTS { + let request = oauth2::http::Request::builder() + .method("GET") + .uri(current_url.as_str()) + .header(HEADER_MCP_PROTOCOL_VERSION, "2024-11-05") + .body(Vec::new()) + .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; + let response = self + .http_client + .execute(OAuthHttpRequest::new( + request, + OAuthHttpRedirectPolicy::Stop, + )) + .await?; + + if !response.status().is_redirection() { + return Ok(response); + } + + let Some(location) = response.headers().get(LOCATION) else { + return Ok(response); + }; + let location = location + .to_str() + .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; + let next_url = current_url + .join(location) + .map_err(|error| OAuthHttpClientError::new(error.to_string()))?; + + if Self::is_http_url(&next_url) && Self::is_same_origin(¤t_url, &next_url) { + current_url = next_url; + continue; + } + + return Err(OAuthHttpClientError::new(format!( + "OAuth discovery redirect to non-same-origin URL rejected: {next_url}" + ))); + } + + Err(OAuthHttpClientError::new(format!( + "OAuth discovery exceeded {MAX_OAUTH_DISCOVERY_REDIRECTS} redirects" + ))) } /// extract parameters from WWW-Authenticate header (resource_metadata and scope) @@ -1915,15 +2052,10 @@ impl AuthorizationManager { let global_pos = search_offset + pos + resource_key.len(); let value_slice = &header[global_pos..]; if let Some((value, consumed)) = Self::parse_next_header_value(value_slice) { - if let Ok(url) = Url::parse(&value) { + if let Some(url) = Self::resolve_resource_metadata_url(&value, base_url) { params.resource_metadata_url = Some(url); break; } - if let Ok(url) = base_url.join(&value) { - params.resource_metadata_url = Some(url); - break; - } - debug!("failed to parse resource metadata value `{value}` as URL"); search_offset = global_pos + consumed; continue; } else { @@ -3035,6 +3167,14 @@ mod tests { .unwrap() } + fn redirect_response(location: &str) -> HttpResponse { + oauth2::http::Response::builder() + .status(302) + .header("location", location) + .body(Vec::new()) + .unwrap() + } + #[tokio::test] async fn custom_http_client_handles_protected_resource_discovery() { let challenge = oauth2::http::Response::builder() @@ -3077,26 +3217,147 @@ mod tests { RecordedOAuthRequest { method: "GET".to_string(), uri: "https://mcp.example.com/mcp".to_string(), - redirect_policy: OAuthHttpRedirectPolicy::Follow, + redirect_policy: OAuthHttpRedirectPolicy::Stop, body: Vec::new(), }, RecordedOAuthRequest { method: "GET".to_string(), uri: "https://mcp.example.com/.well-known/oauth-protected-resource".to_string(), - redirect_policy: OAuthHttpRedirectPolicy::Follow, + redirect_policy: OAuthHttpRedirectPolicy::Stop, body: Vec::new(), }, RecordedOAuthRequest { method: "GET".to_string(), uri: "https://auth.example.com/.well-known/oauth-authorization-server" .to_string(), - redirect_policy: OAuthHttpRedirectPolicy::Follow, + redirect_policy: OAuthHttpRedirectPolicy::Stop, body: Vec::new(), }, ] ); } + #[tokio::test] + async fn discovery_get_follows_same_origin_redirects() { + let client = RecordingOAuthHttpClient::with_responses(vec![ + redirect_response("/redirected"), + http_response(200, serde_json::json!({})), + ]); + let manager = AuthorizationManager::new_with_oauth_http_client( + "https://mcp.example.com/mcp", + Arc::new(client.clone()), + ) + .await + .unwrap(); + + let response = manager + .discovery_get(&Url::parse("https://mcp.example.com/start").unwrap()) + .await + .unwrap(); + let requests = client.requests(); + + assert_eq!( + ( + response.status(), + requests + .iter() + .map(|request| request.uri.as_str()) + .collect::>() + ), + ( + oauth2::http::StatusCode::OK, + vec![ + "https://mcp.example.com/start", + "https://mcp.example.com/redirected" + ] + ) + ); + } + + #[tokio::test] + async fn discovery_get_rejects_cross_origin_redirects() { + let client = RecordingOAuthHttpClient::with_responses(vec![redirect_response( + "http://169.254.169.254/", + )]); + let manager = AuthorizationManager::new_with_oauth_http_client( + "https://mcp.example.com/mcp", + Arc::new(client.clone()), + ) + .await + .unwrap(); + + let err = manager + .discovery_get(&Url::parse("https://mcp.example.com/start").unwrap()) + .await + .unwrap_err(); + + assert_eq!( + ( + err.to_string().contains("non-same-origin"), + client.requests().len() + ), + (true, 1) + ); + } + + #[tokio::test] + async fn protected_resource_metadata_rejects_private_authorization_server_urls() { + let challenge = oauth2::http::Response::builder() + .status(401) + .header( + "www-authenticate", + r#"Bearer resource_metadata="https://mcp.example.com/.well-known/oauth-protected-resource""#, + ) + .body(Vec::new()) + .unwrap(); + let client = RecordingOAuthHttpClient::with_responses(vec![ + challenge, + http_response( + 200, + serde_json::json!({ + "authorization_servers": [ + "http://169.254.169.254/latest/meta-data/", + "https://auth.example.com" + ] + }), + ), + http_response( + 200, + serde_json::json!({ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token" + }), + ), + ]); + let manager = AuthorizationManager::new_with_oauth_http_client( + "https://mcp.example.com/mcp", + Arc::new(client.clone()), + ) + .await + .unwrap(); + + let metadata = manager.discover_metadata().await.unwrap(); + let requests = client.requests(); + + assert_eq!( + ( + metadata.token_endpoint.as_str(), + requests + .iter() + .map(|request| request.uri.as_str()) + .collect::>() + ), + ( + "https://auth.example.com/token", + vec![ + "https://mcp.example.com/mcp", + "https://mcp.example.com/.well-known/oauth-protected-resource", + "https://auth.example.com/.well-known/oauth-authorization-server" + ] + ) + ); + } + #[tokio::test] async fn custom_http_client_handles_registration_exchange_and_refresh() { let client = RecordingOAuthHttpClient::with_responses(vec![ @@ -3410,6 +3671,25 @@ mod tests { ); } + #[test] + fn rejects_cross_origin_resource_metadata_parameter() { + let header = r#"Bearer error="invalid_request", resource_metadata="http://169.254.169.254/latest/meta-data/", scope="read""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + assert_eq!(params.scope.unwrap(), "read"); + } + + #[test] + fn rejects_non_http_resource_metadata_parameter() { + let header = r#"Bearer resource_metadata="file:///etc/passwd""#; + let base = Url::parse("https://example.com/api").unwrap(); + let params = AuthorizationManager::extract_www_authenticate_params(header, &base); + + assert!(params.resource_metadata_url.is_none()); + } + #[test] fn extract_www_authenticate_params_with_all_fields() { let header = r#"Bearer error="invalid_token", resource_metadata="https://example.com/.well-known/oauth-protected-resource", scope="read:data write:data", error_description="token expired""#;