Skip to content

Commit ee8a115

Browse files
committed
Refactor WebSocket implementation to separate inner representation into InnerWebSocket
Signed-off-by: Yuki Kishimoto <yukikishimoto@protonmail.com>
1 parent 0a3f1e8 commit ee8a115

3 files changed

Lines changed: 56 additions & 30 deletions

File tree

src/native/mod.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async fn connect_direct(url: &Url, timeout: Duration) -> Result<WebSocket, Error
5757
))
5858
.await
5959
.map_err(|_| Error::Timeout)??;
60-
Ok(WebSocket::Tokio(Box::new(stream)))
60+
Ok(WebSocket::tokio(Box::new(stream)))
6161
}
6262

6363
#[cfg(feature = "socks")]
@@ -81,7 +81,7 @@ async fn connect_proxy(
8181
))
8282
.await
8383
.map_err(|_| Error::Timeout)??;
84-
Ok(WebSocket::Tokio(Box::new(stream)))
84+
Ok(WebSocket::tokio(Box::new(stream)))
8585
}
8686

8787
#[cfg(feature = "tor")]
@@ -104,7 +104,7 @@ async fn connect_tor(
104104
))
105105
.await
106106
.map_err(|_| Error::Timeout)??;
107-
Ok(WebSocket::Tor(Box::new(stream)))
107+
Ok(WebSocket::tor(Box::new(stream)))
108108
}
109109

110110
#[inline]

src/socket.rs

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// Copyright (c) 2022-2024 Yuki Kishimoto
22
// Distributed under the MIT software license
33

4-
use std::ops::DerefMut;
54
use std::pin::Pin;
65
use std::task::{Context, Poll};
76
use std::time::Duration;
@@ -22,7 +21,7 @@ use crate::{ConnectionMode, Error, Message};
2221
#[cfg(not(target_arch = "wasm32"))]
2322
type WsStream<T> = WebSocketStream<MaybeTlsStream<T>>;
2423

