diff --git a/packages/cubejs-backend-native/js/index.ts b/packages/cubejs-backend-native/js/index.ts index b5dbf3cdc4ad2..24e2c26cdd532 100644 --- a/packages/cubejs-backend-native/js/index.ts +++ b/packages/cubejs-backend-native/js/index.ts @@ -15,6 +15,8 @@ export interface BaseMeta { apiType: string, // Application name, for example Metabase appName?: string, + // Database name from the client startup message (e.g. psql dbname parameter) + database?: string, } export interface LoadRequestMeta extends BaseMeta { @@ -52,7 +54,7 @@ export interface CheckAuthPayload { } export interface CheckSQLAuthPayload { - request: Request, + request: Request, user: string | null, password: string | null, } diff --git a/packages/cubejs-backend-native/src/auth.rs b/packages/cubejs-backend-native/src/auth.rs index 73b2131b2ce2d..4b77a2225033a 100644 --- a/packages/cubejs-backend-native/src/auth.rs +++ b/packages/cubejs-backend-native/src/auth.rs @@ -120,10 +120,22 @@ impl SqlAuthService for NodeBridgeAuthService { let request_id = Uuid::new_v4().to_string(); + let meta = if request.database.is_some() { + let mut m = LoadRequestMeta::new( + "postgres".to_string(), + "sql".to_string(), + None, + ); + m.set_database(request.database.clone()); + Some(m) + } else { + None + }; + let extra = serde_json::to_string(&CheckSQLAuthTransportRequest { request: TransportAuthRequest { id: format!("{}-span-1", request_id), - meta: None, + meta, protocol: request.protocol, method: request.method, }, diff --git a/packages/cubejs-backend-native/test/sql.test.ts b/packages/cubejs-backend-native/test/sql.test.ts index 943383ecb5cca..656e6eed0de4e 100644 --- a/packages/cubejs-backend-native/test/sql.test.ts +++ b/packages/cubejs-backend-native/test/sql.test.ts @@ -182,7 +182,11 @@ describe('SQLInterface', () => { expect(checkSqlAuth.mock.calls[0][0]).toEqual({ request: { id: expect.any(String), - meta: null, + meta: { + protocol: 'postgres', + apiType: 'sql', + database: 'test', + }, method: expect.any(String), protocol: expect.any(String), }, @@ -241,7 +245,11 @@ describe('SQLInterface', () => { expect(checkSqlAuth.mock.calls[0][0]).toEqual({ request: { id: expect.any(String), - meta: null, + meta: { + protocol: 'postgres', + apiType: 'sql', + database: 'test', + }, method: expect.any(String), protocol: expect.any(String), }, diff --git a/rust/cubesql/cubesql/src/compile/router.rs b/rust/cubesql/cubesql/src/compile/router.rs index 1761a900b3b30..1d841aa7fcb67 100644 --- a/rust/cubesql/cubesql/src/compile/router.rs +++ b/rust/cubesql/cubesql/src/compile/router.rs @@ -503,6 +503,7 @@ impl QueryRouter { let sql_auth_request = SqlAuthServiceAuthenticateRequest { protocol: "postgres".to_string(), method: "password".to_string(), + database: self.state.database(), }; let authenticate_response = self .session_manager @@ -647,6 +648,7 @@ impl QueryRouter { let sql_auth_request = SqlAuthServiceAuthenticateRequest { protocol: "postgres".to_string(), method: "password".to_string(), + database: self.state.database(), }; let authenticate_response = self .session_manager diff --git a/rust/cubesql/cubesql/src/compile/test/mod.rs b/rust/cubesql/cubesql/src/compile/test/mod.rs index 20e63e584b5f8..ba5d8dc09639e 100644 --- a/rust/cubesql/cubesql/src/compile/test/mod.rs +++ b/rust/cubesql/cubesql/src/compile/test/mod.rs @@ -35,6 +35,8 @@ pub mod test_cube_join_grouped; #[cfg(test)] pub mod test_cube_scan; #[cfg(test)] +pub mod test_database_meta; +#[cfg(test)] pub mod test_df_execution; #[cfg(test)] pub mod test_filters; diff --git a/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs b/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs new file mode 100644 index 0000000000000..d22d79c15eb89 --- /dev/null +++ b/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs @@ -0,0 +1,115 @@ +//! Tests that check database name propagation through LoadRequestMeta + +use pretty_assertions::assert_eq; + +use crate::compile::{ + test::{init_testing_logger, TestContext}, + DatabaseProtocol, Rewriter, +}; +use crate::transport::LoadRequestMeta; + +#[tokio::test] +async fn test_database_propagates_through_load_request_meta() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let context = TestContext::new(DatabaseProtocol::PostgreSQL).await; + + context + .execute_query( + // language=PostgreSQL + r#" +SELECT + COALESCE(customer_gender, 'N/A'), + AVG(avgPrice) +FROM + KibanaSampleDataEcommerce +WHERE + LOWER(customer_gender) = 'test' +GROUP BY 1 +; + "# + .to_string(), + ) + .await + .expect_err("Test transport does not support load with SQL"); + + let load_calls = context.load_calls().await; + assert_eq!(load_calls.len(), 1); + assert_eq!(load_calls[0].meta.database(), Some("cubedb".to_string())); +} + +#[test] +fn test_load_request_meta_database_serialization() { + let mut meta = LoadRequestMeta::new( + "postgres".to_string(), + "sql".to_string(), + Some("test-app".to_string()), + ); + + let json = serde_json::to_value(&meta).unwrap(); + assert!(json.get("database").is_none()); + + meta.set_database(Some("mydb".to_string())); + let json = serde_json::to_value(&meta).unwrap(); + assert_eq!(json["database"], "mydb"); + assert_eq!(meta.database(), Some("mydb".to_string())); +} + +#[test] +fn test_load_request_meta_no_database_by_default() { + let meta = LoadRequestMeta::new( + "postgres".to_string(), + "sql".to_string(), + None, + ); + + assert_eq!(meta.database(), None); + + let json = serde_json::to_value(&meta).unwrap(); + assert!(json.get("database").is_none()); +} + +/// Verifies that database name from session state propagates into the SQL query +/// meta passed to TestConnectionTransport::sql. Follows the same pattern as +/// test_user_change::test_user_change_sql_generation. +#[tokio::test] +async fn test_database_in_sql_query_meta() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let context = TestContext::new(DatabaseProtocol::PostgreSQL).await; + + context + .execute_query( + // language=PostgreSQL + r#" +SELECT + COALESCE(customer_gender, 'N/A'), + AVG(avgPrice) +FROM + KibanaSampleDataEcommerce +WHERE + LOWER(customer_gender) = 'test' +GROUP BY 1 +; + "# + .to_string(), + ) + .await + .expect_err("Test transport does not support load with SQL"); + + let load_calls = context.load_calls().await; + assert_eq!(load_calls.len(), 1); + + // Database should appear in the serialized SQL query (set by TestConnectionTransport::sql) + let sql_query = load_calls[0].sql_query.as_ref().unwrap(); + assert!(sql_query.sql.contains(r#""database": "cubedb""#)); + + // And directly on the meta object + assert_eq!(load_calls[0].meta.database(), Some("cubedb".to_string())); +} diff --git a/rust/cubesql/cubesql/src/sql/auth_service.rs b/rust/cubesql/cubesql/src/sql/auth_service.rs index 3550716db9759..aa919098ffea6 100644 --- a/rust/cubesql/cubesql/src/sql/auth_service.rs +++ b/rust/cubesql/cubesql/src/sql/auth_service.rs @@ -48,6 +48,8 @@ pub struct AuthenticateResponse { pub struct SqlAuthServiceAuthenticateRequest { pub protocol: String, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub database: Option, } #[async_trait] diff --git a/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs b/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs index 240046fd50da1..3a955f50a2ff7 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs @@ -78,6 +78,7 @@ impl PostgresAuthService for PostgresAuthServiceDefaultImpl { let sql_auth_request = SqlAuthServiceAuthenticateRequest { protocol: "postgres".to_string(), method: "password".to_string(), + database: parameters.get("database").cloned(), }; let authenticate_response = service .authenticate( diff --git a/rust/cubesql/cubesql/src/sql/session.rs b/rust/cubesql/cubesql/src/sql/session.rs index 610d4c77c5f19..eb2c3167bc70d 100644 --- a/rust/cubesql/cubesql/src/sql/session.rs +++ b/rust/cubesql/cubesql/src/sql/session.rs @@ -409,11 +409,13 @@ impl SessionState { None }; - LoadRequestMeta::new( + let mut meta = LoadRequestMeta::new( self.protocol.get_name().to_string(), api_type.to_string(), application_name, - ) + ); + meta.set_database(self.database()); + meta } } diff --git a/rust/cubesql/cubesql/src/transport/service.rs b/rust/cubesql/cubesql/src/transport/service.rs index 23ef71ce3e7b2..532410b527ae8 100644 --- a/rust/cubesql/cubesql/src/transport/service.rs +++ b/rust/cubesql/cubesql/src/transport/service.rs @@ -51,6 +51,8 @@ pub struct LoadRequestMeta { // Optional fields #[serde(rename = "changeUser", skip_serializing_if = "Option::is_none")] change_user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + database: Option, } impl LoadRequestMeta { @@ -61,6 +63,7 @@ impl LoadRequestMeta { api_type, app_name, change_user: None, + database: None, } } @@ -71,6 +74,14 @@ impl LoadRequestMeta { pub fn set_change_user(&mut self, change_user: Option) { self.change_user = change_user; } + + pub fn database(&self) -> Option { + self.database.clone() + } + + pub fn set_database(&mut self, database: Option) { + self.database = database; + } } #[derive(Debug, Deserialize)]