Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 110 additions & 26 deletions buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ class XorqStatPipeline:
EXTERNAL_KEYS = frozenset(
{"orig_col_name", "rewritten_col_name", "dtype", "length", "min", "max"})

# Result column for the table-level length scalar in the Phase-1
# aggregate expression. Exposed so callers of compile_batch_expr can
# locate the row count by name without hard-coding the literal.
TOTAL_LENGTH_KEY = "__total_length__"

def __init__(self, stat_funcs: list, backend: Any = None, unit_test: bool = True):
if not HAS_XORQ:
raise ImportError(
Expand Down Expand Up @@ -165,55 +170,134 @@ def unit_test(self) -> Tuple[bool, List[StatError]]:
finally:
self.backend = saved_backend

def process_table(self, table) -> Tuple[SDType, List[StatError]]:
def _build_batch_agg_exprs(self, table) -> Tuple[List[Any], List[Tuple[str, StatFunc, Any]], List[StatError]]:
"""Build the Phase-1 batch aggregate expressions for ``table``.

Returns ``(agg_exprs, batch_items, errors)``:
- ``agg_exprs``: list of ibis scalar expressions, starting with
``table.count().name(TOTAL_LENGTH_KEY)`` and followed by one
named expression per (col, batch-stat) pair that survived the
column filter and constructed without raising.
- ``batch_items``: ``[(col, StatFunc, named_expr), ...]`` —
per-result attribution. ``process_table`` uses this to write
results back to the per-column accumulator.
- ``errors``: construction failures (``sf.func(col=table[c])``
raised, or ``.name(...)`` raised). One entry per (col, stat).

Pure expression building — no execution and no I/O. Safe to call
with an ``UnboundTable``.
"""
schema = table.schema()
columns = list(table.columns)

# Pre-populate every column accumulator with the externally-provided
# keys. ``length`` is filled in by the batch query below. ``min`` /
# ``max`` start as None so dependents (histogram) don't cascade-
# exclude on non-numeric columns; ``min`` / ``max`` overwrite for
# numeric cols.
accumulators: Dict[str, Dict[str, StatResult]] = {}
for col in columns:
accumulators[col] = {"orig_col_name": Ok(col), "rewritten_col_name": Ok(col), "dtype": Ok(str(schema[col])),
"length": Ok(0), "min": Ok(None), "max": Ok(None)}

# ---- Phase 1: batch aggregate ----------------------------------
# ``length`` is a table-level scalar (same value for every column),
# so it goes in once as ``__total_length__`` rather than as N
# per-column expressions.
TOTAL_LENGTH_KEY = "__total_length__"
agg_exprs: List[Any] = [table.count().name(self.TOTAL_LENGTH_KEY)]
batch_items: List[Tuple[str, StatFunc, Any]] = []
errors: List[StatError] = []

for sf in self.ordered_stat_funcs:
if not _is_batch_func(sf):
continue
# Batch funcs are expected to provide exactly one stat — the
# named aggregate column ``<col>|<stat_name>`` is keyed by
# ``sf.provides[0].name``. If a batch func ever provides more
# than one, only the first would round-trip through the
# aggregate, so flag it loudly rather than silently dropping
# the rest.
assert len(sf.provides) == 1, (
f"batch stat {sf.name!r} provides {len(sf.provides)} keys; "
f"batch funcs must provide exactly one")
xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn)
stat_name = sf.provides[0].name
for col in columns:
col_dtype = schema[col]
if sf.column_filter is not None and not sf.column_filter(col_dtype):
continue
try:
expr = sf.func(**{xorq_col_param: table[col]})
except Exception as e:
for sk in sf.provides:
accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col,
inputs={"col": col})
errors.append(StatError(
column=col, stat_key=stat_name, error=e, stat_func=sf,
inputs={"col": col}))
continue
if expr is None:
continue
stat_name = sf.provides[0].name
try:
expr = expr.name(f"{col}|{stat_name}")
except Exception as e:
for sk in sf.provides:
accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col,
inputs={"col": col})
errors.append(StatError(
column=col, stat_key=stat_name, error=e, stat_func=sf,
inputs={"col": col}))
continue
agg_exprs.append(expr)
batch_items.append((col, sf, expr))

agg_exprs = [table.count().name(TOTAL_LENGTH_KEY)]
agg_exprs.extend(e for _, _, e in batch_items)
return agg_exprs, batch_items, errors

