Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/defguard_session_manager/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,5 @@ thiserror.workspace = true
tokio.workspace = true
tracing.workspace = true

[dev-dependencies]
ipnetwork.workspace = true
2 changes: 2 additions & 0 deletions crates/defguard_session_manager/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub enum SessionManagerError {
SessionDoesNotExistError(Id),
#[error("Received out of order peer stats update")]
PeerStatsUpdateOutOfOrderError,
#[error("Peer stats channel closed")]
PeerStatsChannelClosed,
#[error("Failed to send session manager event: {0}")]
SessionManagerEventError(Box<SendError<SessionManagerEvent>>),
#[error("Failed to send gateway manager event: {0}")]
Expand Down
75 changes: 51 additions & 24 deletions crates/defguard_session_manager/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use tokio::{
broadcast::Sender,
mpsc::{UnboundedReceiver, UnboundedSender},
},
time::{Duration, interval},
time::{Duration, Interval, interval},
};
use tracing::{debug, error, info, trace, warn};

Expand All @@ -34,7 +34,12 @@ pub mod events;
pub mod session_state;

const MESSAGE_LIMIT: usize = 100;
const SESSION_UPDATE_INTERVAL: u64 = 60;
pub const SESSION_UPDATE_INTERVAL: u64 = 60;

pub enum IterationOutcome {
ProcessedBatch(usize),
TickNoMessages,
}

pub async fn run_session_manager(
pool: PgPool,
Expand All @@ -49,40 +54,62 @@ pub async fn run_session_manager(
let mut session_manager = SessionManager::new(pool, session_manager_event_tx, gateway_tx);

loop {
// receive next batch of peer stats messages
// if no message is received within `SESSION_UPDATE_INTERVAL` trigger session status refresh anyway
// to disconnect inactive sessions if necessary
let mut message_buffer: Vec<PeerStatsUpdate> = Vec::with_capacity(MESSAGE_LIMIT);
let _message_count = tokio::select! {
message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count,
_ = session_update_timer.tick() => {
warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients.");
session_manager.update_inactive_session_status().await?;

// skip to next iteration
continue;
}
run_session_manager_iteration(
&mut session_manager,
&mut peer_stats_rx,
&mut session_update_timer,
)
.await?;
}
}

};
pub async fn run_session_manager_iteration(
session_manager: &mut SessionManager,
peer_stats_rx: &mut UnboundedReceiver<PeerStatsUpdate>,
session_update_timer: &mut Interval,
) -> Result<IterationOutcome, SessionManagerError> {
// receive next batch of peer stats messages
// if no message is received within `SESSION_UPDATE_INTERVAL` trigger session status refresh anyway
// to disconnect inactive sessions if necessary
let mut message_buffer: Vec<PeerStatsUpdate> = Vec::with_capacity(MESSAGE_LIMIT);
let message_count = tokio::select! {
biased;
message_count = peer_stats_rx.recv_many(&mut message_buffer, MESSAGE_LIMIT) => message_count,
_ = session_update_timer.tick() => {
warn!("No wireguard peer stats updates received in last {SESSION_UPDATE_INTERVAL}. Triggering session status update to disconnect inactive clients.");
session_manager.update_inactive_session_status().await?;

return Ok(IterationOutcome::TickNoMessages);
}

// process received messages to update active sessions
session_manager
.process_message_batch(message_buffer)
.await?;
};

// update inactive/disconnected sessions
session_manager.update_inactive_session_status().await?;
if message_count == 0 {
return Err(SessionManagerError::PeerStatsChannelClosed);
}

// process received messages to update active sessions
session_manager
.process_message_batch(message_buffer)
.await?;

// update inactive/disconnected sessions
session_manager.update_inactive_session_status().await?;

// reset timer to avoid it being immediately ready on next iteration
session_update_timer.reset();

Ok(IterationOutcome::ProcessedBatch(message_count))
}

struct SessionManager {
pub struct SessionManager {
pool: PgPool,
session_manager_event_tx: UnboundedSender<SessionManagerEvent>,
gateway_tx: Sender<GatewayEvent>,
}

impl SessionManager {
fn new(
pub fn new(
pool: PgPool,
session_manager_event_tx: UnboundedSender<SessionManagerEvent>,
gateway_tx: Sender<GatewayEvent>,
Expand Down
121 changes: 121 additions & 0 deletions crates/defguard_session_manager/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use std::net::{IpAddr, Ipv4Addr};

use defguard_common::db::Id;
use defguard_common::db::models::gateway::Gateway;
use defguard_common::db::models::{
Device, DeviceType, User, WireguardNetwork,
device::WireguardNetworkDevice,
wireguard::{LocationMfaMode, ServiceLocationMode},
};
use defguard_common::messages::peer_stats_update::PeerStatsUpdate;
use defguard_session_manager::{SessionManager, events::SessionManagerEvent};
use ipnetwork::IpNetwork;
use tokio::sync::{broadcast, mpsc};

pub(crate) struct SessionManagerHarness {
pub(crate) manager: SessionManager,
stats_tx: mpsc::UnboundedSender<PeerStatsUpdate>,
pub(crate) stats_rx: mpsc::UnboundedReceiver<PeerStatsUpdate>,
pub(crate) event_rx: mpsc::UnboundedReceiver<SessionManagerEvent>,
}

impl SessionManagerHarness {
pub(crate) fn new(pool: sqlx::PgPool) -> Self {
let (stats_tx, stats_rx) = mpsc::unbounded_channel();
let (event_tx, event_rx) = mpsc::unbounded_channel();
let (gateway_tx, _gateway_rx) = broadcast::channel(16);
let manager = SessionManager::new(pool, event_tx, gateway_tx);

Self {
manager,
stats_tx,
stats_rx,
event_rx,
}
}

pub(crate) fn send_stats(&self, update: PeerStatsUpdate) {
self.stats_tx
.send(update)
.expect("failed to send peer stats update");
}
}

pub(crate) async fn create_network(pool: &sqlx::PgPool) -> WireguardNetwork<Id> {
WireguardNetwork::new(
"TestNet".to_string(),
vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 0)), 24).unwrap()],
51820,
"10.0.0.1".to_string(),
None,
1420,
0,
vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 0).unwrap()],
25,
300,
false,
false,
LocationMfaMode::Disabled,
ServiceLocationMode::Disabled,
)
.save(pool)
.await
.expect("failed to create Wireguard network")
}

