From 675765af24ba7447eb47f872258b38d0f75819cd Mon Sep 17 00:00:00 2001 From: Anant Vindal Date: Fri, 5 Dec 2025 16:18:21 +0530 Subject: [PATCH 1/2] Possible fix for the oauth issue This PR changes session management by storing the refresh token returned by the provider in memory. It is used to renew the session once it expires. Expiry is set based on the expires_in value returned by the provider --- src/cli.rs | 4 +- src/handlers/http/middleware.rs | 89 +++++++++++++++++- src/handlers/http/modal/ingest_server.rs | 4 +- src/handlers/http/modal/mod.rs | 13 ++- src/handlers/http/modal/query_server.rs | 6 +- src/handlers/http/modal/server.rs | 17 ++-- src/handlers/http/oidc.rs | 109 +++++++++++++++++------ src/rbac/map.rs | 21 +++++ src/rbac/mod.rs | 10 ++- src/rbac/user.rs | 28 ++++-- src/utils/mod.rs | 2 +- 11 files changed, 243 insertions(+), 60 deletions(-) diff --git a/src/cli.rs b/src/cli.rs index 4f4d4d219..0c8c84995 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -455,9 +455,9 @@ pub struct Options { long = "oidc-scope", name = "oidc-scope", env = "P_OIDC_SCOPE", - default_value = "openid profile email", + default_value = "openid profile email offline_access", required = false, - help = "OIDC scope to request (default: openid profile email)" + help = "OIDC scope to request (default: openid profile email offline_access)" )] pub scope: String, diff --git a/src/handlers/http/middleware.rs b/src/handlers/http/middleware.rs index dee3933e8..ac79d1795 100644 --- a/src/handlers/http/middleware.rs +++ b/src/handlers/http/middleware.rs @@ -24,16 +24,25 @@ use actix_web::{ dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready}, error::{ErrorBadRequest, ErrorForbidden, ErrorUnauthorized}, http::header::{self, HeaderName}, + web::Data, }; +use chrono::{Duration, Utc}; use futures_util::future::LocalBoxFuture; use crate::{ handlers::{ AUTHORIZATION_KEY, KINESIS_COMMON_ATTRIBUTES_KEY, LOG_SOURCE_KEY, LOG_SOURCE_KINESIS, - STREAM_NAME_HEADER_KEY, + STREAM_NAME_HEADER_KEY, http::rbac::RBACError, }, + oidc::DiscoveredClient, option::Mode, parseable::PARSEABLE, + rbac::{ + EXPIRY_DURATION, + map::{SessionKey, mut_sessions, mut_users, sessions, users}, + roles_to_permission, user, + }, + utils::get_user_from_request, }; use crate::{ rbac::Users, @@ -160,8 +169,86 @@ where let auth_result: Result<_, Error> = (self.auth_method)(&mut req, self.action); + let http_req = req.request().clone(); + let key: Result = extract_session_key(&mut req); + let userid: Result = get_user_from_request(&http_req); + let fut = self.service.call(req); Box::pin(async move { + let Ok(key) = key else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + }; + + // if session is expired, refresh token + if sessions().is_session_expired(&key) { + // request using oidc client + let oidc_client = match http_req.app_data::>>() { + Some(client) => { + let c = client.clone().into_inner(); + c.as_ref().clone() + } + None => None, + }; + + if let Some(client) = oidc_client + && let Ok(userid) = userid + && users().get(&userid).is_some() + { + // get the bearer token + let user = users().get(&userid).unwrap().clone(); + match &user.ty { + user::UserType::OAuth(oauth) => { + if oauth.bearer.as_ref().is_some() { + let Ok(refreshed_token) = client + .refresh_token(oauth, Some(PARSEABLE.options.scope.as_str())) + .await + else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + }; + let expires_in = + if let Some(expires_in) = refreshed_token.expires_in.as_ref() { + // need an i64 somehow + if *expires_in > u32::MAX.into() { + EXPIRY_DURATION + } else { + let v = i64::from(*expires_in as u32); + Duration::seconds(v) + } + } else { + EXPIRY_DURATION + }; + + // set the new oauth bearer value + if let Some(user) = mut_users().get_mut(&userid) + && let user::UserType::OAuth(oauth) = &mut user.ty + { + oauth.bearer = Some(refreshed_token) + } + + mut_sessions().track_new( + userid.clone(), + key.clone(), + Utc::now() + expires_in, + roles_to_permission(user.roles()), + ); + } + } + _ => { + mut_sessions().track_new( + userid.clone(), + key.clone(), + Utc::now() + EXPIRY_DURATION, + roles_to_permission(user.roles()), + ); + } + } + } + } + match auth_result? { rbac::Response::UnAuthorized => { return Err(ErrorForbidden( diff --git a/src/handlers/http/modal/ingest_server.rs b/src/handlers/http/modal/ingest_server.rs index 628bd9f0f..0440e857c 100644 --- a/src/handlers/http/modal/ingest_server.rs +++ b/src/handlers/http/modal/ingest_server.rs @@ -51,7 +51,7 @@ use crate::{ use super::IngestorMetadata; use super::{ - OpenIdClient, ParseableServer, + ParseableServer, ingest::{ingestor_logstream, ingestor_rbac, ingestor_role}, }; @@ -62,7 +62,7 @@ pub struct IngestServer; #[async_trait] impl ParseableServer for IngestServer { // configure the api routes - fn configure_routes(config: &mut web::ServiceConfig, _oidc_client: Option) { + fn configure_routes(config: &mut web::ServiceConfig) { config .service( // Base path "{url}/api/v1" diff --git a/src/handlers/http/modal/mod.rs b/src/handlers/http/modal/mod.rs index 844975e5f..c8be6c89a 100644 --- a/src/handlers/http/modal/mod.rs +++ b/src/handlers/http/modal/mod.rs @@ -18,7 +18,11 @@ use std::{fmt, path::Path, sync::Arc}; -use actix_web::{App, HttpServer, middleware::from_fn, web::ServiceConfig}; +use actix_web::{ + App, HttpServer, + middleware::from_fn, + web::{self, ServiceConfig}, +}; use actix_web_prometheus::PrometheusMetrics; use anyhow::Context; use async_trait::async_trait; @@ -67,7 +71,7 @@ include!(concat!(env!("OUT_DIR"), "/generated.rs")); #[async_trait] pub trait ParseableServer { /// configure the router - fn configure_routes(config: &mut ServiceConfig, oidc_client: Option) + fn configure_routes(config: &mut ServiceConfig) where Self: Sized; @@ -96,7 +100,7 @@ pub trait ParseableServer { let client = config .connect(&format!("{API_BASE_PATH}/{API_VERSION}/o/code")) .await?; - Some(Arc::new(client)) + Some(client) } None => None, @@ -116,8 +120,9 @@ pub trait ParseableServer { // fn that creates the app let create_app_fn = move || { App::new() + .app_data(web::Data::new(oidc_client.clone())) .wrap(prometheus.clone()) - .configure(|config| Self::configure_routes(config, oidc_client.clone())) + .configure(|config| Self::configure_routes(config)) .wrap(from_fn(health_check::check_shutdown_middleware)) .wrap(actix_web::middleware::Logger::default()) .wrap(actix_web::middleware::Compress::default()) diff --git a/src/handlers/http/modal/query_server.rs b/src/handlers/http/modal/query_server.rs index f1a4249c7..c345d3112 100644 --- a/src/handlers/http/modal/query_server.rs +++ b/src/handlers/http/modal/query_server.rs @@ -42,14 +42,14 @@ use crate::Server; use crate::parseable::PARSEABLE; use super::query::{querier_ingest, querier_logstream, querier_rbac, querier_role}; -use super::{NodeType, OpenIdClient, ParseableServer, QuerierMetadata, load_on_init}; +use super::{NodeType, ParseableServer, QuerierMetadata, load_on_init}; pub struct QueryServer; pub static QUERIER_META: OnceCell> = OnceCell::const_new(); #[async_trait] impl ParseableServer for QueryServer { // configure the api routes - fn configure_routes(config: &mut ServiceConfig, oidc_client: Option) { + fn configure_routes(config: &mut ServiceConfig) { config .service( web::scope(&base_path()) @@ -66,7 +66,7 @@ impl ParseableServer for QueryServer { .service(Server::get_dashboards_webscope()) .service(Server::get_filters_webscope()) .service(Server::get_llm_webscope()) - .service(Server::get_oauth_webscope(oidc_client)) + .service(Server::get_oauth_webscope()) .service(Self::get_user_role_webscope()) .service(Server::get_roles_webscope()) .service(Server::get_counts_webscope().wrap(from_fn( diff --git a/src/handlers/http/modal/server.rs b/src/handlers/http/modal/server.rs index 6e3ba9ea7..7b145ebb1 100644 --- a/src/handlers/http/modal/server.rs +++ b/src/handlers/http/modal/server.rs @@ -61,7 +61,6 @@ use crate::{ }; // use super::generate; -use super::OpenIdClient; use super::ParseableServer; use super::generate; use super::load_on_init; @@ -70,7 +69,7 @@ pub struct Server; #[async_trait] impl ParseableServer for Server { - fn configure_routes(config: &mut web::ServiceConfig, oidc_client: Option) { + fn configure_routes(config: &mut web::ServiceConfig) { // there might be a bug in the configure routes method config .service( @@ -91,7 +90,7 @@ impl ParseableServer for Server { .service(Self::get_dashboards_webscope()) .service(Self::get_filters_webscope()) .service(Self::get_llm_webscope()) - .service(Self::get_oauth_webscope(oidc_client)) + .service(Self::get_oauth_webscope()) .service(Self::get_user_role_webscope()) .service(Self::get_roles_webscope()) .service(Self::get_counts_webscope().wrap(from_fn( @@ -570,17 +569,11 @@ impl Server { } // get the oauth webscope - pub fn get_oauth_webscope(oidc_client: Option) -> Scope { - let oauth = web::scope("/o") + pub fn get_oauth_webscope() -> Scope { + web::scope("/o") .service(resource("/login").route(web::get().to(oidc::login))) .service(resource("/logout").route(web::get().to(oidc::logout))) - .service(resource("/code").route(web::get().to(oidc::reply_login))); - - if let Some(client) = oidc_client { - oauth.app_data(web::Data::from(client)) - } else { - oauth - } + .service(resource("/code").route(web::get().to(oidc::reply_login))) } // get list of roles diff --git a/src/handlers/http/oidc.rs b/src/handlers/http/oidc.rs index 5f3506d42..ed780f189 100644 --- a/src/handlers/http/oidc.rs +++ b/src/handlers/http/oidc.rs @@ -16,7 +16,7 @@ * */ -use std::{collections::HashSet, sync::Arc}; +use std::collections::HashSet; use actix_web::{ HttpRequest, HttpResponse, @@ -24,10 +24,12 @@ use actix_web::{ http::header::{self, ContentType}, web::{self, Data}, }; +use chrono::{Duration, TimeDelta}; use http::StatusCode; -use openid::{Options, Token, Userinfo}; +use openid::{Bearer, Options, Token, Userinfo}; use regex::Regex; use serde::Deserialize; +use tracing::error; use ulid::Ulid; use url::Url; @@ -36,7 +38,7 @@ use crate::{ oidc::{Claims, DiscoveredClient}, parseable::PARSEABLE, rbac::{ - self, Users, + self, EXPIRY_DURATION, Users, map::{DEFAULT_ROLE, SessionKey}, user::{self, GroupUser, User, UserType}, }, @@ -72,14 +74,20 @@ pub async fn login( )); } - let oidc_client = req.app_data::>(); + let oidc_client = match req.app_data::>>() { + Some(client) => { + let c = client.clone().into_inner(); + c.as_ref().clone() + } + None => None, + }; let session_key = extract_session_key_from_req(&req).ok(); let (session_key, oidc_client) = match (session_key, oidc_client) { (None, None) => return Ok(redirect_no_oauth_setup(query.redirect.clone())), (None, Some(client)) => { return Ok(redirect_to_oidc( query, - client, + &client, PARSEABLE.options.scope.to_string().as_str(), )); } @@ -103,8 +111,11 @@ pub async fn login( ) if basic.verify_password(&password) => { let user_cookie = cookie_username(&username); let user_id_cookie = cookie_userid(&username); - let session_cookie = - exchange_basic_for_cookie(user, SessionKey::BasicAuth { username, password }); + let session_cookie = exchange_basic_for_cookie( + user, + SessionKey::BasicAuth { username, password }, + EXPIRY_DURATION, + ); Ok(redirect_to_client( query.redirect.as_str(), [user_cookie, user_id_cookie, session_cookie], @@ -121,7 +132,7 @@ pub async fn login( if let Some(oidc_client) = oidc_client { redirect_to_oidc( query, - oidc_client, + &oidc_client, PARSEABLE.options.scope.to_string().as_str(), ) } else { @@ -134,7 +145,13 @@ pub async fn login( } pub async fn logout(req: HttpRequest, query: web::Query) -> HttpResponse { - let oidc_client = req.app_data::>(); + let oidc_client = match req.app_data::>>() { + Some(client) => { + let c = client.clone().into_inner(); + c.as_ref().clone() + } + None => None, + }; let Some(session) = extract_session_key_from_req(&req).ok() else { return redirect_to_client(query.redirect.as_str(), None); }; @@ -155,11 +172,12 @@ pub async fn logout(req: HttpRequest, query: web::Query) -> /// Handler for code callback /// User should be redirected to page they were trying to access with cookie pub async fn reply_login( - oidc_client: Data, + req: HttpRequest, login_query: web::Query, ) -> Result { - let oidc_client = Data::into_inner(oidc_client); - let Ok((mut claims, user_info)): Result<(Claims, Userinfo), anyhow::Error> = + let oidc_client = req.app_data::>>().unwrap(); + let oidc_client = oidc_client.clone().into_inner().as_ref().clone().unwrap(); + let Ok((mut claims, user_info, bearer)): Result<(Claims, Userinfo, Bearer), anyhow::Error> = request_token(oidc_client, &login_query).await else { return Ok(HttpResponse::Unauthorized().finish()); @@ -178,6 +196,10 @@ pub async fn reply_login( } }; let user_info: user::UserInfo = user_info.into(); + + // if provider has group A, and parseable as has role A + // then user will automatically get assigned role A + // else, the default oidc role (inside parseable) will get assigned let group: HashSet = claims .other .remove("groups") @@ -223,12 +245,25 @@ pub async fn reply_login( final_roles.clone_from(&default_role); } + let expires_in = if let Some(expires_in) = bearer.expires_in.as_ref() { + // need an i64 somehow + if *expires_in > u32::MAX.into() { + EXPIRY_DURATION + } else { + let v = i64::from(*expires_in as u32); + Duration::seconds(v) + } + } else { + EXPIRY_DURATION + }; + let user = match (existing_user, final_roles) { - (Some(user), roles) => update_user_if_changed(user, roles, user_info).await?, - (None, roles) => put_user(&user_id, roles, user_info).await?, + (Some(user), roles) => update_user_if_changed(user, roles, user_info, bearer).await?, + (None, roles) => put_user(&user_id, roles, user_info, bearer).await?, }; let id = Ulid::new(); - Users.new_session(&user, SessionKey::SessionId(id)); + + Users.new_session(&user, SessionKey::SessionId(id), expires_in); let redirect_url = login_query .state @@ -270,10 +305,14 @@ fn find_existing_user(user_info: &user::UserInfo) -> Option { None } -fn exchange_basic_for_cookie(user: &User, key: SessionKey) -> Cookie<'static> { +fn exchange_basic_for_cookie( + user: &User, + key: SessionKey, + expires_in: TimeDelta, +) -> Cookie<'static> { let id = Ulid::new(); Users.remove_session(&key); - Users.new_session(user, SessionKey::SessionId(id)); + Users.new_session(user, SessionKey::SessionId(id), expires_in); cookie_session(id) } @@ -288,7 +327,8 @@ fn redirect_to_oidc( state: Some(redirect), ..Default::default() }); - let url: String = auth_url.into(); + let mut url: String = auth_url.into(); + url.push_str("&access_type=offline&prompt=consent"); HttpResponse::TemporaryRedirect() .insert_header((header::LOCATION, url)) .finish() @@ -348,9 +388,9 @@ pub fn cookie_userid(user_id: &str) -> Cookie<'static> { } pub async fn request_token( - oidc_client: Arc, + oidc_client: DiscoveredClient, login_query: &Login, -) -> anyhow::Result<(Claims, Userinfo)> { +) -> anyhow::Result<(Claims, Userinfo, Bearer)> { let mut token: Token = oidc_client.request_token(&login_query.code).await?.into(); let Some(id_token) = token.id_token.as_mut() else { return Err(anyhow::anyhow!("No id_token provided")); @@ -361,7 +401,8 @@ pub async fn request_token( let claims = id_token.payload().expect("payload is decoded").clone(); let userinfo = oidc_client.request_userinfo(&token).await?; - Ok((claims, userinfo)) + let bearer = token.bearer; + Ok((claims, userinfo, bearer)) } // put new user in metadata if does not exits @@ -370,21 +411,28 @@ pub async fn put_user( userid: &str, group: HashSet, user_info: user::UserInfo, + bearer: Bearer, ) -> Result { let mut metadata = get_metadata().await?; - let user = metadata + let mut user = metadata .users .iter() .find(|user| user.userid() == userid) .cloned() .unwrap_or_else(|| { - let user = User::new_oauth(userid.to_owned(), group, user_info); + let user = User::new_oauth(userid.to_owned(), group, user_info, None); metadata.users.push(user.clone()); user }); put_metadata(&metadata).await?; + + // modify before storing + match &mut user.ty { + UserType::Native(_) => {} + UserType::OAuth(oauth) => oauth.bearer = Some(bearer), + } Users.put_user(user.clone()); Ok(user) } @@ -393,6 +441,7 @@ pub async fn update_user_if_changed( mut user: User, group: HashSet, user_info: user::UserInfo, + bearer: Bearer, ) -> Result { // Store the old username before modifying the user object let old_username = user.userid().to_string(); @@ -408,8 +457,12 @@ pub async fn update_user_if_changed( false }; - // update user only if roles, userinfo has changed, or userid needs migration - if roles == &group && oauth_user.user_info == user_info && !needs_userid_migration { + // update user only if roles, userinfo has changed, or userid needs migration, or bearer is updated + if roles == &group + && oauth_user.user_info == user_info + && !needs_userid_migration + && oauth_user.bearer.as_ref() == Some(&bearer) + { return Ok(user); } @@ -438,6 +491,12 @@ pub async fn update_user_if_changed( } put_metadata(&metadata).await?; Users.delete_user(&old_username); + // update oauth bearer + // modify before storing + match &mut user.ty { + UserType::Native(_) => {} + UserType::OAuth(oauth) => oauth.bearer = Some(bearer), + } Users.put_user(user.clone()); Ok(user) } diff --git a/src/rbac/map.rs b/src/rbac/map.rs index 5377d10d7..e8836c824 100644 --- a/src/rbac/map.rs +++ b/src/rbac/map.rs @@ -29,6 +29,7 @@ use super::{ }; use chrono::{DateTime, Utc}; use once_cell::sync::{Lazy, OnceCell}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; pub type Roles = HashMap>; @@ -167,6 +168,26 @@ pub struct Sessions { } impl Sessions { + // only checks if the session is expired or not + pub fn is_session_expired(&self, key: &SessionKey) -> bool { + // fetch userid from session key + let userid = if let Some((user, _)) = self.active_sessions.get(key) { + user + } else { + return false; + }; + + // check against user sessions if this session is still valid + let Some(session) = self.user_sessions.get(userid) else { + return false; + }; + + session + .par_iter() + .find_first(|(sessionid, expiry)| sessionid.eq(key) && expiry < &Utc::now()) + .is_some() + } + // track new session key pub fn track_new( &mut self, diff --git a/src/rbac/mod.rs b/src/rbac/mod.rs index 4eb115778..64cea51af 100644 --- a/src/rbac/mod.rs +++ b/src/rbac/mod.rs @@ -23,7 +23,7 @@ pub mod utils; use std::collections::{HashMap, HashSet}; -use chrono::{DateTime, Days, Utc}; +use chrono::{DateTime, Duration, TimeDelta, Utc}; use itertools::Itertools; use role::model::DefaultPrivilege; use serde::Serialize; @@ -37,6 +37,8 @@ use self::map::SessionKey; use self::role::{Permission, RoleBuilder}; use self::user::UserType; +pub const EXPIRY_DURATION: Duration = Duration::hours(1); + #[derive(PartialEq)] pub enum Response { Authorized, @@ -147,11 +149,11 @@ impl Users { mut_sessions().remove_session(session) } - pub fn new_session(&self, user: &User, session: SessionKey) { + pub fn new_session(&self, user: &User, session: SessionKey, expires_in: TimeDelta) { mut_sessions().track_new( user.userid().to_owned(), session, - Utc::now() + Days::new(7), + Utc::now() + expires_in, roles_to_permission(user.roles()), ) } @@ -228,7 +230,7 @@ pub struct UsersPrism { pub user_groups: HashSet, } -fn roles_to_permission(roles: Vec) -> Vec { +pub fn roles_to_permission(roles: Vec) -> Vec { let mut perms = HashSet::new(); for role in &roles { let role_map = &map::roles(); diff --git a/src/rbac/user.rs b/src/rbac/user.rs index 300bf90d9..ed4266cba 100644 --- a/src/rbac/user.rs +++ b/src/rbac/user.rs @@ -23,6 +23,7 @@ use argon2::{ password_hash::{PasswordHasher, SaltString, rand_core::OsRng}, }; +use openid::Bearer; use rand::distributions::{Alphanumeric, DistString}; use crate::{ @@ -38,7 +39,7 @@ use crate::{ #[serde(untagged)] pub enum UserType { Native(Basic), - OAuth(OAuth), + OAuth(Box), } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -66,12 +67,18 @@ impl User { ) } - pub fn new_oauth(userid: String, roles: HashSet, user_info: UserInfo) -> Self { + pub fn new_oauth( + userid: String, + roles: HashSet, + user_info: UserInfo, + bearer: Option, + ) -> Self { Self { - ty: UserType::OAuth(OAuth { + ty: UserType::OAuth(Box::new(OAuth { userid: user_info.sub.clone().unwrap_or(userid), user_info, - }), + bearer, + })), roles, user_groups: HashSet::new(), } @@ -80,7 +87,7 @@ impl User { pub fn userid(&self) -> &str { match self.ty { UserType::Native(Basic { ref username, .. }) => username, - UserType::OAuth(OAuth { ref userid, .. }) => userid, + UserType::OAuth(ref oauth) => &oauth.userid, } } @@ -175,6 +182,13 @@ pub fn get_admin_user() -> User { pub struct OAuth { pub userid: String, pub user_info: UserInfo, + pub bearer: Option, +} + +impl AsRef for Box { + fn as_ref(&self) -> &Bearer { + self.bearer.as_ref().unwrap() + } } #[derive(Debug, Default, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -255,8 +269,10 @@ impl GroupUser { username: username.clone(), method: "native".to_string(), }, - UserType::OAuth(OAuth { userid, user_info }) => { + UserType::OAuth(oauth) => { // For OAuth users, derive the display username from user_info + let user_info = &oauth.user_info; + let userid = &oauth.userid; let display_username = user_info .name .clone() diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 1d84558ab..4ef8063a6 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -59,7 +59,7 @@ pub fn extract_datetime(path: &str) -> Option { } pub fn get_user_from_request(req: &HttpRequest) -> Result { - let session_key = extract_session_key_from_req(req).unwrap(); + let session_key = extract_session_key_from_req(req).map_err(|_| RBACError::UserDoesNotExist)?; let user_id = Users.get_userid_from_session(&session_key); if user_id.is_none() { return Err(RBACError::UserDoesNotExist); From 82e9a19b4d4ac763c47503130bae950ca6eeff67 Mon Sep 17 00:00:00 2001 From: Anant Vindal Date: Fri, 5 Dec 2025 17:06:38 +0530 Subject: [PATCH 2/2] coderabbit suggestions --- src/handlers/http/middleware.rs | 107 ++++++++++++++++++-------------- src/handlers/http/oidc.rs | 11 ++-- src/rbac/user.rs | 6 ++ 3 files changed, 69 insertions(+), 55 deletions(-) diff --git a/src/handlers/http/middleware.rs b/src/handlers/http/middleware.rs index ac79d1795..7b7d6652a 100644 --- a/src/handlers/http/middleware.rs +++ b/src/handlers/http/middleware.rs @@ -183,7 +183,6 @@ where // if session is expired, refresh token if sessions().is_session_expired(&key) { - // request using oidc client let oidc_client = match http_req.app_data::>>() { Some(client) => { let c = client.clone().into_inner(); @@ -194,57 +193,69 @@ where if let Some(client) = oidc_client && let Ok(userid) = userid - && users().get(&userid).is_some() { - // get the bearer token - let user = users().get(&userid).unwrap().clone(); - match &user.ty { - user::UserType::OAuth(oauth) => { - if oauth.bearer.as_ref().is_some() { - let Ok(refreshed_token) = client - .refresh_token(oauth, Some(PARSEABLE.options.scope.as_str())) - .await - else { - return Err(ErrorUnauthorized( - "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", - )); - }; - let expires_in = - if let Some(expires_in) = refreshed_token.expires_in.as_ref() { - // need an i64 somehow - if *expires_in > u32::MAX.into() { - EXPIRY_DURATION - } else { - let v = i64::from(*expires_in as u32); - Duration::seconds(v) - } - } else { - EXPIRY_DURATION - }; - - // set the new oauth bearer value - if let Some(user) = mut_users().get_mut(&userid) - && let user::UserType::OAuth(oauth) = &mut user.ty - { - oauth.bearer = Some(refreshed_token) + let bearer_to_refresh = { + if let Some(user) = users().get(&userid) { + match &user.ty { + user::UserType::OAuth(oauth) if oauth.bearer.is_some() => { + Some(oauth.clone()) } - - mut_sessions().track_new( - userid.clone(), - key.clone(), - Utc::now() + expires_in, - roles_to_permission(user.roles()), - ); + _ => None, } + } else { + None } - _ => { - mut_sessions().track_new( - userid.clone(), - key.clone(), - Utc::now() + EXPIRY_DURATION, - roles_to_permission(user.roles()), - ); - } + }; + + if let Some(oauth_data) = bearer_to_refresh { + let Ok(refreshed_token) = client + .refresh_token(&oauth_data, Some(PARSEABLE.options.scope.as_str())) + .await + else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + }; + + let expires_in = + if let Some(expires_in) = refreshed_token.expires_in.as_ref() { + if *expires_in > u32::MAX.into() { + EXPIRY_DURATION + } else { + let v = i64::from(*expires_in as u32); + Duration::seconds(v) + } + } else { + EXPIRY_DURATION + }; + + let user_roles = { + let mut users_guard = mut_users(); + if let Some(user) = users_guard.get_mut(&userid) { + if let user::UserType::OAuth(oauth) = &mut user.ty { + oauth.bearer = Some(refreshed_token); + } + user.roles().to_vec() + } else { + return Err(ErrorUnauthorized( + "Your session has expired or is no longer valid. Please re-authenticate to access this resource.", + )); + } + }; + + mut_sessions().track_new( + userid.clone(), + key.clone(), + Utc::now() + expires_in, + roles_to_permission(user_roles), + ); + } else if let Some(user) = users().get(&userid) { + mut_sessions().track_new( + userid.clone(), + key.clone(), + Utc::now() + EXPIRY_DURATION, + roles_to_permission(user.roles()), + ); } } } diff --git a/src/handlers/http/oidc.rs b/src/handlers/http/oidc.rs index ed780f189..ad8523467 100644 --- a/src/handlers/http/oidc.rs +++ b/src/handlers/http/oidc.rs @@ -429,9 +429,8 @@ pub async fn put_user( put_metadata(&metadata).await?; // modify before storing - match &mut user.ty { - UserType::Native(_) => {} - UserType::OAuth(oauth) => oauth.bearer = Some(bearer), + if let user::UserType::OAuth(oauth) = &mut user.ty { + oauth.bearer = Some(bearer); } Users.put_user(user.clone()); Ok(user) @@ -492,10 +491,8 @@ pub async fn update_user_if_changed( put_metadata(&metadata).await?; Users.delete_user(&old_username); // update oauth bearer - // modify before storing - match &mut user.ty { - UserType::Native(_) => {} - UserType::OAuth(oauth) => oauth.bearer = Some(bearer), + if let user::UserType::OAuth(oauth) = &mut user.ty { + oauth.bearer = Some(bearer); } Users.put_user(user.clone()); Ok(user) diff --git a/src/rbac/user.rs b/src/rbac/user.rs index ed4266cba..8e8b62ab8 100644 --- a/src/rbac/user.rs +++ b/src/rbac/user.rs @@ -186,6 +186,12 @@ pub struct OAuth { } impl AsRef for Box { + /// Returns a reference to the bearer token. + /// + /// # Panics + /// Panics if bearer is None. This should never happen in practice as + /// bearer is always set to Some when OIDC is configured and this trait + /// is only called by refresh_token after verifying bearer.is_some(). fn as_ref(&self) -> &Bearer { self.bearer.as_ref().unwrap() }