def compile_batch_expr(self, table) -> Tuple[Any, List[StatError]]:
"""Return the Phase-1 batch aggregate expression for ``table``.

``table`` may be any ibis Table — including an UnboundTable from
``xo.table(schema, name=...)``. The returned expression is already
wrapped in ``table.aggregate(...)`` and has shape
``(1 row) × (1 + N_batch_results)`` with columns
``TOTAL_LENGTH_KEY`` and ``<col>|<stat_name>``.

Only Phase-1 batched stats are folded in: those with an
``XorqColumn`` parameter and no non-raw dependencies. Per-column
histograms (Phase 2) and Python-computed stats (``non_null_count``,
``nan_per``, ``distinct_per``, ``_type``, ``typing_stats``) are NOT
in the result — histograms need the resolved scalar min/max from
Phase 1, and computed stats are pure Python on resolved values.

Returns ``(expr, errors)`` where ``errors`` collects construction
failures (typically empty).

Rebind to a concrete source before executing. The replacement
source must be schema-compatible — same column names and dtypes
the unbound table was built with — or execute will fail when the
backend tries to resolve the per-column expressions::

unbound = xo.table(schema, name="t")
expr, _ = pipeline.compile_batch_expr(unbound)
bound = expr.op().replace({unbound.op(): real_source.op()}).to_expr()
df = bound.execute()
"""
agg_exprs, _batch_items, errors = self._build_batch_agg_exprs(table)
return table.aggregate(agg_exprs), errors

def process_table(self, table) -> Tuple[SDType, List[StatError]]:
schema = table.schema()
columns = list(table.columns)

# Pre-populate every column accumulator with the externally-provided
# keys. ``length`` is filled in by the batch query below. ``min`` /
# ``max`` start as None so dependents (histogram) don't cascade-
# exclude on non-numeric columns; ``min`` / ``max`` overwrite for
# numeric cols.
accumulators: Dict[str, Dict[str, StatResult]] = {}
for col in columns:
accumulators[col] = {"orig_col_name": Ok(col), "rewritten_col_name": Ok(col), "dtype": Ok(str(schema[col])),
"length": Ok(0), "min": Ok(None), "max": Ok(None)}

# ---- Phase 1: batch aggregate ----------------------------------
# ``length`` is a table-level scalar (same value for every column),
# so it goes in once as ``__total_length__`` rather than as N
# per-column expressions.
agg_exprs, batch_items, construction_errors = self._build_batch_agg_exprs(table)
# Mirror construction-time failures into per-column accumulators so
# downstream resolve sees the same Err shape it always has. The
# errors aren't appended to ``all_errors`` directly here — they
# reach the caller via ``resolve_accumulator`` (below), which
# converts every Err entry it walks into a StatError. Writing them
# twice would double-report.
for se in construction_errors:
sf = se.stat_func
if sf is None:
continue
for sk in sf.provides:
accumulators[se.column][sk.name] = Err(
error=se.error, stat_func_name=sf.name,
column_name=se.column, inputs={"col": se.column})

try:
result_df = self._execute(table.aggregate(agg_exprs))
Expand All @@ -224,7 +308,7 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
for sk in sf.provides:
accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col, inputs={})
else:
total_length = _to_python_scalar(result_df[TOTAL_LENGTH_KEY].iloc[0])
total_length = _to_python_scalar(result_df[self.TOTAL_LENGTH_KEY].iloc[0])
if total_length is None:
total_length = 0
for col in columns:
Expand Down
163 changes: 163 additions & 0 deletions tests/unit/test_xorq_compile_batch_expr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Tests for ``XorqStatPipeline.compile_batch_expr``.

Exports the Phase-1 batch aggregate as an ibis expression so callers can:
- pass an ``xo.table(schema, name=...)`` UnboundTable and get a portable,
reusable stat expression (catalog it, ship it, rebind later);
- inspect the SQL/plan via ``ibis.to_sql``;
- run it against a backend manually without going through ``process_table``.

Histograms (Phase 2) are intentionally not in the result — they need
scalar min/max from Phase 1 and therefore can't be folded into one expr.
"""

import pandas as pd
import pytest

xo = pytest.importorskip("xorq.api")

from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import ( # noqa: E402
XorqStatPipeline,
XorqColumn)
from buckaroo.pluggable_analysis_framework.stat_func import stat # noqa: E402
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2 # noqa: E402


SCHEMA = {"ints": "int64", "floats": "float64", "strs": "string", "bools": "boolean"}


def _unbound():
return xo.table(SCHEMA, name="t")


def _real():
return xo.memtable(pd.DataFrame(
{"ints": [1, 2, 3, 4, 5, 6, 7], "floats": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7],
"strs": ["a", "b", "c", "d", "e", "f", "g"], "bools": [True, False, True, False, True, False, True]}), name="t")


def _rebind(expr, unbound, real):
"""Substitute the UnboundTable op in ``expr`` with the real table's op."""
return expr.op().replace({unbound.op(): real.op()}).to_expr()


class TestCompileBatchExpr:
def test_returns_unbound_when_given_unbound(self):
from xorq.vendor.ibis.expr.operations.relations import UnboundTable

