diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..2d359891 --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,4 @@ +[target.wasm32-unknown-unknown] +rustflags = [ + "-C", "link-args=-z stack-size=16777216", +] diff --git a/.github/workflows/wasm.yaml b/.github/workflows/wasm.yaml new file mode 100644 index 00000000..2059020d --- /dev/null +++ b/.github/workflows/wasm.yaml @@ -0,0 +1,56 @@ +name: wasm.yaml + +on: + push: + branches: [main] + pull_request: + branches: [main] + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - name: Check out repository + uses: actions/checkout@v4 + + - name: Increase disk space + run: | + sudo rm -rf /usr/share/dotnet /usr/local/lib/android /opt/ghc /opt/hostedtoolcache/CodeQL + sudo docker image prune --all --force + sudo docker builder prune -a + + - name: Setup Node.js + uses: actions/setup-node@v4 + with: + node-version: "20" + + - name: Install tree-sitter-cli + run: npm install -g tree-sitter-cli + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Install LLVM + run: sudo apt-get install -y llvm + + - name: Caching + uses: Swatinem/rust-cache@v2 + with: + shared-key: ${{ runner.os }}-build + cache-on-failure: true + + - name: Install wasm-opt + run: cargo install wasm-opt + + - name: Install wasm-pack + run: cargo install wasm-pack + + - name: Build WASM package + working-directory: ggsql-wasm + run: wasm-pack build --target web --profile wasm --no-opt + + - name: Optimise WASM binary + working-directory: ggsql-wasm + run: wasm-opt pkg/ggsql_wasm_bg.wasm -o pkg/ggsql_wasm_bg.wasm -Oz --all-features diff --git a/.gitignore b/.gitignore index 2e911c7d..96efbc92 100644 --- a/.gitignore +++ b/.gitignore @@ -104,6 +104,7 @@ docs/_build/ .env.local .env.production config.toml +!.cargo/config.toml secrets.toml # Generated documentation diff --git a/CLAUDE.md b/CLAUDE.md index 38703ecf..883775c2 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -994,10 +994,12 @@ spec = ggsql.execute( ) ``` +Required methods for custom readers (in addition to `execute_sql`): + +- `register(name: str, df: polars.DataFrame, replace: bool = False) -> None` - Register a DataFrame as a table + Optional methods for custom readers: -- `supports_register() -> bool` - Return `True` if registration is supported -- `register(name: str, df: polars.DataFrame) -> None` - Register a DataFrame as a table - `unregister(name: str) -> None` - Unregister a previously registered table Native readers (e.g., `DuckDBReader`) use an optimized fast path, while custom Python readers are automatically bridged via IPC serialization. @@ -1091,15 +1093,15 @@ Where `` can be: ### Clause Types -| Clause | Repeatable | Purpose | Example | -| -------------- | ---------- | ------------------ | ------------------------------------ | -| `VISUALISE` | ✅ Yes | Entry point | `VISUALISE date AS x, revenue AS y` | -| `DRAW` | ✅ Yes | Define layers | `DRAW line MAPPING date AS x, value AS y` | -| `SCALE` | ✅ Yes | Configure scales | `SCALE x VIA date` | -| `FACET` | ❌ No | Small multiples | `FACET WRAP region` | -| `COORD` | ❌ No | Coordinate system | `COORD cartesian SETTING xlim => [0,100]` | -| `LABEL` | ❌ No | Text labels | `LABEL title => 'My Chart', x => 'Date'` | -| `THEME` | ❌ No | Visual styling | `THEME minimal` | +| Clause | Repeatable | Purpose | Example | +| ----------- | ---------- | ----------------- | ----------------------------------------- | +| `VISUALISE` | ✅ Yes | Entry point | `VISUALISE date AS x, revenue AS y` | +| `DRAW` | ✅ Yes | Define layers | `DRAW line MAPPING date AS x, value AS y` | +| `SCALE` | ✅ Yes | Configure scales | `SCALE x VIA date` | +| `FACET` | ❌ No | Small multiples | `FACET WRAP region` | +| `COORD` | ❌ No | Coordinate system | `COORD cartesian SETTING xlim => [0,100]` | +| `LABEL` | ❌ No | Text labels | `LABEL title => 'My Chart', x => 'Date'` | +| `THEME` | ❌ No | Visual styling | `THEME minimal` | ### DRAW Clause (Layers) diff --git a/Cargo.toml b/Cargo.toml index 4098e358..ab823748 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,8 @@ members = [ "tree-sitter-ggsql", "src", "ggsql-jupyter", - "ggsql-python" + "ggsql-python", + "ggsql-wasm" ] # ggsql-python is excluded from default builds because it's a PyO3 extension # that requires Python dev headers. Build it separately with maturin. @@ -25,18 +26,18 @@ description = "SQL extension for declarative data visualization" [workspace.dependencies] # Parsing -tree-sitter = "0.25" +tree-sitter = "0.26" csscolorparser = "0.8.1" # Data processing -polars = { version = "0.52", features = ["lazy", "sql", "ipc"] } +polars = { version = "0.52", default-features = false } polars-ops = { version = "0.52", features = ["pivot"] } # Readers duckdb = { version = "1.4", features = ["bundled", "vtab-arrow"] } arrow = { version = "56", default-features = false, features = ["ipc"] } postgres = "0.19" -sqlx = { version = "0.8", features = ["postgres", "runtime-tokio-rustls"] } +sqlx = { version = "0.8", features = ["postgres"] } rusqlite = "0.32" # Writers @@ -68,7 +69,15 @@ uuid = { version = "1.0", features = ["v4"] } # Web server axum = "0.7" -tokio = { version = "1.35", features = ["full"] } +tokio = { version = "1.35", default-features = false } tower-http = { version = "0.5", features = ["cors", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[profile.wasm] +inherits = "release" +opt-level = "z" +lto = true +codegen-units = 1 +strip = true +panic = "abort" diff --git a/ggsql-python/README.md b/ggsql-python/README.md index b03b4174..a35eaae1 100644 --- a/ggsql-python/README.md +++ b/ggsql-python/README.md @@ -125,10 +125,9 @@ reader = ggsql.DuckDBReader("duckdb:///path/to/file.db") # File database **Methods:** -- `register(name: str, df: polars.DataFrame)` - Register a DataFrame as a queryable table +- `register(name: str, df: polars.DataFrame, replace: bool = False)` - Register a DataFrame as a queryable table - `unregister(name: str)` - Unregister a previously registered table - `execute_sql(sql: str) -> polars.DataFrame` - Execute SQL and return results -- `supports_register() -> bool` - Check if registration is supported #### `VegaLiteWriter()` @@ -262,11 +261,10 @@ writer = ggsql.VegaLiteWriter() json_output = writer.render(spec) ``` -**Optional methods** for custom readers: +**Additional methods** for custom readers: -- `supports_register() -> bool` - Return `True` if your reader supports DataFrame registration -- `register(name: str, df: polars.DataFrame) -> None` - Register a DataFrame as a queryable table -- `unregister(name: str) -> None` - Unregister a previously registered table +- `register(name: str, df: polars.DataFrame, replace: bool = False) -> None` - Register a DataFrame as a queryable table (required) +- `unregister(name: str) -> None` - Unregister a previously registered table (optional) ```python class AdvancedReader: @@ -279,10 +277,7 @@ class AdvancedReader: # Your SQL execution logic here ... - def supports_register(self) -> bool: - return True - - def register(self, name: str, df: pl.DataFrame) -> None: + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: self.tables[name] = df def unregister(self, name: str) -> None: @@ -313,11 +308,8 @@ class IbisReader: def execute_sql(self, sql: str) -> pl.DataFrame: return self.con.con.execute(sql).pl() - def supports_register(self) -> bool: - return True - - def register(self, name: str, df: pl.DataFrame) -> None: - self.con.create_table(name, df.to_arrow(), overwrite=True) + def register(self, name: str, df: pl.DataFrame, replace: bool = False) -> None: + self.con.create_table(name, df.to_arrow(), overwrite=replace) def unregister(self, name: str) -> None: self.con.drop_table(name) diff --git a/ggsql-python/src/lib.rs b/ggsql-python/src/lib.rs index ffe38a01..27d26b68 100644 --- a/ggsql-python/src/lib.rs +++ b/ggsql-python/src/lib.rs @@ -138,29 +138,19 @@ impl Reader for PyReaderBridge { }) } - fn supports_register(&self) -> bool { - Python::attach(|py| { - self.obj - .bind(py) - .call_method0("supports_register") - .and_then(|r| r.extract::()) - .unwrap_or(false) - }) - } - - fn register(&mut self, name: &str, df: DataFrame) -> ggsql::Result<()> { + fn register(&self, name: &str, df: DataFrame, replace: bool) -> ggsql::Result<()> { Python::attach(|py| { let py_df = polars_to_py(py, &df).map_err(|e| GgsqlError::ReaderError(e.to_string()))?; self.obj .bind(py) - .call_method1("register", (name, py_df)) + .call_method1("register", (name, py_df, replace)) .map_err(|e| GgsqlError::ReaderError(format!("Reader.register() failed: {}", e)))?; Ok(()) }) } - fn unregister(&mut self, name: &str) -> ggsql::Result<()> { + fn unregister(&self, name: &str) -> ggsql::Result<()> { Python::attach(|py| { self.obj .bind(py) @@ -254,10 +244,17 @@ impl PyDuckDBReader { /// ------ /// ValueError /// If registration fails or the table name is invalid. - fn register(&mut self, py: Python<'_>, name: &str, df: &Bound<'_, PyAny>) -> PyResult<()> { + #[pyo3(signature = (name, df, replace=false))] + fn register( + &self, + py: Python<'_>, + name: &str, + df: &Bound<'_, PyAny>, + replace: bool, + ) -> PyResult<()> { let rust_df = py_to_polars(py, df)?; self.inner - .register(name, rust_df) + .register(name, rust_df, replace) .map_err(|e| PyErr::new::(e.to_string())) } @@ -272,7 +269,7 @@ impl PyDuckDBReader { /// ------ /// ValueError /// If the table wasn't registered via this reader or unregistration fails. - fn unregister(&mut self, name: &str) -> PyResult<()> { + fn unregister(&self, name: &str) -> PyResult<()> { self.inner .unregister(name) .map_err(|e| PyErr::new::(e.to_string())) @@ -302,16 +299,6 @@ impl PyDuckDBReader { polars_to_py(py, &df) } - /// Check if this reader supports DataFrame registration. - /// - /// Returns - /// ------- - /// bool - /// True if register() is supported, False otherwise. - fn supports_register(&self) -> bool { - self.inner.supports_register() - } - /// Execute a ggsql query and return the visualization specification. /// /// This is the main entry point for creating visualizations. It parses diff --git a/ggsql-python/tests/test_ggsql.py b/ggsql-python/tests/test_ggsql.py index bf01e140..6e9183fc 100644 --- a/ggsql-python/tests/test_ggsql.py +++ b/ggsql-python/tests/test_ggsql.py @@ -85,10 +85,6 @@ def test_register_and_query(self): assert isinstance(result, pl.DataFrame) assert result.shape == (2, 2) - def test_supports_register(self): - reader = ggsql.DuckDBReader("duckdb://memory") - assert reader.supports_register() is True - def test_invalid_connection_string(self): with pytest.raises(ValueError): ggsql.DuckDBReader("invalid://connection") @@ -396,25 +392,6 @@ def test_can_introspect_spec(self): class TestCustomReader: """Tests for custom Python reader support.""" - def test_simple_custom_reader(self): - """Custom reader with execute_sql() method works.""" - - class SimpleReader: - def __init__(self): - self.conn = duckdb.connect() - self.conn.execute( - "CREATE TABLE data AS SELECT * FROM (" - "VALUES (1, 10), (2, 20), (3, 30)" - ") AS t(x, y)" - ) - - def execute_sql(self, sql: str) -> pl.DataFrame: - return self.conn.execute(sql).pl() - - reader = SimpleReader() - spec = ggsql.execute("SELECT * FROM data VISUALISE x, y DRAW point", reader) - assert spec.metadata()["rows"] == 3 - def test_custom_reader_with_register(self): """Custom reader with register() support.""" @@ -425,10 +402,7 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.conn.execute(sql).pl() - def supports_register(self) -> bool: - return True - - def register(self, name: str, df: pl.DataFrame) -> None: + def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: self.conn.register(name, df) reader = RegisterReader() @@ -479,6 +453,9 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.conn.execute(sql).pl() + def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + self.conn.register(name, df) + reader = DuckDBBackedReader() spec = ggsql.execute( "SELECT * FROM data VISUALISE x, y, category AS color DRAW point", @@ -499,8 +476,7 @@ class RecordingReader: def __init__(self): self.conn = duckdb.connect() self.conn.execute( - "CREATE TABLE data AS SELECT * FROM (" - "VALUES (1, 2)) AS t(x, y)" + "CREATE TABLE data AS SELECT * FROM (VALUES (1, 2)) AS t(x, y)" ) self.execute_calls = [] @@ -508,6 +484,9 @@ def execute_sql(self, sql: str) -> pl.DataFrame: self.execute_calls.append(sql) return self.conn.execute(sql).pl() + def register(self, name: str, df: pl.DataFrame, _replace: bool) -> None: + self.conn.register(name, df) + reader = RecordingReader() ggsql.execute( "SELECT * FROM data VISUALISE x, y DRAW point", @@ -530,11 +509,10 @@ def __init__(self): def execute_sql(self, sql: str) -> pl.DataFrame: return self.con.con.execute(sql).pl() - def supports_register(self) -> bool: - return True - - def register(self, name: str, df: pl.DataFrame) -> None: - self.con.create_table(name, df.to_arrow(), overwrite=True) + def register( + self, name: str, df: pl.DataFrame, replace: bool = True + ) -> None: + self.con.create_table(name, df.to_arrow(), overwrite=replace) def unregister(self, name: str) -> None: self.con.drop_table(name) diff --git a/ggsql-wasm/Cargo.toml b/ggsql-wasm/Cargo.toml new file mode 100644 index 00000000..cc7bcf5b --- /dev/null +++ b/ggsql-wasm/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "ggsql-wasm" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +description.workspace = true + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wasm-bindgen = "0.2" +js-sys = "0.3" +csv = "1" +polars = { version = "0.52", default-features = false, features = ["sql", "dtype-full"] } +ggsql = { path = "../src", default-features = false, features = ["polars-sql", "vegalite"] } + +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +tokio = { version = "1.35", features = ["full"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +tokio = { version = "1.35", default-features = false } + diff --git a/ggsql-wasm/LICENSE b/ggsql-wasm/LICENSE new file mode 100644 index 00000000..555e1c3d --- /dev/null +++ b/ggsql-wasm/LICENSE @@ -0,0 +1,7 @@ +Copyright 2026 ggsql authors + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/ggsql-wasm/src/lib.rs b/ggsql-wasm/src/lib.rs new file mode 100644 index 00000000..3cb1d1bc --- /dev/null +++ b/ggsql-wasm/src/lib.rs @@ -0,0 +1,75 @@ +use ggsql::reader::{PolarsReader, Reader}; +use ggsql::writer::{VegaLiteWriter, Writer}; +use std::cell::RefCell; + +use wasm_bindgen::prelude::*; + +/// Persistent ggsql context for WASM +/// +/// Create once and reuse for multiple queries to avoid memory issues. +/// Uses interior mutability to avoid wasm_bindgen's &mut self aliasing issues. +#[wasm_bindgen] +pub struct GgsqlContext { + reader: RefCell, + writer: VegaLiteWriter, +} + +#[wasm_bindgen] +impl GgsqlContext { + /// Create a new ggsql context + #[wasm_bindgen(constructor)] + pub fn new() -> Result { + let reader = PolarsReader::from_connection_string("polars://memory") + .map_err(|e| JsValue::from_str(&format!("Reader error: {:?}", e)))?; + let writer = VegaLiteWriter::new(); + Ok(GgsqlContext { + reader: RefCell::new(reader), + writer, + }) + } + + /// Execute a ggsql query and return Vega-Lite JSON + pub fn execute(&self, query: &str) -> Result { + // Scope the mutable borrow to avoid aliasing issues + let spec = { + let reader = self.reader.borrow_mut(); + reader + .execute(query) + .map_err(|e| JsValue::from_str(&format!("Execute error: {:?}", e)))? + }; + + let result = self + .writer + .render(&spec) + .map_err(|e| JsValue::from_str(&format!("Render error: {:?}", e)))?; + + Ok(result) + } + + // TODO: Register a table from binary data (e.g. CSV, Parquet) + pub fn register(&self, _name: &str) -> Result<(), JsValue> { + Err(JsValue::from_str("Registration not yet implemented.")) + } + + /// Unregister a table + pub fn unregister(&self, name: &str) -> Result<(), JsValue> { + let reader = self.reader.borrow(); + reader + .unregister(name) + .map_err(|e| JsValue::from_str(&format!("Unregister error: {:?}", e)))?; + + Ok(()) + } + + /// List all registered tables + pub fn list_tables(&self) -> JsValue { + let reader = self.reader.borrow(); + let tables = reader.list_tables(false); + + let array = js_sys::Array::new(); + for table in tables { + array.push(&JsValue::from_str(&table)); + } + array.into() + } +} diff --git a/src/Cargo.toml b/src/Cargo.toml index 2e6fb1e6..53102166 100644 --- a/src/Cargo.toml +++ b/src/Cargo.toml @@ -31,7 +31,7 @@ csscolorparser.workspace = true palette.workspace = true # Data processing -polars.workspace = true +polars = { workspace = true, features = ["lazy", "sql"] } polars-ops.workspace = true # Readers @@ -62,7 +62,7 @@ uuid.workspace = true # Web server (optional) axum = { workspace = true, optional = true } -tokio = { workspace = true, optional = true } +tokio = { workspace = true, optional = true, features = ["full"] } tower-http = { workspace = true, optional = true } tracing = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } @@ -74,13 +74,16 @@ pyo3 = { workspace = true, optional = true } proptest.workspace = true [features] -default = ["duckdb", "sqlite", "vegalite"] +default = ["duckdb", "sqlite", "vegalite", "ipc", "builtin-data"] +ipc = ["polars/ipc"] duckdb = ["dep:duckdb", "dep:arrow"] +polars-sql = ["polars/sql"] +builtin-data = ["polars/parquet"] postgres = ["dep:postgres"] sqlite = ["dep:rusqlite"] vegalite = [] ggplot2 = [] python = ["dep:pyo3"] rest-api = ["dep:axum", "dep:tokio", "dep:tower-http", "dep:tracing", "dep:tracing-subscriber", "duckdb", "vegalite"] -all-readers = ["duckdb", "postgres", "sqlite"] +all-readers = ["duckdb", "postgres", "sqlite", "polars-sql"] all-writers = ["vegalite", "ggplot2", "plotters"] diff --git a/src/doc/API.md b/src/doc/API.md index 5ac9ddae..8cc962e6 100644 --- a/src/doc/API.md +++ b/src/doc/API.md @@ -374,13 +374,10 @@ pub trait Reader { fn execute_sql(&self, sql: &str) -> Result; /// Register a DataFrame as a queryable table - fn register(&mut self, name: &str, df: DataFrame) -> Result<()>; + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()>; /// Unregister a previously registered table - fn unregister(&mut self, name: &str) -> Result<()>; - - /// Check if this reader supports DataFrame registration - fn supports_register(&self) -> bool; + fn unregister(&self, name: &str) -> Result<()>; } ``` @@ -434,9 +431,6 @@ class DuckDBReader: def execute_sql(self, sql: str) -> polars.DataFrame: """Execute SQL and return a Polars DataFrame.""" - - def supports_register(self) -> bool: - """Check if registration is supported.""" ``` #### `VegaLiteWriter` diff --git a/src/execute/cte.rs b/src/execute/cte.rs index 2883bc5d..b83abc33 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -4,7 +4,8 @@ //! materializing them as temporary tables, and transforming CTE references //! in SQL queries. -use crate::{naming, parser::SourceTree, DataFrame, GgsqlError, Result}; +use crate::reader::Reader; +use crate::{naming, parser::SourceTree, GgsqlError, Result}; use std::collections::HashSet; use tree_sitter::Node; @@ -125,10 +126,7 @@ pub fn transform_cte_references(sql: &str, cte_names: &HashSet) -> Strin /// temp table name. /// /// Returns the set of CTE names that were materialized. -pub fn materialize_ctes(ctes: &[CteDefinition], execute_sql: &F) -> Result> -where - F: Fn(&str) -> Result, -{ +pub fn materialize_ctes(ctes: &[CteDefinition], reader: &dyn Reader) -> Result> { let mut materialized = HashSet::new(); for cte in ctes { @@ -136,14 +134,14 @@ where let transformed_body = transform_cte_references(&cte.body, &materialized); let temp_table_name = naming::cte_table(&cte.name); - let create_sql = format!( - "CREATE OR REPLACE TEMP TABLE {} AS {}", - temp_table_name, transformed_body - ); - execute_sql(&create_sql).map_err(|e| { + // Execute the CTE body SQL to get a DataFrame, then register it + let df = reader.execute_sql(&transformed_body).map_err(|e| { GgsqlError::ReaderError(format!("Failed to materialize CTE '{}': {}", cte.name, e)) })?; + reader.register(&temp_table_name, df, true).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to register CTE '{}': {}", cte.name, e)) + })?; materialized.insert(cte.name.clone()); } @@ -151,30 +149,40 @@ where Ok(materialized) } -/// Extract the trailing SELECT statement from a WITH clause using existing tree +/// Split a WITH...SELECT query into its CTE prefix and trailing SELECT. +/// +/// Given SQL like `WITH a AS (...), b AS (...) SELECT * FROM a`, returns: +/// - CTE prefix: `"WITH a AS (...), b AS (...)"` +/// - Trailing SELECT: `"SELECT * FROM a"` /// -/// Given SQL like `WITH a AS (...), b AS (...) SELECT * FROM a`, extracts -/// just the `SELECT * FROM a` part. Returns None if there's no trailing SELECT. -pub fn extract_trailing_select(source_tree: &SourceTree) -> Option { +/// Returns `None` if the query is not a WITH statement, has no trailing SELECT, +/// or parsing fails. +pub fn split_with_query(source_tree: &SourceTree) -> Option<(String, String)> { let root = source_tree.root(); + let with_node = source_tree.find_node(&root, "(with_statement) @with")?; + + let mut cursor = with_node.walk(); + let mut last_cte_end: Option = None; + let mut select_node = None; + let mut seen_cte = false; - // Try to find WITH statement first - if let Some(with_node) = source_tree.find_node(&root, "(with_statement) @with") { - // Look for the trailing SELECT that comes AFTER cte_definition nodes - let mut cursor = with_node.walk(); - let mut seen_cte = false; - for child in with_node.children(&mut cursor) { - if child.kind() == "cte_definition" { + for child in with_node.children(&mut cursor) { + match child.kind() { + "cte_definition" => { seen_cte = true; - } else if child.kind() == "select_statement" && seen_cte { - // This is the trailing SELECT after CTEs - return Some(source_tree.get_text(&child)); + last_cte_end = Some(child.end_byte()); + } + "select_statement" if seen_cte => { + select_node = Some(child); + break; } + _ => {} } } - // Otherwise, look for direct SELECT statement (no WITH clause) - source_tree.find_text(&root, "(sql_statement (select_statement) @select)") + let cte_prefix = source_tree.source[with_node.start_byte()..last_cte_end?].to_string(); + let trailing_select = source_tree.get_text(&select_node?); + Some((cte_prefix, trailing_select)) } /// Transform global SQL for execution with temp tables @@ -186,9 +194,16 @@ pub fn transform_global_sql( source_tree: &SourceTree, materialized_ctes: &HashSet, ) -> Option { - // Try to extract SELECT (handles both WITH...SELECT and direct SELECT) - if let Some(select_sql) = extract_trailing_select(source_tree) { - // Transform CTE references in the SELECT + // Try to extract trailing SELECT (WITH...SELECT or direct SELECT) + let select_sql = split_with_query(source_tree) + .map(|(_, select)| select) + .or_else(|| { + // Fallback: direct SELECT statement (no WITH clause) + let root = source_tree.root(); + source_tree.find_text(&root, "(sql_statement (select_statement) @select)") + }); + + if let Some(select_sql) = select_sql { Some(transform_cte_references(&select_sql, materialized_ctes)) } else if has_executable_sql(source_tree) { // Non-SELECT executable SQL (CREATE, INSERT, UPDATE, DELETE) @@ -226,11 +241,8 @@ pub fn has_executable_sql(source_tree: &SourceTree) -> bool { } // Check for WITH statements that have trailing SELECT - let with_statements = source_tree.find_nodes(&root, "(with_statement) @with"); - for with_node in with_statements { - if with_has_trailing_select(&with_node) { - return true; - } + if split_with_query(source_tree).is_some() { + return true; } // Check for VISUALISE FROM (which injects SELECT * FROM ) @@ -245,22 +257,6 @@ pub fn has_executable_sql(source_tree: &SourceTree) -> bool { false } -/// Check if a with_statement node has a trailing SELECT (after CTEs) -fn with_has_trailing_select(with_node: &Node) -> bool { - let mut cursor = with_node.walk(); - let mut seen_cte = false; - - for child in with_node.children(&mut cursor) { - if child.kind() == "cte_definition" { - seen_cte = true; - } else if child.kind() == "select_statement" && seen_cte { - return true; - } - } - - false -} - #[cfg(test)] mod tests { use super::*; @@ -384,4 +380,84 @@ mod tests { } } } + + #[test] + fn test_split_with_query_basic() { + let sql = "WITH cte AS (SELECT * FROM x) SELECT * FROM cte"; + let source_tree = SourceTree::new(sql).unwrap(); + let (prefix, select) = split_with_query(&source_tree).unwrap(); + + assert_eq!(prefix, "WITH cte AS (SELECT * FROM x)"); + assert_eq!(select, "SELECT * FROM cte"); + } + + #[test] + fn test_split_with_query_multiple_ctes() { + let sql = "WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a JOIN b"; + let source_tree = SourceTree::new(sql).unwrap(); + let (prefix, select) = split_with_query(&source_tree).unwrap(); + + assert_eq!(prefix, "WITH a AS (SELECT 1), b AS (SELECT 2)"); + assert_eq!(select, "SELECT * FROM a JOIN b"); + } + + #[test] + fn test_split_with_query_nested_subquery() { + let sql = "WITH cte AS (SELECT * FROM (SELECT 1)) SELECT * FROM cte"; + let source_tree = SourceTree::new(sql).unwrap(); + let (prefix, select) = split_with_query(&source_tree).unwrap(); + + assert_eq!(prefix, "WITH cte AS (SELECT * FROM (SELECT 1))"); + assert_eq!(select, "SELECT * FROM cte"); + } + + #[test] + fn test_split_with_query_string_with_select_keyword() { + let sql = "WITH cte AS (SELECT 'SELECT' AS col) SELECT * FROM cte"; + let source_tree = SourceTree::new(sql).unwrap(); + let (prefix, select) = split_with_query(&source_tree).unwrap(); + + assert_eq!(prefix, "WITH cte AS (SELECT 'SELECT' AS col)"); + assert_eq!(select, "SELECT * FROM cte"); + } + + #[test] + fn test_split_with_query_string_with_parens() { + let sql = "WITH cte AS (SELECT '()' AS col) SELECT * FROM cte"; + let source_tree = SourceTree::new(sql).unwrap(); + let (prefix, select) = split_with_query(&source_tree).unwrap(); + + assert_eq!(prefix, "WITH cte AS (SELECT '()' AS col)"); + assert_eq!(select, "SELECT * FROM cte"); + } + + #[test] + fn test_split_with_query_not_a_with() { + let sql = "SELECT * FROM x"; + let source_tree = SourceTree::new(sql).unwrap(); + assert!(split_with_query(&source_tree).is_none()); + } + + #[test] + fn test_split_with_query_no_trailing_select() { + let sql = "WITH cte AS (SELECT 1) VISUALISE DRAW point"; + let source_tree = SourceTree::new(sql).unwrap(); + assert!(split_with_query(&source_tree).is_none()); + } + + #[test] + fn test_split_with_query_stat_transform_output() { + // Realistic stat transform output (histogram pattern) + let sql = "WITH __stat_src__ AS (SELECT x FROM data), \ + __binned__ AS (SELECT x, COUNT(*) AS count FROM __stat_src__ GROUP BY x) \ + SELECT *, count * 1.0 / SUM(count) OVER () AS density FROM __binned__"; + let source_tree = SourceTree::new(sql).unwrap(); + let (prefix, select) = split_with_query(&source_tree).unwrap(); + + assert!(prefix.starts_with("WITH __stat_src__")); + assert!(prefix.contains("__binned__")); + assert!(prefix.ends_with(")")); + assert!(select.starts_with("SELECT *")); + assert!(select.contains("density")); + } } diff --git a/src/execute/layer.rs b/src/execute/layer.rs index 2d88fabc..cf4e2647 100644 --- a/src/execute/layer.rs +++ b/src/execute/layer.rs @@ -187,7 +187,8 @@ pub fn literal_to_series(name: &str, lit: &ParameterValue, len: usize) -> polars pub fn apply_pre_stat_transform( query: &str, layer: &Layer, - schema: &Schema, + full_schema: &Schema, + aesthetic_schema: &Schema, scales: &[Scale], type_names: &SqlTypeNames, ) -> String { @@ -212,8 +213,8 @@ pub fn apply_pre_stat_transform( continue; } - // Find column dtype from schema using aesthetic column name - let col_dtype = schema + // Find column dtype from aesthetic schema using aesthetic column name + let col_dtype = aesthetic_schema .iter() .find(|c| c.name == aes_col_name) .map(|c| c.dtype.clone()) @@ -238,41 +239,29 @@ pub fn apply_pre_stat_transform( return query.to_string(); } - // Build wrapper: SELECT {transformed_cols}, other_cols FROM ({query}) - // For each transformed column, use the SQL expression; for others, keep as-is - let transformed_col_names: HashSet<&str> = - transform_exprs.iter().map(|(c, _)| c.as_str()).collect(); - - // Build column list: all columns, with transformed ones replaced by their expressions - let col_exprs: Vec = transform_exprs - .iter() - .map(|(col, sql)| format!("{} AS {}", sql, col)) + // Build explicit column list from full_schema (original columns) and + // aesthetic_schema (aesthetic columns added by build_layer_base_query). + // The base query produces SELECT *, col AS __ggsql_aes_x__, ... so the + // actual SQL output has both, but they come from different schema sources. + // This avoids SELECT * EXCLUDE which has portability issues + // (Polars SQL silently drops re-added columns with the same name). + let mut seen: HashSet<&str> = HashSet::new(); + let combined_cols = full_schema.iter().chain(aesthetic_schema.iter()); + + let select_exprs: Vec = combined_cols + .filter(|col| seen.insert(&col.name)) + .map(|col| { + if let Some((_, sql)) = transform_exprs.iter().find(|(c, _)| c == &col.name) { + format!("{} AS \"{}\"", sql, col.name) + } else { + format!("\"{}\"", col.name) + } + }) .collect(); - // Build the excluded columns list for the * expansion - // We need to select *, but exclude the columns we're replacing - if col_exprs.is_empty() { - return query.to_string(); - } - - // Use EXCLUDE to remove the original columns, then add the transformed versions - let exclude_clause = if transformed_col_names.len() == 1 { - format!("EXCLUDE ({})", transformed_col_names.iter().next().unwrap()) - } else { - format!( - "EXCLUDE ({})", - transformed_col_names - .iter() - .cloned() - .collect::>() - .join(", ") - ) - }; - format!( - "SELECT * {}, {} FROM ({}) AS __ggsql_pre__", - exclude_clause, - col_exprs.join(", "), + "SELECT {} FROM ({}) AS __ggsql_pre__", + select_exprs.join(", "), query ) } @@ -394,7 +383,14 @@ where // Apply pre-stat transforms (e.g., binning, discrete censoring) // Uses aesthetic names since columns are now renamed and mappings updated - let query = apply_pre_stat_transform(base_query, layer, &aesthetic_schema, scales, type_names); + let query = apply_pre_stat_transform( + base_query, + layer, + schema, + &aesthetic_schema, + scales, + type_names, + ); // Build group_by columns from partition_by and facet variables let mut group_by: Vec = Vec::new(); @@ -533,12 +529,33 @@ where .map(|s| naming::stat_column(s)) .collect(); let exclude_clause = format!("EXCLUDE ({})", stat_col_names.join(", ")); - format!( - "SELECT * {}, {} FROM ({}) AS __ggsql_stat__", - exclude_clause, - stat_rename_exprs.join(", "), - transformed_query - ) + + // If the transformed query uses CTEs (WITH ... SELECT ...), + // we can't wrap it in a subquery because Polars SQL doesn't + // support CTEs inside subqueries. Instead, split into CTE + // prefix + trailing SELECT, then append the trailing SELECT + // as another CTE and add the rename SELECT on top. + if let Some((cte_prefix, trailing_select)) = + crate::parser::SourceTree::new(&transformed_query) + .ok() + .as_ref() + .and_then(super::cte::split_with_query) + { + format!( + "{}, __ggsql_stat__ AS ({}) SELECT * {}, {} FROM __ggsql_stat__", + cte_prefix, + trailing_select, + exclude_clause, + stat_rename_exprs.join(", ") + ) + } else { + format!( + "SELECT * {}, {} FROM ({}) AS __ggsql_stat__", + exclude_clause, + stat_rename_exprs.join(", "), + transformed_query + ) + } } } StatResult::Identity => query, diff --git a/src/execute/mod.rs b/src/execute/mod.rs index a7fce754..a7f12714 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -486,10 +486,7 @@ pub struct PreparedData { /// # Arguments /// * `query` - The full ggsql query string /// * `reader` - A Reader implementation for executing SQL -pub fn prepare_data_with_reader( - query: &str, - reader: &R, -) -> Result { +pub fn prepare_data_with_reader(query: &str, reader: &R) -> Result { let execute_query = |sql: &str| reader.execute_sql(sql); let type_names = reader.sql_type_names(); @@ -520,9 +517,8 @@ pub fn prepare_data_with_reader( // Extract CTE definitions from the source tree (in declaration order) let ctes = cte::extract_ctes(&source_tree); - // Materialize CTEs as temporary tables - // This creates __ggsql_cte___ tables that persist for the session - let materialized_ctes = cte::materialize_ctes(&ctes, &execute_query)?; + // Materialize CTEs as registered tables via reader.register() + let materialized_ctes = cte::materialize_ctes(&ctes, reader)?; // Build data map for multi-source support let mut data_map: HashMap = HashMap::new(); @@ -537,13 +533,9 @@ pub fn prepare_data_with_reader( let mut has_global_table = false; if sql_part.is_some() { if let Some(transformed_sql) = cte::transform_global_sql(&source_tree, &materialized_ctes) { - // Create temp table for global result - let create_global = format!( - "CREATE OR REPLACE TEMP TABLE {} AS {}", - naming::global_table(), - transformed_sql - ); - execute_query(&create_global)?; + // Execute global result SQL and register result as a temp table + let df = execute_query(&transformed_sql)?; + reader.register(&naming::global_table(), df, true)?; // NOTE: Don't read into data_map yet - defer until after casting is determined // The temp table exists and can be used for schema fetching diff --git a/src/lib.rs b/src/lib.rs index 888970af..5e41874a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,13 +41,11 @@ pub mod naming; pub mod parser; pub mod plot; -#[cfg(any(feature = "duckdb", feature = "postgres", feature = "sqlite"))] pub mod reader; #[cfg(any(feature = "vegalite", feature = "ggplot2", feature = "plotters"))] pub mod writer; -#[cfg(feature = "duckdb")] pub mod execute; pub mod validate; diff --git a/src/naming.rs b/src/naming.rs index be0fe505..96feccdc 100644 --- a/src/naming.rs +++ b/src/naming.rs @@ -69,6 +69,9 @@ const LAYER_PREFIX: &str = concatcp!(GGSQL_PREFIX, "layer_"); /// Full prefix for aesthetic columns: `__ggsql_aes_` const AES_PREFIX: &str = concatcp!(GGSQL_PREFIX, "aes_"); +/// Full prefix for builtin data tables: `__ggsql_data_` +const DATA_PREFIX: &str = concatcp!(GGSQL_PREFIX, "data_"); + /// Key for global data in the layer data HashMap. /// Used as the key in PreparedData.data to store global data that applies to all layers. /// This is NOT a SQL table name - use `global_table()` for SQL statements. @@ -129,6 +132,21 @@ pub fn cte_table(cte_name: &str) -> String { ) } +/// Generate table name for a builtin dataset. +/// +/// Used when rewriting `ggsql:penguins` to the internal table name. +/// Format: `__ggsql_data___` +/// +/// # Example +/// ``` +/// use ggsql::naming; +/// assert_eq!(naming::builtin_data_table("penguins"), "__ggsql_data_penguins__"); +/// assert_eq!(naming::builtin_data_table("airquality"), "__ggsql_data_airquality__"); +/// ``` +pub fn builtin_data_table(name: &str) -> String { + format!("{}{}{}", DATA_PREFIX, name, GGSQL_SUFFIX) +} + /// Generate column name for a constant aesthetic value. /// /// Used when a single layer has a literal aesthetic value that needs @@ -439,6 +457,15 @@ mod tests { assert_eq!(bin_end_column("__ggsql_aes_y__"), "__ggsql_aes_y2__"); } + #[test] + fn test_builtin_data_table() { + assert_eq!(builtin_data_table("penguins"), "__ggsql_data_penguins__"); + assert_eq!( + builtin_data_table("airquality"), + "__ggsql_data_airquality__" + ); + } + #[test] fn test_prefixes_built_from_components() { // Verify prefixes are correctly composed from building blocks @@ -447,6 +474,7 @@ mod tests { assert_eq!(CTE_PREFIX, "__ggsql_cte_"); assert_eq!(LAYER_PREFIX, "__ggsql_layer_"); assert_eq!(AES_PREFIX, "__ggsql_aes_"); + assert_eq!(DATA_PREFIX, "__ggsql_data_"); } #[test] diff --git a/src/plot/layer/geom/boxplot.rs b/src/plot/layer/geom/boxplot.rs index 877a67c4..e1fc0669 100644 --- a/src/plot/layer/geom/boxplot.rs +++ b/src/plot/layer/geom/boxplot.rs @@ -157,55 +157,8 @@ fn stat_boxplot( }) } -fn boxplot_sql_assign_quartiles(from: &str, groups: &[String], value: &str) -> String { - // Selects all relevant columns and adds a quartile column. - // NTILE(4) may create uneven groups - format!( - "SELECT - {value}, - {groups}, - NTILE(4) OVER (PARTITION BY {groups} ORDER BY {value} ASC) AS _Q - FROM ({from}) - WHERE {value} IS NOT NULL", - value = value, - groups = groups.join(", "), - from = from - ) -} - -fn boxplot_sql_quartile_minmax(from: &str, groups: &[String], value: &str) -> String { - // Compute the min and max for every quartile. - // The verbosity here is to pivot the table to a wide format. - // The output is a table with 1 row per groups annotated with quartile metrics - format!( - "SELECT - MIN(CASE WHEN _Q = 1 THEN {value} END) AS Q1_min, - MAX(CASE WHEN _Q = 1 THEN {value} END) AS Q1_max, - MIN(CASE WHEN _Q = 2 THEN {value} END) AS Q2_min, - MAX(CASE WHEN _Q = 2 THEN {value} END) AS Q2_max, - MIN(CASE WHEN _Q = 3 THEN {value} END) AS Q3_min, - MAX(CASE WHEN _Q = 3 THEN {value} END) AS Q3_max, - MIN(CASE WHEN _Q = 4 THEN {value} END) AS Q4_min, - MAX(CASE WHEN _Q = 4 THEN {value} END) AS Q4_max, - {groups} - FROM ({from}) - GROUP BY {groups}", - groups = groups.join(", "), - value = value, - from = from - ) -} - -fn boxplot_sql_compute_fivenum(from: &str, groups: &[String], coef: &f64) -> String { - // Here we compute the 5 statistics: - // * lower: lower whisker - // * upper: upper whisker - // * q1: box start - // * q3: box end - // * median - // We're assuming equally sized quartiles here, but we may have 1-member - // differences. For large datasets this shouldn't be a problem, but in smaller - // datasets one might notice. +fn boxplot_sql_compute_summary(from: &str, groups: &[String], value: &str, coef: &f64) -> String { + let groups_str = groups.join(", "); format!( "SELECT *, @@ -213,26 +166,23 @@ fn boxplot_sql_compute_fivenum(from: &str, groups: &[String], coef: &f64) -> Str LEAST( q3 + {coef} * (q3 - q1), max) AS upper FROM ( SELECT - Q1_min AS min, - Q4_max AS max, - (Q2_max + Q3_min) / 2.0 AS median, - (Q1_max + Q2_min) / 2.0 AS q1, - (Q3_max + Q4_min) / 2.0 AS q3, - {groups} - FROM ({from}) - )", + {groups}, + MIN({value}) AS min, + MAX({value}) AS max, + QUANTILE_CONT({value}, 0.25) AS q1, + QUANTILE_CONT({value}, 0.50) AS median, + QUANTILE_CONT({value}, 0.75) AS q3 + FROM ({from}) AS __ggsql_qt__ + WHERE {value} IS NOT NULL + GROUP BY {groups} + ) AS __ggsql_fn__", coef = coef, - groups = groups.join(", "), + groups = groups_str, + value = value, from = from ) } -fn boxplot_sql_compute_summary(from: &str, groups: &[String], value: &str, coef: &f64) -> String { - let query = boxplot_sql_assign_quartiles(from, groups, value); - let query = boxplot_sql_quartile_minmax(&query, groups, value); - boxplot_sql_compute_fivenum(&query, groups, coef) -} - fn boxplot_sql_filter_outliers(groups: &[String], value: &str, from: &str) -> String { let mut join_pairs = Vec::new(); let mut keep_columns = Vec::new(); @@ -350,41 +300,35 @@ mod tests { // ==================== SQL Generation Tests (Compact) ==================== #[test] - fn test_sql_assign_quartiles_basic() { + fn test_sql_compute_summary_basic() { let groups = vec!["category".to_string()]; - let result = boxplot_sql_assign_quartiles("data", &groups, "value"); - assert!(result.contains("NTILE(4)")); - assert!(result.contains("PARTITION BY category")); + let result = boxplot_sql_compute_summary("data", &groups, "value", &1.5); + assert!(result.contains("QUANTILE_CONT(value, 0.25)")); + assert!(result.contains("QUANTILE_CONT(value, 0.50)")); + assert!(result.contains("QUANTILE_CONT(value, 0.75)")); + assert!(result.contains("MIN(value) AS min")); + assert!(result.contains("MAX(value) AS max")); assert!(result.contains("WHERE value IS NOT NULL")); + assert!(result.contains("GROUP BY category")); + assert!(result.contains("GREATEST")); + assert!(result.contains("LEAST")); } #[test] - fn test_sql_assign_quartiles_multiple_groups() { + fn test_sql_compute_summary_multiple_groups() { let groups = vec!["cat".to_string(), "region".to_string()]; - let result = boxplot_sql_assign_quartiles("tbl", &groups, "val"); - assert!(result.contains("PARTITION BY cat, region")); - } - - #[test] - fn test_sql_quartile_minmax_structure() { - let groups = vec!["grp".to_string()]; - let result = boxplot_sql_quartile_minmax("query", &groups, "v"); - assert!(result.contains("Q1_min")); - assert!(result.contains("Q4_max")); - assert!(result.contains("CASE WHEN _Q = 1")); - assert!(result.contains("GROUP BY grp")); + let result = boxplot_sql_compute_summary("tbl", &groups, "val", &1.5); + assert!(result.contains("GROUP BY cat, region")); + assert!(result.contains("QUANTILE_CONT(val, 0.25)")); } #[test] - fn test_sql_compute_fivenum_coef() { + fn test_sql_compute_summary_custom_coef() { let groups = vec!["x".to_string()]; - let result = boxplot_sql_compute_fivenum("q", &groups, &2.5); + let result = boxplot_sql_compute_summary("q", &groups, "y", &2.5); assert!(result.contains("2.5")); - assert!(result.contains("AS lower")); - assert!(result.contains("AS upper")); - assert!(result.contains("AS median")); - assert!(result.contains("GREATEST")); - assert!(result.contains("LEAST")); + assert!(result.contains("GREATEST(q1 - 2.5 * (q3 - q1), min)")); + assert!(result.contains("LEAST( q3 + 2.5 * (q3 - q1), max)")); } #[test] @@ -411,30 +355,16 @@ mod tests { LEAST( q3 + 1.5 * (q3 - q1), max) AS upper FROM ( SELECT - Q1_min AS min, - Q4_max AS max, - (Q2_max + Q3_min) / 2.0 AS median, - (Q1_max + Q2_min) / 2.0 AS q1, - (Q3_max + Q4_min) / 2.0 AS q3, - category - FROM (SELECT - MIN(CASE WHEN _Q = 1 THEN price END) AS Q1_min, - MAX(CASE WHEN _Q = 1 THEN price END) AS Q1_max, - MIN(CASE WHEN _Q = 2 THEN price END) AS Q2_min, - MAX(CASE WHEN _Q = 2 THEN price END) AS Q2_max, - MIN(CASE WHEN _Q = 3 THEN price END) AS Q3_min, - MAX(CASE WHEN _Q = 3 THEN price END) AS Q3_max, - MIN(CASE WHEN _Q = 4 THEN price END) AS Q4_min, - MAX(CASE WHEN _Q = 4 THEN price END) AS Q4_max, - category - FROM (SELECT - price, - category, - NTILE(4) OVER (PARTITION BY category ORDER BY price ASC) AS _Q - FROM (SELECT * FROM sales) - WHERE price IS NOT NULL) - GROUP BY category) - )"#; + category, + MIN(price) AS min, + MAX(price) AS max, + QUANTILE_CONT(price, 0.25) AS q1, + QUANTILE_CONT(price, 0.50) AS median, + QUANTILE_CONT(price, 0.75) AS q3 + FROM (SELECT * FROM sales) AS __ggsql_qt__ + WHERE price IS NOT NULL + GROUP BY category + ) AS __ggsql_fn__"#; assert_eq!(result, expected); } @@ -450,45 +380,20 @@ mod tests { LEAST( q3 + 1.5 * (q3 - q1), max) AS upper FROM ( SELECT - Q1_min AS min, - Q4_max AS max, - (Q2_max + Q3_min) / 2.0 AS median, - (Q1_max + Q2_min) / 2.0 AS q1, - (Q3_max + Q4_min) / 2.0 AS q3, - region, product - FROM (SELECT - MIN(CASE WHEN _Q = 1 THEN revenue END) AS Q1_min, - MAX(CASE WHEN _Q = 1 THEN revenue END) AS Q1_max, - MIN(CASE WHEN _Q = 2 THEN revenue END) AS Q2_min, - MAX(CASE WHEN _Q = 2 THEN revenue END) AS Q2_max, - MIN(CASE WHEN _Q = 3 THEN revenue END) AS Q3_min, - MAX(CASE WHEN _Q = 3 THEN revenue END) AS Q3_max, - MIN(CASE WHEN _Q = 4 THEN revenue END) AS Q4_min, - MAX(CASE WHEN _Q = 4 THEN revenue END) AS Q4_max, - region, product - FROM (SELECT - revenue, - region, product, - NTILE(4) OVER (PARTITION BY region, product ORDER BY revenue ASC) AS _Q - FROM (SELECT * FROM data) - WHERE revenue IS NOT NULL) - GROUP BY region, product) - )"#; + region, product, + MIN(revenue) AS min, + MAX(revenue) AS max, + QUANTILE_CONT(revenue, 0.25) AS q1, + QUANTILE_CONT(revenue, 0.50) AS median, + QUANTILE_CONT(revenue, 0.75) AS q3 + FROM (SELECT * FROM data) AS __ggsql_qt__ + WHERE revenue IS NOT NULL + GROUP BY region, product + ) AS __ggsql_fn__"#; assert_eq!(result, expected); } - #[test] - fn test_boxplot_sql_compute_summary_custom_coef() { - let groups = vec!["x".to_string()]; - let result = boxplot_sql_compute_summary("source_query", &groups, "y", &3.0); - - // Verify coef parameter is properly interpolated - assert!(result.contains("3 *")); - assert!(result.contains("GREATEST(q1 - 3 * (q3 - q1), min)")); - assert!(result.contains("LEAST( q3 + 3 * (q3 - q1), max)")); - } - #[test] fn test_boxplot_sql_append_outliers_with_outliers() { let groups = vec!["category".to_string()]; diff --git a/src/plot/layer/geom/histogram.rs b/src/plot/layer/geom/histogram.rs index 5e3f1300..6a7c66ca 100644 --- a/src/plot/layer/geom/histogram.rs +++ b/src/plot/layer/geom/histogram.rs @@ -124,7 +124,7 @@ fn stat_histogram( // Query min/max to compute bin width let stats_query = format!( - "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query})", + "SELECT MIN({x}) as min_val, MAX({x}) as max_val FROM ({query}) AS __ggsql_stats__", x = x_col, query = query ); diff --git a/src/reader/connection.rs b/src/reader/connection.rs index 63f90cf7..e592d5cc 100644 --- a/src/reader/connection.rs +++ b/src/reader/connection.rs @@ -11,6 +11,8 @@ pub enum ConnectionInfo { DuckDBMemory, /// DuckDB file-based database DuckDBFile(String), + /// Polars in-memory SQL context + PolarsMemory, /// PostgreSQL connection #[allow(dead_code)] PostgreSQL(String), @@ -56,6 +58,18 @@ pub fn parse_connection_string(uri: &str) -> Result { return Ok(ConnectionInfo::DuckDBFile(cleaned_path.to_string())); } + if uri == "polars://" || uri == "polars://memory" { + return Ok(ConnectionInfo::PolarsMemory); + } + + if uri.starts_with("polars://") { + // Polars only supports in-memory mode + return Err(GgsqlError::ReaderError( + "Polars reader only supports in-memory mode. Use 'polars://memory' or 'polars://'" + .to_string(), + )); + } + if uri.starts_with("postgres://") || uri.starts_with("postgresql://") { return Ok(ConnectionInfo::PostgreSQL(uri.to_string())); } @@ -71,7 +85,7 @@ pub fn parse_connection_string(uri: &str) -> Result { } Err(GgsqlError::ReaderError(format!( - "Unsupported connection string format: {}. Supported: duckdb://, postgres://, sqlite://", + "Unsupported connection string format: {}. Supported: duckdb://, polars://, postgres://, sqlite://", uri ))) } @@ -127,6 +141,28 @@ mod tests { assert_eq!(info, ConnectionInfo::SQLite("data.db".to_string())); } + #[test] + fn test_polars_memory() { + let info = parse_connection_string("polars://memory").unwrap(); + assert_eq!(info, ConnectionInfo::PolarsMemory); + } + + #[test] + fn test_polars_empty() { + let info = parse_connection_string("polars://").unwrap(); + assert_eq!(info, ConnectionInfo::PolarsMemory); + } + + #[test] + fn test_polars_file_not_supported() { + let result = parse_connection_string("polars://data.db"); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("only supports in-memory")); + } + #[test] fn test_empty_duckdb_path() { let result = parse_connection_string("duckdb://"); diff --git a/src/reader/data.rs b/src/reader/data.rs index cff1aff1..8bea4fbc 100644 --- a/src/reader/data.rs +++ b/src/reader/data.rs @@ -1,78 +1,214 @@ -use std::{env, fs}; use tree_sitter::{Parser, Query, StreamingIterator}; -use crate::GgsqlError; +use crate::{naming, GgsqlError}; +// ============================================================================= +// Embedded dataset bytes +// ============================================================================= // To add new built-in datasets follow these steps: // // 1. Add a parquet file of your dataset to the /data/ folder // 2. Include the bytes of that parquet file in the binary, like is done // beneath this block. -// 3. Make a `prep_{dataset}_query()` convenience function. -// 4. In the `init_builtin_data()` function below, include a `match`-arm for your -// dataset using the "ggsql:" pattern. +// 3. Add a match arm in `builtin_parquet_bytes()` for your dataset. +// 4. Add the dataset name to `KNOWN_DATASETS`. +// ============================================================================= +#[cfg(feature = "builtin-data")] static PENGUINS: &[u8] = include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/../data/penguins.parquet" )); +#[cfg(feature = "builtin-data")] static AIRQUALITY: &[u8] = include_bytes!(concat!( env!("CARGO_MANIFEST_DIR"), "/../data/airquality.parquet" )); -pub fn prep_penguins_query() -> String { - prep_builtin_dataset_query("penguins", PENGUINS) +/// Get the embedded parquet bytes for a known builtin dataset. +#[cfg(feature = "builtin-data")] +fn builtin_parquet_bytes(name: &str) -> Option<&'static [u8]> { + match name { + "penguins" => Some(PENGUINS), + "airquality" => Some(AIRQUALITY), + _ => None, + } } -pub fn prep_airquality_query() -> String { - prep_builtin_dataset_query("airquality", AIRQUALITY) +// ============================================================================= +// DuckDB builtin data registration (requires duckdb + builtin-data) +// ============================================================================= + +/// Register any builtin datasets referenced in the SQL with a DuckDB connection. +/// +/// Finds `ggsql:X` patterns in the SQL, writes the embedded parquet data to +/// a temp file, and creates a table named `__ggsql_data_X__` in DuckDB. +#[cfg(all(feature = "duckdb", feature = "builtin-data"))] +pub fn register_builtin_datasets_duckdb( + sql: &str, + conn: &duckdb::Connection, +) -> Result<(), GgsqlError> { + use std::{env, fs}; + + let dataset_names = extract_builtin_dataset_names(sql)?; + for name in dataset_names { + let Some(parquet_bytes) = builtin_parquet_bytes(&name) else { + continue; + }; + + let table_name = naming::builtin_data_table(&name); + + // Write parquet to temp file for DuckDB's read_parquet + let mut tmp_path = env::temp_dir(); + tmp_path.push(format!("{}.parquet", name)); + if !tmp_path.exists() { + fs::write(&tmp_path, parquet_bytes).expect("Failed to write dataset"); + } + + let create_sql = format!( + "CREATE TABLE IF NOT EXISTS \"{}\" AS SELECT * FROM read_parquet('{}')", + table_name, + tmp_path.display() + ); + + conn.execute(&create_sql, duckdb::params![]).map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to register builtin dataset '{}': {}", + name, e + )) + })?; + } + Ok(()) } -fn prep_builtin_dataset_query(name: &str, data: &[u8]) -> String { - let mut tmp_path = env::temp_dir(); - let mut filename = name.to_string(); - filename.push_str(".parquet"); - tmp_path.push(filename); - if !tmp_path.exists() { - fs::write(&tmp_path, data).expect("Failed to write dataset"); - } - format!( - "CREATE TABLE IF NOT EXISTS {} AS SELECT * FROM read_parquet('{}')", - name, - tmp_path.display() - ) +// ============================================================================= +// Polars-based builtin data loading +// ============================================================================= + +#[cfg(feature = "builtin-data")] +pub fn load_builtin_dataframe(name: &str) -> Result { + use polars::prelude::*; + use std::io::Cursor; + + let parquet_bytes = match name { + "penguins" => PENGUINS, + "airquality" => AIRQUALITY, + _ => { + return Err(GgsqlError::ReaderError(format!( + "Unknown builtin dataset: '{}'", + name + ))) + } + }; + + let cursor = Cursor::new(parquet_bytes); + ParquetReader::new(cursor).finish().map_err(|e| { + GgsqlError::ReaderError(format!("Failed to load builtin dataset '{}': {}", name, e)) + }) } -pub fn init_builtin_data(sql: &str) -> Result, GgsqlError> { - // This definition pulls out namespaced identifiers (e.g., ggsql:penguins) by - // @select'ing the string/identifier token. +/// Known builtin dataset names in the ggsql namespace +const KNOWN_DATASETS: &[&str] = &["penguins", "airquality"]; + +/// Check if a dataset name is a known builtin +pub fn is_known_builtin(name: &str) -> bool { + KNOWN_DATASETS.contains(&name) +} + +// ============================================================================= +// SQL namespace rewriting (always available, including WASM) +// ============================================================================= + +/// Extract builtin dataset names from SQL containing namespaced identifiers. +/// +/// Finds `ggsql:X` patterns via tree-sitter and returns the dataset names +/// (without the `ggsql:` prefix), deduplicated. +pub fn extract_builtin_dataset_names(sql: &str) -> Result, GgsqlError> { let token_def = r#"(namespaced_identifier) @select"#; let mut tokens = tokens_from_tree(sql, token_def, "select")?; - let mut result = Vec::new(); + if tokens.is_empty() { - return Ok(result); + return Ok(Vec::new()); } - // Deduplicate tokens tokens.sort_unstable(); tokens.dedup(); - for dataset in tokens { - // Only process ggsql namespace datasets - let materialize_query = match dataset.as_str() { - "ggsql:penguins" => Some(prep_penguins_query()), - "ggsql:airquality" => Some(prep_airquality_query()), - _ => None, // Unknown namespace - ignored - }; - if let Some(query) = materialize_query { - result.push(query); + let datasets: Vec = tokens + .iter() + .filter_map(|token| token.strip_prefix("ggsql:").map(|s| s.to_string())) + .collect(); + + Ok(datasets) +} + +/// Rewrite SQL to replace namespaced identifiers with internal table names. +/// +/// e.g., `SELECT * FROM ggsql:penguins` -> `SELECT * FROM __ggsql_data_penguins__` +/// +/// Uses tree-sitter to find the exact byte positions of namespaced identifiers, +/// then replaces them in reverse order to preserve offsets. +pub fn rewrite_namespaced_sql(sql: &str) -> Result { + let token_def = r#"(namespaced_identifier) @select"#; + + // Parse to get byte positions + let mut parser = Parser::new(); + parser + .set_language(&tree_sitter_ggsql::language()) + .map_err(|e| GgsqlError::ParseError(format!("Failed to initialise parser: {}", e)))?; + + let tree = parser + .parse(sql, None) + .ok_or_else(|| GgsqlError::ParseError(format!("Failed to parse query: {}", sql)))?; + + let query = Query::new(&tree.language(), token_def) + .map_err(|e| GgsqlError::ParseError(format!("Failed to initialise tree_query: {}", e)))?; + + let index = query + .capture_index_for_name("select") + .ok_or_else(|| GgsqlError::ParseError("Failed to capture index".to_string()))?; + + let mut cursor = tree_sitter::QueryCursor::new(); + let mut matches = cursor.matches(&query, tree.root_node(), sql.as_bytes()); + + // Collect (start_byte, end_byte, replacement) tuples + let mut replacements: Vec<(usize, usize, String)> = Vec::new(); + while let Some(matching) = matches.next() { + for item in matching.captures { + if item.index != index { + continue; + } + let node = item.node; + let full_text = &sql[node.start_byte()..node.end_byte()]; + if let Some(name) = full_text.strip_prefix("ggsql:") { + replacements.push(( + node.start_byte(), + node.end_byte(), + naming::builtin_data_table(name), + )); + } } } + + if replacements.is_empty() { + return Ok(sql.to_string()); + } + + // Apply replacements in reverse byte order to preserve earlier offsets + let mut result = sql.to_string(); + replacements.sort_by_key(|(start, _, _)| std::cmp::Reverse(*start)); + for (start, end, replacement) in replacements { + result.replace_range(start..end, &replacement); + } + Ok(result) } +// ============================================================================= +// Shared tree-sitter helpers +// ============================================================================= + fn tokens_from_tree( sql_query: &str, tree_query: &str, @@ -136,27 +272,125 @@ fn tokens_from_tree( Ok(result) } -#[cfg(feature = "duckdb")] -#[test] -fn test_builtin_data_is_available() { - use crate::naming; - - let reader = crate::reader::DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - - // Test penguins builtin dataset with a DRAW clause - let query = - "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING bill_len AS x, bill_dep AS y"; - let result = crate::execute::prepare_data_with_reader(query, &reader).unwrap(); - let dataframe = result.data.get(&naming::layer_key(0)).unwrap(); - // Check that the aesthetic columns are present (other columns preserved via SELECT *) - assert!(dataframe.column("__ggsql_aes_x__").is_ok()); - assert!(dataframe.column("__ggsql_aes_y__").is_ok()); - - // Test airquality builtin dataset with VISUALISE FROM - let query = "VISUALISE FROM ggsql:airquality DRAW point MAPPING Temp AS x, Ozone AS y"; - let result = crate::execute::prepare_data_with_reader(query, &reader).unwrap(); - let dataframe = result.data.get(&naming::layer_key(0)).unwrap(); - // Check that the aesthetic columns are present - assert!(dataframe.column("__ggsql_aes_x__").is_ok()); - assert!(dataframe.column("__ggsql_aes_y__").is_ok()); +// ============================================================================= +// Tests +// ============================================================================= + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_builtin_dataset_names_single() { + let sql = "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING x AS x"; + let names = extract_builtin_dataset_names(sql).unwrap(); + assert_eq!(names, vec!["penguins"]); + } + + #[test] + fn test_extract_builtin_dataset_names_multiple() { + let sql = + "SELECT * FROM ggsql:penguins, ggsql:airquality VISUALISE DRAW point MAPPING x AS x"; + let names = extract_builtin_dataset_names(sql).unwrap(); + assert_eq!(names.len(), 2); + assert!(names.contains(&"airquality".to_string())); + assert!(names.contains(&"penguins".to_string())); + } + + #[test] + fn test_extract_builtin_dataset_names_dedup() { + let sql = "SELECT * FROM ggsql:penguins p1, ggsql:penguins p2 VISUALISE DRAW point MAPPING x AS x"; + let names = extract_builtin_dataset_names(sql).unwrap(); + assert_eq!(names, vec!["penguins"]); + } + + #[test] + fn test_extract_builtin_dataset_names_none() { + let sql = "SELECT * FROM regular_table VISUALISE DRAW point MAPPING x AS x"; + let names = extract_builtin_dataset_names(sql).unwrap(); + assert!(names.is_empty()); + } + + #[test] + fn test_rewrite_namespaced_sql_simple() { + let sql = "SELECT * FROM ggsql:penguins"; + let rewritten = rewrite_namespaced_sql(sql).unwrap(); + assert_eq!(rewritten, "SELECT * FROM __ggsql_data_penguins__"); + } + + #[test] + fn test_rewrite_namespaced_sql_multiple() { + let sql = "SELECT * FROM ggsql:penguins p, ggsql:airquality a WHERE p.id = a.id"; + let rewritten = rewrite_namespaced_sql(sql).unwrap(); + assert_eq!( + rewritten, + "SELECT * FROM __ggsql_data_penguins__ p, __ggsql_data_airquality__ a WHERE p.id = a.id" + ); + } + + #[test] + fn test_rewrite_namespaced_sql_no_change() { + let sql = "SELECT * FROM regular_table WHERE x > 5"; + let rewritten = rewrite_namespaced_sql(sql).unwrap(); + assert_eq!(rewritten, sql); + } + + #[test] + fn test_rewrite_namespaced_sql_with_visualise() { + let sql = "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING bill_len AS x, bill_dep AS y"; + let rewritten = rewrite_namespaced_sql(sql).unwrap(); + assert!(rewritten.starts_with("SELECT * FROM __ggsql_data_penguins__")); + assert!(!rewritten.contains("ggsql:")); + } +} + +#[cfg(all(feature = "duckdb", feature = "builtin-data"))] +#[cfg(test)] +mod duckdb_tests { + #[test] + fn test_builtin_data_is_available() { + use crate::naming; + + let reader = + crate::reader::DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + let query = + "SELECT * FROM ggsql:penguins VISUALISE DRAW point MAPPING bill_len AS x, bill_dep AS y"; + let result = crate::execute::prepare_data_with_reader(query, &reader).unwrap(); + let dataframe = result.data.get(&naming::layer_key(0)).unwrap(); + assert!(dataframe.column("__ggsql_aes_x__").is_ok()); + assert!(dataframe.column("__ggsql_aes_y__").is_ok()); + + let query = "VISUALISE FROM ggsql:airquality DRAW point MAPPING Temp AS x, Ozone AS y"; + let result = crate::execute::prepare_data_with_reader(query, &reader).unwrap(); + let dataframe = result.data.get(&naming::layer_key(0)).unwrap(); + assert!(dataframe.column("__ggsql_aes_x__").is_ok()); + assert!(dataframe.column("__ggsql_aes_y__").is_ok()); + } +} + +#[cfg(feature = "builtin-data")] +#[cfg(test)] +mod builtin_data_tests { + use super::*; + + #[test] + fn test_load_builtin_parquet_penguins() { + let df = load_builtin_dataframe("penguins").unwrap(); + assert!(df.height() > 0); + assert!(df.width() > 0); + } + + #[test] + fn test_load_builtin_parquet_airquality() { + let df = load_builtin_dataframe("airquality").unwrap(); + assert!(df.height() > 0); + assert!(df.width() > 0); + } + + #[test] + fn test_load_builtin_parquet_unknown() { + let result = load_builtin_dataframe("nonexistent"); + assert!(result.is_err()); + } } diff --git a/src/reader/duckdb.rs b/src/reader/duckdb.rs index 655e2d94..26d74c56 100644 --- a/src/reader/duckdb.rs +++ b/src/reader/duckdb.rs @@ -2,7 +2,6 @@ //! //! Provides a reader for DuckDB databases with direct Polars DataFrame integration. -use crate::reader::data::init_builtin_data; use crate::reader::{connection::ConnectionInfo, Reader}; use crate::{DataFrame, GgsqlError, Result}; use arrow::ipc::reader::FileReader; @@ -10,6 +9,7 @@ use duckdb::vtab::arrow::{arrow_recordbatch_to_query_params, ArrowVTab}; use duckdb::{params, Connection}; use polars::io::SerWriter; use polars::prelude::*; +use std::cell::RefCell; use std::collections::HashSet; use std::io::Cursor; @@ -33,7 +33,7 @@ use std::io::Cursor; /// ``` pub struct DuckDBReader { conn: Connection, - registered_tables: HashSet, + registered_tables: RefCell>, } impl DuckDBReader { @@ -79,7 +79,7 @@ impl DuckDBReader { Ok(Self { conn, - registered_tables: HashSet::new(), + registered_tables: RefCell::new(HashSet::new()), }) } @@ -388,6 +388,13 @@ impl Reader for DuckDBReader { fn execute_sql(&self, sql: &str) -> Result { use polars::prelude::*; + // Register builtin datasets if referenced + #[cfg(feature = "builtin-data")] + super::data::register_builtin_datasets_duckdb(sql, &self.conn)?; + + // Rewrite ggsql:name → __ggsql_data_name__ in SQL + let sql = super::data::rewrite_namespaced_sql(sql)?; + // Check if this is a DDL statement (CREATE, DROP, INSERT, UPDATE, DELETE, ALTER) // DDL statements don't return rows, so we handle them specially let trimmed = sql.trim().to_uppercase(); @@ -398,21 +405,10 @@ impl Reader for DuckDBReader { || trimmed.starts_with("DELETE ") || trimmed.starts_with("ALTER "); - // Initialise built-in datasets - let inits = init_builtin_data(sql)?; - for init in inits { - if let Err(e) = self.conn.execute(&init, params![]) { - return Err(GgsqlError::ReaderError(format!( - "Failed to initialise built-in dataset: {}", - e - ))); - } - } - if is_ddl { // For DDL, just execute and return an empty DataFrame self.conn - .execute(sql, params![]) + .execute(&sql, params![]) .map_err(|e| GgsqlError::ReaderError(format!("Failed to execute DDL: {}", e)))?; // Return empty DataFrame for DDL statements @@ -424,7 +420,7 @@ impl Reader for DuckDBReader { // Prepare and execute statement to get schema let mut stmt = self .conn - .prepare(sql) + .prepare(&sql) .map_err(|e| GgsqlError::ReaderError(format!("Failed to prepare SQL: {}", e)))?; // Execute to populate schema info @@ -504,12 +500,12 @@ impl Reader for DuckDBReader { Ok(df) } - fn register(&mut self, name: &str, df: DataFrame) -> Result<()> { + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { // Validate table name validate_table_name(name)?; // Check for duplicates - if self.table_exists(name)? { + if !replace && self.table_exists(name)? { return Err(GgsqlError::ReaderError(format!( "Table '{}' already exists", name @@ -530,13 +526,18 @@ impl Reader for DuckDBReader { // subsequent chunks INSERT into it. const MAX_ARROW_BATCH_ROWS: usize = 2048; let total_rows = df.height(); + let create_or_replace = if replace { + "CREATE OR REPLACE" + } else { + "CREATE" + }; if total_rows <= MAX_ARROW_BATCH_ROWS { // Small DataFrame: register in a single batch let params = dataframe_to_arrow_params(df)?; let sql = format!( - "CREATE TEMP TABLE \"{}\" AS SELECT * FROM arrow(?, ?)", - name + "{} TEMP TABLE \"{}\" AS SELECT * FROM arrow(?, ?)", + create_or_replace, name ); self.conn.execute(&sql, params).map_err(|e| { GgsqlError::ReaderError(format!("Failed to register table '{}': {}", name, e)) @@ -546,8 +547,8 @@ impl Reader for DuckDBReader { let first_chunk = df.slice(0, MAX_ARROW_BATCH_ROWS); let params = dataframe_to_arrow_params(first_chunk)?; let create_sql = format!( - "CREATE TEMP TABLE \"{}\" AS SELECT * FROM arrow(?, ?)", - name + "{} TEMP TABLE \"{}\" AS SELECT * FROM arrow(?, ?)", + create_or_replace, name ); self.conn.execute(&create_sql, params).map_err(|e| { GgsqlError::ReaderError(format!("Failed to register table '{}': {}", name, e)) @@ -570,14 +571,13 @@ impl Reader for DuckDBReader { } // Track the table so we can unregister it later - self.registered_tables.insert(name.to_string()); - + self.registered_tables.borrow_mut().insert(name.to_string()); Ok(()) } - fn unregister(&mut self, name: &str) -> Result<()> { + fn unregister(&self, name: &str) -> Result<()> { // Only allow unregistering tables we created via register() - if !self.registered_tables.contains(name) { + if !self.registered_tables.borrow().contains(name) { return Err(GgsqlError::ReaderError(format!( "Table '{}' was not registered via this reader", name @@ -591,14 +591,10 @@ impl Reader for DuckDBReader { })?; // Remove from tracking - self.registered_tables.remove(name); + self.registered_tables.borrow_mut().remove(name); Ok(()) } - - fn supports_register(&self) -> bool { - true - } } #[cfg(test)] @@ -677,7 +673,7 @@ mod tests { #[test] fn test_register_and_query() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); // Create a DataFrame let df = DataFrame::new(vec![ @@ -687,7 +683,7 @@ mod tests { .unwrap(); // Register the DataFrame - reader.register("my_table", df).unwrap(); + reader.register("my_table", df, false).unwrap(); // Query the registered table let result = reader @@ -699,16 +695,16 @@ mod tests { #[test] fn test_register_duplicate_name_errors() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df1 = DataFrame::new(vec![Column::new("a".into(), vec![1i32])]).unwrap(); let df2 = DataFrame::new(vec![Column::new("b".into(), vec![2i32])]).unwrap(); // First registration should succeed - reader.register("dup_table", df1).unwrap(); + reader.register("dup_table", df1, false).unwrap(); - // Second registration with same name should fail - let result = reader.register("dup_table", df2); + // Second registration with same name should fail (when replace=false) + let result = reader.register("dup_table", df2, false); assert!(result.is_err()); let err = result.unwrap_err().to_string(); assert!(err.contains("already exists")); @@ -716,16 +712,16 @@ mod tests { #[test] fn test_register_invalid_table_names() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df = DataFrame::new(vec![Column::new("a".into(), vec![1i32])]).unwrap(); // Empty name - let result = reader.register("", df.clone()); + let result = reader.register("", df.clone(), false); assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("cannot be empty")); // Name with double quote - let result = reader.register("bad\"name", df.clone()); + let result = reader.register("bad\"name", df.clone(), false); assert!(result.is_err()); assert!(result .unwrap_err() @@ -733,7 +729,7 @@ mod tests { .contains("invalid character")); // Name with null byte - let result = reader.register("bad\0name", df.clone()); + let result = reader.register("bad\0name", df.clone(), false); assert!(result.is_err()); assert!(result .unwrap_err() @@ -742,7 +738,7 @@ mod tests { // Name too long let long_name = "a".repeat(200); - let result = reader.register(&long_name, df); + let result = reader.register(&long_name, df, false); assert!(result.is_err()); assert!(result .unwrap_err() @@ -750,15 +746,9 @@ mod tests { .contains("exceeds maximum length")); } - #[test] - fn test_supports_register() { - let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); - assert!(reader.supports_register()); - } - #[test] fn test_register_empty_dataframe() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); // Create an empty DataFrame with schema let df = DataFrame::new(vec![ @@ -767,7 +757,7 @@ mod tests { ]) .unwrap(); - reader.register("empty_table", df).unwrap(); + reader.register("empty_table", df, false).unwrap(); // Query should return empty result with correct schema let result = reader.execute_sql("SELECT * FROM empty_table").unwrap(); @@ -777,10 +767,10 @@ mod tests { #[test] fn test_unregister() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df = DataFrame::new(vec![Column::new("x".into(), vec![1i32, 2, 3])]).unwrap(); - reader.register("test_data", df).unwrap(); + reader.register("test_data", df, false).unwrap(); // Should be queryable let result = reader.execute_sql("SELECT * FROM test_data").unwrap(); @@ -796,7 +786,7 @@ mod tests { #[test] fn test_unregister_not_registered() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); // Create a table directly (not via register) reader @@ -813,14 +803,14 @@ mod tests { #[test] fn test_reregister_after_unregister() { - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df = DataFrame::new(vec![Column::new("x".into(), vec![1i32, 2, 3])]).unwrap(); - reader.register("data", df.clone()).unwrap(); + reader.register("data", df.clone(), false).unwrap(); reader.unregister("data").unwrap(); // Should be able to register again - reader.register("data", df).unwrap(); + reader.register("data", df, false).unwrap(); let result = reader.execute_sql("SELECT * FROM data").unwrap(); assert_eq!(result.height(), 3); } @@ -829,7 +819,7 @@ mod tests { fn test_register_large_dataframe() { // duckdb-rs Arrow vtab has a vector capacity of 2048 rows. DataFrames // larger than this must be chunked to avoid a panic. - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let n = 3000; let ids: Vec = (0..n).collect(); @@ -843,7 +833,7 @@ mod tests { ]) .unwrap(); - reader.register("large_table", df).unwrap(); + reader.register("large_table", df, false).unwrap(); // Verify row count let result = reader diff --git a/src/reader/mod.rs b/src/reader/mod.rs index a5c93621..5565668c 100644 --- a/src/reader/mod.rs +++ b/src/reader/mod.rs @@ -27,7 +27,7 @@ //! //! // With DataFrame registration //! let mut reader = DuckDBReader::from_connection_string("duckdb://memory")?; -//! reader.register("my_table", some_dataframe)?; +//! reader.register("my_table", some_dataframe, false)?; //! let spec = reader.execute("SELECT * FROM my_table VISUALISE x, y DRAW point")?; //! ``` @@ -41,6 +41,9 @@ use crate::{DataFrame, GgsqlError, Result}; #[cfg(feature = "duckdb")] pub mod duckdb; +#[cfg(feature = "polars-sql")] +pub mod polars_sql; + pub mod connection; pub mod data; mod spec; @@ -48,6 +51,9 @@ mod spec; #[cfg(feature = "duckdb")] pub use duckdb::DuckDBReader; +#[cfg(feature = "polars-sql")] +pub use polars_sql::PolarsReader; + // ============================================================================ // Spec - Result of reader.execute() // ============================================================================ @@ -91,13 +97,13 @@ pub struct Metadata { /// /// # DataFrame Registration /// -/// Some readers support registering DataFrames as queryable tables using +/// Readers support registering DataFrames as queryable tables using /// the [`register`](Reader::register) method. This allows you to query /// in-memory DataFrames with SQL, join them with other tables, etc. /// /// ```rust,ignore /// // Register a DataFrame (takes ownership) -/// reader.register("sales", sales_df)?; +/// reader.register("sales", sales_df, false)?; /// /// // Now you can query it /// let result = reader.execute_sql("SELECT * FROM sales WHERE amount > 100")?; @@ -132,20 +138,13 @@ pub trait Reader { /// /// * `name` - The table name to register under /// * `df` - The DataFrame to register (ownership is transferred) + /// * `replace` - If true, replace any existing table with the same name. + /// If false, return an error if the table already exists. /// /// # Returns /// - /// `Ok(())` on success, error if registration fails or isn't supported. - /// - /// # Default Implementation - /// - /// Returns an error by default. Override for readers that support registration. - fn register(&mut self, name: &str, _df: DataFrame) -> Result<()> { - Err(GgsqlError::ReaderError(format!( - "This reader does not support DataFrame registration for table '{}'", - name - ))) - } + /// `Ok(())` on success, error if registration fails. + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()>; /// Unregister a previously registered table /// @@ -159,23 +158,14 @@ pub trait Reader { /// /// # Default Implementation /// - /// Returns an error by default. Override for readers that support registration. - fn unregister(&mut self, name: &str) -> Result<()> { + /// Returns an error by default. Override for readers that support unregistration. + fn unregister(&self, name: &str) -> Result<()> { Err(GgsqlError::ReaderError(format!( "This reader does not support unregistering table '{}'", name ))) } - /// Check if this reader supports DataFrame registration - /// - /// # Returns - /// - /// `true` if [`register`](Reader::register) is implemented, `false` otherwise. - fn supports_register(&self) -> bool { - false - } - /// Execute a ggsql query and return the visualization specification. /// /// This is the main entry point for creating visualizations. It parses the query, @@ -202,14 +192,16 @@ pub trait Reader { /// use ggsql::reader::{Reader, DuckDBReader}; /// use ggsql::writer::{Writer, VegaLiteWriter}; /// - /// let reader = DuckDBReader::from_connection_string("duckdb://memory")?; + /// let mut reader = DuckDBReader::from_connection_string("duckdb://memory")?; /// let spec = reader.execute("SELECT 1 as x, 2 as y VISUALISE x, y DRAW point")?; /// /// let writer = VegaLiteWriter::new(); /// let json = writer.render(&spec)?; /// ``` - #[cfg(feature = "duckdb")] - fn execute(&self, query: &str) -> Result { + fn execute(&self, query: &str) -> Result + where + Self: Sized, + { // Run validation first to capture warnings let validated = validate(query)?; let warnings: Vec = validated.warnings().to_vec(); @@ -394,7 +386,7 @@ mod tests { fn test_register_and_query() { use polars::prelude::*; - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let df = df! { "x" => [1i32, 2, 3], @@ -402,7 +394,7 @@ mod tests { } .unwrap(); - reader.register("my_data", df).unwrap(); + reader.register("my_data", df, false).unwrap(); let query = "SELECT * FROM my_data VISUALISE x, y DRAW point"; let spec = reader.execute(query).unwrap(); @@ -419,7 +411,7 @@ mod tests { fn test_register_and_join() { use polars::prelude::*; - let mut reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); let sales = df! { "id" => [1i32, 2, 3], @@ -434,8 +426,8 @@ mod tests { } .unwrap(); - reader.register("sales", sales).unwrap(); - reader.register("products", products).unwrap(); + reader.register("sales", sales, false).unwrap(); + reader.register("products", products, false).unwrap(); let query = r#" SELECT s.id, s.amount, p.name diff --git a/src/reader/polars_sql.rs b/src/reader/polars_sql.rs new file mode 100644 index 00000000..4bf69868 --- /dev/null +++ b/src/reader/polars_sql.rs @@ -0,0 +1,512 @@ +//! Polars SQL context data source implementation +//! +//! Provides a reader that uses Polars' built-in SQL context for querying DataFrames. + +use crate::reader::Reader; +use crate::{DataFrame, GgsqlError, Result}; +use polars::prelude::*; +use polars::sql::SQLContext; +use std::cell::RefCell; +use std::collections::HashSet; + +/// Polars SQL context reader +/// +/// Executes SQL queries against registered Polars DataFrames using Polars' built-in +/// SQL context. This is a pure in-memory reader with no external database connection. +/// +/// # Examples +/// +/// ```rust,ignore +/// use ggsql::reader::{Reader, PolarsReader}; +/// use polars::prelude::*; +/// +/// // Create an in-memory reader +/// let mut reader = PolarsReader::from_connection_string("polars://memory")?; +/// +/// // Register a DataFrame +/// let df = df! { +/// "x" => [1, 2, 3], +/// "y" => [10, 20, 30], +/// }?; +/// reader.register("data", df, false)?; +/// +/// // Query it with SQL +/// let result = reader.execute_sql("SELECT * FROM data WHERE x > 1")?; +/// ``` +pub struct PolarsReader { + ctx: RefCell, + registered_tables: RefCell>, +} + +impl PolarsReader { + /// Create a new Polars reader from a connection string + /// + /// # Arguments + /// + /// * `uri` - Connection string (e.g., "polars://memory" or "polars://") + /// + /// # Returns + /// + /// A configured Polars reader with an empty SQL context + /// + /// # Errors + /// + /// Returns an error if the connection string format is invalid + pub fn from_connection_string(uri: &str) -> Result { + let conn_info = super::connection::parse_connection_string(uri)?; + + match conn_info { + super::connection::ConnectionInfo::PolarsMemory => Ok(Self { + ctx: RefCell::new(SQLContext::new()), + registered_tables: RefCell::new(HashSet::new()), + }), + _ => Err(GgsqlError::ReaderError(format!( + "Connection string '{}' is not supported by PolarsReader", + uri + ))), + } + } + + /// Create a new Polars reader with default settings + /// + /// Equivalent to `from_connection_string("polars://memory")` + pub fn new() -> Self { + Self { + ctx: RefCell::new(SQLContext::new()), + registered_tables: RefCell::new(HashSet::new()), + } + } + + /// Check if a table is registered + fn table_exists(&self, name: &str) -> bool { + self.registered_tables.borrow().contains(name) + } + + /// List registered table names + /// + /// When `internal` is false, filters out internal tables (prefixed with `__ggsql_`). + pub fn list_tables(&self, internal: bool) -> Vec { + self.registered_tables + .borrow() + .iter() + .filter(|name| internal || !name.starts_with("__ggsql_")) + .cloned() + .collect() + } +} + +impl Default for PolarsReader { + fn default() -> Self { + Self::new() + } +} + +/// Validate a table name +fn validate_table_name(name: &str) -> Result<()> { + if name.is_empty() { + return Err(GgsqlError::ReaderError("Table name cannot be empty".into())); + } + + // Reject characters that could break identifiers or cause issues + let forbidden = ['"', '\0', '\n', '\r']; + for ch in forbidden { + if name.contains(ch) { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' contains invalid character '{}'", + name, + ch.escape_default() + ))); + } + } + + // Reasonable length limit + if name.len() > 128 { + return Err(GgsqlError::ReaderError(format!( + "Table name '{}' exceeds maximum length of 128 characters", + name + ))); + } + + Ok(()) +} + +impl Reader for PolarsReader { + fn execute_sql(&self, sql: &str) -> Result { + // Check if this is a DDL statement - Polars SQL context doesn't support DDL + let trimmed = sql.trim().to_uppercase(); + let is_ddl = trimmed.starts_with("CREATE ") + || trimmed.starts_with("DROP ") + || trimmed.starts_with("INSERT ") + || trimmed.starts_with("UPDATE ") + || trimmed.starts_with("DELETE ") + || trimmed.starts_with("ALTER "); + + if is_ddl { + return Err(GgsqlError::ReaderError( + format!("Polars SQL context does not support DDL statements. Use register() to add tables. {}", sql) + )); + } + + // Handle ggsql:name namespaced identifiers (builtin datasets) + #[cfg(feature = "builtin-data")] + { + let dataset_names = super::data::extract_builtin_dataset_names(sql)?; + for name in &dataset_names { + let table_name = crate::naming::builtin_data_table(name); + if !self.table_exists(&table_name) { + let df = super::data::load_builtin_dataframe(name)?; + self.register(&table_name, df, true)?; + } + } + } + + // Rewrite ggsql:name → __ggsql_data_name__ in SQL + let sql = super::data::rewrite_namespaced_sql(sql)?; + + // Execute the query - this returns a LazyFrame + let lazy_frame = self.ctx.borrow_mut().execute(&sql).map_err(|e| { + GgsqlError::ReaderError(format!("Failed to execute SQL `{}`: {}", sql, e)) + })?; + + // Collect the LazyFrame into a DataFrame + let df = lazy_frame.collect().map_err(|e| { + GgsqlError::ReaderError(format!("Failed to collect query result: {}", e)) + })?; + + Ok(df) + } + + fn register(&self, name: &str, df: DataFrame, replace: bool) -> Result<()> { + // Validate table name + validate_table_name(name)?; + + // Handle existing table + if self.table_exists(name) { + if replace { + // Unregister existing table first + self.ctx.borrow_mut().unregister(name); + self.registered_tables.borrow_mut().remove(name); + } else { + return Err(GgsqlError::ReaderError(format!( + "Table '{}' already exists", + name + ))); + } + } + + // Register the DataFrame with the SQL context + // Polars SQLContext takes a LazyFrame + self.ctx.borrow_mut().register(name, df.lazy()); + + // Track the table so we can unregister it later + self.registered_tables.borrow_mut().insert(name.to_string()); + + Ok(()) + } + + fn unregister(&self, name: &str) -> Result<()> { + // Only allow unregistering tables we created via register() + if !self.registered_tables.borrow().contains(name) { + return Err(GgsqlError::ReaderError(format!( + "Table '{}' was not registered via this reader", + name + ))); + } + + // Unregister from the SQL context + self.ctx.borrow_mut().unregister(name); + + // Remove from tracking + self.registered_tables.borrow_mut().remove(name); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_reader() { + let reader = PolarsReader::from_connection_string("polars://memory"); + assert!(reader.is_ok()); + } + + #[test] + fn test_create_reader_default() { + let _reader = PolarsReader::new(); + } + + #[test] + fn test_register_and_query() { + let reader = PolarsReader::new(); + + let df = df! { + "x" => [1i32, 2, 3], + "y" => [10i32, 20, 30], + } + .unwrap(); + + reader.register("my_table", df, false).unwrap(); + + let result = reader + .execute_sql("SELECT * FROM my_table ORDER BY x") + .unwrap(); + assert_eq!(result.shape(), (3, 2)); + assert_eq!(result.get_column_names(), vec!["x", "y"]); + } + + #[test] + fn test_register_and_filter() { + let reader = PolarsReader::new(); + + let df = df! { + "x" => [1i32, 2, 3, 4, 5], + "y" => [10i32, 20, 30, 40, 50], + } + .unwrap(); + + reader.register("data", df, false).unwrap(); + + let result = reader + .execute_sql("SELECT * FROM data WHERE x > 2") + .unwrap(); + assert_eq!(result.height(), 3); + } + + #[test] + fn test_register_duplicate_name_errors() { + let reader = PolarsReader::new(); + + let df1 = df! { "a" => [1i32] }.unwrap(); + let df2 = df! { "b" => [2i32] }.unwrap(); + + // First registration should succeed + reader.register("dup_table", df1, false).unwrap(); + + // Second registration with same name should fail (when replace=false) + let result = reader.register("dup_table", df2, false); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("already exists")); + } + + #[test] + fn test_register_invalid_table_names() { + let reader = PolarsReader::new(); + let df = df! { "a" => [1i32] }.unwrap(); + + // Empty name + let result = reader.register("", df.clone(), false); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("cannot be empty")); + + // Name with double quote + let result = reader.register("bad\"name", df.clone(), false); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("invalid character")); + + // Name with null byte + let result = reader.register("bad\0name", df.clone(), false); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("invalid character")); + + // Name too long + let long_name = "a".repeat(200); + let result = reader.register(&long_name, df, false); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("exceeds maximum length")); + } + + #[test] + fn test_unregister() { + let reader = PolarsReader::new(); + let df = df! { "x" => [1i32, 2, 3] }.unwrap(); + + reader.register("test_data", df, false).unwrap(); + + // Should be queryable + let result = reader.execute_sql("SELECT * FROM test_data").unwrap(); + assert_eq!(result.height(), 3); + + // Unregister + reader.unregister("test_data").unwrap(); + + // Should no longer exist + let result = reader.execute_sql("SELECT * FROM test_data"); + assert!(result.is_err()); + } + + #[test] + fn test_unregister_not_registered() { + let reader = PolarsReader::new(); + + // Should fail - we didn't register anything + let result = reader.unregister("nonexistent"); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("was not registered via this reader")); + } + + #[test] + fn test_reregister_after_unregister() { + let reader = PolarsReader::new(); + let df = df! { "x" => [1i32, 2, 3] }.unwrap(); + + reader.register("data", df.clone(), false).unwrap(); + reader.unregister("data").unwrap(); + + // Should be able to register again + reader.register("data", df, false).unwrap(); + let result = reader.execute_sql("SELECT * FROM data").unwrap(); + assert_eq!(result.height(), 3); + } + + #[test] + fn test_invalid_sql() { + let reader = PolarsReader::new(); + let result = reader.execute_sql("INVALID SQL SYNTAX"); + assert!(result.is_err()); + } + + #[test] + fn test_ddl_not_supported() { + let reader = PolarsReader::new(); + + // CREATE should fail + let result = reader.execute_sql("CREATE TABLE test (x INT)"); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("DDL")); + + // DROP should fail + let result = reader.execute_sql("DROP TABLE test"); + assert!(result.is_err()); + + // INSERT should fail + let result = reader.execute_sql("INSERT INTO test VALUES (1)"); + assert!(result.is_err()); + } + + #[test] + fn test_query_with_aggregation() { + let reader = PolarsReader::new(); + + let df = df! { + "region" => ["US", "US", "EU"], + "revenue" => [100.0f64, 200.0, 150.0], + } + .unwrap(); + + reader.register("sales", df, false).unwrap(); + + let result = reader + .execute_sql("SELECT region, SUM(revenue) as total FROM sales GROUP BY region") + .unwrap(); + + assert_eq!(result.shape(), (2, 2)); + assert_eq!(result.get_column_names(), vec!["region", "total"]); + } + + #[test] + fn test_multiple_tables() { + let reader = PolarsReader::new(); + + let sales = df! { + "id" => [1i32, 2, 3], + "amount" => [100i32, 200, 300], + "product_id" => [1i32, 1, 2], + } + .unwrap(); + + let products = df! { + "id" => [1i32, 2], + "name" => ["Widget", "Gadget"], + } + .unwrap(); + + reader.register("sales", sales, false).unwrap(); + reader.register("products", products, false).unwrap(); + + let result = reader + .execute_sql( + "SELECT s.id, s.amount, p.name + FROM sales s + JOIN products p ON s.product_id = p.id", + ) + .unwrap(); + + assert_eq!(result.height(), 3); + } + + #[test] + fn test_namespaced_sql_with_preregistered_data() { + use crate::naming; + + let reader = PolarsReader::new(); + + let df = df! { + "x" => [1i32, 2, 3], + "y" => [10i32, 20, 30], + } + .unwrap(); + + // Register under the internal table name that ggsql:penguins rewrites to + let table_name = naming::builtin_data_table("penguins"); + reader.register(&table_name, df, false).unwrap(); + + // ggsql:penguins should be rewritten to __ggsql_data_penguins__ and resolve + let result = reader.execute_sql("SELECT * FROM ggsql:penguins").unwrap(); + assert_eq!(result.height(), 3); + } + + #[test] + fn test_namespaced_sql_without_registration_errors() { + let reader = PolarsReader::new(); + + // Without builtin-data feature and without pre-registration, should error + // (when builtin-data is enabled, this test still passes because + // the dataset gets auto-loaded) + let result = reader.execute_sql("SELECT * FROM ggsql:unknown_dataset"); + // Either errors from "not pre-loaded" or from SQL execution failing + assert!(result.is_err()); + } +} + +#[cfg(feature = "builtin-data")] +#[cfg(test)] +mod builtin_data_tests { + use super::*; + + #[test] + fn test_builtin_penguins_auto_loads() { + let reader = PolarsReader::new(); + + // ggsql:penguins should auto-load from embedded parquet + let result = reader + .execute_sql("SELECT * FROM ggsql:penguins LIMIT 5") + .unwrap(); + assert_eq!(result.height(), 5); + assert!(result.width() > 0); + } + + #[test] + fn test_builtin_airquality_auto_loads() { + let reader = PolarsReader::new(); + + let result = reader + .execute_sql("SELECT * FROM ggsql:airquality LIMIT 5") + .unwrap(); + assert_eq!(result.height(), 5); + assert!(result.width() > 0); + } +} diff --git a/tree-sitter-ggsql/bindings/rust/build.rs b/tree-sitter-ggsql/bindings/rust/build.rs index 6072306f..b83511e0 100644 --- a/tree-sitter-ggsql/bindings/rust/build.rs +++ b/tree-sitter-ggsql/bindings/rust/build.rs @@ -1,3 +1,4 @@ +use std::path::Path; use std::path::PathBuf; use std::process::Command; @@ -102,8 +103,35 @@ fn main() { } // The generated files are in the grammar_dir/src directory - cc::Build::new() + let parser_path = src_dir.join("parser.c"); + let mut compiler = cc::Build::new(); + let mut opt_level = "3"; + + // set minimal C sysroot if wasm32-unknown-unknown + if std::env::var("TARGET").unwrap() == "wasm32-unknown-unknown" { + let sysroot_dir = Path::new("bindings/rust/wasm-sysroot"); + compiler + .archiver("llvm-ar") + .include(sysroot_dir.join("include")); + opt_level = "z"; + compiler + .include(&src_dir) + .opt_level_str(opt_level) + .file(sysroot_dir.join("src").join("stdio.c")) + .file(sysroot_dir.join("src").join("stdlib.c")) + .file(sysroot_dir.join("src").join("string.c")) + .file(sysroot_dir.join("src").join("wctype.c")) + .compile("stdlib"); + } + + compiler .include(&src_dir) - .file(src_dir.join("parser.c")) - .compile("tree-sitter-ggsql"); + .opt_level_str(opt_level) + .flag_if_supported("-Wno-unused-parameter") + .flag_if_supported("-Wno-unused-but-set-variable") + .flag_if_supported("-Wno-trigraphs") + .file(&parser_path) + .compile("parser"); + + println!("cargo:rerun-if-changed={}", parser_path.to_str().unwrap()); } diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/LICENSE b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/LICENSE new file mode 100644 index 00000000..971b81f9 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2018 Max Brunsfeld + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/assert.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/assert.h new file mode 100644 index 00000000..b1ef2f93 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/assert.h @@ -0,0 +1,14 @@ +#ifndef TREE_SITTER_WASM_ASSERT_H_ +#define TREE_SITTER_WASM_ASSERT_H_ + +#ifdef NDEBUG +#define assert(e) ((void)0) +#else +__attribute__((noreturn)) static inline void __assert_fail(const char *assertion, const char *file, unsigned line, const char *function) { + __builtin_trap(); +} +#define assert(expression) \ + ((expression) ? (void)0 : __assert_fail(#expression, __FILE__, __LINE__, __func__)) +#endif + +#endif // TREE_SITTER_WASM_ASSERT_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/ctype.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/ctype.h new file mode 100644 index 00000000..cea32970 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/ctype.h @@ -0,0 +1,8 @@ +#ifndef TREE_SITTER_WASM_CTYPE_H_ +#define TREE_SITTER_WASM_CTYPE_H_ + +static inline int isprint(int c) { + return c >= 0x20 && c <= 0x7E; +} + +#endif // TREE_SITTER_WASM_CTYPE_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/endian.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/endian.h new file mode 100644 index 00000000..f35a5962 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/endian.h @@ -0,0 +1,12 @@ +#ifndef TREE_SITTER_WASM_ENDIAN_H_ +#define TREE_SITTER_WASM_ENDIAN_H_ + +#define be16toh(x) __builtin_bswap16(x) +#define be32toh(x) __builtin_bswap32(x) +#define be64toh(x) __builtin_bswap64(x) +#define le16toh(x) (x) +#define le32toh(x) (x) +#define le64toh(x) (x) + + +#endif // TREE_SITTER_WASM_ENDIAN_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/inttypes.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/inttypes.h new file mode 100644 index 00000000..f5cccd07 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/inttypes.h @@ -0,0 +1,8 @@ +#ifndef TREE_SITTER_WASM_INTTYPES_H_ +#define TREE_SITTER_WASM_INTTYPES_H_ + +// https://github.com/llvm/llvm-project/blob/0c3cf200f5b918fb5c1114e9f1764c2d54d1779b/libc/include/llvm-libc-macros/inttypes-macros.h#L209 + +#define PRId32 "d" + +#endif // TREE_SITTER_WASM_INTTYPES_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdint.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdint.h new file mode 100644 index 00000000..10cc35dc --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdint.h @@ -0,0 +1,46 @@ +#ifndef TREE_SITTER_WASM_STDINT_H_ +#define TREE_SITTER_WASM_STDINT_H_ + +// https://github.com/llvm/llvm-project/blob/0c3cf200f5b918fb5c1114e9f1764c2d54d1779b/clang/test/Preprocessor/init.c#L1672 + +typedef signed char int8_t; + +typedef short int16_t; + +typedef int int32_t; + +typedef long long int int64_t; + +typedef unsigned char uint8_t; + +typedef unsigned short uint16_t; + +typedef unsigned int uint32_t; + +typedef long long unsigned int uint64_t; + +typedef long unsigned int size_t; + +typedef long unsigned int uintptr_t; + +#define INT8_MAX 127 +#define INT16_MAX 32767 +#define INT32_MAX 2147483647L +#define INT64_MAX 9223372036854775807LL + +#define UINT8_MAX 255 +#define UINT16_MAX 65535 +#define UINT32_MAX 4294967295U +#define UINT64_MAX 18446744073709551615ULL + +#if defined(__wasm32__) + +#define SIZE_MAX 4294967295UL + +#elif defined(__wasm64__) + +#define SIZE_MAX 18446744073709551615UL + +#endif + +#endif // TREE_SITTER_WASM_STDINT_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdio.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdio.h new file mode 100644 index 00000000..4089cccc --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdio.h @@ -0,0 +1,36 @@ +#ifndef TREE_SITTER_WASM_STDIO_H_ +#define TREE_SITTER_WASM_STDIO_H_ + +#include +#include + +typedef struct FILE FILE; + +typedef __builtin_va_list va_list; +#define va_start(ap, last) __builtin_va_start(ap, last) +#define va_end(ap) __builtin_va_end(ap) +#define va_arg(ap, type) __builtin_va_arg(ap, type) + +#define stdout ((FILE *)0) + +#define stderr ((FILE *)1) + +#define stdin ((FILE *)2) + +int fclose(FILE *stream); + +FILE *fdopen(int fd, const char *mode); + +int fputc(int c, FILE *stream); + +int fputs(const char *restrict s, FILE *restrict stream); + +size_t fwrite(const void *restrict buffer, size_t size, size_t nmemb, FILE *restrict stream); + +int fprintf(FILE *restrict stream, const char *restrict format, ...); + +int snprintf(char *restrict buffer, size_t buffsz, const char *restrict format, ...); + +int vsnprintf(char *restrict buffer, size_t buffsz, const char *restrict format, va_list vlist); + +#endif // TREE_SITTER_WASM_STDIO_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdlib.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdlib.h new file mode 100644 index 00000000..2da313ab --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/stdlib.h @@ -0,0 +1,15 @@ +#ifndef TREE_SITTER_WASM_STDLIB_H_ +#define TREE_SITTER_WASM_STDLIB_H_ + +#include + +#define NULL ((void*)0) + +void* malloc(size_t); +void* calloc(size_t, size_t); +void free(void*); +void* realloc(void*, size_t); + +__attribute__((noreturn)) void abort(void); + +#endif // TREE_SITTER_WASM_STDLIB_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/string.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/string.h new file mode 100644 index 00000000..ddbf1e3f --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/string.h @@ -0,0 +1,22 @@ +#ifndef TREE_SITTER_WASM_STRING_H_ +#define TREE_SITTER_WASM_STRING_H_ + +#include + +void *memchr(const void *src, int c, size_t n); + +int memcmp(const void *lhs, const void *rhs, size_t count); + +void *memcpy(void *restrict dst, const void *restrict src, size_t size); + +void *memmove(void *dst, const void *src, size_t count); + +void *memset(void *dst, int value, size_t count); + +char *strchr(const char *str, int c); + +size_t strlen(const char *str); + +int strncmp(const char *left, const char *right, size_t n); + +#endif // TREE_SITTER_WASM_STRING_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/wctype.h b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/wctype.h new file mode 100644 index 00000000..9c20bad1 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/include/wctype.h @@ -0,0 +1,176 @@ +#ifndef TREE_SITTER_WASM_WCTYPE_H_ +#define TREE_SITTER_WASM_WCTYPE_H_ + +#include + +typedef int wint_t; + +int iswlower(wint_t wch); + +int iswupper(wint_t wch); + +int iswpunct(wint_t wch); + +static inline bool iswalpha(wint_t wch) { + switch (wch) { + case L'a': + case L'b': + case L'c': + case L'd': + case L'e': + case L'f': + case L'g': + case L'h': + case L'i': + case L'j': + case L'k': + case L'l': + case L'm': + case L'n': + case L'o': + case L'p': + case L'q': + case L'r': + case L's': + case L't': + case L'u': + case L'v': + case L'w': + case L'x': + case L'y': + case L'z': + case L'A': + case L'B': + case L'C': + case L'D': + case L'E': + case L'F': + case L'G': + case L'H': + case L'I': + case L'J': + case L'K': + case L'L': + case L'M': + case L'N': + case L'O': + case L'P': + case L'Q': + case L'R': + case L'S': + case L'T': + case L'U': + case L'V': + case L'W': + case L'X': + case L'Y': + case L'Z': + return true; + default: + return false; + } +} + +static inline bool iswdigit(wint_t wch) { + switch (wch) { + case L'0': + case L'1': + case L'2': + case L'3': + case L'4': + case L'5': + case L'6': + case L'7': + case L'8': + case L'9': + return true; + default: + return false; + } +} + +static inline bool iswalnum(wint_t wch) { + switch (wch) { + case L'a': + case L'b': + case L'c': + case L'd': + case L'e': + case L'f': + case L'g': + case L'h': + case L'i': + case L'j': + case L'k': + case L'l': + case L'm': + case L'n': + case L'o': + case L'p': + case L'q': + case L'r': + case L's': + case L't': + case L'u': + case L'v': + case L'w': + case L'x': + case L'y': + case L'z': + case L'A': + case L'B': + case L'C': + case L'D': + case L'E': + case L'F': + case L'G': + case L'H': + case L'I': + case L'J': + case L'K': + case L'L': + case L'M': + case L'N': + case L'O': + case L'P': + case L'Q': + case L'R': + case L'S': + case L'T': + case L'U': + case L'V': + case L'W': + case L'X': + case L'Y': + case L'Z': + case L'0': + case L'1': + case L'2': + case L'3': + case L'4': + case L'5': + case L'6': + case L'7': + case L'8': + case L'9': + return true; + default: + return false; + } +} + +static inline bool iswspace(wint_t wch) { + switch (wch) { + case L' ': + case L'\t': + case L'\n': + case L'\v': + case L'\f': + case L'\r': + return true; + default: + return false; + } +} + +#endif // TREE_SITTER_WASM_WCTYPE_H_ diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/stdio.c b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/stdio.c new file mode 100644 index 00000000..470c1ecc --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/stdio.c @@ -0,0 +1,299 @@ +#include +#include + +typedef struct { + bool left_justify; // - + bool zero_pad; // 0 + bool show_sign; // + + bool space_prefix; // ' ' + bool alternate_form; // # +} format_flags_t; + +static const char* parse_format_spec( + const char *format, + int *width, + int *precision, + format_flags_t *flags +) { + *width = 0; + *precision = -1; + flags->left_justify = false; + flags->zero_pad = false; + flags->show_sign = false; + flags->space_prefix = false; + flags->alternate_form = false; + + const char *p = format; + + // Parse flags + while (*p == '-' || *p == '+' || *p == ' ' || *p == '#' || *p == '0') { + switch (*p) { + case '-': flags->left_justify = true; break; + case '0': flags->zero_pad = true; break; + case '+': flags->show_sign = true; break; + case ' ': flags->space_prefix = true; break; + case '#': flags->alternate_form = true; break; + } + p++; + } + + // width + while (*p >= '0' && *p <= '9') { + *width = (*width * 10) + (*p - '0'); + p++; + } + + // precision + if (*p == '.') { + p++; + *precision = 0; + while (*p >= '0' && *p <= '9') { + *precision = (*precision * 10) + (*p - '0'); + p++; + } + } + + return p; +} + +static int int_to_str( + long long value, + char *buffer, + int base, + bool is_signed, + bool uppercase +) { + if (base < 2 || base > 16) return 0; + + const char *digits = uppercase ? "0123456789ABCDEF" : "0123456789abcdef"; + char temp[32]; + int i = 0, len = 0; + bool is_negative = false; + + if (value == 0) { + buffer[0] = '0'; + buffer[1] = '\0'; + return 1; + } + + if (is_signed && value < 0 && base == 10) { + is_negative = true; + value = -value; + } + + unsigned long long uval = (unsigned long long)value; + while (uval > 0) { + temp[i++] = digits[uval % base]; + uval /= base; + } + + if (is_negative) { + buffer[len++] = '-'; + } + + while (i > 0) { + buffer[len++] = temp[--i]; + } + + buffer[len] = '\0'; + return len; +} + +static int ptr_to_str(void *ptr, char *buffer) { + buffer[0] = '0'; + buffer[1] = 'x'; + int len = int_to_str((uintptr_t)ptr, buffer + 2, 16, 0, 0); + return 2 + len; +} + +char *strncpy(char *dest, const char *src, size_t n) { + char *d = dest; + const char *s = src; + while (n-- && (*d++ = *s++)); + if (n == (size_t)-1) *d = '\0'; + return dest; +} + +static int write_formatted_to_buffer( + char *buffer, + size_t buffer_size, + size_t *pos, + const char *str, + int width, + const format_flags_t *flags +) { + int len = strlen(str); + int written = 0; + int pad_len = (width > len) ? (width - len) : 0; + int zero_pad = flags->zero_pad && !flags->left_justify; + + if (!flags->left_justify && pad_len > 0) { + char pad_char = zero_pad ? '0' : ' '; + for (int i = 0; i < pad_len && *pos < buffer_size - 1; i++) { + buffer[(*pos)++] = pad_char; + written++; + } + } + + for (int i = 0; i < len && *pos < buffer_size - 1; i++) { + buffer[(*pos)++] = str[i]; + written++; + } + + if (flags->left_justify && pad_len > 0) { + for (int i = 0; i < pad_len && *pos < buffer_size - 1; i++) { + buffer[(*pos)++] = ' '; + written++; + } + } + + return written; +} + +static int vsnprintf_impl(char *buffer, size_t buffsz, const char *format, va_list args) { + if (!buffer || buffsz == 0 || !format) return -1; + + size_t pos = 0; + int total_chars = 0; + const char *p = format; + + while (*p) { + if (*p == '%') { + p++; + if (*p == '%') { + if (pos < buffsz - 1) buffer[pos++] = '%'; + total_chars++; + p++; + continue; + } + + int width, precision; + format_flags_t flags; + p = parse_format_spec(p, &width, &precision, &flags); + + char temp_buf[64]; + const char *output_str = temp_buf; + + switch (*p) { + case 's': { + const char *str = va_arg(args, const char*); + if (!str) str = "(null)"; + + int str_len = strlen(str); + if (precision >= 0 && str_len > precision) { + strncpy(temp_buf, str, precision); + temp_buf[precision] = '\0'; + output_str = temp_buf; + } else { + output_str = str; + } + break; + } + case 'd': + case 'i': { + int value = va_arg(args, int); + int_to_str(value, temp_buf, 10, true, false); + break; + } + case 'u': { + unsigned int value = va_arg(args, unsigned int); + int_to_str(value, temp_buf, 10, false, false); + break; + } + case 'x': { + unsigned int value = va_arg(args, unsigned int); + int_to_str(value, temp_buf, 16, false, false); + break; + } + case 'X': { + unsigned int value = va_arg(args, unsigned int); + int_to_str(value, temp_buf, 16, false, true); + break; + } + case 'p': { + void *ptr = va_arg(args, void*); + ptr_to_str(ptr, temp_buf); + break; + } + case 'c': { + int c = va_arg(args, int); + temp_buf[0] = (char)c; + temp_buf[1] = '\0'; + break; + } + case 'z': { + if (*(p + 1) == 'u') { + size_t value = va_arg(args, size_t); + int_to_str(value, temp_buf, 10, false, false); + p++; + } else { + temp_buf[0] = 'z'; + temp_buf[1] = '\0'; + } + break; + } + default: + temp_buf[0] = '%'; + temp_buf[1] = *p; + temp_buf[2] = '\0'; + break; + } + + int str_len = strlen(output_str); + int formatted_len = (width > str_len) ? width : str_len; + total_chars += formatted_len; + + if (pos < buffsz - 1) { + write_formatted_to_buffer(buffer, buffsz, &pos, output_str, width, &flags); + } + + } else { + if (pos < buffsz - 1) buffer[pos++] = *p; + total_chars++; + } + p++; + } + + if (buffsz > 0) buffer[pos < buffsz ? pos : buffsz - 1] = '\0'; + + return total_chars; +} + +int snprintf(char *restrict buffer, size_t buffsz, const char *restrict format, ...) { + if (!buffer || buffsz == 0 || !format) return -1; + + va_list args; + va_start(args, format); + int result = vsnprintf_impl(buffer, buffsz, format, args); + va_end(args); + + return result; +} + +int vsnprintf(char *restrict buffer, size_t buffsz, const char *restrict format, va_list vlist) { + return vsnprintf_impl(buffer, buffsz, format, vlist); +} + +int fclose(FILE *stream) { + return 0; +} + +FILE* fdopen(int fd, const char *mode) { + return 0; +} + +int fputc(int c, FILE *stream) { + return c; +} + +int fputs(const char *restrict str, FILE *restrict stream) { + return 0; +} + +size_t fwrite(const void *restrict buffer, size_t size, size_t nmemb, FILE *restrict stream) { + return size * nmemb; +} + +int fprintf(FILE *restrict stream, const char *restrict format, ...) { + return 0; +} diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/stdlib.c b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/stdlib.c new file mode 100644 index 00000000..9719e254 --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/stdlib.c @@ -0,0 +1,163 @@ +// This file implements a very simple allocator for external scanners running +// in Wasm. Allocation is just bumping a static pointer and growing the heap +// as needed, and freeing is just adding the freed region to a free list. +// When additional memory is allocated, the free list is searched first. +// If there is not a suitable region in the free list, the heap is +// grown as necessary, and the allocation is made at the end of the heap. +// When the heap is reset, all allocated memory is considered freed. + +#include +#include +#include + +extern void tree_sitter_debug_message(const char *, size_t); + +#define PAGESIZE 0x10000 +#define MAX_HEAP_SIZE (1024 * 1024 * 1024) + +typedef struct { + size_t size; + struct Region *next; + char data[0]; +} Region; + +static Region *heap_end = NULL; +static Region *heap_start = NULL; +static Region *next = NULL; +static Region *free_list = NULL; + +// Get the region metadata for the given heap pointer. +static inline Region *region_for_ptr(void *ptr) { + return ((Region *)ptr) - 1; +} + +// Get the location of the next region after the given region, +// if the given region had the given size. +static inline Region *region_after(Region *self, size_t len) { + char *address = self->data + len; + char *aligned = (char *)((uintptr_t)(address + 3) & ~0x3); + return (Region *)aligned; +} + +static void *get_heap_end() { + return (void *)(__builtin_wasm_memory_size(0) * PAGESIZE); +} + +static int grow_heap(size_t size) { + size_t new_page_count = ((size - 1) / PAGESIZE) + 1; + return __builtin_wasm_memory_grow(0, new_page_count) != SIZE_MAX; +} + +// Grows the heap if necessary to fit a region at the _end_ of the heap +// ending at `region_end` by `size` bytes. +// +// Returns 0 if the heap could not be grown, 1 otherwise. +static inline int grow_heap_for_region(Region *region_end, size_t size) { + if (region_end > heap_end) { + if ((char *)region_end - (char *)heap_start > MAX_HEAP_SIZE) return 0; + if (!grow_heap(size)) return 0; + heap_end = get_heap_end(); + } + return 1; +} + +// Clear out the heap, and move it to the given address. +void reset_heap(void *new_heap_start) { + heap_start = new_heap_start; + next = new_heap_start; + heap_end = get_heap_end(); + free_list = NULL; +} + +void *malloc(size_t size) { + if (size == 0) return NULL; + + Region *prev = NULL; + Region *curr = free_list; + while (curr != NULL) { + if (curr->size >= size) { + if (prev == NULL) { + free_list = curr->next; + } else { + prev->next = curr->next; + } + return &curr->data; + } + prev = curr; + curr = curr->next; + } + + Region *region_end = region_after(next, size); + + if (!grow_heap_for_region(region_end, size)) return NULL; + + void *result = &next->data; + next->size = size; + next = region_end; + + return result; +} + +void free(void *ptr) { + if (ptr == NULL) return; + + Region *region = region_for_ptr(ptr); + Region *region_end = region_after(region, region->size); + + // When freeing the last allocated pointer, re-use that + // pointer for the next allocation. + if (region_end == next) { + next = region; + } else { + region->next = free_list; + free_list = region; + } +} + +void *calloc(size_t count, size_t size) { + void *result = malloc(count * size); + if (!result) return NULL; + memset(result, 0, count * size); + return result; +} + +void *realloc(void *ptr, size_t new_size) { + if (ptr == NULL) { + return malloc(new_size); + } + if (new_size == 0) { + free(ptr); + return NULL; + } + + + Region *region = region_for_ptr(ptr); + Region *region_end = region_after(region, region->size); + + // When reallocating the last allocated region, resize + // in place if possible, return the same pointer, and + // skip copying the data. + if (region_end == next) { + Region *new_region_end = region_after(region, new_size); + + size_t additional_size = (char *)new_region_end - (char *)heap_end; + if (!grow_heap_for_region(new_region_end, additional_size)) return NULL; + + region->size = new_size; + next = new_region_end; + return ®ion->data; + } + + void *result = malloc(new_size); + if (!result) return NULL; + + size_t copy_size = region->size < new_size ? region->size : new_size; + memcpy(result, ®ion->data, copy_size); + + free(ptr); + return result; +} + +__attribute__((noreturn)) void abort(void) { + __builtin_trap(); +} diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/string.c b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/string.c new file mode 100644 index 00000000..3f1b9a0f --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/string.c @@ -0,0 +1,84 @@ +#include + +// Derived from musl (MIT): https://git.musl-libc.org/cgit/musl/tree/src/string/memchr.c +void *memchr(const void *src, int c, size_t n) { + const unsigned char *s = src; + c = (unsigned char)c; + for (; n && *s != c; s++, n--); + return n ? (void *)s : 0; +} + +int memcmp(const void *lhs, const void *rhs, size_t count) { + const unsigned char *l = lhs; + const unsigned char *r = rhs; + while (count--) { + if (*l != *r) { + return *l - *r; + } + l++; + r++; + } + return 0; +} + +void *memcpy(void *restrict dst, const void *restrict src, size_t size) { + unsigned char *d = dst; + const unsigned char *s = src; + while (size--) { + *d++ = *s++; + } + return dst; +} + +void *memmove(void *dst, const void *src, size_t count) { + unsigned char *d = dst; + const unsigned char *s = src; + if (d < s) { + while (count--) { + *d++ = *s++; + } + } else if (d > s) { + d += count; + s += count; + while (count--) { + *(--d) = *(--s); + } + } + return dst; +} + +void *memset(void *dst, int value, size_t count) { + unsigned char *p = dst; + while (count--) { + *p++ = (unsigned char)value; + } + return dst; +} + +char *strchr(const char *str, int c) { + while (*str != (char)c) { + if (*str == '\0') { + return 0; + } + str++; + } + return (char *)str; +} + +size_t strlen(const char *str) { + const char *s = str; + while (*s) s++; + return s - str; +} + +int strncmp(const char *left, const char *right, size_t n) { + while (n-- > 0) { + if (*left != *right) { + return *(unsigned char *)left - *(unsigned char *)right; + } + if (*left == '\0') break; + left++; + right++; + } + return 0; +} diff --git a/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/wctype.c b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/wctype.c new file mode 100644 index 00000000..4bcc276f --- /dev/null +++ b/tree-sitter-ggsql/bindings/rust/wasm-sysroot/src/wctype.c @@ -0,0 +1,16 @@ +#include + +int iswlower(wint_t wch) { + return (unsigned)wch - L'a' < 26; +} + +int iswupper(wint_t wch) { + return (unsigned)wch - L'A' < 26; +} + +int iswpunct(wint_t wch) { + return (wch >= 33 && wch <= 47) || + (wch >= 58 && wch <= 64) || + (wch >= 91 && wch <= 96) || + (wch >= 123 && wch <= 126); +}