25-
pub enum WebSocket {
24+
enum InnerWebSocket {
2625
#[cfg(not(target_arch = "wasm32"))]
2726
Tokio(Box<WsStream<TcpStream>>),
2827
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
@@ -31,7 +30,34 @@ pub enum WebSocket {
3130
Wasm(WsStream),
3231
}
3332

33+
pub struct WebSocket {
34+
inner: InnerWebSocket,
35+
}
36+
3437
impl WebSocket {
38+
#[inline]
39+
fn new(inner: InnerWebSocket) -> Self {
40+
Self { inner }
41+
}
42+
43+
#[inline]
44+
#[cfg(not(target_arch = "wasm32"))]
45+
pub(crate) fn tokio(inner: Box<WsStream<TcpStream>>) -> Self {
46+
Self::new(InnerWebSocket::Tokio(inner))
47+
}
48+
49+
#[inline]
50+
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
51+
pub(crate) fn tor(inner: Box<WsStream<DataStream>>) -> Self {
52+
Self::new(InnerWebSocket::Tor(inner))
53+
}
54+
55+
#[inline]
56+
#[cfg(target_arch = "wasm32")]
57+
pub(crate) fn wasm(inner: WsStream) -> Self {
58+
Self::new(InnerWebSocket::Wasm(inner))
59+
}
60+
3561
pub async fn connect(
3662
url: &Url,
3763
_mode: &ConnectionMode,
@@ -51,50 +77,50 @@ impl Sink<Message> for WebSocket {
5177
type Error = Error;
5278

5379
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54-
match self.deref_mut() {
80+
match &mut self.inner {
5581
#[cfg(not(target_arch = "wasm32"))]
56-
Self::Tokio(s) => Pin::new(s.as_mut()).poll_ready(cx).map_err(Into::into),
82+
InnerWebSocket::Tokio(s) => Pin::new(s.as_mut()).poll_ready(cx).map_err(Into::into),
5783
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
58-
Self::Tor(s) => Pin::new(s.as_mut()).poll_ready(cx).map_err(Into::into),
84+
InnerWebSocket::Tor(s) => Pin::new(s.as_mut()).poll_ready(cx).map_err(Into::into),
5985
#[cfg(target_arch = "wasm32")]
60-
Self::Wasm(s) => Pin::new(s).poll_ready(cx),
86+
InnerWebSocket::Wasm(s) => Pin::new(s).poll_ready(cx),
6187
}
6288
}
6389

6490
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
65-
match self.deref_mut() {
91+
match &mut self.inner {
6692
#[cfg(not(target_arch = "wasm32"))]
67-
Self::Tokio(s) => Pin::new(s.as_mut())
93+
InnerWebSocket::Tokio(s) => Pin::new(s.as_mut())
6894
.start_send(item.into())
6995
.map_err(Into::into),
7096
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
71-
Self::Tor(s) => Pin::new(s.as_mut())
97+
InnerWebSocket::Tor(s) => Pin::new(s.as_mut())
7298
.start_send(item.into())
7399
.map_err(Into::into),
74100
#[cfg(target_arch = "wasm32")]
75-
Self::Wasm(s) => Pin::new(s).start_send(item),
101+
InnerWebSocket::Wasm(s) => Pin::new(s).start_send(item),
76102
}
77103
}
78104

79105
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80-
match self.deref_mut() {
106+
match &mut self.inner {
81107
#[cfg(not(target_arch = "wasm32"))]
82-
Self::Tokio(s) => Pin::new(s.as_mut()).poll_flush(cx).map_err(Into::into),
108+
InnerWebSocket::Tokio(s) => Pin::new(s.as_mut()).poll_flush(cx).map_err(Into::into),
83109
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
84-
Self::Tor(s) => Pin::new(s.as_mut()).poll_flush(cx).map_err(Into::into),
110+
InnerWebSocket::Tor(s) => Pin::new(s.as_mut()).poll_flush(cx).map_err(Into::into),
85111
#[cfg(target_arch = "wasm32")]
86-
Self::Wasm(s) => Pin::new(s).poll_flush(cx),
112+
InnerWebSocket::Wasm(s) => Pin::new(s).poll_flush(cx),
87113
}
88114
}
89115

90116
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
91-
match self.deref_mut() {
117+
match &mut self.inner {
92118
#[cfg(not(target_arch = "wasm32"))]
93-
Self::Tokio(s) => Pin::new(s.as_mut()).poll_close(cx).map_err(Into::into),
119+
InnerWebSocket::Tokio(s) => Pin::new(s.as_mut()).poll_close(cx).map_err(Into::into),
94120
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
95-
Self::Tor(s) => Pin::new(s.as_mut()).poll_close(cx).map_err(Into::into),
121+
InnerWebSocket::Tor(s) => Pin::new(s.as_mut()).poll_close(cx).map_err(Into::into),
96122
#[cfg(target_arch = "wasm32")]
97-
Self::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
123+
InnerWebSocket::Wasm(s) => Pin::new(s).poll_close(cx).map_err(Into::into),
98124
}
99125
}
100126
}
@@ -103,30 +129,30 @@ impl Stream for WebSocket {
103129
type Item = Result<Message, Error>;
104130

105131
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
106-
match self.deref_mut() {
132+
match &mut self.inner {
107133
#[cfg(not(target_arch = "wasm32"))]
108-
Self::Tokio(s) => Pin::new(s)
134+
InnerWebSocket::Tokio(s) => Pin::new(s)
109135
.poll_next(cx)
110136
.map(|i| i.map(|res| res.map(Message::from_native)))
111137
.map_err(Into::into),
112138
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
113-
Self::Tor(s) => Pin::new(s)
139+
InnerWebSocket::Tor(s) => Pin::new(s)
114140
.poll_next(cx)
115141
.map(|i| i.map(|res| res.map(Message::from_native)))
116142
.map_err(Into::into),
117143
#[cfg(target_arch = "wasm32")]
118-
Self::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
144+
InnerWebSocket::Wasm(s) => Pin::new(s).poll_next(cx).map_err(Into::into),
119145
}
120146
}
121147

122148
fn size_hint(&self) -> (usize, Option<usize>) {
123-
match self {
149+
match &self.inner {
124150
#[cfg(not(target_arch = "wasm32"))]
125-
Self::Tokio(s) => s.size_hint(),
151+
InnerWebSocket::Tokio(s) => s.size_hint(),
126152
#[cfg(all(feature = "tor", not(target_arch = "wasm32")))]
127-
Self::Tor(s) => s.size_hint(),
153+
InnerWebSocket::Tor(s) => s.size_hint(),
128154
#[cfg(target_arch = "wasm32")]
129-
Self::Wasm(s) => s.size_hint(),
155+
InnerWebSocket::Wasm(s) => s.size_hint(),
130156
}
131157
}
132158
}

src/wasm/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub async fn connect(url: &Url, timeout: Duration) -> Result<WebSocket, Error> {
3131
let (_ws, stream) = time::timeout(Some(timeout), WasmWebSocket::connect(url))
3232
.await
3333
.ok_or(Error::Timeout)??;
34-
Ok(WebSocket::Wasm(stream))
34+
Ok(WebSocket::wasm(stream))
3535
}
3636

3737
/// Helper function to reduce code bloat

0 commit comments

Comments
 (0)