pub(crate) async fn create_user(pool: &sqlx::PgPool) -> User<Id> {
User::new(
"session-test",
Some("pass123"),
"Tester",
"Session",
"session-test@example.com",
None,
)
.save(pool)
.await
.expect("failed to create user")
}

pub(crate) async fn create_device(pool: &sqlx::PgPool, user_id: Id) -> Device<Id> {
Device::new(
"session-test-device".to_string(),
"device-pubkey-test".to_string(),
user_id,
DeviceType::User,
None,
true,
)
.save(pool)
.await
.expect("failed to create device")
}

pub(crate) async fn attach_device_to_network(pool: &sqlx::PgPool, network_id: Id, device_id: Id) {
let network_device = WireguardNetworkDevice::new(
network_id,
device_id,
vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 10))],
);
network_device
.insert(pool)
.await
.expect("failed to attach device to network");
}

pub(crate) async fn create_gateway(
pool: &sqlx::PgPool,
network_id: Id,
modified_by: Id,
) -> Gateway<Id> {
Gateway::new(
network_id,
"gateway-1".to_string(),
"127.0.0.1".to_string(),
51820,
modified_by,
)
.save(pool)
.await
.expect("failed to create gateway")
}
2 changes: 2 additions & 0 deletions crates/defguard_session_manager/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub(crate) mod common;
pub(crate) mod session_manager;
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::net::SocketAddr;

