Skip to content

Commit 4375beb

Browse files
committed
Socks4とDNSキャッシュを実装
1 parent a3ee11d commit 4375beb

4 files changed

Lines changed: 123 additions & 12 deletions

File tree

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ futures-util = "0.3"
2121
bytes = "1"
2222
idna = "0.4"
2323
percent-encoding = "2"
24+
lru_time_cache = "0.11"

src/main.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,17 @@ use hyper::{
1313
service::{make_service_fn, service_fn},
1414
Body, Method, Request, Response, Server, StatusCode,
1515
};
16-
use std::{io::Write, time::Duration};
17-
use tokio::io::{AsyncRead, AsyncWrite};
16+
use lru_time_cache::LruCache;
17+
use std::{
18+
hash::Hash,
19+
io::Write,
20+
net::{Ipv4Addr, Ipv6Addr},
21+
time::Duration,
22+
};
23+
use tokio::{
24+
io::{AsyncRead, AsyncWrite},
25+
sync::Mutex,
26+
};
1827

1928
use once_cell::sync::OnceCell;
2029

@@ -81,15 +90,24 @@ async fn main() {
8190
let proxy_protocol_main = proxy_protocol[proxy_protocol.len() - 1];
8291
if proxy_protocol_main == "http" {
8392
proxy_stack.push(Box::new(outbound::HttpProxy::new(proxy).unwrap()));
93+
} else if proxy_protocol_main == "socks4" {
94+
proxy_stack.push(Box::new(outbound::Socks4Proxy::new(proxy).unwrap()));
8495
} else {
8596
panic!("This protocol can not use: {}", proxy_protocol_main);
8697
}
8798
}
8899
}
89100

101+
let dns_cache = if config.doh_endpoint.is_some() {
102+
LruCache::with_expiry_duration_and_capacity(Duration::from_secs(7200), 65535)
103+
} else {
104+
LruCache::with_capacity(0)
105+
};
106+
90107
if PROXY
91108
.set(ProxyState {
92109
config,
110+
dns_cache: Mutex::new(dns_cache),
93111
proxy_stack,
94112
})
95113
.is_err()
@@ -121,10 +139,18 @@ async fn handle(request: Request<Body>) -> Result<Response<Body>, Error> {
121139
}
122140
}
123141

142+
#[allow(clippy::type_complexity)]
124143
struct ProxyState {
125144
config: Config,
145+
dns_cache: Mutex<LruCache<String, (DnsCacheState<Ipv4Addr>, DnsCacheState<Ipv6Addr>)>>,
126146
proxy_stack: Vec<Box<dyn ProxyOutBound>>,
127147
}
128148

149+
enum DnsCacheState<T: Hash> {
150+
Some(T),
151+
Fail,
152+
None,
153+
}
154+
129155
pub trait Stream: AsyncRead + AsyncWrite {}
130156
impl<RW> Stream for RW where RW: AsyncRead + AsyncWrite {}

src/outbound/socks4.rs

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
use super::ProxyOutBound;
2-
use crate::{config::ProxyConfig, utils::SocketAddr, Connection, Error};
2+
use crate::{
3+
config::ProxyConfig,
4+
utils::{HostName, SocketAddr},
5+
Connection, Error,
6+
};
37

48
use std::str::FromStr;
59

610
use async_trait::async_trait;
7-
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11+
use tokio::io::{AsyncReadExt, AsyncWriteExt};
812

