diff --git a/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py b/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py index 54720f9af..047e7f7b8 100644 --- a/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py +++ b/buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py @@ -105,6 +105,11 @@ class XorqStatPipeline: EXTERNAL_KEYS = frozenset( {"orig_col_name", "rewritten_col_name", "dtype", "length", "min", "max"}) + # Result column for the table-level length scalar in the Phase-1 + # aggregate expression. Exposed so callers of compile_batch_expr can + # locate the row count by name without hard-coding the literal. + TOTAL_LENGTH_KEY = "__total_length__" + def __init__(self, stat_funcs: list, backend: Any = None, unit_test: bool = True): if not HAS_XORQ: raise ImportError( @@ -165,30 +170,44 @@ def unit_test(self) -> Tuple[bool, List[StatError]]: finally: self.backend = saved_backend - def process_table(self, table) -> Tuple[SDType, List[StatError]]: + def _build_batch_agg_exprs(self, table) -> Tuple[List[Any], List[Tuple[str, StatFunc, Any]], List[StatError]]: + """Build the Phase-1 batch aggregate expressions for ``table``. + + Returns ``(agg_exprs, batch_items, errors)``: + - ``agg_exprs``: list of ibis scalar expressions, starting with + ``table.count().name(TOTAL_LENGTH_KEY)`` and followed by one + named expression per (col, batch-stat) pair that survived the + column filter and constructed without raising. + - ``batch_items``: ``[(col, StatFunc, named_expr), ...]`` — + per-result attribution. ``process_table`` uses this to write + results back to the per-column accumulator. + - ``errors``: construction failures (``sf.func(col=table[c])`` + raised, or ``.name(...)`` raised). One entry per (col, stat). + + Pure expression building — no execution and no I/O. Safe to call + with an ``UnboundTable``. + """ schema = table.schema() columns = list(table.columns) - # Pre-populate every column accumulator with the externally-provided - # keys. ``length`` is filled in by the batch query below. ``min`` / - # ``max`` start as None so dependents (histogram) don't cascade- - # exclude on non-numeric columns; ``min`` / ``max`` overwrite for - # numeric cols. - accumulators: Dict[str, Dict[str, StatResult]] = {} - for col in columns: - accumulators[col] = {"orig_col_name": Ok(col), "rewritten_col_name": Ok(col), "dtype": Ok(str(schema[col])), - "length": Ok(0), "min": Ok(None), "max": Ok(None)} - - # ---- Phase 1: batch aggregate ---------------------------------- - # ``length`` is a table-level scalar (same value for every column), - # so it goes in once as ``__total_length__`` rather than as N - # per-column expressions. - TOTAL_LENGTH_KEY = "__total_length__" + agg_exprs: List[Any] = [table.count().name(self.TOTAL_LENGTH_KEY)] batch_items: List[Tuple[str, StatFunc, Any]] = [] + errors: List[StatError] = [] + for sf in self.ordered_stat_funcs: if not _is_batch_func(sf): continue + # Batch funcs are expected to provide exactly one stat — the + # named aggregate column ``|`` is keyed by + # ``sf.provides[0].name``. If a batch func ever provides more + # than one, only the first would round-trip through the + # aggregate, so flag it loudly rather than silently dropping + # the rest. + assert len(sf.provides) == 1, ( + f"batch stat {sf.name!r} provides {len(sf.provides)} keys; " + f"batch funcs must provide exactly one") xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn) + stat_name = sf.provides[0].name for col in columns: col_dtype = schema[col] if sf.column_filter is not None and not sf.column_filter(col_dtype): @@ -196,24 +215,89 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]: try: expr = sf.func(**{xorq_col_param: table[col]}) except Exception as e: - for sk in sf.provides: - accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col, - inputs={"col": col}) + errors.append(StatError( + column=col, stat_key=stat_name, error=e, stat_func=sf, + inputs={"col": col})) continue if expr is None: continue - stat_name = sf.provides[0].name try: expr = expr.name(f"{col}|{stat_name}") except Exception as e: - for sk in sf.provides: - accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col, - inputs={"col": col}) + errors.append(StatError( + column=col, stat_key=stat_name, error=e, stat_func=sf, + inputs={"col": col})) continue + agg_exprs.append(expr) batch_items.append((col, sf, expr)) - agg_exprs = [table.count().name(TOTAL_LENGTH_KEY)] - agg_exprs.extend(e for _, _, e in batch_items) + return agg_exprs, batch_items, errors + + def compile_batch_expr(self, table) -> Tuple[Any, List[StatError]]: + """Return the Phase-1 batch aggregate expression for ``table``. + + ``table`` may be any ibis Table — including an UnboundTable from + ``xo.table(schema, name=...)``. The returned expression is already + wrapped in ``table.aggregate(...)`` and has shape + ``(1 row) × (1 + N_batch_results)`` with columns + ``TOTAL_LENGTH_KEY`` and ``|``. + + Only Phase-1 batched stats are folded in: those with an + ``XorqColumn`` parameter and no non-raw dependencies. Per-column + histograms (Phase 2) and Python-computed stats (``non_null_count``, + ``nan_per``, ``distinct_per``, ``_type``, ``typing_stats``) are NOT + in the result — histograms need the resolved scalar min/max from + Phase 1, and computed stats are pure Python on resolved values. + + Returns ``(expr, errors)`` where ``errors`` collects construction + failures (typically empty). + + Rebind to a concrete source before executing. The replacement + source must be schema-compatible — same column names and dtypes + the unbound table was built with — or execute will fail when the + backend tries to resolve the per-column expressions:: + + unbound = xo.table(schema, name="t") + expr, _ = pipeline.compile_batch_expr(unbound) + bound = expr.op().replace({unbound.op(): real_source.op()}).to_expr() + df = bound.execute() + """ + agg_exprs, _batch_items, errors = self._build_batch_agg_exprs(table) + return table.aggregate(agg_exprs), errors + + def process_table(self, table) -> Tuple[SDType, List[StatError]]: + schema = table.schema() + columns = list(table.columns) + + # Pre-populate every column accumulator with the externally-provided + # keys. ``length`` is filled in by the batch query below. ``min`` / + # ``max`` start as None so dependents (histogram) don't cascade- + # exclude on non-numeric columns; ``min`` / ``max`` overwrite for + # numeric cols. + accumulators: Dict[str, Dict[str, StatResult]] = {} + for col in columns: + accumulators[col] = {"orig_col_name": Ok(col), "rewritten_col_name": Ok(col), "dtype": Ok(str(schema[col])), + "length": Ok(0), "min": Ok(None), "max": Ok(None)} + + # ---- Phase 1: batch aggregate ---------------------------------- + # ``length`` is a table-level scalar (same value for every column), + # so it goes in once as ``__total_length__`` rather than as N + # per-column expressions. + agg_exprs, batch_items, construction_errors = self._build_batch_agg_exprs(table) + # Mirror construction-time failures into per-column accumulators so + # downstream resolve sees the same Err shape it always has. The + # errors aren't appended to ``all_errors`` directly here — they + # reach the caller via ``resolve_accumulator`` (below), which + # converts every Err entry it walks into a StatError. Writing them + # twice would double-report. + for se in construction_errors: + sf = se.stat_func + if sf is None: + continue + for sk in sf.provides: + accumulators[se.column][sk.name] = Err( + error=se.error, stat_func_name=sf.name, + column_name=se.column, inputs={"col": se.column}) try: result_df = self._execute(table.aggregate(agg_exprs)) @@ -224,7 +308,7 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]: for sk in sf.provides: accumulators[col][sk.name] = Err(error=e, stat_func_name=sf.name, column_name=col, inputs={}) else: - total_length = _to_python_scalar(result_df[TOTAL_LENGTH_KEY].iloc[0]) + total_length = _to_python_scalar(result_df[self.TOTAL_LENGTH_KEY].iloc[0]) if total_length is None: total_length = 0 for col in columns: diff --git a/tests/unit/test_xorq_compile_batch_expr.py b/tests/unit/test_xorq_compile_batch_expr.py new file mode 100644 index 000000000..5ed843148 --- /dev/null +++ b/tests/unit/test_xorq_compile_batch_expr.py @@ -0,0 +1,163 @@ +"""Tests for ``XorqStatPipeline.compile_batch_expr``. + +Exports the Phase-1 batch aggregate as an ibis expression so callers can: + - pass an ``xo.table(schema, name=...)`` UnboundTable and get a portable, + reusable stat expression (catalog it, ship it, rebind later); + - inspect the SQL/plan via ``ibis.to_sql``; + - run it against a backend manually without going through ``process_table``. + +Histograms (Phase 2) are intentionally not in the result — they need +scalar min/max from Phase 1 and therefore can't be folded into one expr. +""" + +import pandas as pd +import pytest + +xo = pytest.importorskip("xorq.api") + +from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import ( # noqa: E402 + XorqStatPipeline, + XorqColumn) +from buckaroo.pluggable_analysis_framework.stat_func import stat # noqa: E402 +from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2 # noqa: E402 + + +SCHEMA = {"ints": "int64", "floats": "float64", "strs": "string", "bools": "boolean"} + + +def _unbound(): + return xo.table(SCHEMA, name="t") + + +def _real(): + return xo.memtable(pd.DataFrame( + {"ints": [1, 2, 3, 4, 5, 6, 7], "floats": [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7], + "strs": ["a", "b", "c", "d", "e", "f", "g"], "bools": [True, False, True, False, True, False, True]}), name="t") + + +def _rebind(expr, unbound, real): + """Substitute the UnboundTable op in ``expr`` with the real table's op.""" + return expr.op().replace({unbound.op(): real.op()}).to_expr() + + +class TestCompileBatchExpr: + def test_returns_unbound_when_given_unbound(self): + from xorq.vendor.ibis.expr.operations.relations import UnboundTable + + pipeline = XorqStatPipeline(XORQ_STATS_V2) + unbound = _unbound() + expr, errors = pipeline.compile_batch_expr(unbound) + assert errors == [] + # Walk the op tree directly — the expression must still reference + # the UnboundTable. Using ``op.find`` rather than repr-matching so + # the test doesn't depend on ibis' printable format. + found = expr.op().find(UnboundTable) + assert found, "compile_batch_expr should not bind the input table" + assert unbound.op() in found + + def test_output_columns_have_expected_names(self): + pipeline = XorqStatPipeline(XORQ_STATS_V2) + expr, _ = pipeline.compile_batch_expr(_unbound()) + names = set(expr.schema().names) + assert pipeline.TOTAL_LENGTH_KEY in names + # Every column gets null_count + distinct_count (no column_filter). + for col in SCHEMA: + assert f"{col}|null_count" in names + assert f"{col}|distinct_count" in names + # mean is numeric-not-bool → ints/floats only. + assert "ints|mean" in names + assert "floats|mean" in names + assert "strs|mean" not in names + assert "bools|mean" not in names + # min/max use _is_numeric_ibis → numeric per ibis dtype. In xorq's + # vendored ibis, boolean.is_numeric() is False, so bools are out + # alongside strs. + assert "ints|min" in names + assert "floats|max" in names + assert "strs|min" not in names + assert "bools|min" not in names + + def test_no_histogram_in_batch_expr(self): + """Histogram is Phase 2 — must not appear in the compiled batch expr.""" + pipeline = XorqStatPipeline(XORQ_STATS_V2) + expr, _ = pipeline.compile_batch_expr(_unbound()) + names = set(expr.schema().names) + for col in SCHEMA: + assert f"{col}|histogram" not in names + + def test_rebind_matches_process_table_batch_results(self): + """Rebinding the unbound expr to a real source and executing must yield + the same scalar values that process_table records in its accumulator + for the same batch-phase stats.""" + pipeline = XorqStatPipeline(XORQ_STATS_V2) + unbound = _unbound() + real = _real() + + expr, errors = pipeline.compile_batch_expr(unbound) + assert errors == [] + rebound = _rebind(expr, unbound, real) + df = rebound.execute() + + baseline, _ = pipeline.process_table(real) + + # __total_length__ matches `length` on every column accumulator. + assert int(df[pipeline.TOTAL_LENGTH_KEY].iloc[0]) == baseline["ints"]["length"] + + # Spot-check a few (col, stat) pairs against the accumulator. + for col, stat_name in [ + ("ints", "null_count"), + ("ints", "min"), + ("ints", "max"), + ("ints", "mean"), + ("floats", "median"), + ("strs", "distinct_count"), + ("bools", "null_count"), + ]: + key = f"{col}|{stat_name}" + got = df[key].iloc[0] + want = baseline[col][stat_name] + # Coerce numpy scalars; allow float NaN-tolerance not needed here. + assert float(got) == pytest.approx(float(want)), ( + f"mismatch on {key}: rebound={got} baseline={want}") + + def test_accepts_real_table_too(self): + """compile_batch_expr should accept a bound table — just produces a + bound aggregate expression (no rebind needed).""" + pipeline = XorqStatPipeline(XORQ_STATS_V2) + expr, errors = pipeline.compile_batch_expr(_real()) + assert errors == [] + df = expr.execute() + assert len(df) == 1 + assert int(df[pipeline.TOTAL_LENGTH_KEY].iloc[0]) == 7 + + def test_construction_error_surfaces_in_errors(self): + """A stat that raises while building its ibis expression must appear + in the returned errors list, not be silently dropped.""" + + @stat() + def broken_batch(col: XorqColumn) -> int: + raise RuntimeError("intentional construction failure") + + pipeline = XorqStatPipeline([*XORQ_STATS_V2, broken_batch], unit_test=False) + expr, errors = pipeline.compile_batch_expr(_unbound()) + # One error per column the stat would have been built for (no filter + # → every column). + assert len(errors) == len(SCHEMA) + for se in errors: + assert isinstance(se.error, RuntimeError) + assert se.stat_key == "broken_batch" + assert se.column in SCHEMA + # The expression should still compile — the broken stat is just absent. + names = set(expr.schema().names) + for col in SCHEMA: + assert f"{col}|broken_batch" not in names + + def test_process_table_still_works_after_refactor(self): + """compile_batch_expr is extracted from process_table's Phase 1. + process_table itself must continue to produce correct results.""" + pipeline = XorqStatPipeline(XORQ_STATS_V2) + stats, errors = pipeline.process_table(_real()) + assert errors == [] + assert stats["ints"]["length"] == 7 + assert stats["ints"]["mean"] == pytest.approx(4.0) + assert stats["strs"]["distinct_count"] == 7