diff --git a/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py b/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py
index 54720f9af..047e7f7b8 100644
--- a/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py
+++ b/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py
@@ -105,6 +105,11 @@ class XorqStatPipeline:
EXTERNAL_KEYS = frozenset(
{"orig_col_name", "rewritten_col_name", "dtype", "length", "min", "max"})
+ # Result column for the table-level length scalar in the Phase-1
+ # aggregate expression. Exposed so callers of compile_batch_expr can
+ # locate the row count by name without hard-coding the literal.
+ TOTAL_LENGTH_KEY = "__total_length__"
+
def __init__(self, stat_funcs: list, backend: Any = None, unit_test: bool = True):
if not HAS_XORQ:
raise ImportError(
@@ -165,30 +170,44 @@ def unit_test(self) -> Tuple[bool, List[StatError]]:
finally:
self.backend = saved_backend
- def process_table(self, table) -> Tuple[SDType, List[StatError]]:
+ def _build_batch_agg_exprs(self, table) -> Tuple[List[Any], List[Tuple[str, StatFunc, Any]], List[StatError]]:
+ """Build the Phase-1 batch aggregate expressions for ``table``.
+
+ Returns ``(agg_exprs, batch_items, errors)``:
+ - ``agg_exprs``: list of ibis scalar expressions, starting with
+ ``table.count().name(TOTAL_LENGTH_KEY)`` and followed by one
+ named expression per (col, batch-stat) pair that survived the
+ column filter and constructed without raising.
+ - ``batch_items``: ``[(col, StatFunc, named_expr), ...]`` —
+ per-result attribution. ``process_table`` uses this to write
+ results back to the per-column accumulator.
+ - ``errors``: construction failures (``sf.func(col=table[c])``
+ raised, or ``.name(...)`` raised). One entry per (col, stat).
+
+ Pure expression building — no execution and no I/O. Safe to call
+ with an ``UnboundTable``.
+ """
schema = table.schema()
columns = list(table.columns)
- # Pre-populate every column accumulator with the externally-provided
- # keys. ``length`` is filled in by the batch query below. ``min`` /
- # ``max`` start as None so dependents (histogram) don't cascade-
- # exclude on non-numeric columns; ``min`` / ``max`` overwrite for
- # numeric cols.
- accumulators: Dict[str, Dict[str, StatResult]] = {}
- for col in columns:
- accumulators[col] = {"orig_col_name": Ok(col), "rewritten_col_name": Ok(col), "dtype": Ok(str(schema[col])),
- "length": Ok(0), "min": Ok(None), "max": Ok(None)}
-
- # ---- Phase 1: batch aggregate ----------------------------------
- # ``length`` is a table-level scalar (same value for every column),
- # so it goes in once as ``__total_length__`` rather than as N
- # per-column expressions.
- TOTAL_LENGTH_KEY = "__total_length__"
+ agg_exprs: List[Any] = [table.count().name(self.TOTAL_LENGTH_KEY)]
batch_items: List[Tuple[str, StatFunc, Any]] = []
+ errors: List[StatError] = []
+
for sf in self.ordered_stat_funcs:
if not _is_batch_func(sf):
continue
+ # Batch funcs are expected to provide exactly one stat — the
+ # named aggregate column ``
|`` is keyed by
+ # ``sf.provides[0].name``. If a batch func ever provides more
+ # than one, only the first would round-trip through the
+ # aggregate, so flag it loudly rather than silently dropping
+ # the rest.
+ assert len(sf.provides) == 1, (
+ f"batch stat {sf.name!r} provides {len(sf.provides)} keys; "
+ f"batch funcs must provide exactly one")
xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn)
+ stat_name = sf.provides[0].name
for col in columns:
col_dtype = schema[col]
if sf.column_filter is not None and not sf.column_filter(col_dtype):
@@ -196,24 +215,89 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
try:
expr = sf.func(**{xorq_col_param: table[col]})
except Exception as e:
- for sk in sf.provides:
- accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col,
- inputs={"col": col})
+ errors.append(StatError(
+ column=col, stat_key=stat_name, error=e, stat_func=sf,
+ inputs={"col": col}))
continue
if expr is None:
continue
- stat_name = sf.provides[0].name
try:
expr = expr.name(f"{col}|{stat_name}")
except Exception as e:
- for sk in sf.provides:
- accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col,
- inputs={"col": col})
+ errors.append(StatError(
+ column=col, stat_key=stat_name, error=e, stat_func=sf,
+ inputs={"col": col}))
continue
+ agg_exprs.append(expr)
batch_items.append((col, sf, expr))
- agg_exprs = [table.count().name(TOTAL_LENGTH_KEY)]
- agg_exprs.extend(e for _, _, e in batch_items)
+ return agg_exprs, batch_items, errors
+
+ def compile_batch_expr(self, table) -> Tuple[Any, List[StatError]]:
+ """Return the Phase-1 batch aggregate expression for ``table``.
+
+ ``table`` may be any ibis Table — including an UnboundTable from
+ ``xo.table(schema, name=...)``. The returned expression is already
+ wrapped in ``table.aggregate(...)`` and has shape
+ ``(1 row) × (1 + N_batch_results)`` with columns
+ ``TOTAL_LENGTH_KEY`` and ``|``.
+
+ Only Phase-1 batched stats are folded in: those with an
+ ``XorqColumn`` parameter and no non-raw dependencies. Per-column
+ histograms (Phase 2) and Python-computed stats (``non_null_count``,
+ ``nan_per``, ``distinct_per``, ``_type``, ``typing_stats``) are NOT
+ in the result — histograms need the resolved scalar min/max from
+ Phase 1, and computed stats are pure Python on resolved values.
+
+ Returns ``(expr, errors)`` where ``errors`` collects construction
+ failures (typically empty).
+
+ Rebind to a concrete source before executing. The replacement
+ source must be schema-compatible — same column names and dtypes
+ the unbound table was built with — or execute will fail when the
+ backend tries to resolve the per-column expressions::
+
+ unbound = xo.table(schema, name="t")
+ expr, _ = pipeline.compile_batch_expr(unbound)
+ bound = expr.op().replace({unbound.op(): real_source.op()}).to_expr()
+ df = bound.execute()
+ """
+ agg_exprs, _batch_items, errors = self._build_batch_agg_exprs(table)
+ return table.aggregate(agg_exprs), errors
+
+ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
+ schema = table.schema()
+ columns = list(table.columns)
+
+ # Pre-populate every column accumulator with the externally-provided
+ # keys. ``length`` is filled in by the batch query below. ``min`` /
+ # ``max`` start as None so dependents (histogram) don't cascade-
+ # exclude on non-numeric columns; ``min`` / ``max`` overwrite for
+ # numeric cols.
+ accumulators: Dict[str, Dict[str, StatResult]] = {}
+ for col in columns:
+ accumulators[col] = {"orig_col_name": Ok(col), "rewritten_col_name": Ok(col), "dtype": Ok(str(schema[col])),
+ "length": Ok(0), "min": Ok(None), "max": Ok(None)}
+
+ # ---- Phase 1: batch aggregate ----------------------------------
+ # ``length`` is a table-level scalar (same value for every column),
+ # so it goes in once as ``__total_length__`` rather than as N
+ # per-column expressions.
+ agg_exprs, batch_items, construction_errors = self._build_batch_agg_exprs(table)
+ # Mirror construction-time failures into per-column accumulators so
+ # downstream resolve sees the same Err shape it always has. The
+ # errors aren't appended to ``all_errors`` directly here — they
+ # reach the caller via ``resolve_accumulator`` (below), which
+ # converts every Err entry it walks into a StatError. Writing them
+ # twice would double-report.
+ for se in construction_errors:
+ sf = se.stat_func
+ if sf is None:
+ continue
+ for sk in sf.provides:
+ accumulators[se.column][sk.name] = Err(
+ error=se.error, stat_func_name=sf.name,
+ column_name=se.column, inputs={"col": se.column})
try:
result_df = self._execute(table.aggregate(agg_exprs))
@@ -224,7 +308,7 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
for sk in sf.provides:
accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col, inputs={})
else:
- total_length = _to_python_scalar(result_df[TOTAL_LENGTH_KEY].iloc[0])
+ total_length = _to_python_scalar(result_df[self.TOTAL_LENGTH_KEY].iloc[0])
if total_length is None:
total_length = 0
for col in columns:
diff --git a/tests/unit/test_xorq_compile_batch_expr.py b/tests/unit/test_xorq_compile_batch_expr.py
new file mode 100644
index 000000000..5ed843148
--- /dev/null
+++ b/tests/unit/test_xorq_compile_batch_expr.py
@@ -0,0 +1,163 @@
+"""Tests for ``XorqStatPipeline.compile_batch_expr``.
+
+Exports the Phase-1 batch aggregate as an ibis expression so callers can:
+ - pass an ``xo.table(schema, name=...)`` UnboundTable and get a portable,
+ reusable stat expression (catalog it, ship it, rebind later);
+ - inspect the SQL/plan via ``ibis.to_sql``;
+ - run it against a backend manually without going through ``process_table``.
+
+Histograms (Phase 2) are intentionally not in the result — they need
+scalar min/max from Phase 1 and therefore can't be folded into one expr.
+"""
+
+import pandas as pd
+import pytest
+
+xo = pytest.importorskip("xorq.api")
+
+from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import ( # noqa: E402
+ XorqStatPipeline,
+ XorqColumn)
+from buckaroo.pluggable_analysis_framework.stat_func import stat # noqa: E402
+from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2 # noqa: E402
+
+
+SCHEMA = {"ints": "int64", "floats": "float64", "strs": "string", "bools": "boolean"}
+
+
+def _unbound():
+ return xo.table(SCHEMA, name="t")
+
+
+def _real():
+ return xo.memtable(pd.DataFrame(
+ {"ints": [1, 2, 3, 4, 5, 6, 7], "floats": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7],
+ "strs": ["a", "b", "c", "d", "e", "f", "g"], "bools": [True, False, True, False, True, False, True]}), name="t")
+
+
+def _rebind(expr, unbound, real):
+ """Substitute the UnboundTable op in ``expr`` with the real table's op."""
+ return expr.op().replace({unbound.op(): real.op()}).to_expr()
+
+
+class TestCompileBatchExpr:
+ def test_returns_unbound_when_given_unbound(self):
+ from xorq.vendor.ibis.expr.operations.relations import UnboundTable
+
+ pipeline = XorqStatPipeline(XORQ_STATS_V2)
+ unbound = _unbound()
+ expr, errors = pipeline.compile_batch_expr(unbound)
+ assert errors == []
+ # Walk the op tree directly — the expression must still reference
+ # the UnboundTable. Using ``op.find`` rather than repr-matching so
+ # the test doesn't depend on ibis' printable format.
+ found = expr.op().find(UnboundTable)
+ assert found, "compile_batch_expr should not bind the input table"
+ assert unbound.op() in found
+
+ def test_output_columns_have_expected_names(self):
+ pipeline = XorqStatPipeline(XORQ_STATS_V2)
+ expr, _ = pipeline.compile_batch_expr(_unbound())
+ names = set(expr.schema().names)
+ assert pipeline.TOTAL_LENGTH_KEY in names
+ # Every column gets null_count + distinct_count (no column_filter).
+ for col in SCHEMA:
+ assert f"{col}|null_count" in names
+ assert f"{col}|distinct_count" in names
+ # mean is numeric-not-bool → ints/floats only.
+ assert "ints|mean" in names
+ assert "floats|mean" in names
+ assert "strs|mean" not in names
+ assert "bools|mean" not in names
+ # min/max use _is_numeric_ibis → numeric per ibis dtype. In xorq's
+ # vendored ibis, boolean.is_numeric() is False, so bools are out
+ # alongside strs.
+ assert "ints|min" in names
+ assert "floats|max" in names
+ assert "strs|min" not in names
+ assert "bools|min" not in names
+
+ def test_no_histogram_in_batch_expr(self):
+ """Histogram is Phase 2 — must not appear in the compiled batch expr."""
+ pipeline = XorqStatPipeline(XORQ_STATS_V2)
+ expr, _ = pipeline.compile_batch_expr(_unbound())
+ names = set(expr.schema().names)
+ for col in SCHEMA:
+ assert f"{col}|histogram" not in names
+
+ def test_rebind_matches_process_table_batch_results(self):
+ """Rebinding the unbound expr to a real source and executing must yield
+ the same scalar values that process_table records in its accumulator
+ for the same batch-phase stats."""
+ pipeline = XorqStatPipeline(XORQ_STATS_V2)
+ unbound = _unbound()
+ real = _real()
+
+ expr, errors = pipeline.compile_batch_expr(unbound)
+ assert errors == []
+ rebound = _rebind(expr, unbound, real)
+ df = rebound.execute()
+
+ baseline, _ = pipeline.process_table(real)
+
+ # __total_length__ matches `length` on every column accumulator.
+ assert int(df[pipeline.TOTAL_LENGTH_KEY].iloc[0]) == baseline["ints"]["length"]
+
+ # Spot-check a few (col, stat) pairs against the accumulator.
+ for col, stat_name in [
+ ("ints", "null_count"),
+ ("ints", "min"),
+ ("ints", "max"),
+ ("ints", "mean"),
+ ("floats", "median"),
+ ("strs", "distinct_count"),
+ ("bools", "null_count"),
+ ]:
+ key = f"{col}|{stat_name}"
+ got = df[key].iloc[0]
+ want = baseline[col][stat_name]
+ # Coerce numpy scalars; allow float NaN-tolerance not needed here.
+ assert float(got) == pytest.approx(float(want)), (
+ f"mismatch on {key}: rebound={got} baseline={want}")
+
+ def test_accepts_real_table_too(self):
+ """compile_batch_expr should accept a bound table — just produces a
+ bound aggregate expression (no rebind needed)."""
+ pipeline = XorqStatPipeline(XORQ_STATS_V2)
+ expr, errors = pipeline.compile_batch_expr(_real())
+ assert errors == []
+ df = expr.execute()
+ assert len(df) == 1
+ assert int(df[pipeline.TOTAL_LENGTH_KEY].iloc[0]) == 7
+
+ def test_construction_error_surfaces_in_errors(self):
+ """A stat that raises while building its ibis expression must appear
+ in the returned errors list, not be silently dropped."""
+
+ @stat()
+ def broken_batch(col: XorqColumn) -> int:
+ raise RuntimeError("intentional construction failure")
+
+ pipeline = XorqStatPipeline([*XORQ_STATS_V2, broken_batch], unit_test=False)
+ expr, errors = pipeline.compile_batch_expr(_unbound())
+ # One error per column the stat would have been built for (no filter
+ # → every column).
+ assert len(errors) == len(SCHEMA)
+ for se in errors:
+ assert isinstance(se.error, RuntimeError)
+ assert se.stat_key == "broken_batch"
+ assert se.column in SCHEMA
+ # The expression should still compile — the broken stat is just absent.
+ names = set(expr.schema().names)
+ for col in SCHEMA:
+ assert f"{col}|broken_batch" not in names
+
+ def test_process_table_still_works_after_refactor(self):
+ """compile_batch_expr is extracted from process_table's Phase 1.
+ process_table itself must continue to produce correct results."""
+ pipeline = XorqStatPipeline(XORQ_STATS_V2)
+ stats, errors = pipeline.process_table(_real())
+ assert errors == []
+ assert stats["ints"]["length"] == 7
+ assert stats["ints"]["mean"] == pytest.approx(4.0)
+ assert stats["strs"]["distinct_count"] == 7