From 4533030d86830990a53ad9db77b6c7987f32662d Mon Sep 17 00:00:00 2001 From: Yaroslav Litvinov Date: Fri, 9 Jan 2026 05:50:08 +0100 Subject: [PATCH 1/5] wip --- .../src/helpers.rs | 13 +- .../api-snowflake-rest-sessions/src/layer.rs | 13 +- crates/api-snowflake-rest-sessions/src/lib.rs | 2 +- .../src/session.rs | 72 ++++++++--- crates/api-snowflake-rest/src/models.rs | 1 + .../api-snowflake-rest/src/server/handlers.rs | 8 +- crates/api-snowflake-rest/src/server/layer.rs | 2 +- crates/api-snowflake-rest/src/server/logic.rs | 74 +++++++----- crates/api-snowflake-rest/src/tests/client.rs | 1 + .../set_command_show_variables.snap | 113 ++++++++++++++++++ .../use_command_show_variables-2.snap | 18 +-- .../src/tests/sql_test_macro.rs | 2 + .../src/tests/test_gzip_encoding.rs | 1 + .../src/tests/test_rest_api.rs | 5 + crates/executor/src/lib.rs | 1 + crates/executor/src/models.rs | 42 +++++++ crates/executor/src/query.rs | 14 +-- crates/executor/src/query_task_result.rs | 21 ++-- crates/executor/src/service.rs | 43 +++++-- crates/executor/src/session.rs | 16 +-- .../src/tests/statestore_queries_unittest.rs | 70 ++++++++--- crates/state-store/src/models.rs | 30 +++++ 22 files changed, 430 insertions(+), 132 deletions(-) create mode 100644 crates/api-snowflake-rest/src/tests/snapshots/compatible/set_command_show_variables.snap diff --git a/crates/api-snowflake-rest-sessions/src/helpers.rs b/crates/api-snowflake-rest-sessions/src/helpers.rs index e500ed55..87248a4d 100644 --- a/crates/api-snowflake-rest-sessions/src/helpers.rs +++ b/crates/api-snowflake-rest-sessions/src/helpers.rs @@ -1,8 +1,8 @@ +use super::TokenizedSession; use chrono::offset::Local; use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation, decode, encode}; use serde::{Deserialize, Serialize}; use time::Duration; -use uuid::Uuid; #[derive(Serialize, Deserialize)] #[cfg_attr(test, derive(Debug))] @@ -11,11 +11,16 @@ pub struct Claims { pub aud: String, // validate audience since as it can be deployed on multiple hosts pub iat: i64, // Issued At pub exp: i64, // Expiration Time - pub session_id: String, + pub session: TokenizedSession, } #[must_use] -pub fn jwt_claims(username: &str, audience: &str, expiration: Duration) -> Claims { +pub fn jwt_claims( + username: &str, + audience: &str, + expiration: Duration, + session: TokenizedSession, +) -> Claims { let now = Local::now(); let iat = now.timestamp(); let exp = now.timestamp() + expiration.whole_seconds(); @@ -25,7 +30,7 @@ pub fn jwt_claims(username: &str, audience: &str, expiration: Duration) -> Claim aud: audience.to_string(), iat, exp, - session_id: Uuid::new_v4().to_string(), + session, } } diff --git a/crates/api-snowflake-rest-sessions/src/layer.rs b/crates/api-snowflake-rest-sessions/src/layer.rs index b115d525..0183d1ac 100644 --- a/crates/api-snowflake-rest-sessions/src/layer.rs +++ b/crates/api-snowflake-rest-sessions/src/layer.rs @@ -1,7 +1,7 @@ use crate::error as session_error; use crate::error::{Error, Result}; use crate::session::{ - DFSessionId, SESSION_ID_COOKIE_NAME, SessionStore, extract_token_from_cookie, + SESSION_ID_COOKIE_NAME, SessionStore, TokenizedSession, extract_token_from_cookie, }; use axum::extract::{FromRequestParts, Request, State}; use axum::http::{HeaderMap, HeaderName, request::Parts}; @@ -39,6 +39,7 @@ where } } +// this method loads just session_id, so that won't include any session attrs #[allow(clippy::unwrap_used, clippy::cognitive_complexity)] pub async fn propagate_session_cookie( State(state): State, @@ -54,24 +55,26 @@ pub async fn propagate_session_cookie( let session_id = uuid::Uuid::new_v4().to_string(); //Propagate new session_id to the extractor - req.extensions_mut().insert(DFSessionId(session_id.clone())); + req.extensions_mut() + .insert(TokenizedSession::new(session_id.clone())); let mut res = next.run(req).await; set_headers_in_flight( res.headers_mut(), SET_COOKIE, SESSION_ID_COOKIE_NAME, - session_id.as_str(), + &session_id, )?; return Ok(res); } tracing::debug!("This DF session_id is not expired or deleted."); //Propagate in-use (valid) session_id to the extractor - req.extensions_mut().insert(DFSessionId(token)); + req.extensions_mut().insert(TokenizedSession::new(token)); } else { let session_id = uuid::Uuid::new_v4().to_string(); tracing::debug!(session_id = %session_id, "Created new DF session_id"); //Propagate new session_id to the extractor - req.extensions_mut().insert(DFSessionId(session_id.clone())); + req.extensions_mut() + .insert(TokenizedSession::new(session_id.clone())); let mut res = next.run(req).await; set_headers_in_flight( res.headers_mut(), diff --git a/crates/api-snowflake-rest-sessions/src/lib.rs b/crates/api-snowflake-rest-sessions/src/lib.rs index 078d1bd6..b33c8dc2 100644 --- a/crates/api-snowflake-rest-sessions/src/lib.rs +++ b/crates/api-snowflake-rest-sessions/src/lib.rs @@ -3,4 +3,4 @@ pub mod helpers; pub mod layer; pub mod session; -pub use crate::session::DFSessionId; +pub use crate::session::TokenizedSession; diff --git a/crates/api-snowflake-rest-sessions/src/session.rs b/crates/api-snowflake-rest-sessions/src/session.rs index 7aad5de9..3c459874 100644 --- a/crates/api-snowflake-rest-sessions/src/session.rs +++ b/crates/api-snowflake-rest-sessions/src/session.rs @@ -3,13 +3,16 @@ use crate::error::BadAuthTokenSnafu; use crate::helpers::get_claims_validate_jwt_token; use axum::extract::FromRequestParts; use executor::ExecutionAppState; +use executor::SessionMetadata; use executor::service::ExecutionService; use http::header::COOKIE; use http::request::Parts; use http::{HeaderMap, HeaderName}; use regex::Regex; +use serde::{Deserialize, Serialize}; use snafu::{OptionExt, ResultExt}; use std::{collections::HashMap, sync::Arc}; +use uuid::Uuid; pub const SESSION_ID_COOKIE_NAME: &str = "session_id"; @@ -38,17 +41,50 @@ pub trait JwtSecret { fn jwt_secret(&self) -> &str; } -#[derive(Debug, Clone)] -pub struct DFSessionId(pub String); +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenizedSession(pub String, pub SessionMetadata); -impl FromRequestParts for DFSessionId +impl Default for TokenizedSession { + fn default() -> Self { + Self(Uuid::new_v4().to_string(), SessionMetadata::default()) + } +} + +impl TokenizedSession { + #[must_use] + pub fn new(session_id: String) -> Self { + Self(session_id, SessionMetadata::default()) + } + + #[must_use] + pub fn with_metadata(mut self, metadata: SessionMetadata) -> Self { + self.1 = metadata; + self + } + + #[must_use] + pub fn session_id(&self) -> &str { + &self.0 + } + + #[must_use] + pub const fn metadata(&self) -> &SessionMetadata { + &self.1 + } +} + +impl FromRequestParts for TokenizedSession where S: Send + Sync + ExecutionAppState + JwtSecret, { type Rejection = session_error::Error; #[allow(clippy::unwrap_used)] - #[tracing::instrument(level = "debug", skip(req, state), fields(session_id, located_at))] + #[tracing::instrument( + level = "debug", + skip(req, state), + fields(session_id, located_at, metadata) + )] async fn from_request_parts(req: &mut Parts, state: &S) -> Result { let execution_svc = state.get_execution_svc(); @@ -59,9 +95,9 @@ where // let Extension(Host(host)) = req.extract::>() // .await // .context(session_error::ExtensionRejectionSnafu)?; - // tracing::info!("Host '{host}' extracted from DFSessionId"); + // tracing::info!("Host '{host}' extracted from TokenizedSession"); - let (session_id, located_at) = if let Some(token) = extract_token_from_auth(&req.headers) { + let (session, located_at) = if let Some(token) = extract_token_from_auth(&req.headers) { // host is require to check token audience claim let host = req.headers.get("host"); let host = host.and_then(|host| host.to_str().ok()); @@ -71,40 +107,42 @@ where let jwt_claims = get_claims_validate_jwt_token(&token, host, jwt_secret) .context(BadAuthTokenSnafu)?; - (jwt_claims.session_id, "auth header") + (jwt_claims.session, "auth header") } else { //This is guaranteed by the `propagate_session_cookie`, so we can unwrap - let Self(token) = req.extensions.get::().unwrap(); - (token.clone(), "extensions") + let session = req.extensions.get::().unwrap(); + (session.clone(), "extensions") }; // Record the result as part of the current span. tracing::Span::current() .record("located_at", located_at) - .record("session_id", session_id.clone()); + .record("metadata", format!("{:?}", session.metadata())) + .record("session_id", session.session_id()); - Self::get_or_create_session(execution_svc, session_id).await + Self::get_or_create_session(execution_svc, session).await } } -impl DFSessionId { +impl TokenizedSession { #[tracing::instrument( - name = "DFSessionId::get_or_create_session", + name = "TokenizedSession::get_or_create_session", level = "info", skip(execution_svc), fields(new_session, sessions_count) )] async fn get_or_create_session( execution_svc: Arc, - session_id: String, + session: Self, ) -> Result { + let session_id = session.session_id(); if !execution_svc - .update_session_expiry(&session_id) + .update_session_expiry(session_id) .await .context(session_error::ExecutionSnafu)? { let _ = execution_svc - .create_session(&session_id) + .create_session(session_id) .await .context(session_error::ExecutionSnafu)?; tracing::Span::current().record("new_session", true); @@ -114,7 +152,7 @@ impl DFSessionId { // Record the result as part of the current span. tracing::Span::current().record("sessions_count", sessions_count); - Ok(Self(session_id)) + Ok(session) } } diff --git a/crates/api-snowflake-rest/src/models.rs b/crates/api-snowflake-rest/src/models.rs index b3f1975a..bb98f7b0 100644 --- a/crates/api-snowflake-rest/src/models.rs +++ b/crates/api-snowflake-rest/src/models.rs @@ -56,6 +56,7 @@ pub struct QueryRequest { pub struct QueryRequestBody { pub sql_text: String, pub async_exec: Option, + pub query_submission_time: Option, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] diff --git a/crates/api-snowflake-rest/src/server/handlers.rs b/crates/api-snowflake-rest/src/server/handlers.rs index 3bb1b203..eb72ef9a 100644 --- a/crates/api-snowflake-rest/src/server/handlers.rs +++ b/crates/api-snowflake-rest/src/server/handlers.rs @@ -5,7 +5,7 @@ use crate::models::{ }; use crate::server::error::Result; use crate::server::logic::{handle_login_request, handle_query_request}; -use api_snowflake_rest_sessions::DFSessionId; +use api_snowflake_rest_sessions::TokenizedSession; use api_snowflake_rest_sessions::layer::Host; use axum::Json; use axum::extract::{ConnectInfo, Query, State}; @@ -52,14 +52,14 @@ pub async fn login( )] pub async fn query( ConnectInfo(addr): ConnectInfo, - DFSessionId(session_id): DFSessionId, + tokenized_session: TokenizedSession, State(state): State, Query(query): Query, Json(query_body): Json, ) -> Result> { let response = handle_query_request( &state, - &session_id, + tokenized_session, query, query_body, Option::from(addr.ip().to_string()), @@ -93,7 +93,7 @@ pub async fn abort( ret(level = tracing::Level::TRACE) )] pub async fn session( - DFSessionId(session_id): DFSessionId, + TokenizedSession(session_id, ..): TokenizedSession, State(state): State, Query(query_params): Query, ) -> Result> { diff --git a/crates/api-snowflake-rest/src/server/layer.rs b/crates/api-snowflake-rest/src/server/layer.rs index 0a8dda89..de7266d2 100644 --- a/crates/api-snowflake-rest/src/server/layer.rs +++ b/crates/api-snowflake-rest/src/server/layer.rs @@ -40,7 +40,7 @@ pub async fn require_auth( get_claims_validate_jwt_token(&token, &host, &jwt_secret).context(BadAuthTokenSnafu)?; // Record the result as part of the current span. - tracing::Span::current().record("session_id", jwt_claims.session_id.as_str()); + tracing::Span::current().record("session_id", jwt_claims.session.session_id()); let response = next.run(req).await; diff --git a/crates/api-snowflake-rest/src/server/logic.rs b/crates/api-snowflake-rest/src/server/logic.rs index 4ec05874..7fddc8d7 100644 --- a/crates/api-snowflake-rest/src/server/logic.rs +++ b/crates/api-snowflake-rest/src/server/logic.rs @@ -4,21 +4,23 @@ use crate::models::{ QueryRequest, QueryRequestBody, }; use crate::server::error::{ - self as api_snowflake_rest_error, CreateJwtSnafu, NoJwtSecretSnafu, Result, SetVariableSnafu, + self as api_snowflake_rest_error, CreateJwtSnafu, NoJwtSecretSnafu, Result, }; use crate::server::helpers::handle_query_ok_result; +use api_snowflake_rest_sessions::TokenizedSession; use api_snowflake_rest_sessions::helpers::{create_jwt, ensure_jwt_secret_is_valid, jwt_claims}; use executor::RunningQueryId; -use executor::models::QueryContext; +use executor::models::{QueryContext, SessionMetadata, SessionMetadataAttr}; use snafu::{OptionExt, ResultExt}; use time::Duration; pub const JWT_TOKEN_EXPIRATION_SECONDS: u32 = 3 * 24 * 60 * 60; #[tracing::instrument( - name = "api_snowflake_rest::logic::login", + name = "api_snowflake_rest::handle_login_request", level = "debug", skip(state, credentials), + fields(session_metadata), err, ret(level = tracing::Level::TRACE) )] @@ -32,6 +34,9 @@ pub async fn handle_login_request( let LoginRequestData { login_name, password, + account_name, + client_app_id, + client_app_version, .. } = credentials; @@ -43,40 +48,40 @@ pub async fn handle_login_request( let jwt_secret = &*state.config.auth.jwt_secret; let _ = ensure_jwt_secret_is_valid(jwt_secret).context(NoJwtSecretSnafu)?; + let mut session_metadata = SessionMetadata::default(); + session_metadata.set_attr(SessionMetadataAttr::UserName, login_name.clone()); + session_metadata.set_attr(SessionMetadataAttr::AccountName, account_name); + session_metadata.set_attr(SessionMetadataAttr::ClientAppId, client_app_id); + session_metadata.set_attr(SessionMetadataAttr::ClientAppVersion, client_app_version); + // set database, schema when provided + if let Some(db) = params.database_name { + session_metadata.set_attr(SessionMetadataAttr::Database, db); + } + if let Some(schema) = params.schema_name { + session_metadata.set_attr(SessionMetadataAttr::Schema, schema); + } + if let Some(warehouse) = params.warehouse { + session_metadata.set_attr(SessionMetadataAttr::Warehouse, warehouse); + } + + tracing::Span::current().record("session_metadata", format!("{session_metadata:?}")); + + let tokenized_session = TokenizedSession::default().with_metadata(session_metadata); + let jwt_claims = jwt_claims( &login_name, &host, Duration::seconds(JWT_TOKEN_EXPIRATION_SECONDS.into()), + tokenized_session, ); tracing::info!("Host '{host}' for token creation"); - let session_id = jwt_claims.session_id.clone(); - let session = state.execution_svc.create_session(&session_id).await?; + let session_id = jwt_claims.session.session_id(); + let _session = state.execution_svc.create_session(session_id).await?; let jwt_token = create_jwt(&jwt_claims, jwt_secret).context(CreateJwtSnafu)?; - // set database, schema when provided - if let Some(db) = params.database_name { - session.set_database(&db).await.context(SetVariableSnafu { - variable: "database", - })?; - } - if let Some(schema) = params.schema_name { - session - .set_schema(&schema) - .await - .context(SetVariableSnafu { variable: "schema" })?; - } - if let Some(warehouse) = params.warehouse { - session - .set_warehouse(&warehouse) - .await - .context(SetVariableSnafu { - variable: "warehouse", - })?; - } - Ok(LoginResponse { data: Option::from(LoginResponseData { token: jwt_token }), success: true, @@ -85,7 +90,7 @@ pub async fn handle_login_request( } #[tracing::instrument( - name = "api_snowflake_rest::logic::query", + name = "api_snowflake_rest::handle_query_request", level = "debug", skip(state, query_body, client_ip), fields(request_id = %query.request_id), @@ -94,7 +99,7 @@ pub async fn handle_login_request( )] pub async fn handle_query_request( state: &AppState, - session_id: &str, + TokenizedSession(session_id, session_metadata): TokenizedSession, query: QueryRequest, query_body: QueryRequestBody, client_ip: Option, @@ -102,6 +107,7 @@ pub async fn handle_query_request( let QueryRequestBody { sql_text, async_exec, + query_submission_time, } = query_body; let async_exec = async_exec.unwrap_or(false); if async_exec { @@ -109,7 +115,13 @@ pub async fn handle_query_request( } let serialization_format = state.config.dbt_serialization_format; - let mut query_context = QueryContext::default().with_request_id(query.request_id); + let mut query_context = QueryContext::new( + session_metadata.attr(SessionMetadataAttr::Database), + session_metadata.attr(SessionMetadataAttr::Schema), + None, + ) + .with_request_id(query.request_id) + .with_query_submission_time(query_submission_time); if let Some(ip) = client_ip { query_context = query_context.with_ip_address(ip); @@ -123,7 +135,7 @@ pub async fn handle_query_request( sql_text.clone(), )); - // if retry-disable feature is enabled we ignory retries regardless of query_id is located or not + // if retry-disable feature is enabled we ignore retries regardless of query_id is located or not #[cfg(feature = "retry-disable")] if query.retry_count.unwrap_or_default() > 0 { return api_snowflake_rest_error::RetryDisabledSnafu.fail(); @@ -138,7 +150,7 @@ pub async fn handle_query_request( let query_id = query_context.query_id; let result = state .execution_svc - .query(session_id, &sql_text, query_context) + .query(&session_id, &sql_text, query_context) .await?; (result, query_id) }; diff --git a/crates/api-snowflake-rest/src/tests/client.rs b/crates/api-snowflake-rest/src/tests/client.rs index f83b36e6..6f093d61 100644 --- a/crates/api-snowflake-rest/src/tests/client.rs +++ b/crates/api-snowflake-rest/src/tests/client.rs @@ -204,6 +204,7 @@ where json!(QueryRequestBody { sql_text: query.to_string(), async_exec: Some(async_exec), + query_submission_time: Some(1_764_161_275_445), }) .to_string(), ) diff --git a/crates/api-snowflake-rest/src/tests/snapshots/compatible/set_command_show_variables.snap b/crates/api-snowflake-rest/src/tests/snapshots/compatible/set_command_show_variables.snap new file mode 100644 index 00000000..9f8c04f5 --- /dev/null +++ b/crates/api-snowflake-rest/src/tests/snapshots/compatible/set_command_show_variables.snap @@ -0,0 +1,113 @@ +--- +source: crates/api-snowflake-rest/src/tests/test_rest_api.rs +description: "SQL #1 [spent: 1028/1653ms]: SHOW VARIABLES\nQuery UUID: 019ba0a2-7c9d-7f31-accb-d20a0622ba57" +expression: snapshot +--- +SHOW VARIABLES +{ + "data": { + "rowtype": [ + { + "name": "session_id", + "database": "", + "schema": "", + "table": "", + "nullable": true, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + }, + { + "name": "name", + "database": "", + "schema": "", + "table": "", + "nullable": false, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + }, + { + "name": "value", + "database": "", + "schema": "", + "table": "", + "nullable": true, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + }, + { + "name": "type", + "database": "", + "schema": "", + "table": "", + "nullable": true, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + }, + { + "name": "comment", + "database": "", + "schema": "", + "table": "", + "nullable": true, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + }, + { + "name": "created_on", + "database": "", + "schema": "", + "table": "", + "nullable": true, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + }, + { + "name": "updated_on", + "database": "", + "schema": "", + "table": "", + "nullable": true, + "type": "text", + "byteLength": 16777216, + "length": 16777216, + "scale": null, + "precision": null, + "collation": null + } + ], + "rowsetBase64": null, + "rowset": [], + "total": 0, + "returned": 0, + "queryResultFormat": "json", + "sqlState": "02000", + "queryId": "UUID" + }, + "success": true, + "message": "successfully executed", + "code": null +} diff --git a/crates/api-snowflake-rest/src/tests/snapshots/compatible/use_command_show_variables-2.snap b/crates/api-snowflake-rest/src/tests/snapshots/compatible/use_command_show_variables-2.snap index 1bc3f32f..ef0a377c 100644 --- a/crates/api-snowflake-rest/src/tests/snapshots/compatible/use_command_show_variables-2.snap +++ b/crates/api-snowflake-rest/src/tests/snapshots/compatible/use_command_show_variables-2.snap @@ -1,6 +1,6 @@ --- source: crates/api-snowflake-rest/src/tests/test_rest_api.rs -description: "SQL #2 [spent: 1947/3917ms]: SHOW VARIABLES\nQuery UUID: 00000000-0000-0000-0000-000000000000" +description: "SQL #2 [spent: 834/2246ms]: SHOW VARIABLES\nQuery UUID: 019ba0a2-81a9-7f01-8831-0ab6e9cef96c" expression: snapshot --- SHOW VARIABLES @@ -100,19 +100,9 @@ SHOW VARIABLES } ], "rowsetBase64": null, - "rowset": [ - [ - "UUID", - "warehouse", - "embucket", - "text", - "", - "UTC_TIME9", - "UTC_TIME9" - ] - ], - "total": 1, - "returned": 1, + "rowset": [], + "total": 0, + "returned": 0, "queryResultFormat": "json", "sqlState": "02000", "queryId": "UUID" diff --git a/crates/api-snowflake-rest/src/tests/sql_test_macro.rs b/crates/api-snowflake-rest/src/tests/sql_test_macro.rs index e1c1fb66..6902f141 100644 --- a/crates/api-snowflake-rest/src/tests/sql_test_macro.rs +++ b/crates/api-snowflake-rest/src/tests/sql_test_macro.rs @@ -5,6 +5,7 @@ use crate::tests::TEST_JWT_SECRET; use crate::tests::snow_sql::{ACCESS_TOKEN_KEY, snow_sql}; use crate::tests::snow_sql::{PASSWORD_KEY, REQUEST_ID_KEY, USER_KEY}; use crate::{models::JsonResponse, server::server_models::RestApiConfig}; +use api_snowflake_rest_sessions::TokenizedSession; use api_snowflake_rest_sessions::helpers::{create_jwt, jwt_claims}; use arrow::record_batch::RecordBatch; use catalog_metastore::metastore_settings_config::MetastoreSettingsConfig; @@ -171,6 +172,7 @@ impl SqlTest { DEMO_USER, &host, time::Duration::seconds(JWT_TOKEN_EXPIRATION_SECONDS.into()), + TokenizedSession::default(), ); create_jwt(&jwt_claims, jwt_secret).expect("Failed to create JWT token") diff --git a/crates/api-snowflake-rest/src/tests/test_gzip_encoding.rs b/crates/api-snowflake-rest/src/tests/test_gzip_encoding.rs index 80be9db7..8c34ead4 100644 --- a/crates/api-snowflake-rest/src/tests/test_gzip_encoding.rs +++ b/crates/api-snowflake-rest/src/tests/test_gzip_encoding.rs @@ -27,6 +27,7 @@ mod tests { let query_request = QueryRequestBody { sql_text: "SELECT 1;".to_string(), async_exec: Some(false), + query_submission_time: Some(1_764_161_275_445), }; let query_compressed_bytes = make_bytes_body(&query_request); diff --git a/crates/api-snowflake-rest/src/tests/test_rest_api.rs b/crates/api-snowflake-rest/src/tests/test_rest_api.rs index 820ebcd1..0608933c 100644 --- a/crates/api-snowflake-rest/src/tests/test_rest_api.rs +++ b/crates/api-snowflake-rest/src/tests/test_rest_api.rs @@ -86,6 +86,11 @@ mod compatible { SqlTest::new(&["use schema test_schema", "SHOW VARIABLES"]) ); + sql_test!( + set_command_show_variables, + SqlTest::new(&["SHOW VARIABLES"]).with_setup_queries(&["set variable_name = 'value'"]) + ); + sql_test!( create_table_missing_schema, SqlTest::new(&[ diff --git a/crates/executor/src/lib.rs b/crates/executor/src/lib.rs index cb62f87d..629d2649 100644 --- a/crates/executor/src/lib.rs +++ b/crates/executor/src/lib.rs @@ -19,6 +19,7 @@ pub mod utils; pub mod tests; pub use error::{Error, Result}; +pub use models::{SessionMetadata, SessionMetadataAttr}; pub use query_types::{ExecutionStatus, QueryId}; pub use running_queries::RunningQueryId; pub use snowflake_error::SnowflakeError; diff --git a/crates/executor/src/models.rs b/crates/executor/src/models.rs index c45cca92..5a9526af 100644 --- a/crates/executor/src/models.rs +++ b/crates/executor/src/models.rs @@ -18,6 +18,8 @@ pub struct QueryContext { pub query_id: QueryId, pub request_id: Option, pub ip_address: Option, + pub query_submission_time: Option, + pub session_metadata: Option, } // Add own Default implementation to avoid getting default (zeroed) Uuid. @@ -31,6 +33,8 @@ impl Default for QueryContext { query_id: Uuid::now_v7(), request_id: None, ip_address: None, + query_submission_time: None, + session_metadata: None, } } } @@ -67,6 +71,18 @@ impl QueryContext { self.ip_address = Some(ip_address); self } + + #[must_use] + pub const fn with_query_submission_time(mut self, query_submission_time: Option) -> Self { + self.query_submission_time = query_submission_time; + self + } + + #[must_use] + pub fn with_session_metadata(mut self, session_metadata: Option) -> Self { + self.session_metadata = session_metadata; + self + } } #[derive(Debug, Clone, PartialEq)] @@ -132,6 +148,32 @@ impl QueryMetric { } } +#[derive(strum::Display, Clone, Copy)] +pub enum SessionMetadataAttr { + UserName, + Warehouse, + Database, + Schema, + AccountName, + ClientAppId, + ClientAppVersion, +} + +#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub struct SessionMetadata(HashMap); + +impl SessionMetadata { + pub fn set_attr(&mut self, attr: SessionMetadataAttr, value: String) { + self.0.insert(attr.to_string(), value); + } + + #[must_use] + pub fn attr(&self, attr: SessionMetadataAttr) -> Option { + self.0.get(&attr.to_string()).cloned() + } +} + // TODO: We should not have serde dependency here // Instead it should be in api-snowflake-rest #[derive(Debug, Serialize, Deserialize, Clone)] diff --git a/crates/executor/src/query.rs b/crates/executor/src/query.rs index 312c15d5..e755163c 100644 --- a/crates/executor/src/query.rs +++ b/crates/executor/src/query.rs @@ -155,19 +155,17 @@ impl UserQuery { #[must_use] pub fn current_database(&self) -> String { - self.query_context - .database - .clone() - .or_else(|| self.session.get_session_variable("database")) + self.session + .get_session_variable("database") + .or_else(|| self.query_context.database.clone()) .unwrap_or_else(|| "embucket".to_string()) } #[must_use] pub fn current_schema(&self) -> String { - self.query_context - .schema - .clone() - .or_else(|| self.session.get_session_variable("schema")) + self.session + .get_session_variable("schema") + .or_else(|| self.query_context.schema.clone()) .unwrap_or_else(|| "public".to_string()) } diff --git a/crates/executor/src/query_task_result.rs b/crates/executor/src/query_task_result.rs index c6008f84..633212b2 100644 --- a/crates/executor/src/query_task_result.rs +++ b/crates/executor/src/query_task_result.rs @@ -86,6 +86,7 @@ impl ExecutionTaskResult { } #[cfg(feature = "state-store-query")] + #[allow(clippy::cast_sign_loss, clippy::as_conversions)] pub fn assign_rows_counts_attributes( &self, query: &mut state_store::Query, @@ -94,23 +95,22 @@ impl ExecutionTaskResult { if let Ok(result) = &self.result && let QueryType::Dml(query_type) = query_type { - if let DmlStType::Select = query_type { + if matches!(query_type, DmlStType::Select) { let rows_count: u64 = result.records.iter().map(|r| r.num_rows() as u64).sum(); query.set_rows_produced(rows_count); - } else if let Some(rows_count) = value_by_row_column(&result, 0, 0) { + } else if let Some(rows_count) = value_by_row_column(result, 0, 0) { match query_type { - DmlStType::Insert => query.set_rows_inserted(rows_count as u64), - DmlStType::Update => query.set_rows_updated(rows_count as u64), - DmlStType::Delete => query.set_rows_deleted(rows_count as u64), - DmlStType::Truncate => query.set_rows_deleted(rows_count as u64), + DmlStType::Insert => query.set_rows_inserted(rows_count), + DmlStType::Update => query.set_rows_updated(rows_count), + DmlStType::Delete | DmlStType::Truncate => query.set_rows_deleted(rows_count), DmlStType::Merge => { // merge has 2 columns, currently map values to insert/select rows counts - query.set_rows_inserted(rows_count as u64); - if let Some(rows_count) = value_by_row_column(&result, 0, 1) { - query.set_rows_produced(rows_count as u64); + query.set_rows_inserted(rows_count); + if let Some(rows_count) = value_by_row_column(result, 0, 1) { + query.set_rows_produced(rows_count); } } - _ => {} + DmlStType::Select => {} } } } @@ -145,6 +145,7 @@ impl ExecutionTaskResult { } #[cfg(feature = "state-store-query")] +#[allow(clippy::cast_sign_loss, clippy::as_conversions)] fn value_by_row_column(result: &QueryResult, row_idx: usize, col_idx: usize) -> Option { result.records[0].columns().get(col_idx).and_then(|col| { if let Some(cols) = col.as_any().downcast_ref::() { diff --git a/crates/executor/src/service.rs b/crates/executor/src/service.rs index b6ad5e23..f824a25d 100644 --- a/crates/executor/src/service.rs +++ b/crates/executor/src/service.rs @@ -25,6 +25,8 @@ use tokio::task; use tokio_util::sync::CancellationToken; use super::error::{self as ex_error, Result}; +#[cfg(feature = "state-store")] +use super::models::SessionMetadataAttr; use super::models::{QueryContext, QueryResult}; use super::running_queries::{RunningQueries, RunningQueriesRegistry, RunningQuery}; use super::session::UserSession; @@ -319,11 +321,11 @@ impl ExecutionService for CoreExecutionService { tracing::trace!("Acquired write lock for df_sessions"); sessions.insert(session_id.to_string(), user_session.clone()); - #[cfg(feature = "state-store")] - self.state_store - .put_new_session(session_id) - .await - .context(ex_error::StateStoreSnafu)?; + // #[cfg(feature = "state-store")] + // self.state_store + // .put_new_session(session_id) + // .await + // .context(ex_error::StateStoreSnafu)?; // Record the result as part of the current span. tracing::Span::current().record("new_sessions_count", sessions.len()); @@ -529,6 +531,27 @@ impl ExecutionService for CoreExecutionService { query.set_execution_status(ExecutionStatus::Running); query.set_warehouse_type(self.config.warehouse_type.clone()); query.set_release_version(self.config.build_version.clone()); + // session context set by user during login + if let Some(database) = &query_context.database { + query.set_user_database(database.clone()); + } + if let Some(schema) = &query_context.schema { + query.set_user_schema(schema.clone()); + } + if let Some(query_submission_time) = &query_context.query_submission_time { + query.set_query_submission_time(*query_submission_time); + } + if let Some(session_metadata) = &query_context.session_metadata { + if let Some(user_name) = session_metadata.attr(SessionMetadataAttr::UserName) { + query.set_user_name(user_name); + } + if let Some(client_app_id) = session_metadata.attr(SessionMetadataAttr::ClientAppId) { + query.set_client_app_id(client_app_id); + } + if let Some(client_app_version) = session_metadata.attr(SessionMetadataAttr::ClientAppVersion) { + query.set_client_app_version(client_app_version); + } + } } } @@ -581,7 +604,7 @@ impl ExecutionService for CoreExecutionService { let mut query_obj = user_session.query(query_text, query_context); #[cfg(feature = "state-store-query")] { - // current database/schema at planning/execution time + // effective database/schema at planning/execution time query.set_database_name(query_obj.current_database()); query.set_schema_name(query_obj.current_schema()); } @@ -630,11 +653,9 @@ impl ExecutionService for CoreExecutionService { cfg_if::cfg_if! { if #[cfg(feature = "state-store-query")] { execution_result.assign_query_attributes(&mut query); - if let Some(stats) = queries_registry.cloned_stats(query_id) { - if let Some(query_type) = stats.query_type { - query.set_query_type(query_type.to_string()); - execution_result.assign_rows_counts_attributes(&mut query, query_type); - } + if let Some(stats) = queries_registry.cloned_stats(query_id) && let Some(query_type) = stats.query_type { + query.set_query_type(query_type.to_string()); + execution_result.assign_rows_counts_attributes(&mut query, query_type); } // just log error and do not raise it from task if let Err(err) = state_store.update_query(&query).await { diff --git a/crates/executor/src/session.rs b/crates/executor/src/session.rs index c26f2e1a..8f065fbf 100644 --- a/crates/executor/src/session.rs +++ b/crates/executor/src/session.rs @@ -19,6 +19,7 @@ use catalog::catalog_list::{DEFAULT_CATALOG, EmbucketCatalogList}; use catalog_metastore::Metastore; #[cfg(feature = "state-store")] use chrono::{TimeZone, Utc}; +use dashmap::DashMap; use datafusion::config::ConfigOptions; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::{SessionStateBuilder, SessionStateDefaults}; @@ -64,6 +65,7 @@ pub struct UserSession { pub session_params: Arc, pub recent_queries: Arc>>, pub session_id: String, + pub attrs: DashMap, } impl UserSession { @@ -160,6 +162,7 @@ impl UserSession { session_params: session_params_arc, recent_queries: Arc::new(RwLock::new(VecDeque::new())), session_id: session_id.to_string(), + attrs: DashMap::new(), }; Ok(session) } @@ -222,24 +225,13 @@ impl UserSession { .context(ex_error::StateStoreSnafu) } - pub async fn set_database(&self, database: &str) -> Result<()> { - self.set_variable("database", database).await - } - - pub async fn set_schema(&self, schema: &str) -> Result<()> { - self.set_variable("schema", schema).await - } - - pub async fn set_warehouse(&self, warehouse: &str) -> Result<()> { - self.set_variable("warehouse", warehouse).await - } - #[tracing::instrument( name = "api_snowflake_rest::session::set_variable", level = "info", skip(self), err )] + #[allow(dead_code)] async fn set_variable(&self, key: &str, value: &str) -> Result<()> { if key.is_empty() || value.is_empty() { return ex_error::OnyUseWithVariablesSnafu.fail(); diff --git a/crates/executor/src/tests/statestore_queries_unittest.rs b/crates/executor/src/tests/statestore_queries_unittest.rs index 2e8f72e3..01806fd6 100644 --- a/crates/executor/src/tests/statestore_queries_unittest.rs +++ b/crates/executor/src/tests/statestore_queries_unittest.rs @@ -15,6 +15,7 @@ use uuid::Uuid; const TEST_SESSION_ID: &str = "test_session_id"; const TEST_DATABASE: &str = "test_database"; const TEST_SCHEMA: &str = "test_schema"; +const TEST_TIMESTAMP: u64 = 1_764_161_275_445; const MOCK_RELATED_TIMEOUT_DURATION: Duration = Duration::from_millis(100); @@ -48,15 +49,37 @@ fn insta_settings(name: &str) -> insta::Settings { #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_query() { - let query_context = QueryContext::default().with_request_id(Uuid::default()); + let query_context = QueryContext::new( + Some("test_database".to_string()), + Some("test_schema".to_string()), + None, + ) + .with_request_id(Uuid::default()); let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); state_store_mock .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); + .returning(|_| { + let mut session = SessionRecord::new(TEST_SESSION_ID); + session.variables.insert("database".to_string(), state_store::Variable { + name: "database".to_string(), + value: "embucket".to_string(), + value_type: "text".to_string(), + comment: None, + created_at: 1, + updated_at: None, + }); + session.variables.insert("schema".to_string(), state_store::Variable { + name: "schema".to_string(), + value: "public".to_string(), + value_type: "text".to_string(), + comment: None, + created_at: 1, + updated_at: None, + }); + Ok(session) + }); + state_store_mock .expect_put_query() .times(1) @@ -75,7 +98,9 @@ async fn test_query_lifecycle_ok_query() { "start_time": "2026-01-01T01:01:01.000000001Z", "release_version": "test-version", "query_hash": "1717924485430328356", - "query_hash_version": 1 + "query_hash_version": 1, + "user_database_name": "test_database", + "user_schema_name": "test_schema" } "#); }); @@ -93,8 +118,8 @@ async fn test_query_lifecycle_ok_query() { "request_id": "00000000-0000-0000-0000-000000000000", "query_text": "SELECT 1 AS a, 2.0 AS b, '3' AS 'c'", "session_id": "test_session_id", - "database_name": "embucket", - "schema_name": "public", + "database_name": "test_database", + "schema_name": "test_schema", "query_type": "SELECT", "warehouse_type": "DEFAULT", "execution_status": "Success", @@ -105,6 +130,8 @@ async fn test_query_lifecycle_ok_query() { "release_version": "test-version", "query_hash": "1717924485430328356", "query_hash_version": 1, + "user_database_name": "test_database", + "user_schema_name": "test_schema", "query_metrics": "[query_metrics]" } "#); @@ -115,6 +142,10 @@ async fn test_query_lifecycle_ok_query() { let state_store: Arc = Arc::new(state_store_mock); let metastore = Arc::new(InMemoryMetastore::new()); + MetastoreBootstrapConfig::bootstrap() + .apply(metastore.clone()) + .await + .expect("Failed to bootstrap metastore"); let execution_svc = CoreExecutionService::new_test_executor( metastore, state_store, @@ -129,7 +160,7 @@ async fn test_query_lifecycle_ok_query() { ) .await .expect("Create session timed out") - .expect("Failed to create session"); + .expect("Failed to create session"); // See note about timeout above let _ = timeout( @@ -148,7 +179,13 @@ async fn test_query_lifecycle_ok_query() { #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_insert() { - let query_context = QueryContext::default().with_request_id(Uuid::default()); + let query_context = QueryContext::new( + Some(TEST_DATABASE.to_string()), + Some(TEST_SCHEMA.to_string()), + None, + ) + .with_query_submission_time(Some(TEST_TIMESTAMP)) + .with_request_id(Uuid::default()); let mut state_store_mock = MockStateStore::new(); state_store_mock @@ -180,8 +217,8 @@ async fn test_query_lifecycle_ok_insert() { "request_id": "00000000-0000-0000-0000-000000000000", "query_text": "INSERT INTO embucket.public.table VALUES (1)", "session_id": "test_session_id", - "database_name": "embucket", - "schema_name": "public", + "database_name": "test_database", + "schema_name": "test_schema", "query_type": "INSERT", "warehouse_type": "DEFAULT", "execution_status": "Success", @@ -192,7 +229,10 @@ async fn test_query_lifecycle_ok_insert() { "release_version": "test-version", "query_hash": "17856184221539895914", "query_hash_version": 1, - "query_metrics": "[query_metrics]" + "user_database_name": "test_database", + "user_schema_name": "test_schema", + "query_metrics": "[query_metrics]", + "query_submission_time": 1764161275445 } "#); }); @@ -764,7 +804,9 @@ async fn test_query_lifecycle_query_status_incident_limit_exceeded() { "execution_time": "1", "release_version": "test-version", "query_hash": "8436521302113462945", - "query_hash_version": 1 + "query_hash_version": 1, + "user_database_name": "test_database", + "user_schema_name": "test_schema" } "#); }); diff --git a/crates/state-store/src/models.rs b/crates/state-store/src/models.rs index 11032ec3..a6a9db40 100644 --- a/crates/state-store/src/models.rs +++ b/crates/state-store/src/models.rs @@ -284,6 +284,12 @@ pub struct Query { pub query_history_time: Option, #[serde(default, skip_serializing_if = "Option::is_none")] pub query_result_time: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client_app_id: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub client_app_version: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub query_submission_time: Option, } impl Query { @@ -319,6 +325,18 @@ impl Query { self.schema_name = Some(schema); } + pub fn set_user_name(&mut self, user: String) { + self.user_name = Some(user); + } + + pub fn set_user_database(&mut self, database: String) { + self.user_database_name = Some(database); + } + + pub fn set_user_schema(&mut self, schema: String) { + self.user_schema_name = Some(schema); + } + pub const fn set_execution_status(&mut self, status: ExecutionStatus) { self.execution_status = Some(status); } @@ -363,6 +381,18 @@ impl Query { self.query_type = Some(query_type); } + pub fn set_client_app_id(&mut self, client_app_id: String) { + self.client_app_id = Some(client_app_id); + } + + pub fn set_client_app_version(&mut self, client_app_version: String) { + self.client_app_version = Some(client_app_version); + } + + pub const fn set_query_submission_time(&mut self, query_submission_time: u64) { + self.query_submission_time = Some(query_submission_time); + } + #[allow(clippy::as_conversions, clippy::cast_sign_loss)] pub fn set_end_time(&mut self) { let end_time = Utc::now(); From 7f02f23a7682921938cd263dc804b27aaea58ea5 Mon Sep 17 00:00:00 2001 From: Yaroslav Litvinov Date: Fri, 9 Jan 2026 14:44:29 +0100 Subject: [PATCH 2/5] fix unittests --- crates/executor/src/lib.rs | 2 +- .../src/tests/statestore_queries_unittest.rs | 762 ++++++++---------- 2 files changed, 348 insertions(+), 416 deletions(-) diff --git a/crates/executor/src/lib.rs b/crates/executor/src/lib.rs index 629d2649..55c6bf3b 100644 --- a/crates/executor/src/lib.rs +++ b/crates/executor/src/lib.rs @@ -19,7 +19,7 @@ pub mod utils; pub mod tests; pub use error::{Error, Result}; -pub use models::{SessionMetadata, SessionMetadataAttr}; +pub use models::{QueryResult, SessionMetadata, SessionMetadataAttr}; pub use query_types::{ExecutionStatus, QueryId}; pub use running_queries::RunningQueryId; pub use snowflake_error::SnowflakeError; diff --git a/crates/executor/src/tests/statestore_queries_unittest.rs b/crates/executor/src/tests/statestore_queries_unittest.rs index 01806fd6..954b57a0 100644 --- a/crates/executor/src/tests/statestore_queries_unittest.rs +++ b/crates/executor/src/tests/statestore_queries_unittest.rs @@ -1,5 +1,8 @@ +use crate::QueryResult; +use crate::error::Result; use crate::models::QueryContext; use crate::service::{CoreExecutionService, ExecutionService}; +use crate::session::UserSession; use crate::utils::Config; use catalog_metastore::InMemoryMetastore; use catalog_metastore::metastore_bootstrap_config::MetastoreBootstrapConfig; @@ -46,39 +49,73 @@ fn insta_settings(name: &str) -> insta::Settings { settings } +pub struct Mocker; + +impl Mocker { + pub fn apply_bypass_queries_mock(state_store_mock: &mut MockStateStore, count: usize) { + state_store_mock + .expect_put_query() + .times(count) + .returning(|_| Ok(())); + state_store_mock + .expect_update_query() + .times(count) + .returning(|_| Ok(())); + } + + pub fn apply_bypass_put_queries_only_mock(state_store_mock: &mut MockStateStore, count: usize) { + state_store_mock + .expect_put_query() + .times(count) + .returning(|_| Ok(())); + } + + pub fn apply_create_session_mock( + state_store_mock: &mut MockStateStore, + f: fn(&str) -> state_store::Result, + ) { + state_store_mock + .expect_put_new_session() + .returning(|_| Ok(())); + state_store_mock.expect_put_session().returning(|_| Ok(())); + state_store_mock.expect_get_session().returning(f); + } + + pub async fn create_session( + executor: Arc, + session_id: &str, + ) -> Result> { + timeout( + MOCK_RELATED_TIMEOUT_DURATION, + executor.create_session(session_id), + ) + .await + .expect("Create session timed out") + } + + pub async fn query( + executor: Arc, + session_id: &str, + query_context: QueryContext, + sql: &str, + ) -> Result { + timeout( + MOCK_RELATED_TIMEOUT_DURATION, + executor.query(session_id, sql, query_context.clone()), + ) + .await + .expect("Query timed out") + } +} + #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_query() { - let query_context = QueryContext::new( - Some("test_database".to_string()), - Some("test_schema".to_string()), - None, - ) - .with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_get_session() - .returning(|_| { - let mut session = SessionRecord::new(TEST_SESSION_ID); - session.variables.insert("database".to_string(), state_store::Variable { - name: "database".to_string(), - value: "embucket".to_string(), - value_type: "text".to_string(), - comment: None, - created_at: 1, - updated_at: None, - }); - session.variables.insert("schema".to_string(), state_store::Variable { - name: "schema".to_string(), - value: "public".to_string(), - value_type: "text".to_string(), - comment: None, - created_at: 1, - updated_at: None, - }); - Ok(session) - }); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_queries_mock(&mut state_store_mock, 2); state_store_mock .expect_put_query() @@ -106,6 +143,7 @@ async fn test_query_lifecycle_ok_query() { }); true }); + state_store_mock .expect_update_query() .times(1) @@ -118,8 +156,8 @@ async fn test_query_lifecycle_ok_query() { "request_id": "00000000-0000-0000-0000-000000000000", "query_text": "SELECT 1 AS a, 2.0 AS b, '3' AS 'c'", "session_id": "test_session_id", - "database_name": "test_database", - "schema_name": "test_schema", + "database_name": "embucket", + "schema_name": "public", "query_type": "SELECT", "warehouse_type": "DEFAULT", "execution_status": "Success", @@ -139,72 +177,71 @@ async fn test_query_lifecycle_ok_query() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::new( + Some("test_database".to_string()), + Some("test_schema".to_string()), + None, + ) + .with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() .apply(metastore.clone()) .await - .expect("Failed to bootstrap metastore"); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), + .expect("Failed to bootstrap metastore"); + + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); + + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); + + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "SET DATABASE = 'embucket'", ) .await - .expect("Failed to create execution service"); + .expect("Query execution failed"); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "SET SCHEMA = 'public'", ) .await - .expect("Create session timed out") - .expect("Failed to create session"); + .expect("Query execution failed"); - // See note about timeout above - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "SELECT 1 AS a, 2.0 AS b, '3' AS 'c'", - query_context, - ), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "SELECT 1 AS a, 2.0 AS b, '3' AS 'c'", ) .await - .expect("Query timed out") .expect("Query execution failed"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_insert() { - let query_context = QueryContext::new( - Some(TEST_DATABASE.to_string()), - Some(TEST_SCHEMA.to_string()), - None, - ) - .with_query_submission_time(Some(TEST_TIMESTAMP)) - .with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); - state_store_mock - .expect_put_query() - .times(2) - .returning(|_| Ok(())); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_queries_mock(&mut state_store_mock, 1); + Mocker::apply_bypass_put_queries_only_mock(&mut state_store_mock, 1); - // bypass 1st update - state_store_mock - .expect_update_query() - .times(1) - .returning(|_| Ok(())); - // verify 2nd update state_store_mock .expect_update_query() .times(1) @@ -239,79 +276,63 @@ async fn test_query_lifecycle_ok_insert() { true }); - let state_store: Arc = Arc::new(state_store_mock); - let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() .apply(metastore.clone()) .await .expect("Failed to bootstrap metastore"); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), + let ctx = QueryContext::new( + Some(TEST_DATABASE.to_string()), + Some(TEST_SCHEMA.to_string()), + None, ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + .with_query_submission_time(Some(TEST_TIMESTAMP)) + .with_request_id(Uuid::default()); - // prepare table - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "create table if not exists embucket.public.table (id int)", - query_context.clone(), - ), + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); + + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "create table if not exists embucket.public.table (id int)", ) .await - .expect("Query timed out") .expect("Query execution failed"); - // insert - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "INSERT INTO embucket.public.table VALUES (1)", - query_context, - ), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "INSERT INTO embucket.public.table VALUES (1)", ) .await - .expect("Query timed out") .expect("Query execution failed"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_update() { - let query_context = QueryContext::default().with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); - state_store_mock - .expect_put_query() - .times(2) - .returning(|_| Ok(())); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_queries_mock(&mut state_store_mock, 1); + Mocker::apply_bypass_put_queries_only_mock(&mut state_store_mock, 1); - // bypass 1st update - state_store_mock - .expect_update_query() - .times(1) - .returning(|_| Ok(())); // verify 2nd update state_store_mock .expect_update_query() @@ -343,7 +364,7 @@ async fn test_query_lifecycle_ok_update() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::default().with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() @@ -351,29 +372,25 @@ async fn test_query_lifecycle_ok_update() { .await .expect("Failed to bootstrap metastore"); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), - ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); - // prepare table - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - " - CREATE TABLE embucket.public.table AS SELECT + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "CREATE TABLE embucket.public.table AS SELECT id, name, RANDOM() AS random_value, @@ -384,50 +401,30 @@ async fn test_query_lifecycle_ok_update() { (3, 'Charlie'), (4, 'David') ) AS t(id, name);", - query_context.clone(), - ), ) .await - .expect("Query timed out") .expect("Query execution failed"); - // update - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "UPDATE embucket.public.table SET name = 'John'", - query_context, - ), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "UPDATE embucket.public.table SET name = 'John'", ) .await - .expect("Query timed out") .expect("Query execution failed"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_delete_failed() { - let query_context = QueryContext::default().with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); - state_store_mock - .expect_put_query() - .times(2) - .returning(|_| Ok(())); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_queries_mock(&mut state_store_mock, 1); + Mocker::apply_bypass_put_queries_only_mock(&mut state_store_mock, 1); - // bypass 1st update - state_store_mock - .expect_update_query() - .times(1) - .returning(|_| Ok(())); - // verify 2nd update state_store_mock .expect_update_query() .times(1) @@ -459,7 +456,7 @@ async fn test_query_lifecycle_delete_failed() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::default().with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() @@ -467,29 +464,25 @@ async fn test_query_lifecycle_delete_failed() { .await .expect("Failed to bootstrap metastore"); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), - ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); - // prepare table - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - " - CREATE TABLE embucket.public.table AS SELECT + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "CREATE TABLE embucket.public.table AS SELECT id, name, RANDOM() AS random_value, @@ -500,50 +493,30 @@ async fn test_query_lifecycle_delete_failed() { (3, 'Charlie'), (4, 'David') ) AS t(id, name);", - query_context.clone(), - ), ) .await - .expect("Query timed out") .expect("Query execution failed"); - // update - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "DELETE FROM embucket.public.table", - query_context, - ), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "DELETE FROM embucket.public.table", ) .await - .expect("Query timed out") .expect_err("Query expected to fail"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_truncate() { - let query_context = QueryContext::default().with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); - state_store_mock - .expect_put_query() - .times(2) - .returning(|_| Ok(())); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_queries_mock(&mut state_store_mock, 1); + Mocker::apply_bypass_put_queries_only_mock(&mut state_store_mock, 1); - // bypass 1st update - state_store_mock - .expect_update_query() - .times(1) - .returning(|_| Ok(())); - // verify 2nd update state_store_mock .expect_update_query() .times(1) @@ -575,7 +548,7 @@ async fn test_query_lifecycle_ok_truncate() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::default().with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() @@ -583,29 +556,25 @@ async fn test_query_lifecycle_ok_truncate() { .await .expect("Failed to bootstrap metastore"); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), - ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); - // prepare table - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - " - CREATE TABLE embucket.public.table AS SELECT + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "CREATE TABLE embucket.public.table AS SELECT id, name, RANDOM() AS random_value, @@ -616,49 +585,30 @@ async fn test_query_lifecycle_ok_truncate() { (3, 'Charlie'), (4, 'David') ) AS t(id, name);", - query_context.clone(), - ), ) .await - .expect("Query timed out") .expect("Query execution failed"); - // update - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "TRUNCATE TABLE embucket.public.table", - query_context, - ), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "TRUNCATE TABLE embucket.public.table", ) .await - .expect("Query timed out") .expect("Query execution failed"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_merge() { - let query_context = QueryContext::default().with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); - state_store_mock - .expect_put_query() - .times(3) - .returning(|_| Ok(())); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_queries_mock(&mut state_store_mock, 2); + Mocker::apply_bypass_put_queries_only_mock(&mut state_store_mock, 1); - // bypass first two updates - state_store_mock - .expect_update_query() - .times(2) - .returning(|_| Ok(())); // verify 3rd update state_store_mock .expect_update_query() @@ -670,7 +620,7 @@ async fn test_query_lifecycle_ok_merge() { { "query_id": "00000000-0000-0000-0000-000000000000", "request_id": "00000000-0000-0000-0000-000000000000", - "query_text": "MERGE INTO t1 USING (SELECT * FROM t2) AS t2 ON t1.a = t2.a WHEN MATCHED THEN UPDATE SET t1.c = t2.c WHEN NOT MATCHED THEN INSERT (a,c) VALUES(t2.a,t2.c)", + "query_text": "MERGE INTO t1 USING \n (SELECT * FROM t2) AS t2 \n ON t1.a = t2.a \n WHEN MATCHED THEN UPDATE SET t1.c = t2.c \n WHEN NOT MATCHED THEN INSERT (a,c) VALUES(t2.a,t2.c)", "session_id": "test_session_id", "database_name": "embucket", "schema_name": "public", @@ -683,7 +633,7 @@ async fn test_query_lifecycle_ok_merge() { "rows_inserted": 1, "execution_time": "1", "release_version": "test-version", - "query_hash": "16532873076018472935", + "query_hash": "10180476120311618623", "query_hash_version": 1, "query_metrics": "[query_metrics]" } @@ -692,7 +642,7 @@ async fn test_query_lifecycle_ok_merge() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::default().with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() @@ -700,29 +650,25 @@ async fn test_query_lifecycle_ok_merge() { .await .expect("Failed to bootstrap metastore"); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), - ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); - // prepare tables - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - " - CREATE TABLE embucket.public.t1 AS SELECT + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "CREATE TABLE embucket.public.t1 AS SELECT a,b,c FROM (VALUES (1,'b1','c1'), @@ -730,19 +676,15 @@ async fn test_query_lifecycle_ok_merge() { (2,'b3','c3'), (3,'b4','c4') ) AS t(a, b, c);", - query_context.clone(), - ), ) .await - .expect("Query timed out") .expect("Query execution failed"); - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - " - CREATE TABLE embucket.public.t2 AS SELECT + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "CREATE TABLE embucket.public.t2 AS SELECT a,b,c FROM (VALUES (1,'b_5','c_5'), @@ -750,39 +692,32 @@ async fn test_query_lifecycle_ok_merge() { (2,'b_7','c_7'), (4,'b_8','c_8') ) AS t(a, b, c);", - query_context.clone(), - ), ) .await - .expect("Query timed out") .expect("Query execution failed"); - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query(TEST_SESSION_ID, "MERGE INTO t1 USING (SELECT * FROM t2) AS t2 ON t1.a = t2.a WHEN MATCHED THEN UPDATE SET t1.c = t2.c WHEN NOT MATCHED THEN INSERT (a,c) VALUES(t2.a,t2.c)", query_context), + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "MERGE INTO t1 USING + (SELECT * FROM t2) AS t2 + ON t1.a = t2.a + WHEN MATCHED THEN UPDATE SET t1.c = t2.c + WHEN NOT MATCHED THEN INSERT (a,c) VALUES(t2.a,t2.c)", ) .await - .expect("Query timed out") .expect("Query execution failed"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_query_status_incident_limit_exceeded() { - let query_context = QueryContext::new( - Some(TEST_DATABASE.to_string()), - Some(TEST_SCHEMA.to_string()), - None, - ) - .with_request_id(Uuid::default()); - let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + state_store_mock.expect_put_query() .returning(|_| Ok(()) ) .times(1) @@ -813,42 +748,46 @@ async fn test_query_lifecycle_query_status_incident_limit_exceeded() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::new( + Some(TEST_DATABASE.to_string()), + Some(TEST_SCHEMA.to_string()), + None, + ) + .with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default().with_max_concurrency_level(0)), - ) - .await - .expect("Failed to create execution service"); + MetastoreBootstrapConfig::bootstrap() + .apply(metastore.clone()) + .await + .expect("Failed to bootstrap metastore"); + + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default().with_max_concurrency_level(0)), + ) + .await + .expect("Failed to create execution service"), + ); - execution_svc - .create_session(TEST_SESSION_ID) + Mocker::create_session(ex.clone(), TEST_SESSION_ID) .await .expect("Failed to create session"); - // See note about timeout above - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query(TEST_SESSION_ID, "SELECT 1", query_context), - ) - .await - .expect("Query timed out") - .expect_err("Query execution should fail"); + Mocker::query(ex.clone(), TEST_SESSION_ID, ctx.clone(), "SELECT 1") + .await + .expect_err("Query execution should fail"); } #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_query_status_fail() { let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + state_store_mock .expect_put_query() .times(1) @@ -902,36 +841,35 @@ async fn test_query_lifecycle_query_status_fail() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::default().with_request_id(Uuid::new_v4()); let metastore = Arc::new(InMemoryMetastore::new()); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + MetastoreBootstrapConfig::bootstrap() + .apply(metastore.clone()) + .await + .expect("Failed to bootstrap metastore"); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), - ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); - // See note about timeout above - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.query( - TEST_SESSION_ID, - "SELECT should fail", - QueryContext::default().with_request_id(Uuid::new_v4()), - ), + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); + + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "SELECT should fail", ) .await - .expect("Query timed out") .expect_err("Query execution should fail"); } @@ -939,12 +877,10 @@ async fn test_query_lifecycle_query_status_fail() { #[tokio::test] async fn test_query_lifecycle_query_status_cancelled() { let mut state_store_mock = MockStateStore::new(); - state_store_mock - .expect_put_new_session() - .returning(|_| Ok(())); - state_store_mock - .expect_get_session() - .returning(|_| Ok(SessionRecord::new(TEST_SESSION_ID))); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + state_store_mock .expect_put_query() .times(1) @@ -997,43 +933,39 @@ async fn test_query_lifecycle_query_status_cancelled() { true }); - let state_store: Arc = Arc::new(state_store_mock); + let ctx = QueryContext::default().with_request_id(Uuid::default()); let metastore = Arc::new(InMemoryMetastore::new()); - let execution_svc = CoreExecutionService::new_test_executor( - metastore, - state_store, - Arc::new(Config::default()), - ) - .await - .expect("Failed to create execution service"); + MetastoreBootstrapConfig::bootstrap() + .apply(metastore.clone()) + .await + .expect("Failed to bootstrap metastore"); - timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.create_session(TEST_SESSION_ID), - ) - .await - .expect("Create session timed out") - .expect("Failed to create session"); + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); + + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); // See note about timeout above let query_handle = timeout( MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.submit( - TEST_SESSION_ID, - "SELECT 1", - QueryContext::default().with_request_id(Uuid::new_v4()), - ), + ex.submit(TEST_SESSION_ID, "SELECT 1", ctx), ) .await .expect("Query timed out") .expect("Query submit error"); - let _ = timeout( - MOCK_RELATED_TIMEOUT_DURATION, - execution_svc.abort(query_handle), - ) - .await - .expect("Query timed out") - .expect("Failed to cancel query"); + let _ = timeout(MOCK_RELATED_TIMEOUT_DURATION, ex.abort(query_handle)) + .await + .expect("Query timed out") + .expect("Failed to cancel query"); } From 03a37049486459d7a075279101cf05f3406f060a Mon Sep 17 00:00:00 2001 From: Yaroslav Litvinov Date: Fri, 9 Jan 2026 15:11:34 +0100 Subject: [PATCH 3/5] state-store-persist-session-oncreate --- crates/embucket-lambda/Cargo.toml | 1 + crates/embucketd/Cargo.toml | 2 ++ crates/executor/Cargo.toml | 4 ++++ crates/executor/src/service.rs | 10 +++++----- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/crates/embucket-lambda/Cargo.toml b/crates/embucket-lambda/Cargo.toml index e16d6369..ebb02988 100644 --- a/crates/embucket-lambda/Cargo.toml +++ b/crates/embucket-lambda/Cargo.toml @@ -45,6 +45,7 @@ retry-disable = ["api-snowflake-rest/retry-disable"] streaming = [] rest-catalog = ["executor/rest-catalog"] dedicated-executor = ["executor/dedicated-executor"] +state-store-persist-session-oncreate = ["executor/state-store-persist-session-oncreate"] state-store-query = ["executor/state-store-query"] [package.metadata.lambda] diff --git a/crates/embucketd/Cargo.toml b/crates/embucketd/Cargo.toml index 83a74b5a..7be0dba7 100644 --- a/crates/embucketd/Cargo.toml +++ b/crates/embucketd/Cargo.toml @@ -52,6 +52,8 @@ retry-disable = ["api-snowflake-rest/retry-disable"] rest-catalog = ["executor/rest-catalog"] dedicated-executor = ["executor/dedicated-executor"] state-store = ["executor/state-store"] +state-store-persist-session-oncreate = ["executor/state-store-persist-session-oncreate"] state-store-query = ["executor/state-store-query"] state-store-query-test = ["executor/state-store-query-test"] + diff --git a/crates/executor/Cargo.toml b/crates/executor/Cargo.toml index 93aa42b3..357241a2 100644 --- a/crates/executor/Cargo.toml +++ b/crates/executor/Cargo.toml @@ -12,6 +12,10 @@ dedicated-executor = [] # "state-store" feature enables DynamoDB based state-store implementation. state-store = [] +# "state-store-persist-session-oncreate" feature allows to persist empty session record. +# Otherwise it will be persisted only if session has something to persist. +state-store-persist-session-oncreate = ["state-store"] + state-store-query-test = ["state-store-query"] # "state-store-query" feature depends on state-store feature. diff --git a/crates/executor/src/service.rs b/crates/executor/src/service.rs index f824a25d..d9db95ff 100644 --- a/crates/executor/src/service.rs +++ b/crates/executor/src/service.rs @@ -321,11 +321,11 @@ impl ExecutionService for CoreExecutionService { tracing::trace!("Acquired write lock for df_sessions"); sessions.insert(session_id.to_string(), user_session.clone()); - // #[cfg(feature = "state-store")] - // self.state_store - // .put_new_session(session_id) - // .await - // .context(ex_error::StateStoreSnafu)?; + #[cfg(feature = "state-store-persist-session-oncreate")] + self.state_store + .put_new_session(session_id) + .await + .context(ex_error::StateStoreSnafu)?; // Record the result as part of the current span. tracing::Span::current().record("new_sessions_count", sessions.len()); From 5ea4ddfa397e7f42dc7e2d3f0cdcc3bd119ad3e3 Mon Sep 17 00:00:00 2001 From: Yaroslav Litvinov Date: Fri, 9 Jan 2026 18:59:06 +0100 Subject: [PATCH 4/5] add explain/analyze query_type, fix SessionMetadata --- crates/api-snowflake-rest/src/server/logic.rs | 3 +- crates/executor/src/query.rs | 6 + crates/executor/src/query_types.rs | 3 + .../src/tests/statestore_queries_unittest.rs | 114 +++++++++++++++++- 4 files changed, 120 insertions(+), 6 deletions(-) diff --git a/crates/api-snowflake-rest/src/server/logic.rs b/crates/api-snowflake-rest/src/server/logic.rs index 7fddc8d7..8316f21f 100644 --- a/crates/api-snowflake-rest/src/server/logic.rs +++ b/crates/api-snowflake-rest/src/server/logic.rs @@ -121,7 +121,8 @@ pub async fn handle_query_request( None, ) .with_request_id(query.request_id) - .with_query_submission_time(query_submission_time); + .with_query_submission_time(query_submission_time) + .with_session_metadata(Some(session_metadata)); if let Some(ip) = client_ip { query_context = query_context.with_ip_address(ip); diff --git a/crates/executor/src/query.rs b/crates/executor/src/query.rs index e755163c..e789ff73 100644 --- a/crates/executor/src/query.rs +++ b/crates/executor/src/query.rs @@ -312,10 +312,16 @@ impl UserQuery { Statement::ShowTables { .. } => save(QueryType::Misc(MiscStType::ShowTables)), Statement::ShowViews { .. } => save(QueryType::Misc(MiscStType::ShowViews)), Statement::ExplainTable { .. } => save(QueryType::Misc(MiscStType::ExplainTable)), + Statement::Explain { .. } => save(QueryType::Misc(MiscStType::Explain)), + Statement::Analyze { .. } => save(QueryType::Misc(MiscStType::Analyze)), _ => {} } } else if let DFStatement::CreateExternalTable(..) = statement { save(QueryType::Ddl(DdlStType::CreateExternalTable)); + } else if let DFStatement::Explain(..) = statement { + save(QueryType::Misc(MiscStType::Explain)); + } else if let DFStatement::CopyTo(..) = statement { + save(QueryType::Misc(MiscStType::CopyTo)); } } diff --git a/crates/executor/src/query_types.rs b/crates/executor/src/query_types.rs index 658abccb..67f524b7 100644 --- a/crates/executor/src/query_types.rs +++ b/crates/executor/src/query_types.rs @@ -68,6 +68,9 @@ pub enum MiscStType { ShowTables, ShowViews, ExplainTable, + Explain, + Analyze, + CopyTo, } #[derive(Debug, Clone)] diff --git a/crates/executor/src/tests/statestore_queries_unittest.rs b/crates/executor/src/tests/statestore_queries_unittest.rs index 954b57a0..0fc612fe 100644 --- a/crates/executor/src/tests/statestore_queries_unittest.rs +++ b/crates/executor/src/tests/statestore_queries_unittest.rs @@ -1,13 +1,13 @@ -use crate::QueryResult; use crate::error::Result; use crate::models::QueryContext; use crate::service::{CoreExecutionService, ExecutionService}; use crate::session::UserSession; use crate::utils::Config; +use crate::{QueryResult, SessionMetadata}; use catalog_metastore::InMemoryMetastore; use catalog_metastore::metastore_bootstrap_config::MetastoreBootstrapConfig; use insta::assert_json_snapshot; -use state_store::{MockStateStore, Query, SessionRecord, StateStore}; +use state_store::{MockStateStore, Query, SessionRecord}; use std::sync::Arc; use tokio::time::{Duration, timeout}; use uuid::Uuid; @@ -137,7 +137,9 @@ async fn test_query_lifecycle_ok_query() { "query_hash": "1717924485430328356", "query_hash_version": 1, "user_database_name": "test_database", - "user_schema_name": "test_schema" + "user_schema_name": "test_schema", + "client_app_id": "client_app_id", + "client_app_version": "1.0.0" } "#); }); @@ -170,19 +172,31 @@ async fn test_query_lifecycle_ok_query() { "query_hash_version": 1, "user_database_name": "test_database", "user_schema_name": "test_schema", - "query_metrics": "[query_metrics]" + "query_metrics": "[query_metrics]", + "client_app_id": "client_app_id", + "client_app_version": "1.0.0" } "#); }); true }); + let mut session_metadata = SessionMetadata::default(); + session_metadata.set_attr( + crate::SessionMetadataAttr::ClientAppId, + "client_app_id".to_string(), + ); + session_metadata.set_attr( + crate::SessionMetadataAttr::ClientAppVersion, + "1.0.0".to_string(), + ); let ctx = QueryContext::new( Some("test_database".to_string()), Some("test_schema".to_string()), None, ) - .with_request_id(Uuid::default()); + .with_request_id(Uuid::default()) + .with_session_metadata(Some(session_metadata)); let metastore = Arc::new(InMemoryMetastore::new()); MetastoreBootstrapConfig::bootstrap() @@ -232,6 +246,96 @@ async fn test_query_lifecycle_ok_query() { .expect("Query execution failed"); } +#[allow(clippy::expect_used)] +#[tokio::test] +async fn test_query_lifecycle_explain_query() { + let mut state_store_mock = MockStateStore::new(); + Mocker::apply_create_session_mock(&mut state_store_mock, |_| { + Ok(SessionRecord::new(TEST_SESSION_ID)) + }); + Mocker::apply_bypass_put_queries_only_mock(&mut state_store_mock, 1); + + state_store_mock + .expect_update_query() + .times(1) + .returning(|_| Ok(())) + .withf(move |query: &Query| { + insta_settings("explain_query_update").bind(|| { + assert_json_snapshot!(query, @r#" + { + "query_id": "00000000-0000-0000-0000-000000000000", + "request_id": "00000000-0000-0000-0000-000000000000", + "query_text": "EXPLAIN SELECT 1 AS a, 2.0 AS b, '3' AS 'c'", + "session_id": "test_session_id", + "database_name": "test_database", + "schema_name": "test_schema", + "query_type": "EXPLAIN", + "warehouse_type": "DEFAULT", + "execution_status": "Success", + "start_time": "2026-01-01T01:01:01.000000001Z", + "end_time": "2026-01-01T01:01:01.000000001Z", + "execution_time": "1", + "release_version": "test-version", + "query_hash": "1265703338911562377", + "query_hash_version": 1, + "user_database_name": "test_database", + "user_schema_name": "test_schema", + "query_metrics": "[query_metrics]", + "client_app_id": "client_app_id", + "client_app_version": "1.0.0" + } + "#); + }); + true + }); + + let mut session_metadata = SessionMetadata::default(); + session_metadata.set_attr( + crate::SessionMetadataAttr::ClientAppId, + "client_app_id".to_string(), + ); + session_metadata.set_attr( + crate::SessionMetadataAttr::ClientAppVersion, + "1.0.0".to_string(), + ); + let ctx = QueryContext::new( + Some("test_database".to_string()), + Some("test_schema".to_string()), + None, + ) + .with_request_id(Uuid::default()) + .with_session_metadata(Some(session_metadata)); + + let metastore = Arc::new(InMemoryMetastore::new()); + MetastoreBootstrapConfig::bootstrap() + .apply(metastore.clone()) + .await + .expect("Failed to bootstrap metastore"); + + let ex: Arc = Arc::new( + CoreExecutionService::new_test_executor( + metastore, + Arc::new(state_store_mock), + Arc::new(Config::default()), + ) + .await + .expect("Failed to create execution service"), + ); + + Mocker::create_session(ex.clone(), TEST_SESSION_ID) + .await + .expect("Failed to create session"); + + Mocker::query( + ex.clone(), + TEST_SESSION_ID, + ctx.clone(), + "EXPLAIN SELECT 1 AS a, 2.0 AS b, '3' AS 'c'", + ) + .await + .expect("Query execution failed"); +} + #[allow(clippy::expect_used)] #[tokio::test] async fn test_query_lifecycle_ok_insert() { From 661c6b3612b5530b5c416f68818b9ce15d553665 Mon Sep 17 00:00:00 2001 From: Yaroslav Litvinov Date: Fri, 9 Jan 2026 20:13:52 +0100 Subject: [PATCH 5/5] change storage datatype of query_submittion_type to DataTime --- crates/executor/src/tests/statestore_queries_unittest.rs | 6 +++++- crates/state-store/src/models.rs | 7 +++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/executor/src/tests/statestore_queries_unittest.rs b/crates/executor/src/tests/statestore_queries_unittest.rs index 0fc612fe..48187b1f 100644 --- a/crates/executor/src/tests/statestore_queries_unittest.rs +++ b/crates/executor/src/tests/statestore_queries_unittest.rs @@ -46,6 +46,10 @@ fn insta_settings(name: &str) -> insta::Settings { r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{6}Z", "2026-01-01T01:01:01.000001Z", ); + settings.add_filter( + r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2026-01-01T01:01:01.001Z", + ); settings } @@ -373,7 +377,7 @@ async fn test_query_lifecycle_ok_insert() { "user_database_name": "test_database", "user_schema_name": "test_schema", "query_metrics": "[query_metrics]", - "query_submission_time": 1764161275445 + "query_submission_time": "2026-01-01T01:01:01.001Z" } "#); }); diff --git a/crates/state-store/src/models.rs b/crates/state-store/src/models.rs index a6a9db40..c57cc3d5 100644 --- a/crates/state-store/src/models.rs +++ b/crates/state-store/src/models.rs @@ -289,7 +289,7 @@ pub struct Query { #[serde(default, skip_serializing_if = "Option::is_none")] pub client_app_version: Option, #[serde(default, skip_serializing_if = "Option::is_none")] - pub query_submission_time: Option, + pub query_submission_time: Option>, } impl Query { @@ -389,8 +389,11 @@ impl Query { self.client_app_version = Some(client_app_version); } + #[allow(clippy::cast_possible_wrap, clippy::as_conversions)] pub const fn set_query_submission_time(&mut self, query_submission_time: u64) { - self.query_submission_time = Some(query_submission_time); + // Convert u64 timestamp to DateTime + let dt = DateTime::::from_timestamp_millis(query_submission_time as i64); + self.query_submission_time = dt; } #[allow(clippy::as_conversions, clippy::cast_sign_loss)]