Skip to content

Commit 1cc0699

Browse files
Normalize port and ICMP id
1 parent 7ff08f6 commit 1cc0699

1 file changed

Lines changed: 40 additions & 10 deletions

File tree

src/worker.rs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::cli::{Config, TimeoutAction};
1+
use crate::cli::{Config, SupportedProtocol, TimeoutAction};
22
use crate::net::params::MAX_WIRE_PAYLOAD;
33
use crate::net::payload::{handle_payload_result, send_payload, validate_payload};
44
use crate::net::sock_mgr::{SocketHandles, SocketManager};
@@ -23,6 +23,31 @@ fn as_uninit_mut(buf: &mut [u8]) -> &mut [std::mem::MaybeUninit<u8>] {
2323
}
2424
}
2525

26+
#[inline]
27+
fn normalize_client_sockaddr(
28+
src_sa: &SockAddr,
29+
listen_proto: SupportedProtocol,
30+
listen_port_id: u16,
31+
) -> Option<SocketAddr> {
32+
let src = src_sa.as_socket()?;
33+
if listen_proto == SupportedProtocol::ICMP {
34+
let normalized = match src {
35+
SocketAddr::V4(addr) => {
36+
SocketAddr::V4(std::net::SocketAddrV4::new(*addr.ip(), listen_port_id))
37+
}
38+
SocketAddr::V6(addr) => SocketAddr::V6(std::net::SocketAddrV6::new(
39+
*addr.ip(),
40+
listen_port_id,
41+
addr.flowinfo(),
42+
addr.scope_id(),
43+
)),
44+
};
45+
Some(normalized)
46+
} else {
47+
Some(src)
48+
}
49+
}
50+
2651
// Stack-resident, cacheline-aligned buffers
2752
#[repr(align(64))]
2853
struct AlignedBuf {
@@ -40,7 +65,6 @@ impl AlignedBuf {
4065
struct CachedClientState {
4166
c2u: bool,
4267
worker_id: usize,
43-
client_sa: Option<SockAddr>,
4468
dest_sock_type: Type,
4569
dest_sa: SockAddr,
4670
dest_port_id: u16,
@@ -60,7 +84,6 @@ impl CachedClientState {
6084
Self {
6185
c2u,
6286
worker_id,
63-
client_sa: handles.client_addr.map(|addr| SockAddr::from(addr)),
6487
dest_sock_type: handles.upstream_sock.r#type().unwrap_or(Type::RAW),
6588
dest_sa: SockAddr::from(handles.upstream_addr),
6689
dest_port_id: handles.upstream_addr.port(),
@@ -80,7 +103,6 @@ impl CachedClientState {
80103
Self {
81104
c2u,
82105
worker_id,
83-
client_sa: None,
84106
dest_sock_type: handles.client_sock.r#type().unwrap_or(Type::RAW),
85107
dest_sa,
86108
dest_port_id,
@@ -92,7 +114,6 @@ impl CachedClientState {
92114

93115
fn refresh_from_handles(&mut self, handles: &SocketHandles) {
94116
if self.c2u {
95-
self.client_sa = handles.client_addr.map(|addr| SockAddr::from(addr));
96117
self.dest_sock_type = handles.upstream_sock.r#type().unwrap_or(Type::RAW);
97118
self.dest_sa = SockAddr::from(handles.upstream_addr);
98119
self.dest_port_id = handles.upstream_addr.port();
@@ -359,7 +380,11 @@ pub fn run_client_to_upstream_thread(
359380
// First lock: publish client and connect the socket for fast path
360381
if !locked.load(AtomOrdering::Relaxed) {
361382
let src = option_or_log_continue!(
362-
src_sa.as_socket(),
383+
normalize_client_sockaddr(
384+
&src_sa,
385+
cfg.listen_proto,
386+
cfg.listen_port_id
387+
),
363388
log_warn_dir,
364389
worker_id,
365390
C2U,
@@ -378,13 +403,13 @@ pub fn run_client_to_upstream_thread(
378403

379404
// Signal to other threads that a client is currently being locked
380405
locked.store(true, AtomOrdering::Relaxed);
381-
cache.client_sa = Some(SockAddr::from(src));
406+
let src_sa_clean = SockAddr::from(src);
382407
let addr_opt = Some(src);
383408

384409
handles.client_connected = false;
385410
if cfg.debug_no_connect {
386411
log_info!("Locked to single client {} (not connected)", src);
387-
} else if let Err(e) = handles.client_sock.connect(&src_sa) {
412+
} else if let Err(e) = handles.client_sock.connect(&src_sa_clean) {
388413
log_warn!("connect client_sock to {} failed: {}", src, e);
389414
log_info!("Locked to single client {} (not connected)", src);
390415
} else {
@@ -416,7 +441,7 @@ pub fn run_client_to_upstream_thread(
416441
let _ = mgr.set_client_sock_connected(
417442
addr_opt,
418443
handles.client_connected,
419-
&src_sa,
444+
&src_sa_clean,
420445
0,
421446
);
422447
}
@@ -456,7 +481,12 @@ pub fn run_client_to_upstream_thread(
456481
&cache.dest_sa,
457482
None,
458483
);
459-
} else if Some(src_sa) == cache.client_sa {
484+
} else if normalize_client_sockaddr(
485+
&src_sa,
486+
cfg.listen_proto,
487+
cfg.listen_port_id,
488+
) == handles.client_addr
489+
{
460490
// Only forward packets from the locked client (recv_from may still deliver before connect succeeds)
461491
let validated = result_or_log_continue!(
462492
validate_payload(C2U, cfg, stats, &buf.data[..len], cache.recv_port_id),

0 commit comments

Comments
 (0)