Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion packages/cubejs-backend-native/js/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export interface BaseMeta {
apiType: string,
// Application name, for example Metabase
appName?: string,
// Database name from the client startup message (e.g. psql dbname parameter)
database?: string,
}

export interface LoadRequestMeta extends BaseMeta {
Expand Down Expand Up @@ -52,7 +54,7 @@ export interface CheckAuthPayload {
}

export interface CheckSQLAuthPayload {
request: Request<undefined>,
request: Request<BaseMeta | undefined>,
user: string | null,
password: string | null,
}
Expand Down
14 changes: 13 additions & 1 deletion packages/cubejs-backend-native/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,22 @@ impl SqlAuthService for NodeBridgeAuthService {

let request_id = Uuid::new_v4().to_string();

let meta = if request.database.is_some() {
let mut m = LoadRequestMeta::new(
"postgres".to_string(),
"sql".to_string(),
None,
);
m.set_database(request.database.clone());
Some(m)
} else {
None
};

let extra = serde_json::to_string(&CheckSQLAuthTransportRequest {
request: TransportAuthRequest {
id: format!("{}-span-1", request_id),
meta: None,
meta,
protocol: request.protocol,
method: request.method,
},
Expand Down
12 changes: 10 additions & 2 deletions packages/cubejs-backend-native/test/sql.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ describe('SQLInterface', () => {
expect(checkSqlAuth.mock.calls[0][0]).toEqual({
request: {
id: expect.any(String),
meta: null,
meta: {
protocol: 'postgres',
apiType: 'sql',
database: 'test',
},
method: expect.any(String),
protocol: expect.any(String),
},
Expand Down Expand Up @@ -241,7 +245,11 @@ describe('SQLInterface', () => {
expect(checkSqlAuth.mock.calls[0][0]).toEqual({
request: {
id: expect.any(String),
meta: null,
meta: {
protocol: 'postgres',
apiType: 'sql',
database: 'test',
},
method: expect.any(String),
protocol: expect.any(String),
},
Expand Down
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/compile/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ impl QueryRouter {
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
protocol: "postgres".to_string(),
method: "password".to_string(),
database: self.state.database(),
};
let authenticate_response = self
.session_manager
Expand Down Expand Up @@ -647,6 +648,7 @@ impl QueryRouter {
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
protocol: "postgres".to_string(),
method: "password".to_string(),
database: self.state.database(),
};
let authenticate_response = self
.session_manager
Expand Down
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/compile/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub mod test_cube_join_grouped;
#[cfg(test)]
pub mod test_cube_scan;
#[cfg(test)]
pub mod test_database_meta;
#[cfg(test)]
pub mod test_df_execution;
#[cfg(test)]
pub mod test_filters;
Expand Down
115 changes: 115 additions & 0 deletions rust/cubesql/cubesql/src/compile/test/test_database_meta.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
//! Tests that check database name propagation through LoadRequestMeta

use pretty_assertions::assert_eq;

use crate::compile::{
test::{init_testing_logger, TestContext},
DatabaseProtocol, Rewriter,
};
use crate::transport::LoadRequestMeta;

#[tokio::test]
async fn test_database_propagates_through_load_request_meta() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let context = TestContext::new(DatabaseProtocol::PostgreSQL).await;

context
.execute_query(
// language=PostgreSQL
r#"
SELECT
COALESCE(customer_gender, 'N/A'),
AVG(avgPrice)
FROM
KibanaSampleDataEcommerce
WHERE
LOWER(customer_gender) = 'test'
GROUP BY 1
;
"#
.to_string(),
)
.await
.expect_err("Test transport does not support load with SQL");

let load_calls = context.load_calls().await;
assert_eq!(load_calls.len(), 1);
assert_eq!(load_calls[0].meta.database(), Some("cubedb".to_string()));
}

#[test]
fn test_load_request_meta_database_serialization() {
let mut meta = LoadRequestMeta::new(
"postgres".to_string(),
"sql".to_string(),
Some("test-app".to_string()),
);

let json = serde_json::to_value(&meta).unwrap();
assert!(json.get("database").is_none());

meta.set_database(Some("mydb".to_string()));
let json = serde_json::to_value(&meta).unwrap();
assert_eq!(json["database"], "mydb");
assert_eq!(meta.database(), Some("mydb".to_string()));
}

Check warning on line 60 in rust/cubesql/cubesql/src/compile/test/test_database_meta.rs

View workflow job for this annotation

GitHub Actions / Check fmt/clippy

Diff in /__w/cube/cube/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs

Check warning on line 60 in rust/cubesql/cubesql/src/compile/test/test_database_meta.rs

View workflow job for this annotation

GitHub Actions / lint

Diff in /home/runner/work/cube/cube/rust/cubesql/cubesql/src/compile/test/test_database_meta.rs
#[test]
fn test_load_request_meta_no_database_by_default() {
let meta = LoadRequestMeta::new(
"postgres".to_string(),
"sql".to_string(),
None,
);

assert_eq!(meta.database(), None);

let json = serde_json::to_value(&meta).unwrap();
assert!(json.get("database").is_none());
}

/// Verifies that database name from session state propagates into the SQL query
/// meta passed to TestConnectionTransport::sql. Follows the same pattern as
/// test_user_change::test_user_change_sql_generation.
#[tokio::test]
async fn test_database_in_sql_query_meta() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let context = TestContext::new(DatabaseProtocol::PostgreSQL).await;

context
.execute_query(
// language=PostgreSQL
r#"
SELECT
COALESCE(customer_gender, 'N/A'),
AVG(avgPrice)
FROM
KibanaSampleDataEcommerce
WHERE
LOWER(customer_gender) = 'test'
GROUP BY 1
;
"#
.to_string(),
)
.await
.expect_err("Test transport does not support load with SQL");

let load_calls = context.load_calls().await;
assert_eq!(load_calls.len(), 1);

// Database should appear in the serialized SQL query (set by TestConnectionTransport::sql)
let sql_query = load_calls[0].sql_query.as_ref().unwrap();
assert!(sql_query.sql.contains(r#""database": "cubedb""#));

// And directly on the meta object
assert_eq!(load_calls[0].meta.database(), Some("cubedb".to_string()));
}
2 changes: 2 additions & 0 deletions rust/cubesql/cubesql/src/sql/auth_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ pub struct AuthenticateResponse {
pub struct SqlAuthServiceAuthenticateRequest {
pub protocol: String,
pub method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub database: Option<String>,
}

#[async_trait]
Expand Down
1 change: 1 addition & 0 deletions rust/cubesql/cubesql/src/sql/postgres/pg_auth_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl PostgresAuthService for PostgresAuthServiceDefaultImpl {
let sql_auth_request = SqlAuthServiceAuthenticateRequest {
protocol: "postgres".to_string(),
method: "password".to_string(),
database: parameters.get("database").cloned(),
};
let authenticate_response = service
.authenticate(
Expand Down
6 changes: 4 additions & 2 deletions rust/cubesql/cubesql/src/sql/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,13 @@ impl SessionState {
None
};

LoadRequestMeta::new(
let mut meta = LoadRequestMeta::new(
self.protocol.get_name().to_string(),
api_type.to_string(),
application_name,
)
);
meta.set_database(self.database());
meta
}
}

Expand Down
11 changes: 11 additions & 0 deletions rust/cubesql/cubesql/src/transport/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub struct LoadRequestMeta {
// Optional fields
#[serde(rename = "changeUser", skip_serializing_if = "Option::is_none")]
change_user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
database: Option<String>,
}

impl LoadRequestMeta {
Expand All @@ -61,6 +63,7 @@ impl LoadRequestMeta {
api_type,
app_name,
change_user: None,
database: None,
}
}

Expand All @@ -71,6 +74,14 @@ impl LoadRequestMeta {
pub fn set_change_user(&mut self, change_user: Option<String>) {
self.change_user = change_user;
}

pub fn database(&self) -> Option<String> {
self.database.clone()
}

pub fn set_database(&mut self, database: Option<String>) {
self.database = database;
}
}

#[derive(Debug, Deserialize)]
Expand Down
Loading