Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions crates/api-snowflake-rest-sessions/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::TokenizedSession;
use chrono::offset::Local;
use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use time::Duration;
use uuid::Uuid;

#[derive(Serialize, Deserialize)]
#[cfg_attr(test, derive(Debug))]
Expand All @@ -11,11 +11,16 @@ pub struct Claims {
pub aud: String, // validate audience since as it can be deployed on multiple hosts
pub iat: i64, // Issued At
pub exp: i64, // Expiration Time
pub session_id: String,
pub session: TokenizedSession,
}

#[must_use]
pub fn jwt_claims(username: &str, audience: &str, expiration: Duration) -> Claims {
pub fn jwt_claims(
username: &str,
audience: &str,
expiration: Duration,
session: TokenizedSession,
) -> Claims {
let now = Local::now();
let iat = now.timestamp();
let exp = now.timestamp() + expiration.whole_seconds();
Expand All @@ -25,7 +30,7 @@ pub fn jwt_claims(username: &str, audience: &str, expiration: Duration) -> Claim
aud: audience.to_string(),
iat,
exp,
session_id: Uuid::new_v4().to_string(),
session,
}
}

Expand Down
13 changes: 8 additions & 5 deletions crates/api-snowflake-rest-sessions/src/layer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::error as session_error;
use crate::error::{Error, Result};
use crate::session::{
DFSessionId, SESSION_ID_COOKIE_NAME, SessionStore, extract_token_from_cookie,
SESSION_ID_COOKIE_NAME, SessionStore, TokenizedSession, extract_token_from_cookie,
};
use axum::extract::{FromRequestParts, Request, State};
use axum::http::{HeaderMap, HeaderName, request::Parts};
Expand Down Expand Up @@ -39,6 +39,7 @@ where
}
}

// this method loads just session_id, so that won't include any session attrs
#[allow(clippy::unwrap_used, clippy::cognitive_complexity)]
pub async fn propagate_session_cookie(
State(state): State<SessionStore>,
Expand All @@ -54,24 +55,26 @@ pub async fn propagate_session_cookie(

let session_id = uuid::Uuid::new_v4().to_string();
//Propagate new session_id to the extractor
req.extensions_mut().insert(DFSessionId(session_id.clone()));
req.extensions_mut()
.insert(TokenizedSession::new(session_id.clone()));
let mut res = next.run(req).await;
set_headers_in_flight(
res.headers_mut(),
SET_COOKIE,
SESSION_ID_COOKIE_NAME,
session_id.as_str(),
&session_id,
)?;
return Ok(res);
}
tracing::debug!("This DF session_id is not expired or deleted.");
//Propagate in-use (valid) session_id to the extractor
req.extensions_mut().insert(DFSessionId(token));
req.extensions_mut().insert(TokenizedSession::new(token));
} else {
let session_id = uuid::Uuid::new_v4().to_string();
tracing::debug!(session_id = %session_id, "Created new DF session_id");
//Propagate new session_id to the extractor
req.extensions_mut().insert(DFSessionId(session_id.clone()));
req.extensions_mut()
.insert(TokenizedSession::new(session_id.clone()));
let mut res = next.run(req).await;
set_headers_in_flight(
res.headers_mut(),
Expand Down
2 changes: 1 addition & 1 deletion crates/api-snowflake-rest-sessions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pub mod helpers;
pub mod layer;
pub mod session;

pub use crate::session::DFSessionId;
pub use crate::session::TokenizedSession;
72 changes: 55 additions & 17 deletions crates/api-snowflake-rest-sessions/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@ use crate::error::BadAuthTokenSnafu;
use crate::helpers::get_claims_validate_jwt_token;
use axum::extract::FromRequestParts;
use executor::ExecutionAppState;
use executor::SessionMetadata;
use executor::service::ExecutionService;
use http::header::COOKIE;
use http::request::Parts;
use http::{HeaderMap, HeaderName};
use regex::Regex;
use serde::{Deserialize, Serialize};
use snafu::{OptionExt, ResultExt};
use std::{collections::HashMap, sync::Arc};
use uuid::Uuid;

pub const SESSION_ID_COOKIE_NAME: &str = "session_id";

Expand Down Expand Up @@ -38,17 +41,50 @@ pub trait JwtSecret {
fn jwt_secret(&self) -> &str;
}

#[derive(Debug, Clone)]
pub struct DFSessionId(pub String);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizedSession(pub String, pub SessionMetadata);

impl<S> FromRequestParts<S> for DFSessionId
impl Default for TokenizedSession {
fn default() -> Self {
Self(Uuid::new_v4().to_string(), SessionMetadata::default())
}
}

impl TokenizedSession {
#[must_use]
pub fn new(session_id: String) -> Self {
Self(session_id, SessionMetadata::default())
}

#[must_use]
pub fn with_metadata(mut self, metadata: SessionMetadata) -> Self {
self.1 = metadata;
self
}

#[must_use]
pub fn session_id(&self) -> &str {
&self.0
}

#[must_use]
pub const fn metadata(&self) -> &SessionMetadata {
&self.1
}
}

impl<S> FromRequestParts<S> for TokenizedSession
where
S: Send + Sync + ExecutionAppState + JwtSecret,
{
type Rejection = session_error::Error;

#[allow(clippy::unwrap_used)]
#[tracing::instrument(level = "debug", skip(req, state), fields(session_id, located_at))]
#[tracing::instrument(
level = "debug",
skip(req, state),
fields(session_id, located_at, metadata)
)]
async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let execution_svc = state.get_execution_svc();

Expand All @@ -59,9 +95,9 @@ where
// let Extension(Host(host)) = req.extract::<Extension<Host>>()
// .await
// .context(session_error::ExtensionRejectionSnafu)?;
// tracing::info!("Host '{host}' extracted from DFSessionId");
// tracing::info!("Host '{host}' extracted from TokenizedSession");

let (session_id, located_at) = if let Some(token) = extract_token_from_auth(&req.headers) {
let (session, located_at) = if let Some(token) = extract_token_from_auth(&req.headers) {
// host is require to check token audience claim
let host = req.headers.get("host");
let host = host.and_then(|host| host.to_str().ok());
Expand All @@ -71,40 +107,42 @@ where
let jwt_claims = get_claims_validate_jwt_token(&token, host, jwt_secret)
.context(BadAuthTokenSnafu)?;

(jwt_claims.session_id, "auth header")
(jwt_claims.session, "auth header")
} else {
//This is guaranteed by the `propagate_session_cookie`, so we can unwrap
let Self(token) = req.extensions.get::<Self>().unwrap();
(token.clone(), "extensions")
let session = req.extensions.get::<Self>().unwrap();
(session.clone(), "extensions")
};

// Record the result as part of the current span.
tracing::Span::current()
.record("located_at", located_at)
.record("session_id", session_id.clone());
.record("metadata", format!("{:?}", session.metadata()))
.record("session_id", session.session_id());

Self::get_or_create_session(execution_svc, session_id).await
Self::get_or_create_session(execution_svc, session).await
}
}

