diff --git a/.gitignore b/.gitignore index 06e87f2..3632b66 100644 --- a/.gitignore +++ b/.gitignore @@ -45,6 +45,9 @@ pnpm-debug.log* # Local AI agent instructions .github/copilot-instructions.md +# Local agent state +.agents/state/ + # MkDocs build output site/ diff --git a/docs/architecture.md b/docs/architecture.md index b7dfb54..cf7ca86 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -129,40 +129,40 @@ Session is intentionally outside RFC 002’s normative emitted contract. It cons For current package behavior, see [Execution context (Reference)][execution-ref] and [Execution context (Explanation)][execution-expl]. -## Current implementation shape +## Package implementation shape -The package currently uses the following implementation shape: +The package uses the following implementation shape: - author-facing carrier types live in [mod.incn](../src/dataset/mod.incn) - canonical relational operator helpers live in [ops.incn](../src/dataset/ops.incn) - Substrait emission lives under [substrait/](../src/substrait/) - Prism internals live under [prism/](../src/prism/) -- `LazyFrame[T]` currently routes through a backend-native `PrismCursor[T]` -- `DataFrame[T]` and `DataStream[T]` are not yet fully converged on the same internal backing model as `LazyFrame[T]` +- `LazyFrame[T]` routes through a backend-native `PrismCursor[T]` +- `DataFrame[T]` and `DataStream[T]` keep their carrier-specific backing shapes while sharing the public dataset surface -This is enough to explain the package architecture while keeping current API behavior in language docs and follow-on gaps in RFCs, issues, and release notes. +This keeps package architecture in this document while detailed API behavior lives in language docs and future surface expansion stays in RFCs, issues, and release notes. ## Repository layout -| Path | Role | -| -------------------------------- | -------------------------------------------------- | -| `incan.toml` | Package metadata and Rust dependency declarations | -| `src/lib.incn` | Public package exports | -| `src/dataset/mod.incn` | Carrier types and trait surface | -| `src/dataset/ops.incn` | Canonical relational operator helpers | -| `src/prism/mod.incn` | Internal Prism graph, cursor, and lowering logic | -| `src/substrait/relations.incn` | Concrete `Rel` builders and relation lowering | -| `src/substrait/plans.incn` | Top-level `Plan` assembly helpers | -| `src/substrait/inspect.incn` | Relation/plan inspection and output-column helpers | -| `src/substrait/schema_registry.incn` | Named-table schema registration and lookup | -| `src/substrait/extensions.incn` | Extension anchors, URIs, and declaration helpers | -| `src/substrait/expr_lowering.incn` | Builder-to-Substrait expression lowering | -| `src/substrait/conformance.incn` | Typed conformance facade over catalog + validators | -| `src/substrait/schema.incn` | Model/schema to Substrait type bridging | -| `tests/` | Package tests run through `incan test` | -| `docs/language/` | Current package docs | -| `docs/rfcs/` | Normative RFC series | -| `docs/release_notes/` | Release-facing notes | +| Path | Role | +| ------------------------------------ | -------------------------------------------------- | +| `incan.toml` | Package metadata and Rust dependency declarations | +| `src/lib.incn` | Public package exports | +| `src/dataset/mod.incn` | Carrier types and trait surface | +| `src/dataset/ops.incn` | Canonical relational operator helpers | +| `src/prism/mod.incn` | Internal Prism graph, cursor, and lowering logic | +| `src/substrait/relations.incn` | Concrete `Rel` builders and relation lowering | +| `src/substrait/plans.incn` | Top-level `Plan` assembly helpers | +| `src/substrait/inspect.incn` | Relation/plan inspection and output-column helpers | +| `src/substrait/schema_registry.incn` | Named-table schema registration and lookup | +| `src/substrait/extensions.incn` | Extension anchors, URIs, and declaration helpers | +| `src/substrait/expr_lowering.incn` | Builder-to-Substrait expression lowering | +| `src/substrait/conformance.incn` | Typed conformance facade over catalog + validators | +| `src/substrait/schema.incn` | Model/schema to Substrait type bridging | +| `tests/` | Package tests run through `incan test` | +| `docs/language/` | Current package docs | +| `docs/rfcs/` | Normative RFC series | +| `docs/release_notes/` | Release-facing notes | Normative behavior lives in the RFC series first. Current package behavior and usage belong in the language docs. If code and RFCs disagree, treat that as a bug or transition state to resolve explicitly. diff --git a/docs/language/explanation/dataset_carriers.md b/docs/language/explanation/dataset_carriers.md index 72fa6d7..378f480 100644 --- a/docs/language/explanation/dataset_carriers.md +++ b/docs/language/explanation/dataset_carriers.md @@ -56,11 +56,11 @@ Use `LazyFrame[T]` when you want to compose operations before execution: ```incan from pub::inql import LazyFrame -from pub::inql.functions import col, gt, int_lit +from pub::inql.functions import col, gt, lit from models import Order def high_value_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.filter(gt(col("amount"), int_lit(100))) + return orders.filter(gt(col("amount"), lit(100))) ``` ### `DataStream[T]` — streaming @@ -69,11 +69,11 @@ Use `DataStream[T]` for streaming/unbounded data: ```incan from pub::inql import DataStream -from pub::inql.functions import col, eq, str_lit +from pub::inql.functions import col, eq, lit from models import Event def important_events(events: DataStream[Event]) -> DataStream[Event]: - return events.filter(eq(col("severity"), str_lit("critical"))) + return events.filter(eq(col("severity"), lit("critical"))) ``` `DataStream[T]` shares the same operation API as batch carriers, but signals that its source is unbounded. Static streaming constraints are specified in RFC 001 and enforced as the compiler gains analysis for `UnboundedDataSet[T]`. @@ -122,9 +122,9 @@ Current relational authoring is explicit and builder-based. That is deliberate: Today there are three concrete builder families: -- filters: `eq(...)`, `gt(...)`, `int_lit(...)`, `str_lit(...)` +- filters: `eq(...)`, `gt(...)`, `lit(...)` - aggregates: `col(...)`, `sum(...)`, `count()` -- projections: `with_column(...)`, `add(...)`, `mul(...)`, `int_expr(...)`, `str_expr(...)`, `bool_expr(...)` +- projections: `with_column(...)`, `add(...)`, `mul(...)`, `lit(...)` ### Aggregate helpers @@ -148,15 +148,15 @@ That is the current semantic target for future sugar such as `.customer_id` or ` Computed columns now have one real entrypoint: `with_column(name, expr)`. ```incan -from pub::inql.functions import add, col, int_expr, mul +from pub::inql.functions import add, col, lit, mul from pub::inql import LazyFrame from models import Order def enrich_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: return ( orders - .with_column("amount_x2", mul(col("amount"), int_expr(2))) - .with_column("amount_plus_one", add(col("amount"), int_expr(1))) + .with_column("amount_x2", mul(col("amount"), lit(2))) + .with_column("amount_plus_one", add(col("amount"), lit(1))) ) ``` @@ -177,14 +177,14 @@ The most useful way to read the current surface is to separate: This is real current InQL, not aspirational pseudocode: ```incan -from pub::inql.functions import add, col, count, int_expr, sum +from pub::inql.functions import add, col, count, lit, sum from pub::inql import LazyFrame from models import Order def summarize_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: grouped = ( orders - .with_column("amount_plus_one", add(col("amount"), int_expr(1))) + .with_column("amount_plus_one", add(col("amount"), lit(1))) .group_by([col("customer_id")]) .agg([sum(col("amount")), count()]) ) diff --git a/docs/language/explanation/execution_context.md b/docs/language/explanation/execution_context.md index 466de64..bae51b7 100644 --- a/docs/language/explanation/execution_context.md +++ b/docs/language/explanation/execution_context.md @@ -52,36 +52,36 @@ This is the boundary where deferred relational work becomes local data in hand. Some convenience APIs are nicer when they do not force the session parameter through every call site. `lazy.collect()` is one of those cases. -That convenience still needs a real execution context underneath, so it resolves through the active session at call time. +That convenience needs a real execution context underneath, so it resolves through the active session at call time. - `session.activate()` sets the current active session - `lazy.collect()` uses that active session If there is no active session, the convenience API fails clearly instead of pretending execution context can be ambient without definition. -## Writing is still Session-owned +## Writing is Session-owned `session.write_csv(...)` and `session.write_parquet(...)` remain explicit Session methods because writing is not just a carrier concern. It requires binding, execution, and sink ownership. -So the current ergonomic split is: +The ergonomic split is: - convenience materialization: `lazy.collect()` - explicit writes: `session.write_csv(...)`, `session.write_parquet(...)` -This is a current package ergonomics choice, not a statement that all future convenience APIs must keep the same shape. +This keeps materialization convenient while leaving sink ownership explicit at the session boundary. ## Typical flow ```incan -from pub::inql import Session -from pub::inql.functions import col, gt, int_expr, int_lit, mul +from pub::inql import LazyFrame, Session +from pub::inql.functions import col, gt, lit, mul from models import Order session = Session.default() -orders = session.read_csv[Order]("orders", "orders.csv")? -enriched = orders.with_column("amount_x2", mul(col("amount"), int_expr(2))) -filtered = enriched.filter(gt(col("amount"), int_lit(100))).limit(10) +orders: LazyFrame[Order] = session.read_csv("orders", "orders.csv")? +enriched = orders.with_column("amount_x2", mul(col("amount"), lit(2))) +filtered = enriched.filter(gt(col("amount"), lit(100))).limit(10) session.activate() preview = filtered.collect()? @@ -98,14 +98,14 @@ This pattern is intentionally simple: For the exact method surface, see [Dataset methods (Reference)](../reference/dataset_methods.md). -## Current limitation +## Materialized carrier shape -`DataFrame[T]` is already the materialized carrier, but its row-level user API is still intentionally narrow. The important current semantic distinction is already in place: +`DataFrame[T]` is the materialized carrier. The important semantic distinction is: - `LazyFrame[T]` = deferred - `DataFrame[T]` = local materialized -Today that materialized carrier exposes structured collection metadata first: +The materialized carrier exposes structured collection metadata: - resolved columns - row count diff --git a/docs/language/reference/builders/aggregates.md b/docs/language/reference/builders/aggregates.md index d9830f7..b066e43 100644 --- a/docs/language/reference/builders/aggregates.md +++ b/docs/language/reference/builders/aggregates.md @@ -1,24 +1,25 @@ # Aggregate builders (Reference) -Current aggregate authoring is explicit and builder-based. +Current aggregate authoring is explicit and scalar-expression-based. ## Functions -| Builder | Signature | Meaning | -| ------- | ----------------------------------------------- | ---------------------------------------------------------------------- | -| `col` | `def col(name: str) -> ColumnExpr` | Column reference builder used by aggregates, filters, and projections. | -| `sum` | `def sum(expr: ColumnExpr) -> AggregateMeasure` | Sum one selected numeric column. | -| `count` | `def count() -> AggregateMeasure` | Count rows in the current relation or group. | +| Builder | Signature | Meaning | +| ------- | ----------------------------------------------------------- | ---------------------------------------------------------------------- | +| `col` | `def col(name: str) -> ColumnExpr` | Column reference builder used by aggregates, filters, and projections. | +| `lit` | `def lit(value: int \| float \| str \| bool) -> ColumnExpr` | Canonical scalar literal helper. | +| `sum` | `def sum(expr: ColumnExpr) -> AggregateMeasure` | Sum one scalar expression. | +| `count` | `def count() -> AggregateMeasure` | Count rows in the current relation or group. | ## Example ```incan -from pub::inql.functions import col, count, sum +from pub::inql.functions import add, col, count, lit, sum -grouped = orders.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) +grouped = orders.group_by([col("customer_id")]).agg([sum(add(col("amount"), lit(5))), count()]) ``` ## Notes -- The current package slice requires explicit `col(...)` builders. +- Aggregate inputs use the same scalar-expression model as filters, projections, and grouping keys. - Future `.column` sugar and scoped aggregate symbols should lower to this same surface rather than replacing its semantics. diff --git a/docs/language/reference/builders/filters.md b/docs/language/reference/builders/filters.md index d49c3e0..fd2a033 100644 --- a/docs/language/reference/builders/filters.md +++ b/docs/language/reference/builders/filters.md @@ -1,32 +1,34 @@ # Filter builders (Reference) -Current filter authoring is explicit and builder-based. +Current filter authoring uses the shared scalar-expression builder model. ## Functions -| Builder | Signature | Meaning | -| -------------- | ----------------------------------------------------------------------- | ------------------------------------------------------ | -| `always_true` | `def always_true() -> FilterPredicate` | Trivial predicate; canonical rewrite can eliminate it. | -| `always_false` | `def always_false() -> FilterPredicate` | Predicate that rejects every row. | -| `eq` | `def eq(column: ColumnExpr, literal: FilterLiteral) -> FilterPredicate` | Equality predicate. | -| `gt` | `def gt(column: ColumnExpr, literal: FilterLiteral) -> FilterPredicate` | Greater-than predicate. | -| `int_lit` | `def int_lit(value: int) -> FilterLiteral` | Integer literal for filter predicates. | -| `str_lit` | `def str_lit(value: str) -> FilterLiteral` | String literal for filter predicates. | -| `bool_lit` | `def bool_lit(value: bool) -> FilterLiteral` | Boolean literal for filter predicates. | +| Builder | Signature | Meaning | +| -------------- | ----------------------------------------------------------- | ---------------------------------------------------------------------- | +| `always_true` | `def always_true() -> ColumnExpr` | Trivial boolean scalar expression; canonical rewrite can eliminate it. | +| `always_false` | `def always_false() -> ColumnExpr` | Boolean scalar expression that rejects every row. | +| `eq` | `def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr` | Equality predicate scalar expression. | +| `gt` | `def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr` | Greater-than predicate scalar expression. | +| `lit` | `def lit(value: int \| float \| str \| bool) -> ColumnExpr` | Canonical scalar literal helper. | +| `int_lit` | `def int_lit(value: int) -> ColumnExpr` | Typed integer literal helper. | +| `str_lit` | `def str_lit(value: str) -> ColumnExpr` | Typed string literal helper. | +| `bool_lit` | `def bool_lit(value: bool) -> ColumnExpr` | Typed boolean literal helper. | ## Example ```incan -from pub::inql.functions import col, eq, gt, int_lit, str_lit +from pub::inql.functions import col, eq, gt, lit filtered = ( orders - .filter(gt(col("amount"), int_lit(100))) - .filter(eq(col("status"), str_lit("open"))) + .filter(gt(col("amount"), lit(100))) + .filter(eq(col("status"), lit("open"))) ) ``` ## Notes -- Filter predicates currently operate on one explicit column builder plus one explicit literal. -- Rich boolean composition is follow-up work. +- Filter predicates are scalar expressions, not a separate predicate-only builder hierarchy. +- The typed `*_lit(...)` helpers construct the same scalar-literal representation as `lit(...)`. +- Boolean composition belongs to the broader scalar-function surface. diff --git a/docs/language/reference/builders/projections.md b/docs/language/reference/builders/projections.md index 37ccd60..fe12df1 100644 --- a/docs/language/reference/builders/projections.md +++ b/docs/language/reference/builders/projections.md @@ -1,18 +1,21 @@ # Projection builders (Reference) -Projection builders are the current semantic target for computed columns. +Projection builders are the current semantic target for scalar expressions in computed columns and other row-level positions. ## Functions | Builder | Signature | Meaning | | ------------ | ------------------------------------------------------------ | --------------------------- | | `col` | `def col(name: str) -> ColumnExpr` | Named column reference. | +| `lit` | `def lit(value: int \| float \| str \| bool) -> ColumnExpr` | Canonical scalar literal. | | `int_expr` | `def int_expr(value: int) -> ColumnExpr` | Integer literal expression. | | `float_expr` | `def float_expr(value: float) -> ColumnExpr` | Float literal expression. | | `str_expr` | `def str_expr(value: str) -> ColumnExpr` | String literal expression. | | `bool_expr` | `def bool_expr(value: bool) -> ColumnExpr` | Boolean literal expression. | | `add` | `def add(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr` | Binary addition. | | `mul` | `def mul(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr` | Binary multiplication. | +| `eq` | `def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr` | Equality predicate. | +| `gt` | `def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr` | Greater-than predicate. | ## Dataset entrypoint @@ -26,17 +29,17 @@ def with_column(self, name: str, expr: ColumnExpr) -> Self ## Example ```incan -from pub::inql.functions import add, col, int_expr, mul +from pub::inql.functions import add, col, lit, mul projected = ( orders - .with_column("amount_x2", mul(col("amount"), int_expr(2))) - .with_column("amount_plus_one", add(col("amount"), int_expr(1))) + .with_column("amount_x2", mul(col("amount"), lit(2))) + .with_column("amount_plus_one", add(col("amount"), lit(1))) ) ``` -## Current limits +## Capability notes -- No argument-bearing `select(...)` yet. -- No query-block projection sugar yet. -- No alias-free symbolic surface like `.amount * 2` yet. +- `with_column(...)` is the explicit computed-column entrypoint. +- Projection-list selection, query-block projection sugar, and alias-free symbolic surfaces lower to this scalar-expression model when exposed. +- The typed literal helpers construct the same scalar-literal representation as `lit(...)`. diff --git a/docs/language/reference/dataset_methods.md b/docs/language/reference/dataset_methods.md index f402300..9ee0963 100644 --- a/docs/language/reference/dataset_methods.md +++ b/docs/language/reference/dataset_methods.md @@ -11,17 +11,17 @@ The Substrait helper surface behind these methods is split by semantic role: ## Shared method surface -| Method | Signature | Meaning | -| ------------- | ------------------------------------------------------------ | --------------------------------------------------------------------------------------------- | -| `filter` | `def filter(self, predicate: FilterPredicate) -> Self` | Restrict rows by an explicit builder-backed predicate. | -| `join` | `def join(self, other: Self, on: bool) -> Self` | Combine with another same-carrier relation on the current placeholder join condition surface. | -| `select` | `def select(self) -> Self` | Identity projection today; full argument-bearing `select(...)` is follow-up work. | -| `with_column` | `def with_column(self, name: str, expr: ColumnExpr) -> Self` | Add or replace one projected column using an explicit projection builder expression. | -| `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using explicit column builders. | -| `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | -| `order_by` | `def order_by(self) -> Self` | Placeholder sort entrypoint until richer order expressions land. | -| `limit` | `def limit(self, n: int) -> Self` | Cap row count. | -| `explode` | `def explode(self) -> Self` | Expand a nested list column into rows. | +| Method | Signature | Meaning | +| ------------- | ------------------------------------------------------------ | ---------------------------------------------------------------------------------------------- | +| `filter` | `def filter(self, predicate: ColumnExpr) -> Self` | Restrict rows by a boolean scalar expression. | +| `join` | `def join(self, other: Self, on: bool) -> Self` | Combine with another same-carrier relation using the package's boolean join predicate surface. | +| `select` | `def select(self) -> Self` | Preserve the current projection shape as an identity projection. | +| `with_column` | `def with_column(self, name: str, expr: ColumnExpr) -> Self` | Add or replace one projected column using a scalar expression. | +| `group_by` | `def group_by(self, columns: list[ColumnExpr]) -> Self` | Define grouping keys using scalar expressions. | +| `agg` | `def agg(self, measures: list[AggregateMeasure]) -> Self` | Apply aggregate measures over the current relation or current grouping. | +| `order_by` | `def order_by(self) -> Self` | Preserve order-planning shape for the package sort boundary. | +| `limit` | `def limit(self, n: int) -> Self` | Cap row count. | +| `explode` | `def explode(self) -> Self` | Expand a nested list column into rows. | ## `with_column` @@ -36,36 +36,39 @@ def with_column(self, name: str, expr: ColumnExpr) -> Self - If `name` does not already exist, the new projected column is appended at the end. - If `name` already exists, that slot is replaced in place. - Replacement preserves ordinal position. -- The current first-slice expression surface is: +- The scalar-expression surface is: - `col(name)` + - `lit(value)` - `int_expr(...)` - `float_expr(...)` - `str_expr(...)` - `bool_expr(...)` - `add(left, right)` - `mul(left, right)` + - `eq(left, right)` + - `gt(left, right)` ### Example ```incan from pub::inql import LazyFrame -from pub::inql.functions import add, col, int_expr, mul +from pub::inql.functions import add, col, lit, mul from models import Order def enrich(orders: LazyFrame[Order]) -> LazyFrame[Order]: return ( orders - .with_column("amount_x2", mul(col("amount"), int_expr(2))) - .with_column("amount_plus_one", add(col("amount"), int_expr(1))) + .with_column("amount_x2", mul(col("amount"), lit(2))) + .with_column("amount_plus_one", add(col("amount"), lit(1))) ) ``` -## Current limits +## Capability notes -- `join(...)` is still structurally temporary: same-carrier input type and placeholder join condition. -- `select(...)` does not yet accept explicit projection arguments. -- `DataFrame[T]` still exposes materialized metadata and preview text, not a full public row API. -- Query-block sugar and scoped DSL symbol surfaces are later compiler work. The current builder APIs are the semantic target for that future sugar. +- `join(...)` is constrained to same-carrier inputs and the boolean join predicate surface shown in the signature. +- `select(...)` preserves projection shape; explicit projection lists are represented today through `with_column(...)` and scalar-expression builders. +- `DataFrame[T]` exposes materialized metadata and preview text; row-level accessors belong to the materialized DataFrame API surface. +- Query-block and scoped DSL surfaces lower into these builder APIs rather than defining separate method semantics. ## Related builder references diff --git a/docs/language/reference/execution_context.md b/docs/language/reference/execution_context.md index 6fbb5e8..1ac5099 100644 --- a/docs/language/reference/execution_context.md +++ b/docs/language/reference/execution_context.md @@ -1,6 +1,6 @@ # Execution context (Reference) -This page documents the current public execution surface in the InQL package. It describes the API as implemented today. Normative design intent still lives in [RFC 004][rfc-004]. +This page documents the public execution surface in the InQL package. Normative design intent lives in [RFC 004][rfc-004]. ## Core types @@ -22,10 +22,10 @@ This page documents the current public execution surface in the InQL package. It | API | Returns | Notes | | ------------------------------------ | ------------------------------------ | -------------------------------------------------------- | | `session.register(name, source)` | `Result[None, SessionError]` | Bind a logical relation name to a source definition | -| `session.table[T](name)` | `Result[LazyFrame[T], SessionError]` | Resolve a registered logical relation by name | -| `session.read_csv[T](name, uri)` | `Result[LazyFrame[T], SessionError]` | Register and return a deferred CSV-backed relation | -| `session.read_parquet[T](name, uri)` | `Result[LazyFrame[T], SessionError]` | Register and return a deferred Parquet-backed relation | -| `session.read_arrow[T](name, uri)` | `Result[LazyFrame[T], SessionError]` | Register and return a deferred Arrow IPC-backed relation | +| `session.table(name)` | `Result[LazyFrame[T], SessionError]` | Resolve a registered logical relation by name | +| `session.read_csv(name, uri)` | `Result[LazyFrame[T], SessionError]` | Register and return a deferred CSV-backed relation | +| `session.read_parquet(name, uri)` | `Result[LazyFrame[T], SessionError]` | Register and return a deferred Parquet-backed relation | +| `session.read_arrow(name, uri)` | `Result[LazyFrame[T], SessionError]` | Register and return a deferred Arrow IPC-backed relation | All read APIs return `LazyFrame[T]`. They create deferred logical work; they do not fetch rows immediately. @@ -66,15 +66,15 @@ If no active session exists when a convenience API needs one, the operation fail - `LazyFrame[T]` is the deferred carrier for bounded work. - `DataFrame[T]` is the materialized local carrier. -- Current `collect(...)` materialization stores structured metadata plus preview text: +- `collect(...)` materialization stores structured metadata plus preview text: - resolved output columns - row count - preview text for display/debugging -- Row-level user APIs remain follow-on work; preview text is not the canonical schema contract. +- Preview text is for display/debugging; resolved output columns are the schema contract surfaced by collection. -## Current backend note +## Backend note -DataFusion is the only implemented execution backend today. The public builder/configuration surface is designed so additional backends can be added later without changing the `Session` entry point. +DataFusion is the implemented execution backend. The public builder/configuration surface is designed so additional backends can be added without changing the `Session` entry point. ## Related docs diff --git a/docs/language/reference/functions/index.md b/docs/language/reference/functions/index.md index 5b97dca..8546998 100644 --- a/docs/language/reference/functions/index.md +++ b/docs/language/reference/functions/index.md @@ -8,4 +8,6 @@ Today the concrete shipped surfaces are documented here: - [Aggregate builders](../builders/aggregates.md) - [Projection builders](../builders/projections.md) +The canonical scalar literal helper is `lit(...)`. Typed literal helpers construct the same scalar-expression representation. + Future ANSI-style families should grow under this section instead of bloating `dataset_types` or `dataset_methods`. diff --git a/docs/language/reference/substrait/conformance.md b/docs/language/reference/substrait/conformance.md index b001eca..0ce47da 100644 --- a/docs/language/reference/substrait/conformance.md +++ b/docs/language/reference/substrait/conformance.md @@ -16,10 +16,10 @@ The corpus uses typed models/enums (`SubstraitConformanceScenario`, `Conformance Canonical operation semantics flow through `src/dataset/ops.incn`, while proto-backed Substrait relation building, plan assembly, and plan inspection now live in the focused `src/substrait/*.incn` helper modules. -For the current package-level profile, conformance checks are intentionally split between: +For the package-level profile, conformance checks are split between: -- real boundary facts that the package can prove now (relation kind, read kind, join variant, set operation, reference ordinal, extension URI presence) -- richer planning semantics that remain deferred to future `query {}` lowering and Prism work +- relation boundary facts that the package validates directly (relation kind, read kind, join variant, set operation, reference ordinal, extension URI presence) +- expression and planning semantics that are validated through package tests for the implemented method-chain surface ## Representation contract @@ -73,7 +73,7 @@ Downstream tooling should consume scenario catalog and validator functions from Conformance validation for the v1 profile is expected to run against canonical operation functions in `src/dataset/ops.incn`, emitted proto-backed plans from the current `src/substrait/*.incn` helper surface, and typed model/schema helpers where needed. -The current `ProjectRel` and `AggregateRel` scenarios are boundary-shape scaffolds, not proof that full computed-column, window, grouping-set, or distinct semantics are already implemented in package code. +`ProjectRel` and `AggregateRel` scenarios validate the Substrait relation boundary. Package tests cover the implemented scalar computed-column and grouped-aggregate method-chain behavior; windows, grouping sets, and distinct semantics have their own capability rows in the operator catalog. Historical design context is captured in [InQL RFC 002][rfc-002], but this page is the source of truth for the current conformance corpus representation. diff --git a/docs/language/reference/substrait/operator_catalog.md b/docs/language/reference/substrait/operator_catalog.md index 4e283dc..4560185 100644 --- a/docs/language/reference/substrait/operator_catalog.md +++ b/docs/language/reference/substrait/operator_catalog.md @@ -26,12 +26,12 @@ The following table maps InQL plan capabilities to Substrait logical relations a | Literal or embedded rows | `ReadRel` + `VirtualTable` | core | | Predicate pushdown into scan | `ReadRel` filter fields and/or separate `FilterRel` — producer policy; **must** be documented per implementation | core | | Row filter | `FilterRel` | core | -| Add or replace computed columns | `ProjectRel`; current package code provides a boundary scaffold, while richer computed/window semantics remain deferred | core | +| Add or replace computed columns | `ProjectRel` with scalar-expression payloads for package-authored computed columns; window expressions use the window capability below | core | | Inner join | `JoinRel` (inner variant) | core | | Left join | `JoinRel` (left outer variant) | core | | Semi, anti, single, mark join variants | `JoinRel` (respective variant; optional `post_join_filter`) | core | | Cross join | `CrossRel` | core | -| Group by / aggregates | `AggregateRel`; current package code provides a boundary scaffold, while richer grouping/measure semantics remain deferred | core | +| Group by / aggregates | `AggregateRel` with scalar grouping keys and aggregate measures; grouping sets are tracked as a distinct capability below | core | | Rollup / cube / grouping sets | `AggregateRel` with multiple groupings | core | | Distinct rows | `AggregateRel` with grouping keys and no measures | core | | Window / analytic functions | `ProjectRel` with window expressions | core | diff --git a/docs/language/reference/substrait/read_root_binding_contract.md b/docs/language/reference/substrait/read_root_binding_contract.md index db191ae..46f7941 100644 --- a/docs/language/reference/substrait/read_root_binding_contract.md +++ b/docs/language/reference/substrait/read_root_binding_contract.md @@ -45,7 +45,7 @@ The execution context ([InQL RFC 004][rfc-004] `Session`) **must**: ## Adapter boundary -Product SDKs and higher operational layers **may** provide convenience read APIs (for example, `session.read_csv[T](name, uri)` wrapping registration + `ReadRel` resolution). These convenience surfaces: +Product SDKs and higher operational layers **may** provide convenience read APIs (for example, `session.read_csv(name, uri)` returning a typed `LazyFrame[T]` by inference) wrapping registration + `ReadRel` resolution. These convenience surfaces: - **Must** produce a `ReadRel` in the normative Substrait interchange with the logical identity encoded appropriately for the variant. - **Must not** embed execution-context state — resolved credentials, session tokens, resolved endpoint URLs — in the `ReadRel` payload of the normative plan. @@ -59,10 +59,10 @@ The following table summarizes how each `Session` read method maps to a `ReadRel | `Session` method | Returns | `ReadRel` variant | | ------------------------------------ | -------------- | ------------------------------------------------- | -| `session.table[T](name)` | `LazyFrame[T]` | `NamedTable` | -| `session.read_csv[T](name, uri)` | `LazyFrame[T]` | `NamedTable` (via Session registration + binding) | -| `session.read_parquet[T](name, uri)` | `LazyFrame[T]` | `NamedTable` (via Session registration + binding) | -| `session.read_arrow[T](name, uri)` | `LazyFrame[T]` | `NamedTable` (via Session registration + binding) | +| `session.table(name)` | `LazyFrame[T]` | `NamedTable` | +| `session.read_csv(name, uri)` | `LazyFrame[T]` | `NamedTable` (via Session registration + binding) | +| `session.read_parquet(name, uri)` | `LazyFrame[T]` | `NamedTable` (via Session registration + binding) | +| `session.read_arrow(name, uri)` | `LazyFrame[T]` | `NamedTable` (via Session registration + binding) | In all cases the `LazyFrame[T]` holds a deferred plan — no data is fetched until Session execution is triggered by `session.execute(...)`, `session.collect(...)`, a convenience call such as `lazy.collect()`, or a Session-owned write. The `ReadRel` in the deferred plan carries only the logical identity; resolution to a physical source happens at execution time per the execution context obligations described above. diff --git a/docs/release_notes/v0_1.md b/docs/release_notes/v0_1.md index 40ab165..f737105 100644 --- a/docs/release_notes/v0_1.md +++ b/docs/release_notes/v0_1.md @@ -11,6 +11,7 @@ Entries will be filled in as work lands (link RFCs and PRs when applicable). - **Plans:** Apache Substrait as the logical interchange contract. - **Authoring:** method-chain lowering into a real Substrait boundary today, with `query {}` work still ahead. - **Aggregates:** builder-based `col`, `sum`, and `count` helpers now lower grouped and global aggregates through Prism, Substrait, and Session execution. +- **Scalar expressions:** RFC 012 unifies filter predicates, computed projection values, grouping keys, and aggregate inputs around one `ColumnExpr` surface with canonical `lit(...)` and typed literal helpers. - **Projection:** builder-based `with_column`, `add`, `mul`, and literal expression helpers now lower derived columns through Prism, Substrait, and Session execution. - **Substrait internals:** RFC 002 helpers are now split into focused owner modules for relation building, plan assembly, inspection, schema registry, extension bookkeeping, and expression lowering instead of one `substrait.plan` godmodule. - **Prism:** `LazyFrame` lowering applies safe canonical rewrites (`Filter(true)` elimination and adjacent `Limit`/`Project`/`OrderBy` collapse) before RFC 002 plan emission. diff --git a/docs/rfcs/004_inql_execution_context.md b/docs/rfcs/004_inql_execution_context.md index c37991c..922ce49 100644 --- a/docs/rfcs/004_inql_execution_context.md +++ b/docs/rfcs/004_inql_execution_context.md @@ -33,7 +33,7 @@ InQL RFCs 000–003 define a typed query language that produces portable logical Choosing Apache DataFusion as the reference backend is a pragmatic decision: it is Rust-native, Substrait-aware, provides serious query optimization, and operates on Apache Arrow — the de facto columnar data interchange format in the modern data ecosystem. Naming it explicitly avoids the trap of an abstract "pluggable backend" with no concrete implementation. -The `Session` surface should also feel familiar to users coming from established data runtimes such as Spark: one obvious entry point for reading data, registering logical names, executing plans, and writing results. That familiarity is an ergonomic goal, but not a semantic dependency. InQL keeps its own typed `DataSet[T]` model, explicit execution boundary, and RFC-defined semantics rather than inheriting historical API baggage from other systems. +The `Session` surface should also feel familiar to users coming from established data runtimes such as Spark: one obvious entry point for reading data, registering logical names, executing plans, and writing results. That familiarity is an ergonomic goal, but not a semantic dependency. InQL keeps its own typed `DataSet[T]` model, explicit execution boundary, and RFC-defined semantics rather than inheriting runtime-specific API details from other systems. ## Goals diff --git a/docs/rfcs/007_prism_planning_engine.md b/docs/rfcs/007_prism_planning_engine.md index 090e463..368b954 100644 --- a/docs/rfcs/007_prism_planning_engine.md +++ b/docs/rfcs/007_prism_planning_engine.md @@ -55,7 +55,7 @@ This matters for more than simple query lowering. Complex multi-hop pipelines, f From an author's point of view, Prism is not something they use directly. Authors work with InQL carriers such as `LazyFrame[T]`, `DataFrame[T]`, and (later) `DataStream[T]`. Those carriers build or operate over logical work that Prism stores and optimizes internally. ```incan -orders = session.table[Order]("orders")? +orders: LazyFrame[Order] = session.table("orders")? cutoff = ... # some appropriate value high_value = orders.filter(.amount > 1000) diff --git a/docs/rfcs/008_optimizer_boundary_stats_cbo_aqe.md b/docs/rfcs/008_optimizer_boundary_stats_cbo_aqe.md index 6636613..e019a06 100644 --- a/docs/rfcs/008_optimizer_boundary_stats_cbo_aqe.md +++ b/docs/rfcs/008_optimizer_boundary_stats_cbo_aqe.md @@ -81,8 +81,8 @@ from models import Customer, Order, RegionalSummary session = Session.default() -orders: LazyFrame[Order] = session.table[Order]("orders")? -customers: LazyFrame[Customer] = session.table[Customer]("customers")? +orders: LazyFrame[Order] = session.table("orders")? +customers: LazyFrame[Customer] = session.table("customers")? summary: LazyFrame[RegionalSummary] = query { FROM orders diff --git a/docs/rfcs/009_session_format_handler_registry.md b/docs/rfcs/009_session_format_handler_registry.md index af7006b..91e89e2 100644 --- a/docs/rfcs/009_session_format_handler_registry.md +++ b/docs/rfcs/009_session_format_handler_registry.md @@ -44,7 +44,7 @@ A handler registry makes format support extensible without destabilizing Session Authors should be able to register a handler and then read data through that format key. ```incan -from session import Session +from pub::inql import LazyFrame, Session from session.formats import SessionFormatHandler class FooFormatHandler with SessionFormatHandler: @@ -60,14 +60,14 @@ class FooFormatHandler with SessionFormatHandler: def main() -> None: mut session = Session.default() session.register_format_handler(FooFormatHandler())? - rows = session.read_format[OrderLine]("orders", "foo://bucket/orders.foo", "foo")? + rows: LazyFrame[OrderLine] = session.read_format("orders", "foo://bucket/orders.foo", "foo")? ``` Built-ins should follow the same path: ```incan session = Session.default() -rows = session.read_format[OrderLine]("orders", "tests/fixtures/orders.csv", "csv")? +rows: LazyFrame[OrderLine] = session.read_format("orders", "tests/fixtures/orders.csv", "csv")? ``` Convenience APIs like `read_csv` remain available and delegate to `read_format(..., "csv")`. diff --git a/docs/rfcs/010_csv_ingestion_contract.md b/docs/rfcs/010_csv_ingestion_contract.md index 7c89a87..afc54d0 100644 --- a/docs/rfcs/010_csv_ingestion_contract.md +++ b/docs/rfcs/010_csv_ingestion_contract.md @@ -61,7 +61,7 @@ Authors should think of CSV reads as a contract between their program and the ru The public surface should therefore expose one structured options value for CSV reads. ```incan -from pub::inql import Session +from pub::inql import LazyFrame, Session from pub::inql.formats import ( CsvDialect, CsvHeaderMode, @@ -74,7 +74,7 @@ from models import Order session = Session.default() -orders = session.read_csv[Order]( +orders: LazyFrame[Order] = session.read_csv( "orders", "s3://warehouse/orders.csv", CsvReadOptions( @@ -96,7 +96,7 @@ orders = session.read_csv[Order]( The same contract should be reachable through the generic format-dispatch path. `read_csv` is convenience API surface, not a separate semantic contract. ```incan -orders = session.read_format[Order]( +orders: LazyFrame[Order] = session.read_format( logical_name="orders", source="s3://warehouse/orders.csv", format="csv", @@ -119,7 +119,7 @@ orders = session.read_format[Order]( If an author does not provide options, InQL still has a defined default contract rather than an implementation accident. ```incan -orders = session.read_csv[Order]("orders", "s3://warehouse/orders.csv") +orders: LazyFrame[Order] = session.read_csv("orders", "s3://warehouse/orders.csv") ``` That default should mean "standard CSV with headers, quoted-field support, strict malformed-row handling, and deterministic type interpretation rules", not "whatever the current backend happened to do this week." @@ -129,7 +129,7 @@ Headerless CSV is still allowed, but it should be explicit because it changes ho ```incan from pub::inql.formats import CsvDialect, CsvHeaderMode, CsvReadOptions -rows = session.read_csv[Order]( +rows: LazyFrame[Order] = session.read_csv( "orders", "file:///tmp/orders_no_header.csv", CsvReadOptions( @@ -183,7 +183,7 @@ CSV ingestion must distinguish three schema layers: - **planned schema**: the schema InQL can establish before execution from configuration, headers, and compatible schema analysis - **resolved schema**: the schema observed from materialized output after execution -When `T` is supplied in `read_csv[T](...)`, `T` is the declared schema contract. CSV ingestion must validate compatibility between the file and `T` according to the configured header and inference policy. A backend must not silently reorder columns or reinterpret field names in a way that breaks the declared schema. +When the call context requires `LazyFrame[T]`, `T` is the declared schema contract. CSV ingestion must validate compatibility between the file and `T` according to the configured header and inference policy. A backend must not silently reorder columns or reinterpret field names in a way that breaks the declared schema. When CSV options require token interpretation beyond raw strings, `CsvInference` must control that behavior. At minimum it must define: diff --git a/docs/rfcs/011_source_discovery_contract.md b/docs/rfcs/011_source_discovery_contract.md index 8b071cc..57821e3 100644 --- a/docs/rfcs/011_source_discovery_contract.md +++ b/docs/rfcs/011_source_discovery_contract.md @@ -58,13 +58,13 @@ For the common path, authors should not have to restate obvious discovery behavi For a single file, the discovery contract is trivial. ```incan -from pub::inql import Session +from pub::inql import LazyFrame, Session from pub::inql.sources import SourceDiscovery, SourceTarget from models import Order session = Session.default() -orders = session.read_format[Order]( +orders: LazyFrame[Order] = session.read_format( logical_name="orders", source=SourceTarget.file("s3://warehouse/orders.csv"), format="csv", @@ -76,7 +76,7 @@ For a directory or prefix, the discovery contract becomes explicit: discover mul ```incan from pub::inql.sources import ParseUnitExpansion, SourceDiscovery, SourceTarget -orders = session.read_format[Order]( +orders: LazyFrame[Order] = session.read_format( logical_name="orders", source=SourceTarget.directory("s3://warehouse/orders/"), format="csv", diff --git a/docs/rfcs/012_unified_scalar_expression_surface.md b/docs/rfcs/012_unified_scalar_expression_surface.md index 8230c57..28c9bba 100644 --- a/docs/rfcs/012_unified_scalar_expression_surface.md +++ b/docs/rfcs/012_unified_scalar_expression_surface.md @@ -1,6 +1,6 @@ # InQL RFC 012: Unified scalar expression surface -- **Status:** Draft +- **Status:** Implemented - **Created:** 2026-04-22 - **Author(s):** Danny Meijer (@dannymeijer) - **Related:** @@ -8,12 +8,15 @@ - InQL RFC 003 (`query {}` blocks and relational authoring) - InQL RFC 004 (execution context and backend execution boundary) - InQL RFC 007 (Prism logical planning and optimization engine) + - Incan RFC 025 (multi-instantiation trait dispatch) + - Incan RFC 028 (trait-based operator overloading) + - Incan RFC 029 (union types and type narrowing) - Incan RFC 040 (scoped DSL glyph surfaces) - Incan RFC 045 (scoped DSL symbol surfaces) - **Issue:** [InQL #25](https://github.com/dannys-code-corner/InQL/issues/25) - **RFC PR:** — -- **Written against:** Incan v0.2 -- **Shipped in:** — +- **Written against:** Incan v0.3.0 +- **Shipped in:** v0.1 ## Summary @@ -52,7 +55,7 @@ This RFC is also needed before InQL can take proper advantage of the scoped DSL Authors should be able to think in terms of one row-level expression language. -The exact public literal helper spelling is still unresolved in this Draft. The examples below use `lit(...)` illustratively to show the semantic model, not to settle the final helper name. +The canonical public literal helper is `lit(...)`. Typed literal helpers are entrypoints over the same scalar-literal representation rather than a separate predicate-only or projection-only literal family. If an author filters rows or computes a new column, those operations should be using the same underlying scalar expression model: @@ -120,7 +123,7 @@ At minimum, the canonical scalar expression model must be able to represent: - scalar literals - scalar function or operator application over scalar-expression inputs -Separate public wrapper types may exist during migration, but they must lower into the same canonical scalar expression model rather than remaining semantically independent systems. +Separate public wrapper types may exist as implementation details, but they must lower into the same canonical scalar expression model rather than remaining semantically independent systems. ### Row-level consumers @@ -160,11 +163,13 @@ The following behaviors are forbidden: ### Canonical literal concept -The semantic concept of a scalar literal must be unified. InQL may expose one canonical helper such as `lit(...)`, or a migration family of typed helpers that all lower into the same scalar-literal representation, but the system must not preserve separate literal hierarchies for filters, projections, and other row-level positions. +The semantic concept of a scalar literal must be unified. InQL standardizes `lit(...)` as the canonical public helper for scalar literals. Because Incan RFC 029 gives Incan first-class union types, `lit(...)` may accept a closed union of supported literal input types such as `int | float | str | bool` while still returning one scalar-expression representation. + +Typed helpers such as `int_expr(...)`, `float_expr(...)`, `str_expr(...)`, `bool_expr(...)`, `int_lit(...)`, `str_lit(...)`, and `bool_lit(...)` must lower into the same scalar-literal representation as `lit(...)`. The system must not preserve separate literal hierarchies for filters, projections, and other row-level positions. ### Lowering target for future authoring surfaces -If InQL adopts future concise method-chain sugar or richer `query {}` syntax using RFC 040 and RFC 045 facilities, those surfaces must lower into the canonical scalar expression model for row-level meaning and the canonical aggregate-measure model for aggregate meaning. +If InQL adopts future concise method-chain sugar or richer `query {}` syntax using Incan RFC 025, RFC 028, RFC 029, RFC 040, and RFC 045 facilities, those surfaces must lower into the canonical scalar expression model for row-level meaning and the canonical aggregate-measure model for aggregate meaning. ## Design details @@ -183,7 +188,7 @@ The core semantic split is: Boolean predicates are ordinary scalar expressions whose result type is `bool`; they are not a separate semantic species. -Grouping keys belong on the scalar-expression side. They determine grouping identity by evaluating one scalar expression per input row. InQL may initially support only a subset of scalar expressions for grouping, but if so, that restriction must be explicit and diagnosable. +Grouping keys belong on the scalar-expression side. They determine grouping identity by evaluating one scalar expression per input row. The north-star contract allows deterministic, row-level scalar expressions as grouping keys when their result type is valid for grouping. InQL may initially support only a subset of scalar expressions for grouping, but if so, that restriction must be explicit and diagnosable, not a permanent semantic narrowing and not a silent fallback to direct-column grouping. ### Interaction with other InQL surfaces @@ -193,18 +198,18 @@ Grouping keys belong on the scalar-expression side. They determine grouping iden - **Prism (InQL RFC 007):** Prism should represent row-level expression meaning once and reuse it across logical operators instead of duplicating per-surface expression semantics. - **Incan `model` types and lexical scope:** model fields remain the source of column naming and typing, and ordinary lexical scope rules still govern explicit helper references until scoped DSL facilities are adopted. -### Compatibility / migration +### Surface cleanup -This RFC is additive as a design direction, but it may require cleanup of existing public builder families. +This RFC consolidates builder families before InQL's first released API. -The expected migration shape is: +The expected cleanup shape is: -- legacy typed literal helpers may remain temporarily as compatibility shims -- legacy predicate-specific wrappers may remain temporarily if they lower into the same scalar-expression model +- typed literal helpers must not create separate scalar-literal hierarchies +- predicate-specific wrappers must lower into the same scalar-expression model when they are exposed - docs and diagnostics should steer authors toward one canonical scalar-expression concept -- split row-level helper families should be deprecated and eventually removed once compatibility windows close +- split row-level helper families should be collapsed before release -Correctness takes precedence over convenience during migration. If a permissive compatibility path would silently change semantics, InQL should reject that path instead. +Correctness takes precedence over convenience. If a permissive helper path would silently change semantics, InQL should reject that path instead. ## Alternatives considered @@ -215,7 +220,7 @@ Correctness takes precedence over convenience during migration. If a permissive ## Drawbacks -- Existing InQL surfaces that grew independently may need migration and deprecation work. +- InQL surfaces that grew independently may need cleanup before release. - Tooling and diagnostics become more demanding because the system must enforce the scalar-versus-aggregate boundary more consistently. - Some previously tolerated expression shapes may need to become hard errors if they only "worked" through accidental or degraded behavior. - The RFC makes inconsistencies more visible, which can force earlier cleanup across docs, examples, planning, and lowering. @@ -228,11 +233,90 @@ Correctness takes precedence over convenience during migration. If a permissive - **Execution / interchange** — Prism and Substrait lowering must preserve the scalar-versus-aggregate boundary and must not silently rewrite unsupported expression shapes. - **Documentation** — user docs and reference pages should describe one row-level expression model instead of multiple parallel mini-DSLs. -## Unresolved questions +## Implementation Plan + +### Phase 1: Canonical scalar expression model + +- Introduce one canonical scalar-expression model that can represent column references, scalar literals, and scalar function/operator applications. +- Add `lit(...)` as the canonical literal helper using the Incan RFC 029 union-type surface for supported literal input types. +- Keep typed literal helpers as constructors that lower into the canonical scalar-expression model. +- Keep public helper names aligned with the canonical scalar-expression model. + +### Phase 2: Row-level consumers + +- Make `filter(...)` consume a scalar expression whose result contract is boolean. +- Make `with_column(...)` and projection assignment helpers consume row-level scalar expressions. +- Make `group_by(...)` consume grouping-key scalar expressions, with explicit errors for scalar expression shapes that the current implementation cannot yet lower faithfully. +- Make aggregate helpers consume scalar expressions as inputs while continuing to produce aggregate measures. + +### Phase 3: Prism and Substrait lowering + +- Store scalar expressions consistently in Prism filter, projection, grouping, and aggregate-input positions. +- Lower scalar expressions through one shared Substrait expression lowering path where the target position supports the expression. +- Preserve aggregate measures as a distinct group-level representation rather than collapsing them into row-level scalar expressions. +- Replace silent fallback or direct-column-only assumptions with explicit diagnostics or planning errors for unsupported expression shapes. + +### Phase 4: Tests, docs, and surface coherence + +- Add package tests covering `lit(...)` and typed literal helpers across filters, computed projections, grouping keys, and aggregate inputs. +- Add negative tests for unsupported scalar-expression shapes in grouping and aggregate positions when the implementation cannot lower them faithfully. +- Update reference and explanation docs so users see one scalar expression model instead of separate filter/projection literal families. +- Decide whether this user-visible package change requires an InQL package version bump, and if so keep `incan.toml` and `src/metadata.incn` synchronized. + +## Implementation Log + +### Spec / design + +- [x] Resolve canonical literal helper spelling as `lit(...)`. +- [x] Resolve grouping keys as broad scalar-expression positions with explicit temporary capability errors allowed. +- [x] Resolve aggregate outputs as distinct aggregate measures, not row-level scalar expressions. +- [x] Resolve first-class metadata as a registry/tooling north star owned by the function registry RFCs rather than by RFC 012 alone. +- [x] Verify Incan RFC 025, RFC 028, and RFC 029 dependencies against Incan `origin/main`. + +### Scalar expression model + +- [x] Define the canonical scalar-expression model in package code. +- [x] Represent column references, scalar literals, and scalar function/operator applications in that model. +- [x] Add canonical `lit(...)` helper for supported literal input types. +- [x] Convert typed literal helpers into constructors over the canonical model. +- [x] Keep public helper imports aligned with the canonical scalar-expression model. + +### Row-level consumers + +- [x] Update filter APIs to consume scalar expressions. +- [x] Update computed projection APIs to consume scalar expressions. +- [x] Update grouping APIs to consume grouping-key scalar expressions. +- [x] Update aggregate helpers to consume scalar-expression inputs and produce aggregate measures. +- [x] Add explicit errors for accepted expression shapes that cannot yet be represented or lowered faithfully in a target position. + +### Prism and Substrait + +- [x] Store scalar expressions consistently in Prism filter, projection, grouping, and aggregate-input nodes. +- [x] Lower scalar expressions through one shared Substrait expression-lowering path. +- [x] Preserve aggregate-measure lowering as a separate group-level path. +- [x] Add coverage for direct-column grouping and computed grouping expressions. +- [x] Add coverage for aggregate inputs that are scalar expressions rather than direct column names only. + +### Tests + +- [x] Add package tests for `lit(...)` in filters. +- [x] Add package tests for `lit(...)` in computed projections. +- [x] Add package tests for scalar-expression grouping keys or explicit unsupported-shape diagnostics. +- [x] Add package tests for aggregate scalar-expression inputs or explicit unsupported-shape diagnostics. +- [x] Keep typed helper tests passing as shared-model coverage. + +### Documentation and release + +- [x] Update `docs/language/reference/dataset_methods.md`. +- [x] Update builder/function reference docs to teach canonical scalar expressions and typed helper entrypoints. +- [x] Update explanation docs and examples that currently teach separate filter/projection literal families. +- [x] Add release notes if package behavior changes are user-visible. +- [x] No package version bump was required for this pre-release v0.1 slice. -- Should InQL standardize a canonical public literal helper spelling such as `lit(...)`, or only standardize the semantic concept and allow multiple helper spellings during a long compatibility period? -- Should grouping keys eventually accept all scalar expressions that projections accept, or should InQL permanently support a narrower grouping subset? -- Should aggregate outputs remain completely disallowed inside row-level scalar expressions, or is there future design space for explicitly scoped mixed aggregate projections? -- How much of this contract should be represented as shared InQL conventions versus first-class vocabulary metadata that tooling can inspect directly? +## Design Decisions - +- **Canonical literal helper:** `lit(...)` is the canonical public scalar-literal helper. Typed literal helpers construct the same scalar-literal representation and are not separate semantic systems. +- **Incan dependency boundary:** Incan RFC 029 enables a clean union-typed `lit(...)` input surface. Incan RFC 025 and RFC 028 are relevant to later operator-heavy expression sugar, including multiple operator trait instantiations for different right-hand operand types. Those RFCs do not replace the canonical scalar-expression model; they provide authoring surfaces that lower into it. +- **Grouping keys:** Grouping keys are scalar expressions in the north-star contract. Implementations may temporarily reject unsupported grouping expression shapes, but they must do so explicitly and must not permanently encode direct-column grouping as the only semantic model. +- **Aggregate boundary:** Aggregate outputs remain aggregate measures, not row-level scalar expressions. Mixed aggregate/scalar expression semantics require a later RFC. +- **Metadata boundary:** RFC 012 defines the semantic scalar-versus-aggregate contract. First-class function and vocabulary metadata belongs to the function registry and catalog RFCs, but implementations should avoid prose-only or helper-name-only behavior that tooling cannot eventually inspect. diff --git a/docs/rfcs/015_core_scalar_functions.md b/docs/rfcs/015_core_scalar_functions.md index b2b0bde..2d2f569 100644 --- a/docs/rfcs/015_core_scalar_functions.md +++ b/docs/rfcs/015_core_scalar_functions.md @@ -96,9 +96,9 @@ The minimum core scalar catalog is: `query {}` syntax may present SQL-familiar operators such as `==`, `>`, `AND`, `OR`, and `IS NULL`, but those operators must lower to the same scalar function semantics defined here. Dataframe methods must consume the same scalar expression model. -### Compatibility / migration +### Surface cleanup -Existing helpers such as `int_expr`, `float_expr`, `str_expr`, `bool_expr`, `int_lit`, `str_lit`, and `bool_lit` may remain temporarily as compatibility shims to `lit`. Existing `add`, `mul`, `eq`, and `gt` should become registered core scalar entries. +Typed literal helpers such as `int_expr`, `float_expr`, `str_expr`, `bool_expr`, `int_lit`, `str_lit`, and `bool_lit` should route through `lit` or the same scalar-literal representation. Existing `add`, `mul`, `eq`, and `gt` should become registered core scalar entries. ## Alternatives considered @@ -109,16 +109,16 @@ Existing helpers such as `int_expr`, `float_expr`, `str_expr`, `bool_expr`, `int ## Drawbacks - Defining null, cast, and numeric failure behavior exposes hard choices earlier. -- Compatibility shims may temporarily increase the number of public helper names. +- Typed helper entrypoints may increase the number of public helper names. - Operator sugar and function helper names can drift unless tooling treats registry entries as canonical. ## Layers affected - **InQL specification** — the core scalar vocabulary must remain consistent with InQL RFC 012 and InQL RFC 014. -- **InQL library package** — public helpers should expose the core scalar entries and compatibility shims. +- **InQL library package** — public helpers should expose the core scalar entries and selected typed helper entrypoints. - **Incan compiler** — query syntax and any future operator sugar must lower to the same scalar function semantics. - **Execution / interchange** — Prism and Substrait lowering must preserve casts, null-safe equality, boolean logic, ordering null placement, and `try_` behavior. -- **Documentation** — scalar function reference docs should distinguish canonical names from compatibility helpers. +- **Documentation** — scalar function reference docs should distinguish canonical names from typed helper entrypoints. ## Unresolved questions diff --git a/docs/rfcs/README.md b/docs/rfcs/README.md index cd79bdb..f83a49f 100644 --- a/docs/rfcs/README.md +++ b/docs/rfcs/README.md @@ -18,7 +18,7 @@ InQL uses its **own** RFC series (starting at 000), independent of the [Incan la | [009][rfc-009] | Draft | Session format handler registry (plugin-style source format registration) | | | [010][rfc-010] | Draft | CSV dialect and interpretation contract | | | [011][rfc-011] | Draft | Source discovery and parse-unit expansion | | -| [012][rfc-012] | Draft | Unified scalar expression surface | | +| [012][rfc-012] | Implemented | Unified scalar expression surface | | | [013][rfc-013] | Draft | Function catalog program | | | [014][rfc-014] | Draft | Function registry and catalog governance | | | [015][rfc-015] | Draft | Core scalar functions and operators | | diff --git a/examples/README.md b/examples/README.md index f83caa8..f2e7291 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,26 +1,30 @@ # InQL examples -Examples demonstrating InQL dataset types and Session execution patterns. +Examples demonstrating InQL model-shaped dataset types, scalar-expression builders, and Session execution patterns. -## Current status +## Overview -Most examples are still focused on compile-safe RFC 001 type contracts, and the Session examples exercise the RFC 004 execution path end-to-end. +The examples are split between compile-safe API shape examples and executable Session flows. Together they show how model +types, typed carriers, scalar expressions, grouped aggregates, reads, writes, collection, and display fit together in +ordinary Incan code. ## Example structure -- `dataset_api.incn` — Demonstrates the DataSet[T] operation API +- `dataset_api.incn` — Demonstrates typed DataSet[T] pipeline helper shapes - `trait_hierarchy.incn` — Demonstrates trait hierarchy usage - `bounded_vs_unbounded.incn` — Demonstrates bounded vs unbounded type signatures -- `session_read_transform_write_csv.incn` — Demonstrates `Session.read_csv[T](name, uri) -> LazyFrame transform -> Session.write_csv(...) -> session.activate() -> display(...)` +- `session_read_transform_write_csv.incn` — Demonstrates `Session.read_csv(...) -> LazyFrame[OrderId] -> transform -> write/display` - `session_read_transform_write_order_lines_csv.incn` — Same flow with a realistic multi-column `OrderLine` model and fixture -- `session_grouped_aggregate_csv.incn` — Real grouped aggregate example over CSV using `col(...)`, `sum(...)`, and `count()` -- `session_with_column_csv.incn` — Real derived-column example over CSV using `with_column(...)`, `mul(...)`, and `int_expr(...)` -- `models.incn` — Placeholder models for examples +- `session_grouped_aggregate_csv.incn` — Grouped aggregate over `LazyFrame[AggregateOrder]` using `col(...)`, `sum(...)`, and `count()` +- `session_with_column_csv.incn` — Derived-column example over `LazyFrame[AggregateOrder]` using `with_column(...)`, `mul(...)`, and `lit(...)` +- `models.incn` — Shared `@derive(Clone)` row models for examples ## Running examples ```bash incan run examples/dataset_api.incn +incan run examples/trait_hierarchy.incn +incan run examples/bounded_vs_unbounded.incn incan run examples/session_read_transform_write_csv.incn incan run examples/session_read_transform_write_order_lines_csv.incn incan run examples/session_grouped_aggregate_csv.incn @@ -31,16 +35,17 @@ incan run examples/session_with_column_csv.incn ## What these examples show -These examples document the API patterns for the current InQL dataset and Session surface: +These examples document the API patterns for the InQL dataset and Session surface: -1. **RFC 001** contracts are represented as compile-safe signatures and trait assignments -2. Builder-based aggregation is now concrete and runnable through `col(...)`, `sum(...)`, and `count()` -3. Builder-based projection is now concrete and runnable through `with_column(...)`, `add(...)`, `mul(...)`, and literal builders -4. **RFC 004** provides execution behavior (`execute`, `collect`, and write sinks over DataFusion) +1. Model contracts are visible at source boundaries as `LazyFrame[OrderId]`, `LazyFrame[OrderLine]`, and `LazyFrame[AggregateOrder]` +2. Carrier transformations remain typed Incan functions rather than stringly runtime scripts +3. Builder-based aggregation runs through `col(...)`, `sum(...)`, and `count()` +4. Builder-based scalar expressions run through `col(...)`, `lit(...)`, `eq(...)`, `gt(...)`, `add(...)`, and `mul(...)` +5. Session execution provides `collect`, `display`, and write sinks over DataFusion -Once those are in place, these examples will serve as: +They serve three purposes: -- **Regression tests** — verifying the patterns still work +- **Regression tests** — verifying the patterns remain valid - **Documentation** — showing users how to use the API - **Examples** — providing starting points for real code @@ -51,7 +56,8 @@ Once those are in place, these examples will serve as: These RFCs provide the trait and interop foundation InQL builds on. -What's still needed: +## Scope boundaries -- **Materialized row APIs** — `DataFrame[T]` row-level accessors remain out of scope in the current slice -- **Additional convenience APIs** — broader transformation ergonomics continue in follow-on RFC slices +- **Materialized row access** — Session collection exposes typed `DataFrame[T]` materialization metadata and preview text; row iteration/accessor design belongs to the DataFrame API work. +- **Output row retargeting** — projection and aggregate methods currently preserve the carrier type parameter while planned columns expose shape changes. +- **Convenience authoring** — these examples use the explicit builder surface that concise query and operator surfaces lower into. diff --git a/examples/bounded_vs_unbounded.incn b/examples/bounded_vs_unbounded.incn index edf56af..9e19f1a 100644 --- a/examples/bounded_vs_unbounded.incn +++ b/examples/bounded_vs_unbounded.incn @@ -1,82 +1,44 @@ -""" -Example: Bounded vs unbounded type signatures (RFC 001). - -This example focuses on compile-time type signatures and keeps runtime bodies as placeholders. - -## Patterns shown - -- BoundedDataSet[T]: All operations allowed -- UnboundedDataSet[T]: Unbounded-state operations rejected at compile time -- Static capability gating based on trait bounds -""" +"""Example: bounded vs unbounded model-carrier contracts.""" +from functions import col, eq, lit from dataset import DataSet, BoundedDataSet, UnboundedDataSet from models import Order, Event -# ---- Bounded-only functions ---- -def write_to_parquet(data: BoundedDataSet[Order]) -> None: - """Write to Parquet requires bounded data (finite extent).""" - pass - - -def materialize_to_dataframe(_data: BoundedDataSet[Order]) -> None: - """Materialize to DataFrame requires bounded data.""" - pass +def open_batch_orders(orders: BoundedDataSet[Order]) -> BoundedDataSet[Order]: + """Batch helpers can require finite Order data explicitly.""" + return orders.filter(eq(col("status"), lit("open"))) -def batch_aggregate(_data: BoundedDataSet[Order]) -> None: - """Batch aggregation requires bounded data.""" - pass +def order_snapshot(orders: BoundedDataSet[Order]) -> BoundedDataSet[Order]: + """Snapshot helpers can keep boundedness in the return type.""" + return open_batch_orders(orders).limit(100) -# ---- Streaming-only functions ---- -def write_to_kafka(events: UnboundedDataSet[Event]) -> None: - """Write to Kafka requires streaming data.""" - pass +def high_priority_events(events: UnboundedDataSet[Event]) -> UnboundedDataSet[Event]: + """Streaming helpers keep the event-stream carrier visible.""" + return events.filter(eq(col("severity"), lit("high"))) -def process_stream(events: UnboundedDataSet[Event]) -> UnboundedDataSet[Event]: - """Process streaming data.""" - return events +def event_stream_window(events: UnboundedDataSet[Event]) -> UnboundedDataSet[Event]: + """Unbounded pipelines can still use row-local transforms.""" + return high_priority_events(events).limit(1000) -# ---- Generic functions with boundedness constraints ---- -def process_any[T](data: DataSet[T]) -> int: - """Generic function that works with any carrier.""" - return 0 +def require_batch_orders(orders: BoundedDataSet[Order]) -> None: + """Sink boundaries can demand a bounded Order carrier.""" + _ = orders -def process_batch[T](data: BoundedDataSet[T]) -> BoundedDataSet[T]: - """Function that works only with bounded data.""" - return data +def require_event_stream(events: UnboundedDataSet[Event]) -> None: + """Sink boundaries can demand an unbounded Event carrier.""" + _ = events -# ---- Static capability gating examples ---- -def safe_filter[T](data: DataSet[T]) -> DataSet[T]: - """Filter is available on all carriers.""" - # Intended shape: data.filter(...) +def pass_through_any_model[T](data: DataSet[T]) -> DataSet[T]: + """Generic helpers stay carrier-neutral when they only need shared DataSet operations.""" return data -def safe_join[T](data: DataSet[T], other: DataSet[T]) -> DataSet[T]: - """Join uses `other: Self` (same carrier type at the trait level; logical row typing per RFC 003).""" - # Intended shape when stubs are executable: data.join(other, ...) - _ = other - return data - - -def bounded_only_operation(data: BoundedDataSet[Order]) -> None: - """Function that requires bounded data.""" - pass - - -def streaming_safe_operation(data: UnboundedDataSet[Event]) -> None: - """Function that works with streaming data.""" - pass - - -# ---- Main ---- def main() -> None: - """Example main function.""" pass diff --git a/examples/dataset_api.incn b/examples/dataset_api.incn index 5d108d8..9e5546d 100644 --- a/examples/dataset_api.incn +++ b/examples/dataset_api.incn @@ -1,69 +1,69 @@ """ -Example: DataSet[T] operation API patterns (RFC 001). +Example: typed DataSet[T] operation API patterns. -This file stays compile-safe today, but it now uses the parts of the method-chain surface that are already real: +These helpers read like ordinary application code: model types mark row contracts, while explicit builder expressions +describe portable relational intent. -- builder-based filters via `eq(...)`, `gt(...)`, `int_lit(...)`, and `str_lit(...)` -- builder-based projection via `with_column(...)`, `add(...)`, `mul(...)`, and `int_expr(...)` +- builder-based filters via `eq(...)`, `gt(...)`, and `lit(...)` +- builder-based projection via `with_column(...)`, `add(...)`, `mul(...)`, and `lit(...)` - builder-based grouping and aggregation via `col(...)`, `sum(...)`, and `count()` - -Query-expression sugar such as `.amount > 100` and rich `select/order_by` expressions still depend on later compiler -work, so those remain documented as intended shapes rather than executable examples. """ -from functions import add, col, count, eq, gt, int_expr, int_lit, mul, str_lit, sum +from functions import add, col, count, eq, gt, lit, mul, sum from dataset import DataSet, BoundedDataSet, DataFrame, LazyFrame -from models import Order, Customer +from models import AggregateOrder, Customer, Order, OrderLine def filter_high_value_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.filter(gt(col("amount"), int_lit(100))) + return orders.filter(gt(col("amount"), lit(100.0))) def filter_active_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.filter(eq(col("status"), str_lit("active"))) + return orders.filter(eq(col("status"), lit("active"))) -def filter_combined(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.filter(gt(col("amount"), int_lit(100))).filter(eq(col("status"), str_lit("active"))) +def active_high_value_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return filter_high_value_orders(filter_active_orders(orders)) -def join_orders_with_customers(orders: LazyFrame[Order], customers: LazyFrame[Customer]) -> LazyFrame[Order]: - # Trait join uses `other: Self`; heterogeneous Order/Customer join is a logical/RFC 003 concern. - # Same-schema illustration: orders.join(orders, True) when both are LazyFrame[Order]. - _ = customers - return orders.join(orders, true) +def customer_dimension(customers: LazyFrame[Customer]) -> LazyFrame[Customer]: + return customers.select() -def select_basic_fields(orders: LazyFrame[Order]) -> LazyFrame[Order]: - # Intended shape: orders.select(.id, .amount, .status) - return orders +def join_orders_with_revision(orders: LazyFrame[Order], revised_orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.join(revised_orders, true) + + +def customer_dimension_preview(customers: LazyFrame[Customer]) -> LazyFrame[Customer]: + return customer_dimension(customers).limit(25) -def derive_double_amount(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.with_column("amount_x2", mul(col("amount"), int_expr(2))) +def select_order_fields(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.select() -def derive_amount_plus_one(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.with_column("amount_plus_one", add(col("amount"), int_expr(1))) +def derive_amount_with_tax(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.with_column("amount_with_tax", mul(col("amount"), lit(1.21))) -def group_by_customer_sum_amount(orders: LazyFrame[Order]) -> LazyFrame[Order]: +def derive_amount_plus_shipping(orders: LazyFrame[Order]) -> LazyFrame[Order]: + return orders.with_column("amount_plus_shipping", add(col("amount"), lit(4.95))) + + +def customer_amount_rollup(orders: LazyFrame[AggregateOrder]) -> LazyFrame[AggregateOrder]: return orders.group_by([col("customer_id")]).agg([sum(col("amount"))]) -def group_by_multiple_keys(orders: LazyFrame[Order]) -> LazyFrame[Order]: +def customer_status_rollup(orders: LazyFrame[Order]) -> LazyFrame[Order]: return orders.group_by([col("customer_id"), col("status")]).agg([sum(col("amount")), count()]) -def order_by_amount_desc(orders: LazyFrame[Order]) -> LazyFrame[Order]: - # Intended shape: orders.order_by(.amount.desc()) - return orders +def open_order_lines(lines: LazyFrame[OrderLine]) -> LazyFrame[OrderLine]: + return lines.filter(eq(col("status"), lit("open"))) -def order_by_limit_top_10(orders: LazyFrame[Order]) -> LazyFrame[Order]: - # Today only the `limit(...)` half is executable without richer order expressions. - return orders.limit(10) +def priced_order_lines(lines: LazyFrame[OrderLine]) -> LazyFrame[OrderLine]: + return open_order_lines(lines).with_column("discounted_unit_price", mul(col("unit_price"), lit(0.9))) def process_orders_batch(orders: DataFrame[Order]) -> DataFrame[Order]: diff --git a/examples/models.incn b/examples/models.incn index 0e0073f..64390db 100644 --- a/examples/models.incn +++ b/examples/models.incn @@ -1,18 +1,26 @@ """ -Placeholder models for InQL examples. +Compact schemas used by the InQL examples. -These are simplified models for demonstration purposes. In real code, these would -be defined in a separate module with proper field definitions. +Each executable example binds source reads to `LazyFrame[Model]` so the row contract is visible at the boundary where +data enters InQL. ## Models -- `Order`: Represents an order with id, customer_id, amount, status -- `Customer`: Represents a customer with id, name, email -- `Event`: Represents an event with id, type, severity +- `OrderId`: One-column order-id fixture row +- `Order`: General order row used by API-shape examples +- `AggregateOrder`: Compact aggregate fixture row +- `OrderLine`: Multi-column order-line fixture row +- `Customer`: Customer dimension row +- `Event`: Streaming/event row """ -# In real code, these would have proper field definitions: +@derive(Clone) +pub model OrderId: + pub id: int + + +@derive(Clone) pub model Order: pub id: int pub customer_id: int @@ -20,12 +28,32 @@ pub model Order: pub status: str +@derive(Clone) +pub model AggregateOrder: + pub customer_id: str + pub amount: int + + +@derive(Clone) +pub model OrderLine: + pub order_id: int + pub line_id: int + pub sku: str + pub qty: int + pub unit_price: float + pub currency: str + pub status: str + pub created_at: str + + +@derive(Clone) pub model Customer: pub id: int pub name: str pub email: str +@derive(Clone) pub model Event: pub id: int pub event_type: str diff --git a/examples/session_grouped_aggregate_csv.incn b/examples/session_grouped_aggregate_csv.incn index 9a8d346..72e96d1 100644 --- a/examples/session_grouped_aggregate_csv.incn +++ b/examples/session_grouped_aggregate_csv.incn @@ -1,15 +1,11 @@ -"""Example: Session grouped aggregate over one CSV fixture using real builder-based aggregation.""" +"""Example: typed Session grouped aggregate over a CSV fixture.""" +from dataset import LazyFrame from functions import col, count, display, sum +from models import AggregateOrder from session import Session, SessionError, report_session_error -@derive(Clone) -pub model AggregateOrder: - pub customer_id: str - pub amount: int - - def main() -> None: mut session = Session.default() input_uri = "tests/fixtures/aggregate_orders.csv" @@ -21,10 +17,21 @@ def main() -> None: def _group_collect_write(mut session: Session, input_uri: str, output_uri: str) -> Result[None, SessionError]: - orders = session.read_csv[AggregateOrder]("aggregate_orders", input_uri)? + orders: LazyFrame[AggregateOrder] = session.read_csv("aggregate_orders", input_uri)? session.activate() - grouped = orders.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) + grouped = customer_amount_rollup(orders) + print_lazy_schema("customer rollup", grouped) display(grouped.clone()) session.write_csv(grouped, output_uri)? return Ok(None) + + +def customer_amount_rollup(orders: LazyFrame[AggregateOrder]) -> LazyFrame[AggregateOrder]: + return orders.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) + + +def print_lazy_schema(label: str, frame: LazyFrame[AggregateOrder]) -> None: + schema = frame.schema() + println(f"{label} declared columns: {schema.declared_columns:?}") + println(f"{label} planned columns: {schema.planned_columns:?}") diff --git a/examples/session_read_transform_write_csv.incn b/examples/session_read_transform_write_csv.incn index b73ca77..4bb5819 100644 --- a/examples/session_read_transform_write_csv.incn +++ b/examples/session_read_transform_write_csv.incn @@ -1,14 +1,11 @@ -"""Example: Session read -> transform -> write CSV using LazyFrame.""" +"""Example: typed Session read -> transform -> write CSV.""" -from functions import col, display, eq, int_lit +from dataset import LazyFrame +from functions import col, display, eq, lit +from models import OrderId from session import Session, SessionError, report_session_error -@derive(Clone) -pub model CsvOrder: - pub id: int - - def main() -> None: mut session = Session.default() input_uri = "tests/fixtures/orders.csv" @@ -20,13 +17,21 @@ def main() -> None: def _read_transform_write(mut session: Session, input_uri: str, output_uri: str) -> Result[None, SessionError]: - # Read path stays deferred: this returns LazyFrame, not local materialized data. - orders = session.read_csv[CsvOrder]("orders", input_uri)? - # Activate after registration so LazyFrame.collect/display sees the registered sources. + orders: LazyFrame[OrderId] = session.read_csv("orders", input_uri)? session.activate() - # Keep transform simple and schema-safe for the current fixture. - transformed = orders.filter(eq(col("id"), int_lit(1))).limit(2) + transformed = first_order_ids(orders) + print_lazy_schema("orders", transformed) session.write_csv(transformed, output_uri)? display(transformed) return Ok(None) + + +def first_order_ids(orders: LazyFrame[OrderId]) -> LazyFrame[OrderId]: + return orders.filter(eq(col("id"), lit(1))).limit(2) + + +def print_lazy_schema(label: str, frame: LazyFrame[OrderId]) -> None: + schema = frame.schema() + println(f"{label} declared columns: {schema.declared_columns:?}") + println(f"{label} planned columns: {schema.planned_columns:?}") diff --git a/examples/session_read_transform_write_order_lines_csv.incn b/examples/session_read_transform_write_order_lines_csv.incn index 4409489..0c8b9ad 100644 --- a/examples/session_read_transform_write_order_lines_csv.incn +++ b/examples/session_read_transform_write_order_lines_csv.incn @@ -1,21 +1,11 @@ """Example: Session read -> transform -> collect/display -> write CSV with multi-column order lines.""" -from functions import display +from dataset import LazyFrame +from functions import col, display, eq, lit, mul +from models import OrderLine from session import Session, SessionError, report_session_error -@derive(Clone) -pub model OrderLine: - pub order_id: int - pub line_id: int - pub sku: str - pub qty: int - pub unit_price: float - pub currency: str - pub status: str - pub created_at: str - - def main() -> None: mut session = Session.default() input_uri = "tests/fixtures/order_lines.csv" @@ -31,14 +21,25 @@ def _read_transform_collect_display_write( input_uri: str, output_uri: str, ) -> Result[None, SessionError]: - raw = session.read_csv[OrderLine]("order_lines", input_uri)? - # Activate after registration so LazyFrame.collect/display sees the registered sources. + lines: LazyFrame[OrderLine] = session.read_csv("order_lines", input_uri)? session.activate() - # Apply one concrete transform with visible effect: reduce to top-2 rows. - # (Note: More expressive predicates/order keys are tracked as separate surface work.) - transformed = raw.limit(2) + transformed = open_eur_lines_with_discount(lines) + print_lazy_schema("order lines", transformed) session.write_csv(transformed, output_uri)? display(transformed) return Ok(None) + + +def open_eur_lines_with_discount(lines: LazyFrame[OrderLine]) -> LazyFrame[OrderLine]: + return lines.filter(eq(col("status"), lit("open"))).filter(eq(col("currency"), lit("EUR"))).with_column( + "discounted_unit_price", + mul(col("unit_price"), lit(0.9)), + ) + + +def print_lazy_schema(label: str, frame: LazyFrame[OrderLine]) -> None: + schema = frame.schema() + println(f"{label} declared columns: {schema.declared_columns:?}") + println(f"{label} planned columns: {schema.planned_columns:?}") diff --git a/examples/session_with_column_csv.incn b/examples/session_with_column_csv.incn index fe829b7..465ba53 100644 --- a/examples/session_with_column_csv.incn +++ b/examples/session_with_column_csv.incn @@ -1,15 +1,11 @@ -"""Example: Session read -> with_column -> collect/display -> write CSV using a real fixture.""" +"""Example: typed Session read -> with_column -> display -> write CSV.""" -from functions import col, display, int_expr, mul +from dataset import LazyFrame +from functions import col, display, lit, mul +from models import AggregateOrder from session import Session, SessionError, report_session_error -@derive(Clone) -pub model AggregateOrder: - pub customer_id: str - pub amount: int - - def main() -> None: mut session = Session.default() input_uri = "tests/fixtures/aggregate_orders.csv" @@ -21,10 +17,20 @@ def main() -> None: def _read_transform_write(mut session: Session, input_uri: str, output_uri: str) -> Result[None, SessionError]: - orders = session.read_csv[AggregateOrder]("aggregate_orders", input_uri)? + orders: LazyFrame[AggregateOrder] = session.read_csv("aggregate_orders", input_uri)? session.activate() - # add a column with double the amount, using a simple multiply expression over the "amount" column and an integer literal - projected = orders.with_column("double_amount", mul(col("amount"), int_expr(2))) + projected = add_double_amount(orders) + print_lazy_schema("projected orders", projected) display(projected) session.write_csv(projected, output_uri)? return Ok(None) + + +def add_double_amount(orders: LazyFrame[AggregateOrder]) -> LazyFrame[AggregateOrder]: + return orders.with_column("double_amount", mul(col("amount"), lit(2))) + + +def print_lazy_schema(label: str, frame: LazyFrame[AggregateOrder]) -> None: + schema = frame.schema() + println(f"{label} declared columns: {schema.declared_columns:?}") + println(f"{label} planned columns: {schema.planned_columns:?}") diff --git a/examples/trait_hierarchy.incn b/examples/trait_hierarchy.incn index adf56c0..ef0acfd 100644 --- a/examples/trait_hierarchy.incn +++ b/examples/trait_hierarchy.incn @@ -1,90 +1,52 @@ -""" -Example: Trait hierarchy patterns (RFC 001). - -This example demonstrates compile-time trait hierarchy usage patterns. - -## Patterns shown - -- Trait inheritance: BoundedDataSet[T] extends DataSet[T] -- Trait inheritance: UnboundedDataSet[T] extends DataSet[T] -- Concrete type implementation: DataFrame[T] implements BoundedDataSet[T] -- Concrete type implementation: LazyFrame[T] implements BoundedDataSet[T] -- Concrete type implementation: DataStream[T] implements UnboundedDataSet[T] -""" +"""Example: dataset carrier trait hierarchy with model-shaped rows.""" from dataset import DataSet, BoundedDataSet, UnboundedDataSet, DataFrame, LazyFrame, DataStream from models import Order, Event -# ---- Trait hierarchy assignment patterns ---- -def assign_data_frame_to_bounded(data: DataFrame[Order]) -> BoundedDataSet[Order]: - """Concrete bounded carrier upcasts to its bounded trait.""" - return data - - -def assign_lazy_frame_to_bounded(data: LazyFrame[Order]) -> BoundedDataSet[Order]: - """Deferred bounded carrier upcasts to its bounded trait.""" - return data +def bounded_order_frame(frame: DataFrame[Order]) -> BoundedDataSet[Order]: + return frame -def assign_data_stream_to_unbounded(data: DataStream[Event]) -> UnboundedDataSet[Event]: - """Streaming carrier upcasts to its unbounded trait.""" - return data +def bounded_order_plan(plan: LazyFrame[Order]) -> BoundedDataSet[Order]: + return plan -def assign_bounded_to_data_set(data: BoundedDataSet[Order]) -> DataSet[Order]: - """Bounded trait upcasts to the root DataSet trait.""" - return data +def unbounded_event_stream(stream: DataStream[Event]) -> UnboundedDataSet[Event]: + return stream -def assign_unbounded_to_data_set(data: UnboundedDataSet[Event]) -> DataSet[Event]: - """Unbounded trait upcasts to the root DataSet trait.""" - return data +def any_order_dataset(orders: BoundedDataSet[Order]) -> DataSet[Order]: + return orders -def chain_data_frame_to_root(data: DataFrame[Order]) -> DataSet[Order]: - """Concrete-to-supertrait chain: DataFrame -> BoundedDataSet -> DataSet.""" - return assign_bounded_to_data_set(assign_data_frame_to_bounded(data)) +def any_event_dataset(events: UnboundedDataSet[Event]) -> DataSet[Event]: + return events -def chain_data_stream_to_root(data: DataStream[Event]) -> DataSet[Event]: - """Concrete-to-supertrait chain: DataStream -> UnboundedDataSet -> DataSet.""" - return assign_unbounded_to_data_set(assign_data_stream_to_unbounded(data)) +def materialized_orders_as_root(frame: DataFrame[Order]) -> DataSet[Order]: + return any_order_dataset(bounded_order_frame(frame)) -# ---- Function signature patterns ---- -def accept_any_data_set[T](data: DataSet[T]) -> None: - """Function accepts any dataset carrier (most restrictive).""" - # Only intersection of bounded + unbounded operations available - pass +def planned_orders_as_root(plan: LazyFrame[Order]) -> DataSet[Order]: + return any_order_dataset(bounded_order_plan(plan)) -def accept_bounded_only(data: BoundedDataSet[Order]) -> None: - """Function accepts only bounded datasets.""" - # All operations available - pass +def streaming_events_as_root(stream: DataStream[Event]) -> DataSet[Event]: + return any_event_dataset(unbounded_event_stream(stream)) -def accept_unbounded_only(data: UnboundedDataSet[Event]) -> None: - """Function accepts only unbounded datasets.""" - # All operations except unbounded-state operations available - pass +def accepts_any_model_carrier[T](data: DataSet[T]) -> DataSet[T]: + return data -# ---- Generic function patterns ---- -def process_any[T](data: DataSet[T]) -> int: - """Generic function that works with any dataset carrier.""" - # Can only use operations available on DataSet[T] - return 0 +def accepts_bounded_orders(data: BoundedDataSet[Order]) -> BoundedDataSet[Order]: + return data -def process_batch_only(data: BoundedDataSet[Order]) -> int: - """Function that works only with bounded datasets.""" - # Can use all operations - return 0 +def accepts_unbounded_events(data: UnboundedDataSet[Event]) -> UnboundedDataSet[Event]: + return data -# ---- Main ---- def main() -> None: - """Example main function.""" pass diff --git a/incan.lock b/incan.lock index a7fc0e0..c1d45a0 100644 --- a/incan.lock +++ b/incan.lock @@ -3,9 +3,8 @@ [incan] format = 1 -incan-version = "0.3.0-dev.7" -generated = "2026-04-27T09:48:22.385894Z" -deps-fingerprint = "sha256:26aeed0759cb9a430f8f88d0e3c814d74ddb61831d3b496c1b509f94d3a51f3b" +incan-version = "0.3.0-rc6" +deps-fingerprint = "sha256:424fd53b12f0e810ffe3ba4188a82203bb011b5dd0769ad2b33ab1b854027487" cargo-features = [] cargo-no-default-features = false cargo-all-features = false @@ -104,9 +103,9 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "arrow" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d441fdda254b65f3e9025910eb2c2066b6295d9c8ed409522b8d2ace1ff8574c" +checksum = "378530e55cd479eda3c14eb345310799717e6f76d0c332041e8487022166b471" dependencies = [ "arrow-arith", "arrow-array", @@ -125,9 +124,9 @@ dependencies = [ [[package]] name = "arrow-arith" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ced5406f8b720cc0bc3aa9cf5758f93e8593cda5490677aa194e4b4b383f9a59" +checksum = "a0ab212d2c1886e802f51c5212d78ebbcbb0bec980fff9dadc1eb8d45cd0b738" dependencies = [ "arrow-array", "arrow-buffer", @@ -139,9 +138,9 @@ dependencies = [ [[package]] name = "arrow-array" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "772bd34cacdda8baec9418d80d23d0fb4d50ef0735685bd45158b83dfeb6e62d" +checksum = "cfd33d3e92f207444098c75b42de99d329562be0cf686b307b097cc52b4e999e" dependencies = [ "ahash", "arrow-buffer", @@ -150,7 +149,7 @@ dependencies = [ "chrono", "chrono-tz", "half", - "hashbrown 0.16.1", + "hashbrown 0.17.1", "num-complex", "num-integer", "num-traits", @@ -158,9 +157,9 @@ dependencies = [ [[package]] name = "arrow-buffer" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "898f4cf1e9598fdb77f356fdf2134feedfd0ee8d5a4e0a5f573e7d0aec16baa4" +checksum = "0c6cd424c2693bcdbc150d843dc9d4d137dd2de4782ce6df491ad11a3a0416c0" dependencies = [ "bytes", "half", @@ -170,9 +169,9 @@ dependencies = [ [[package]] name = "arrow-cast" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b0127816c96533d20fc938729f48c52d3e48f99717e7a0b5ade77d742510736d" +checksum = "4c5aefb56a2c02e9e2b30746241058b85f8983f0fcff2ba0c6d09006e1cded7f" dependencies = [ "arrow-array", "arrow-buffer", @@ -192,9 +191,9 @@ dependencies = [ [[package]] name = "arrow-csv" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca025bd0f38eeecb57c2153c0123b960494138e6a957bbda10da2b25415209fe" +checksum = "e94e8cf7e517657a52b91ea1263acf38c4ca62a84655d72458a3359b12ab97de" dependencies = [ "arrow-array", "arrow-cast", @@ -207,9 +206,9 @@ dependencies = [ [[package]] name = "arrow-data" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42d10beeab2b1c3bb0b53a00f7c944a178b622173a5c7bcabc3cb45d90238df4" +checksum = "3c88210023a2bfee1896af366309a3028fc3bcbd6515fa29a7990ee1baa08ee0" dependencies = [ "arrow-buffer", "arrow-schema", @@ -220,9 +219,9 @@ dependencies = [ [[package]] name = "arrow-ipc" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "609a441080e338147a84e8e6904b6da482cefb957c5cdc0f3398872f69a315d0" +checksum = "238438f0834483703d88896db6fe5a7138b2230debc31b34c0336c2996e3c64f" dependencies = [ "arrow-array", "arrow-buffer", @@ -236,15 +235,16 @@ dependencies = [ [[package]] name = "arrow-json" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ead0914e4861a531be48fe05858265cf854a4880b9ed12618b1d08cba9bebc8" +checksum = "205ca2119e6d679d5c133c6f30e68f027738d95ed948cf77677ea69c7800036b" dependencies = [ "arrow-array", "arrow-buffer", "arrow-cast", - "arrow-data", + "arrow-ord", "arrow-schema", + "arrow-select", "chrono", "half", "indexmap", @@ -260,9 +260,9 @@ dependencies = [ [[package]] name = "arrow-ord" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a7ba279b20b52dad300e68cfc37c17efa65e68623169076855b3a9e941ca5" +checksum = "1bffd8fd2579286a5d63bac898159873e5094a79009940bcb42bbfce4f19f1d0" dependencies = [ "arrow-array", "arrow-buffer", @@ -273,9 +273,9 @@ dependencies = [ [[package]] name = "arrow-row" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14fe367802f16d7668163ff647830258e6e0aeea9a4d79aaedf273af3bdcd3e" +checksum = "bab5994731204603c73ba69267616c50f80780774c6bb0476f1f830625115e0c" dependencies = [ "arrow-array", "arrow-buffer", @@ -286,9 +286,9 @@ dependencies = [ [[package]] name = "arrow-schema" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c30a1365d7a7dc50cc847e54154e6af49e4c4b0fddc9f607b687f29212082743" +checksum = "f633dbfdf39c039ada1bf9e34c694816eb71fbb7dc78f613993b7245e078a1ed" dependencies = [ "serde_core", "serde_json", @@ -296,9 +296,9 @@ dependencies = [ [[package]] name = "arrow-select" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78694888660a9e8ac949853db393af2a8b8fc82c19ce333132dfa2e72cc1a7fe" +checksum = "8cd065c54172ac787cf3f2f8d4107e0d3fdc26edba76fdf4f4cc170258942222" dependencies = [ "ahash", "arrow-array", @@ -310,9 +310,9 @@ dependencies = [ [[package]] name = "arrow-string" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "61e04a01f8bb73ce54437514c5fd3ee2aa3e8abe4c777ee5cc55853b1652f79e" +checksum = "29dd7cda3ab9692f43a2e4acc444d760cc17b12bb6d8232ddf64e9bab7c06b42" dependencies = [ "arrow-array", "arrow-buffer", @@ -481,9 +481,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.61" +version = "1.2.62" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" dependencies = [ "find-msvc-tools", "jobserver", @@ -662,9 +662,9 @@ dependencies = [ [[package]] name = "dashmap" -version = "6.1.0" +version = "6.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -1388,9 +1388,18 @@ checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" + +[[package]] +name = "encoding_rs" +version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] [[package]] name = "equivalent" @@ -1648,9 +1657,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "heck" @@ -1805,9 +1814,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" dependencies = [ "icu_normalizer", "icu_properties", @@ -1815,14 +1824,14 @@ dependencies = [ [[package]] name = "incan_core" -version = "0.3.0-dev.7" +version = "0.3.0-rc6" dependencies = [ "serde", ] [[package]] name = "incan_derive" -version = "0.3.0-dev.7" +version = "0.3.0-rc6" dependencies = [ "proc-macro2", "quote", @@ -1831,11 +1840,12 @@ dependencies = [ [[package]] name = "incan_stdlib" -version = "0.3.0-dev.7" +version = "0.3.0-rc6" dependencies = [ "incan_core", "incan_derive", "tokio", + "xxhash-rust", ] [[package]] @@ -1845,23 +1855,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.17.0", + "hashbrown 0.17.1", "serde", "serde_core", ] [[package]] name = "inql" -version = "0.3.0-dev.7" +version = "0.3.0-rc6" dependencies = [ + "byteorder", "datafusion", "datafusion-substrait", + "encoding_rs", "incan_derive", "incan_stdlib", "prost", "prost-types", + "rustix", "substrait 0.63.0", - "tokio", ] [[package]] @@ -1897,9 +1909,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.95" +version = "0.3.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2964e92d1d9dc3364cae4d718d93f227e3abb088e747d92e0395bfdedf1c12ca" +checksum = "67df7112613f8bfd9150013a0314e196f4800d3201ae742489d999db2f979f08" dependencies = [ "cfg-if", "futures-util", @@ -1972,9 +1984,9 @@ dependencies = [ [[package]] name = "libbz2-rs-sys" -version = "0.2.3" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a6a8c165077efc8f3a971534c50ea6a1a18b329ef4a66e897a7e3a1494565f" +checksum = "34b357333733e8260735ba5894eb928c02ecc69c78715f01a8019e7fa7f2db4c" [[package]] name = "libc" @@ -2037,9 +2049,9 @@ checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lz4_flex" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db9a0d582c2874f68138a16ce1867e0ffde6c0bb0a0df85e1f36d04146db488a" +checksum = "7ef0d4ed8669f8f8826eb00dc878084aa8f253506c4fd5e8f58f5bce72ddb97e" dependencies = [ "twox-hash", ] @@ -2200,9 +2212,9 @@ dependencies = [ [[package]] name = "parquet" -version = "58.1.0" +version = "58.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d3f9f2205199603564127932b89695f52b62322f541d0fc7179d57c2e1c9877" +checksum = "5dafa7d01085b62a47dd0c1829550a0a36710ea9c4fe358a05a85477cec8a908" dependencies = [ "ahash", "arrow-array", @@ -2218,7 +2230,7 @@ dependencies = [ "flate2", "futures", "half", - "hashbrown 0.16.1", + "hashbrown 0.17.1", "lz4_flex", "num-bigint", "num-integer", @@ -2682,9 +2694,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "indexmap", "itoa", @@ -2750,9 +2762,9 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" @@ -2966,9 +2978,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.52.1" +version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ "bytes", "libc", @@ -3206,9 +3218,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.118" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf938a0bacb0469e83c1e148908bd7d5a6010354cf4fb73279b7447422e3a89" +checksum = "49ace1d07c165b0864824eee619580c4689389afa9dc9ed3a4c75040d82e6790" dependencies = [ "cfg-if", "once_cell", @@ -3219,9 +3231,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.68" +version = "0.4.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f371d383f2fb139252e0bfac3b81b265689bf45b6874af544ffa4c975ac1ebf8" +checksum = "96492d0d3ffba25305a7dc88720d250b1401d7edca02cc3bcd50633b424673b8" dependencies = [ "js-sys", "wasm-bindgen", @@ -3229,9 +3241,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.118" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eeff24f84126c0ec2db7a449f0c2ec963c6a49efe0698c4242929da037ca28ed" +checksum = "8e68e6f4afd367a562002c05637acb8578ff2dea1943df76afb9e83d177c8578" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -3239,9 +3251,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.118" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d08065faf983b2b80a79fd87d8254c409281cf7de75fc4b773019824196c904" +checksum = "d95a9ec35c64b2a7cb35d3fead40c4238d0940c86d107136999567a4703259f2" dependencies = [ "bumpalo", "proc-macro2", @@ -3252,9 +3264,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.118" +version = "0.2.121" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fd04d9e306f1907bd13c6361b5c6bfc7b3b3c095ed3f8a9246390f8dbdee129" +checksum = "c4e0100b01e9f0d03189a92b96772a1fb998639d981193d7dbab487302513441" dependencies = [ "unicode-ident", ] @@ -3480,6 +3492,12 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" +[[package]] +name = "xxhash-rust" +version = "0.8.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" + [[package]] name = "yoke" version = "0.8.2" @@ -3525,9 +3543,9 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" dependencies = [ "zerofrom-derive", ] diff --git a/src/aggregate_builders.incn b/src/aggregate_builders.incn index 12b0c5c..9be0647 100644 --- a/src/aggregate_builders.incn +++ b/src/aggregate_builders.incn @@ -6,15 +6,15 @@ Today they provide a library-owned way to express grouping columns and aggregate typechecker changes in the Incan compiler. """ -from projection_builders import ColumnExpr, col as col_expr, require_column_expr_name +from projection_builders import ColumnExpr, col as col_expr @derive(Clone) -pub enum AggregateKind: +pub enum AggregateKind(str): """Supported aggregate function kinds in the current package slice.""" - Sum - Count + Sum = "sum" + Count = "count" @derive(Clone) @@ -22,7 +22,7 @@ pub model AggregateMeasure: """Aggregate measure description carried through dataset, Prism, and Substrait boundaries.""" pub kind: AggregateKind - pub column_name: str + pub expr: ColumnExpr pub def col(name: str) -> ColumnExpr: @@ -31,10 +31,10 @@ pub def col(name: str) -> ColumnExpr: pub def sum(expr: ColumnExpr) -> AggregateMeasure: - """Build one `sum` aggregate measure over a named column reference.""" - return AggregateMeasure(kind=AggregateKind.Sum, column_name=require_column_expr_name(expr, str("sum(...)"))) + """Build one `sum` aggregate measure over a scalar expression.""" + return AggregateMeasure(kind=AggregateKind.Sum, expr=expr) pub def count() -> AggregateMeasure: """Build one zero-argument `count` aggregate measure.""" - return AggregateMeasure(kind=AggregateKind.Count, column_name="") + return AggregateMeasure(kind=AggregateKind.Count, expr=col_expr("")) diff --git a/src/backends.incn b/src/backends.incn index a796992..a4f0f7a 100644 --- a/src/backends.incn +++ b/src/backends.incn @@ -19,15 +19,15 @@ pub class DataFusion: @derive(Clone) -pub enum BackendKind: - DataFusionEngine +pub enum BackendKind(str): + DataFusionEngine = "datafusion" @derive(Clone) -pub enum SourceKind: - Csv - Parquet - Arrow +pub enum SourceKind(str): + Csv = "csv" + Parquet = "parquet" + Arrow = "arrow" @derive(Clone) @@ -51,10 +51,7 @@ pub class BackendSelection: def clone(self) -> Self: """Return one cloned backend selection preserving engine-specific options.""" - mut kind = BackendKind.DataFusionEngine - if self.kind == BackendKind.DataFusionEngine: - kind = BackendKind.DataFusionEngine - return BackendSelection(kind=kind, datafusion=self.datafusion) + return BackendSelection(kind=self.kind, datafusion=self.datafusion) pub def default_backend_selection() -> BackendSelection: @@ -64,18 +61,12 @@ pub def default_backend_selection() -> BackendSelection: pub def backend_kind_name(selection: BackendSelection) -> str: """Return one stable backend name for diagnostics and UX.""" - if selection.kind == BackendKind.DataFusionEngine: - return str("datafusion") - return str("unknown") + return selection.kind.value() pub def source_kind_name(kind: SourceKind) -> str: """Return one stable source-kind name for diagnostics and UX.""" - if kind == SourceKind.Csv: - return "csv" - if kind == SourceKind.Parquet: - return "parquet" - return "arrow" + return kind.value() pub def csv_source(uri: str) -> TableSource: diff --git a/src/dataset/mod.incn b/src/dataset/mod.incn index 7ebbfbb..6f4ebb5 100644 --- a/src/dataset/mod.incn +++ b/src/dataset/mod.incn @@ -18,7 +18,7 @@ state (`LazyFrame.collect()`) intentionally route through Session-owned active-s The current method-chain surface in this module is the explicit builder-based API: -- `filter(predicate: FilterPredicate)` +- `filter(predicate: ColumnExpr)` - `with_column(name: str, expr: ColumnExpr)` - `group_by(columns: list[ColumnExpr])` - `agg(measures: list[AggregateMeasure])` @@ -28,14 +28,14 @@ Illustrative current-shape examples: ```incan from pub::inql import LazyFrame -from pub::inql.functions import col, gt, int_lit, add, int_expr, sum, count +from pub::inql.functions import add, col, count, gt, lit, sum from models import Order, OrderSummary def high_value_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.filter(gt(col("amount"), int_lit(100))) + return orders.filter(gt(col("amount"), lit(100))) def enrich_orders(orders: LazyFrame[Order]) -> LazyFrame[Order]: - return orders.with_column("amount_x2", add(col("amount"), int_expr(2))) + return orders.with_column("amount_x2", add(col("amount"), lit(2))) def summarize_orders(orders: LazyFrame[Order]) -> LazyFrame[OrderSummary]: return orders.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) @@ -53,9 +53,9 @@ See also: from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure -from filter_builders import FilterPredicate from projection_builders import ColumnExpr from dataset.materialization import DataFrameMaterialization +from substrait.errors import SubstraitLoweringError from substrait.schema_registry import named_table_columns from substrait.plans import plan_from_root_relation from substrait.inspect import relation_output_columns, root_rel, source_named_table_name @@ -89,20 +89,20 @@ from prism import ( # ---- DataSet trait ---- pub trait DataSet[T with Clone]: - def to_substrait_plan(self) -> Plan: ... - def filter(self, predicate: FilterPredicate) -> Self: ... - # TODO(InQL RFC 004): Replace the temporary Self-only join surface with heterogeneous join typing plus an explicit output-schema contract. - def join(self, other: Self, on: bool) -> Self: ... - def select(self) -> Self: ... - def with_column(self, name: str, expr: ColumnExpr) -> Self: ... - def group_by(self, columns: list[ColumnExpr]) -> Self: ... - def agg(self, measures: list[AggregateMeasure]) -> Self: ... - def order_by(self) -> Self: ... - def limit(self, n: int) -> Self: ... - def explode(self) -> Self: ... - - -# ---- BoundedDataSet trait and concrete types ---- + def to_substrait_plan(self) -> Plan + def try_to_substrait_plan(self) -> Result[Plan, SubstraitLoweringError] + def filter(self, predicate: ColumnExpr) -> Self + # TODO(InQL RFC 004): Replace the Self-only join surface with heterogeneous join typing and an explicit output-schema contract. + def join(self, other: Self, on: bool) -> Self + def select(self) -> Self + def with_column(self, name: str, expr: ColumnExpr) -> Self + def group_by(self, columns: list[ColumnExpr]) -> Self + def agg(self, measures: list[AggregateMeasure]) -> Self + def order_by(self) -> Self + def limit(self, n: int) -> Self + def explode(self) -> Self + + pub trait BoundedDataSet[T with Clone] with DataSet[T]: pass @@ -122,10 +122,7 @@ pub model CarrierSchema: pub class DataFrame[T with Clone] with BoundedDataSet: - # Generic witness so runtime representation still carries T in the generated Rust type. pub _type_witness: list[T] - # Structured materialization produced by Session.collect(...). - # Preview text is display-oriented; resolved columns and row count are the canonical collected metadata. pub _materialization: DataFrameMaterialization pub _substrait_rel: Rel @@ -134,7 +131,7 @@ pub class DataFrame[T with Clone] with BoundedDataSet: return str(self._materialization.preview_text) def materialized_payload(self) -> str: - """Compatibility wrapper returning the collected preview text for older callers.""" + """Return the collected preview text through the payload-oriented helper name.""" return self.preview_text() def row_count(self) -> int: @@ -172,51 +169,55 @@ pub class DataFrame[T with Clone] with BoundedDataSet: """Rebuild a one-root Substrait plan from this frame's root relation.""" return plan_from_root_relation(self._substrait_rel, self.planned_columns()) - def filter(self, predicate: FilterPredicate) -> Self: + def try_to_substrait_plan(self) -> Result[Plan, SubstraitLoweringError]: + """Fallibly rebuild a one-root Substrait plan from this frame's already-lowered root relation.""" + return Ok(self.to_substrait_plan()) + + def filter(self, predicate: ColumnExpr) -> Self: """Return one new DataFrame with a filter stage appended and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T]( + return _data_frame_with_invalidated_materialization( filter_ds_of_columns(self._substrait_rel, self.planned_columns(), predicate), ) def join(self, other: Self, on: bool) -> Self: """Return one new DataFrame with a join stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T](join_ds(self._substrait_rel, other._substrait_rel, on)) + return _data_frame_with_invalidated_materialization(join_ds(self._substrait_rel, other._substrait_rel, on)) def select(self) -> Self: """Return one new DataFrame with an identity projection stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T]( + return _data_frame_with_invalidated_materialization( select_ds_of_columns(self._substrait_rel, self.planned_columns()), ) def with_column(self, name: str, expr: ColumnExpr) -> Self: """Return one new DataFrame with an add-or-replace projection stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T]( + return _data_frame_with_invalidated_materialization( with_column_ds(self._substrait_rel, self.planned_columns(), name, expr), ) def group_by(self, columns: list[ColumnExpr]) -> Self: """Return one new DataFrame with a grouping stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T]( + return _data_frame_with_invalidated_materialization( group_by_ds_of_columns(self._substrait_rel, self.planned_columns(), columns), ) def agg(self, measures: list[AggregateMeasure]) -> Self: """Return one new DataFrame with an aggregation stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T]( + return _data_frame_with_invalidated_materialization( agg_ds_of_columns(self._substrait_rel, self.planned_columns(), measures), ) def order_by(self) -> Self: """Return one new DataFrame with an ordering stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T](order_by_ds(self._substrait_rel)) + return _data_frame_with_invalidated_materialization(order_by_ds(self._substrait_rel)) def limit(self, n: int) -> Self: """Return one new DataFrame with a row-limit stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T](limit_ds(self._substrait_rel, n)) + return _data_frame_with_invalidated_materialization(limit_ds(self._substrait_rel, n)) def explode(self) -> Self: """Return one new DataFrame with an explode/unnest stage and stale materialization cleared.""" - return _data_frame_with_invalidated_materialization[T](explode_ds(self._substrait_rel)) + return _data_frame_with_invalidated_materialization(explode_ds(self._substrait_rel)) pub class LazyFrame[T with Clone] with BoundedDataSet: @@ -257,7 +258,11 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: """Lower this lazy pipeline to a Substrait `Plan` at the boundary.""" return plan_from_root_relation(self._cursor.lower_to_rel(), self.planned_columns()) - def filter(self, predicate: FilterPredicate) -> Self: + def try_to_substrait_plan(self) -> Result[Plan, SubstraitLoweringError]: + """Fallibly lower this lazy pipeline to a Substrait `Plan` at the boundary.""" + return Ok(plan_from_root_relation(self._cursor.try_lower_to_rel()?, self.planned_columns())) + + def filter(self, predicate: ColumnExpr) -> Self: """Return one new lazy carrier with an appended filter stage.""" return LazyFrame(_cursor=prism_cursor_apply_filter(self._cursor, predicate)) @@ -295,12 +300,12 @@ pub class LazyFrame[T with Clone] with BoundedDataSet: def collect(self) -> Result[DataFrame[T], SessionError]: """Collect through the currently active Session and return one materialized DataFrame.""" - return collect_with_active_session[T](self.clone()) + return collect_with_active_session(self.clone()) pub def lazy_frame_named_table[T with Clone](table_name: str) -> LazyFrame[T]: """Create one deferred lazy carrier rooted at one logical named-table read.""" - return LazyFrame(_cursor=prism_cursor_named_table[T](table_name)) + return LazyFrame(_cursor=prism_cursor_named_table(table_name)) def _empty_type_witness[T with Clone]() -> list[T]: @@ -310,8 +315,8 @@ def _empty_type_witness[T with Clone]() -> list[T]: def _data_frame_with_invalidated_materialization[T with Clone](rel: Rel) -> DataFrame[T]: """Build one DataFrame whose logical relation changed and whose prior collected materialization is no longer valid.""" - return DataFrame[T]( - _type_witness=_empty_type_witness[T](), + return DataFrame( + _type_witness=_empty_type_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=rel, ) @@ -325,7 +330,6 @@ def _declared_columns_from_rel(rel: Rel) -> list[str]: return named_table_columns(table_name) -# ---- UnboundedDataSet trait and concrete types ---- pub trait UnboundedDataSet[T with Clone] with DataSet[T]: pass @@ -358,7 +362,11 @@ pub class DataStream[T with Clone] with UnboundedDataSet: """Rebuild a one-root Substrait plan from this stream relation.""" return plan_from_root_relation(self._substrait_rel, self.planned_columns()) - def filter(self, predicate: FilterPredicate) -> Self: + def try_to_substrait_plan(self) -> Result[Plan, SubstraitLoweringError]: + """Fallibly rebuild a one-root Substrait plan from this stream's already-lowered root relation.""" + return Ok(self.to_substrait_plan()) + + def filter(self, predicate: ColumnExpr) -> Self: """Return one new DataStream with a filter stage.""" return DataStream( _row_schema_marker=self._row_schema_marker.clone(), diff --git a/src/dataset/ops.incn b/src/dataset/ops.incn index e88e659..0367186 100644 --- a/src/dataset/ops.incn +++ b/src/dataset/ops.incn @@ -1,5 +1,5 @@ """ -Canonical dataset-operation seam over concrete relation builders. +Canonical dataset-operation layer over concrete relation builders. These helpers keep carrier-level operation names (`filter`, `join`, `group_by`, ...) distinct from the exact Substrait builder functions used underneath. Projection now flows through explicit input-column lists so carrier-level schema @@ -8,8 +8,7 @@ views stay aligned with the lowered relation tree. from rust::substrait::proto import Rel from aggregate_builders import AggregateMeasure -from filter_builders import FilterPredicate -from projection_builders import ColumnExpr, ProjectionAssignment, require_column_expr_name, with_column_assignment +from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment from substrait.extensions import explode_extension_uri from substrait.inspect import relation_output_columns from substrait.relations import ( @@ -23,13 +22,13 @@ from substrait.relations import ( ) -pub def filter_ds(rel: Rel, predicate: FilterPredicate) -> Rel: +pub def filter_ds(rel: Rel, predicate: ColumnExpr) -> Rel: """ Apply dataset-level filter intent to one relation. - Args: - rel: Input relation to filter. - predicate: Filter predicate for the current package slice. + Args: + rel: Input relation to filter. + predicate: Boolean scalar expression to apply. Returns: A relation shaped as a filter over the input relation. @@ -37,7 +36,7 @@ pub def filter_ds(rel: Rel, predicate: FilterPredicate) -> Rel: return filter_ds_of_columns(rel, relation_output_columns(rel.clone()), predicate) -pub def filter_ds_of_columns(rel: Rel, input_columns: list[str], predicate: FilterPredicate) -> Rel: +pub def filter_ds_of_columns(rel: Rel, input_columns: list[str], predicate: ColumnExpr) -> Rel: """Apply dataset-level filter intent using explicit input-column names.""" return filter_rel_of_columns(rel, input_columns, predicate) @@ -49,7 +48,7 @@ pub def join_ds(left_rel: Rel, right_rel: Rel, on: bool) -> Rel: Args: left_rel: Left input relation. right_rel: Right input relation. - on: Placeholder join condition for the current package slice. + on: Boolean join condition accepted by the current join builder. Returns: A relation shaped as a join over the two input relations. @@ -91,6 +90,7 @@ pub def group_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: Args: rel: Input relation to group. + columns: Scalar expressions that define grouping keys. Returns: A relation shaped as the grouping form of aggregate over the input relation. @@ -100,8 +100,7 @@ pub def group_by_ds(rel: Rel, columns: list[ColumnExpr]) -> Rel: pub def group_by_ds_of_columns(rel: Rel, input_columns: list[str], columns: list[ColumnExpr]) -> Rel: """Apply dataset-level grouping intent using explicit input-column names.""" - column_names = [require_column_expr_name(column, str("group_by(...)")) for column in columns] - return aggregate_rel_of_columns(rel, input_columns, column_names, []) + return aggregate_rel_of_columns(rel, input_columns, columns, []) pub def agg_ds(rel: Rel, measures: list[AggregateMeasure]) -> Rel: @@ -110,6 +109,7 @@ pub def agg_ds(rel: Rel, measures: list[AggregateMeasure]) -> Rel: Args: rel: Input relation to aggregate. + measures: Aggregate measures to apply over the relation or grouping. Returns: A relation shaped as the aggregating form of aggregate over the input relation. diff --git a/src/filter_builders.incn b/src/filter_builders.incn index f21b458..e73bd47 100644 --- a/src/filter_builders.incn +++ b/src/filter_builders.incn @@ -1,90 +1,43 @@ """ -Explicit filter builder surface for the current InQL-only predicate slice. +Filter-oriented builder surface for the current InQL scalar-expression slice. -These builders are the semantic target for future compiler sugar such as `.amount > 100` and query-block predicate -lowering. Today they provide a library-owned way to express real filter intent without requiring parser or typechecker -changes in the Incan compiler. +These names provide predicate-friendly helper spellings over the same scalar-expression model used by filters, +projections, grouping keys, and aggregate inputs. """ -from projection_builders import ColumnExpr, require_column_expr_name +from projection_builders import ColumnExpr, bool_expr, eq as scalar_eq, gt as scalar_gt, int_expr, str_expr -@derive(Clone) -pub enum FilterLiteralKind: - """Supported literal kinds for the current filter-builder surface.""" +pub def int_lit(value: int) -> ColumnExpr: + """Build one integer scalar literal through the filter-helper naming style.""" + return int_expr(value) - Int - String - Bool +pub def str_lit(value: str) -> ColumnExpr: + """Build one string scalar literal through the filter-helper naming style.""" + return str_expr(value) -@derive(Clone) -pub model FilterLiteral: - """Literal value captured by one filter predicate.""" - pub kind: FilterLiteralKind - pub int_value: int - pub string_value: str - pub bool_value: bool +pub def bool_lit(value: bool) -> ColumnExpr: + """Build one boolean scalar literal through the filter-helper naming style.""" + return bool_expr(value) -@derive(Clone) -pub enum FilterPredicateKind: - """Supported predicate kinds for the current filter-builder surface.""" +pub def always_true() -> ColumnExpr: + """Build one no-op scalar predicate that canonical rewrite can eliminate.""" + return bool_expr(true) - AlwaysTrue - AlwaysFalse - Eq - Gt +pub def always_false() -> ColumnExpr: + """Build one scalar predicate that rejects every row.""" + return bool_expr(false) -@derive(Clone) -pub model FilterPredicate: - """Filter predicate description carried through dataset, Prism, and Substrait boundaries.""" - pub kind: FilterPredicateKind - pub column_name: str - pub literal: FilterLiteral +pub def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """Build one equality scalar predicate.""" + return scalar_eq(left, right) -pub def int_lit(value: int) -> FilterLiteral: - """Build one integer literal for filter expressions.""" - return FilterLiteral(kind=FilterLiteralKind.Int, int_value=value, string_value=str(""), bool_value=false) - - -pub def str_lit(value: str) -> FilterLiteral: - """Build one string literal for filter expressions.""" - return FilterLiteral(kind=FilterLiteralKind.String, int_value=0, string_value=value, bool_value=false) - - -pub def bool_lit(value: bool) -> FilterLiteral: - """Build one boolean literal for filter expressions.""" - return FilterLiteral(kind=FilterLiteralKind.Bool, int_value=0, string_value=str(""), bool_value=value) - - -pub def always_true() -> FilterPredicate: - """Build one no-op filter predicate that canonical rewrite can eliminate.""" - return FilterPredicate(kind=FilterPredicateKind.AlwaysTrue, column_name=str(""), literal=bool_lit(true)) - - -pub def always_false() -> FilterPredicate: - """Build one filter predicate that rejects every row.""" - return FilterPredicate(kind=FilterPredicateKind.AlwaysFalse, column_name=str(""), literal=bool_lit(false)) - - -pub def eq(column: ColumnExpr, literal: FilterLiteral) -> FilterPredicate: - """Build one equality predicate over a named column and literal.""" - return FilterPredicate( - kind=FilterPredicateKind.Eq, - column_name=require_column_expr_name(column, str("eq(...)")), - literal=literal, - ) - - -pub def gt(column: ColumnExpr, literal: FilterLiteral) -> FilterPredicate: - """Build one greater-than predicate over a named column and literal.""" - return FilterPredicate( - kind=FilterPredicateKind.Gt, - column_name=require_column_expr_name(column, str("gt(...)")), - literal=literal, - ) +pub def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """Build one greater-than scalar predicate.""" + return scalar_gt(left, right) diff --git a/src/functions.incn b/src/functions.incn index 4677f72..92bbda0 100644 --- a/src/functions.incn +++ b/src/functions.incn @@ -5,10 +5,8 @@ These symbols **must** be imported by authors. The current package slice uses ex semantic target for future compiler sugar and query-block lowering. """ -from aggregate_builders import AggregateMeasure, count as count_builder, sum as sum_builder +from aggregate_builders import count as count_builder, sum as sum_builder from filter_builders import ( - FilterLiteral, - FilterPredicate, bool_lit as bool_lit_builder, always_false as always_false_builder, always_true as always_true_builder, @@ -18,96 +16,34 @@ from filter_builders import ( str_lit as str_lit_builder, ) from projection_builders import ( - ColumnExpr, add as add_builder, bool_expr as bool_expr_builder, col as col_builder, float_expr as float_expr_builder, int_expr as int_expr_builder, + lit as lit_builder, mul as mul_builder, str_expr as str_expr_builder, ) from dataset import DataFrame, LazyFrame - -pub def col(name: str) -> ColumnExpr: - """Build one named column reference expression.""" - return col_builder(name) - - -pub def sum(expr: ColumnExpr) -> AggregateMeasure: - """Build one `sum` aggregate measure over a named column reference.""" - return sum_builder(expr) - - -pub def count() -> AggregateMeasure: - """Build one zero-argument `count` aggregate measure.""" - return count_builder() - - -pub def int_expr(value: int) -> ColumnExpr: - """Build one integer literal expression for projection builders.""" - return int_expr_builder(value) - - -pub def float_expr(value: float) -> ColumnExpr: - """Build one float literal expression for projection builders.""" - return float_expr_builder(value) - - -pub def str_expr(value: str) -> ColumnExpr: - """Build one string literal expression for projection builders.""" - return str_expr_builder(value) - - -pub def bool_expr(value: bool) -> ColumnExpr: - """Build one boolean literal expression for projection builders.""" - return bool_expr_builder(value) - - -pub def add(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one binary addition projection expression.""" - return add_builder(left, right) - - -pub def mul(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: - """Build one binary multiply projection expression.""" - return mul_builder(left, right) - - -pub def int_lit(value: int) -> FilterLiteral: - """Build one integer literal for filter expressions.""" - return int_lit_builder(value) - - -pub def str_lit(value: str) -> FilterLiteral: - """Build one string literal for filter expressions.""" - return str_lit_builder(value) - - -pub def bool_lit(value: bool) -> FilterLiteral: - """Build one boolean literal for filter expressions.""" - return bool_lit_builder(value) - - -pub def always_true() -> FilterPredicate: - """Build one filter predicate that canonical rewrite can eliminate.""" - return always_true_builder() - - -pub def always_false() -> FilterPredicate: - """Build one filter predicate that rejects every row.""" - return always_false_builder() - - -pub def eq(column: ColumnExpr, literal: FilterLiteral) -> FilterPredicate: - """Build one equality predicate over a named column and literal.""" - return eq_builder(column, literal) - - -pub def gt(column: ColumnExpr, literal: FilterLiteral) -> FilterPredicate: - """Build one greater-than predicate over a named column and literal.""" - return gt_builder(column, literal) +pub col = alias col_builder +pub lit = alias lit_builder +pub sum = alias sum_builder +pub count = alias count_builder +pub int_expr = alias int_expr_builder +pub float_expr = alias float_expr_builder +pub str_expr = alias str_expr_builder +pub bool_expr = alias bool_expr_builder +pub add = alias add_builder +pub mul = alias mul_builder +pub int_lit = alias int_lit_builder +pub str_lit = alias str_lit_builder +pub bool_lit = alias bool_lit_builder +pub always_true = alias always_true_builder +pub always_false = alias always_false_builder +pub eq = alias eq_builder +pub gt = alias gt_builder pub def display[T with Clone](data: LazyFrame[T]) -> None: diff --git a/src/lib.incn b/src/lib.incn index 02d69bc..b65a5ca 100644 --- a/src/lib.incn +++ b/src/lib.incn @@ -8,8 +8,22 @@ Consumers depend on this package via `[dependencies]` and import with `from pub: pub from dataset import BoundedDataSet, DataFrame, DataSet, DataStream, LazyFrame, UnboundedDataSet pub from dataset.ops import agg_ds, explode_ds, filter_ds, group_by_ds, join_ds, limit_ds, order_by_ds, select_ds pub from aggregate_builders import AggregateKind, AggregateMeasure -pub from filter_builders import FilterLiteral, FilterLiteralKind, FilterPredicate, FilterPredicateKind -pub from projection_builders import ColumnExpr, ColumnExprKind, ProjectionAssignment +pub from projection_builders import ( + AddExpr, + BoolLiteralExpr, + ColumnExpr, + ColumnExprKind, + ColumnRefExpr, + EqExpr, + FloatLiteralExpr, + GtExpr, + IntLiteralExpr, + MultiplyExpr, + ProjectionAssignment, + StringLiteralExpr, + column_expr_kind, + column_expr_name, +) pub from functions import ( add, always_false, @@ -24,6 +38,7 @@ pub from functions import ( gt, int_expr, int_lit, + lit, mul, str_expr, str_lit, @@ -32,7 +47,8 @@ pub from functions import ( pub from backends import BackendKind, DataFusion, TableSource, csv_source, parquet_source pub from metadata import inql_version pub from session.types import Session, SessionBuilder -pub from session.errors import SessionError, format_session_diagnostic, report_session_error +pub from session.errors import SessionError, SessionErrorKind, format_session_diagnostic, report_session_error +pub from substrait.errors import SubstraitLoweringError, SubstraitLoweringErrorKind pub from substrait.schema import ( DemoCustomer, RowColumnSpec, diff --git a/src/prism/lower.incn b/src/prism/lower.incn index fb10d12..e44edba 100644 --- a/src/prism/lower.incn +++ b/src/prism/lower.incn @@ -1,19 +1,21 @@ """Lowering from Prism optimized views into Substrait relations and plans.""" from rust::substrait::proto import Plan, Rel +from rust::incan_stdlib::errors import raise_value_error from prism.output_columns import rewritten_output_columns from substrait.extensions import explode_extension_uri from substrait.plans import plan_from_root_relation from substrait.relations import ( - aggregate_rel_of_columns, extension_single_rel, fetch_rel, - filter_rel_of_columns, join_rel, - project_rel_of_columns, read_named_table_rel, sort_rel, + try_aggregate_rel_of_columns, + try_filter_rel_of_columns, + try_project_rel_of_columns, ) +from substrait.errors import SubstraitLoweringError from prism.rewrite import derive_rewritten_view, rewritten_node_at from prism.types import PrismNodeKind, PrismOptimizedView, PrismStoreId @@ -24,12 +26,26 @@ pub def lower_prism_tip(store_id: PrismStoreId, tip_id: int) -> Rel: This is the stable boundary from Prism's internal graph representation into concrete Substrait relations. """ - return lower_optimized_view(derive_rewritten_view(store_id, tip_id)) + return _lowered_rel_or_raise(try_lower_prism_tip(store_id, tip_id)) + + +pub def try_lower_prism_tip(store_id: PrismStoreId, tip_id: int) -> Result[Rel, SubstraitLoweringError]: + """ + Fallibly lower one authored Prism tip through the canonical rewrite view. + + User-authored scalar-expression errors are reported through `SubstraitLoweringError` instead of assertion failure. + """ + return try_lower_optimized_view(derive_rewritten_view(store_id, tip_id)) pub def lower_optimized_view(view: PrismOptimizedView) -> Rel: """Lower one rewritten Prism view to its root Substrait relation.""" - return _lower_node(view, view.root_tip_id) + return _lowered_rel_or_raise(try_lower_optimized_view(view)) + + +pub def try_lower_optimized_view(view: PrismOptimizedView) -> Result[Rel, SubstraitLoweringError]: + """Fallibly lower one rewritten Prism view to its root Substrait relation.""" + return _try_lower_node(view, view.root_tip_id) pub def prism_rel_to_plan(rel: Rel) -> Plan: @@ -43,30 +59,46 @@ pub def prism_rel_to_plan(rel: Rel) -> Plan: def _lower_node(view: PrismOptimizedView, node_id: int) -> Rel: """Lower one Prism node to its corresponding Substrait relation.""" + return _lowered_rel_or_raise(_try_lower_node(view, node_id)) + + +def _lowered_rel_or_raise(result: Result[Rel, SubstraitLoweringError]) -> Rel: + """Return one lowered relation or raise the structured lowering diagnostic for infallible call sites.""" + match result: + Ok(rel) => return rel + Err(err) => + message = err.error_message() + return raise_value_error(message) + + +def _try_lower_node(view: PrismOptimizedView, node_id: int) -> Result[Rel, SubstraitLoweringError]: + """Fallibly lower one Prism node to its corresponding Substrait relation.""" node = rewritten_node_at(view, node_id) match node.kind: - PrismNodeKind.ReadNamedTable => return read_named_table_rel(node.named_table) + PrismNodeKind.ReadNamedTable => return Ok(read_named_table_rel(node.named_table)) PrismNodeKind.Filter => - return filter_rel_of_columns( - _lower_node(view, node.input_ids[0]), + return try_filter_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, rewritten_output_columns(view, node.input_ids[0]), node.filter_predicate, ) PrismNodeKind.Join => - return join_rel( - _lower_node(view, node.input_ids[0]), - _lower_node(view, node.input_ids[1]), - node.join_predicate, + return Ok( + join_rel( + _try_lower_node(view, node.input_ids[0])?, + _try_lower_node(view, node.input_ids[1])?, + node.join_predicate, + ), ) PrismNodeKind.Project => - return project_rel_of_columns( - _lower_node(view, node.input_ids[0]), + return try_project_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, rewritten_output_columns(view, node.input_ids[0]), node.projection_assignments, ) PrismNodeKind.GroupBy => - return aggregate_rel_of_columns( - _lower_node(view, node.input_ids[0]), + return try_aggregate_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, rewritten_output_columns(view, node.input_ids[0]), node.group_columns, [], @@ -74,19 +106,19 @@ def _lower_node(view: PrismOptimizedView, node_id: int) -> Rel: PrismNodeKind.Aggregate => input_node = rewritten_node_at(view, node.input_ids[0]) if input_node.kind == PrismNodeKind.GroupBy: - return aggregate_rel_of_columns( - _lower_node(view, input_node.input_ids[0]), + return try_aggregate_rel_of_columns( + _try_lower_node(view, input_node.input_ids[0])?, rewritten_output_columns(view, input_node.input_ids[0]), input_node.group_columns, node.aggregate_measures, ) - return aggregate_rel_of_columns( - _lower_node(view, node.input_ids[0]), + return try_aggregate_rel_of_columns( + _try_lower_node(view, node.input_ids[0])?, rewritten_output_columns(view, node.input_ids[0]), [], node.aggregate_measures, ) - PrismNodeKind.OrderBy => return sort_rel(_lower_node(view, node.input_ids[0])) - PrismNodeKind.Limit => return fetch_rel(_lower_node(view, node.input_ids[0]), 0, node.limit_count) + PrismNodeKind.OrderBy => return Ok(sort_rel(_try_lower_node(view, node.input_ids[0])?)) + PrismNodeKind.Limit => return Ok(fetch_rel(_try_lower_node(view, node.input_ids[0])?, 0, node.limit_count)) PrismNodeKind.Explode => - return extension_single_rel(_lower_node(view, node.input_ids[0]), explode_extension_uri()) + return Ok(extension_single_rel(_try_lower_node(view, node.input_ids[0])?, explode_extension_uri())) diff --git a/src/prism/mod.incn b/src/prism/mod.incn index 5ac8ae9..1754b53 100644 --- a/src/prism/mod.incn +++ b/src/prism/mod.incn @@ -12,9 +12,13 @@ This façade keeps one stable internal import surface while the implementation i from rust::substrait::proto import Plan, Rel from aggregate_builders import AggregateMeasure -from filter_builders import FilterPredicate, always_true -from projection_builders import ColumnExpr, ProjectionAssignment, require_column_expr_name, with_column_assignment -from prism.lower import lower_prism_tip as lower_prism_tip_impl, prism_rel_to_plan +from filter_builders import always_true +from projection_builders import ColumnExpr, ProjectionAssignment, with_column_assignment +from prism.lower import ( + lower_prism_tip as lower_prism_tip_impl, + prism_rel_to_plan, + try_lower_prism_tip as try_lower_prism_tip_impl, +) from prism.output_columns import rewritten_output_columns from prism.rewrite import ( derive_optimized_view as derive_optimized_view_impl, @@ -32,6 +36,7 @@ from prism.store import ( store_nodes, ) from prism.types import authored_node_kind_name, PrismNodeKind, PrismOptimizedView, PrismRewriteExplain, PrismStoreId +from substrait.errors import SubstraitLoweringError pub class PrismCursor[T with Clone]: @@ -51,7 +56,7 @@ pub class PrismCursor[T with Clone]: """Return one cloned cursor handle to the same store and tip.""" return PrismCursor(store_id=self.store_id, tip_id=self.tip_id, _type_marker=[]) - def filter(self, predicate: FilterPredicate) -> Self: + def filter(self, predicate: ColumnExpr) -> Self: """Append one filter node in the shared authored store and return the derived tip.""" next_tip_id = append_node( store_id=self.store_id, @@ -132,7 +137,6 @@ pub class PrismCursor[T with Clone]: def group_by(self, columns: list[ColumnExpr]) -> Self: """Append one grouping node and return the derived tip.""" - column_names = [require_column_expr_name(column, str("group_by(...)")) for column in columns] next_tip_id = append_node( store_id=self.store_id, kind=PrismNodeKind.GroupBy, @@ -141,7 +145,7 @@ pub class PrismCursor[T with Clone]: join_predicate=false, filter_predicate=always_true(), limit_count=0, - group_columns=column_names, + group_columns=columns, aggregate_measures=[], projection_assignments=[], ) @@ -223,10 +227,18 @@ pub class PrismCursor[T with Clone]: """Lower the current tip through canonical rewrites to one Substrait root relation.""" return lower_prism_tip(self.store_id, self.tip_id) + def try_lower_to_rel(self) -> Result[Rel, SubstraitLoweringError]: + """Fallibly lower the current tip through canonical rewrites to one Substrait root relation.""" + return try_lower_prism_tip(self.store_id, self.tip_id) + def to_substrait_plan(self) -> Plan: """Lower the current tip all the way to one Substrait plan.""" return prism_rel_to_plan(self.lower_to_rel()) + def try_to_substrait_plan(self) -> Result[Plan, SubstraitLoweringError]: + """Fallibly lower the current tip all the way to one Substrait plan.""" + return Ok(prism_rel_to_plan(self.try_lower_to_rel()?)) + pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T]: """Create the minimal shared-store cursor rooted at one named-table read.""" @@ -246,7 +258,7 @@ pub def prism_cursor_named_table[T with Clone](table_name: str) -> PrismCursor[T return PrismCursor(store_id=store_id, tip_id=tip_id, _type_marker=[]) -pub def prism_cursor_apply_filter[T with Clone](cursor: PrismCursor[T], predicate: FilterPredicate) -> PrismCursor[T]: +pub def prism_cursor_apply_filter[T with Clone](cursor: PrismCursor[T], predicate: ColumnExpr) -> PrismCursor[T]: """Apply dataset-level filter intent through Prism's shared authored store.""" return cursor.filter(predicate) @@ -312,6 +324,11 @@ pub def lower_prism_tip(store_id: PrismStoreId, tip_id: int) -> Rel: return lower_prism_tip_impl(store_id, tip_id) +pub def try_lower_prism_tip(store_id: PrismStoreId, tip_id: int) -> Result[Rel, SubstraitLoweringError]: + """Fallibly lower one authored Prism tip through canonical rewrite view to the `Rel` boundary.""" + return try_lower_prism_tip_impl(store_id, tip_id) + + pub def derive_optimized_view(store_id: PrismStoreId, tip_id: int) -> PrismOptimizedView: """Derive the current identity-shaped view for one authored tip.""" return derive_optimized_view_impl(store_id, tip_id) diff --git a/src/prism/output_columns.incn b/src/prism/output_columns.incn index e8199ea..f1de58c 100644 --- a/src/prism/output_columns.incn +++ b/src/prism/output_columns.incn @@ -3,14 +3,19 @@ from prism.store import node_at from prism.rewrite import rewritten_node_at from prism.types import PrismNodeKind, PrismOptimizedView, PrismStoreId -from projection_builders import project_output_columns +from projection_builders import ColumnExpr, project_output_columns, scalar_expr_output_name from substrait.inspect import aggregate_measure_output_names from substrait.schema_registry import named_table_columns def _is_passthrough_output_kind(kind: PrismNodeKind) -> bool: """Return whether a Prism node preserves its input columns unchanged.""" - return (kind == PrismNodeKind.Filter or kind == PrismNodeKind.OrderBy or kind == PrismNodeKind.Limit or kind == PrismNodeKind.Explode) + return ( + kind == PrismNodeKind.Filter + or kind == PrismNodeKind.OrderBy + or kind == PrismNodeKind.Limit + or kind == PrismNodeKind.Explode + ) pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str]: @@ -33,7 +38,7 @@ pub def authored_output_columns(store_id: PrismStoreId, tip_id: int) -> list[str columns.extend(rhs_columns) return columns if node.kind == PrismNodeKind.GroupBy: - return [column for column in node.group_columns] + return _group_output_columns(node.group_columns) if node.kind == PrismNodeKind.Aggregate: # Aggregate output preserves grouping-prefix columns first, # then appends one output column per aggregate measure in declaration order. @@ -64,7 +69,7 @@ pub def rewritten_output_columns(view: PrismOptimizedView, node_id: int) -> list columns.extend(rhs_columns) return columns if node.kind == PrismNodeKind.GroupBy: - return [column for column in node.group_columns] + return _group_output_columns(node.group_columns) if node.kind == PrismNodeKind.Aggregate: # Rewritten aggregate nodes follow the same output contract as authored nodes: # grouping-prefix columns first, aggregate measure outputs after that. @@ -80,7 +85,7 @@ def _authored_aggregate_prefix_columns(store_id: PrismStoreId, tip_id: int) -> l """Return grouping-prefix columns that survive into one authored aggregate node output.""" input_node = node_at(store_id, tip_id) if input_node.kind == PrismNodeKind.GroupBy: - return [column for column in input_node.group_columns] + return _group_output_columns(input_node.group_columns) if input_node.kind == PrismNodeKind.Aggregate: return _authored_aggregate_prefix_columns(store_id, input_node.input_ids[0]) return [] @@ -90,7 +95,12 @@ def _rewritten_aggregate_prefix_columns(view: PrismOptimizedView, node_id: int) """Return grouping-prefix columns that survive into one rewritten aggregate node output.""" input_node = rewritten_node_at(view, node_id) if input_node.kind == PrismNodeKind.GroupBy: - return [column for column in input_node.group_columns] + return _group_output_columns(input_node.group_columns) if input_node.kind == PrismNodeKind.Aggregate: return _rewritten_aggregate_prefix_columns(view, input_node.input_ids[0]) return [] + + +def _group_output_columns(group_columns: list[ColumnExpr]) -> list[str]: + """Return stable output column names for grouping scalar expressions.""" + return [scalar_expr_output_name(expr, f"group_{idx}") for idx, expr in enumerate(group_columns)] diff --git a/src/prism/rewrite.incn b/src/prism/rewrite.incn index b7b9a54..d31460b 100644 --- a/src/prism/rewrite.incn +++ b/src/prism/rewrite.incn @@ -1,8 +1,9 @@ """Prism optimized-view derivation, canonical rewrite rules, and explain artifacts.""" -from filter_builders import FilterPredicateKind, always_true +from filter_builders import always_true from prism.types import PrismNode, PrismNodeKind, PrismOptimizedView, PrismRewriteExplain, PrismStoreId from prism.store import node_at, reachable_node_ids, remap_input_ids, store_nodes +from projection_builders import is_bool_literal_expr model PrismRewriteResult: @@ -133,7 +134,7 @@ def _derive_rewrite_result(store_id: PrismStoreId, tip_id: int) -> PrismRewriteR def _can_eliminate_filter_true(node: PrismNode) -> bool: """Return whether a filter node is a no-op `filter(true)` and can be erased.""" - return node.kind == PrismNodeKind.Filter and node.filter_predicate.kind == FilterPredicateKind.AlwaysTrue + return node.kind == PrismNodeKind.Filter and is_bool_literal_expr(node.filter_predicate, true) def _can_collapse_adjacent_limit(node: PrismNode, remapped_inputs: list[int], rewritten_nodes: list[PrismNode]) -> bool: diff --git a/src/prism/store.incn b/src/prism/store.incn index 4005f46..4b50635 100644 --- a/src/prism/store.incn +++ b/src/prism/store.incn @@ -1,8 +1,19 @@ """Append-only Prism store allocation, storage, reachability, and cross-store adoption.""" from aggregate_builders import AggregateMeasure -from filter_builders import FilterLiteralKind, FilterPredicate, FilterPredicateKind -from projection_builders import ColumnExpr, ProjectionAssignment +from projection_builders import ( + AddExpr, + BoolLiteralExpr, + ColumnExpr, + ColumnRefExpr, + EqExpr, + FloatLiteralExpr, + GtExpr, + IntLiteralExpr, + MultiplyExpr, + ProjectionAssignment, + StringLiteralExpr, +) from prism.types import PrismNode, PrismNodeKind, PrismStoreAdoption, PrismStoreId @@ -36,9 +47,9 @@ pub def append_node( input_ids: list[int], named_table: str, join_predicate: bool, - filter_predicate: FilterPredicate, + filter_predicate: ColumnExpr, limit_count: int, - group_columns: list[str], + group_columns: list[ColumnExpr], aggregate_measures: list[AggregateMeasure], projection_assignments: list[ProjectionAssignment], ) -> int: @@ -91,6 +102,7 @@ pub def adopt_cursor_subgraph( # Inputs are remapped first so structural equality is evaluated in the target store's id space. remapped_input_ids = remap_input_ids(source_node.input_ids, id_map) + # Dedup is intentionally local to the nodes adopted during this pass. reused_adopted_id = _find_existing_adopted_node_id( target_store_nodes, source_node, @@ -135,6 +147,8 @@ pub def adopt_cursor_subgraph( id_map[source_node_id] = adopted_id adopted_node_ids.append(adopted_id) adopted_node_count += 1 + + # The source tip is guaranteed reachable, so its remapped target-store id is the adopted tip for the caller. return PrismStoreAdoption(adopted_tip_id=id_map[tip_id], adopted_node_count=adopted_node_count) @@ -150,6 +164,7 @@ pub def reachable_node_ids(store_id: PrismStoreId, tip_id: int) -> list[int]: mut visited: list[bool] = [false for _node in store_nodes_for_id] mut pending: list[int] = [tip_id] mut cursor = 0 + # Walk outward through inputs from the tip so later passes only see the live authored slice. while cursor < len(pending): node_id = pending[cursor] cursor += 1 @@ -159,6 +174,8 @@ pub def reachable_node_ids(store_id: PrismStoreId, tip_id: int) -> list[int]: node = store_nodes_for_id[node_id] for input_id in node.input_ids: pending.append(input_id) + + # Return reachable node ids in store order for stable iteration in later passes. ordered = [node.node_id for node in store_nodes_for_id if visited[node.node_id]] return ordered @@ -197,21 +214,21 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return false if candidate.join_predicate != source_node.join_predicate: return false - if not _filter_predicates_structurally_equal( - candidate.filter_predicate.clone(), - source_node.filter_predicate.clone(), - ): + if not _filter_predicates_structurally_equal(candidate.filter_predicate, source_node.filter_predicate): return false if candidate.limit_count != source_node.limit_count: return false - if candidate.group_columns != source_node.group_columns: + if not _column_expr_lists_structurally_equal(candidate.group_columns, source_node.group_columns): return false if len(candidate.aggregate_measures) != len(source_node.aggregate_measures): return false for idx in range(len(candidate.aggregate_measures)): if candidate.aggregate_measures[idx].kind != source_node.aggregate_measures[idx].kind: return false - if candidate.aggregate_measures[idx].column_name != source_node.aggregate_measures[idx].column_name: + if not _column_exprs_structurally_equal( + candidate.aggregate_measures[idx].expr, + source_node.aggregate_measures[idx].expr, + ): return false if not _projection_assignments_structurally_equal( candidate.projection_assignments, @@ -221,21 +238,9 @@ def _nodes_structurally_equal(candidate: PrismNode, source_node: PrismNode, rema return candidate.input_ids == remapped_input_ids -def _filter_predicates_structurally_equal(left: FilterPredicate, right: FilterPredicate) -> bool: - """Return whether two filter predicates are structurally equivalent.""" - if left.kind != right.kind: - return false - if left.column_name != right.column_name: - return false - if left.literal.kind != right.literal.kind: - return false - if left.literal.kind == FilterLiteralKind.Int: - return left.literal.int_value == right.literal.int_value - if left.literal.kind == FilterLiteralKind.String: - return left.literal.string_value == right.literal.string_value - if left.literal.kind == FilterLiteralKind.Bool: - return left.literal.bool_value == right.literal.bool_value - return true +def _filter_predicates_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: + """Return whether two filter scalar expressions are structurally equivalent.""" + return _column_exprs_structurally_equal(left, right) def _projection_assignments_structurally_equal( @@ -253,28 +258,77 @@ def _projection_assignments_structurally_equal( return true -def _column_exprs_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: - """Return whether two projection expressions are structurally equivalent.""" - if left.kind != right.kind: - return false - if left.column_name != right.column_name: - return false - if left.int_value != right.int_value: - return false - if left.float_value != right.float_value: - return false - if left.string_value != right.string_value: - return false - if left.bool_value != right.bool_value: - return false - if len(left.arguments) != len(right.arguments): +def _column_expr_lists_structurally_equal(left: list[ColumnExpr], right: list[ColumnExpr]) -> bool: + """Return whether two scalar-expression lists are structurally equivalent.""" + if len(left) != len(right): return false - for idx in range(len(left.arguments)): - if not _column_exprs_structurally_equal(left.arguments[idx], right.arguments[idx]): + for idx in range(len(left)): + if not _column_exprs_structurally_equal(left[idx], right[idx]): return false return true +def _column_exprs_structurally_equal(left: ColumnExpr, right: ColumnExpr) -> bool: + """Return whether two projection expressions are structurally equivalent.""" + match left: + ColumnRefExpr(left_column) => + match right: + ColumnRefExpr(right_column) => return left_column.column_name == right_column.column_name + _ => return false + IntLiteralExpr(left_literal) => + match right: + IntLiteralExpr(right_literal) => return left_literal.value == right_literal.value + _ => return false + FloatLiteralExpr(left_literal) => + match right: + FloatLiteralExpr(right_literal) => return left_literal.value == right_literal.value + _ => return false + StringLiteralExpr(left_literal) => + match right: + StringLiteralExpr(right_literal) => return left_literal.value == right_literal.value + _ => return false + BoolLiteralExpr(left_literal) => + match right: + BoolLiteralExpr(right_literal) => return left_literal.value == right_literal.value + _ => return false + AddExpr(left_add) => + match right: + AddExpr(right_add) => + if len(left_add.arguments) != 2 or len(right_add.arguments) != 2: + return false + if not _column_exprs_structurally_equal(left_add.arguments[0], right_add.arguments[0]): + return false + return _column_exprs_structurally_equal(left_add.arguments[1], right_add.arguments[1]) + _ => return false + MultiplyExpr(left_multiply) => + match right: + MultiplyExpr(right_multiply) => + if len(left_multiply.arguments) != 2 or len(right_multiply.arguments) != 2: + return false + if not _column_exprs_structurally_equal(left_multiply.arguments[0], right_multiply.arguments[0]): + return false + return _column_exprs_structurally_equal(left_multiply.arguments[1], right_multiply.arguments[1]) + _ => return false + EqExpr(left_eq) => + match right: + EqExpr(right_eq) => + if len(left_eq.arguments) != 2 or len(right_eq.arguments) != 2: + return false + if not _column_exprs_structurally_equal(left_eq.arguments[0], right_eq.arguments[0]): + return false + return _column_exprs_structurally_equal(left_eq.arguments[1], right_eq.arguments[1]) + _ => return false + GtExpr(left_gt) => + match right: + GtExpr(right_gt) => + if len(left_gt.arguments) != 2 or len(right_gt.arguments) != 2: + return false + if not _column_exprs_structurally_equal(left_gt.arguments[0], right_gt.arguments[0]): + return false + return _column_exprs_structurally_equal(left_gt.arguments[1], right_gt.arguments[1]) + _ => return false + + def _latest_store_next_node_id(store_id: PrismStoreId) -> int: """ Return next store-local node id for one store. @@ -287,11 +341,4 @@ def _latest_store_next_node_id(store_id: PrismStoreId) -> int: candidate = prism_store_node_counts[idx] if candidate.store_id_raw == store_id.0: return candidate.next_node_id - - # Dedup is intentionally local to this adoption pass. We do not attempt global memoization across stores yet. - return 0 - # The source tip is guaranteed reachable, so its remapped target-store id is the adopted tip for the caller. - # Walk backward through inputs from the tip so later lowering/rewrite steps only see the live authored slice. - # Compute reachable node ids in store order for stable iteration in later passes. - # Only scan nodes appended during this adoption pass; older target-store nodes are out of scope for this slice. diff --git a/src/prism/types.incn b/src/prism/types.incn index 9e39e2c..138b35f 100644 --- a/src/prism/types.incn +++ b/src/prism/types.incn @@ -1,26 +1,25 @@ """Shared Prism types that define the internal planning substrate contract.""" from aggregate_builders import AggregateMeasure -from filter_builders import FilterPredicate -from projection_builders import ProjectionAssignment +from projection_builders import ColumnExpr, ProjectionAssignment pub type PrismStoreId = newtype int @derive(Clone) -pub enum PrismNodeKind: +pub enum PrismNodeKind(str): """Logical node kinds Prism currently supports in authored and rewritten views.""" - ReadNamedTable - Filter - Join - Project - GroupBy - Aggregate - OrderBy - Limit - Explode + ReadNamedTable = "ReadNamedTable" + Filter = "Filter" + Join = "Join" + Project = "Project" + GroupBy = "GroupBy" + Aggregate = "Aggregate" + OrderBy = "OrderBy" + Limit = "Limit" + Explode = "Explode" @derive(Clone) @@ -37,9 +36,9 @@ pub model PrismNode: pub input_ids: list[int] pub named_table: str pub join_predicate: bool - pub filter_predicate: FilterPredicate + pub filter_predicate: ColumnExpr pub limit_count: int - pub group_columns: list[str] + pub group_columns: list[ColumnExpr] pub aggregate_measures: list[AggregateMeasure] pub projection_assignments: list[ProjectionAssignment] @@ -76,13 +75,4 @@ pub model PrismRewriteExplain: pub def authored_node_kind_name(node: PrismNode) -> str: """Render one stable human-readable node-kind name for tests and diagnostics.""" - match node.kind: - PrismNodeKind.ReadNamedTable => return str("ReadNamedTable") - PrismNodeKind.Filter => return str("Filter") - PrismNodeKind.Join => return str("Join") - PrismNodeKind.Project => return str("Project") - PrismNodeKind.GroupBy => return str("GroupBy") - PrismNodeKind.Aggregate => return str("Aggregate") - PrismNodeKind.OrderBy => return str("OrderBy") - PrismNodeKind.Limit => return str("Limit") - PrismNodeKind.Explode => return str("Explode") + return node.kind.value() diff --git a/src/projection_builders.incn b/src/projection_builders.incn index d7837e0..a5b5000 100644 --- a/src/projection_builders.incn +++ b/src/projection_builders.incn @@ -1,40 +1,95 @@ """ -Explicit projection-expression builder surface for the current InQL-only projection slice. +Explicit scalar-expression builder surface for the current InQL-only row-expression slice. -These builders are the semantic target for future compiler sugar such as `.amount * 1.2` and query-block projection -lowering. Today they provide a library-owned way to express derived-column intent without requiring parser or -typechecker changes in the Incan compiler. +These builders are the semantic target for future compiler sugar such as `.amount * 1.2`, predicate expressions, +grouping keys, aggregate inputs, and query-block lowering. Today they provide a library-owned way to express row-level +scalar intent without requiring parser or typechecker changes in the Incan compiler. """ -import std.testing +from rust::incan_stdlib::errors import raise_value_error @derive(Clone) -pub enum ColumnExprKind: - """Supported expression kinds for the current projection-builder surface.""" +pub enum ColumnExprKind(str): + """Supported expression kinds for the current scalar-expression surface.""" - Column - IntLiteral - FloatLiteral - StringLiteral - BoolLiteral - Add - Multiply + Column = "column" + IntLiteral = "int_literal" + FloatLiteral = "float_literal" + StringLiteral = "string_literal" + BoolLiteral = "bool_literal" + Add = "add" + Multiply = "multiply" + Eq = "eq" + Gt = "gt" @derive(Clone) -pub model ColumnExpr: - """Projection expression description carried through dataset, Prism, and Substrait boundaries.""" +pub model ColumnRefExpr: + """Named column reference expression.""" - pub kind: ColumnExprKind pub column_name: str - pub int_value: int - pub float_value: float - pub string_value: str - pub bool_value: bool + + +@derive(Clone) +pub model IntLiteralExpr: + """Integer literal scalar expression.""" + + pub value: int + + +@derive(Clone) +pub model FloatLiteralExpr: + """Float literal scalar expression.""" + + pub value: float + + +@derive(Clone) +pub model StringLiteralExpr: + """String literal scalar expression.""" + + pub value: str + + +@derive(Clone) +pub model BoolLiteralExpr: + """Boolean literal scalar expression.""" + + pub value: bool + + +@derive(Clone) +pub model AddExpr: + """Binary addition scalar expression.""" + + pub arguments: list[ColumnExpr] + + +@derive(Clone) +pub model MultiplyExpr: + """Binary multiplication scalar expression.""" + pub arguments: list[ColumnExpr] +@derive(Clone) +pub model EqExpr: + """Binary equality scalar expression.""" + + pub arguments: list[ColumnExpr] + + +@derive(Clone) +pub model GtExpr: + """Binary greater-than scalar expression.""" + + pub arguments: list[ColumnExpr] + + +pub type ColumnExpr = Union[ColumnRefExpr, IntLiteralExpr, FloatLiteralExpr, StringLiteralExpr, BoolLiteralExpr, AddExpr, MultiplyExpr, EqExpr, GtExpr] + + @derive(Clone) pub model ProjectionAssignment: """One ordered add-or-replace projection assignment used by `with_column(...)` and merged project nodes.""" @@ -45,93 +100,56 @@ pub model ProjectionAssignment: pub def col(name: str) -> ColumnExpr: """Build one named column reference expression.""" - return ColumnExpr( - kind=ColumnExprKind.Column, - column_name=name, - int_value=0, - float_value=0.0, - string_value="", - bool_value=false, - arguments=[], - ) + return ColumnRefExpr(column_name=name) pub def int_expr(value: int) -> ColumnExpr: """Build one integer literal expression.""" - return ColumnExpr( - kind=ColumnExprKind.IntLiteral, - column_name="", - int_value=value, - float_value=0.0, - string_value="", - bool_value=false, - arguments=[], - ) + return IntLiteralExpr(value=value) pub def float_expr(value: float) -> ColumnExpr: """Build one float literal expression.""" - return ColumnExpr( - kind=ColumnExprKind.FloatLiteral, - column_name="", - int_value=0, - float_value=value, - string_value="", - bool_value=false, - arguments=[], - ) + return FloatLiteralExpr(value=value) pub def str_expr(value: str) -> ColumnExpr: """Build one string literal expression.""" - return ColumnExpr( - kind=ColumnExprKind.StringLiteral, - column_name="", - int_value=0, - float_value=0.0, - string_value=value, - bool_value=false, - arguments=[], - ) + return StringLiteralExpr(value=value) pub def bool_expr(value: bool) -> ColumnExpr: """Build one boolean literal expression.""" - return ColumnExpr( - kind=ColumnExprKind.BoolLiteral, - column_name="", - int_value=0, - float_value=0.0, - string_value="", - bool_value=value, - arguments=[], - ) + return BoolLiteralExpr(value=value) + + +pub def lit(value: Union[int, float, str, bool]) -> ColumnExpr: + """Build one canonical scalar literal expression.""" + match value: + bool(flag) => return bool_expr(flag) + int(number) => return int_expr(number) + float(number) => return float_expr(number) + str(text) => return str_expr(text) pub def add(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: """Build one binary addition expression.""" - return ColumnExpr( - kind=ColumnExprKind.Add, - column_name="", - int_value=0, - float_value=0.0, - string_value="", - bool_value=false, - arguments=[left, right], - ) + return AddExpr(arguments=[left, right]) pub def mul(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: """Build one binary multiply expression.""" - return ColumnExpr( - kind=ColumnExprKind.Multiply, - column_name="", - int_value=0, - float_value=0.0, - string_value="", - bool_value=false, - arguments=[left, right], - ) + return MultiplyExpr(arguments=[left, right]) + + +pub def eq(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """Build one equality predicate scalar expression.""" + return EqExpr(arguments=[left, right]) + + +pub def gt(left: ColumnExpr, right: ColumnExpr) -> ColumnExpr: + """Build one greater-than predicate scalar expression.""" + return GtExpr(arguments=[left, right]) pub def with_column_assignment(name: str, expr: ColumnExpr) -> ProjectionAssignment: @@ -141,15 +159,46 @@ pub def with_column_assignment(name: str, expr: ColumnExpr) -> ProjectionAssignm pub def column_expr_name(expr: ColumnExpr) -> str: """Return the referenced column name when the expression is a direct column reference, otherwise empty.""" - if expr.kind == ColumnExprKind.Column: - return expr.column_name - return "" + match expr: + ColumnRefExpr(column) => return column.column_name + _ => return "" pub def require_column_expr_name(expr: ColumnExpr, context: str) -> str: """Return one direct column name or fail with a clear message for unsupported expression shapes.""" - assert expr.kind == ColumnExprKind.Column, f"{context} currently requires a direct col(\"name\") reference" - return expr.column_name + match expr: + ColumnRefExpr(column) => return column.column_name + _ => + message = f"{context} currently requires a direct col(\"name\") reference" + return raise_value_error(message) + + +pub def scalar_expr_output_name(expr: ColumnExpr, fallback: str) -> str: + """Return a stable output name for one scalar expression in positions that expose grouping output columns.""" + match expr: + ColumnRefExpr(column) => return column.column_name + _ => return fallback + + +pub def column_expr_kind(expr: ColumnExpr) -> ColumnExprKind: + """Return the public expression-kind label for one scalar expression.""" + match expr: + ColumnRefExpr(_) => return ColumnExprKind.Column + IntLiteralExpr(_) => return ColumnExprKind.IntLiteral + FloatLiteralExpr(_) => return ColumnExprKind.FloatLiteral + StringLiteralExpr(_) => return ColumnExprKind.StringLiteral + BoolLiteralExpr(_) => return ColumnExprKind.BoolLiteral + AddExpr(_) => return ColumnExprKind.Add + MultiplyExpr(_) => return ColumnExprKind.Multiply + EqExpr(_) => return ColumnExprKind.Eq + GtExpr(_) => return ColumnExprKind.Gt + + +pub def is_bool_literal_expr(expr: ColumnExpr, expected: bool) -> bool: + """Return whether one scalar expression is the requested boolean literal.""" + match expr: + BoolLiteralExpr(literal) => return literal.value == expected + _ => return false pub def project_output_columns(input_columns: list[str], assignments: list[ProjectionAssignment]) -> list[str]: @@ -158,13 +207,10 @@ pub def project_output_columns(input_columns: list[str], assignments: list[Proje output_columns.extend(input_columns) for assignment in assignments: existing_idx = _index_of_column(output_columns, assignment.output_name) - # TODO(#121): remove this explicit owned-string materialization once v0.3 fixes the borrowed-loop-item field - # path structurally instead of case by case. - output_name = str(assignment.output_name) if existing_idx >= 0: - output_columns[existing_idx] = output_name + output_columns[existing_idx] = assignment.output_name else: - output_columns.append(output_name) + output_columns.append(assignment.output_name) return output_columns diff --git a/src/session/active.incn b/src/session/active.incn index d04483b..0f685d8 100644 --- a/src/session/active.incn +++ b/src/session/active.incn @@ -8,6 +8,13 @@ object to keep ownership boundaries explicit across session submodules. from backends import BackendSelection, TableSource, default_backend_selection +@derive(Clone) +pub enum ActiveSessionErrorKind(str): + """Stable categories for active-session lookup failures.""" + + NoActiveSession = "no_active_session" + + @derive(Clone) pub class ActiveRegistration: pub logical_name: str @@ -29,7 +36,7 @@ pub class ActiveSessionState: pub model ActiveSessionError: - pub kind: str + pub kind: ActiveSessionErrorKind pub message: str @@ -63,7 +70,7 @@ pub def get_active_session_state() -> Result[ActiveSessionState, ActiveSessionEr latest_has_active = flag found_flag = true if found_flag is false or latest_has_active is false: - return Err(ActiveSessionError(kind="no_active_session", message=err_msg)) + return Err(ActiveSessionError(kind=ActiveSessionErrorKind.NoActiveSession, message=err_msg)) # State snapshots are also append-only. # The latest entry is the active snapshot payload associated with the most recent lifecycle marker. @@ -73,6 +80,6 @@ pub def get_active_session_state() -> Result[ActiveSessionState, ActiveSessionEr active = candidate.clone() found_state = true if found_state is false: - return Err(ActiveSessionError(kind="no_active_session", message=err_msg)) + return Err(ActiveSessionError(kind=ActiveSessionErrorKind.NoActiveSession, message=err_msg)) return Ok(active) diff --git a/src/session/csv_schema.incn b/src/session/csv_schema.incn index f27ee78..f878ca3 100644 --- a/src/session/csv_schema.incn +++ b/src/session/csv_schema.incn @@ -1,6 +1,6 @@ """CSV schema inference helpers for Session read paths.""" -from rust::std::fs import read_to_string +from std.fs import Path from session.errors import SessionError, invalid_registration from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind @@ -11,10 +11,9 @@ pub def infer_csv_schema_columns(uri: str) -> Result[list[RowColumnSpec], Sessio # It only looks at header presence, empty cells, and coarse primitive-kind merging. # Revisit this once the fuller CSV ingestion contract lands so dialect rules, quoting/escaping, # richer typing, and stricter schema/header policies stop living behind this lightweight helper. - match read_to_string(uri): + match Path(uri).read_text("utf-8", "strict"): Ok(content) => - normalized = str(content).replace(str("\r"), str("")) - lines = normalized.split(str("\n")) + lines = content.split("\n") if len(lines) == 0: return Err(invalid_registration(f"csv source '{uri}' is empty")) header = lines[0].strip() @@ -31,7 +30,7 @@ pub def infer_csv_schema_columns(uri: str) -> Result[list[RowColumnSpec], Sessio line = raw_line.strip() if len(line) == 0: continue - cells = line.split(str(",")) + cells = line.split(",") for col_idx in range(len(columns)): if col_idx >= len(cells): nullable[col_idx] = true @@ -53,12 +52,12 @@ pub def infer_csv_schema_columns(uri: str) -> Result[list[RowColumnSpec], Sessio kind = SubstraitPrimitiveKind.String specs.append(RowColumnSpec(name=columns[spec_idx], kind=kind, nullable=nullable[spec_idx])) return Ok(specs) - Err(err) => return Err(invalid_registration(f"failed to read csv source '{uri}': {err.to_string()}")) + Err(err) => return Err(invalid_registration(f"failed to read csv source '{uri}': {err.message()}")) def _non_empty_csv_header_columns(header: str) -> list[str]: """Parse one CSV header row into non-empty trimmed column names.""" - return [raw_name.strip() for raw_name in header.split(str(",")) if len(raw_name.strip()) > 0] + return [raw_name.strip() for raw_name in header.split(",") if len(raw_name.strip()) > 0] def _infer_primitive_kind(value: str) -> SubstraitPrimitiveKind: @@ -81,7 +80,10 @@ def _merge_primitive_kind(current: SubstraitPrimitiveKind, incoming: SubstraitPr return SubstraitPrimitiveKind.String if current == incoming: return current - if (current == SubstraitPrimitiveKind.I64 and incoming == SubstraitPrimitiveKind.F64) or (current == SubstraitPrimitiveKind.F64 and incoming == SubstraitPrimitiveKind.I64): + if (current == SubstraitPrimitiveKind.I64 and incoming == SubstraitPrimitiveKind.F64) or ( + current == SubstraitPrimitiveKind.F64 + and incoming == SubstraitPrimitiveKind.I64 + ): return SubstraitPrimitiveKind.F64 return SubstraitPrimitiveKind.String @@ -91,7 +93,7 @@ def _is_integer_like(value: str) -> bool: if len(value) == 0: return false mut start = 0 - if value[0] == str("-"): + if value[0] == "-": if len(value) == 1: return false start = 1 @@ -103,7 +105,7 @@ def _is_integer_like(value: str) -> bool: def _strip_csv_quotes(value: str) -> str: """Strip one pair of surrounding double quotes from a CSV cell value.""" - if len(value) >= 2 and value[0] == str("\"") and value[len(value) - 1] == str("\""): + if len(value) >= 2 and value[0] == "\"" and value[len(value) - 1] == "\"": return value[1:len(value) - 1] return value @@ -120,7 +122,7 @@ def _is_float_like(value: str) -> bool: return false # Optional leading sign. mut start_idx = 0 - if value[0] == str("-"): + if value[0] == "-": if len(value) == 1: return false start_idx = 1 @@ -128,7 +130,7 @@ def _is_float_like(value: str) -> bool: mut seen_dot = false for i in range(start_idx, len(value)): ch = value[i] - if ch == str("."): + if ch == ".": if seen_dot: return false seen_dot = true @@ -140,9 +142,9 @@ def _is_float_like(value: str) -> bool: def _is_ascii_digit(ch: str) -> bool: """Return whether one single-character string is an ASCII decimal digit.""" - return ch >= str("0") and ch <= str("9") + return ch >= "0" and ch <= "9" def _is_timestamp_like(value: str) -> bool: """Return whether a value looks like an ISO-like timestamp token.""" - return value.contains(str("T")) and value.contains(str(":")) + return value.contains("T") and value.contains(":") diff --git a/src/session/datafusion_backend.incn b/src/session/datafusion_backend.incn index fb032a9..41c8d33 100644 --- a/src/session/datafusion_backend.incn +++ b/src/session/datafusion_backend.incn @@ -15,6 +15,16 @@ from dataset.materialization import DataFrameMaterialization from substrait.inspect import root_names +@derive(Clone) +pub enum BackendErrorKind(str): + """Stable categories for DataFusion backend failures.""" + + BackendPlanningError = "backend_planning_error" + BackendExecutionError = "backend_execution_error" + BackendSinkError = "backend_sink_error" + BackendRegistrationError = "backend_registration_error" + + @derive(Clone) pub class BackendRegistration: pub logical_name: str @@ -26,11 +36,11 @@ pub class BackendRegistration: pub model BackendError: - pub kind: str + pub kind: BackendErrorKind pub message: str -def _backend_error(kind: str, message: str) -> BackendError: +def _backend_error(kind: BackendErrorKind, message: str) -> BackendError: """Build one backend error with stable kind/message fields.""" return BackendError(kind=kind, message=message) @@ -49,11 +59,11 @@ pub async def datafusion_execute_async( Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): Ok(df) => match await df.collect(): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_planning_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) pub async def datafusion_collect_materialization_async( @@ -74,7 +84,7 @@ pub async def datafusion_collect_materialization_async( Ok(rendered) => mut row_count = 0 for batch in batches: - row_count += _rust_usize_to_int(batch.num_rows()) + row_count += _rust_usize_to_int(batch.num_rows())? return Ok( DataFrameMaterialization( resolved_columns=resolved_columns, @@ -82,13 +92,13 @@ pub async def datafusion_collect_materialization_async( preview_text=rendered, ), ) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_planning_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) pub async def datafusion_write_csv_async( @@ -106,11 +116,11 @@ pub async def datafusion_write_csv_async( Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): Ok(df) => match await df.write_csv(uri, DataFrameWriteOptions.new(), None): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error("backend_sink_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendSinkError, err.to_string())) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_planning_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) pub async def datafusion_write_parquet_async( @@ -128,11 +138,11 @@ pub async def datafusion_write_parquet_async( Ok(logical_plan) => match await ctx.execute_logical_plan(logical_plan): Ok(df) => match await df.write_parquet(uri, DataFrameWriteOptions.new(), None): Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error("backend_sink_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendSinkError, err.to_string())) - Err(err) => return Err(_backend_error("backend_execution_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, err.to_string())) - Err(err) => return Err(_backend_error("backend_planning_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) async def _register_sources( @@ -147,22 +157,22 @@ async def _register_sources( async def _register_one(ctx: SessionContext, logical_name: str, source: TableSource) -> Result[None, BackendError]: """Register one logical source by kind on the DataFusion context.""" - if source.source_kind == SourceKind.Csv: - csv_opts = CsvReadOptions.new().has_header(true) - match await ctx.register_csv(logical_name, source.uri, csv_opts): - Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error("backend_registration_error", err.to_string())) - - if source.source_kind == SourceKind.Parquet: - parquet_opts = ParquetReadOptions.default() - match await ctx.register_parquet(logical_name, source.uri, parquet_opts): - Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error("backend_registration_error", err.to_string())) - - arrow_opts = ArrowReadOptions.default() - match await ctx.register_arrow(logical_name, source.uri, arrow_opts): - Ok(_) => return Ok(None) - Err(err) => return Err(_backend_error("backend_registration_error", err.to_string())) + match source.source_kind: + SourceKind.Csv => + csv_opts = CsvReadOptions.new().has_header(true) + match await ctx.register_csv(logical_name, source.uri, csv_opts): + Ok(_) => return Ok(None) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) + SourceKind.Parquet => + parquet_opts = ParquetReadOptions.default() + match await ctx.register_parquet(logical_name, source.uri, parquet_opts): + Ok(_) => return Ok(None) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) + SourceKind.Arrow => + arrow_opts = ArrowReadOptions.default() + match await ctx.register_arrow(logical_name, source.uri, arrow_opts): + Ok(_) => return Ok(None) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendRegistrationError, err.to_string())) def _consumer_plan_from_current_plan(plan: Plan) -> Result[ConsumerPlan, BackendError]: @@ -170,11 +180,11 @@ def _consumer_plan_from_current_plan(plan: Plan) -> Result[ConsumerPlan, Backend encoded = plan.encode_to_vec() match ConsumerPlan.decode(encoded.as_slice()): Ok(decoded) => return Ok(decoded) - Err(err) => return Err(_backend_error("backend_planning_error", err.to_string())) + Err(err) => return Err(_backend_error(BackendErrorKind.BackendPlanningError, err.to_string())) -def _rust_usize_to_int(value: RustUsize) -> int: +def _rust_usize_to_int(value: RustUsize) -> Result[int, BackendError]: """Convert one Rust `usize` count into Incan `int` through a checked `i64` boundary.""" match RustI64.try_from(value): - Ok(converted) => return converted.into() - Err(_) => return 0 + Ok(converted) => return Ok(converted.into()) + Err(_) => return Err(_backend_error(BackendErrorKind.BackendExecutionError, "row count does not fit Incan int")) diff --git a/src/session/domain.incn b/src/session/domain.incn index 395f6e7..e5935c7 100644 --- a/src/session/domain.incn +++ b/src/session/domain.incn @@ -1,16 +1,14 @@ """Session domain newtypes and internal enums.""" -from session.errors import SessionError, invalid_registration, invalid_sink - @derive(Clone) pub type LogicalName = newtype str: """Validated logical table name used by Session registration and lookup paths.""" - def from_underlying(value: str) -> Result[LogicalName, SessionError]: + def from_underlying(value: str) -> Result[LogicalName, ValidationError]: """Validate and construct one logical table-name value.""" if len(value) == 0: - return Err(invalid_registration("logical_name must not be empty")) + return Err(ValidationError("logical_name must not be empty")) return Ok(LogicalName(value)) @@ -18,10 +16,10 @@ pub type LogicalName = newtype str: pub type SourceUri = newtype str: """Validated source URI used by Session registration/read paths.""" - def from_underlying(value: str) -> Result[SourceUri, SessionError]: + def from_underlying(value: str) -> Result[SourceUri, ValidationError]: """Validate and construct one source URI value.""" if len(value) == 0: - return Err(invalid_registration("source uri must not be empty")) + return Err(ValidationError("source uri must not be empty")) return Ok(SourceUri(value)) @@ -29,14 +27,14 @@ pub type SourceUri = newtype str: pub type SinkUri = newtype str: """Validated sink URI used by Session write paths.""" - def from_underlying(value: str) -> Result[SinkUri, SessionError]: + def from_underlying(value: str) -> Result[SinkUri, ValidationError]: """Validate and construct one sink URI value.""" if len(value) == 0: - return Err(invalid_sink("sink uri must not be empty")) + return Err(ValidationError("sink uri must not be empty")) return Ok(SinkUri(value)) @derive(Clone) -pub enum SinkKind: - Csv - Parquet +pub enum SinkKind(str): + Csv = "csv" + Parquet = "parquet" diff --git a/src/session/errors.incn b/src/session/errors.incn index 4723b1c..bb65c54 100644 --- a/src/session/errors.incn +++ b/src/session/errors.incn @@ -1,15 +1,35 @@ """Session error model and diagnostics helpers.""" +@derive(Clone) +pub enum SessionErrorKind(str): + """Stable categories for Session-facing failures.""" + + InvalidRegistration = "invalid_registration" + InvalidSink = "invalid_sink" + DuplicateRegistration = "duplicate_registration" + UnknownTable = "unknown_table" + NoActiveSession = "no_active_session" + RuntimeInitError = "runtime_init_error" + UnsupportedBackend = "unsupported_backend" + BackendPlanningError = "backend_planning_error" + BackendExecutionError = "backend_execution_error" + BackendSinkError = "backend_sink_error" + BackendRegistrationError = "backend_registration_error" + InvalidScalarExpression = "invalid_scalar_expression" + UnknownScalarColumn = "unknown_scalar_column" + UnsupportedScalarExpression = "unsupported_scalar_expression" + + pub model SessionError: """Typed error envelope for Session-facing APIs.""" - pub kind: str + pub kind: SessionErrorKind pub message: str def error_message(self) -> str: """Return one stable `kind: message` string for diagnostics.""" - return f"{self.kind}: {self.message}" + return f"{self.kind.value()}: {self.message}" pub def format_session_diagnostic(stage: str, err: SessionError) -> str: @@ -24,9 +44,9 @@ pub def report_session_error(stage: str, err: SessionError) -> None: pub def invalid_registration(message: str) -> SessionError: """Create one stable `invalid_registration` Session error.""" - return SessionError(kind=str("invalid_registration"), message=message) + return SessionError(kind=SessionErrorKind.InvalidRegistration, message=message) pub def invalid_sink(message: str) -> SessionError: """Create one stable `invalid_sink` Session error.""" - return SessionError(kind=str("invalid_sink"), message=message) + return SessionError(kind=SessionErrorKind.InvalidSink, message=message) diff --git a/src/session/mod.incn b/src/session/mod.incn index a635e45..51e0a01 100644 --- a/src/session/mod.incn +++ b/src/session/mod.incn @@ -11,4 +11,4 @@ Module layout: """ pub from session.types import Session, SessionBuilder -pub from session.errors import SessionError, format_session_diagnostic, report_session_error +pub from session.errors import SessionError, SessionErrorKind, format_session_diagnostic, report_session_error diff --git a/src/session/types.incn b/src/session/types.incn index 5d4b7f7..1b20703 100644 --- a/src/session/types.incn +++ b/src/session/types.incn @@ -7,6 +7,7 @@ from backends import ( BackendKind, BackendSelection, DataFusion, + SourceKind, TableSource, backend_kind_name, source_kind_name, @@ -16,12 +17,14 @@ from backends import ( parquet_source, ) from dataset import DataFrame, LazyFrame, lazy_frame_named_table -pub from session.errors import SessionError, format_session_diagnostic, report_session_error +pub from session.errors import SessionError, SessionErrorKind, format_session_diagnostic, report_session_error +from session.errors import invalid_registration, invalid_sink from session.domain import LogicalName, SourceUri, SinkUri, SinkKind from session.csv_schema import infer_csv_schema_columns from session.active import ( ActiveRegistration, ActiveSessionError, + ActiveSessionErrorKind, ActiveSessionState, clear_active_session_state, get_active_session_state, @@ -29,12 +32,15 @@ from session.active import ( ) from session.datafusion_backend import ( BackendError, + BackendErrorKind, BackendRegistration, datafusion_collect_materialization_async, datafusion_execute_async, datafusion_write_csv_async, datafusion_write_parquet_async, ) +from substrait.schema import RowColumnSpec +from substrait.errors import SubstraitLoweringError, SubstraitLoweringErrorKind from substrait.schema_registry import register_named_table_schema from substrait.inspect import root_rel, read_named_table_name @@ -99,52 +105,53 @@ pub class Session: def register(mut self, logical_name: str, source: TableSource) -> Result[None, SessionError]: """Register one logical name to a portable table source descriptor.""" - name = LogicalName.from_underlying(logical_name)? - _ = SourceUri.from_underlying(source.uri)? + name = _logical_name_from_text(logical_name)? + _ = _source_uri_from_text(source.uri)? + _ensure_registration_available(self._registrations, name.0)? + columns = _schema_columns_for_source(source)? self._register_parsed(name, source)? + register_named_table_schema(name.0, columns) return Ok(None) def _register_parsed(mut self, logical_name: LogicalName, source: TableSource) -> Result[None, SessionError]: """Register one source after boundary values have been validated into typed wrappers.""" - if _find_registration_index(self._registrations, logical_name.0) >= 0: - return Err( - SessionError(kind="duplicate_registration", message=f"table '{logical_name.0}' is already registered"), - ) + _ensure_registration_available(self._registrations, logical_name.0)? self._registrations.append(SessionRegistration(logical_name=logical_name.0, source=source)) return Ok(None) def read_csv[T with Clone](mut self, logical_name: str, uri: str) -> Result[LazyFrame[T], SessionError]: """Convenience read path: register one CSV source then return one lazy named-table plan.""" - name = LogicalName.from_underlying(logical_name)? - source_uri = SourceUri.from_underlying(uri)? - self._register_parsed(name, csv_source(source_uri.0))? + name = _logical_name_from_text(logical_name)? + source_uri = _source_uri_from_text(uri)? + _ensure_registration_available(self._registrations, name.0)? columns = infer_csv_schema_columns(source_uri.0)? + self._register_parsed(name, csv_source(source_uri.0))? # TODO(InQL RFC 004): Formalize logical-name schema binding as an explicit catalog/snapshot model and surface overwrite diagnostics. register_named_table_schema(name.0, columns) - return self.table[T](name.0) + return self.table(name.0) def read_parquet[T with Clone](mut self, logical_name: str, uri: str) -> Result[LazyFrame[T], SessionError]: """Convenience read path: register one Parquet source then return one lazy named-table plan.""" - name = LogicalName.from_underlying(logical_name)? - source_uri = SourceUri.from_underlying(uri)? + name = _logical_name_from_text(logical_name)? + source_uri = _source_uri_from_text(uri)? self._register_parsed(name, parquet_source(source_uri.0))? register_named_table_schema(name.0, []) - return self.table[T](name.0) + return self.table(name.0) def read_arrow[T with Clone](mut self, logical_name: str, uri: str) -> Result[LazyFrame[T], SessionError]: """Convenience read path: register one Arrow IPC source then return one lazy named-table plan.""" - name = LogicalName.from_underlying(logical_name)? - source_uri = SourceUri.from_underlying(uri)? + name = _logical_name_from_text(logical_name)? + source_uri = _source_uri_from_text(uri)? self._register_parsed(name, arrow_source(source_uri.0))? register_named_table_schema(name.0, []) - return self.table[T](name.0) + return self.table(name.0) def table[T with Clone](self, logical_name: str) -> Result[LazyFrame[T], SessionError]: """Return one lazy named-table carrier for a previously registered logical name.""" - name = LogicalName.from_underlying(logical_name)? + name = _logical_name_from_text(logical_name)? if _find_registration_index(self._registrations, name.0) < 0: - return Err(SessionError(kind="unknown_table", message=f"table '{name.0}' is not registered")) - return Ok(lazy_frame_named_table[T](name.0)) + return Err(SessionError(kind=SessionErrorKind.UnknownTable, message=f"table '{name.0}' is not registered")) + return Ok(lazy_frame_named_table(name.0)) def activate(self) -> None: """Set this Session as the current active Session used by LazyFrame.collect().""" @@ -165,7 +172,7 @@ pub class Session: def execute[T with Clone](self, data: LazyFrame[T]) -> Result[LazyFrame[T], SessionError]: """Validate and execute one lazy plan while preserving deferred carrier shape.""" _ensure_datafusion_backend(self._backend)? - plan = data.to_substrait_plan() + plan = _plan_from_lazy_frame(data)? _validate_named_table_binding(self._registrations, plan)? match block_on(datafusion_execute_async(_to_backend_registrations(self._registrations), plan)): @@ -173,12 +180,13 @@ pub class Session: match exec_result: Ok(_) => return Ok(data) Err(err) => return Err(_session_error_from_backend_error(err)) - Err(err) => return Err(SessionError(kind="runtime_init_error", message=err.message())) + Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) def collect[T with Clone](self, data: LazyFrame[T]) -> Result[DataFrame[T], SessionError]: """Validate and execute one lazy plan, returning a structured materialized DataFrame.""" _ensure_datafusion_backend(self._backend)? - plan = data.to_substrait_plan() + plan = _plan_from_lazy_frame(data)? + rel = root_rel(plan.clone()) _validate_named_table_binding(self._registrations, plan)? match block_on(datafusion_collect_materialization_async(_to_backend_registrations(self._registrations), plan)): @@ -186,34 +194,34 @@ pub class Session: match collect_result: Ok(materialization) => return Ok( - DataFrame[T]( - _type_witness=_empty_type_witness[T](), + DataFrame( + _type_witness=_empty_type_witness(), _materialization=materialization, - _substrait_rel=root_rel(data.to_substrait_plan()), + _substrait_rel=rel, ), ) Err(err) => return Err(_session_error_from_backend_error(err)) - Err(err) => return Err(SessionError(kind="runtime_init_error", message=err.message())) + Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) def write_csv[T with Clone](self, data: LazyFrame[T], uri: str) -> Result[None, SessionError]: """Execute one lazy plan and write result rows to a CSV sink URI.""" - return self._write_plan_to_sink(data.to_substrait_plan(), uri, SinkKind.Csv) + return self._write_plan_to_sink(_plan_from_lazy_frame(data)?, uri, SinkKind.Csv) def write_parquet[T with Clone](self, data: LazyFrame[T], uri: str) -> Result[None, SessionError]: """Execute one lazy plan and write result rows to a Parquet sink URI.""" - return self._write_plan_to_sink(data.to_substrait_plan(), uri, SinkKind.Parquet) + return self._write_plan_to_sink(_plan_from_lazy_frame(data)?, uri, SinkKind.Parquet) def write_csv_df[T with Clone](self, data: DataFrame[T], uri: str) -> Result[None, SessionError]: """Write one materialized DataFrame through the Session CSV sink path.""" - return self._write_plan_to_sink(data.to_substrait_plan(), uri, SinkKind.Csv) + return self._write_plan_to_sink(_plan_from_data_frame(data)?, uri, SinkKind.Csv) def write_parquet_df[T with Clone](self, data: DataFrame[T], uri: str) -> Result[None, SessionError]: """Write one materialized DataFrame through the Session Parquet sink path.""" - return self._write_plan_to_sink(data.to_substrait_plan(), uri, SinkKind.Parquet) + return self._write_plan_to_sink(_plan_from_data_frame(data)?, uri, SinkKind.Parquet) def _write_plan_to_sink(self, plan: Plan, uri: str, sink_kind: SinkKind) -> Result[None, SessionError]: """Run one validated plan through the selected sink writer and normalize runtime/backend errors.""" - sink_uri = SinkUri.from_underlying(uri)? + sink_uri = _sink_uri_from_text(uri)? _ensure_datafusion_backend(self._backend)? _validate_named_table_binding(self._registrations, plan)? registrations = _to_backend_registrations(self._registrations) @@ -222,11 +230,11 @@ pub class Session: SinkKind.Csv => match block_on(datafusion_write_csv_async(registrations, plan, sink_uri.0)): Ok(write_result) => return _write_result_from_backend_result(write_result) - Err(err) => return Err(SessionError(kind="runtime_init_error", message=err.message())) + Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) SinkKind.Parquet => match block_on(datafusion_write_parquet_async(registrations, plan, sink_uri.0)): Ok(write_result) => return _write_result_from_backend_result(write_result) - Err(err) => return Err(SessionError(kind="runtime_init_error", message=err.message())) + Err(err) => return Err(SessionError(kind=SessionErrorKind.RuntimeInitError, message=err.message())) pub class SessionBuilder: @@ -257,7 +265,9 @@ def _session_from_active_state(state: ActiveSessionState) -> Session: def _session_error_from_active_error(err: ActiveSessionError) -> SessionError: """Translate active-session errors into public SessionError values.""" - return SessionError(kind=err.kind, message=err.message) + match err.kind: + ActiveSessionErrorKind.NoActiveSession => + return SessionError(kind=SessionErrorKind.NoActiveSession, message=err.message) def _to_backend_registrations(registrations: list[SessionRegistration]) -> list[BackendRegistration]: @@ -267,14 +277,97 @@ def _to_backend_registrations(registrations: list[SessionRegistration]) -> list[ def _session_error_from_backend_error(err: BackendError) -> SessionError: """Translate backend errors into public SessionError values.""" - return SessionError(kind=err.kind, message=err.message) + match err.kind: + BackendErrorKind.BackendPlanningError => + return SessionError(kind=SessionErrorKind.BackendPlanningError, message=err.message) + BackendErrorKind.BackendExecutionError => + return SessionError(kind=SessionErrorKind.BackendExecutionError, message=err.message) + BackendErrorKind.BackendSinkError => + return SessionError(kind=SessionErrorKind.BackendSinkError, message=err.message) + BackendErrorKind.BackendRegistrationError => + return SessionError(kind=SessionErrorKind.BackendRegistrationError, message=err.message) + + +def _session_error_from_lowering_error(err: SubstraitLoweringError) -> SessionError: + """Translate Substrait lowering errors into public SessionError values.""" + match err.kind: + SubstraitLoweringErrorKind.InvalidScalarExpression => + return SessionError(kind=SessionErrorKind.InvalidScalarExpression, message=err.message) + SubstraitLoweringErrorKind.UnknownScalarColumn => + return SessionError(kind=SessionErrorKind.UnknownScalarColumn, message=err.message) + SubstraitLoweringErrorKind.UnsupportedScalarExpression => + return SessionError(kind=SessionErrorKind.UnsupportedScalarExpression, message=err.message) + + +def _duplicate_registration_error(logical_name: str) -> SessionError: + """Build the stable duplicate-registration error for one logical name.""" + return SessionError( + kind=SessionErrorKind.DuplicateRegistration, + message=f"table '{logical_name}' is already registered", + ) + + +def _ensure_registration_available( + registrations: list[SessionRegistration], + logical_name: str, +) -> Result[None, SessionError]: + """Ensure one logical name is not already bound in the Session registry.""" + if _find_registration_index(registrations, logical_name) >= 0: + return Err(_duplicate_registration_error(logical_name)) + return Ok(None) + + +def _schema_columns_for_source(source: TableSource) -> Result[list[RowColumnSpec], SessionError]: + """Return planned schema columns that can be established from a source descriptor.""" + match source.source_kind: + SourceKind.Csv => return infer_csv_schema_columns(source.uri) + SourceKind.Parquet => return Ok(_empty_schema_columns()) + SourceKind.Arrow => return Ok(_empty_schema_columns()) + + +def _empty_schema_columns() -> list[RowColumnSpec]: + """Return an empty schema binding for sources whose planned schema is not inferred yet.""" + return [] + + +def _plan_from_lazy_frame[T with Clone](data: LazyFrame[T]) -> Result[Plan, SessionError]: + """Lower one lazy frame to a plan while preserving structured planning errors.""" + return data.try_to_substrait_plan().map_err(_session_error_from_lowering_error) + + +def _plan_from_data_frame[T with Clone](data: DataFrame[T]) -> Result[Plan, SessionError]: + """Return one materialized frame plan through the same structured planning-error envelope.""" + return data.try_to_substrait_plan().map_err(_session_error_from_lowering_error) def _write_result_from_backend_result(write_result: Result[None, BackendError]) -> Result[None, SessionError]: """Normalize one backend write result into the Session-facing write result contract.""" - match write_result: - Ok(_) => return Ok(None) - Err(err) => return Err(_session_error_from_backend_error(err)) + return write_result.map_err(_session_error_from_backend_error) + + +def _invalid_registration_from_validation_error(err: ValidationError) -> SessionError: + """Translate a validated input failure into a Session registration error.""" + return invalid_registration(err.to_string()) + + +def _invalid_sink_from_validation_error(err: ValidationError) -> SessionError: + """Translate a validated sink failure into a Session sink error.""" + return invalid_sink(err.to_string()) + + +def _logical_name_from_text(logical_name: str) -> Result[LogicalName, SessionError]: + """Validate one logical name and normalize validation failures for Session callers.""" + return LogicalName.from_underlying(logical_name).map_err(_invalid_registration_from_validation_error) + + +def _source_uri_from_text(uri: str) -> Result[SourceUri, SessionError]: + """Validate one source URI and normalize validation failures for Session callers.""" + return SourceUri.from_underlying(uri).map_err(_invalid_registration_from_validation_error) + + +def _sink_uri_from_text(uri: str) -> Result[SinkUri, SessionError]: + """Validate one sink URI and normalize validation failures for Session callers.""" + return SinkUri.from_underlying(uri).map_err(_invalid_sink_from_validation_error) def _ensure_datafusion_backend(selection: BackendSelection) -> Result[None, SessionError]: @@ -282,7 +375,10 @@ def _ensure_datafusion_backend(selection: BackendSelection) -> Result[None, Sess if selection.kind == BackendKind.DataFusionEngine: return Ok(None) return Err( - SessionError(kind="unsupported_backend", message=f"backend '{backend_kind_name(selection)}' is not implemented"), + SessionError( + kind=SessionErrorKind.UnsupportedBackend, + message=f"backend '{backend_kind_name(selection)}' is not implemented", + ), ) @@ -303,7 +399,7 @@ def _validate_named_table_binding(registrations: list[SessionRegistration], plan if len(table_name) == 0: return Ok(None) if _find_registration_index(registrations, table_name) < 0: - return Err(SessionError(kind="unknown_table", message=f"table '{table_name}' is not registered")) + return Err(SessionError(kind=SessionErrorKind.UnknownTable, message=f"table '{table_name}' is not registered")) return Ok(None) diff --git a/src/substrait/errors.incn b/src/substrait/errors.incn new file mode 100644 index 0000000..9c63070 --- /dev/null +++ b/src/substrait/errors.incn @@ -0,0 +1,39 @@ +"""Structured errors for Substrait relation and expression lowering.""" + + +@derive(Clone) +pub enum SubstraitLoweringErrorKind(str): + """Stable categories for Substrait lowering failures.""" + + InvalidScalarExpression = "invalid_scalar_expression" + UnknownScalarColumn = "unknown_scalar_column" + UnsupportedScalarExpression = "unsupported_scalar_expression" + + +pub model SubstraitLoweringError: + """Typed error envelope for failures while lowering author-facing builders into Substrait.""" + + pub kind: SubstraitLoweringErrorKind + pub message: str + + def error_message(self) -> str: + """Return one stable `kind: message` string for diagnostics.""" + return f"{self.kind.value()}: {self.message}" + + +pub def invalid_scalar_expression(message: str) -> SubstraitLoweringError: + """Create one scalar-expression validation error.""" + return SubstraitLoweringError(kind=SubstraitLoweringErrorKind.InvalidScalarExpression, message=message) + + +pub def unknown_scalar_column(name: str) -> SubstraitLoweringError: + """Create one unknown-column scalar-expression error.""" + return SubstraitLoweringError( + kind=SubstraitLoweringErrorKind.UnknownScalarColumn, + message=f"unknown scalar expression column '{name}'", + ) + + +pub def unsupported_scalar_expression(message: str) -> SubstraitLoweringError: + """Create one unsupported scalar-expression validation error.""" + return SubstraitLoweringError(kind=SubstraitLoweringErrorKind.UnsupportedScalarExpression, message=message) diff --git a/src/substrait/expr_lowering.incn b/src/substrait/expr_lowering.incn index cf38ac3..f6a76a9 100644 --- a/src/substrait/expr_lowering.incn +++ b/src/substrait/expr_lowering.incn @@ -1,14 +1,26 @@ """ Expression lowering helpers for the current Substrait surface. -This module owns builder-to-Substrait expression conversion for projection expressions, filter predicates, and the -small numeric/index conversion helpers needed by proto fields along that path. +This module owns builder-to-Substrait expression conversion for scalar expressions and the small numeric/index +conversion helpers needed by proto fields along that path. """ +from rust::incan_stdlib::errors import raise_value_error from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32, u32 as RustU32 -from filter_builders import FilterLiteral, FilterLiteralKind, FilterPredicate, FilterPredicateKind -from projection_builders import ColumnExpr, ColumnExprKind, ProjectionAssignment +from projection_builders import ( + AddExpr, + BoolLiteralExpr, + ColumnExpr, + ColumnRefExpr, + EqExpr, + FloatLiteralExpr, + GtExpr, + IntLiteralExpr, + MultiplyExpr, + ProjectionAssignment, + StringLiteralExpr, +) from rust::substrait::proto import Expression, FunctionArgument, ProjectRel from rust::substrait::proto::expression import FieldReference, Literal, ReferenceSegment, RexType, ScalarFunction from rust::substrait::proto::expression::field_reference import ( @@ -20,6 +32,7 @@ from rust::substrait::proto::expression::literal import LiteralType from rust::substrait::proto::expression::reference_segment import ReferenceType as SegmentReferenceType, StructField from rust::substrait::proto::function_argument import ArgType from rust::substrait::proto::rel_common import EmitKind +from substrait.errors import SubstraitLoweringError, invalid_scalar_expression, unknown_scalar_column from substrait.extensions import ( ADD_FUNCTION_ANCHOR, EQUAL_FUNCTION_ANCHOR, @@ -103,22 +116,14 @@ pub def field_ref_expr(index: int) -> Expression: ) -pub def filter_literal_expr(literal: FilterLiteral) -> Expression: - """Build one literal expression from a filter literal builder value.""" - match literal.kind: - FilterLiteralKind.Int => return i64_expr(literal.int_value) - FilterLiteralKind.String => return string_expr(literal.string_value) - FilterLiteralKind.Bool => return bool_expr(literal.bool_value) - - -pub def scalar_function_expr(function_reference: int, arguments: list[Expression]) -> Expression: +pub def scalar_function_expr(function_reference: u32, arguments: list[Expression]) -> Expression: """Build one scalar-function expression over positional value arguments.""" lowered_arguments = [FunctionArgument(arg_type=Some(ArgType.Value(argument))) for argument in arguments] return Expression( rex_type=Some( RexType.ScalarFunction( ScalarFunction( - function_reference=to_rust_u32(function_reference), + function_reference=function_reference, arguments=lowered_arguments, output_type=None, args=[], @@ -129,6 +134,7 @@ pub def scalar_function_expr(function_reference: int, arguments: list[Expression ) +@derive(Clone) model ResolvedProjectionBinding: """One output-column binding mapped to a lowered Substrait expression relative to the current project input.""" @@ -137,6 +143,13 @@ model ResolvedProjectionBinding: output_index: int +@derive(Clone) +model ResolvedProjectExpression: + """One projected expression payload kept appendable before extracting the proto expression list.""" + + expr: Expression + + pub model LoweredProject: """One lowered ProjectRel payload containing explicit expressions plus emit mapping.""" @@ -173,43 +186,81 @@ pub def index_of_column(columns: list[str], name: str) -> int: return _index_of_name(columns, name) -def _resolved_projection_expr(bindings: list[ResolvedProjectionBinding], expr: ColumnExpr) -> Expression: - """Lower one projection expression against the current resolved binding set.""" - if expr.kind == ColumnExprKind.Column: - binding_idx = _index_of_binding(bindings, expr.column_name) - if binding_idx >= 0: - return bindings[binding_idx].expr.clone() - return bool_expr(false) - if expr.kind == ColumnExprKind.IntLiteral: - return i64_expr(expr.int_value) - if expr.kind == ColumnExprKind.FloatLiteral: - return f64_expr(expr.float_value) - if expr.kind == ColumnExprKind.StringLiteral: - return string_expr(expr.string_value) - if expr.kind == ColumnExprKind.BoolLiteral: - return bool_expr(expr.bool_value) - mut argument_count = 0 - for _argument in expr.arguments: - argument_count += 1 - if argument_count < 2: - return bool_expr(false) - left = _resolved_projection_expr(bindings, expr.arguments[0]) - right = _resolved_projection_expr(bindings, expr.arguments[1]) - if expr.kind == ColumnExprKind.Multiply: - return scalar_function_expr(MULTIPLY_FUNCTION_ANCHOR, [left, right]) - return scalar_function_expr(ADD_FUNCTION_ANCHOR, [left, right]) +def _resolved_projection_expr( + bindings: list[ResolvedProjectionBinding], + expr: ColumnExpr, +) -> Result[Expression, SubstraitLoweringError]: + """Lower one scalar expression against the current resolved binding set.""" + match expr: + ColumnRefExpr(column) => + binding_idx = _index_of_binding(bindings, column.column_name) + if binding_idx >= 0: + return Ok(bindings[binding_idx].expr.clone()) + return Err(unknown_scalar_column(column.column_name)) + IntLiteralExpr(literal) => return Ok(i64_expr(literal.value)) + FloatLiteralExpr(literal) => return Ok(f64_expr(literal.value)) + StringLiteralExpr(literal) => return Ok(string_expr(literal.value)) + BoolLiteralExpr(literal) => return Ok(bool_expr(literal.value)) + AddExpr(add_expr) => + if len(add_expr.arguments) != 2: + return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) + return _resolved_binary_scalar_function_expr( + bindings, + add_expr.arguments[0], + add_expr.arguments[1], + ADD_FUNCTION_ANCHOR, + ) + MultiplyExpr(mul_expr) => + if len(mul_expr.arguments) != 2: + return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) + return _resolved_binary_scalar_function_expr( + bindings, + mul_expr.arguments[0], + mul_expr.arguments[1], + MULTIPLY_FUNCTION_ANCHOR, + ) + EqExpr(eq_expr) => + if len(eq_expr.arguments) != 2: + return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) + return _resolved_binary_scalar_function_expr( + bindings, + eq_expr.arguments[0], + eq_expr.arguments[1], + EQUAL_FUNCTION_ANCHOR, + ) + GtExpr(gt_expr) => + if len(gt_expr.arguments) != 2: + return Err(invalid_scalar_expression("binary scalar expression requires exactly two arguments")) + return _resolved_binary_scalar_function_expr( + bindings, + gt_expr.arguments[0], + gt_expr.arguments[1], + GT_FUNCTION_ANCHOR, + ) + + +def _resolved_binary_scalar_function_expr( + bindings: list[ResolvedProjectionBinding], + left_expr: ColumnExpr, + right_expr: ColumnExpr, + function_anchor: u32, +) -> Result[Expression, SubstraitLoweringError]: + """Lower one binary scalar-function expression against resolved bindings.""" + left = _resolved_projection_expr(bindings, left_expr.clone())? + right = _resolved_projection_expr(bindings, right_expr.clone())? + return Ok(scalar_function_expr(function_anchor, [left, right])) def _apply_projection_assignments_to_bindings( input_columns: list[str], assignments: list[ProjectionAssignment], -) -> list[ResolvedProjectionBinding]: +) -> Result[list[ResolvedProjectionBinding], SubstraitLoweringError]: """Apply ordered add-or-replace assignments to the current binding set.""" input_column_count = len(input_columns) mut bindings = _resolved_projection_bindings_for_input_columns(input_columns) mut next_output_index = input_column_count for assignment in assignments: - resolved_expr = _resolved_projection_expr(bindings, assignment.expr) + resolved_expr = _resolved_projection_expr(bindings, assignment.expr)? existing_idx = _index_of_binding(bindings, assignment.output_name) if existing_idx >= 0: bindings[existing_idx] = ResolvedProjectionBinding( @@ -226,18 +277,21 @@ def _apply_projection_assignments_to_bindings( ), ) next_output_index += 1 - return bindings + return Ok(bindings) -pub def lower_project(input_columns: list[str], assignments: list[ProjectionAssignment]) -> LoweredProject: +pub def lower_project( + input_columns: list[str], + assignments: list[ProjectionAssignment], +) -> Result[LoweredProject, SubstraitLoweringError]: """Lower one add-or-replace projection into explicit project expressions plus emit mapping.""" input_column_count = len(input_columns) mut bindings = _resolved_projection_bindings_for_input_columns(input_columns) - mut expressions: list[Expression] = [] + mut expressions: list[ResolvedProjectExpression] = [] mut mapping_indexes: list[int] = [] for assignment in assignments: - resolved_expr = _resolved_projection_expr(bindings, assignment.expr) - expressions.append(resolved_expr.clone()) + resolved_expr = _resolved_projection_expr(bindings, assignment.expr)? + expressions.append(ResolvedProjectExpression(expr=resolved_expr.clone())) next_output_index = input_column_count + len(expressions) - 1 existing_idx = _index_of_binding(bindings, assignment.output_name) if existing_idx >= 0: @@ -256,34 +310,41 @@ pub def lower_project(input_columns: list[str], assignments: list[ProjectionAssi ) for binding in bindings: mapping_indexes.append(binding.output_index) - return LoweredProject(expressions=expressions, output_mapping=[to_rust_i32(index) for index in mapping_indexes]) - - -pub def filter_predicate_expr(input_columns: list[str], predicate: FilterPredicate) -> Expression: - """Lower one filter predicate builder into a Substrait expression.""" - # Short-circuit the always-true and always-false cases to avoid unnecessary column lookups and function calls. - if predicate.kind == FilterPredicateKind.AlwaysTrue: - return bool_expr(true) - if predicate.kind == FilterPredicateKind.AlwaysFalse: - return bool_expr(false) - # When the column is absent, return an always-false condition to filter out all rows. - column_index = index_of_column(input_columns, predicate.column_name) - if column_index < 0: - return bool_expr(false) - # match the predicate kind to the appropriate function call over the column reference and literal value. - left = field_ref_expr(column_index) - right = filter_literal_expr(predicate.literal) - if predicate.kind == FilterPredicateKind.Gt: - return scalar_function_expr(GT_FUNCTION_ANCHOR, [left, right]) - else: - return scalar_function_expr(EQUAL_FUNCTION_ANCHOR, [left, right]) + return Ok( + LoweredProject( + expressions=[resolved.expr.clone() for resolved in expressions], + output_mapping=[to_rust_i32(index) for index in mapping_indexes], + ), + ) + + +pub def scalar_expr(input_columns: list[str], expr: ColumnExpr) -> Result[Expression, SubstraitLoweringError]: + """Lower one scalar-expression builder against an input-column list.""" + return _resolved_projection_expr(_resolved_projection_bindings_for_input_columns(input_columns), expr) + + +pub def filter_predicate_expr( + input_columns: list[str], + predicate: ColumnExpr, +) -> Result[Expression, SubstraitLoweringError]: + """Lower one filter predicate through the shared scalar-expression path.""" + return scalar_expr(input_columns, predicate) pub def to_rust_i32(value: int) -> RustI32: """Convert one Incan `int` into Rust `i32` for proto fields that require 32-bit indices.""" - match RustI32.try_from(value): + match try_to_rust_i32(value): Ok(converted) => return converted - Err(_) => return 0 + Err(err) => + message = err.error_message() + return raise_value_error(message) + + +pub def try_to_rust_i32(value: int) -> Result[RustI32, SubstraitLoweringError]: + """Fallibly convert one Incan `int` into Rust `i32` for proto fields.""" + match RustI32.try_from(value): + Ok(converted) => return Ok(converted) + Err(_) => return Err(invalid_scalar_expression(f"value {value} does not fit Rust i32")) pub def rust_i32_to_int(value: RustI32) -> int: @@ -298,32 +359,28 @@ pub def rust_u32_to_int(value: RustU32) -> int: pub def to_rust_u32(value: int) -> RustU32: """Convert one non-negative Incan `int` into Rust `u32` for proto index vectors.""" - match RustU32.try_from(value): + match try_to_rust_u32(value): Ok(converted) => return converted - Err(_) => return 0 + Err(err) => + message = err.error_message() + return raise_value_error(message) -def _field_index_from_reference_segment(segment: ReferenceSegment) -> int: - """Extract one struct-field index from a direct reference segment, or `-1` when the segment is not a field.""" - match segment.reference_type: - Some(SegmentReferenceType.StructField(field)) => return (field.field).into() - _ => return -1 - - -def _field_index_from_field_reference(field_ref: FieldReference) -> int: - """Extract one struct-field index from a direct field reference, or `-1` when the reference is not direct.""" - match field_ref.reference_type: - Some(FieldReferenceType.DirectReference(segment)) => return _field_index_from_reference_segment(segment) - _ => return -1 +pub def try_to_rust_u32(value: int) -> Result[RustU32, SubstraitLoweringError]: + """Fallibly convert one non-negative Incan `int` into Rust `u32` for proto index vectors.""" + match RustU32.try_from(value): + Ok(converted) => return Ok(converted) + Err(_) => return Err(invalid_scalar_expression(f"value {value} does not fit Rust u32")) -# TODO(Incan RFC 049): revisit this direct-reference unwrapping path once `if let` lands. The helper split keeps the -# nested proto walk local today, but `if let` should let the selection/direct-reference happy path read much flatter. pub def field_index_from_expression(expr: Expression) -> int: """Extract the selected field index from one direct field-reference expression, or `-1` when absent.""" - match expr.rex_type: - Some(RexType.Selection(field_ref)) => return _field_index_from_field_reference(field_ref.as_ref().clone()) - _ => return -1 + if let Some(RexType.Selection(field_ref)) = expr.rex_type: + field_reference = field_ref.as_ref().clone() + if let Some(FieldReferenceType.DirectReference(segment)) = field_reference.reference_type: + if let Some(SegmentReferenceType.StructField(field)) = segment.reference_type: + return (field.field).into() + return -1 pub def grouping_reference_indexes(count: int) -> list[RustU32]: @@ -341,7 +398,7 @@ pub def project_output_columns_from_expressions(input_columns: list[str], expres continue match expr.rex_type.clone(): Some(RexType.ScalarFunction(fun)) => - columns.append(scalar_function_name_from_anchor(rust_u32_to_int(fun.function_reference))) + columns.append(scalar_function_name_from_anchor(fun.function_reference)) Some(RexType.Literal(literal)) => match literal.literal_type: Some(LiteralType.String(_)) => columns.append("string") @@ -372,6 +429,9 @@ pub def project_output_columns_with_emit(input_columns: list[str], project_rel: match project_rel.common.clone(): Some(common) => match common.emit_kind: - Some(EmitKind.Emit(emit)) => return _emitted_output_columns(default_columns, emit.output_mapping) + Some(emit_kind) => + match emit_kind: + EmitKind.Emit(emit) => return _emitted_output_columns(default_columns, emit.output_mapping) + _ => return default_columns _ => return default_columns _ => return default_columns diff --git a/src/substrait/extensions.incn b/src/substrait/extensions.incn index ad89d5f..a7e01e6 100644 --- a/src/substrait/extensions.incn +++ b/src/substrait/extensions.incn @@ -5,25 +5,50 @@ This module owns stable function anchors, extension URIs, anchor-name mapping, a collection over relation and expression trees. """ +from rust::incan_stdlib::errors import raise_value_error from rust::substrait::proto import Expression, Rel from rust::substrait::proto::extensions import SimpleExtensionDeclaration, SimpleExtensionUrn from rust::substrait::proto::extensions::simple_extension_declaration import ExtensionFunction, MappingType from rust::substrait::proto::function_argument import ArgType from rust::substrait::proto::expression import RexType from rust::substrait::proto::rel import RelType -from aggregate_builders import AggregateKind -from substrait.expr_lowering import to_rust_u32 - -# TODO(Incan RFC 009): use builtin `u32` constants for these Substrait function anchors once sized numeric -# types are available, then remove the `to_rust_u32(...)` conversion calls at proto field boundaries. -pub const SUM_FUNCTION_ANCHOR: int = 0 -pub const COUNT_FUNCTION_ANCHOR: int = 1 -pub const EQUAL_FUNCTION_ANCHOR: int = 2 -pub const GT_FUNCTION_ANCHOR: int = 3 -pub const ADD_FUNCTION_ANCHOR: int = 4 -pub const MULTIPLY_FUNCTION_ANCHOR: int = 5 -const FUNCTION_EXTENSION_URN_ANCHOR: int = 0 -const RELATION_EXTENSION_URN_ANCHOR: int = 1 +from substrait.errors import SubstraitLoweringError, invalid_scalar_expression +from substrait.traversal import relation_children + + +@derive(Clone) +model ExtensionUrnSpec: + """One Clone-able extension URN descriptor used before constructing Rust proto payloads.""" + + anchor: u32 + urn: str + + +@derive(Clone) +enum ExtensionFunctionKind(str): + """Function extension categories used by the current Substrait extension registry.""" + + Aggregate = "aggregate" + Scalar = "scalar" + + +@derive(Clone) +model FunctionExtensionSpec: + """One function-extension anchor/name/kind fact.""" + + anchor: u32 + name: str + kind: ExtensionFunctionKind + + +pub const SUM_FUNCTION_ANCHOR: u32 = 0 +pub const COUNT_FUNCTION_ANCHOR: u32 = 1 +pub const EQUAL_FUNCTION_ANCHOR: u32 = 2 +pub const GT_FUNCTION_ANCHOR: u32 = 3 +pub const ADD_FUNCTION_ANCHOR: u32 = 4 +pub const MULTIPLY_FUNCTION_ANCHOR: u32 = 5 +const FUNCTION_EXTENSION_URN_ANCHOR: u32 = 0 +const RELATION_EXTENSION_URN_ANCHOR: u32 = 1 const FUNCTION_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/functions.yaml" const EXPLODE_EXTENSION_URI: str = "https://inql.io/extensions/v0.1/unnest.yaml#explode" @@ -43,168 +68,143 @@ pub def registered_substrait_extension_uris() -> list[str]: return [FUNCTION_EXTENSION_URI, EXPLODE_EXTENSION_URI] -pub def aggregate_function_name_from_anchor(anchor: int) -> str: +pub def aggregate_function_name_from_anchor(anchor: u32) -> str: """Resolve one known aggregate-function anchor back to its registered function name.""" - if anchor == COUNT_FUNCTION_ANCHOR: - return "count" - return "sum" + return _function_name_from_anchor_or_raise(anchor, ExtensionFunctionKind.Aggregate) -pub def scalar_function_name_from_anchor(anchor: int) -> str: +pub def scalar_function_name_from_anchor(anchor: u32) -> str: """Resolve one known scalar-function anchor back to its registered function name.""" - if anchor == ADD_FUNCTION_ANCHOR: - return "add" - if anchor == MULTIPLY_FUNCTION_ANCHOR: - return "multiply" - if anchor == GT_FUNCTION_ANCHOR: - return "gt" - return "equal" - - -def _aggregate_extension_decl(anchor: int) -> SimpleExtensionDeclaration: - """Build one aggregate-function extension declaration for the provided anchor.""" + return _function_name_from_anchor_or_raise(anchor, ExtensionFunctionKind.Scalar) + + +def _function_extension_specs() -> list[FunctionExtensionSpec]: + """Return all known function-extension specs in stable declaration order.""" + return [FunctionExtensionSpec(anchor=SUM_FUNCTION_ANCHOR, name="sum", kind=ExtensionFunctionKind.Aggregate), FunctionExtensionSpec( + anchor=COUNT_FUNCTION_ANCHOR, + name="count", + kind=ExtensionFunctionKind.Aggregate, + ), FunctionExtensionSpec(anchor=EQUAL_FUNCTION_ANCHOR, name="equal", kind=ExtensionFunctionKind.Scalar), FunctionExtensionSpec( + anchor=GT_FUNCTION_ANCHOR, + name="gt", + kind=ExtensionFunctionKind.Scalar, + ), FunctionExtensionSpec(anchor=ADD_FUNCTION_ANCHOR, name="add", kind=ExtensionFunctionKind.Scalar), FunctionExtensionSpec( + anchor=MULTIPLY_FUNCTION_ANCHOR, + name="multiply", + kind=ExtensionFunctionKind.Scalar, + )] + + +def _function_spec_from_anchor(anchor: u32) -> Result[FunctionExtensionSpec, SubstraitLoweringError]: + """Resolve one function anchor to its registry spec.""" + for spec in _function_extension_specs(): + if spec.anchor == anchor: + return Ok(spec) + return Err(invalid_scalar_expression(f"unknown Substrait function anchor {anchor}")) + + +def _function_spec_from_anchor_of_kind( + anchor: u32, + expected_kind: ExtensionFunctionKind, +) -> Result[FunctionExtensionSpec, SubstraitLoweringError]: + """Resolve one function anchor and verify the expected extension kind.""" + spec = _function_spec_from_anchor(anchor)? + if spec.kind == expected_kind: + return Ok(spec) + return Err(invalid_scalar_expression(f"Substrait anchor {anchor} is not a {expected_kind.value()} function")) + + +def _function_name_from_anchor_or_raise(anchor: u32, expected_kind: ExtensionFunctionKind) -> str: + """Resolve one function name or raise a structured diagnostic for infallible inspection helpers.""" + match _function_spec_from_anchor_of_kind(anchor, expected_kind): + Ok(spec) => return spec.name + Err(err) => + message = err.error_message() + return raise_value_error(message) + + +def _function_extension_decl(spec: FunctionExtensionSpec) -> SimpleExtensionDeclaration: + """Build one function-extension declaration from the typed registry spec.""" return SimpleExtensionDeclaration( mapping_type=Some( MappingType.ExtensionFunction( ExtensionFunction( - extension_urn_reference=to_rust_u32(FUNCTION_EXTENSION_URN_ANCHOR), - function_anchor=to_rust_u32(anchor), - name=aggregate_function_name_from_anchor(anchor), + extension_urn_reference=FUNCTION_EXTENSION_URN_ANCHOR, + function_anchor=spec.anchor, + name=spec.name, ), ), ), ) -def _scalar_extension_decl(anchor: int) -> SimpleExtensionDeclaration: +def _aggregate_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: + """Build one aggregate-function extension declaration for the provided anchor.""" + match _function_spec_from_anchor_of_kind(anchor, ExtensionFunctionKind.Aggregate): + Ok(spec) => return _function_extension_decl(spec) + Err(err) => + message = err.error_message() + return raise_value_error(message) + + +def _scalar_extension_decl(anchor: u32) -> SimpleExtensionDeclaration: """Build one scalar-function extension declaration for the provided anchor.""" - return SimpleExtensionDeclaration( - mapping_type=Some( - MappingType.ExtensionFunction( - ExtensionFunction( - extension_urn_reference=to_rust_u32(FUNCTION_EXTENSION_URN_ANCHOR), - function_anchor=to_rust_u32(anchor), - name=scalar_function_name_from_anchor(anchor), - ), - ), - ), - ) + match _function_spec_from_anchor_of_kind(anchor, ExtensionFunctionKind.Scalar): + Ok(spec) => return _function_extension_decl(spec) + Err(err) => + message = err.error_message() + return raise_value_error(message) -def _relation_children(rel: Rel) -> list[Rel]: - """Return direct child relations for one relation node while walking extension usage.""" - match rel.rel_type: - Some(RelType.Filter(filter)) => - match filter.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Project(project)) => - match project.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Join(join)) => - mut children: list[Rel] = [] - match join.left: - Some(left) => children.append(left.as_ref().clone()) - None => pass - - match join.right: - Some(right) => children.append(right.as_ref().clone()) - None => pass - - return children - Some(RelType.Cross(cross)) => - mut children: list[Rel] = [] - match cross.left: - Some(left) => children.append(left.as_ref().clone()) - None => pass - - match cross.right: - Some(right) => children.append(right.as_ref().clone()) - None => pass - - return children - Some(RelType.Aggregate(aggregate)) => - match aggregate.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Sort(sort)) => - match sort.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Fetch(fetch)) => - match fetch.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Set(set_rel)) => return set_rel.inputs - Some(RelType.ExtensionSingle(extension)) => - match extension.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - _ => return [] - - -def _expr_uses_scalar_function_anchor(expr: Expression, expected_anchor: int) -> bool: +def _expr_uses_scalar_function_anchor(expr: Expression, expected_anchor: u32) -> bool: """Return whether one expression tree uses the requested scalar-function anchor.""" - expected_function_anchor = to_rust_u32(expected_anchor) - match expr.rex_type: - Some(RexType.ScalarFunction(fun)) => - if fun.function_reference == expected_function_anchor: - return true - for argument in fun.arguments: - match argument.arg_type: - Some(ArgType.Value(value)) => - if _expr_uses_scalar_function_anchor(value, expected_anchor): - return true - _ => pass - - return false - _ => return false - - -def _rel_uses_aggregate_function_anchor(rel: Rel, expected_anchor: int) -> bool: + if let Some(RexType.ScalarFunction(fun)) = expr.rex_type: + if fun.function_reference == expected_anchor: + return true + + for argument in fun.arguments: + if let Some(ArgType.Value(value)) = argument.arg_type: + if _expr_uses_scalar_function_anchor(value, expected_anchor): + return true + return false + + +def _rel_uses_aggregate_function_anchor(rel: Rel, expected_anchor: u32) -> bool: """Return whether one relation subtree uses the requested aggregate-function anchor.""" - expected_function_anchor = to_rust_u32(expected_anchor) - match rel.rel_type.clone(): - Some(RelType.Aggregate(aggregate_rel)) => - for measure in aggregate_rel.measures: - match measure.measure: - Some(agg_fn) => - if agg_fn.function_reference == expected_function_anchor: - return true - None => pass - _ => pass + if let Some(RelType.Aggregate(aggregate_rel)) = rel.rel_type.clone(): + for measure in aggregate_rel.measures: + if let Some(agg_fn) = measure.measure: + if agg_fn.function_reference == expected_anchor: + return true - for child in _relation_children(rel): + for child in relation_children(rel): if _rel_uses_aggregate_function_anchor(child, expected_anchor): return true return false -def _rel_uses_scalar_function_anchor(rel: Rel, expected_anchor: int) -> bool: +def _rel_uses_scalar_function_anchor(rel: Rel, expected_anchor: u32) -> bool: """Return whether one relation subtree uses the requested scalar-function anchor.""" match rel.rel_type.clone(): Some(RelType.Filter(filter_rel)) => - match filter_rel.condition: - Some(condition) => - if _expr_uses_scalar_function_anchor(condition.as_ref().clone(), expected_anchor): - return true - None => pass + if let Some(condition) = filter_rel.condition: + if _expr_uses_scalar_function_anchor(condition.as_ref().clone(), expected_anchor): + return true Some(RelType.Project(project_rel)) => for expr in project_rel.expressions: if _expr_uses_scalar_function_anchor(expr, expected_anchor): return true _ => pass - for child in _relation_children(rel): + for child in relation_children(rel): if _rel_uses_scalar_function_anchor(child, expected_anchor): return true return false -def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[int]: +def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect aggregate-function anchors used by one relation subtree in stable declaration order.""" - mut anchors: list[int] = [] + mut anchors: list[u32] = [] if _rel_uses_aggregate_function_anchor(rel.clone(), SUM_FUNCTION_ANCHOR): anchors.append(SUM_FUNCTION_ANCHOR) if _rel_uses_aggregate_function_anchor(rel, COUNT_FUNCTION_ANCHOR): @@ -212,9 +212,9 @@ def _aggregate_extension_anchors_for_rel(rel: Rel) -> list[int]: return anchors -def _scalar_extension_anchors_for_rel(rel: Rel) -> list[int]: +def _scalar_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect scalar-function anchors used by one relation subtree in stable declaration order.""" - mut anchors: list[int] = [] + mut anchors: list[u32] = [] if _rel_uses_scalar_function_anchor(rel.clone(), EQUAL_FUNCTION_ANCHOR): anchors.append(EQUAL_FUNCTION_ANCHOR) if _rel_uses_scalar_function_anchor(rel.clone(), GT_FUNCTION_ANCHOR): @@ -226,24 +226,26 @@ def _scalar_extension_anchors_for_rel(rel: Rel) -> list[int]: return anchors -def _aggregate_extension_decls(anchors: list[int]) -> list[SimpleExtensionDeclaration]: +def _aggregate_extension_decls(anchors: list[u32]) -> list[SimpleExtensionDeclaration]: """Lower aggregate-function anchors into extension declarations in the provided order.""" return [_aggregate_extension_decl(anchor) for anchor in anchors] -def _scalar_extension_decls(anchors: list[int]) -> list[SimpleExtensionDeclaration]: +def _scalar_extension_decls(anchors: list[u32]) -> list[SimpleExtensionDeclaration]: """Lower scalar-function anchors into extension declarations in the provided order.""" return [_scalar_extension_decl(anchor) for anchor in anchors] -def _extension_decl_for_anchor(anchor: int) -> SimpleExtensionDeclaration: +def _extension_decl_for_anchor(anchor: u32) -> SimpleExtensionDeclaration: """Lower one known aggregate or scalar function anchor into its extension declaration.""" - if anchor == SUM_FUNCTION_ANCHOR or anchor == COUNT_FUNCTION_ANCHOR: - return _aggregate_extension_decl(anchor) - return _scalar_extension_decl(anchor) + match _function_spec_from_anchor(anchor): + Ok(spec) => return _function_extension_decl(spec) + Err(err) => + message = err.error_message() + return raise_value_error(message) -def _plan_extension_anchors_for_rel(rel: Rel) -> list[int]: +def _plan_extension_anchors_for_rel(rel: Rel) -> list[u32]: """Collect aggregate and scalar function anchors used by one relation subtree in stable plan order.""" mut anchors = _aggregate_extension_anchors_for_rel(rel.clone()) anchors.extend(_scalar_extension_anchors_for_rel(rel)) @@ -255,6 +257,15 @@ def _function_extension_urn_is_required(rel: Rel) -> bool: return len(_plan_extension_anchors_for_rel(rel)) > 0 +def _direct_extension_urn_for_rel(rel: Rel) -> Option[str]: + """Return the extension URI carried directly by one extension relation node.""" + if let Some(RelType.ExtensionSingle(extension)) = rel.rel_type: + if let Some(detail) = extension.detail: + if detail.type_url != "": + return Some(detail.type_url) + return None + + pub def aggregate_extensions_for_rel(rel: Rel) -> list[SimpleExtensionDeclaration]: """Collect aggregate-function extension declarations required by one relation subtree.""" return _aggregate_extension_decls(_aggregate_extension_anchors_for_rel(rel)) @@ -273,16 +284,10 @@ pub def plan_extensions_for_rel(rel: Rel) -> list[SimpleExtensionDeclaration]: def _collect_extension_urn_strings(rel: Rel) -> list[str]: """Recursively collect extension URI strings from one relation subtree.""" mut urns: list[str] = [] - match rel.rel_type.clone(): - Some(RelType.ExtensionSingle(extension)) => - match extension.detail: - Some(detail) => - if detail.type_url != "": - urns.append(detail.type_url) - None => pass - _ => pass + if let Some(urn) = _direct_extension_urn_for_rel(rel.clone()): + urns.append(urn) - for child in _relation_children(rel): + for child in relation_children(rel): for urn in _collect_extension_urn_strings(child): urns.append(urn) return urns @@ -290,14 +295,9 @@ def _collect_extension_urn_strings(rel: Rel) -> list[str]: pub def extension_urns_for_rel(rel: Rel) -> list[SimpleExtensionUrn]: """Collect extension URNs required by one relation subtree.""" - mut urns: list[SimpleExtensionUrn] = [] + mut specs: list[ExtensionUrnSpec] = [] if _function_extension_urn_is_required(rel.clone()): - urns.append( - SimpleExtensionUrn( - extension_urn_anchor=to_rust_u32(FUNCTION_EXTENSION_URN_ANCHOR), - urn=FUNCTION_EXTENSION_URI, - ), - ) + specs.append(ExtensionUrnSpec(anchor=FUNCTION_EXTENSION_URN_ANCHOR, urn=FUNCTION_EXTENSION_URI)) for urn in _collect_extension_urn_strings(rel): - urns.append(SimpleExtensionUrn(extension_urn_anchor=to_rust_u32(RELATION_EXTENSION_URN_ANCHOR), urn=urn)) - return urns + specs.append(ExtensionUrnSpec(anchor=RELATION_EXTENSION_URN_ANCHOR, urn=urn)) + return [SimpleExtensionUrn(extension_urn_anchor=spec.anchor, urn=spec.urn) for spec in specs] diff --git a/src/substrait/inspect.incn b/src/substrait/inspect.incn index 62a5cbe..d14b53f 100644 --- a/src/substrait/inspect.incn +++ b/src/substrait/inspect.incn @@ -9,7 +9,6 @@ from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32 from rust::substrait::proto import AggregateRel, Expression, Plan, ReadRel, Rel, RelCommon from rust::substrait::proto::aggregate_rel import Measure -from rust::substrait::proto::expression::literal import LiteralType from rust::substrait::proto::function_argument import ArgType from rust::substrait::proto::plan_rel import RelType as PlanRelType from rust::substrait::proto::read_rel import NamedTable as ReadNamedTable, ReadType @@ -17,9 +16,11 @@ from rust::substrait::proto::rel import RelType from rust::substrait::proto::rel_common import Direct, EmitKind from rust::substrait::proto::set_rel import SetOp from aggregate_builders import AggregateKind, AggregateMeasure -from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit, to_rust_u32 +from projection_builders import scalar_expr_output_name +from substrait.expr_lowering import field_index_from_expression, project_output_columns_with_emit from substrait.extensions import COUNT_FUNCTION_ANCHOR from substrait.schema_registry import named_table_columns, unknown_named_struct +from substrait.traversal import relation_children def _malformed_root_rel() -> Rel: @@ -44,8 +45,8 @@ def _set_op_value(value: SetOp) -> RustI32: def _measure_output_name(measure: AggregateMeasure) -> str: """Return the projected output column name for one aggregate measure.""" if measure.kind == AggregateKind.Count: - return "count" - return "sum_" + measure.column_name + return measure.kind.value() + return measure.kind.value() + "_" + scalar_expr_output_name(measure.expr, "expr") pub def aggregate_measure_output_names(measures: list[AggregateMeasure]) -> list[str]: @@ -72,10 +73,12 @@ def _aggregate_input_columns(aggregate_rel: AggregateRel) -> list[str]: def _grouping_output_columns(input_columns: list[str], grouping_expressions: list[Expression]) -> list[str]: """Resolve grouping output names from field-reference expressions against one input-column list.""" mut columns: list[str] = [] - for expr in grouping_expressions: + for idx, expr in enumerate(grouping_expressions): field_index = field_index_from_expression(expr) if field_index >= 0 and field_index < len(input_columns): columns.append(input_columns[field_index]) + else: + columns.append(f"group_{idx}") return columns @@ -83,16 +86,19 @@ def _aggregate_measure_output_name(input_columns: list[str], measure: Measure) - """Resolve one aggregate output name from the lowered Substrait aggregate measure payload.""" match measure.measure: Some(agg_fn) => - if agg_fn.function_reference == to_rust_u32(COUNT_FUNCTION_ANCHOR): + if agg_fn.function_reference == COUNT_FUNCTION_ANCHOR: return "count" if len(agg_fn.arguments) == 0: return "sum" match agg_fn.arguments[0].arg_type: - Some(ArgType.Value(expr)) => - field_index = field_index_from_expression(expr) - if field_index >= 0 and field_index < len(input_columns): - return "sum_" + input_columns[field_index] - return "sum" + Some(arg_type) => + match arg_type: + ArgType.Value(expr) => + field_index = field_index_from_expression(expr) + if field_index >= 0 and field_index < len(input_columns): + return "sum_" + input_columns[field_index] + return "sum" + _ => return "sum" _ => return "sum" None => return "" @@ -187,11 +193,14 @@ pub def root_rel(plan: Plan) -> Rel: return _malformed_root_rel() relation = plan.relations[0] match relation.rel_type: - Some(PlanRelType.Root(root)) => - match root.input: - Some(rel) => return rel - None => return _malformed_root_rel() - Some(PlanRelType.Rel(rel)) => return rel + Some(rel_type) => + match rel_type: + PlanRelType.Root(root) => + match root.input: + Some(rel) => return rel + None => return _malformed_root_rel() + PlanRelType.Rel(rel) => return rel + _ => return _malformed_root_rel() _ => return _malformed_root_rel() @@ -201,7 +210,10 @@ pub def root_names(plan: Plan) -> list[str]: return [] relation = plan.relations[0] match relation.rel_type: - Some(PlanRelType.Root(root)) => return root.names + Some(rel_type) => + match rel_type: + PlanRelType.Root(root) => return root.names + _ => return [] _ => return [] @@ -288,55 +300,6 @@ pub def reference_subtree_ordinal(rel: Rel) -> RustI32: _ => return -1 -pub def relation_children(rel: Rel) -> list[Rel]: - """Return direct child relations for one relation node.""" - match rel.rel_type: - Some(RelType.Filter(filter)) => - match filter.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Project(project)) => - match project.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Join(join)) => - mut children: list[Rel] = [] - match join.left: - Some(left) => children.append(left.as_ref().clone()) - None => pass - match join.right: - Some(right) => children.append(right.as_ref().clone()) - None => pass - return children - Some(RelType.Cross(cross)) => - mut children: list[Rel] = [] - match cross.left: - Some(left) => children.append(left.as_ref().clone()) - None => pass - match cross.right: - Some(right) => children.append(right.as_ref().clone()) - None => pass - return children - Some(RelType.Aggregate(aggregate)) => - match aggregate.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Sort(sort)) => - match sort.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Fetch(fetch)) => - match fetch.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - Some(RelType.Set(set_rel)) => return set_rel.inputs - Some(RelType.ExtensionSingle(extension)) => - match extension.input: - Some(child) => return [child.as_ref().clone()] - None => return [] - _ => return [] - - pub def rel_contains_kind(rel: Rel, expected_kind: str) -> bool: """Return whether one relation subtree contains the requested public kind label.""" if relation_kind_name(rel) == expected_kind: diff --git a/src/substrait/mod.incn b/src/substrait/mod.incn index 55020e1..e11ac2f 100644 --- a/src/substrait/mod.incn +++ b/src/substrait/mod.incn @@ -4,6 +4,7 @@ Substrait support for InQL. The implementation is split into focused owner modules: - `substrait.schema_registry` for named-table schema binding +- `substrait.errors` for structured lowering errors - `substrait.expr_lowering` for builder-to-expression lowering - `substrait.extensions` for extension anchors and declaration collection - `substrait.relations` for concrete `Rel` construction @@ -13,6 +14,7 @@ The implementation is split into focused owner modules: from rust::prost import Message from rust::substrait::proto import Plan +pub from substrait.errors import SubstraitLoweringError, SubstraitLoweringErrorKind pub from substrait.schema_registry import named_table_columns, register_named_table_schema pub from substrait.relations import ( SubstraitJoinKind, diff --git a/src/substrait/plans.incn b/src/substrait/plans.incn index 0cdde58..48a9c46 100644 --- a/src/substrait/plans.incn +++ b/src/substrait/plans.incn @@ -85,7 +85,7 @@ pub def plan_from_local_parquet(uri: str) -> Plan: pub def plan_from_local_files(uri: str) -> Plan: - """Compatibility wrapper for the current Parquet-backed local-files plan helper.""" + """Build a minimal local-files plan through the current Parquet-backed helper.""" return plan_from_root_relation(read_local_files_rel(uri), []) diff --git a/src/substrait/relations.incn b/src/substrait/relations.incn index 0075957..4e2d374 100644 --- a/src/substrait/relations.incn +++ b/src/substrait/relations.incn @@ -6,8 +6,8 @@ relation variants. It depends on schema lookup, expression lowering, and output- plan assembly or plan inspection. """ -import std.testing from rust::prost_types import Any +from rust::incan_stdlib::errors import raise_value_error from rust::std::boxed import Box from rust::std::primitive import i32 as RustI32 from rust::substrait::proto import ( @@ -44,19 +44,17 @@ from rust::substrait::proto::rel_common import Direct, Emit, EmitKind from rust::substrait::proto::set_rel import SetOp from rust::substrait::proto::sort_field import SortDirection, SortKind from aggregate_builders import AggregateKind, AggregateMeasure -from filter_builders import FilterPredicate -from projection_builders import ProjectionAssignment +from projection_builders import ColumnExpr, ProjectionAssignment from substrait.expr_lowering import ( bool_expr, - field_ref_expr, filter_predicate_expr, grouping_reference_indexes, i64_expr, - index_of_column, lower_project, + scalar_expr, string_expr, - to_rust_u32, ) +from substrait.errors import SubstraitLoweringError, invalid_scalar_expression from substrait.extensions import COUNT_FUNCTION_ANCHOR, SUM_FUNCTION_ANCHOR from substrait.inspect import relation_output_columns from substrait.schema_registry import named_table_base_schema, unknown_named_struct @@ -64,7 +62,14 @@ from substrait.schema_registry import named_table_base_schema, unknown_named_str model ResolvedAggregateMeasure: kind: AggregateKind - column_index: int + expr: Expression + + +@derive(Clone) +model ResolvedRelationExpression: + """One relation expression payload kept appendable before extracting the proto expression list.""" + + expr: Expression pub enum SubstraitJoinKind: @@ -173,15 +178,17 @@ pub def set_op_from_kind(operation: SubstraitSetOperation) -> RustI32: SubstraitSetOperation.UnionAll => return SetOp.UnionAll.into() -def _set_operation_from_legacy_name(operation: str) -> SubstraitSetOperation: - """Map legacy set-operation names to typed Substrait set-operation enums.""" +def _set_operation_from_name(operation: str) -> Result[SubstraitSetOperation, SubstraitLoweringError]: + """Map string set-operation names to typed Substrait set-operation enums.""" if operation == "Union": - return SubstraitSetOperation.UnionDistinct + return Ok(SubstraitSetOperation.UnionDistinct) elif operation == "UnionAll": - return SubstraitSetOperation.UnionAll + return Ok(SubstraitSetOperation.UnionAll) elif operation == "Intersect": - return SubstraitSetOperation.IntersectPrimary - return SubstraitSetOperation.MinusPrimary + return Ok(SubstraitSetOperation.IntersectPrimary) + elif operation == "Minus": + return Ok(SubstraitSetOperation.MinusPrimary) + return Err(invalid_scalar_expression(f"unknown set operation '{operation}'")) def _resolved_measure_to_substrait(measure: ResolvedAggregateMeasure) -> Measure: @@ -190,12 +197,11 @@ def _resolved_measure_to_substrait(measure: ResolvedAggregateMeasure) -> Measure mut function_reference = COUNT_FUNCTION_ANCHOR if measure.kind == AggregateKind.Sum: function_reference = SUM_FUNCTION_ANCHOR - if measure.column_index >= 0: - arguments.append(FunctionArgument(arg_type=Some(ArgType.Value(field_ref_expr(measure.column_index))))) + arguments = [FunctionArgument(arg_type=Some(ArgType.Value(measure.expr.clone())))] return Measure( measure=Some( AggregateFunction( - function_reference=to_rust_u32(function_reference), + function_reference=function_reference, arguments=arguments, sorts=[], output_type=None, @@ -209,27 +215,23 @@ def _resolved_measure_to_substrait(measure: ResolvedAggregateMeasure) -> Measure ) -def _grouping_indexes(input_columns: list[str], grouping_columns: list[str]) -> list[int]: - """Resolve grouping column names to input-column indexes, keeping only the lookup phase here.""" - return [index_of_column(input_columns, column_name) for column_name in grouping_columns] - - -def _grouping_expressions_from_indexes(grouping_indexes: list[int]) -> list[Expression]: - """Build field-reference expressions for grouping indexes that resolved successfully.""" - return [field_ref_expr(index) for index in grouping_indexes if index >= 0] - - -def _resolved_measure(measure: AggregateMeasure, input_columns: list[str]) -> ResolvedAggregateMeasure: +def _resolved_measure( + measure: AggregateMeasure, + input_columns: list[str], +) -> Result[ResolvedAggregateMeasure, SubstraitLoweringError]: """Resolve one aggregate measure against the current input-column list.""" if measure.kind == AggregateKind.Count: - return ResolvedAggregateMeasure(kind=measure.kind, column_index=-1) - assert len(measure.column_name) > 0, "sum(...) currently requires a direct col(\"name\") reference" - return ResolvedAggregateMeasure(kind=measure.kind, column_index=index_of_column(input_columns, measure.column_name)) + return Ok(ResolvedAggregateMeasure(kind=measure.kind, expr=bool_expr(true))) + return Ok(ResolvedAggregateMeasure(kind=measure.kind, expr=scalar_expr(input_columns, measure.expr)?)) -def _resolved_measures(input_columns: list[str], measures: list[AggregateMeasure]) -> list[ResolvedAggregateMeasure]: - """Resolve all aggregate measures before lowering them into Substrait measure payloads.""" - return [_resolved_measure(measure, input_columns) for measure in measures] +def _lowered_rel_or_raise(result: Result[Rel, SubstraitLoweringError]) -> Rel: + """Return one lowered relation or raise the structured lowering diagnostic for infallible call sites.""" + match result: + Ok(rel) => return rel + Err(err) => + message = err.error_message() + return raise_value_error(message) pub def read_named_table_rel(table_name: str) -> Rel: @@ -268,7 +270,7 @@ pub def read_local_parquet_rel(uri: str) -> Rel: pub def read_local_files_rel(uri: str) -> Rel: - """Compatibility wrapper for the current Parquet-backed `LocalFiles` read helper.""" + """Construct a logical `ReadRel(LocalFiles)` through the current Parquet-backed helper.""" return read_local_parquet_rel(uri) @@ -287,42 +289,75 @@ pub def read_virtual_table_rel(table_name: str) -> Rel: return _rel_read(read) -pub def filter_rel(input: Rel, predicate: FilterPredicate) -> Rel: +pub def filter_rel(input: Rel, predicate: ColumnExpr) -> Rel: """Wrap a child relation in `FilterRel` with one builder-backed predicate.""" - return filter_rel_of_columns(input, relation_output_columns(input.clone()), predicate) + return _lowered_rel_or_raise(try_filter_rel(input, predicate)) + +pub def try_filter_rel(input: Rel, predicate: ColumnExpr) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in `FilterRel` with one builder-backed predicate.""" + return try_filter_rel_of_columns(input.clone(), relation_output_columns(input), predicate) -pub def filter_rel_of_columns(input: Rel, input_columns: list[str], predicate: FilterPredicate) -> Rel: + +pub def filter_rel_of_columns(input: Rel, input_columns: list[str], predicate: ColumnExpr) -> Rel: """Wrap a child relation in `FilterRel` using explicit input-column names.""" - return _rel_filter( - FilterRel( - common=Some(_direct_common()), - input=Some(Box.new(input)), - condition=Some(Box.new(filter_predicate_expr(input_columns, predicate))), - advanced_extension=None, + return _lowered_rel_or_raise(try_filter_rel_of_columns(input, input_columns, predicate)) + + +pub def try_filter_rel_of_columns( + input: Rel, + input_columns: list[str], + predicate: ColumnExpr, +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in `FilterRel` using explicit input-column names.""" + condition = filter_predicate_expr(input_columns, predicate)? + return Ok( + _rel_filter( + FilterRel( + common=Some(_direct_common()), + input=Some(Box.new(input)), + condition=Some(Box.new(condition)), + advanced_extension=None, + ), ), ) pub def project_rel(input: Rel) -> Rel: """Wrap a child relation in an identity `ProjectRel` over all current input columns.""" - return project_rel_of_columns(input, relation_output_columns(input.clone()), []) + return _lowered_rel_or_raise(try_project_rel(input)) + + +pub def try_project_rel(input: Rel) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in an identity `ProjectRel` over all current input columns.""" + return try_project_rel_of_columns(input.clone(), relation_output_columns(input), []) pub def project_rel_of_columns(input: Rel, input_columns: list[str], assignments: list[ProjectionAssignment]) -> Rel: """Wrap a child relation in `ProjectRel` using explicit input-column names and ordered add-or-replace assignments.""" - lowered = lower_project(input_columns, assignments) + return _lowered_rel_or_raise(try_project_rel_of_columns(input, input_columns, assignments)) + + +pub def try_project_rel_of_columns( + input: Rel, + input_columns: list[str], + assignments: list[ProjectionAssignment], +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in `ProjectRel` using explicit input-column names and assignments.""" + lowered = lower_project(input_columns, assignments)? common = RelCommon( hint=None, advanced_extension=None, emit_kind=Some(EmitKind.Emit(Emit(output_mapping=lowered.output_mapping))), ) - return _rel_project( - ProjectRel( - common=Some(common), - input=Some(Box.new(input)), - expressions=lowered.expressions, - advanced_extension=None, + return Ok( + _rel_project( + ProjectRel( + common=Some(common), + input=Some(Box.new(input)), + expressions=lowered.expressions, + advanced_extension=None, + ), ), ) @@ -359,35 +394,58 @@ pub def cross_rel(left: Rel, right: Rel) -> Rel: ) -pub def aggregate_rel(input: Rel, grouping_columns: list[str], measures: list[AggregateMeasure]) -> Rel: +pub def aggregate_rel(input: Rel, grouping_columns: list[ColumnExpr], measures: list[AggregateMeasure]) -> Rel: """Wrap a child relation in `AggregateRel` with explicit grouping columns and aggregate measures.""" - return aggregate_rel_of_columns(input, relation_output_columns(input.clone()), grouping_columns, measures) + return _lowered_rel_or_raise(try_aggregate_rel(input, grouping_columns, measures)) + + +pub def try_aggregate_rel( + input: Rel, + grouping_columns: list[ColumnExpr], + measures: list[AggregateMeasure], +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in `AggregateRel` with grouping columns and aggregate measures.""" + return try_aggregate_rel_of_columns(input.clone(), relation_output_columns(input), grouping_columns, measures) pub def aggregate_rel_of_columns( input: Rel, input_columns: list[str], - grouping_columns: list[str], + grouping_columns: list[ColumnExpr], measures: list[AggregateMeasure], ) -> Rel: """Wrap a child relation in `AggregateRel` using explicit input-column names.""" - grouping_indexes = _grouping_indexes(input_columns, grouping_columns) - grouping_exprs = _grouping_expressions_from_indexes(grouping_indexes) + return _lowered_rel_or_raise(try_aggregate_rel_of_columns(input, input_columns, grouping_columns, measures)) + + +pub def try_aggregate_rel_of_columns( + input: Rel, + input_columns: list[str], + grouping_columns: list[ColumnExpr], + measures: list[AggregateMeasure], +) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap a child relation in `AggregateRel` using explicit input-column names.""" + mut resolved_grouping_exprs: list[ResolvedRelationExpression] = [] + for column in grouping_columns: + grouping_expr = scalar_expr(input_columns, column)? + resolved_grouping_exprs.append(ResolvedRelationExpression(expr=grouping_expr.clone())) + grouping_exprs = [resolved.expr.clone() for resolved in resolved_grouping_exprs] grouping_expr_copies = [expr.clone() for expr in grouping_exprs] groupings = [Grouping( grouping_expressions=grouping_expr_copies, expression_references=grouping_reference_indexes(len(grouping_exprs)), )] - resolved_measures = _resolved_measures(input_columns, measures) - lowered_measures = [_resolved_measure_to_substrait(measure) for measure in resolved_measures] - return _rel_aggregate( - AggregateRel( - common=Some(_direct_common()), - input=Some(Box.new(input)), - groupings=groupings, - measures=lowered_measures, - grouping_expressions=grouping_exprs, - advanced_extension=None, + lowered_measures = [_resolved_measure_to_substrait(_resolved_measure(measure, input_columns)?) for measure in measures] + return Ok( + _rel_aggregate( + AggregateRel( + common=Some(_direct_common()), + input=Some(Box.new(input)), + groupings=groupings, + measures=lowered_measures, + grouping_expressions=grouping_exprs, + advanced_extension=None, + ), ), ) @@ -417,8 +475,13 @@ pub def fetch_rel(input: Rel, offset: int, count: int) -> Rel: pub def set_rel(left: Rel, right: Rel, operation: str) -> Rel: - """Compatibility wrapper for `SetRel` using the historic string-based operation name.""" - return set_rel_of_kind(left, right, _set_operation_from_legacy_name(operation)) + """Wrap two child relations in `SetRel` using a string operation name.""" + return _lowered_rel_or_raise(try_set_rel(left, right, operation)) + + +pub def try_set_rel(left: Rel, right: Rel, operation: str) -> Result[Rel, SubstraitLoweringError]: + """Fallibly wrap two child relations in `SetRel` using a string operation name.""" + return Ok(set_rel_of_kind(left, right, _set_operation_from_name(operation)?)) pub def set_rel_of_kind(left: Rel, right: Rel, operation: SubstraitSetOperation) -> Rel: diff --git a/src/substrait/schema.incn b/src/substrait/schema.incn index 9c08fc8..efe73d8 100644 --- a/src/substrait/schema.incn +++ b/src/substrait/schema.incn @@ -31,12 +31,12 @@ pub model DemoCustomer: @derive(Clone) -pub enum SubstraitPrimitiveKind: - Bool - I64 - F64 - Timestamp - String +pub enum SubstraitPrimitiveKind(str): + Bool = "bool" + I64 = "i64" + F64 = "f64" + Timestamp = "timestamp" + String = "string" @derive(Clone) @@ -74,22 +74,23 @@ def _type_from_primitive(kind: SubstraitPrimitiveKind, nullable: bool) -> Type: mut n = Nullability.Required if nullable: n = Nullability.Nullable - # `incan fmt` currently rewrites `match`/`case` into invalid Rust-style patterns; use branches. - if kind == SubstraitPrimitiveKind.Bool: - return Type(kind=Some(Kind.Bool(Boolean(type_variation_reference=0, nullability=n.into())))) - if kind == SubstraitPrimitiveKind.I64: - return Type(kind=Some(Kind.I64(I64(type_variation_reference=0, nullability=n.into())))) - if kind == SubstraitPrimitiveKind.F64: - return Type(kind=Some(Kind.Fp64(Fp64(type_variation_reference=0, nullability=n.into())))) - if kind == SubstraitPrimitiveKind.Timestamp: - return Type( - kind=Some( - Kind.PrecisionTimestamp( - PrecisionTimestamp(type_variation_reference=0, nullability=n.into(), precision=6), + match kind: + SubstraitPrimitiveKind.Bool => + return Type(kind=Some(Kind.Bool(Boolean(type_variation_reference=0, nullability=n.into())))) + SubstraitPrimitiveKind.I64 => + return Type(kind=Some(Kind.I64(I64(type_variation_reference=0, nullability=n.into())))) + SubstraitPrimitiveKind.F64 => + return Type(kind=Some(Kind.Fp64(Fp64(type_variation_reference=0, nullability=n.into())))) + SubstraitPrimitiveKind.Timestamp => + return Type( + kind=Some( + Kind.PrecisionTimestamp( + PrecisionTimestamp(type_variation_reference=0, nullability=n.into(), precision=6), + ), ), - ), - ) - return Type(kind=Some(Kind.String(SubstraitString(type_variation_reference=0, nullability=n.into())))) + ) + SubstraitPrimitiveKind.String => + return Type(kind=Some(Kind.String(SubstraitString(type_variation_reference=0, nullability=n.into())))) def _row_shape_field_types(shape: RowShapeSpec) -> list[Type]: diff --git a/src/substrait/schema_registry.incn b/src/substrait/schema_registry.incn index 916fe51..88153de 100644 --- a/src/substrait/schema_registry.incn +++ b/src/substrait/schema_registry.incn @@ -58,6 +58,10 @@ pub def named_table_base_schema(table_name: str) -> NamedStruct: pub def register_named_table_schema(table_name: str, columns: list[RowColumnSpec]) -> None: """Register one logical named-table schema used to bind `ReadRel` base_schema.""" + for idx, entry in enumerate(_named_table_schema_registry): + if entry.table_name == table_name: + _named_table_schema_registry[idx] = NamedTableSchemaEntry(table_name=table_name, columns=columns) + return _named_table_schema_registry.append(NamedTableSchemaEntry(table_name=table_name, columns=columns)) return diff --git a/src/substrait/traversal.incn b/src/substrait/traversal.incn new file mode 100644 index 0000000..93ea530 --- /dev/null +++ b/src/substrait/traversal.incn @@ -0,0 +1,53 @@ +"""Shared relation traversal helpers for Substrait tree walks.""" + +from rust::substrait::proto import Rel +from rust::substrait::proto::rel import RelType + + +pub def relation_children(rel: Rel) -> list[Rel]: + """Return direct child relations for one relation node.""" + match rel.rel_type: + Some(RelType.Filter(filter)) => + match filter.input: + Some(child) => return [child.as_ref().clone()] + None => return [] + Some(RelType.Project(project)) => + match project.input: + Some(child) => return [child.as_ref().clone()] + None => return [] + Some(RelType.Join(join)) => + mut children: list[Rel] = [] + match join.left: + Some(left) => children.append(left.as_ref().clone()) + None => pass + match join.right: + Some(right) => children.append(right.as_ref().clone()) + None => pass + return children + Some(RelType.Cross(cross)) => + mut children: list[Rel] = [] + match cross.left: + Some(left) => children.append(left.as_ref().clone()) + None => pass + match cross.right: + Some(right) => children.append(right.as_ref().clone()) + None => pass + return children + Some(RelType.Aggregate(aggregate)) => + match aggregate.input: + Some(child) => return [child.as_ref().clone()] + None => return [] + Some(RelType.Sort(sort)) => + match sort.input: + Some(child) => return [child.as_ref().clone()] + None => return [] + Some(RelType.Fetch(fetch)) => + match fetch.input: + Some(child) => return [child.as_ref().clone()] + None => return [] + Some(RelType.Set(set_rel)) => return set_rel.inputs + Some(RelType.ExtensionSingle(extension)) => + match extension.input: + Some(child) => return [child.as_ref().clone()] + None => return [] + _ => return [] diff --git a/tests/test_dataset.incn b/tests/test_dataset.incn index 3434763..bbde8b5 100644 --- a/tests/test_dataset.incn +++ b/tests/test_dataset.incn @@ -1,11 +1,26 @@ """Test: dataset types and RFC 001 contract (types, hierarchy, functions import).""" -from std.testing import assert_eq +from std.testing import fail_t, parametrize from metadata import inql_version from dataset import DataSet, BoundedDataSet, UnboundedDataSet, DataFrame, LazyFrame, DataStream, lazy_frame_named_table from dataset.materialization import DataFrameMaterialization from dataset.ops import filter_ds, join_ds -from functions import always_false, always_true, col, count, int_expr, mul, sum +from functions import ( + always_false, + always_true, + bool_lit, + col, + count, + float_expr, + int_expr, + int_lit, + lit, + mul, + str_expr, + str_lit, + sum, +) +from projection_builders import ColumnExprKind, column_expr_kind, column_expr_name from substrait.extensions import explode_extension_uri from substrait.inspect import plan_contains_relation_kind, plan_has_extension_urn, relation_kind_name, root_rel from substrait.plans import plan_encoded_len, plan_from_named_table, plan_from_root_relation @@ -24,6 +39,30 @@ model Event: id: int +def _register_order_schema(table_name: str) -> None: + register_named_table_schema(table_name, [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) + + +def _column_expr_kind_name(kind: ColumnExprKind) -> str: + match kind: + ColumnExprKind.IntLiteral => return "IntLiteral" + ColumnExprKind.FloatLiteral => return "FloatLiteral" + ColumnExprKind.StringLiteral => return "StringLiteral" + ColumnExprKind.BoolLiteral => return "BoolLiteral" + _ => return "Other" + + +def _literal_helper_kind_name(helper_name: str) -> str: + match helper_name: + "int_expr" => return _column_expr_kind_name(column_expr_kind(int_expr(2))) + "float_expr" => return _column_expr_kind_name(column_expr_kind(float_expr(2.5))) + "str_expr" => return _column_expr_kind_name(column_expr_kind(str_expr("open"))) + "int_lit" => return _column_expr_kind_name(column_expr_kind(int_lit(3))) + "str_lit" => return _column_expr_kind_name(column_expr_kind(str_lit("closed"))) + "bool_lit" => return _column_expr_kind_name(column_expr_kind(bool_lit(true))) + _ => return fail_t(f"unknown literal helper: {helper_name}") + + def _accept_data_set_generic[T with Clone](data: DataSet[T]) -> DataSet[T]: return data @@ -110,15 +149,27 @@ def test_smoke__dataset_types_are_published() -> None: count_result = count() # -- Assert -- - assert_eq(sum_result.column_name, "amount", "sum should preserve the selected input column name") - assert_eq(count_result.column_name, "", "count should remain a zero-argument aggregate helper") + assert column_expr_name(sum_result.expr) == "amount", "sum should preserve the selected input column expression" + assert column_expr_name(count_result.expr) == "", "count should remain a zero-argument aggregate helper" + + +@parametrize("helper_name, expected_kind", [("int_expr", "IntLiteral"), ("float_expr", "FloatLiteral"), ("str_expr", "StringLiteral"), ("int_lit", "IntLiteral"), ("str_lit", "StringLiteral"), ("bool_lit", "BoolLiteral")]) +def test_scalar_literals__filter_helper_aliases_share_column_expr_model(helper_name: str, expected_kind: str) -> None: + """Typed filter literal helper aliases should build canonical column expressions.""" + # -- Arrange -- + + # -- Act -- + actual_kind = _literal_helper_kind_name(helper_name) + + # -- Assert -- + assert actual_kind == expected_kind, f"{helper_name} should build a {expected_kind} scalar literal" def test_type_contracts__signature_tiers_compile() -> None: """RFC 001: DataSet / BoundedDataSet / UnboundedDataSet accept concrete carriers.""" # -- Arrange -- row = Order(id=1) - df: DataFrame[Order] = DataFrame[Order]( + df: DataFrame[Order] = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), @@ -143,12 +194,12 @@ def test_type_contracts__signature_tiers_compile() -> None: def test_type_contracts__concrete_carriers_compile() -> None: """RFC 001: concrete carriers flow through generic helper signatures.""" # -- Arrange -- - df: DataFrame[Order] = DataFrame[Order]( + df: DataFrame[Order] = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), ) - lf: LazyFrame[Order] = lazy_frame_named_table[Order]("orders") + lf: LazyFrame[Order] = lazy_frame_named_table("orders") ev = Event(id=4) st: DataStream[Event] = DataStream(_row_schema_marker=ev, _substrait_rel=read_named_table_rel("events")) @@ -166,12 +217,12 @@ def test_type_contracts__concrete_carriers_compile() -> None: def test_hierarchy__concrete_and_supertrait_assignability() -> None: """RFC 001: concrete -> bounded/unbounded trait -> DataSet chains compile.""" # -- Arrange -- - df: DataFrame[Order] = DataFrame[Order]( + df: DataFrame[Order] = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), ) - lf: LazyFrame[Order] = lazy_frame_named_table[Order]("orders") + lf: LazyFrame[Order] = lazy_frame_named_table("orders") ev = Event(id=6) st: DataStream[Event] = DataStream(_row_schema_marker=ev, _substrait_rel=read_named_table_rel("events")) @@ -185,21 +236,21 @@ def test_hierarchy__concrete_and_supertrait_assignability() -> None: def test_type_contracts__concrete_and_trait_types_match_generic_arguments() -> None: - """Generic parameter T matches for DataFrame/LazyFrame/DataStream carriers.""" + """Generic parameter T matches through supported callable trait boundaries.""" # -- Arrange -- - df: DataFrame[Order] = DataFrame[Order]( + df: DataFrame[Order] = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), ) - lf: LazyFrame[Order] = lazy_frame_named_table[Order]("orders") + lf: LazyFrame[Order] = lazy_frame_named_table("orders") ev = Event(id=8) st: DataStream[Event] = DataStream(_row_schema_marker=ev, _substrait_rel=read_named_table_rel("events")) # -- Act -- - bounded_df: BoundedDataSet[Order] = df - bounded_lf: BoundedDataSet[Order] = lf - ub: UnboundedDataSet[Event] = st + bounded_df = _accept_bounded_generic(df) + bounded_lf = _accept_bounded_generic(lf) + ub = _accept_unbounded_generic(st) # -- Assert -- _touch(bounded_df) @@ -216,13 +267,13 @@ def test_version__inql_version_is_published() -> None: version = inql_version() # -- Assert -- - assert_eq(version, expected_version, "InQL version should be 0.1.0") + assert version == expected_version, "InQL version should be 0.1.0" def test_dataset_ops__method_wrapper_matches_canonical_function() -> None: """Method wrappers should preserve canonical dataset_ops semantics.""" # -- Arrange -- - base_df = DataFrame[Order]( + base_df = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), @@ -233,23 +284,15 @@ def test_dataset_ops__method_wrapper_matches_canonical_function() -> None: via_method = base_df.filter(always_true()) # -- Assert -- - assert_eq(relation_kind_name(via_function), "FilterRel", "canonical function should produce FilterRel root") - assert_eq( - plan_encoded_len(via_method.to_substrait_plan()) > 0, - true, - "method wrapper should still emit a real proto plan", - ) - assert_eq( - relation_kind_name(root_rel(via_method.to_substrait_plan())), - "FilterRel", - "method wrapper plan should preserve filter root", - ) + assert relation_kind_name(via_function) == "FilterRel", "canonical function should produce FilterRel root" + assert plan_encoded_len(via_method.to_substrait_plan()) > 0, "method wrapper should still emit a real proto plan" + assert relation_kind_name(root_rel(via_method.to_substrait_plan())) == "FilterRel", "method wrapper plan should preserve filter root" def test_dataset_ops__data_frame_transforms_invalidate_stale_materialization() -> None: """DataFrame transforms should not retain payload-derived schema once the logical relation changes.""" # -- Arrange -- - base_df = DataFrame[Order]( + base_df = DataFrame( _type_witness=_order_witness(), _materialization=_preview_materialization(["id"], 1, "| id |\n| 1 |"), _substrait_rel=read_named_table_rel("orders"), @@ -259,28 +302,20 @@ def test_dataset_ops__data_frame_transforms_invalidate_stale_materialization() - transformed = base_df.filter(always_true()) # -- Assert -- - assert_eq(transformed.preview_text(), "", "transformed DataFrame should clear stale preview text") - assert_eq(transformed.row_count(), 0, "transformed DataFrame should clear stale collected row-count metadata") - assert_eq( - len(transformed.resolved_columns()), - 0, - "transformed DataFrame should clear resolved columns derived from stale materialization", - ) - assert_eq( - relation_kind_name(root_rel(transformed.to_substrait_plan())), - "FilterRel", - "transformed DataFrame should still preserve the updated logical relation", - ) + assert transformed.preview_text() == "", "transformed DataFrame should clear stale preview text" + assert transformed.row_count() == 0, "transformed DataFrame should clear stale collected row-count metadata" + assert len(transformed.resolved_columns()) == 0, "transformed DataFrame should clear resolved columns derived from stale materialization" + assert relation_kind_name(root_rel(transformed.to_substrait_plan())) == "FilterRel", "transformed DataFrame should still preserve the updated logical relation" def test_dataset_ops__all_carriers_emit_real_plans() -> None: # -- Arrange -- - df = DataFrame[Order]( + df = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), ) - lf = lazy_frame_named_table[Order]("orders") + lf: LazyFrame[Order] = lazy_frame_named_table("orders") st = DataStream(_row_schema_marker=Event(id=12), _substrait_rel=read_named_table_rel("events")) # -- Act -- @@ -289,39 +324,35 @@ def test_dataset_ops__all_carriers_emit_real_plans() -> None: st_plan = st.to_substrait_plan() # -- Assert -- - assert_eq(plan_encoded_len(df_plan) > 0, true, "DataFrame should emit a real plan") - assert_eq(plan_encoded_len(lf_plan) > 0, true, "LazyFrame should emit a real plan") - assert_eq(plan_encoded_len(st_plan) > 0, true, "DataStream should emit a real plan") + assert plan_encoded_len(df_plan) > 0, "DataFrame should emit a real plan" + assert plan_encoded_len(lf_plan) > 0, "LazyFrame should emit a real plan" + assert plan_encoded_len(st_plan) > 0, "DataStream should emit a real plan" def test_dataset_ops__with_column_updates_lazy_planned_columns_with_add_or_replace_semantics() -> None: """with_column should append new names and replace existing names in place at lazy plan time.""" # -- Arrange -- _register_projection_test_schema("orders_projection") - base = lazy_frame_named_table[Order]("orders_projection") + base: LazyFrame[Order] = lazy_frame_named_table("orders_projection") # -- Act -- - augmented = base.with_column("double_id", mul(col("id"), int_expr(2))) - replaced = augmented.with_column("id", int_expr(0)) + augmented = base.with_column("double_id", mul(col("id"), lit(2))) + replaced = augmented.with_column("id", lit(0)) # -- Assert -- augmented_cols = augmented.planned_columns() replaced_cols = replaced.planned_columns() - assert_eq(len(augmented_cols), 2, "with_column should append one new projected output column") - assert_eq(augmented_cols[0], "id", "existing source columns should stay ahead of appended projections") - assert_eq(augmented_cols[1], "double_id", "with_column should append the requested output name") - assert_eq(len(replaced_cols), 2, "replacing an existing name should preserve output column count") - assert_eq(replaced_cols[0], "id", "replacing an existing name should preserve that slot's ordinal position") - assert_eq( - replaced_cols[1], - "double_id", - "later projected columns should keep their relative order after replacement", - ) + assert len(augmented_cols) == 2, "with_column should append one new projected output column" + assert augmented_cols[0] == "id", "existing source columns should stay ahead of appended projections" + assert augmented_cols[1] == "double_id", "with_column should append the requested output name" + assert len(replaced_cols) == 2, "replacing an existing name should preserve output column count" + assert replaced_cols[0] == "id", "replacing an existing name should preserve that slot's ordinal position" + assert replaced_cols[1] == "double_id", "later projected columns should keep their relative order after replacement" def test_dataset_ops__api_lowered_boundary_facts_stay_stable() -> None: # -- Arrange -- - left = DataFrame[Order]( + left = DataFrame( _type_witness=_order_witness(), _materialization=DataFrameMaterialization.empty(), _substrait_rel=read_named_table_rel("orders"), @@ -335,106 +366,63 @@ def test_dataset_ops__api_lowered_boundary_facts_stay_stable() -> None: exploded_plan = left.explode().to_substrait_plan() # -- Assert -- - assert_eq( - relation_kind_name(root_rel(joined_plan)), - "JoinRel", - "canonical join function should still lower to a JoinRel root", - ) - assert_eq( - plan_has_extension_urn(exploded_plan, explode_extension_uri()), - true, - "explode method should emit the registered extension URI", - ) + assert relation_kind_name(root_rel(joined_plan)) == "JoinRel", "canonical join function should still lower to a JoinRel root" + assert plan_has_extension_urn(exploded_plan, explode_extension_uri()), "explode method should emit the registered extension URI" def test_lazy_frame__immutable_branching_and_origin_mapping_hold() -> None: # -- Arrange -- - base = lazy_frame_named_table[Order]("orders") + base: LazyFrame[Order] = lazy_frame_named_table("orders") # -- Act -- filtered: LazyFrame[Order] = base.filter(always_false()) # -- Assert -- - assert_eq( - plan_encoded_len(base.to_substrait_plan()) > 0, - true, - "base lazy carrier should still emit a real proto-backed plan", - ) - assert_eq( - relation_kind_name(root_rel(filtered.to_substrait_plan())), - "FilterRel", - "Prism-backed LazyFrame should still lower to FilterRel", - ) - assert_eq( - plan_contains_relation_kind(filtered.to_substrait_plan(), "ReadRel"), - true, - "derived filter plans should still preserve the read root", - ) + assert plan_encoded_len(base.to_substrait_plan()) > 0, "base lazy carrier should still emit a real proto-backed plan" + assert relation_kind_name(root_rel(filtered.to_substrait_plan())) == "FilterRel", "Prism-backed LazyFrame should still lower to FilterRel" + assert plan_contains_relation_kind(filtered.to_substrait_plan(), "ReadRel"), "derived filter plans should still preserve the read root" def test_lazy_frame__independent_roots_can_join_and_lower() -> None: # -- Arrange -- - left: LazyFrame[Order] = lazy_frame_named_table[Order]("orders").filter(always_false()) - right: LazyFrame[Order] = lazy_frame_named_table[Order]("orders_archive").filter(always_false()) + left: LazyFrame[Order] = lazy_frame_named_table("orders").filter(always_false()) + right: LazyFrame[Order] = lazy_frame_named_table("orders_archive").filter(always_false()) # -- Act -- joined: LazyFrame[Order] = left.join(right, true) plan = joined.to_substrait_plan() # -- Assert -- - assert_eq(relation_kind_name(root_rel(plan)), "JoinRel", "joined lazy carriers should lower to a JoinRel root") - assert_eq( - plan_contains_relation_kind(plan, "FilterRel"), - true, - "joined lazy carriers should preserve nested filters", - ) - assert_eq( - plan_contains_relation_kind(plan, "ReadRel"), - true, - "joined lazy carriers should still reach the original read roots", - ) + assert relation_kind_name(root_rel(plan)) == "JoinRel", "joined lazy carriers should lower to a JoinRel root" + assert plan_contains_relation_kind(plan, "FilterRel"), "joined lazy carriers should preserve nested filters" + assert plan_contains_relation_kind(plan, "ReadRel"), "joined lazy carriers should still reach the original read roots" def test_lazy_frame__native_prism_ops_preserve_current_boundary_shapes() -> None: # -- Arrange -- - projected = lazy_frame_named_table[Order]("orders").select() - grouped = lazy_frame_named_table[Order]("orders").group_by([col("id")]) + _register_order_schema("orders") + projected: LazyFrame[Order] = lazy_frame_named_table("orders").select() + grouped: LazyFrame[Order] = lazy_frame_named_table("orders").group_by([col("id")]) # -- Act -- aggregated = grouped.agg([count()]) - ordered = lazy_frame_named_table[Order]("orders").order_by() - limited = lazy_frame_named_table[Order]("orders").limit(10) - exploded = lazy_frame_named_table[Order]("orders").explode() + ordered: LazyFrame[Order] = lazy_frame_named_table("orders").order_by() + limited: LazyFrame[Order] = lazy_frame_named_table("orders").limit(10) + exploded: LazyFrame[Order] = lazy_frame_named_table("orders").explode() # -- Assert -- - assert_eq( - relation_kind_name(root_rel(projected.to_substrait_plan())), - "ProjectRel", - "select should lower through the project boundary shape", - ) - assert_eq( - relation_kind_name(root_rel(grouped.to_substrait_plan())), - "AggregateRel", - "group_by should lower to the aggregate scaffold boundary shape", - ) - assert_eq( - relation_kind_name(root_rel(aggregated.to_substrait_plan())), - "AggregateRel", - "agg should lower to AggregateRel", - ) - assert_eq(relation_kind_name(root_rel(ordered.to_substrait_plan())), "SortRel", "order_by should lower to SortRel") - assert_eq(relation_kind_name(root_rel(limited.to_substrait_plan())), "FetchRel", "limit should lower to FetchRel") - assert_eq( - plan_has_extension_urn(exploded.to_substrait_plan(), explode_extension_uri()), - true, - "explode should keep emitting the registered extension boundary", - ) + assert relation_kind_name(root_rel(projected.to_substrait_plan())) == "ProjectRel", "select should lower through the project boundary shape" + assert relation_kind_name(root_rel(grouped.to_substrait_plan())) == "AggregateRel", "group_by should lower to the aggregate scaffold boundary shape" + assert relation_kind_name(root_rel(aggregated.to_substrait_plan())) == "AggregateRel", "agg should lower to AggregateRel" + assert relation_kind_name(root_rel(ordered.to_substrait_plan())) == "SortRel", "order_by should lower to SortRel" + assert relation_kind_name(root_rel(limited.to_substrait_plan())) == "FetchRel", "limit should lower to FetchRel" + assert plan_has_extension_urn(exploded.to_substrait_plan(), explode_extension_uri()), "explode should keep emitting the registered extension boundary" def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() -> None: # -- Arrange -- - left: LazyFrame[Order] = lazy_frame_named_table[Order]("orders").filter(always_false()) - right_base: LazyFrame[Order] = lazy_frame_named_table[Order]("orders_archive") + left: LazyFrame[Order] = lazy_frame_named_table("orders").filter(always_false()) + right_base: LazyFrame[Order] = lazy_frame_named_table("orders_archive") # -- Act -- right_joined: LazyFrame[Order] = right_base.filter(always_false()).order_by().join( @@ -445,63 +433,31 @@ def test_lazy_frame__deeper_independent_roots_still_lower_with_stable_shapes() - plan = joined.to_substrait_plan() # -- Assert -- - assert_eq( - relation_kind_name(root_rel(plan)), - "JoinRel", - "deeper independent-root join compositions should still lower to JoinRel", - ) - assert_eq( - plan_contains_relation_kind(plan, "SortRel"), - true, - "deeper branch compositions should preserve sort nodes", - ) - assert_eq( - plan_contains_relation_kind(plan, "FilterRel"), - true, - "deeper branch compositions should preserve filter nodes", - ) - assert_eq( - plan_contains_relation_kind(plan, "ReadRel"), - true, - "deeper branch compositions should still preserve read roots", - ) + assert relation_kind_name(root_rel(plan)) == "JoinRel", "deeper independent-root join compositions should still lower to JoinRel" + assert plan_contains_relation_kind(plan, "SortRel"), "deeper branch compositions should preserve sort nodes" + assert plan_contains_relation_kind(plan, "FilterRel"), "deeper branch compositions should preserve filter nodes" + assert plan_contains_relation_kind(plan, "ReadRel"), "deeper branch compositions should still preserve read roots" def test_lazy_frame__canonical_rewrite_removes_filter_true_and_preserves_boundary_validity() -> None: # -- Arrange -- - simplified = lazy_frame_named_table[Order]("orders").filter(always_true()) + simplified: LazyFrame[Order] = lazy_frame_named_table("orders").filter(always_true()) # -- Act -- plan = simplified.to_substrait_plan() # -- Assert -- - assert_eq( - plan_encoded_len(plan) > 0, - true, - "canonicalized plans should still encode to a non-empty Substrait payload", - ) - assert_eq( - relation_kind_name(root_rel(plan)), - "ReadRel", - "filter(true) should canonicalize away before RFC 002 lowering", - ) - assert_eq( - plan_contains_relation_kind(plan, "ReadRel"), - true, - "canonicalized plans should still preserve the read root", - ) + assert plan_encoded_len(plan) > 0, "canonicalized plans should still encode to a non-empty Substrait payload" + assert relation_kind_name(root_rel(plan)) == "ReadRel", "filter(true) should canonicalize away before RFC 002 lowering" + assert plan_contains_relation_kind(plan, "ReadRel"), "canonicalized plans should still preserve the read root" def test_lazy_frame__canonical_rewrite_keeps_extension_boundary_for_explode() -> None: # -- Arrange -- - exploded = lazy_frame_named_table[Order]("orders").filter(always_true()).explode() + exploded: LazyFrame[Order] = lazy_frame_named_table("orders").filter(always_true()).explode() # -- Act -- plan = exploded.to_substrait_plan() # -- Assert -- - assert_eq( - plan_has_extension_urn(plan, explode_extension_uri()), - true, - "canonical rewrites must keep explode extension URI emission intact", - ) + assert plan_has_extension_urn(plan, explode_extension_uri()), "canonical rewrites must keep explode extension URI emission intact" diff --git a/tests/test_model_substrait_schema.incn b/tests/test_model_substrait_schema.incn index cb2a01c..c705cbb 100644 --- a/tests/test_model_substrait_schema.incn +++ b/tests/test_model_substrait_schema.incn @@ -1,6 +1,5 @@ """Tests for model → Substrait NamedStruct prototype.""" -from std.testing import assert_eq from substrait.schema import ( demo_customer_row_shape, row_shape_field_names, @@ -17,10 +16,10 @@ def test_demo_customer_row_shape__field_order() -> None: names = row_shape_field_names(shape) # -- Assert -- - assert_eq(shape.model_name, "DemoCustomer", "model name") - assert_eq(len(names), 2, "column count") - assert_eq(names[0], "id", "first field is declaration order") - assert_eq(names[1], "name", "second field") + assert shape.model_name == "DemoCustomer", "model name" + assert len(names) == 2, "column count" + assert names[0] == "id", "first field is declaration order" + assert names[1] == "name", "second field" def test_row_shape_to_substrait_struct_type__encodes() -> None: @@ -32,4 +31,4 @@ def test_row_shape_to_substrait_struct_type__encodes() -> None: encoded = substrait_row_type_encoded_len(row_ty) # -- Assert -- - assert_eq(encoded > 0, true, "protobuf encoding should be non-empty") + assert encoded > 0, "protobuf encoding should be non-empty" diff --git a/tests/test_prism.incn b/tests/test_prism.incn index 23c1da1..9c3aa49 100644 --- a/tests/test_prism.incn +++ b/tests/test_prism.incn @@ -1,7 +1,6 @@ """Internal Prism engine tests against the shared-store cursor substrate.""" -from std.testing import assert_eq -from functions import always_false, always_true, col, count, int_expr, mul, sum +from functions import always_false, always_true, col, count, lit, mul, sum from prism import ( PrismCursor, prism_cursor_apply_filter, @@ -38,60 +37,44 @@ def _register_projection_test_schema(table_name: str) -> None: def test_prism__branching_keeps_base_reachable_history_small() -> None: # -- Arrange -- - base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) # -- Act -- filtered: PrismCursor[Order] = base.filter(always_false()) plan = filtered.to_substrait_plan() # -- Assert -- - assert_eq( - prism_cursors_share_store(base, filtered), - true, - "unary branching should share one append-only Prism store", - ) - assert_eq(prism_cursor_authored_node_count(base), 1, "base cursor should still only reach the original read node") - assert_eq( - prism_cursor_store_node_count(base), - 2, - "shared store should contain the appended filter node exactly once", - ) - assert_eq(prism_cursor_tip_kind_name(base), str("ReadNamedTable"), "base cursor tip should remain the read root") - assert_eq( - prism_cursor_authored_node_count(filtered), - 2, - "derived cursor should reach the read node plus the new filter", - ) - assert_eq(prism_cursor_tip_kind_name(filtered), str("Filter"), "filtered cursor tip should become the filter node") - assert_eq( - prism_cursor_tip_origin_id(filtered), - 1, - "identity optimized view should preserve the appended filter as the tip origin", - ) - assert_eq(plan_encoded_len(plan) > 0, true, "lazy carrier should still emit a real proto-backed plan") - assert_eq(relation_kind_name(root_rel(plan)), str("FilterRel"), "Prism lowering should preserve filter root shape") - assert_eq(plan_contains_relation_kind(plan, str("ReadRel")), true, "Prism lowering should still reach the read root") + assert prism_cursors_share_store(base, filtered), "unary branching should share one append-only Prism store" + assert prism_cursor_authored_node_count(base) == 1, "base cursor should still only reach the original read node" + assert prism_cursor_store_node_count(base) == 2, "shared store should contain the appended filter node exactly once" + assert prism_cursor_tip_kind_name(base) == str("ReadNamedTable"), "base cursor tip should remain the read root" + assert prism_cursor_authored_node_count(filtered) == 2, "derived cursor should reach the read node plus the new filter" + assert prism_cursor_tip_kind_name(filtered) == str("Filter"), "filtered cursor tip should become the filter node" + assert prism_cursor_tip_origin_id(filtered) == 1, "identity optimized view should preserve the appended filter as the tip origin" + assert plan_encoded_len(plan) > 0, "lazy carrier should still emit a real proto-backed plan" + assert relation_kind_name(root_rel(plan)) == str("FilterRel"), "Prism lowering should preserve filter root shape" + assert plan_contains_relation_kind(plan, str("ReadRel")), "Prism lowering should still reach the read root" def test_prism__independent_roots_get_distinct_store_ids() -> None: # -- Arrange -- - left: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")) - right: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders_archive")) + left: PrismCursor[Order] = prism_cursor_named_table(str("orders")) + right: PrismCursor[Order] = prism_cursor_named_table(str("orders_archive")) # -- Act -- left_store_id = prism_cursor_store_id_value(left) right_store_id = prism_cursor_store_id_value(right) # -- Assert -- - assert_eq(prism_cursors_share_store(left, right), false, "independent read roots should not share one store id") - assert_eq(left_store_id >= 0, true, "store ids should be non-negative") - assert_eq(right_store_id >= 0, true, "store ids should be non-negative") - assert_eq(left_store_id == right_store_id, false, "allocator should assign distinct ids to distinct stores") + assert prism_cursors_share_store(left, right) is false, "independent read roots should not share one store id" + assert left_store_id >= 0, "store ids should be non-negative" + assert right_store_id >= 0, "store ids should be non-negative" + assert left_store_id != right_store_id, "allocator should assign distinct ids to distinct stores" def test_prism__same_store_join_reuses_shared_history() -> None: # -- Arrange -- - base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) left: PrismCursor[Order] = base.filter(always_false()) right: PrismCursor[Order] = base.select() @@ -100,54 +83,22 @@ def test_prism__same_store_join_reuses_shared_history() -> None: plan = joined.to_substrait_plan() # -- Assert -- - assert_eq( - prism_cursors_share_store(left, right), - true, - "branches from the same base should still share one store before the join", - ) - assert_eq( - prism_cursor_authored_node_count(base), - 1, - "joining later should not mutate the base cursor's reachable history", - ) - assert_eq( - prism_cursor_store_node_count(joined), - 4, - "same-store join should append one join node without duplicating the shared read root", - ) - assert_eq( - prism_cursor_authored_node_count(joined), - 4, - "joined cursor should reach the shared read plus both unary nodes and the join", - ) - assert_eq(prism_cursor_tip_kind_name(joined), str("Join"), "same-store join tip should become a join node") - assert_eq( - prism_cursor_tip_origin_id(joined), - 3, - "identity optimized view should preserve the appended join node as the tip origin", - ) - assert_eq(plan_encoded_len(plan) > 0, true, "branch-and-join carrier should still emit a real proto-backed plan") - assert_eq(relation_kind_name(root_rel(plan)), str("JoinRel"), "binary authored nodes should lower back to JoinRel") - assert_eq( - plan_contains_relation_kind(plan, str("FilterRel")), - true, - "join lowering should preserve the left filter input", - ) - assert_eq( - plan_contains_relation_kind(plan, str("ProjectRel")), - true, - "join lowering should preserve the right project input", - ) - assert_eq( - plan_contains_relation_kind(plan, str("ReadRel")), - true, - "join lowering should still reach the original read root", - ) + assert prism_cursors_share_store(left, right), "branches from the same base should still share one store before the join" + assert prism_cursor_authored_node_count(base) == 1, "joining later should not mutate the base cursor's reachable history" + assert prism_cursor_store_node_count(joined) == 4, "same-store join should append one join node without duplicating the shared read root" + assert prism_cursor_authored_node_count(joined) == 4, "joined cursor should reach the shared read plus both unary nodes and the join" + assert prism_cursor_tip_kind_name(joined) == str("Join"), "same-store join tip should become a join node" + assert prism_cursor_tip_origin_id(joined) == 3, "identity optimized view should preserve the appended join node as the tip origin" + assert plan_encoded_len(plan) > 0, "branch-and-join carrier should still emit a real proto-backed plan" + assert relation_kind_name(root_rel(plan)) == str("JoinRel"), "binary authored nodes should lower back to JoinRel" + assert plan_contains_relation_kind(plan, str("FilterRel")), "join lowering should preserve the left filter input" + assert plan_contains_relation_kind(plan, str("ProjectRel")), "join lowering should preserve the right project input" + assert plan_contains_relation_kind(plan, str("ReadRel")), "join lowering should still reach the original read root" def test_prism__same_store_join_with_longer_branches_is_still_one_append() -> None: # -- Arrange -- - base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) left: PrismCursor[Order] = base.filter(always_false()).order_by() right: PrismCursor[Order] = base.select().limit(5) @@ -156,79 +107,39 @@ def test_prism__same_store_join_with_longer_branches_is_still_one_append() -> No plan = joined.to_substrait_plan() # -- Assert -- - assert_eq( - prism_cursors_share_store(left, right), - true, - "branches from one base should continue sharing a store as branch depth grows", - ) - assert_eq( - prism_cursor_authored_node_count(base), - 1, - "deeper branch joins should still keep the base handle immutable", - ) - assert_eq( - prism_cursor_store_node_count(joined), - 6, - "same-store join should still append one join node after deeper branch history", - ) - assert_eq( - prism_cursor_authored_node_count(joined), - 6, - "joined cursor should still reach exactly read + branch nodes + join", - ) - assert_eq( - relation_kind_name(root_rel(plan)), - str("JoinRel"), - "deeper same-store branches should still lower through JoinRel", - ) - assert_eq(plan_contains_relation_kind(plan, str("SortRel")), true, "left branch sort should be preserved") - assert_eq(plan_contains_relation_kind(plan, str("FetchRel")), true, "right branch limit should be preserved") - assert_eq( - plan_contains_relation_kind(plan, str("ReadRel")), - true, - "joined deeper branches should still reach the original read root", - ) + assert prism_cursors_share_store(left, right), "branches from one base should continue sharing a store as branch depth grows" + assert prism_cursor_authored_node_count(base) == 1, "deeper branch joins should still keep the base handle immutable" + assert prism_cursor_store_node_count(joined) == 6, "same-store join should still append one join node after deeper branch history" + assert prism_cursor_authored_node_count(joined) == 6, "joined cursor should still reach exactly read + branch nodes + join" + assert relation_kind_name(root_rel(plan)) == str("JoinRel"), "deeper same-store branches should still lower through JoinRel" + assert plan_contains_relation_kind(plan, str("SortRel")), "left branch sort should be preserved" + assert plan_contains_relation_kind(plan, str("FetchRel")), "right branch limit should be preserved" + assert plan_contains_relation_kind(plan, str("ReadRel")), "joined deeper branches should still reach the original read root" def test_prism__cross_store_join_adopts_reachable_subgraph() -> None: # -- Arrange -- - left: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).filter(always_false()) - right: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders_archive")).filter(always_false()) + left: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_false()) + right: PrismCursor[Order] = prism_cursor_named_table(str("orders_archive")).filter(always_false()) # -- Act -- joined: PrismCursor[Order] = left.join(right.clone(), true) plan = joined.to_substrait_plan() # -- Assert -- - assert_eq(prism_cursors_share_store(left, right), false, "independent roots should begin on distinct stores") - assert_eq( - prism_cursor_store_node_count(joined), - 5, - "cross-store join should adopt two reachable rhs nodes then append one join node", - ) - assert_eq( - prism_cursor_authored_node_count(joined), - 5, - "joined cursor should reach both adopted branches plus the join", - ) - assert_eq(prism_cursor_tip_kind_name(joined), str("Join"), "cross-store join tip should become a join node") - assert_eq(relation_kind_name(root_rel(plan)), str("JoinRel"), "cross-store joins should still lower to JoinRel") - assert_eq( - plan_contains_relation_kind(plan, str("FilterRel")), - true, - "cross-store join lowering should preserve nested filters", - ) - assert_eq( - plan_contains_relation_kind(plan, str("ReadRel")), - true, - "cross-store join lowering should still reach both read roots", - ) + assert prism_cursors_share_store(left, right) is false, "independent roots should begin on distinct stores" + assert prism_cursor_store_node_count(joined) == 5, "cross-store join should adopt two reachable rhs nodes then append one join node" + assert prism_cursor_authored_node_count(joined) == 5, "joined cursor should reach both adopted branches plus the join" + assert prism_cursor_tip_kind_name(joined) == str("Join"), "cross-store join tip should become a join node" + assert relation_kind_name(root_rel(plan)) == str("JoinRel"), "cross-store joins should still lower to JoinRel" + assert plan_contains_relation_kind(plan, str("FilterRel")), "cross-store join lowering should preserve nested filters" + assert plan_contains_relation_kind(plan, str("ReadRel")), "cross-store join lowering should still reach both read roots" def test_prism__cross_store_join_dedups_equivalent_reachable_rhs_nodes() -> None: # -- Arrange -- - left: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).filter(always_false()) - right_base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders_archive")) + left: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_false()) + right_base: PrismCursor[Order] = prism_cursor_named_table(str("orders_archive")) right_left: PrismCursor[Order] = right_base.filter(always_false()) right_right: PrismCursor[Order] = right_base.filter(always_false()) @@ -238,42 +149,18 @@ def test_prism__cross_store_join_dedups_equivalent_reachable_rhs_nodes() -> None plan = joined.to_substrait_plan() # -- Assert -- - assert_eq( - prism_cursors_share_store(left, right_joined), - false, - "left and rhs join trees should still begin on distinct stores", - ) - assert_eq( - prism_cursor_store_node_count(joined), - 6, - "cross-store adoption should dedup equivalent rhs filter branches instead of cloning both copies", - ) - assert_eq( - prism_cursor_authored_node_count(joined), - 6, - "joined cursor should reach the left branch plus one deduped rhs branch-join subgraph", - ) - assert_eq( - relation_kind_name(root_rel(plan)), - str("JoinRel"), - "cross-store join with rhs dedup should still lower to JoinRel", - ) - assert_eq( - plan_contains_relation_kind(plan, str("FilterRel")), - true, - "deduped cross-store join lowering should still preserve nested filters", - ) - assert_eq( - plan_contains_relation_kind(plan, str("ReadRel")), - true, - "deduped cross-store join lowering should still preserve read roots", - ) + assert prism_cursors_share_store(left, right_joined) is false, "left and rhs join trees should still begin on distinct stores" + assert prism_cursor_store_node_count(joined) == 6, "cross-store adoption should dedup equivalent rhs filter branches instead of cloning both copies" + assert prism_cursor_authored_node_count(joined) == 6, "joined cursor should reach the left branch plus one deduped rhs branch-join subgraph" + assert relation_kind_name(root_rel(plan)) == str("JoinRel"), "cross-store join with rhs dedup should still lower to JoinRel" + assert plan_contains_relation_kind(plan, str("FilterRel")), "deduped cross-store join lowering should still preserve nested filters" + assert plan_contains_relation_kind(plan, str("ReadRel")), "deduped cross-store join lowering should still preserve read roots" def test_prism__cross_store_join_dedups_equivalent_rhs_multistep_branches() -> None: # -- Arrange -- - left: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).filter(always_false()) - right_base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders_archive")) + left: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_false()) + right_base: PrismCursor[Order] = prism_cursor_named_table(str("orders_archive")) right_left: PrismCursor[Order] = right_base.filter(always_false()).order_by() right_right: PrismCursor[Order] = right_base.filter(always_false()).order_by() @@ -283,241 +170,141 @@ def test_prism__cross_store_join_dedups_equivalent_rhs_multistep_branches() -> N plan = joined.to_substrait_plan() # -- Assert -- - assert_eq( - prism_cursors_share_store(left, right_joined), - false, - "final join should still cross stores for independently rooted branches", - ) - assert_eq( - prism_cursor_store_node_count(joined), - 7, - "cross-store adoption should dedup structurally equivalent multi-step rhs branches", - ) - assert_eq( - prism_cursor_authored_node_count(joined), - 7, - "joined cursor should still reach exactly one adopted copy of each rhs branch shape", - ) - assert_eq( - relation_kind_name(root_rel(plan)), - str("JoinRel"), - "deduped cross-store multistep branches should still lower to JoinRel", - ) - assert_eq( - plan_contains_relation_kind(plan, str("SortRel")), - true, - "rhs multistep branch sort should remain present after adoption", - ) - assert_eq( - plan_contains_relation_kind(plan, str("FilterRel")), - true, - "rhs multistep branch filter should remain present after adoption", - ) - assert_eq( - plan_contains_relation_kind(plan, str("ReadRel")), - true, - "deduped multistep adoption should still preserve read roots", - ) + assert prism_cursors_share_store(left, right_joined) is false, "final join should still cross stores for independently rooted branches" + assert prism_cursor_store_node_count(joined) == 7, "cross-store adoption should dedup structurally equivalent multi-step rhs branches" + assert prism_cursor_authored_node_count(joined) == 7, "joined cursor should still reach exactly one adopted copy of each rhs branch shape" + assert relation_kind_name(root_rel(plan)) == str("JoinRel"), "deduped cross-store multistep branches should still lower to JoinRel" + assert plan_contains_relation_kind(plan, str("SortRel")), "rhs multistep branch sort should remain present after adoption" + assert plan_contains_relation_kind(plan, str("FilterRel")), "rhs multistep branch filter should remain present after adoption" + assert plan_contains_relation_kind(plan, str("ReadRel")), "deduped multistep adoption should still preserve read roots" def test_prism__cursor_native_nodes_cover_current_method_surface() -> None: # -- Arrange -- - projected: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).select() - grouped: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).group_by([col("id")]) + _register_projection_test_schema(str("orders")) + projected: PrismCursor[Order] = prism_cursor_named_table(str("orders")).select() + grouped: PrismCursor[Order] = prism_cursor_named_table(str("orders")).group_by([col("id")]) # -- Act -- aggregated: PrismCursor[Order] = grouped.agg([count()]) - ordered: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).order_by() - limited: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).limit(10) - exploded: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).explode() + ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by() + limited: PrismCursor[Order] = prism_cursor_named_table(str("orders")).limit(10) + exploded: PrismCursor[Order] = prism_cursor_named_table(str("orders")).explode() # -- Assert -- - assert_eq(prism_cursor_tip_kind_name(projected), str("Project"), "select should append a native project node") - assert_eq(prism_cursor_tip_kind_name(grouped), str("GroupBy"), "group_by should append a native group node") - assert_eq(prism_cursor_tip_kind_name(aggregated), str("Aggregate"), "agg should append a native aggregate node") - assert_eq(prism_cursor_tip_kind_name(ordered), str("OrderBy"), "order_by should append a native sort node") - assert_eq(prism_cursor_tip_kind_name(limited), str("Limit"), "limit should append a native limit node") - assert_eq(prism_cursor_tip_kind_name(exploded), str("Explode"), "explode should append a native explode node") + assert prism_cursor_tip_kind_name(projected) == str("Project"), "select should append a native project node" + assert prism_cursor_tip_kind_name(grouped) == str("GroupBy"), "group_by should append a native group node" + assert prism_cursor_tip_kind_name(aggregated) == str("Aggregate"), "agg should append a native aggregate node" + assert prism_cursor_tip_kind_name(ordered) == str("OrderBy"), "order_by should append a native sort node" + assert prism_cursor_tip_kind_name(limited) == str("Limit"), "limit should append a native limit node" + assert prism_cursor_tip_kind_name(exploded) == str("Explode"), "explode should append a native explode node" def test_prism__rewrite_eliminates_filter_true_by_default() -> None: # -- Arrange -- - cursor: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).filter(always_true()) + cursor: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_true()) # -- Act -- plan = cursor.to_substrait_plan() # -- Assert -- - assert_eq(prism_cursor_authored_node_count(cursor), 2, "authored history should still keep the explicit filter node") - assert_eq(prism_cursor_rewritten_node_count(cursor), 1, "rewrite should eliminate filter(true) from lowered shape") - assert_eq( - prism_cursor_rewritten_tip_kind_name(cursor), - str("ReadNamedTable"), - "rewritten tip should become the read root after eliminating filter(true)", - ) - assert_eq(prism_cursor_rewrite_applied_rule_count(cursor), 1, "one rewrite rule should be recorded for filter(true)") - assert_eq( - prism_cursor_rewrite_applied_rule_name_at(cursor, 0), - str("eliminate_filter_true"), - "rewrite trace should include the filter elimination rule", - ) - assert_eq( - relation_kind_name(root_rel(plan)), - str("ReadRel"), - "filter(true) should lower to ReadRel root after canonical rewrite", - ) + assert prism_cursor_authored_node_count(cursor) == 2, "authored history should still keep the explicit filter node" + assert prism_cursor_rewritten_node_count(cursor) == 1, "rewrite should eliminate filter(true) from lowered shape" + assert prism_cursor_rewritten_tip_kind_name(cursor) == str("ReadNamedTable"), "rewritten tip should become the read root after eliminating filter(true)" + assert prism_cursor_rewrite_applied_rule_count(cursor) == 1, "one rewrite rule should be recorded for filter(true)" + assert prism_cursor_rewrite_applied_rule_name_at(cursor, 0) == str("eliminate_filter_true"), "rewrite trace should include the filter elimination rule" + assert relation_kind_name(root_rel(plan)) == str("ReadRel"), "filter(true) should lower to ReadRel root after canonical rewrite" def test_prism__rewrite_collapses_adjacent_limits_projects_and_order_by() -> None: # -- Arrange -- - limited: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).limit(10).limit(3) - projected: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).select().select() + limited: PrismCursor[Order] = prism_cursor_named_table(str("orders")).limit(10).limit(3) + projected: PrismCursor[Order] = prism_cursor_named_table(str("orders")).select().select() # -- Act -- - ordered: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).order_by().order_by() + ordered: PrismCursor[Order] = prism_cursor_named_table(str("orders")).order_by().order_by() # -- Assert -- - assert_eq(prism_cursor_authored_node_count(limited), 3, "authored history should keep both limit nodes") - assert_eq( - prism_cursor_rewritten_node_count(limited), - 2, - "rewritten view should collapse adjacent limits to one node", - ) - assert_eq( - prism_cursor_rewrite_applied_rule_name_at(limited, 0), - str("collapse_adjacent_limit"), - "limit rewrite trace should record adjacent limit collapse", - ) - assert_eq(prism_cursor_authored_node_count(projected), 3, "authored history should keep both project nodes") - assert_eq(prism_cursor_rewritten_node_count(projected), 2, "rewritten view should collapse adjacent projects") - assert_eq( - prism_cursor_rewrite_applied_rule_name_at(projected, 0), - str("collapse_adjacent_project"), - "project rewrite trace should record adjacent project collapse", - ) - assert_eq(prism_cursor_authored_node_count(ordered), 3, "authored history should keep both order_by nodes") - assert_eq(prism_cursor_rewritten_node_count(ordered), 2, "rewritten view should collapse adjacent order_by nodes") - assert_eq( - prism_cursor_rewrite_applied_rule_name_at(ordered, 0), - str("collapse_adjacent_order_by"), - "order_by rewrite trace should record adjacent order_by collapse", - ) + assert prism_cursor_authored_node_count(limited) == 3, "authored history should keep both limit nodes" + assert prism_cursor_rewritten_node_count(limited) == 2, "rewritten view should collapse adjacent limits to one node" + assert prism_cursor_rewrite_applied_rule_name_at(limited, 0) == str("collapse_adjacent_limit"), "limit rewrite trace should record adjacent limit collapse" + assert prism_cursor_authored_node_count(projected) == 3, "authored history should keep both project nodes" + assert prism_cursor_rewritten_node_count(projected) == 2, "rewritten view should collapse adjacent projects" + assert prism_cursor_rewrite_applied_rule_name_at(projected, 0) == str("collapse_adjacent_project"), "project rewrite trace should record adjacent project collapse" + assert prism_cursor_authored_node_count(ordered) == 3, "authored history should keep both order_by nodes" + assert prism_cursor_rewritten_node_count(ordered) == 2, "rewritten view should collapse adjacent order_by nodes" + assert prism_cursor_rewrite_applied_rule_name_at(ordered, 0) == str("collapse_adjacent_order_by"), "order_by rewrite trace should record adjacent order_by collapse" def test_prism__rewrite_collapses_adjacent_aggregates_by_merging_measures() -> None: # -- Arrange -- - aggregated: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).group_by([col("id")]).agg( - [sum(col("id"))], - ).agg([count()]) + _register_projection_test_schema(str("orders")) + aggregated: PrismCursor[Order] = prism_cursor_named_table(str("orders")).group_by([col("id")]).agg([sum(col("id"))]).agg( + [count()], + ) # -- Act -- output_cols = prism_cursor_output_columns(aggregated) plan = aggregated.to_substrait_plan() # -- Assert -- - assert_eq( - prism_cursor_authored_node_count(aggregated), - 4, - "authored history should keep read, group_by, and both aggregate nodes", - ) - assert_eq( - prism_cursor_rewritten_node_count(aggregated), - 3, - "rewritten view should collapse adjacent aggregate nodes", - ) - assert_eq( - prism_cursor_rewrite_applied_rule_name_at(aggregated, 0), - str("collapse_adjacent_aggregate"), - "aggregate rewrite trace should record adjacent aggregate collapse", - ) - assert_eq(len(output_cols), 3, "merged aggregate output columns should preserve group key plus both measures") - assert_eq(output_cols[0], "id", "group key should stay first after aggregate rewrite collapse") - assert_eq(output_cols[1], "sum_id", "existing aggregate measures should survive adjacent aggregate collapse") - assert_eq(output_cols[2], "count", "later aggregate measures should append after existing measures") - assert_eq( - relation_kind_name(root_rel(plan)), - str("AggregateRel"), - "collapsed aggregate rewrite should still lower through AggregateRel", - ) + assert prism_cursor_authored_node_count(aggregated) == 4, "authored history should keep read, group_by, and both aggregate nodes" + assert prism_cursor_rewritten_node_count(aggregated) == 3, "rewritten view should collapse adjacent aggregate nodes" + assert prism_cursor_rewrite_applied_rule_name_at(aggregated, 0) == str("collapse_adjacent_aggregate"), "aggregate rewrite trace should record adjacent aggregate collapse" + assert len(output_cols) == 3, "merged aggregate output columns should preserve group key plus both measures" + assert output_cols[0] == "id", "group key should stay first after aggregate rewrite collapse" + assert output_cols[1] == "sum_id", "existing aggregate measures should survive adjacent aggregate collapse" + assert output_cols[2] == "count", "later aggregate measures should append after existing measures" + assert relation_kind_name(root_rel(plan)) == str("AggregateRel"), "collapsed aggregate rewrite should still lower through AggregateRel" def test_prism__with_column_tracks_output_columns_and_collapses_adjacent_projects() -> None: # -- Arrange -- _register_projection_test_schema("orders_projection_prism") - base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders_projection_prism")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders_projection_prism")) # -- Act -- - projected: PrismCursor[Order] = base.with_column("double_id", mul(col("id"), int_expr(2))).with_column( + projected: PrismCursor[Order] = base.with_column("double_id", mul(col("id"), lit(2))).with_column( "triple_id", - mul(col("id"), int_expr(3)), + mul(col("id"), lit(3)), ) output_cols = prism_cursor_output_columns(projected) plan = projected.to_substrait_plan() # -- Assert -- - assert_eq(prism_cursor_tip_kind_name(projected), str("Project"), "with_column should append a native project node") - assert_eq( - len(output_cols), - 3, - "projection output columns should include the original column plus both derived names", - ) - assert_eq(output_cols[0], "id", "projection output columns should preserve existing source ordinal positions") - assert_eq(output_cols[1], "double_id", "projection output columns should preserve appended derived-column names") - assert_eq(output_cols[2], "triple_id", "multiple derived columns should preserve append order") - assert_eq( - prism_cursor_rewritten_node_count(projected), - 2, - "adjacent with_column projections should collapse into one rewritten project", - ) - assert_eq( - prism_cursor_rewrite_applied_rule_name_at(projected, 0), - str("collapse_adjacent_project"), - "projection rewrite trace should record adjacent project collapse", - ) - assert_eq( - relation_kind_name(root_rel(plan)), - str("ProjectRel"), - "with_column should lower through ProjectRel at the plan boundary", - ) + assert prism_cursor_tip_kind_name(projected) == str("Project"), "with_column should append a native project node" + assert len(output_cols) == 3, "projection output columns should include the original column plus both derived names" + assert output_cols[0] == "id", "projection output columns should preserve existing source ordinal positions" + assert output_cols[1] == "double_id", "projection output columns should preserve appended derived-column names" + assert output_cols[2] == "triple_id", "multiple derived columns should preserve append order" + assert prism_cursor_rewritten_node_count(projected) == 2, "adjacent with_column projections should collapse into one rewritten project" + assert prism_cursor_rewrite_applied_rule_name_at(projected, 0) == str("collapse_adjacent_project"), "projection rewrite trace should record adjacent project collapse" + assert relation_kind_name(root_rel(plan)) == str("ProjectRel"), "with_column should lower through ProjectRel at the plan boundary" def test_prism__rewrite_explain_origin_mapping_matches_rewritten_nodes() -> None: # -- Arrange -- - cursor: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")).filter(always_true()).limit(5).limit(4) + cursor: PrismCursor[Order] = prism_cursor_named_table(str("orders")).filter(always_true()).limit(5).limit(4) # -- Act -- rewritten_nodes = prism_cursor_rewritten_node_count(cursor) rewritten_origins = prism_cursor_rewritten_origin_count(cursor) # -- Assert -- - assert_eq( - rewritten_nodes, - rewritten_origins, - "rewritten origin mapping cardinality should match rewritten node cardinality", - ) + assert rewritten_nodes == rewritten_origins, "rewritten origin mapping cardinality should match rewritten node cardinality" def test_prism__cursor_methods_match_apply_helpers() -> None: # -- Arrange -- - base: PrismCursor[Order] = prism_cursor_named_table[Order](str("orders")) + base: PrismCursor[Order] = prism_cursor_named_table(str("orders")) # -- Act -- via_methods = base.filter(always_false()).select().limit(3) via_helpers = prism_cursor_apply_limit(prism_cursor_apply_select(prism_cursor_apply_filter(base, always_false())), 3) # -- Assert -- - assert_eq( - prism_cursor_tip_kind_name(via_methods), - prism_cursor_tip_kind_name(via_helpers), - "method and helper paths should produce the same tip kind", - ) - assert_eq( - prism_cursor_authored_node_count(via_methods), - prism_cursor_authored_node_count(via_helpers), - "method and helper paths should produce equivalent authored history shape", - ) - assert_eq( - relation_kind_name(root_rel(via_methods.to_substrait_plan())), - relation_kind_name(root_rel(via_helpers.to_substrait_plan())), - "method and helper paths should lower to equivalent root relation kinds", - ) + assert prism_cursor_tip_kind_name(via_methods) == prism_cursor_tip_kind_name(via_helpers), "method and helper paths should produce the same tip kind" + assert prism_cursor_authored_node_count(via_methods) == prism_cursor_authored_node_count(via_helpers), "method and helper paths should produce equivalent authored history shape" + assert relation_kind_name(root_rel(via_methods.to_substrait_plan())) == relation_kind_name( + root_rel(via_helpers.to_substrait_plan()), + ), "method and helper paths should lower to equivalent root relation kinds" diff --git a/tests/test_session.incn b/tests/test_session.incn index 0507c32..176d861 100644 --- a/tests/test_session.incn +++ b/tests/test_session.incn @@ -1,10 +1,10 @@ """RFC 004 Phase 2: Session construction and backend seam.""" -from std.testing import parametrize +from std.testing import assert_is_err, assert_is_ok, fail, parametrize from rust::std::path import Path from backends import DataFusion, csv_source, parquet_source -from dataset import lazy_frame_named_table -from session import Session +from dataset import LazyFrame, lazy_frame_named_table +from session import Session, SessionError, SessionErrorKind from substrait.inspect import read_kind_name, root_rel @@ -70,7 +70,7 @@ def test_session__public_types_construct_locally() -> None: # -- Assert -- # Compile-shape test: construction and generic touch calls typecheck for public Session surface. - assert true, "public Session surface types should be constructible in ordinary local code" + pass def test_session__register_tracks_named_sources() -> None: @@ -79,16 +79,11 @@ def test_session__register_tracks_named_sources() -> None: mut session = Session.default() # -- Act -- - # register two sources with different logical names and kinds, asserting that no unexpected errors are raised - match session.register("orders", csv_source(ORDERS_CSV_FIXTURE)): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - - match session.register("orders_archive", parquet_source(ORDERS_PARQUET_FIXTURE)): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + assert_is_ok(session.register("orders", csv_source(ORDERS_CSV_FIXTURE)), "orders registration should succeed") + assert_is_ok( + session.register("orders_archive", parquet_source(ORDERS_PARQUET_FIXTURE)), + "orders_archive registration should succeed", + ) # -- Assert -- assert session.registration_count() == 2, "register should retain two source bindings" @@ -105,17 +100,14 @@ def test_session__table_returns_prism_backed_named_table_plan() -> None: """ # -- Arrange -- mut session = Session.default() - match session.register("orders", csv_source(ORDERS_CSV_FIXTURE)): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + assert_is_ok(session.register("orders", csv_source(ORDERS_CSV_FIXTURE)), "orders registration should succeed") - # -- Act & Assert -- - match session.table[Order]("orders"): - Ok(lazy) => - assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "table lookup should create a logical named-table read" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok(session.table("orders"), "registered table lookup should succeed") + + # -- Assert -- + assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "table lookup should create a logical named-table read" + assert len(lazy.schema().declared_columns) > 0, "registered CSV tables should bind declared columns for downstream execution" def test_session__read_csv_registers_source_and_returns_named_table_plan() -> None: @@ -126,42 +118,45 @@ def test_session__read_csv_registers_source_and_returns_named_table_plan() -> No # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - assert session.registration_count() == 1, "read_csv should register one CSV source" - assert session.has_registration("orders") is true, "read_csv should register the provided logical name" - assert session.registered_source_kind("orders") == "csv", "read_csv should register CSV source kind" - assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "read_csv should still lower through a logical named-table root" - declared_cols = lazy.declared_columns() - planned_cols = lazy.planned_columns() - assert len(declared_cols) > 0, "read_csv should infer declared columns from the CSV header" - assert declared_cols[0] == "id", "declared columns should preserve source column order" - assert len(planned_cols) == len(declared_cols), "lazy planned_columns() should currently mirror declared columns" - assert len(lazy.columns()) == len(planned_cols), "lazy columns() should mirror planned columns in deferred mode" - lazy_schema = lazy.schema() - assert len(lazy_schema.declared_columns) == len(declared_cols), "lazy schema() should surface declared columns" - assert len(lazy_schema.planned_columns) == len(planned_cols), "lazy schema() should surface planned columns" - assert len(lazy_schema.resolved_columns) == 0, "lazy schema() should keep resolved columns empty pre-collect" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + declared_cols = lazy.declared_columns() + planned_cols = lazy.planned_columns() + lazy_schema = lazy.schema() + + # -- Assert -- + assert session.registration_count() == 1, "read_csv should register one CSV source" + assert session.has_registration("orders") is true, "read_csv should register the provided logical name" + assert session.registered_source_kind("orders") == "csv", "read_csv should register CSV source kind" + assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "read_csv should still lower through a logical named-table root" + assert len(declared_cols) > 0, "read_csv should infer declared columns from the CSV header" + assert declared_cols[0] == "id", "declared columns should preserve source column order" + assert len(planned_cols) == len(declared_cols), "lazy planned_columns() should currently mirror declared columns" + assert len(lazy.columns()) == len(planned_cols), "lazy columns() should mirror planned columns in deferred mode" + assert len(lazy_schema.declared_columns) == len(declared_cols), "lazy schema() should surface declared columns" + assert len(lazy_schema.planned_columns) == len(planned_cols), "lazy schema() should surface planned columns" + assert len(lazy_schema.resolved_columns) == 0, "lazy schema() should keep resolved columns empty pre-collect" def test_session__read_csv_rejects_duplicate_logical_names() -> None: """read_csv should reject duplicate logical registration names with a typed duplicate_registration error.""" # -- Arrange -- mut session = Session.default() - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "initial orders CSV read should succeed", + ) + _touch(lazy) + + # -- Act -- + duplicate_result: Result[LazyFrame[Order], SessionError] = session.read_csv("orders", ORDERS_CSV_FIXTURE) + err = assert_is_err(duplicate_result, "duplicate read_csv logical name should fail") - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(_) => - assert false, "expected duplicate read_csv logical name to fail" - Err(err) => - assert err.kind == "duplicate_registration", "read_csv should surface duplicate-registration errors" + # -- Assert -- + assert err.kind == SessionErrorKind.DuplicateRegistration, "read_csv should surface duplicate-registration errors" def test_session__read_parquet_registers_source_and_executes_named_table_plan() -> None: @@ -171,27 +166,21 @@ def test_session__read_parquet_registers_source_and_executes_named_table_plan() parquet_uri = _target_uri("session_read_parquet_input") # -- Act -- - match session.read_csv[Order]("orders_csv", ORDERS_CSV_FIXTURE): - Ok(csv_lazy) => - match session.write_parquet(csv_lazy, parquet_uri): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + csv_lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders_csv", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + assert_is_ok(session.write_parquet(csv_lazy, parquet_uri), "CSV fixture should be written as Parquet") + lazy: LazyFrame[Order] = assert_is_ok( + session.read_parquet("orders_parquet", parquet_uri), + "generated Parquet fixture should load", + ) # -- Assert -- - match session.read_parquet[Order]("orders_parquet", parquet_uri): - Ok(lazy) => - assert session.has_registration("orders_parquet") is true, "read_parquet should register the provided logical name" - assert session.registered_source_kind("orders_parquet") == "parquet", "read_parquet should register Parquet source kind" - assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "read_parquet should lower through a logical named-table root" - match session.execute(lazy): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + assert session.has_registration("orders_parquet") is true, "read_parquet should register the provided logical name" + assert session.registered_source_kind("orders_parquet") == "parquet", "read_parquet should register Parquet source kind" + assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "read_parquet should lower through a logical named-table root" + assert_is_ok(session.execute(lazy), "generated Parquet fixture should execute") def test_session__read_arrow_registers_source_and_returns_named_table_plan() -> None: @@ -199,15 +188,17 @@ def test_session__read_arrow_registers_source_and_returns_named_table_plan() -> # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_arrow[Order]("orders_arrow", ORDERS_ARROW_FIXTURE): - Ok(lazy) => - assert session.registration_count() == 1, "read_arrow should register one Arrow source" - assert session.has_registration("orders_arrow") is true, "read_arrow should register the provided logical name" - assert session.registered_source_kind("orders_arrow") == "arrow", "read_arrow should register Arrow source kind" - assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "read_arrow should lower through a logical named-table root" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_arrow("orders_arrow", ORDERS_ARROW_FIXTURE), + "Arrow fixture should load", + ) + + # -- Assert -- + assert session.registration_count() == 1, "read_arrow should register one Arrow source" + assert session.has_registration("orders_arrow") is true, "read_arrow should register the provided logical name" + assert session.registered_source_kind("orders_arrow") == "arrow", "read_arrow should register Arrow source kind" + assert read_kind_name(root_rel(lazy.to_substrait_plan())) == "NamedTable", "read_arrow should lower through a logical named-table root" def test_session__execute_runs_plan_and_preserves_carrier_shape() -> None: @@ -215,16 +206,15 @@ def test_session__execute_runs_plan_and_preserves_carrier_shape() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - match session.execute(lazy): - Ok(executed) => - assert read_kind_name(root_rel(executed.to_substrait_plan())) == "NamedTable", "execute should preserve deferred named-table carrier shape on success" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + executed = assert_is_ok(session.execute(lazy), "named-table execution should succeed") + + # -- Assert -- + assert read_kind_name(root_rel(executed.to_substrait_plan())) == "NamedTable", "execute should preserve deferred named-table carrier shape on success" def test_session__collect_returns_materialized_dataframe() -> None: @@ -232,44 +222,42 @@ def test_session__collect_returns_materialized_dataframe() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - match session.collect(lazy): - Ok(df) => - assert read_kind_name(root_rel(df.to_substrait_plan())) == "NamedTable", "collect should return a typed DataFrame carrier" - assert len(df.preview_text()) > 0, "collect should populate a non-empty preview render for display/debug paths" - assert df.row_count() > 0, "collect should record structural row-count metadata" - declared_cols = df.declared_columns() - planned_cols = df.planned_columns() - resolved_cols = df.resolved_columns() - assert len(declared_cols) > 0, "collect should preserve declared schema columns" - assert len(planned_cols) == len(declared_cols), "DataFrame planned_columns() should currently mirror declared columns" - assert len(resolved_cols) > 0, "collect should expose resolved columns from structured materialization metadata" - assert len(df.columns()) == len(resolved_cols), "DataFrame columns() should prefer resolved columns when available" - assert resolved_cols[0] == "id", "resolved columns should preserve output column order" - schema = df.schema() - assert len(schema.declared_columns) == len(declared_cols), "schema() should surface declared columns" - assert len(schema.planned_columns) == len(planned_cols), "schema() should surface planned columns" - assert len(schema.resolved_columns) == len(resolved_cols), "schema() should surface resolved columns" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + df = assert_is_ok(session.collect(lazy), "orders collect should succeed") + declared_cols = df.declared_columns() + planned_cols = df.planned_columns() + resolved_cols = df.resolved_columns() + schema = df.schema() + + # -- Assert -- + assert read_kind_name(root_rel(df.to_substrait_plan())) == "NamedTable", "collect should return a typed DataFrame carrier" + assert len(df.preview_text()) > 0, "collect should populate a non-empty preview render for display/debug paths" + assert df.row_count() > 0, "collect should record structural row-count metadata" + assert len(declared_cols) > 0, "collect should preserve declared schema columns" + assert len(planned_cols) == len(declared_cols), "DataFrame planned_columns() should currently mirror declared columns" + assert len(resolved_cols) > 0, "collect should expose resolved columns from structured materialization metadata" + assert len(df.columns()) == len(resolved_cols), "DataFrame columns() should prefer resolved columns when available" + assert resolved_cols[0] == "id", "resolved columns should preserve output column order" + assert len(schema.declared_columns) == len(declared_cols), "schema() should surface declared columns" + assert len(schema.planned_columns) == len(planned_cols), "schema() should surface planned columns" + assert len(schema.resolved_columns) == len(resolved_cols), "schema() should surface resolved columns" def test_session__lazy_collect_fails_without_active_session() -> None: """lazy.collect should fail with no_active_session when no Session has been activated.""" # -- Arrange -- Session.clear_active_session() - lazy = lazy_frame_named_table[Order]("orders_never_registered_in_active_session_test") + lazy: LazyFrame[Order] = lazy_frame_named_table("orders_never_registered_in_active_session_test") - # -- Act & Assert -- - match lazy.collect(): - Ok(_) => - assert false, "expected lazy.collect() to fail without an active Session" - Err(err) => - assert err.kind == "no_active_session", "lazy.collect should fail with no_active_session when no session is active" + # -- Act -- + err = assert_is_err(lazy.collect(), "lazy.collect should fail without an active Session") + + # -- Assert -- + assert err.kind == SessionErrorKind.NoActiveSession, "lazy.collect should fail with no_active_session when no session is active" def test_session__activate_enables_lazy_collect_with_live_session() -> None: @@ -278,17 +266,16 @@ def test_session__activate_enables_lazy_collect_with_live_session() -> None: Session.clear_active_session() mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - session.activate() - match lazy.collect(): - Ok(df) => - assert read_kind_name(root_rel(df.to_substrait_plan())) == "NamedTable", "lazy.collect should delegate through the active session" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + session.activate() + df = assert_is_ok(lazy.collect(), "lazy.collect should delegate through the active session") + + # -- Assert -- + assert read_kind_name(root_rel(df.to_substrait_plan())) == "NamedTable", "lazy.collect should delegate through the active session" def test_session__lazy_collect_uses_currently_active_session() -> None: @@ -298,18 +285,17 @@ def test_session__lazy_collect_uses_currently_active_session() -> None: mut source_session = Session.default() other_session = Session.default() - # -- Act & Assert -- - match source_session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - source_session.activate() - other_session.activate() - match lazy.collect(): - Ok(_) => - assert false, "expected lazy.collect to resolve against the currently active session" - Err(err) => - assert err.kind == "unknown_table", "active-session switching should route collect through the latest active session" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + source_session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + source_session.activate() + other_session.activate() + err = assert_is_err(lazy.collect(), "lazy.collect should resolve against the currently active session") + + # -- Assert -- + assert err.kind == SessionErrorKind.UnknownTable, "active-session switching should route collect through the latest active session" def test_session__clear_active_session_forces_no_active_session() -> None: @@ -317,18 +303,17 @@ def test_session__clear_active_session_forces_no_active_session() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - session.activate() - Session.clear_active_session() - match lazy.collect(): - Ok(_) => - assert false, "expected clear_active_session to remove active-session binding" - Err(err) => - assert err.kind == "no_active_session", "clear_active_session should force no_active_session for lazy.collect" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + session.activate() + Session.clear_active_session() + err = assert_is_err(lazy.collect(), "clear_active_session should remove active-session binding") + + # -- Assert -- + assert err.kind == SessionErrorKind.NoActiveSession, "clear_active_session should force no_active_session for lazy.collect" def test_session__read_transform_collect_write_csv_end_to_end() -> None: @@ -337,23 +322,19 @@ def test_session__read_transform_collect_write_csv_end_to_end() -> None: mut session = Session.default() output_uri = _target_uri("session_collect_then_write_output.csv") - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - transformed = lazy.limit(2) - match session.collect(transformed.clone()): - Ok(df) => - assert len(df.preview_text()) > 0, "collect should materialize preview text before downstream sink operations" - assert df.row_count() > 0, "collect should record row-count metadata before downstream sink operations" - match session.write_csv(transformed, output_uri): - Ok(_) => - assert Path.new(output_uri).exists() is true, "write_csv should still succeed on the Session path after collect" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + transformed = lazy.limit(2) + df = assert_is_ok(session.collect(transformed.clone()), "transformed collect should succeed") + assert_is_ok(session.write_csv(transformed, output_uri), "write_csv should succeed after collect") + + # -- Assert -- + assert len(df.preview_text()) > 0, "collect should materialize preview text before downstream sink operations" + assert df.row_count() > 0, "collect should record row-count metadata before downstream sink operations" + assert Path.new(output_uri).exists() is true, "write_csv should still succeed on the Session path after collect" @parametrize("sink_kind, output_name", [("csv", "session_write_csv_output.csv"), ("parquet", "session_write_parquet_output")]) @@ -363,25 +344,18 @@ def test_session__write_succeeds_for_deferred_lazyframe(sink_kind: str, output_n mut session = Session.default() output_uri = _target_uri(output_name) - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - if sink_kind == "csv": - match session.write_csv(lazy, output_uri): - Ok(_) => - assert true, "write_csv should complete successfully" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - elif sink_kind == "parquet": - match session.write_parquet(lazy, output_uri): - Ok(_) => - assert true, "write_parquet should complete successfully" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - else: - assert false, f"unexpected sink kind: {sink_kind}" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + match sink_kind: + "csv" => assert_is_ok(session.write_csv(lazy, output_uri), "CSV write should succeed") + "parquet" => assert_is_ok(session.write_parquet(lazy, output_uri), "Parquet write should succeed") + _ => fail(f"unexpected sink kind: {sink_kind}") + + # -- Assert -- + assert Path.new(output_uri).exists() is true, "write should create the requested output artifact" def test_session__write_rejects_empty_sink_uri() -> None: @@ -389,33 +363,31 @@ def test_session__write_rejects_empty_sink_uri() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - match session.write_csv(lazy, ""): - Ok(_) => - assert false, "expected empty CSV sink URI to fail" - Err(err) => - assert err.kind == "invalid_sink", "write_csv should reject empty sink URIs" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[Order] = assert_is_ok( + session.read_csv("orders", ORDERS_CSV_FIXTURE), + "orders CSV fixture should load", + ) + err = assert_is_err(session.write_csv(lazy, ""), "empty CSV sink URI should fail") + + # -- Assert -- + assert err.kind == SessionErrorKind.InvalidSink, "write_csv should reject empty sink URIs" def test_session__register_rejects_duplicates_with_typed_error() -> None: """register should reject duplicate logical names with duplicate_registration.""" # -- Arrange -- mut session = Session.default() - match session.register("orders", csv_source(ORDERS_CSV_FIXTURE)): - Ok(_) => pass - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + assert_is_ok( + session.register("orders", csv_source(ORDERS_CSV_FIXTURE)), + "initial orders registration should succeed", + ) - # -- Act & Assert -- - match session.register("orders", csv_source(ORDERS_CSV_FIXTURE)): - Ok(_) => - assert false, "expected duplicate registration to fail" - Err(err) => - assert err.kind == "duplicate_registration", "duplicate names should surface a typed registration error" + # -- Act -- + err = assert_is_err(session.register("orders", csv_source(ORDERS_CSV_FIXTURE)), "duplicate registration should fail") + + # -- Assert -- + assert err.kind == SessionErrorKind.DuplicateRegistration, "duplicate names should surface a typed registration error" def test_session__table_reports_unknown_name_with_typed_error() -> None: @@ -423,12 +395,12 @@ def test_session__table_reports_unknown_name_with_typed_error() -> None: # -- Arrange -- session = Session.default() - # -- Act & Assert -- - match session.table[Order]("missing"): - Ok(_) => - assert false, "expected unknown table lookup to fail" - Err(err) => - assert err.kind == "unknown_table", "unknown table lookup should surface a typed binding error" + # -- Act -- + table_result: Result[LazyFrame[Order], SessionError] = session.table("missing") + err = assert_is_err(table_result, "unknown table lookup should fail") + + # -- Assert -- + assert err.kind == SessionErrorKind.UnknownTable, "unknown table lookup should surface a typed binding error" def test_session__empty_registration_input_is_rejected() -> None: @@ -436,15 +408,13 @@ def test_session__empty_registration_input_is_rejected() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.register("", csv_source(ORDERS_CSV_FIXTURE)): - Ok(_) => - assert false, "expected empty logical name to fail" - Err(err) => - assert err.kind == "invalid_registration", "empty logical names should be rejected" - - match session.register("orders", csv_source("")): - Ok(_) => - assert false, "expected empty source uri to fail" - Err(err) => - assert err.kind == "invalid_registration", "empty source uris should be rejected" + # -- Act -- + empty_name_err = assert_is_err( + session.register("", csv_source(ORDERS_CSV_FIXTURE)), + "empty logical name should fail", + ) + empty_uri_err = assert_is_err(session.register("orders", csv_source("")), "empty source URI should fail") + + # -- Assert -- + assert empty_name_err.kind == SessionErrorKind.InvalidRegistration, "empty logical names should be rejected" + assert empty_uri_err.kind == SessionErrorKind.InvalidRegistration, "empty source uris should be rejected" diff --git a/tests/test_session_aggregates.incn b/tests/test_session_aggregates.incn index d081fe1..f40c31e 100644 --- a/tests/test_session_aggregates.incn +++ b/tests/test_session_aggregates.incn @@ -1,8 +1,9 @@ """End-to-end Session aggregate execution tests over the DataFusion backend.""" -from std.testing import assert_eq from functions import col, count, sum +from dataset import LazyFrame from session import Session +from std.testing import assert_is_ok @derive(Clone) @@ -25,40 +26,24 @@ def test_session_aggregates__grouped_collect_executes_sum_and_count() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - grouped = lazy.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) - match session.collect(grouped): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 2, "grouped aggregate collect should produce one row per customer") - resolved = df.resolved_columns() - # -- Assert -- - assert_eq(len(resolved), 3, "grouped aggregate output should contain one key plus two measures") - assert_eq(payload.contains("A"), true, "grouped aggregate output should contain customer A") - assert_eq(payload.contains("B"), true, "grouped aggregate output should contain customer B") - assert_eq( - payload.contains("25"), - true, - "grouped aggregate output should contain the summed amount for customer A", - ) - assert_eq( - payload.contains("7"), - true, - "grouped aggregate output should contain the summed amount for customer B", - ) - assert_eq( - payload.contains("2"), - true, - "grouped aggregate output should contain the row-count for customer A", - ) - assert_eq( - payload.contains("1"), - true, - "grouped aggregate output should contain the row-count for customer B", - ) - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + grouped = lazy.group_by([col("customer_id")]).agg([sum(col("amount")), count()]) + df = assert_is_ok(session.collect(grouped), "grouped aggregate collect should execute") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 2, "grouped aggregate collect should produce one row per customer" + assert len(resolved) == 3, "grouped aggregate output should contain one key plus two measures" + assert payload.contains("A"), "grouped aggregate output should contain customer A" + assert payload.contains("B"), "grouped aggregate output should contain customer B" + assert payload.contains("25"), "grouped aggregate output should contain the summed amount for customer A" + assert payload.contains("7"), "grouped aggregate output should contain the summed amount for customer B" + assert payload.contains("2"), "grouped aggregate output should contain the row-count for customer A" + assert payload.contains("1"), "grouped aggregate output should contain the row-count for customer B" def test_session_aggregates__global_collect_executes_count() -> None: @@ -66,23 +51,16 @@ def test_session_aggregates__global_collect_executes_count() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[Order]("orders", ORDERS_CSV_FIXTURE): - Ok(lazy) => - aggregated = lazy.agg([count()]) - match session.collect(aggregated): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 1, "global count aggregate should produce one output row") - resolved = df.resolved_columns() - # -- Assert -- - assert_eq(len(resolved), 1, "global count aggregate should return one measure column") - assert_eq( - payload.contains("3"), - true, - "global count aggregate should report the three rows in orders.csv", - ) - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[Order] = assert_is_ok(session.read_csv("orders", ORDERS_CSV_FIXTURE), "orders fixture should load") + aggregated = lazy.agg([count()]) + df = assert_is_ok(session.collect(aggregated), "global count aggregate collect should execute") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 1, "global count aggregate should produce one output row" + assert len(resolved) == 1, "global count aggregate should return one measure column" + assert payload.contains("3"), "global count aggregate should report the three rows in orders.csv" def test_session_aggregates__adjacent_agg_calls_merge_measures() -> None: @@ -90,42 +68,22 @@ def test_session_aggregates__adjacent_agg_calls_merge_measures() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - grouped = lazy.group_by([col("customer_id")]).agg([sum(col("amount"))]).agg([count()]) - match session.collect(grouped): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 2, "adjacent aggregate merge should still produce one row per group") - resolved = df.resolved_columns() - # -- Assert -- - assert_eq( - len(resolved), - 3, - "adjacent agg calls should preserve both the existing and appended measure columns", - ) - assert_eq(resolved[0], "customer_id", "group key should remain the first output column") - assert_eq(resolved[1], "sum_amount", "existing aggregate measure should survive a later agg append") - assert_eq(resolved[2], "count", "later aggregate measure should append after existing measures") - assert_eq( - payload.contains("25"), - true, - "adjacent agg merge should retain the summed amount for customer A", - ) - assert_eq( - payload.contains("7"), - true, - "adjacent agg merge should retain the summed amount for customer B", - ) - assert_eq( - payload.contains("2"), - true, - "adjacent agg merge should still include the later count measure for customer A", - ) - assert_eq( - payload.contains("1"), - true, - "adjacent agg merge should still include the later count measure for customer B", - ) - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + grouped = lazy.group_by([col("customer_id")]).agg([sum(col("amount"))]).agg([count()]) + df = assert_is_ok(session.collect(grouped), "adjacent aggregate collect should execute") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 2, "adjacent aggregate merge should still produce one row per group" + assert len(resolved) == 3, "adjacent agg calls should preserve both the existing and appended measure columns" + assert resolved[0] == "customer_id", "group key should remain the first output column" + assert resolved[1] == "sum_amount", "existing aggregate measure should survive a later agg append" + assert resolved[2] == "count", "later aggregate measure should append after existing measures" + assert payload.contains("25"), "adjacent agg merge should retain the summed amount for customer A" + assert payload.contains("7"), "adjacent agg merge should retain the summed amount for customer B" + assert payload.contains("2"), "adjacent agg merge should still include the later count measure for customer A" + assert payload.contains("1"), "adjacent agg merge should still include the later count measure for customer B" diff --git a/tests/test_session_backend_datafusion.incn b/tests/test_session_backend_datafusion.incn index f4f2b60..2f2f2a3 100644 --- a/tests/test_session_backend_datafusion.incn +++ b/tests/test_session_backend_datafusion.incn @@ -1,6 +1,7 @@ -from std.testing import assert_eq from backends import csv_source +from dataset import LazyFrame from session import Session +from std.testing import assert_is_ok from rust::std::path import Path @@ -12,37 +13,27 @@ pub model Order: def test_session_backend_datafusion__registered_named_table_executes_via_substrait() -> None: # -- Arrange -- mut session = Session.default() - fixture_uri = "../../../tests/fixtures/orders.csv" - match session.register("orders", csv_source(fixture_uri)): - Ok(_) => pass - Err(err) => assert_eq(true, false, err.error_message()) + fixture_uri = "tests/fixtures/orders.csv" + assert_is_ok(session.register("orders", csv_source(fixture_uri)), "fixture registration should succeed") # -- Act -- - match session.table[Order]("orders"): - Ok(lazy) => - match session.execute(lazy): - Ok(_) => - # -- Assert -- - pass - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[Order] = assert_is_ok(session.table("orders"), "registered table lookup should succeed") + execute_result = session.execute(lazy) + + # -- Assert -- + assert_is_ok(execute_result, "registered named-table execution should succeed") def test_session_backend_datafusion__session_write_csv_routes_through_execution_path() -> None: # -- Arrange -- mut session = Session.default() - output_uri = "../../../tests/target/session_backend_datafusion_output.csv" - fixture_uri = "../../../tests/fixtures/orders.csv" - match session.register("orders", csv_source(fixture_uri)): - Ok(_) => pass - Err(err) => assert_eq(true, false, err.error_message()) + output_uri = "tests/target/session_backend_datafusion_output.csv" + fixture_uri = "tests/fixtures/orders.csv" + assert_is_ok(session.register("orders", csv_source(fixture_uri)), "fixture registration should succeed") # -- Act -- - match session.table[Order]("orders"): - Ok(lazy) => - match session.write_csv(lazy, output_uri): - Ok(_) => - # -- Assert -- - assert_eq(Path.new(output_uri).exists(), true, "session.write_csv should produce an output artifact") - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[Order] = assert_is_ok(session.table("orders"), "registered table lookup should succeed") + assert_is_ok(session.write_csv(lazy, output_uri), "session.write_csv should route through execution") + + # -- Assert -- + assert Path.new(output_uri).exists(), "session.write_csv should produce an output artifact" diff --git a/tests/test_session_filters.incn b/tests/test_session_filters.incn index 866b10f..73fc626 100644 --- a/tests/test_session_filters.incn +++ b/tests/test_session_filters.incn @@ -1,8 +1,9 @@ """End-to-end Session filter execution tests over the DataFusion backend.""" -from std.testing import assert_eq -from functions import col, eq, gt, int_lit, str_lit +from functions import col, eq, gt, lit +from dataset import LazyFrame from session import Session +from std.testing import assert_is_ok @derive(Clone) @@ -25,21 +26,20 @@ def test_session_filters__collect_executes_gt_predicate() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[OrderLine]("order_lines", ORDER_LINES_CSV_FIXTURE): - Ok(lazy) => - filtered = lazy.filter(gt(col("qty"), int_lit(2))) - match session.collect(filtered): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 2, "qty > 2 should materialize two matching rows") - assert payload.contains("SKU-LAMP-11"), "qty > 2 should keep the row with qty 4" - assert payload.contains("SKU-MAT-07"), "qty > 2 should keep the row with qty 3" - assert payload.contains("SKU-CHAIR-01") is false, "qty > 2 should exclude rows with qty 2" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[OrderLine] = assert_is_ok( + session.read_csv("order_lines", ORDER_LINES_CSV_FIXTURE), + "order lines fixture should load", + ) + filtered = lazy.filter(gt(col("qty"), lit(2))) + df = assert_is_ok(session.collect(filtered), "gt filter collect should execute") + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 2, "qty > 2 should materialize two matching rows" + assert payload.contains("SKU-LAMP-11"), "qty > 2 should keep the row with qty 4" + assert payload.contains("SKU-MAT-07"), "qty > 2 should keep the row with qty 3" + assert payload.contains("SKU-CHAIR-01") is false, "qty > 2 should exclude rows with qty 2" def test_session_filters__collect_executes_eq_predicate() -> None: @@ -47,20 +47,19 @@ def test_session_filters__collect_executes_eq_predicate() -> None: # -- Arrange -- mut session = Session.default() - # -- Act & Assert -- - match session.read_csv[OrderLine]("order_lines", ORDER_LINES_CSV_FIXTURE): - Ok(lazy) => - filtered = lazy.filter(eq(col("status"), str_lit("open"))) - match session.collect(filtered): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 3, "status == open should materialize the three open rows") - assert payload.contains("SKU-CHAIR-01"), "status == open should keep open rows" - assert payload.contains("SKU-DESK-02"), "status == open should keep all open rows" - assert payload.contains("SKU-MAT-07"), "status == open should keep the third open row" - assert payload.contains("SKU-LAMP-11") is false, "status == open should exclude closed rows" - assert payload.contains("SKU-MONITOR-24") is false, "status == open should exclude cancelled rows" - Err(err) => - assert err.kind == "__unexpected__", err.error_message() - Err(err) => - assert err.kind == "__unexpected__", err.error_message() + # -- Act -- + lazy: LazyFrame[OrderLine] = assert_is_ok( + session.read_csv("order_lines", ORDER_LINES_CSV_FIXTURE), + "order lines fixture should load", + ) + filtered = lazy.filter(eq(col("status"), lit("open"))) + df = assert_is_ok(session.collect(filtered), "eq filter collect should execute") + payload = df.preview_text() + + # -- Assert -- + assert df.row_count() == 3, "status == open should materialize the three open rows" + assert payload.contains("SKU-CHAIR-01"), "status == open should keep open rows" + assert payload.contains("SKU-DESK-02"), "status == open should keep all open rows" + assert payload.contains("SKU-MAT-07"), "status == open should keep the third open row" + assert payload.contains("SKU-LAMP-11") is false, "status == open should exclude closed rows" + assert payload.contains("SKU-MONITOR-24") is false, "status == open should exclude cancelled rows" diff --git a/tests/test_session_projection.incn b/tests/test_session_projection.incn index 8e37202..d4ee9ee 100644 --- a/tests/test_session_projection.incn +++ b/tests/test_session_projection.incn @@ -1,8 +1,9 @@ """End-to-end Session projection execution tests over the DataFusion backend.""" -from std.testing import assert_eq -from functions import add, col, int_expr, mul -from session import Session +from functions import add, col, lit, mul +from dataset import LazyFrame +from session import Session, SessionErrorKind +from std.testing import assert_is_err, assert_is_ok from substrait.inspect import root_names @@ -20,16 +21,18 @@ def test_session_projection__plan_root_names_match_append_projection() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - projected = lazy.with_column("double_amount", mul(col("amount"), int_expr(2))) - # -- Assert -- - names = root_names(projected.to_substrait_plan()) - assert_eq(len(names), 3, "project append path should publish one root name per projected output column") - assert_eq(names[0], "customer_id", "project append path should preserve the first input name") - assert_eq(names[1], "amount", "project append path should preserve the second input name") - assert_eq(names[2], "double_amount", "project append path should publish the appended alias") - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("double_amount", mul(col("amount"), lit(2))) + + # -- Assert -- + names = root_names(projected.to_substrait_plan()) + assert len(names) == 3, "project append path should publish one root name per projected output column" + assert names[0] == "customer_id", "project append path should preserve the first input name" + assert names[1] == "amount", "project append path should preserve the second input name" + assert names[2] == "double_amount", "project append path should publish the appended alias" def test_session_projection__plan_root_names_match_replace_projection() -> None: @@ -37,15 +40,17 @@ def test_session_projection__plan_root_names_match_replace_projection() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - projected = lazy.with_column("amount", add(col("amount"), int_expr(1))) - # -- Assert -- - names = root_names(projected.to_substrait_plan()) - assert_eq(len(names), 2, "project replace path should preserve one root name per output column") - assert_eq(names[0], "customer_id", "project replace path should preserve the first input name") - assert_eq(names[1], "amount", "project replace path should preserve the replaced alias") - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("amount", add(col("amount"), lit(1))) + + # -- Assert -- + names = root_names(projected.to_substrait_plan()) + assert len(names) == 2, "project replace path should preserve one root name per output column" + assert names[0] == "customer_id", "project replace path should preserve the first input name" + assert names[1] == "amount", "project replace path should preserve the replaced alias" def test_session_projection__collect_executes_with_column_append() -> None: @@ -53,49 +58,29 @@ def test_session_projection__collect_executes_with_column_append() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - projected = lazy.with_column("double_amount", mul(col("amount"), int_expr(2))) - planned = projected.planned_columns() - plan_names = root_names(projected.to_substrait_plan()) - match session.collect(projected): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 3, "project append collect should preserve the three input rows") - resolved = df.resolved_columns() - # -- Assert -- - assert_eq(len(planned), 3, "with_column should append one projected output column before execution") - assert_eq( - len(plan_names), - 3, - "projected plan should publish one root name per projected output column", - ) - assert_eq( - plan_names[2], - "double_amount", - "projected plan root names should preserve the requested alias", - ) - assert_eq( - planned[2], - "double_amount", - "lazy planned columns should preserve the requested output alias", - ) - assert_eq(len(resolved), 3, "collected projection should expose one appended resolved column") - assert_eq( - resolved[2], - "double_amount", - "root output names should preserve the requested projected alias after collect", - ) - assert_eq( - payload.contains("double_amount"), - true, - "preview text should render the projected alias as a column header", - ) - assert_eq(payload.contains("20"), true, "materialized projection should include 10 * 2") - assert_eq(payload.contains("30"), true, "materialized projection should include 15 * 2") - assert_eq(payload.contains("14"), true, "materialized projection should include 7 * 2") - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("double_amount", mul(col("amount"), lit(2))) + planned = projected.planned_columns() + plan_names = root_names(projected.to_substrait_plan()) + df = assert_is_ok(session.collect(projected), "project append collect should execute") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 3, "project append collect should preserve the three input rows" + assert len(planned) == 3, "with_column should append one projected output column before execution" + assert len(plan_names) == 3, "projected plan should publish one root name per projected output column" + assert plan_names[2] == "double_amount", "projected plan root names should preserve the requested alias" + assert planned[2] == "double_amount", "lazy planned columns should preserve the requested output alias" + assert len(resolved) == 3, "collected projection should expose one appended resolved column" + assert resolved[2] == "double_amount", "root output names should preserve the requested projected alias after collect" + assert payload.contains("double_amount"), "preview text should render the projected alias as a column header" + assert payload.contains("20"), "materialized projection should include 10 * 2" + assert payload.contains("30"), "materialized projection should include 15 * 2" + assert payload.contains("14"), "materialized projection should include 7 * 2" def test_session_projection__collect_executes_identity_select() -> None: @@ -103,22 +88,35 @@ def test_session_projection__collect_executes_identity_select() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - selected = lazy.select() - match session.collect(selected): - Ok(df) => - resolved = df.resolved_columns() - # -- Assert -- - assert_eq( - len(resolved), - 2, - "identity select should preserve the original column count after collect", - ) - assert_eq(resolved[0], "customer_id", "identity select should preserve the first input name") - assert_eq(resolved[1], "amount", "identity select should preserve the second input name") - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + selected = lazy.select() + df = assert_is_ok(session.collect(selected), "identity select collect should execute") + resolved = df.resolved_columns() + + # -- Assert -- + assert len(resolved) == 2, "identity select should preserve the original column count after collect" + assert resolved[0] == "customer_id", "identity select should preserve the first input name" + assert resolved[1] == "amount", "identity select should preserve the second input name" + + +def test_session_projection__collect_unknown_projection_column_returns_planning_error() -> None: + # -- Arrange -- + mut session = Session.default() + + # -- Act -- + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("bad_amount", add(col("missing_amount"), lit(1))) + err = assert_is_err(session.collect(projected), "expected unknown projection input column to fail during planning") + + # -- Assert -- + assert err.kind == SessionErrorKind.UnknownScalarColumn, err.error_message() + assert err.message.contains("missing_amount"), err.error_message() def test_session_projection__collect_executes_with_column_replace() -> None: @@ -126,41 +124,25 @@ def test_session_projection__collect_executes_with_column_replace() -> None: mut session = Session.default() # -- Act -- - match session.read_csv[AggregateOrder]("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE): - Ok(lazy) => - projected = lazy.with_column("amount", add(col("amount"), int_expr(1))) - planned = projected.planned_columns() - plan_names = root_names(projected.to_substrait_plan()) - match session.collect(projected): - Ok(df) => - payload = df.preview_text() - assert_eq(df.row_count(), 3, "project replace collect should preserve the three input rows") - resolved = df.resolved_columns() - # -- Assert -- - assert_eq(len(planned), 2, "replacing an existing name should preserve projected column count") - assert_eq( - len(plan_names), - 2, - "replacement projection should publish one root name per output column", - ) - assert_eq( - plan_names[1], - "amount", - "replacement projection root names should preserve the replaced alias", - ) - assert_eq(planned[1], "amount", "replacement semantics should preserve the original ordinal slot") - assert_eq( - len(resolved), - 2, - "replacement projection should preserve resolved column count after collect", - ) - assert_eq( - resolved[1], - "amount", - "root output names should preserve the replaced column alias after collect", - ) - assert_eq(payload.contains("11"), true, "replacement projection should include 10 + 1") - assert_eq(payload.contains("16"), true, "replacement projection should include 15 + 1") - assert_eq(payload.contains("8"), true, "replacement projection should include 7 + 1") - Err(err) => assert_eq(true, false, err.error_message()) - Err(err) => assert_eq(true, false, err.error_message()) + lazy: LazyFrame[AggregateOrder] = assert_is_ok( + session.read_csv("aggregate_orders", AGGREGATE_ORDERS_CSV_FIXTURE), + "aggregate orders fixture should load", + ) + projected = lazy.with_column("amount", add(col("amount"), lit(1))) + planned = projected.planned_columns() + plan_names = root_names(projected.to_substrait_plan()) + df = assert_is_ok(session.collect(projected), "project replace collect should execute") + payload = df.preview_text() + resolved = df.resolved_columns() + + # -- Assert -- + assert df.row_count() == 3, "project replace collect should preserve the three input rows" + assert len(planned) == 2, "replacing an existing name should preserve projected column count" + assert len(plan_names) == 2, "replacement projection should publish one root name per output column" + assert plan_names[1] == "amount", "replacement projection root names should preserve the replaced alias" + assert planned[1] == "amount", "replacement semantics should preserve the original ordinal slot" + assert len(resolved) == 2, "replacement projection should preserve resolved column count after collect" + assert resolved[1] == "amount", "root output names should preserve the replaced column alias after collect" + assert payload.contains("11"), "replacement projection should include 10 + 1" + assert payload.contains("16"), "replacement projection should include 15 + 1" + assert payload.contains("8"), "replacement projection should include 7 + 1" diff --git a/tests/test_substrait_plan.incn b/tests/test_substrait_plan.incn index 9ae0419..bc7aefa 100644 --- a/tests/test_substrait_plan.incn +++ b/tests/test_substrait_plan.incn @@ -1,8 +1,7 @@ """Tests for RFC 002 proto-backed Substrait emission and conformance alignment.""" -from std.testing import assert_eq from aggregate_builders import AggregateMeasure, AggregateKind -from functions import add, always_true, col, int_expr, mul +from functions import add, always_true, col, lit, mul from projection_builders import with_column_assignment from substrait.extensions import explode_extension_uri, function_extension_uri, registered_substrait_extension_uris from substrait.inspect import ( @@ -46,7 +45,7 @@ from substrait.relations import ( set_rel_of_kind, sort_rel, ) -from substrait.schema_registry import register_named_table_schema +from substrait.schema_registry import named_table_columns, register_named_table_schema from substrait.schema import RowColumnSpec, SubstraitPrimitiveKind from substrait.conformance import ( ConformanceCapabilityTags, @@ -116,6 +115,32 @@ def _register_orders_schema() -> None: register_named_table_schema("orders", [RowColumnSpec(name="id", kind=SubstraitPrimitiveKind.I64, nullable=false)]) +def _register_fixture_schema(table_name: str) -> None: + register_named_table_schema( + table_name, + [RowColumnSpec(name=_fixture_col_primary(), kind=SubstraitPrimitiveKind.I64, nullable=false)], + ) + + +def test_plan__named_table_schema_registry_prefers_latest_registration() -> None: + # -- Arrange -- + table_name = "__schema_replace_probe__" + register_named_table_schema( + table_name, + [RowColumnSpec(name="old_id", kind=SubstraitPrimitiveKind.I64, nullable=false)], + ) + register_named_table_schema( + table_name, + [RowColumnSpec(name="new_id", kind=SubstraitPrimitiveKind.I64, nullable=false)], + ) + + # -- Act -- + columns = named_table_columns(table_name) + + # -- Assert -- + assert columns == ["new_id"], "named-table schema lookup should use the latest registered shape" + + def test_plan__named_table_root_kind() -> None: # -- Arrange -- plan = plan_from_named_table("orders") @@ -126,9 +151,9 @@ def test_plan__named_table_root_kind() -> None: read_kind = read_kind_name(root) # -- Assert -- - assert_eq(plan_encoded_len(plan) > 0, true, "proto-backed plan should encode") - assert_eq(root_kind, "ReadRel", "named table plan should expose ReadRel root") - assert_eq(read_kind, "NamedTable", "named table plan should expose NamedTable root kind") + assert plan_encoded_len(plan) > 0, "proto-backed plan should encode" + assert root_kind == "ReadRel", "named table plan should expose ReadRel root" + assert read_kind == "NamedTable", "named table plan should expose NamedTable root kind" def test_plan__read_root_kinds_are_distinguished() -> None: @@ -143,33 +168,33 @@ def test_plan__read_root_kinds_are_distinguished() -> None: virtual_kind = read_kind_name(root_rel(virtual)) # -- Assert -- - assert_eq(named_kind, "NamedTable", "named table root should remain inspectable") - assert_eq(local_kind, "LocalFiles", "local files root should remain inspectable") - assert_eq(virtual_kind, "VirtualTable", "virtual table root should remain inspectable") + assert named_kind == "NamedTable", "named table root should remain inspectable" + assert local_kind == "LocalFiles", "local files root should remain inspectable" + assert virtual_kind == "VirtualTable", "virtual table root should remain inspectable" def test_plan__local_files_helpers_are_explicitly_parquet_backed() -> None: # -- Arrange -- uri = "fixture://orders.parquet" - rel_via_compat = read_local_files_rel(uri) + rel_via_default = read_local_files_rel(uri) rel_via_parquet = read_local_parquet_rel(uri) - plan_via_compat = plan_from_local_files(uri) + plan_via_default = plan_from_local_files(uri) plan_via_parquet = plan_from_local_parquet(uri) # -- Act -- - compat_kind = relation_kind_name(rel_via_compat) - compat_read_kind = read_kind_name(rel_via_compat) + default_kind = relation_kind_name(rel_via_default) + default_read_kind = read_kind_name(rel_via_default) parquet_kind = relation_kind_name(rel_via_parquet) parquet_read_kind = read_kind_name(rel_via_parquet) - compat_encoded = plan_encoded_len(plan_via_compat) + default_encoded = plan_encoded_len(plan_via_default) parquet_encoded = plan_encoded_len(plan_via_parquet) # -- Assert -- - assert_eq(compat_kind, "ReadRel", "compat local-files helper should still emit ReadRel") - assert_eq(compat_read_kind, "LocalFiles", "compat local-files helper should still emit LocalFiles read kind") - assert_eq(parquet_kind, "ReadRel", "explicit parquet local-files helper should emit ReadRel") - assert_eq(parquet_read_kind, "LocalFiles", "explicit parquet local-files helper should emit LocalFiles read kind") - assert_eq(compat_encoded, parquet_encoded, "compat and explicit parquet plan helpers should remain equivalent") + assert default_kind == "ReadRel", "default local-files helper should emit ReadRel" + assert default_read_kind == "LocalFiles", "default local-files helper should emit LocalFiles read kind" + assert parquet_kind == "ReadRel", "explicit parquet local-files helper should emit ReadRel" + assert parquet_read_kind == "LocalFiles", "explicit parquet local-files helper should emit LocalFiles read kind" + assert default_encoded == parquet_encoded, "default and explicit parquet plan helpers should remain equivalent" def test_plan__combinators_compose_deterministically() -> None: @@ -183,11 +208,11 @@ def test_plan__combinators_compose_deterministically() -> None: plan = plan_from_root_relation(with_fetch, ["id"]) # -- Assert -- - assert_eq(relation_kind_name(with_fetch), "FetchRel", "last combinator should become the root rel") - assert_eq(plan_encoded_len(plan) > 0, true, "composed plan should still encode") - assert_eq(plan_contains_relation_kind(plan, "FilterRel"), true, "recursive traversal should find nested filter") - assert_eq(plan_contains_relation_kind(plan, "SortRel"), true, "recursive traversal should find nested sort") - assert_eq(plan_contains_relation_kind(plan, "FetchRel"), true, "recursive traversal should find fetch root") + assert relation_kind_name(with_fetch) == "FetchRel", "last combinator should become the root rel" + assert plan_encoded_len(plan) > 0, "composed plan should still encode" + assert plan_contains_relation_kind(plan, "FilterRel"), "recursive traversal should find nested filter" + assert plan_contains_relation_kind(plan, "SortRel"), "recursive traversal should find nested sort" + assert plan_contains_relation_kind(plan, "FetchRel"), "recursive traversal should find fetch root" def test_plan__project_rel_supports_projection_expression_payloads() -> None: @@ -196,32 +221,24 @@ def test_plan__project_rel_supports_projection_expression_payloads() -> None: projected = project_rel_of_columns( base, ["id"], - [with_column_assignment("double_id", mul(col("id"), int_expr(2))), with_column_assignment( + [with_column_assignment("double_id", mul(col("id"), lit(2))), with_column_assignment( "id_plus_one", - add(col("id"), int_expr(1)), + add(col("id"), lit(1)), )], ) # -- Act -- plan = plan_from_root_relation(projected, ["id", "double_id", "id_plus_one"]) # -- Assert -- - assert_eq(relation_kind_name(projected), "ProjectRel", "projection builder lowering should emit ProjectRel roots") - assert_eq(plan_encoded_len(plan) > 0, true, "projection builder lowering should still encode into a valid plan") - assert_eq( - plan_contains_relation_kind(plan, "ProjectRel"), - true, - "encoded plan should retain the projected root relation kind", - ) - assert_eq( - plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), - true, - "scalar function plans should register the shared function extension URN", - ) + assert relation_kind_name(projected) == "ProjectRel", "projection builder lowering should emit ProjectRel roots" + assert plan_encoded_len(plan) > 0, "projection builder lowering should still encode into a valid plan" + assert plan_contains_relation_kind(plan, "ProjectRel"), "encoded plan should retain the projected root relation kind" + assert plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), "scalar function plans should register the shared function extension URN" root_output_names = root_names(plan) - assert_eq(len(root_output_names), 3, "project root should publish one output name per projected column") - assert_eq(root_output_names[0], "id", "project root should preserve passthrough column names") - assert_eq(root_output_names[1], "double_id", "project root should preserve derived aliases") - assert_eq(root_output_names[2], "id_plus_one", "project root should preserve all derived aliases") + assert len(root_output_names) == 3, "project root should publish one output name per projected column" + assert root_output_names[0] == "id", "project root should preserve passthrough column names" + assert root_output_names[1] == "double_id", "project root should preserve derived aliases" + assert root_output_names[2] == "id_plus_one", "project root should preserve all derived aliases" def test_plan__aggregate_rel_surfaces_group_and_measure_output_columns() -> None: @@ -230,10 +247,10 @@ def test_plan__aggregate_rel_surfaces_group_and_measure_output_columns() -> None base = read_named_table_rel("orders") aggregated = aggregate_rel( base, - ["id"], - [AggregateMeasure(kind=AggregateKind.Sum, column_name="id"), AggregateMeasure( + [col("id")], + [AggregateMeasure(kind=AggregateKind.Sum, expr=col("id")), AggregateMeasure( kind=AggregateKind.Count, - column_name="", + expr=col(""), )], ) plan = plan_from_root_relation(aggregated, ["id", "sum_id", "count"]) @@ -242,16 +259,32 @@ def test_plan__aggregate_rel_surfaces_group_and_measure_output_columns() -> None output_columns = relation_output_columns(aggregated) # -- Assert -- - assert_eq(relation_kind_name(aggregated), "AggregateRel", "aggregate lowering should emit AggregateRel") - assert_eq( - plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), - true, - "aggregate function plans should register the shared function extension URN", + assert relation_kind_name(aggregated) == "AggregateRel", "aggregate lowering should emit AggregateRel" + assert plan_has_extension_urn(plan, registered_substrait_extension_uris()[0]), "aggregate function plans should register the shared function extension URN" + assert len(output_columns) == 3, "aggregate output columns should include group key plus measure outputs" + assert output_columns[0] == "id", "group key should remain the first aggregate output column" + assert output_columns[1] == "sum_id", "sum measure output columns should use the stable prefixed name" + assert output_columns[2] == "count", "count measure output columns should use the stable count name" + + +def test_plan__aggregate_rel_accepts_scalar_group_and_measure_expressions() -> None: + # -- Arrange -- + _register_orders_schema() + base = read_named_table_rel("orders") + aggregated = aggregate_rel( + base, + [add(col("id"), lit(1))], + [AggregateMeasure(kind=AggregateKind.Sum, expr=add(col("id"), lit(2)))], ) - assert_eq(len(output_columns), 3, "aggregate output columns should include group key plus measure outputs") - assert_eq(output_columns[0], "id", "group key should remain the first aggregate output column") - assert_eq(output_columns[1], "sum_id", "sum measure output columns should use the stable prefixed name") - assert_eq(output_columns[2], "count", "count measure output columns should use the stable count name") + + # -- Act -- + output_columns = relation_output_columns(aggregated) + + # -- Assert -- + assert relation_kind_name(aggregated) == "AggregateRel", "scalar aggregate lowering should emit AggregateRel" + assert len(output_columns) == 2, "scalar aggregate output should include computed group key and measure" + assert output_columns[0] == "group_0", "computed grouping expressions should get stable fallback names" + assert output_columns[1] == "sum", "computed aggregate measures should use stable fallback names" def test_plan__set_rel_uses_operation_enum() -> None: @@ -264,9 +297,9 @@ def test_plan__set_rel_uses_operation_enum() -> None: union_plan = plan_from_root_relation(union_rel, ["id"]) # -- Assert -- - assert_eq(relation_kind_name(union_rel), "SetRel", "set combinator must produce SetRel root") - assert_eq(plan_encoded_len(union_plan) > 0, true, "set plan should still encode") - assert_eq(plan_contains_relation_kind(union_plan, "ReadRel"), true, "set traversal should still reach child reads") + assert relation_kind_name(union_rel) == "SetRel", "set combinator must produce SetRel root" + assert plan_encoded_len(union_plan) > 0, "set plan should still encode" + assert plan_contains_relation_kind(union_plan, "ReadRel"), "set traversal should still reach child reads" def test_plan__enum_backed_join_and_set_builders_expose_boundary_facts() -> None: @@ -288,8 +321,8 @@ def test_plan__enum_backed_join_and_set_builders_expose_boundary_facts() -> None set_kind = set_operation_name(union_distinct_rel) # -- Assert -- - assert_eq(join_kind, "JoinRel", "enum-backed join builder should still emit a JoinRel root") - assert_eq(set_kind, "UnionDistinct", "enum-backed set builder should preserve the selected set operation") + assert join_kind == "JoinRel", "enum-backed join builder should still emit a JoinRel root" + assert set_kind == "UnionDistinct", "enum-backed set builder should preserve the selected set operation" def test_plan__reference_rel_preserves_subtree_ordinal() -> None: @@ -301,8 +334,8 @@ def test_plan__reference_rel_preserves_subtree_ordinal() -> None: ordinal = reference_subtree_ordinal(rel) # -- Assert -- - assert_eq(kind, "ReferenceRel", "reference builder should still expose ReferenceRel root kind") - assert_eq(ordinal, 7, "reference builder should preserve the requested subtree ordinal") + assert kind == "ReferenceRel", "reference builder should still expose ReferenceRel root kind" + assert ordinal == 7, "reference builder should preserve the requested subtree ordinal" def test_plan__extension_urns_are_surfaced() -> None: @@ -314,8 +347,8 @@ def test_plan__extension_urns_are_surfaced() -> None: plan = plan_from_root_relation(rel, ["id"]) # -- Assert -- - assert_eq(plan_has_extension_urn(plan, extension_uri), true, "extension relation should populate extension URNs") - assert_eq(plan_contains_relation_kind(plan, "ExtensionSingleRel"), true, "extension root should remain inspectable") + assert plan_has_extension_urn(plan, extension_uri), "extension relation should populate extension URNs" + assert plan_contains_relation_kind(plan, "ExtensionSingleRel"), "extension root should remain inspectable" def test_plan__revision_pin_and_extension_registry_are_exported() -> None: @@ -327,22 +360,23 @@ def test_plan__revision_pin_and_extension_registry_are_exported() -> None: producer = substrait_producer_name() # -- Assert -- - assert_eq(tag, "v0.63.0", "revision helpers should expose the currently targeted Substrait release tag") - assert_eq(producer, "inql-rfc002", "revision helpers should expose the package producer label") - assert_eq(len(registered), 2, "current package boundary should register both extension URIs") - assert_eq(registered[0], function_extension_uri(), "registry should include the shared function extension URI first") - assert_eq(registered[1], explode_extension_uri(), "registry should include the emitted explode extension URI") + assert tag == "v0.63.0", "revision helpers should expose the currently targeted Substrait release tag" + assert producer == "inql-rfc002", "revision helpers should expose the package producer label" + assert len(registered) == 2, "current package boundary should register both extension URIs" + assert registered[0] == function_extension_uri(), "registry should include the shared function extension URI first" + assert registered[1] == explode_extension_uri(), "registry should include the emitted explode extension URI" def test_conformance__core_scenarios_validate_emission_output() -> None: # -- Arrange -- + _register_fixture_schema(_fixture_table_main()) scenarios = core_scenarios() # -- Act -- - assert_eq(len(scenarios), 12, "core scenario count should stay stable") + assert len(scenarios) == 12, "core scenario count should stay stable" for scenario in scenarios: - assert_eq(len(scenario.capability_tags.0) > 0, true, "capability tags must remain non-empty") - assert_eq(len(scenario.references.0) > 0, true, "references must remain non-empty") + assert len(scenario.capability_tags.0) > 0, "capability tags must remain non-empty" + assert len(scenario.references.0) > 0, "references must remain non-empty" # -- Assert -- for key in _core_keys(): @@ -384,8 +418,8 @@ def test_conformance__core_scenarios_validate_emission_output() -> None: plan = plan_from_root_relation( aggregate_rel( read_named_table_rel(_fixture_table_main()), - [_fixture_col_primary()], - [AggregateMeasure(kind=AggregateKind.Count, column_name="")], + [col(_fixture_col_primary())], + [AggregateMeasure(kind=AggregateKind.Count, expr=col(""))], ), [_fixture_col_primary()], ) @@ -411,21 +445,9 @@ def test_conformance__core_scenarios_validate_emission_output() -> None: CoreScenarioKey.ReferenceRelSharedSubplan => plan = plan_from_root_relation(reference_rel(7), [_fixture_col_primary()]) - assert_eq( - scenario_matches_root_shape(scenario, plan), - true, - "scenario should validate root shape against test-owned fixture plan", - ) - assert_eq( - scenario_has_required_rel_kinds(scenario, plan), - true, - "fixture plan should contain all scenario-required relation kinds", - ) - assert_eq( - scenario_matches_key_invariants(key, plan), - true, - "fixture plan should satisfy scenario-specific invariant checks", - ) + assert scenario_matches_root_shape(scenario, plan), "scenario should validate root shape against test-owned fixture plan" + assert scenario_has_required_rel_kinds(scenario, plan), "fixture plan should contain all scenario-required relation kinds" + assert scenario_matches_key_invariants(key, plan), "fixture plan should satisfy scenario-specific invariant checks" def test_conformance__scenario_shape_requires_all_declared_required_rels() -> None: @@ -449,13 +471,5 @@ def test_conformance__scenario_shape_requires_all_declared_required_rels() -> No named_only_plan = plan_from_named_table("orders") # -- Assert -- - assert_eq( - scenario_matches_root_shape(scenario, named_only_plan), - false, - "shape validation should fail when the root relation does not match the declared root contract", - ) - assert_eq( - scenario_has_required_rel_kinds(scenario, named_only_plan), - false, - "recursive relation-kind validation should fail when one required relation is missing", - ) + assert scenario_matches_root_shape(scenario, named_only_plan) is false, "shape validation should fail when the root relation does not match the declared root contract" + assert scenario_has_required_rel_kinds(scenario, named_only_plan) is false, "recursive relation-kind validation should fail when one required relation is missing"