From fd1ac68dfd2b16582f80db6f4896d14823340c60 Mon Sep 17 00:00:00 2001 From: Carson Date: Thu, 5 Feb 2026 14:40:05 -0600 Subject: [PATCH] feat(rest): add SQL endpoint Add /api/v1/sql endpoint for executing plain SQL queries (no VISUALISE clause). Returns rows and columns as JSON, enabling direct database access for data exploration. Features: - Execute arbitrary SQL queries against the database - Returns JSON with rows, columns, rowCount, and truncated flag - Configurable row limit via --sql-max-rows (default: 10000) - Results truncated (not errored) when limit exceeded - Proper date/datetime serialization matching VegaLite writer Co-Authored-By: Claude Opus 4.6 --- Cargo.toml | 1 + src/Cargo.toml | 1 + src/rest.rs | 585 ++++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 583 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 14aeac86..79591321 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ uuid = { version = "1.0", features = ["v4"] } # Web server axum = "0.7" tokio = { version = "1.35", features = ["full"] } +tower = "0.5" tower-http = { version = "0.5", features = ["cors", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/src/Cargo.toml b/src/Cargo.toml index 41010871..e64156a4 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -68,6 +68,7 @@ pyo3 = { workspace = true, optional = true } [dev-dependencies] proptest.workspace = true +tower = { workspace = true } [features] default = ["duckdb", "sqlite", "vegalite"] diff --git a/src/rest.rs b/src/rest.rs index 8f2338c4..cb335607 100644 --- a/src/rest.rs +++ b/src/rest.rs @@ -11,7 +11,8 @@ ggsql-rest --host 127.0.0.1 --port 3000 ## Endpoints -- `POST /api/v1/query` - Execute a ggsql query +- `POST /api/v1/query` - Execute a ggsql query with VISUALISE (returns Vega-Lite spec) +- `POST /api/v1/sql` - Execute plain SQL query (returns rows and columns) - `POST /api/v1/parse` - Parse a ggsql query (debugging) - `GET /api/v1/health` - Health check - `GET /api/v1/version` - Version information @@ -66,6 +67,10 @@ struct Cli { /// Example: --load-data data.csv --load-data other.parquet #[arg(long = "load-data")] load_data_files: Vec, + + /// Maximum rows returned by /api/v1/sql endpoint (0 = unlimited) + #[arg(long, default_value = "10000")] + sql_max_rows: usize, } /// Shared application state @@ -75,6 +80,9 @@ struct AppState { /// Wrapped in Arc since DuckDB Connection is not Sync #[cfg(feature = "duckdb")] reader: Option>>, + + /// Maximum rows returned by SQL endpoint (0 = unlimited) + sql_max_rows: usize, } // ============================================================================ @@ -109,6 +117,13 @@ struct ParseRequest { query: String, } +/// Request body for /api/v1/sql endpoint +#[derive(Debug, Deserialize)] +struct SqlRequest { + /// SQL query to execute + query: String, +} + /// Successful API response #[derive(Debug, Serialize)] struct ApiSuccess { @@ -155,6 +170,20 @@ struct ParseResult { specs: Vec, } +/// SQL execution result data +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct SqlResult { + /// Array of row objects + rows: Vec, + /// Column names + columns: Vec, + /// Total row count before truncation + row_count: usize, + /// Whether results were truncated due to row limit + truncated: bool, +} + /// Health check response #[derive(Debug, Serialize)] struct HealthResponse { @@ -426,6 +455,51 @@ fn load_sample_data(reader: &DuckDBReader) -> Result<(), GgsqlError> { Ok(()) } +/// Convert a single value from a Polars Column to JSON +#[cfg(feature = "duckdb")] +fn column_value_to_json(column: &polars::prelude::Column, idx: usize) -> serde_json::Value { + use polars::prelude::AnyValue; + + let any_value = match column.get(idx) { + Ok(v) => v, + Err(_) => return serde_json::Value::Null, + }; + + match any_value { + AnyValue::Null => serde_json::Value::Null, + AnyValue::Boolean(b) => serde_json::Value::Bool(b), + AnyValue::Int8(v) => serde_json::Value::Number(v.into()), + AnyValue::Int16(v) => serde_json::Value::Number(v.into()), + AnyValue::Int32(v) => serde_json::Value::Number(v.into()), + AnyValue::Int64(v) => serde_json::Value::Number(v.into()), + AnyValue::UInt8(v) => serde_json::Value::Number(v.into()), + AnyValue::UInt16(v) => serde_json::Value::Number(v.into()), + AnyValue::UInt32(v) => serde_json::Value::Number(v.into()), + AnyValue::UInt64(v) => serde_json::Value::Number(v.into()), + AnyValue::Float32(v) => serde_json::Number::from_f64(v as f64) + .map(serde_json::Value::Number) + .unwrap_or(serde_json::Value::Null), + AnyValue::Float64(v) => serde_json::Number::from_f64(v) + .map(serde_json::Value::Number) + .unwrap_or(serde_json::Value::Null), + AnyValue::String(s) => serde_json::Value::String(s.to_string()), + AnyValue::StringOwned(s) => serde_json::Value::String(s.to_string()), + AnyValue::Date(days) => { + let unix_epoch = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let date = unix_epoch + chrono::Duration::days(days as i64); + serde_json::Value::String(date.format("%Y-%m-%d").to_string()) + } + AnyValue::Datetime(us, _, _) => { + let dt = chrono::DateTime::from_timestamp_micros(us).unwrap_or_default(); + serde_json::Value::String(dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()) + } + other => { + tracing::debug!("Converting unsupported Polars type to string: {:?}", other); + serde_json::Value::String(format!("{}", other)) + } + } +} + // ============================================================================ // Handler Functions // ============================================================================ @@ -561,6 +635,76 @@ async fn parse_handler( })) } +/// POST /api/v1/sql - Execute plain SQL query (no visualization) +#[cfg(feature = "duckdb")] +async fn sql_handler( + State(state): State, + Json(request): Json, +) -> Result>, ApiErrorResponse> { + info!("Executing SQL: {} chars", request.query.len()); + + let df = if let Some(ref reader_mutex) = state.reader { + let reader = reader_mutex.lock().map_err(|e| { + GgsqlError::InternalError(format!( + "Database connection unavailable (mutex poisoned): {}", + e + )) + })?; + reader.execute_sql(&request.query)? + } else { + let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + reader.execute_sql(&request.query)? + }; + + let columns: Vec = df + .get_column_names() + .iter() + .map(|s| s.to_string()) + .collect(); + + let (total_rows, _) = df.shape(); + let (rows_to_process, truncated) = if state.sql_max_rows > 0 && total_rows > state.sql_max_rows + { + info!( + "Truncating SQL results from {} to {} rows", + total_rows, + state.sql_max_rows + ); + (state.sql_max_rows, true) + } else { + (total_rows, false) + }; + + let col_refs: Vec<_> = columns + .iter() + .map(|name| df.column(name)) + .collect::, _>>() + .map_err(|e| GgsqlError::InternalError(format!("Failed to get columns: {}", e)))?; + + let mut rows: Vec = Vec::with_capacity(rows_to_process); + + for i in 0..rows_to_process { + let mut row_obj = serde_json::Map::new(); + for (col_name, column) in columns.iter().zip(&col_refs) { + let value = column_value_to_json(column, i); + row_obj.insert(col_name.clone(), value); + } + rows.push(serde_json::Value::Object(row_obj)); + } + + let result = SqlResult { + rows, + columns, + row_count: total_rows, + truncated, + }; + + Ok(Json(ApiSuccess { + status: "success".to_string(), + data: result, + })) +} + /// GET /api/v1/health - Health check async fn health_handler() -> Json { Json(HealthResponse { @@ -645,6 +789,7 @@ async fn main() -> anyhow::Result<()> { let state = AppState { #[cfg(feature = "duckdb")] reader, + sql_max_rows: cli.sql_max_rows, }; // Configure CORS @@ -666,12 +811,19 @@ async fn main() -> anyhow::Result<()> { }; // Build router - let app = Router::new() + let mut app = Router::new() .route("/", get(root_handler)) .route("/api/v1/query", post(query_handler)) .route("/api/v1/parse", post(parse_handler)) .route("/api/v1/health", get(health_handler)) - .route("/api/v1/version", get(version_handler)) + .route("/api/v1/version", get(version_handler)); + + #[cfg(feature = "duckdb")] + { + app = app.route("/api/v1/sql", post(sql_handler)); + } + + let app = app .layer(cors) .layer(tower_http::trace::TraceLayer::new_for_http()) .with_state(state); @@ -683,7 +835,8 @@ async fn main() -> anyhow::Result<()> { info!("Starting ggsql REST API server on {}", addr); info!("API documentation:"); - info!(" POST /api/v1/query - Execute ggsql query"); + info!(" POST /api/v1/query - Execute ggsql query (with VISUALISE)"); + info!(" POST /api/v1/sql - Execute plain SQL query"); info!(" POST /api/v1/parse - Parse ggsql query"); info!(" GET /api/v1/health - Health check"); info!(" GET /api/v1/version - Version info"); @@ -694,3 +847,427 @@ async fn main() -> anyhow::Result<()> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::Body; + use axum::http::{Request, StatusCode}; + use tower::util::ServiceExt; + + fn create_test_app() -> Router { + create_test_app_with_max_rows(10000) + } + + fn create_test_app_with_max_rows(sql_max_rows: usize) -> Router { + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + // Load some test data + let conn = reader.connection(); + conn.execute( + "CREATE TABLE test_table (id INTEGER, name VARCHAR)", + duckdb::params![], + ).unwrap(); + conn.execute( + "INSERT INTO test_table VALUES (1, 'Alice'), (2, 'Bob')", + duckdb::params![], + ).unwrap(); + + let state = AppState { + reader: Some(std::sync::Arc::new(std::sync::Mutex::new(reader))), + sql_max_rows, + }; + + Router::new() + .route("/", get(root_handler)) + .route("/api/v1/health", get(health_handler)) + .route("/api/v1/version", get(version_handler)) + .route("/api/v1/query", post(query_handler)) + .route("/api/v1/parse", post(parse_handler)) + .route("/api/v1/sql", post(sql_handler)) + .with_state(state) + } + + // ======================================================================== + // SQL Endpoint Tests + // ======================================================================== + + #[tokio::test] + async fn test_sql_endpoint_select() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM test_table ORDER BY id"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "success"); + assert_eq!(json["data"]["rows"].as_array().unwrap().len(), 2); + assert_eq!(json["data"]["columns"], serde_json::json!(["id", "name"])); + assert_eq!(json["data"]["rows"][0]["id"], 1); + assert_eq!(json["data"]["rows"][0]["name"], "Alice"); + assert_eq!(json["data"]["rowCount"], 2); + assert_eq!(json["data"]["truncated"], false); + } + + #[tokio::test] + async fn test_sql_endpoint_invalid_query() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM nonexistent_table"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + // Returns 400 Bad Request for SQL errors + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[tokio::test] + async fn test_sql_endpoint_create_and_query() { + let app = create_test_app(); + + // Create a new table + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "CREATE TABLE new_table AS SELECT 1 as x, 2 as y"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + // Query the new table + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM new_table"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + assert_eq!(json["data"]["rows"][0]["x"], 1); + assert_eq!(json["data"]["rows"][0]["y"], 2); + } + + #[tokio::test] + async fn test_sql_endpoint_empty_result() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM test_table WHERE 1=0"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "success"); + assert_eq!(json["data"]["rows"].as_array().unwrap().len(), 0); + assert_eq!(json["data"]["rowCount"], 0); + assert_eq!(json["data"]["truncated"], false); + } + + #[tokio::test] + async fn test_sql_endpoint_null_handling() { + let app = create_test_app(); + + let response = app + .clone() + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT 1 as a, NULL as b, 'text' as c"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["data"]["rows"][0]["a"], 1); + assert!(json["data"]["rows"][0]["b"].is_null()); + assert_eq!(json["data"]["rows"][0]["c"], "text"); + } + + #[tokio::test] + async fn test_sql_endpoint_date_types() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT DATE '2024-03-15' as d"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Date should be serialized as ISO format string + assert_eq!(json["data"]["rows"][0]["d"], "2024-03-15"); + } + + #[tokio::test] + async fn test_sql_endpoint_truncation() { + let app = create_test_app_with_max_rows(1); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/sql") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM test_table ORDER BY id"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Should return only 1 row but report total of 2 + assert_eq!(json["data"]["rows"].as_array().unwrap().len(), 1); + assert_eq!(json["data"]["rowCount"], 2); + assert_eq!(json["data"]["truncated"], true); + // First row should be Alice (ordered by id) + assert_eq!(json["data"]["rows"][0]["name"], "Alice"); + } + + // ======================================================================== + // Query Endpoint Tests + // ======================================================================== + + #[tokio::test] + async fn test_global_query_endpoint() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/query") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM test_table VISUALISE DRAW point MAPPING id AS x, id AS y"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "success"); + assert!(json["data"]["spec"].is_object()); + assert!(json["data"]["spec"]["$schema"].as_str().unwrap().contains("vega-lite")); + } + + #[tokio::test] + async fn test_global_query_invalid_syntax() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/query") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM test_table VISUALISE INVALID SYNTAX"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + // Should return an error for invalid ggsql syntax + assert!(response.status() == StatusCode::BAD_REQUEST || response.status() == StatusCode::INTERNAL_SERVER_ERROR); + } + + // ======================================================================== + // Parse Endpoint Tests + // ======================================================================== + + #[tokio::test] + async fn test_parse_endpoint() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/parse") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "SELECT * FROM t VISUALISE DRAW point MAPPING x AS x, y AS y"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert_eq!(json["status"], "success"); + // Should return parse information with sql_portion, viz_portion, and specs + assert!(json["data"]["sql_portion"].is_string()); + assert!(json["data"]["viz_portion"].is_string()); + assert!(json["data"]["specs"].is_array()); + } + + #[tokio::test] + async fn test_parse_endpoint_invalid() { + let app = create_test_app(); + + // Use completely invalid syntax that should fail to parse + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/api/v1/parse") + .header("Content-Type", "application/json") + .body(Body::from(r#"{"query": "NOT VALID SQL OR GGSQL AT ALL @@@@"}"#)) + .unwrap(), + ) + .await + .unwrap(); + + let status = response.status(); + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + // Either returns error status or has no specs + let is_error = status != StatusCode::OK || json["status"] == "error"; + // For now, accept that some invalid queries might still parse (just with empty results) + assert!(is_error || json["data"]["specs"].as_array().map(|a| a.is_empty()).unwrap_or(true)); + } + + // ======================================================================== + // Utility Endpoint Tests + // ======================================================================== + + #[tokio::test] + async fn test_root_endpoint() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("GET") + .uri("/") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let body_str = String::from_utf8_lossy(&body); + + // Root endpoint returns a plain text message + assert!(body_str.contains("ggsql")); + } + + #[tokio::test] + async fn test_health_endpoint() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("GET") + .uri("/api/v1/health") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + // Health endpoint returns "healthy" status + assert_eq!(json["status"], "healthy"); + assert!(json["version"].is_string()); + } + + #[tokio::test] + async fn test_version_endpoint() { + let app = create_test_app(); + + let response = app + .oneshot( + Request::builder() + .method("GET") + .uri("/api/v1/version") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX).await.unwrap(); + let json: serde_json::Value = serde_json::from_slice(&body).unwrap(); + + assert!(json["version"].is_string()); + assert!(json["features"].is_array()); + } +}