pipeline = XorqStatPipeline(XORQ_STATS_V2)
unbound = _unbound()
expr, errors = pipeline.compile_batch_expr(unbound)
assert errors == []
# Walk the op tree directly — the expression must still reference
# the UnboundTable. Using ``op.find`` rather than repr-matching so
# the test doesn't depend on ibis' printable format.
found = expr.op().find(UnboundTable)
assert found, "compile_batch_expr should not bind the input table"
assert unbound.op() in found

def test_output_columns_have_expected_names(self):
pipeline = XorqStatPipeline(XORQ_STATS_V2)
expr, _ = pipeline.compile_batch_expr(_unbound())
names = set(expr.schema().names)
assert pipeline.TOTAL_LENGTH_KEY in names
# Every column gets null_count + distinct_count (no column_filter).
for col in SCHEMA:
assert f"{col}|null_count" in names
assert f"{col}|distinct_count" in names
# mean is numeric-not-bool → ints/floats only.
assert "ints|mean" in names
assert "floats|mean" in names
assert "strs|mean" not in names
assert "bools|mean" not in names
# min/max use _is_numeric_ibis → numeric per ibis dtype. In xorq's
# vendored ibis, boolean.is_numeric() is False, so bools are out
# alongside strs.
assert "ints|min" in names
assert "floats|max" in names
assert "strs|min" not in names
assert "bools|min" not in names

def test_no_histogram_in_batch_expr(self):
"""Histogram is Phase 2 — must not appear in the compiled batch expr."""
pipeline = XorqStatPipeline(XORQ_STATS_V2)
expr, _ = pipeline.compile_batch_expr(_unbound())
names = set(expr.schema().names)
for col in SCHEMA:
assert f"{col}|histogram" not in names

def test_rebind_matches_process_table_batch_results(self):
"""Rebinding the unbound expr to a real source and executing must yield
the same scalar values that process_table records in its accumulator
for the same batch-phase stats."""
pipeline = XorqStatPipeline(XORQ_STATS_V2)
unbound = _unbound()
real = _real()

expr, errors = pipeline.compile_batch_expr(unbound)
assert errors == []
rebound = _rebind(expr, unbound, real)
df = rebound.execute()

baseline, _ = pipeline.process_table(real)

# __total_length__ matches `length` on every column accumulator.
assert int(df[pipeline.TOTAL_LENGTH_KEY].iloc[0]) == baseline["ints"]["length"]

# Spot-check a few (col, stat) pairs against the accumulator.
for col, stat_name in [
("ints", "null_count"),
("ints", "min"),
("ints", "max"),
("ints", "mean"),
("floats", "median"),
("strs", "distinct_count"),
("bools", "null_count"),
]:
key = f"{col}|{stat_name}"
got = df[key].iloc[0]
want = baseline[col][stat_name]
# Coerce numpy scalars; allow float NaN-tolerance not needed here.
assert float(got) == pytest.approx(float(want)), (
f"mismatch on {key}: rebound={got} baseline={want}")

def test_accepts_real_table_too(self):
"""compile_batch_expr should accept a bound table — just produces a
bound aggregate expression (no rebind needed)."""
pipeline = XorqStatPipeline(XORQ_STATS_V2)
expr, errors = pipeline.compile_batch_expr(_real())
assert errors == []
df = expr.execute()
assert len(df) == 1
assert int(df[pipeline.TOTAL_LENGTH_KEY].iloc[0]) == 7

def test_construction_error_surfaces_in_errors(self):
"""A stat that raises while building its ibis expression must appear
in the returned errors list, not be silently dropped."""

@stat()
def broken_batch(col: XorqColumn) -> int:
raise RuntimeError("intentional construction failure")

pipeline = XorqStatPipeline([*XORQ_STATS_V2, broken_batch], unit_test=False)
expr, errors = pipeline.compile_batch_expr(_unbound())
# One error per column the stat would have been built for (no filter
# → every column).
assert len(errors) == len(SCHEMA)
for se in errors:
assert isinstance(se.error, RuntimeError)
assert se.stat_key == "broken_batch"
assert se.column in SCHEMA
# The expression should still compile — the broken stat is just absent.
names = set(expr.schema().names)
for col in SCHEMA:
assert f"{col}|broken_batch" not in names

def test_process_table_still_works_after_refactor(self):
"""compile_batch_expr is extracted from process_table's Phase 1.
process_table itself must continue to produce correct results."""
pipeline = XorqStatPipeline(XORQ_STATS_V2)
stats, errors = pipeline.process_table(_real())
assert errors == []
assert stats["ints"]["length"] == 7
assert stats["ints"]["mean"] == pytest.approx(4.0)
assert stats["strs"]["distinct_count"] == 7
Loading