use chrono::{TimeDelta, Utc};
use defguard_common::db::setup_pool;
use defguard_common::messages::peer_stats_update::PeerStatsUpdate;
use defguard_session_manager::{
SESSION_UPDATE_INTERVAL, events::SessionManagerEventType, run_session_manager_iteration,
};
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use tokio::time::{Duration, interval};

use crate::common::{
SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network,
create_user,
};

#[sqlx::test]
async fn test_session_manager_emits_connected_event(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;
let network = create_network(&pool).await;
let user = create_user(&pool).await;
let device = create_device(&pool, user.id).await;
attach_device_to_network(&pool, network.id, device.id).await;
let gateway = create_gateway(&pool, network.id, user.id).await;

let mut harness = SessionManagerHarness::new(pool);

let endpoint: SocketAddr = "203.0.113.10:51820".parse().unwrap();
let base_time = Utc::now().naive_utc();
let update = PeerStatsUpdate {
location_id: network.id,
gateway_id: gateway.id,
device_pubkey: device.wireguard_pubkey.clone(),
collected_at: base_time,
endpoint,
upload: 100,
download: 200,
latest_handshake: base_time - TimeDelta::seconds(5),
};

harness.send_stats(update);

let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL));
let _ = run_session_manager_iteration(
&mut harness.manager,
&mut harness.stats_rx,
&mut session_update_timer,
)
.await
.expect("session manager iteration failed");

let event = harness
.event_rx
.recv()
.await
.expect("session manager event channel closed");

assert!(matches!(
event.event,
SessionManagerEventType::ClientConnected
));
assert_eq!(event.context.location.id, network.id);
assert_eq!(event.context.user.id, user.id);
assert_eq!(event.context.device.id, device.id);
assert_eq!(event.context.public_ip, endpoint.ip());
}
3 changes: 3 additions & 0 deletions crates/defguard_session_manager/tests/session_manager/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod event_flow;
mod sessions;
mod stats;
53 changes: 53 additions & 0 deletions crates/defguard_session_manager/tests/session_manager/sessions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use chrono::{TimeDelta, Utc};
use defguard_common::db::models::vpn_client_session::{VpnClientSession, VpnClientSessionState};
use defguard_common::db::setup_pool;
use defguard_common::messages::peer_stats_update::PeerStatsUpdate;
use defguard_session_manager::{SESSION_UPDATE_INTERVAL, run_session_manager_iteration};
use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
use tokio::time::{Duration, interval};

use crate::common::{
SessionManagerHarness, attach_device_to_network, create_device, create_gateway, create_network,
create_user,
};

#[sqlx::test]
async fn test_session_manager_creates_active_session(_: PgPoolOptions, options: PgConnectOptions) {
let pool = setup_pool(options).await;
let network = create_network(&pool).await;
let user = create_user(&pool).await;
let device = create_device(&pool, user.id).await;
attach_device_to_network(&pool, network.id, device.id).await;
let gateway = create_gateway(&pool, network.id, user.id).await;

let mut harness = SessionManagerHarness::new(pool.clone());

let base_time = Utc::now().naive_utc();
let update = PeerStatsUpdate {
location_id: network.id,
gateway_id: gateway.id,
device_pubkey: device.wireguard_pubkey.clone(),
collected_at: base_time,
endpoint: "203.0.113.10:51820".parse().unwrap(),
upload: 100,
download: 200,
latest_handshake: base_time - TimeDelta::seconds(5),
};

harness.send_stats(update);

let mut session_update_timer = interval(Duration::from_secs(SESSION_UPDATE_INTERVAL));
let _ = run_session_manager_iteration(
&mut harness.manager,
&mut harness.stats_rx,
&mut session_update_timer,
)
.await
.expect("session manager iteration failed");

let session = VpnClientSession::try_get_active_session(&pool, network.id, device.id)
.await
.expect("failed to query active session")
.expect("expected active session");
assert_eq!(session.state, VpnClientSessionState::Connected);
}
Loading
Loading