impl DFSessionId {
impl TokenizedSession {
#[tracing::instrument(
name = "DFSessionId::get_or_create_session",
name = "TokenizedSession::get_or_create_session",
level = "info",
skip(execution_svc),
fields(new_session, sessions_count)
)]
async fn get_or_create_session(
execution_svc: Arc<dyn ExecutionService>,
session_id: String,
session: Self,
) -> Result<Self, session_error::Error> {
let session_id = session.session_id();
if !execution_svc
.update_session_expiry(&session_id)
.update_session_expiry(session_id)
.await
.context(session_error::ExecutionSnafu)?
{
let _ = execution_svc
.create_session(&session_id)
.create_session(session_id)
.await
.context(session_error::ExecutionSnafu)?;
tracing::Span::current().record("new_session", true);
Expand All @@ -114,7 +152,7 @@ impl DFSessionId {
// Record the result as part of the current span.
tracing::Span::current().record("sessions_count", sessions_count);

Ok(Self(session_id))
Ok(session)
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/api-snowflake-rest/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ pub struct QueryRequest {
pub struct QueryRequestBody {
pub sql_text: String,
pub async_exec: Option<bool>,
pub query_submission_time: Option<u64>,
}

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand Down
8 changes: 4 additions & 4 deletions crates/api-snowflake-rest/src/server/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::models::{
};
use crate::server::error::Result;
use crate::server::logic::{handle_login_request, handle_query_request};
use api_snowflake_rest_sessions::DFSessionId;
use api_snowflake_rest_sessions::TokenizedSession;
use api_snowflake_rest_sessions::layer::Host;
use axum::Json;
use axum::extract::{ConnectInfo, Query, State};
Expand Down Expand Up @@ -52,14 +52,14 @@ pub async fn login(
)]
pub async fn query(
ConnectInfo(addr): ConnectInfo<SocketAddr>,
DFSessionId(session_id): DFSessionId,
tokenized_session: TokenizedSession,
State(state): State<AppState>,
Query(query): Query<QueryRequest>,
Json(query_body): Json<QueryRequestBody>,
) -> Result<Json<JsonResponse>> {
let response = handle_query_request(
&state,
&session_id,
tokenized_session,
query,
query_body,
Option::from(addr.ip().to_string()),
Expand Down Expand Up @@ -93,7 +93,7 @@ pub async fn abort(
ret(level = tracing::Level::TRACE)
)]
pub async fn session(
DFSessionId(session_id): DFSessionId,
TokenizedSession(session_id, ..): TokenizedSession,
State(state): State<AppState>,
Query(query_params): Query<SessionQueryParams>,
) -> Result<Json<serde_json::value::Value>> {
Expand Down
2 changes: 1 addition & 1 deletion crates/api-snowflake-rest/src/server/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub async fn require_auth(
get_claims_validate_jwt_token(&token, &host, &jwt_secret).context(BadAuthTokenSnafu)?;

// Record the result as part of the current span.
tracing::Span::current().record("session_id", jwt_claims.session_id.as_str());
tracing::Span::current().record("session_id", jwt_claims.session.session_id());

let response = next.run(req).await;

Expand Down
Loading