From 158b496ee174cd0a71a182807b82501ce1393a4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 24 Feb 2026 23:07:03 +0100 Subject: [PATCH 1/4] setup session manager testing harness and add basic test cases --- Cargo.lock | 1 + crates/defguard_session_manager/Cargo.toml | 2 + .../tests/common/mod.rs | 146 ++++++++++++++++++ crates/defguard_session_manager/tests/mod.rs | 2 + .../tests/session_manager/event_flow.rs | 50 ++++++ .../tests/session_manager/mod.rs | 3 + .../tests/session_manager/sessions.rs | 61 ++++++++ .../tests/session_manager/stats.rs | 86 +++++++++++ 8 files changed, 351 insertions(+) create mode 100644 crates/defguard_session_manager/tests/common/mod.rs create mode 100644 crates/defguard_session_manager/tests/mod.rs create mode 100644 crates/defguard_session_manager/tests/session_manager/event_flow.rs create mode 100644 crates/defguard_session_manager/tests/session_manager/mod.rs create mode 100644 crates/defguard_session_manager/tests/session_manager/sessions.rs create mode 100644 crates/defguard_session_manager/tests/session_manager/stats.rs diff --git a/Cargo.lock b/Cargo.lock index 6654033d9..c9cf97a7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1590,6 +1590,7 @@ dependencies = [ "chrono", "defguard_common", "defguard_core", + "ipnetwork", "sqlx", "thiserror 2.0.18", "tokio", diff --git a/crates/defguard_session_manager/Cargo.toml b/crates/defguard_session_manager/Cargo.toml index a28bc0ec1..d3bf8b077 100644 --- a/crates/defguard_session_manager/Cargo.toml +++ b/crates/defguard_session_manager/Cargo.toml @@ -18,3 +18,5 @@ thiserror.workspace = true tokio.workspace = true tracing.workspace = true +[dev-dependencies] +ipnetwork.workspace = true diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs new file mode 100644 index 000000000..0783404f9 --- /dev/null +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -0,0 +1,146 @@ +use std::net::{IpAddr, Ipv4Addr}; + +use defguard_common::db::models::gateway::Gateway; +use defguard_common::db::models::{ + device::WireguardNetworkDevice, wireguard::{LocationMfaMode, ServiceLocationMode}, Device, + DeviceType, User, WireguardNetwork, +}; +use defguard_common::db::Id; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_core::grpc::GatewayEvent; +use defguard_session_manager::{events::SessionManagerEvent, run_session_manager}; +use ipnetwork::IpNetwork; +use tokio::{ + sync::{broadcast, mpsc}, + task::JoinHandle, + time::{Duration, timeout}, +}; + +const EVENT_TIMEOUT: Duration = Duration::from_secs(2); + +pub(crate) struct SessionManagerHarness { + stats_tx: mpsc::UnboundedSender, + event_rx: mpsc::UnboundedReceiver, + #[allow(dead_code)] + gateway_rx: broadcast::Receiver, + handle: JoinHandle>, +} + +impl Drop for SessionManagerHarness { + fn drop(&mut self) { + self.handle.abort(); + } +} + +impl SessionManagerHarness { + pub(crate) fn send_stats(&self, update: PeerStatsUpdate) { + self.stats_tx + .send(update) + .expect("failed to send peer stats update"); + } + + pub(crate) async fn recv_event(&mut self) -> SessionManagerEvent { + timeout(EVENT_TIMEOUT, self.event_rx.recv()) + .await + .expect("timed out waiting for session manager event") + .expect("session manager event channel closed") + } +} + +pub(crate) fn start_session_manager(pool: sqlx::PgPool) -> SessionManagerHarness { + let (stats_tx, stats_rx) = mpsc::unbounded_channel(); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (gateway_tx, gateway_rx) = broadcast::channel(16); + + let handle = tokio::spawn(run_session_manager(pool, stats_rx, event_tx, gateway_tx)); + + SessionManagerHarness { + stats_tx, + event_rx, + gateway_rx, + handle, + } +} + +pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork { + WireguardNetwork::new( + "TestNet".to_string(), + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], + 51820, + "10.0.0.1".to_string(), + None, + 1420, + 0, + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0).unwrap()], + 25, + 300, + false, + false, + LocationMfaMode::Disabled, + ServiceLocationMode::Disabled, + ) + .save(pool) + .await + .expect("failed to create Wireguard network") +} + +pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { + User::new( + "session-test", + Some("pass123"), + "Tester", + "Session", + "session-test@example.com", + None, + ) + .save(pool) + .await + .expect("failed to create user") +} + +pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device { + Device::new( + "session-test-device".to_string(), + "device-pubkey-test".to_string(), + user_id, + DeviceType::User, + None, + true, + ) + .save(pool) + .await + .expect("failed to create device") +} + +pub(crate) async fn attach_device_to_network( + pool: &sqlx::PgPool, + network_id: Id, + device_id: Id, +) { + let network_device = WireguardNetworkDevice::new( + network_id, + device_id, + vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 10))], + ); + network_device + .insert(pool) + .await + .expect("failed to attach device to network"); +} + +pub(crate) async fn create_gateway( + pool: &sqlx::PgPool, + network_id: Id, + modified_by: Id, +) -> Gateway { + Gateway::new( + network_id, + "gateway-1".to_string(), + "127.0.0.1".to_string(), + 51820, + modified_by, + ) + .save(pool) + .await + .expect("failed to create gateway") +} diff --git a/crates/defguard_session_manager/tests/mod.rs b/crates/defguard_session_manager/tests/mod.rs new file mode 100644 index 000000000..9bf2fc86c --- /dev/null +++ b/crates/defguard_session_manager/tests/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod common; +pub(crate) mod session_manager; diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs new file mode 100644 index 000000000..9ad914561 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -0,0 +1,50 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::setup_pool; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::events::SessionManagerEventType; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; + +use crate::common::{ + attach_device_to_network, create_device, create_gateway, create_network, create_user, + start_session_manager, +}; + +#[sqlx::test] +async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.id).await; + + let mut manager = start_session_manager(pool); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time, + endpoint, + upload: 100, + download: 200, + latest_handshake: base_time - TimeDelta::seconds(5), + }; + + manager.send_stats(update); + + let event = manager.recv_event().await; + + assert!(matches!( + event.event, + SessionManagerEventType::ClientConnected + )); + assert_eq!(event.context.location.id, network.id); + assert_eq!(event.context.user.id, user.id); + assert_eq!(event.context.device.id, device.id); + assert_eq!(event.context.public_ip, endpoint.ip()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/mod.rs b/crates/defguard_session_manager/tests/session_manager/mod.rs new file mode 100644 index 000000000..a16bea70c --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/mod.rs @@ -0,0 +1,3 @@ +mod event_flow; +mod sessions; +mod stats; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs new file mode 100644 index 000000000..7561aa4db --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -0,0 +1,61 @@ +use chrono::{TimeDelta, Utc}; +use defguard_common::db::models::vpn_client_session::{VpnClientSession, VpnClientSessionState}; +use defguard_common::db::setup_pool; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, timeout}; + +use crate::common::{ + attach_device_to_network, create_device, create_gateway, create_network, create_user, + start_session_manager, +}; + +const DB_WAIT_TIMEOUT: Duration = Duration::from_secs(2); + +async fn wait_for_active_session( + pool: &sqlx::PgPool, + location_id: defguard_common::db::Id, + device_id: defguard_common::db::Id, +) -> VpnClientSession { + timeout(DB_WAIT_TIMEOUT, async { + loop { + if let Ok(Some(session)) = + VpnClientSession::try_get_active_session(pool, location_id, device_id).await + { + return session; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + }) + .await + .expect("timed out waiting for active session") +} + +#[sqlx::test] +async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.id).await; + + let manager = start_session_manager(pool.clone()); + + let base_time = Utc::now().naive_utc(); + let update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time, + endpoint: "203.0.113.10:51820".parse().unwrap(), + upload: 100, + download: 200, + latest_handshake: base_time - TimeDelta::seconds(5), + }; + + manager.send_stats(update); + + let session = wait_for_active_session(&pool, network.id, device.id).await; + assert_eq!(session.state, VpnClientSessionState::Connected); +} diff --git a/crates/defguard_session_manager/tests/session_manager/stats.rs b/crates/defguard_session_manager/tests/session_manager/stats.rs new file mode 100644 index 000000000..f601d09fb --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/stats.rs @@ -0,0 +1,86 @@ +use chrono::{TimeDelta, Utc}; +use defguard_common::db::models::vpn_session_stats::VpnSessionStats; +use defguard_common::db::setup_pool; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, timeout}; + +use crate::common::{ + attach_device_to_network, create_device, create_gateway, create_network, create_user, + start_session_manager, +}; + +const DB_WAIT_TIMEOUT: Duration = Duration::from_secs(2); + +async fn wait_for_latest_stats( + pool: &sqlx::PgPool, + device_id: defguard_common::db::Id, + location_id: defguard_common::db::Id, + expected_upload: i64, + expected_download: i64, +) -> VpnSessionStats { + timeout(DB_WAIT_TIMEOUT, async { + loop { + if let Ok(Some(stats)) = + VpnSessionStats::fetch_latest_for_device(pool, device_id, location_id).await + { + if stats.total_upload == expected_upload + && stats.total_download == expected_download + { + return stats; + } + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + }) + .await + .expect("timed out waiting for latest stats") +} + +#[sqlx::test] +async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.id).await; + + let manager = start_session_manager(pool.clone()); + + let endpoint: std::net::SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let first_update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time, + endpoint, + upload: 100, + download: 200, + latest_handshake: base_time - TimeDelta::seconds(5), + }; + + manager.send_stats(first_update); + + let first_stats = wait_for_latest_stats(&pool, device.id, network.id, 100, 200).await; + assert_eq!(first_stats.upload_diff, 0); + assert_eq!(first_stats.download_diff, 0); + + let second_update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time + TimeDelta::seconds(10), + endpoint, + upload: 150, + download: 260, + latest_handshake: base_time + TimeDelta::seconds(10), + }; + + manager.send_stats(second_update); + + let second_stats = wait_for_latest_stats(&pool, device.id, network.id, 150, 260).await; + assert_eq!(second_stats.upload_diff, 50); + assert_eq!(second_stats.download_diff, 60); +} From b8252eff737171c2d10cd4274ea327a5bc8cd681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Tue, 24 Feb 2026 23:45:49 +0100 Subject: [PATCH 2/4] refactor test approach --- crates/defguard_session_manager/src/lib.rs | 67 ++++++++++------- .../tests/common/mod.rs | 58 +++++---------- .../tests/session_manager/event_flow.rs | 28 ++++++-- .../tests/session_manager/sessions.rs | 46 +++++------- .../tests/session_manager/stats.rs | 71 +++++++++---------- 5 files changed, 136 insertions(+), 134 deletions(-) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index cfe8505ee..fef06e9e8 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -19,7 +19,7 @@ use tokio::{ broadcast::Sender, mpsc::{UnboundedReceiver, UnboundedSender}, }, - time::{Duration, interval}, + time::{Duration, Interval, interval}, }; use tracing::{debug, error, info, trace, warn}; @@ -34,7 +34,12 @@ pub mod events; pub mod session_state; const MESSAGE_LIMIT: usize = 100; -const SESSION_UPDATE_INTERVAL: u64 = 60; +pub const SESSION_UPDATE_INTERVAL: u64 = 60; + +pub enum IterationOutcome { + ProcessedBatch(usize), + TickNoMessages, +} pub async fn run_session_manager( pool: PgPool, @@ -49,40 +54,52 @@ pub async fn run_session_manager( let mut session_manager = SessionManager::new(pool, session_manager_event_tx, gateway_tx); loop { - // receive next batch of peer stats messages - // if no message is received within `SESSION_UPDATE_INTERVAL` trigger session status refresh anyway - // to disconnect inactive sessions if necessary - let mut message_buffer: Vec = Vec::with_capacity(MESSAGE_LIMIT); - let _message_count = tokio::select! { - message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count, - _ = session_update_timer.tick() => { - warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients."); - session_manager.update_inactive_session_status().await?; - - // skip to next iteration - continue; - } + run_session_manager_iteration( + &mut session_manager, + &mut peer_stats_rx, + &mut session_update_timer, + ) + .await?; + } +} - }; +pub async fn run_session_manager_iteration( + session_manager: &mut SessionManager, + peer_stats_rx: &mut UnboundedReceiver, + session_update_timer: &mut Interval, +) -> Result { + // receive next batch of peer stats messages + // if no message is received within `SESSION_UPDATE_INTERVAL` trigger session status refresh anyway + // to disconnect inactive sessions if necessary + let mut message_buffer: Vec = Vec::with_capacity(MESSAGE_LIMIT); + let message_count = tokio::select! { + message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count, + _ = session_update_timer.tick() => { + warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients."); + session_manager.update_inactive_session_status().await?; + + return Ok(IterationOutcome::TickNoMessages); + } - // process received messages to update active sessions - session_manager - .process_message_batch(message_buffer) - .await?; + }; - // update inactive/disconnected sessions - session_manager.update_inactive_session_status().await?; - } + // process received messages to update active sessions + session_manager.process_message_batch(message_buffer).await?; + + // update inactive/disconnected sessions + session_manager.update_inactive_session_status().await?; + + Ok(IterationOutcome::ProcessedBatch(message_count)) } -struct SessionManager { +pub struct SessionManager { pool: PgPool, session_manager_event_tx: UnboundedSender, gateway_tx: Sender, } impl SessionManager { - fn new( + pub fn new( pool: PgPool, session_manager_event_tx: UnboundedSender, gateway_tx: Sender, diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 0783404f9..691e7778a 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -7,59 +7,37 @@ use defguard_common::db::models::{ }; use defguard_common::db::Id; use defguard_common::messages::peer_stats_update::PeerStatsUpdate; -use defguard_core::grpc::GatewayEvent; -use defguard_session_manager::{events::SessionManagerEvent, run_session_manager}; +use defguard_session_manager::{SessionManager, events::SessionManagerEvent}; use ipnetwork::IpNetwork; -use tokio::{ - sync::{broadcast, mpsc}, - task::JoinHandle, - time::{Duration, timeout}, -}; - -const EVENT_TIMEOUT: Duration = Duration::from_secs(2); +use tokio::sync::{broadcast, mpsc}; pub(crate) struct SessionManagerHarness { + pub(crate) manager: SessionManager, stats_tx: mpsc::UnboundedSender, - event_rx: mpsc::UnboundedReceiver, - #[allow(dead_code)] - gateway_rx: broadcast::Receiver, - handle: JoinHandle>, + pub(crate) stats_rx: mpsc::UnboundedReceiver, + pub(crate) event_rx: mpsc::UnboundedReceiver, } -impl Drop for SessionManagerHarness { - fn drop(&mut self) { - self.handle.abort(); +impl SessionManagerHarness { + pub(crate) fn new(pool: sqlx::PgPool) -> Self { + let (stats_tx, stats_rx) = mpsc::unbounded_channel(); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (gateway_tx, _gateway_rx) = broadcast::channel(16); + let manager = SessionManager::new(pool, event_tx, gateway_tx); + + Self { + manager, + stats_tx, + stats_rx, + event_rx, + } } -} -impl SessionManagerHarness { pub(crate) fn send_stats(&self, update: PeerStatsUpdate) { self.stats_tx .send(update) .expect("failed to send peer stats update"); } - - pub(crate) async fn recv_event(&mut self) -> SessionManagerEvent { - timeout(EVENT_TIMEOUT, self.event_rx.recv()) - .await - .expect("timed out waiting for session manager event") - .expect("session manager event channel closed") - } -} - -pub(crate) fn start_session_manager(pool: sqlx::PgPool) -> SessionManagerHarness { - let (stats_tx, stats_rx) = mpsc::unbounded_channel(); - let (event_tx, event_rx) = mpsc::unbounded_channel(); - let (gateway_tx, gateway_rx) = broadcast::channel(16); - - let handle = tokio::spawn(run_session_manager(pool, stats_rx, event_tx, gateway_tx)); - - SessionManagerHarness { - stats_tx, - event_rx, - gateway_rx, - handle, - } } pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork { diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs index 9ad914561..9dac43116 100644 --- a/crates/defguard_session_manager/tests/session_manager/event_flow.rs +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -3,12 +3,15 @@ use std::net::SocketAddr; use chrono::{TimeDelta, Utc}; use defguard_common::db::setup_pool; use defguard_common::messages::peer_stats_update::PeerStatsUpdate; -use defguard_session_manager::events::SessionManagerEventType; +use defguard_session_manager::{ + SESSION_UPDATE_INTERVAL, events::SessionManagerEventType, run_session_manager_iteration, +}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, interval}; use crate::common::{ - attach_device_to_network, create_device, create_gateway, create_network, create_user, - start_session_manager, + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_user, }; #[sqlx::test] @@ -20,7 +23,7 @@ async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: P attach_device_to_network(&pool, network.id, device.id).await; let gateway = create_gateway(&pool, network.id, user.id).await; - let mut manager = start_session_manager(pool); + let mut harness = SessionManagerHarness::new(pool); let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); @@ -35,9 +38,22 @@ async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: P latest_handshake: base_time - TimeDelta::seconds(5), }; - manager.send_stats(update); + harness.send_stats(update); + + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); - let event = manager.recv_event().await; + let event = harness + .event_rx + .recv() + .await + .expect("session manager event channel closed"); assert!(matches!( event.event, diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs index 7561aa4db..e9892a76c 100644 --- a/crates/defguard_session_manager/tests/session_manager/sessions.rs +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -2,35 +2,15 @@ use chrono::{TimeDelta, Utc}; use defguard_common::db::models::vpn_client_session::{VpnClientSession, VpnClientSessionState}; use defguard_common::db::setup_pool; use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, timeout}; +use tokio::time::{Duration, interval}; use crate::common::{ - attach_device_to_network, create_device, create_gateway, create_network, create_user, - start_session_manager, + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_user, }; -const DB_WAIT_TIMEOUT: Duration = Duration::from_secs(2); - -async fn wait_for_active_session( - pool: &sqlx::PgPool, - location_id: defguard_common::db::Id, - device_id: defguard_common::db::Id, -) -> VpnClientSession { - timeout(DB_WAIT_TIMEOUT, async { - loop { - if let Ok(Some(session)) = - VpnClientSession::try_get_active_session(pool, location_id, device_id).await - { - return session; - } - tokio::time::sleep(Duration::from_millis(25)).await; - } - }) - .await - .expect("timed out waiting for active session") -} - #[sqlx::test] async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; @@ -40,7 +20,7 @@ async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: attach_device_to_network(&pool, network.id, device.id).await; let gateway = create_gateway(&pool, network.id, user.id).await; - let manager = start_session_manager(pool.clone()); + let mut harness = SessionManagerHarness::new(pool.clone()); let base_time = Utc::now().naive_utc(); let update = PeerStatsUpdate { @@ -54,8 +34,20 @@ async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: latest_handshake: base_time - TimeDelta::seconds(5), }; - manager.send_stats(update); + harness.send_stats(update); + + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); - let session = wait_for_active_session(&pool, network.id, device.id).await; + let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); assert_eq!(session.state, VpnClientSessionState::Connected); } diff --git a/crates/defguard_session_manager/tests/session_manager/stats.rs b/crates/defguard_session_manager/tests/session_manager/stats.rs index f601d09fb..7844ccbb9 100644 --- a/crates/defguard_session_manager/tests/session_manager/stats.rs +++ b/crates/defguard_session_manager/tests/session_manager/stats.rs @@ -1,42 +1,18 @@ +use std::net::SocketAddr; + use chrono::{TimeDelta, Utc}; use defguard_common::db::models::vpn_session_stats::VpnSessionStats; use defguard_common::db::setup_pool; use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; -use tokio::time::{Duration, timeout}; +use tokio::time::{Duration, interval}; use crate::common::{ - attach_device_to_network, create_device, create_gateway, create_network, create_user, - start_session_manager, + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_user, }; -const DB_WAIT_TIMEOUT: Duration = Duration::from_secs(2); - -async fn wait_for_latest_stats( - pool: &sqlx::PgPool, - device_id: defguard_common::db::Id, - location_id: defguard_common::db::Id, - expected_upload: i64, - expected_download: i64, -) -> VpnSessionStats { - timeout(DB_WAIT_TIMEOUT, async { - loop { - if let Ok(Some(stats)) = - VpnSessionStats::fetch_latest_for_device(pool, device_id, location_id).await - { - if stats.total_upload == expected_upload - && stats.total_download == expected_download - { - return stats; - } - } - tokio::time::sleep(Duration::from_millis(25)).await; - } - }) - .await - .expect("timed out waiting for latest stats") -} - #[sqlx::test] async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnectOptions) { let pool = setup_pool(options).await; @@ -46,9 +22,9 @@ async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnect attach_device_to_network(&pool, network.id, device.id).await; let gateway = create_gateway(&pool, network.id, user.id).await; - let manager = start_session_manager(pool.clone()); + let mut harness = SessionManagerHarness::new(pool.clone()); - let endpoint: std::net::SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); let base_time = Utc::now().naive_utc(); let first_update = PeerStatsUpdate { location_id: network.id, @@ -61,9 +37,21 @@ async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnect latest_handshake: base_time - TimeDelta::seconds(5), }; - manager.send_stats(first_update); + harness.send_stats(first_update); - let first_stats = wait_for_latest_stats(&pool, device.id, network.id, 100, 200).await; + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); + + let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query session stats") + .expect("expected session stats"); assert_eq!(first_stats.upload_diff, 0); assert_eq!(first_stats.download_diff, 0); @@ -78,9 +66,20 @@ async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnect latest_handshake: base_time + TimeDelta::seconds(10), }; - manager.send_stats(second_update); + harness.send_stats(second_update); + + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); - let second_stats = wait_for_latest_stats(&pool, device.id, network.id, 150, 260).await; + let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query session stats") + .expect("expected session stats"); assert_eq!(second_stats.upload_diff, 50); assert_eq!(second_stats.download_diff, 60); } From f9a045c05f5f5ad2472167faf36967bc81a79491 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20W=C3=B3jcik?= Date: Wed, 25 Feb 2026 10:32:42 +0100 Subject: [PATCH 3/4] avoid race condition with timer --- crates/defguard_session_manager/src/error.rs | 2 ++ crates/defguard_session_manager/src/lib.rs | 11 ++++++++++- crates/defguard_session_manager/tests/common/mod.rs | 13 +++++-------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 5242bf2a0..4f065ce6e 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -29,6 +29,8 @@ pub enum SessionManagerError { SessionDoesNotExistError(Id), #[error("Received out of order peer stats update")] PeerStatsUpdateOutOfOrderError, + #[error("Peer stats channel closed")] + PeerStatsChannelClosed, #[error("Failed to send session manager event: {0}")] SessionManagerEventError(Box>), #[error("Failed to send gateway manager event: {0}")] diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index fef06e9e8..92f6d4365 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -73,6 +73,7 @@ pub async fn run_session_manager_iteration( // to disconnect inactive sessions if necessary let mut message_buffer: Vec = Vec::with_capacity(MESSAGE_LIMIT); let message_count = tokio::select! { + biased; message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count, _ = session_update_timer.tick() => { warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients."); @@ -83,12 +84,20 @@ pub async fn run_session_manager_iteration( }; + if message_count == 0 { + return Err(SessionManagerError::PeerStatsChannelClosed); + } + // process received messages to update active sessions - session_manager.process_message_batch(message_buffer).await?; + session_manager + .process_message_batch(message_buffer) + .await?; // update inactive/disconnected sessions session_manager.update_inactive_session_status().await?; + session_update_timer.reset(); + Ok(IterationOutcome::ProcessedBatch(message_count)) } diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs index 691e7778a..b5ae4a65d 100644 --- a/crates/defguard_session_manager/tests/common/mod.rs +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -1,11 +1,12 @@ use std::net::{IpAddr, Ipv4Addr}; +use defguard_common::db::Id; use defguard_common::db::models::gateway::Gateway; use defguard_common::db::models::{ - device::WireguardNetworkDevice, wireguard::{LocationMfaMode, ServiceLocationMode}, Device, - DeviceType, User, WireguardNetwork, + Device, DeviceType, User, WireguardNetwork, + device::WireguardNetworkDevice, + wireguard::{LocationMfaMode, ServiceLocationMode}, }; -use defguard_common::db::Id; use defguard_common::messages::peer_stats_update::PeerStatsUpdate; use defguard_session_manager::{SessionManager, events::SessionManagerEvent}; use ipnetwork::IpNetwork; @@ -90,11 +91,7 @@ pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device Date: Wed, 25 Feb 2026 10:37:57 +0100 Subject: [PATCH 4/4] add comment --- crates/defguard_session_manager/src/lib.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index 92f6d4365..48cce5a1d 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -96,6 +96,7 @@ pub async fn run_session_manager_iteration( // update inactive/disconnected sessions session_manager.update_inactive_session_status().await?; + // reset timer to avoid it being immediately ready on next iteration session_update_timer.reset(); Ok(IterationOutcome::ProcessedBatch(message_count))