@@ -2,7 +2,7 @@ use std::{convert::Infallible, fmt::Display, sync::Arc, time::Duration};
22
33use bytes:: Bytes ;
44use futures:: { StreamExt , future:: BoxFuture } ;
5- use http:: { Method , Request , Response , header:: ALLOW } ;
5+ use http:: { HeaderMap , Method , Request , Response , header:: ALLOW } ;
66use http_body:: Body ;
77use http_body_util:: { BodyExt , Full , combinators:: BoxBody } ;
88use tokio_stream:: wrappers:: ReceiverStream ;
@@ -29,8 +29,8 @@ use crate::{
2929 } ,
3030} ;
3131
32- #[ derive( Debug , Clone ) ]
3332#[ non_exhaustive]
33+ #[ derive( Debug , Clone ) ]
3434pub struct StreamableHttpServerConfig {
3535 /// The ping message duration for SSE connections.
3636 pub sse_keep_alive : Option < Duration > ,
@@ -49,6 +49,16 @@ pub struct StreamableHttpServerConfig {
4949 /// When this token is cancelled, all active sessions are terminated and
5050 /// the server stops accepting new requests.
5151 pub cancellation_token : CancellationToken ,
52+ /// Allowed hostnames or `host:port` authorities for inbound `Host` validation.
53+ ///
54+ /// By default, Streamable HTTP servers only accept loopback hosts to
55+ /// prevent DNS rebinding attacks against locally running servers. Public
56+ /// deployments should override this list with their own hostnames.
57+ /// examples:
58+ /// allowed_hosts = ["localhost", "127.0.0.1", "0.0.0.0"]
59+ /// or with ports:
60+ /// allowed_hosts = ["example.com", "example.com:8080"]
61+ pub allowed_hosts : Vec < String > ,
5262}
5363
5464impl Default for StreamableHttpServerConfig {
@@ -59,11 +69,24 @@ impl Default for StreamableHttpServerConfig {
5969 stateful_mode : true ,
6070 json_response : false ,
6171 cancellation_token : CancellationToken :: new ( ) ,
72+ allowed_hosts : vec ! [ "localhost" . into( ) , "127.0.0.1" . into( ) , "::1" . into( ) ] ,
6273 }
6374 }
6475}
6576
6677impl StreamableHttpServerConfig {
78+ pub fn with_allowed_hosts (
79+ mut self ,
80+ allowed_hosts : impl IntoIterator < Item = impl Into < String > > ,
81+ ) -> Self {
82+ self . allowed_hosts = allowed_hosts. into_iter ( ) . map ( Into :: into) . collect ( ) ;
83+ self
84+ }
85+ /// Disable allowed hosts. This will allow requests with any `Host` or `Origin` header, which is NOT recommended for public deployments.
86+ pub fn disable_allowed_hosts ( mut self ) -> Self {
87+ self . allowed_hosts . clear ( ) ;
88+ self
89+ }
6790 pub fn with_sse_keep_alive ( mut self , duration : Option < Duration > ) -> Self {
6891 self . sse_keep_alive = duration;
6992 self
@@ -130,6 +153,87 @@ fn validate_protocol_version_header(headers: &http::HeaderMap) -> Result<(), Box
130153 Ok ( ( ) )
131154}
132155
156+ fn forbidden_response ( message : impl Into < String > ) -> BoxResponse {
157+ Response :: builder ( )
158+ . status ( http:: StatusCode :: FORBIDDEN )
159+ . body ( Full :: new ( Bytes :: from ( message. into ( ) ) ) . boxed ( ) )
160+ . expect ( "valid response" )
161+ }
162+
163+ fn normalize_host ( host : & str ) -> String {
164+ host. trim_matches ( '[' )
165+ . trim_matches ( ']' )
166+ . to_ascii_lowercase ( )
167+ }
168+
169+ #[ derive( Debug , Clone , PartialEq , Eq ) ]
170+ struct NormalizedAuthority {
171+ host : String ,
172+ port : Option < u16 > ,
173+ }
174+
175+ fn normalize_authority ( host : & str , port : Option < u16 > ) -> NormalizedAuthority {
176+ NormalizedAuthority {
177+ host : normalize_host ( host) ,
178+ port,
179+ }
180+ }
181+
182+ fn parse_allowed_authority ( allowed : & str ) -> Option < NormalizedAuthority > {
183+ let allowed = allowed. trim ( ) ;
184+ if allowed. is_empty ( ) {
185+ return None ;
186+ }
187+
188+ if let Ok ( authority) = http:: uri:: Authority :: try_from ( allowed) {
189+ return Some ( normalize_authority ( authority. host ( ) , authority. port_u16 ( ) ) ) ;
190+ }
191+
192+ Some ( normalize_authority ( allowed, None ) )
193+ }
194+
195+ fn host_is_allowed ( host : & NormalizedAuthority , allowed_hosts : & [ String ] ) -> bool {
196+ if allowed_hosts. is_empty ( ) {
197+ // If the allowed hosts list is empty, allow all hosts (not recommended).
198+ return true ;
199+ }
200+ allowed_hosts
201+ . iter ( )
202+ . filter_map ( |allowed| parse_allowed_authority ( allowed) )
203+ . any ( |allowed| {
204+ allowed. host == host. host
205+ && match allowed. port {
206+ Some ( port) => host. port == Some ( port) ,
207+ None => true ,
208+ }
209+ } )
210+ }
211+
212+ fn parse_host_header ( headers : & HeaderMap ) -> Result < NormalizedAuthority , BoxResponse > {
213+ let Some ( host) = headers. get ( http:: header:: HOST ) else {
214+ return Err ( forbidden_response ( "Forbidden:missing_host header" ) ) ;
215+ } ;
216+
217+ let host = host
218+ . to_str ( )
219+ . map_err ( |_| forbidden_response ( "Forbidden: Invalid Host header encoding" ) ) ?;
220+ let authority = http:: uri:: Authority :: try_from ( host)
221+ . map_err ( |_| forbidden_response ( "Forbidden: Invalid Host header" ) ) ?;
222+ Ok ( normalize_authority ( authority. host ( ) , authority. port_u16 ( ) ) )
223+ }
224+
225+ fn validate_dns_rebinding_headers (
226+ headers : & HeaderMap ,
227+ config : & StreamableHttpServerConfig ,
228+ ) -> Result < ( ) , BoxResponse > {
229+ let host = parse_host_header ( headers) ?;
230+ if !host_is_allowed ( & host, & config. allowed_hosts ) {
231+ return Err ( forbidden_response ( "Forbidden: Host header is not allowed" ) ) ;
232+ }
233+
234+ Ok ( ( ) )
235+ }
236+
133237/// # Streamable HTTP server
134238///
135239/// An HTTP service that implements the
@@ -279,6 +383,9 @@ where
279383 B : Body + Send + ' static ,
280384 B :: Error : Display ,
281385 {
386+ if let Err ( response) = validate_dns_rebinding_headers ( request. headers ( ) , & self . config ) {
387+ return response;
388+ }
282389 let method = request. method ( ) . clone ( ) ;
283390 let allowed_methods = match self . config . stateful_mode {
284391 true => "GET, POST, DELETE" ,
0 commit comments