913
pub struct Socks4Proxy {
1014
addr: SocketAddr,
@@ -40,14 +44,57 @@ impl ProxyOutBound for Socks4Proxy {
4044
mut proxies: Box<dyn Iterator<Item = &Box<dyn ProxyOutBound>> + Send>,
4145
addr: &SocketAddr,
4246
) -> Result<Connection, Error> {
43-
let server = proxies
47+
let ip;
48+
let mut hostname = None;
49+
match &addr.hostname {
50+
HostName::V4(v4) => ip = *v4,
51+
HostName::Domain(domain) => {
52+
ip = "0.0.0.1".parse()?;
53+
if domain.contains('\0') {
54+
return Err("".into());
55+
}
56+
hostname = Some(domain);
57+
}
58+
HostName::V6(_) => return Err("".into()),
59+
}
60+
61+
let ip_octet = ip.octets();
62+
if hostname.is_none()
63+
&& ip_octet[0] == 0
64+
&& ip_octet[1] == 0
65+
&& ip_octet[2] == 0
66+
&& ip_octet[3] != 0
67+
{
68+
return Err("".into());
69+
}
70+
71+
let mut server = proxies
4472
.next()
4573
.ok_or("")?
4674
.connect(proxies, &self.addr)
4775
.await?;
48-
let mut server = BufReader::new(server);
4976

5077
server.write_all(&[4, 1]).await?;
78+
server.write_all(&addr.port.to_be_bytes()).await?;
79+
server.write_all(&ip_octet).await?;
80+
if let Some(auth) = &self.auth {
81+
server.write_all(auth.as_bytes()).await?
82+
}
83+
server.write_all(b"\0").await?;
84+
if let Some(hostname) = hostname {
85+
server.write_all(hostname.as_bytes()).await?;
86+
server.write_all(b"\0").await?;
87+
}
88+
server.flush().await?;
89+
90+
if server.read_u8().await? != 0 {
91+
return Err("".into());
92+
}
93+
if server.read_u8().await? != 90 {
94+
return Err("".into());
95+
}
96+
let mut buf = [0; 6];
97+
server.read_exact(&mut buf).await?;
5198

5299
Ok(Box::new(server))
53100
}

src/utils/addr.rs

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::ParsedUri;
2-
use crate::{http_proxy, Connection, Error, PROXY};
2+
use crate::{http_proxy, Connection, DnsCacheState, Error, PROXY};
33

44
use std::{
55
fmt::{Display, Write},
@@ -150,16 +150,37 @@ impl HostName {
150150
_ => return Err("".into()),
151151
};
152152

153+
let proxy = PROXY.get().ok_or("")?;
154+
let mut uri: ParsedUri =
155+
Uri::from_str(proxy.config.doh_endpoint.as_ref().ok_or("")?)?.try_into()?;
156+
157+
let mut dns_cache = proxy.dns_cache.lock().await;
158+
if let Some(cache_content) = dns_cache.get(domain) {
159+
if qtype == QueryType::A {
160+
match cache_content.0 {
161+
DnsCacheState::Some(s) => return Ok(Self::V4(s)),
162+
DnsCacheState::Fail => return Err("".into()),
163+
DnsCacheState::None => (),
164+
}
165+
} else if qtype == QueryType::AAAA {
166+
match cache_content.1 {
167+
DnsCacheState::Some(s) => return Ok(Self::V6(s)),
168+
DnsCacheState::Fail => return Err("".into()),
169+
DnsCacheState::None => (),
170+
}
171+
} else {
172+
return Err("".into());
173+
}
174+
}
175+
drop(dns_cache);
176+
153177
let mut query = dns_parser::Builder::new_query(0xabcd, true);
154178
query.add_question(domain, false, qtype, QueryClass::IN);
155179
let query = query.build().map_err(|_| "")?;
156180

157181
let base64 = base64::engine::general_purpose::URL_SAFE_NO_PAD;
158182
let query = base64.encode(query);
159183

160-
let proxy = PROXY.get().ok_or("")?;
161-
let mut uri: ParsedUri =
162-
Uri::from_str(proxy.config.doh_endpoint.as_ref().ok_or("")?)?.try_into()?;
163184
if let Some(s) = uri.query.as_mut() {
164185
s.push_str(&format!("&dns={}", query));
165186
} else {
@@ -184,17 +205,33 @@ impl HostName {
184205
}
185206
let response_body = dns_parser::Packet::parse(&response_body)?;
186207

208+
let mut dns_cache = proxy.dns_cache.lock().await;
209+
let cache_content = dns_cache
210+
.entry(domain.to_string())
211+
.or_insert((DnsCacheState::None, DnsCacheState::None));
212+
187213
for answer in response_body.answers {
188214
if answer.cls != dns_parser::Class::IN {
189215
continue;
190216
}
191217
match answer.data {
192-
RData::A(addr) => return Ok(Self::V4(addr.0)),
193-
RData::AAAA(addr) => return Ok(Self::V6(addr.0)),
218+
RData::A(addr) => {
219+
cache_content.0 = DnsCacheState::Some(addr.0);
220+
return Ok(Self::V4(addr.0));
221+
}
222+
RData::AAAA(addr) => {
223+
cache_content.1 = DnsCacheState::Some(addr.0);
224+
return Ok(Self::V6(addr.0));
225+
}
194226
_ => continue,
195227
}
196228
}
197229

230+
if qtype == QueryType::A {
231+
cache_content.0 = DnsCacheState::Fail;
232+
} else if qtype == QueryType::AAAA {
233+
cache_content.1 = DnsCacheState::Fail;
234+
}
198235
Err("".into())
199236
}
200237
}

0 commit comments

Comments
 (0)