@@ -21,45 +21,52 @@ use tokio_tungstenite::{
2121 } ,
2222} ;
2323use tracing;
24+ use url;
25+
26+ /// Capture config to server/connect
27+ enum WsTransportConfig {
28+ Server { host : String , port : u16 } ,
29+ Client { url : String } ,
30+ }
2431
2532pub struct WebSocketTransport {
26- host : String ,
27- port : u16 ,
28- client_mode : bool ,
29- use_tls : bool ,
33+ config : WsTransportConfig ,
3034 buffer_size : usize ,
3135 auth_header : Option < String > ,
3236}
3337
3438impl WebSocketTransport {
3539 pub fn new_server ( host : String , port : u16 , buffer_size : usize ) -> Self {
3640 Self {
37- host,
38- port,
39- client_mode : false ,
40- use_tls : false ,
41+ config : WsTransportConfig :: Server { host, port } ,
4142 buffer_size,
4243 auth_header : None ,
4344 }
4445 }
4546
4647 pub fn new_client ( host : String , port : u16 , buffer_size : usize ) -> Self {
4748 Self {
48- host,
49- port,
50- client_mode : true ,
51- use_tls : false ,
49+ config : WsTransportConfig :: Client {
50+ url : format ! ( "ws://{}:{}/ws" , host, port) ,
51+ } ,
5252 buffer_size,
5353 auth_header : None ,
5454 }
5555 }
5656
5757 pub fn new_wss_client ( host : String , port : u16 , buffer_size : usize ) -> Self {
5858 Self {
59- host,
60- port,
61- client_mode : true ,
62- use_tls : true ,
59+ config : WsTransportConfig :: Client {
60+ url : format ! ( "wss://{}:{}/ws" , host, port) ,
61+ } ,
62+ buffer_size,
63+ auth_header : None ,
64+ }
65+ }
66+
67+ pub fn new_client_with_url ( url : String , buffer_size : usize ) -> Self {
68+ Self {
69+ config : WsTransportConfig :: Client { url } ,
6370 buffer_size,
6471 auth_header : None ,
6572 }
@@ -279,6 +286,7 @@ impl WebSocketTransport {
279286 Ok ( addr) => addr,
280287 Err ( e) => {
281288 tracing:: error!( "Failed to parse host address: {:?}" , e) ;
289+ message_task. abort ( ) ;
282290 return ;
283291 }
284292 } ;
@@ -294,27 +302,40 @@ impl WebSocketTransport {
294302 }
295303
296304 async fn run_client (
297- host : String ,
298- port : u16 ,
299- use_tls : bool ,
305+ url : String ,
300306 auth_header : Option < String > ,
301307 mut cmd_rx : mpsc:: Receiver < TransportCommand > ,
302308 event_tx : mpsc:: Sender < TransportEvent > ,
303309 ) {
304- let protocol = if use_tls { "wss" } else { "ws" } ;
305- let ws_url = format ! ( "{}://{}:{}/ws" , protocol, host, port) ;
306- tracing:: debug!( "Connecting to WebSocket endpoint: {}" , ws_url) ;
310+ tracing:: debug!( "Connecting to WebSocket endpoint: {}" , url) ;
307311
308312 // Connect to the WebSocket server
309313 let ws_stream_result = if let Some ( auth) = & auth_header {
310314 // Create a custom connector with auth header
311315 let request = http:: Request :: builder ( )
312- . uri ( & ws_url )
316+ . uri ( & url )
313317 . header ( "User-Agent" , "mcp-rs-client" )
314318 . header ( "Authorization" , auth)
315319 . header ( "Connection" , "Upgrade" )
316320 . header ( "Upgrade" , "websocket" )
317- . header ( "Host" , format ! ( "{}:{}" , host, port) )
321+ . header ( "Host" , {
322+ // Extract host:port from URL for the Host header
323+ if let Ok ( parsed_url) = url:: Url :: parse ( & url) {
324+ format ! (
325+ "{}:{}" ,
326+ parsed_url. host_str( ) . unwrap_or( "localhost" ) ,
327+ parsed_url
328+ . port( )
329+ . unwrap_or( if parsed_url. scheme( ) == "wss" {
330+ 443
331+ } else {
332+ 80
333+ } )
334+ )
335+ } else {
336+ "localhost:80" . to_string ( )
337+ }
338+ } )
318339 . header ( "Sec-WebSocket-Version" , "13" )
319340 . header (
320341 "Sec-WebSocket-Key" ,
@@ -344,7 +365,7 @@ impl WebSocketTransport {
344365 }
345366 } else {
346367 // Use standard connection without auth
347- connect_async ( & ws_url ) . await
368+ connect_async ( & url ) . await
348369 } ;
349370
350371 // Connect to the WebSocket server
@@ -506,22 +527,18 @@ impl Transport for WebSocketTransport {
506527 let ( cmd_tx, cmd_rx) = mpsc:: channel ( self . buffer_size ) ;
507528 let ( event_tx, event_rx) = mpsc:: channel ( self . buffer_size ) ;
508529
509- if self . client_mode {
510- tokio:: spawn ( Self :: run_client (
511- self . host . clone ( ) ,
512- self . port ,
513- self . use_tls ,
514- self . auth_header . clone ( ) ,
515- cmd_rx,
516- event_tx,
517- ) ) ;
518- } else {
519- tokio:: spawn ( Self :: run_server (
520- self . host . clone ( ) ,
521- self . port ,
522- cmd_rx,
523- event_tx,
524- ) ) ;
530+ match & self . config {
531+ WsTransportConfig :: Client { url } => {
532+ tokio:: spawn ( Self :: run_client (
533+ url. clone ( ) ,
534+ self . auth_header . clone ( ) ,
535+ cmd_rx,
536+ event_tx,
537+ ) ) ;
538+ }
539+ WsTransportConfig :: Server { host, port } => {
540+ tokio:: spawn ( Self :: run_server ( host. clone ( ) , * port, cmd_rx, event_tx) ) ;
541+ }
525542 }
526543
527544 let event_rx = Arc :: new ( tokio:: sync:: Mutex :: new ( event_rx) ) ;
@@ -554,6 +571,7 @@ mod tests {
554571 let host = "127.0.0.1" . to_string ( ) ;
555572 let port = PORT_COUNTER . fetch_add ( 1 , AtomicOrdering :: SeqCst ) ; // Unique port to avoid conflicts
556573 let mut transport = WebSocketTransport :: new_server ( host. clone ( ) , port, 32 ) ;
574+ let ws_url = format ! ( "ws://{}:{}/ws" , host, port) ;
557575
558576 // Start the transport
559577 let TransportChannels { cmd_tx, event_rx } = transport. start ( ) . await ?;
@@ -562,7 +580,6 @@ mod tests {
562580 sleep ( Duration :: from_millis ( 300 ) ) . await ;
563581
564582 // Connect a client to the server
565- let ws_url = format ! ( "ws://{}:{}/ws" , host, port) ;
566583 let ( ws_stream, _) = connect_async ( & ws_url) . await . expect ( "Failed to connect" ) ;
567584 let ( mut write, mut read) = ws_stream. split ( ) ;
568585
@@ -673,6 +690,7 @@ mod tests {
673690 // Start a WebSocket server using warp for the client to connect to
674691 let host = "127.0.0.1" . to_string ( ) ;
675692 let port = PORT_COUNTER . fetch_add ( 1 , AtomicOrdering :: SeqCst ) ; // Unique port to avoid conflicts
693+ let ws_url = format ! ( "ws://{}:{}/ws" , host, port) ;
676694
677695 // Create a channel to receive messages from the test server
678696 let ( server_tx, mut server_rx) = mpsc:: channel :: < JsonRpcMessage > ( 32 ) ;
@@ -711,7 +729,7 @@ mod tests {
711729 sleep ( Duration :: from_millis ( 100 ) ) . await ;
712730
713731 // Create and start the WebSocket client transport
714- let mut transport = WebSocketTransport :: new_client ( host . clone ( ) , port , 32 ) ;
732+ let mut transport = WebSocketTransport :: new_client_with_url ( ws_url , 32 ) ;
715733 let TransportChannels { cmd_tx, event_rx } = transport. start ( ) . await ?;
716734
717735 // Give the client time to connect
@@ -784,6 +802,7 @@ mod tests {
784802 // Start a WebSocket server using warp for the client to connect to
785803 let host = "127.0.0.1" . to_string ( ) ;
786804 let port = PORT_COUNTER . fetch_add ( 1 , AtomicOrdering :: SeqCst ) ; // Unique port to avoid conflicts
805+ let ws_url = format ! ( "ws://{}:{}/ws" , host, port) ;
787806
788807 // Create a channel to receive messages from the test server
789808 let ( server_tx, mut server_rx) = mpsc:: channel :: < JsonRpcMessage > ( 32 ) ;
@@ -835,7 +854,7 @@ mod tests {
835854
836855 // Create and start the WebSocket client transport with auth header
837856 let auth_header = "Bearer test-token-123" . to_string ( ) ;
838- let mut transport = WebSocketTransport :: new_client ( host . clone ( ) , port , 32 )
857+ let mut transport = WebSocketTransport :: new_client_with_url ( ws_url , 32 )
839858 . with_auth_header ( auth_header. clone ( ) ) ;
840859
841860 let TransportChannels {
0 commit comments