Skip to content

Commit 2c54971

Browse files
authored
Merge pull request #15 from gitarcode/ws_url
Add support to send in url directly
2 parents 880caf3 + bb5caf9 commit 2c54971

File tree

1 file changed

+63
-44
lines changed

1 file changed

+63
-44
lines changed

src/transport/ws.rs

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,45 +21,52 @@ use tokio_tungstenite::{
2121
},
2222
};
2323
use tracing;
24+
use url;
25+
26+
/// Capture config to server/connect
27+
enum WsTransportConfig {
28+
Server { host: String, port: u16 },
29+
Client { url: String },
30+
}
2431

2532
pub struct WebSocketTransport {
26-
host: String,
27-
port: u16,
28-
client_mode: bool,
29-
use_tls: bool,
33+
config: WsTransportConfig,
3034
buffer_size: usize,
3135
auth_header: Option<String>,
3236
}
3337

3438
impl WebSocketTransport {
3539
pub fn new_server(host: String, port: u16, buffer_size: usize) -> Self {
3640
Self {
37-
host,
38-
port,
39-
client_mode: false,
40-
use_tls: false,
41+
config: WsTransportConfig::Server { host, port },
4142
buffer_size,
4243
auth_header: None,
4344
}
4445
}
4546

4647
pub fn new_client(host: String, port: u16, buffer_size: usize) -> Self {
4748
Self {
48-
host,
49-
port,
50-
client_mode: true,
51-
use_tls: false,
49+
config: WsTransportConfig::Client {
50+
url: format!("ws://{}:{}/ws", host, port),
51+
},
5252
buffer_size,
5353
auth_header: None,
5454
}
5555
}
5656

5757
pub fn new_wss_client(host: String, port: u16, buffer_size: usize) -> Self {
5858
Self {
59-
host,
60-
port,
61-
client_mode: true,
62-
use_tls: true,
59+
config: WsTransportConfig::Client {
60+
url: format!("wss://{}:{}/ws", host, port),
61+
},
62+
buffer_size,
63+
auth_header: None,
64+
}
65+
}
66+
67+
pub fn new_client_with_url(url: String, buffer_size: usize) -> Self {
68+
Self {
69+
config: WsTransportConfig::Client { url },
6370
buffer_size,
6471
auth_header: None,
6572
}
@@ -279,6 +286,7 @@ impl WebSocketTransport {
279286
Ok(addr) => addr,
280287
Err(e) => {
281288
tracing::error!("Failed to parse host address: {:?}", e);
289+
message_task.abort();
282290
return;
283291
}
284292
};
@@ -294,27 +302,40 @@ impl WebSocketTransport {
294302
}
295303

296304
async fn run_client(
297-
host: String,
298-
port: u16,
299-
use_tls: bool,
305+
url: String,
300306
auth_header: Option<String>,
301307
mut cmd_rx: mpsc::Receiver<TransportCommand>,
302308
event_tx: mpsc::Sender<TransportEvent>,
303309
) {
304-
let protocol = if use_tls { "wss" } else { "ws" };
305-
let ws_url = format!("{}://{}:{}/ws", protocol, host, port);
306-
tracing::debug!("Connecting to WebSocket endpoint: {}", ws_url);
310+
tracing::debug!("Connecting to WebSocket endpoint: {}", url);
307311

308312
// Connect to the WebSocket server
309313
let ws_stream_result = if let Some(auth) = &auth_header {
310314
// Create a custom connector with auth header
311315
let request = http::Request::builder()
312-
.uri(&ws_url)
316+
.uri(&url)
313317
.header("User-Agent", "mcp-rs-client")
314318
.header("Authorization", auth)
315319
.header("Connection", "Upgrade")
316320
.header("Upgrade", "websocket")
317-
.header("Host", format!("{}:{}", host, port))
321+
.header("Host", {
322+
// Extract host:port from URL for the Host header
323+
if let Ok(parsed_url) = url::Url::parse(&url) {
324+
format!(
325+
"{}:{}",
326+
parsed_url.host_str().unwrap_or("localhost"),
327+
parsed_url
328+
.port()
329+
.unwrap_or(if parsed_url.scheme() == "wss" {
330+
443
331+
} else {
332+
80
333+
})
334+
)
335+
} else {
336+
"localhost:80".to_string()
337+
}
338+
})
318339
.header("Sec-WebSocket-Version", "13")
319340
.header(
320341
"Sec-WebSocket-Key",
@@ -344,7 +365,7 @@ impl WebSocketTransport {
344365
}
345366
} else {
346367
// Use standard connection without auth
347-
connect_async(&ws_url).await
368+
connect_async(&url).await
348369
};
349370

350371
// Connect to the WebSocket server
@@ -506,22 +527,18 @@ impl Transport for WebSocketTransport {
506527
let (cmd_tx, cmd_rx) = mpsc::channel(self.buffer_size);
507528
let (event_tx, event_rx) = mpsc::channel(self.buffer_size);
508529

509-
if self.client_mode {
510-
tokio::spawn(Self::run_client(
511-
self.host.clone(),
512-
self.port,
513-
self.use_tls,
514-
self.auth_header.clone(),
515-
cmd_rx,
516-
event_tx,
517-
));
518-
} else {
519-
tokio::spawn(Self::run_server(
520-
self.host.clone(),
521-
self.port,
522-
cmd_rx,
523-
event_tx,
524-
));
530+
match &self.config {
531+
WsTransportConfig::Client { url } => {
532+
tokio::spawn(Self::run_client(
533+
url.clone(),
534+
self.auth_header.clone(),
535+
cmd_rx,
536+
event_tx,
537+
));
538+
}
539+
WsTransportConfig::Server { host, port } => {
540+
tokio::spawn(Self::run_server(host.clone(), *port, cmd_rx, event_tx));
541+
}
525542
}
526543

527544
let event_rx = Arc::new(tokio::sync::Mutex::new(event_rx));
@@ -554,6 +571,7 @@ mod tests {
554571
let host = "127.0.0.1".to_string();
555572
let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts
556573
let mut transport = WebSocketTransport::new_server(host.clone(), port, 32);
574+
let ws_url = format!("ws://{}:{}/ws", host, port);
557575

558576
// Start the transport
559577
let TransportChannels { cmd_tx, event_rx } = transport.start().await?;
@@ -562,7 +580,6 @@ mod tests {
562580
sleep(Duration::from_millis(300)).await;
563581

564582
// Connect a client to the server
565-
let ws_url = format!("ws://{}:{}/ws", host, port);
566583
let (ws_stream, _) = connect_async(&ws_url).await.expect("Failed to connect");
567584
let (mut write, mut read) = ws_stream.split();
568585

@@ -673,6 +690,7 @@ mod tests {
673690
// Start a WebSocket server using warp for the client to connect to
674691
let host = "127.0.0.1".to_string();
675692
let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts
693+
let ws_url = format!("ws://{}:{}/ws", host, port);
676694

677695
// Create a channel to receive messages from the test server
678696
let (server_tx, mut server_rx) = mpsc::channel::<JsonRpcMessage>(32);
@@ -711,7 +729,7 @@ mod tests {
711729
sleep(Duration::from_millis(100)).await;
712730

713731
// Create and start the WebSocket client transport
714-
let mut transport = WebSocketTransport::new_client(host.clone(), port, 32);
732+
let mut transport = WebSocketTransport::new_client_with_url(ws_url, 32);
715733
let TransportChannels { cmd_tx, event_rx } = transport.start().await?;
716734

717735
// Give the client time to connect
@@ -784,6 +802,7 @@ mod tests {
784802
// Start a WebSocket server using warp for the client to connect to
785803
let host = "127.0.0.1".to_string();
786804
let port = PORT_COUNTER.fetch_add(1, AtomicOrdering::SeqCst); // Unique port to avoid conflicts
805+
let ws_url = format!("ws://{}:{}/ws", host, port);
787806

788807
// Create a channel to receive messages from the test server
789808
let (server_tx, mut server_rx) = mpsc::channel::<JsonRpcMessage>(32);
@@ -835,7 +854,7 @@ mod tests {
835854

836855
// Create and start the WebSocket client transport with auth header
837856
let auth_header = "Bearer test-token-123".to_string();
838-
let mut transport = WebSocketTransport::new_client(host.clone(), port, 32)
857+
let mut transport = WebSocketTransport::new_client_with_url(ws_url, 32)
839858
.with_auth_header(auth_header.clone());
840859

841860
let TransportChannels {

0 commit comments

Comments
 (0)