diff --git a/Cargo.lock b/Cargo.lock index 6654033d9..c9cf97a7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1590,6 +1590,7 @@ dependencies = [ "chrono", "defguard_common", "defguard_core", + "ipnetwork", "sqlx", "thiserror 2.0.18", "tokio", diff --git a/crates/defguard_session_manager/Cargo.toml b/crates/defguard_session_manager/Cargo.toml index a28bc0ec1..d3bf8b077 100644 --- a/crates/defguard_session_manager/Cargo.toml +++ b/crates/defguard_session_manager/Cargo.toml @@ -18,3 +18,5 @@ thiserror.workspace = true tokio.workspace = true tracing.workspace = true +[dev-dependencies] +ipnetwork.workspace = true diff --git a/crates/defguard_session_manager/src/error.rs b/crates/defguard_session_manager/src/error.rs index 5242bf2a0..4f065ce6e 100644 --- a/crates/defguard_session_manager/src/error.rs +++ b/crates/defguard_session_manager/src/error.rs @@ -29,6 +29,8 @@ pub enum SessionManagerError { SessionDoesNotExistError(Id), #[error("Received out of order peer stats update")] PeerStatsUpdateOutOfOrderError, + #[error("Peer stats channel closed")] + PeerStatsChannelClosed, #[error("Failed to send session manager event: {0}")] SessionManagerEventError(Box>), #[error("Failed to send gateway manager event: {0}")] diff --git a/crates/defguard_session_manager/src/lib.rs b/crates/defguard_session_manager/src/lib.rs index cfe8505ee..48cce5a1d 100644 --- a/crates/defguard_session_manager/src/lib.rs +++ b/crates/defguard_session_manager/src/lib.rs @@ -19,7 +19,7 @@ use tokio::{ broadcast::Sender, mpsc::{UnboundedReceiver, UnboundedSender}, }, - time::{Duration, interval}, + time::{Duration, Interval, interval}, }; use tracing::{debug, error, info, trace, warn}; @@ -34,7 +34,12 @@ pub mod events; pub mod session_state; const MESSAGE_LIMIT: usize = 100; -const SESSION_UPDATE_INTERVAL: u64 = 60; +pub const SESSION_UPDATE_INTERVAL: u64 = 60; + +pub enum IterationOutcome { + ProcessedBatch(usize), + TickNoMessages, +} pub async fn run_session_manager( pool: PgPool, @@ -49,40 +54,62 @@ pub async fn run_session_manager( let mut session_manager = SessionManager::new(pool, session_manager_event_tx, gateway_tx); loop { - // receive next batch of peer stats messages - // if no message is received within `SESSION_UPDATE_INTERVAL` trigger session status refresh anyway - // to disconnect inactive sessions if necessary - let mut message_buffer: Vec = Vec::with_capacity(MESSAGE_LIMIT); - let _message_count = tokio::select! { - message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count, - _ = session_update_timer.tick() => { - warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients."); - session_manager.update_inactive_session_status().await?; - - // skip to next iteration - continue; - } + run_session_manager_iteration( + &mut session_manager, + &mut peer_stats_rx, + &mut session_update_timer, + ) + .await?; + } +} - }; +pub async fn run_session_manager_iteration( + session_manager: &mut SessionManager, + peer_stats_rx: &mut UnboundedReceiver, + session_update_timer: &mut Interval, +) -> Result { + // receive next batch of peer stats messages + // if no message is received within `SESSION_UPDATE_INTERVAL` trigger session status refresh anyway + // to disconnect inactive sessions if necessary + let mut message_buffer: Vec = Vec::with_capacity(MESSAGE_LIMIT); + let message_count = tokio::select! { + biased; + message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count, + _ = session_update_timer.tick() => { + warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients."); + session_manager.update_inactive_session_status().await?; + + return Ok(IterationOutcome::TickNoMessages); + } - // process received messages to update active sessions - session_manager - .process_message_batch(message_buffer) - .await?; + }; - // update inactive/disconnected sessions - session_manager.update_inactive_session_status().await?; + if message_count == 0 { + return Err(SessionManagerError::PeerStatsChannelClosed); } + + // process received messages to update active sessions + session_manager + .process_message_batch(message_buffer) + .await?; + + // update inactive/disconnected sessions + session_manager.update_inactive_session_status().await?; + + // reset timer to avoid it being immediately ready on next iteration + session_update_timer.reset(); + + Ok(IterationOutcome::ProcessedBatch(message_count)) } -struct SessionManager { +pub struct SessionManager { pool: PgPool, session_manager_event_tx: UnboundedSender, gateway_tx: Sender, } impl SessionManager { - fn new( + pub fn new( pool: PgPool, session_manager_event_tx: UnboundedSender, gateway_tx: Sender, diff --git a/crates/defguard_session_manager/tests/common/mod.rs b/crates/defguard_session_manager/tests/common/mod.rs new file mode 100644 index 000000000..b5ae4a65d --- /dev/null +++ b/crates/defguard_session_manager/tests/common/mod.rs @@ -0,0 +1,121 @@ +use std::net::{IpAddr, Ipv4Addr}; + +use defguard_common::db::Id; +use defguard_common::db::models::gateway::Gateway; +use defguard_common::db::models::{ + Device, DeviceType, User, WireguardNetwork, + device::WireguardNetworkDevice, + wireguard::{LocationMfaMode, ServiceLocationMode}, +}; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::{SessionManager, events::SessionManagerEvent}; +use ipnetwork::IpNetwork; +use tokio::sync::{broadcast, mpsc}; + +pub(crate) struct SessionManagerHarness { + pub(crate) manager: SessionManager, + stats_tx: mpsc::UnboundedSender, + pub(crate) stats_rx: mpsc::UnboundedReceiver, + pub(crate) event_rx: mpsc::UnboundedReceiver, +} + +impl SessionManagerHarness { + pub(crate) fn new(pool: sqlx::PgPool) -> Self { + let (stats_tx, stats_rx) = mpsc::unbounded_channel(); + let (event_tx, event_rx) = mpsc::unbounded_channel(); + let (gateway_tx, _gateway_rx) = broadcast::channel(16); + let manager = SessionManager::new(pool, event_tx, gateway_tx); + + Self { + manager, + stats_tx, + stats_rx, + event_rx, + } + } + + pub(crate) fn send_stats(&self, update: PeerStatsUpdate) { + self.stats_tx + .send(update) + .expect("failed to send peer stats update"); + } +} + +pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork { + WireguardNetwork::new( + "TestNet".to_string(), + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()], + 51820, + "10.0.0.1".to_string(), + None, + 1420, + 0, + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0).unwrap()], + 25, + 300, + false, + false, + LocationMfaMode::Disabled, + ServiceLocationMode::Disabled, + ) + .save(pool) + .await + .expect("failed to create Wireguard network") +} + +pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User { + User::new( + "session-test", + Some("pass123"), + "Tester", + "Session", + "session-test@example.com", + None, + ) + .save(pool) + .await + .expect("failed to create user") +} + +pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device { + Device::new( + "session-test-device".to_string(), + "device-pubkey-test".to_string(), + user_id, + DeviceType::User, + None, + true, + ) + .save(pool) + .await + .expect("failed to create device") +} + +pub(crate) async fn attach_device_to_network(pool: &sqlx::PgPool, network_id: Id, device_id: Id) { + let network_device = WireguardNetworkDevice::new( + network_id, + device_id, + vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 10))], + ); + network_device + .insert(pool) + .await + .expect("failed to attach device to network"); +} + +pub(crate) async fn create_gateway( + pool: &sqlx::PgPool, + network_id: Id, + modified_by: Id, +) -> Gateway { + Gateway::new( + network_id, + "gateway-1".to_string(), + "127.0.0.1".to_string(), + 51820, + modified_by, + ) + .save(pool) + .await + .expect("failed to create gateway") +} diff --git a/crates/defguard_session_manager/tests/mod.rs b/crates/defguard_session_manager/tests/mod.rs new file mode 100644 index 000000000..9bf2fc86c --- /dev/null +++ b/crates/defguard_session_manager/tests/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod common; +pub(crate) mod session_manager; diff --git a/crates/defguard_session_manager/tests/session_manager/event_flow.rs b/crates/defguard_session_manager/tests/session_manager/event_flow.rs new file mode 100644 index 000000000..9dac43116 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/event_flow.rs @@ -0,0 +1,66 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::setup_pool; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::{ + SESSION_UPDATE_INTERVAL, events::SessionManagerEventType, run_session_manager_iteration, +}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, interval}; + +use crate::common::{ + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_user, +}; + +#[sqlx::test] +async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.id).await; + + let mut harness = SessionManagerHarness::new(pool); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time, + endpoint, + upload: 100, + download: 200, + latest_handshake: base_time - TimeDelta::seconds(5), + }; + + harness.send_stats(update); + + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); + + let event = harness + .event_rx + .recv() + .await + .expect("session manager event channel closed"); + + assert!(matches!( + event.event, + SessionManagerEventType::ClientConnected + )); + assert_eq!(event.context.location.id, network.id); + assert_eq!(event.context.user.id, user.id); + assert_eq!(event.context.device.id, device.id); + assert_eq!(event.context.public_ip, endpoint.ip()); +} diff --git a/crates/defguard_session_manager/tests/session_manager/mod.rs b/crates/defguard_session_manager/tests/session_manager/mod.rs new file mode 100644 index 000000000..a16bea70c --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/mod.rs @@ -0,0 +1,3 @@ +mod event_flow; +mod sessions; +mod stats; diff --git a/crates/defguard_session_manager/tests/session_manager/sessions.rs b/crates/defguard_session_manager/tests/session_manager/sessions.rs new file mode 100644 index 000000000..e9892a76c --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/sessions.rs @@ -0,0 +1,53 @@ +use chrono::{TimeDelta, Utc}; +use defguard_common::db::models::vpn_client_session::{VpnClientSession, VpnClientSessionState}; +use defguard_common::db::setup_pool; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, interval}; + +use crate::common::{ + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_user, +}; + +#[sqlx::test] +async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.id).await; + + let mut harness = SessionManagerHarness::new(pool.clone()); + + let base_time = Utc::now().naive_utc(); + let update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time, + endpoint: "203.0.113.10:51820".parse().unwrap(), + upload: 100, + download: 200, + latest_handshake: base_time - TimeDelta::seconds(5), + }; + + harness.send_stats(update); + + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); + + let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id) + .await + .expect("failed to query active session") + .expect("expected active session"); + assert_eq!(session.state, VpnClientSessionState::Connected); +} diff --git a/crates/defguard_session_manager/tests/session_manager/stats.rs b/crates/defguard_session_manager/tests/session_manager/stats.rs new file mode 100644 index 000000000..7844ccbb9 --- /dev/null +++ b/crates/defguard_session_manager/tests/session_manager/stats.rs @@ -0,0 +1,85 @@ +use std::net::SocketAddr; + +use chrono::{TimeDelta, Utc}; +use defguard_common::db::models::vpn_session_stats::VpnSessionStats; +use defguard_common::db::setup_pool; +use defguard_common::messages::peer_stats_update::PeerStatsUpdate; +use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration}; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; +use tokio::time::{Duration, interval}; + +use crate::common::{ + SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network, + create_user, +}; + +#[sqlx::test] +async fn test_session_manager_updates_stats(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; + let network = create_network(&pool).await; + let user = create_user(&pool).await; + let device = create_device(&pool, user.id).await; + attach_device_to_network(&pool, network.id, device.id).await; + let gateway = create_gateway(&pool, network.id, user.id).await; + + let mut harness = SessionManagerHarness::new(pool.clone()); + + let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap(); + let base_time = Utc::now().naive_utc(); + let first_update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time, + endpoint, + upload: 100, + download: 200, + latest_handshake: base_time - TimeDelta::seconds(5), + }; + + harness.send_stats(first_update); + + let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL)); + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); + + let first_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query session stats") + .expect("expected session stats"); + assert_eq!(first_stats.upload_diff, 0); + assert_eq!(first_stats.download_diff, 0); + + let second_update = PeerStatsUpdate { + location_id: network.id, + gateway_id: gateway.id, + device_pubkey: device.wireguard_pubkey.clone(), + collected_at: base_time + TimeDelta::seconds(10), + endpoint, + upload: 150, + download: 260, + latest_handshake: base_time + TimeDelta::seconds(10), + }; + + harness.send_stats(second_update); + + let _ = run_session_manager_iteration( + &mut harness.manager, + &mut harness.stats_rx, + &mut session_update_timer, + ) + .await + .expect("session manager iteration failed"); + + let second_stats = VpnSessionStats::fetch_latest_for_device(&pool, device.id, network.id) + .await + .expect("failed to query session stats") + .expect("expected session stats"); + assert_eq!(second_stats.upload_diff, 50); + assert_eq!(second_stats.download_diff, 60); +}