Skip to content

Commit 1bbb6ad

Browse files
committed
feat: add ensure_expr helper function for internal expression validation
1 parent 29d53a7 commit 1bbb6ad

File tree

3 files changed

+46
-55
lines changed

3 files changed

+46
-55
lines changed

python/datafusion/dataframe.py

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@
4141
from datafusion._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242
from datafusion._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343
from datafusion.expr import (
44-
EXPR_TYPE_ERROR,
4544
Expr,
4645
SortKey,
47-
_to_raw_expr,
46+
ensure_expr,
4847
expr_list_to_raw_expr_list,
4948
sort_list_to_raw_sort_list,
5049
)
@@ -59,8 +58,6 @@
5958
import polars as pl
6059
import pyarrow as pa
6160

62-
from datafusion._internal import expr as expr_internal
63-
6461
from enum import Enum
6562

6663

@@ -293,27 +290,6 @@ def __init__(
293290
self.bloom_filter_ndv = bloom_filter_ndv
294291

295292

296-
def _ensure_expr(value: Expr | str) -> expr_internal.Expr:
297-
"""Return the internal expression from ``Expr`` or raise ``TypeError``.
298-
299-
This helper rejects plain strings so higher level APIs consistently
300-
require explicit :func:`~datafusion.col` or :func:`~datafusion.lit`
301-
expressions.
302-
303-
Args:
304-
value: Candidate expression.
305-
306-
Returns:
307-
The internal expression representation.
308-
309-
Raises:
310-
TypeError: If ``value`` is not an instance of :class:`Expr`.
311-
"""
312-
if not isinstance(value, Expr):
313-
raise TypeError(EXPR_TYPE_ERROR)
314-
return _to_raw_expr(value)
315-
316-
317293
class DataFrame:
318294
"""Two dimensional table representation of data.
319295
@@ -460,7 +436,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
460436
"""
461437
df = self.df
462438
for p in predicates:
463-
df = df.filter(_ensure_expr(p))
439+
df = df.filter(ensure_expr(p))
464440
return DataFrame(df)
465441

466442
def with_column(self, name: str, expr: Expr) -> DataFrame:
@@ -482,7 +458,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
482458
Returns:
483459
DataFrame with the new column.
484460
"""
485-
return DataFrame(self.df.with_column(name, _ensure_expr(expr)))
461+
return DataFrame(self.df.with_column(name, ensure_expr(expr)))
486462

487463
def with_columns(
488464
self, *exprs: Expr | Iterable[Expr], **named_exprs: Expr
@@ -512,30 +488,19 @@ def with_columns(
512488
DataFrame with the new columns added.
513489
"""
514490

515-
def _simplify_expression(
516-
*exprs: Expr | Iterable[Expr], **named_exprs: Expr
517-
) -> list[expr_internal.Expr]:
518-
expr_list: list[expr_internal.Expr] = []
519-
for expr in exprs:
491+
def _iter_exprs(items: Iterable[Expr | Iterable[Expr]]) -> Iterable[Expr | str]:
492+
for expr in items:
520493
if isinstance(expr, str):
521-
raise TypeError(EXPR_TYPE_ERROR)
522-
if isinstance(expr, Iterable) and not isinstance(expr, Expr):
523-
expr_value = list(expr)
524-
if any(isinstance(inner, str) for inner in expr_value):
525-
raise TypeError(EXPR_TYPE_ERROR)
494+
yield expr
495+
elif isinstance(expr, Iterable) and not isinstance(expr, Expr):
496+
yield from _iter_exprs(expr)
526497
else:
527-
expr_value = expr
528-
try:
529-
expr_list.extend(expr_list_to_raw_expr_list(expr_value))
530-
except TypeError as err:
531-
raise TypeError(EXPR_TYPE_ERROR) from err
532-
for alias, expr in named_exprs.items():
533-
if not isinstance(expr, Expr):
534-
raise TypeError(EXPR_TYPE_ERROR)
535-
expr_list.append(expr.alias(alias).expr)
536-
return expr_list
537-
538-
expressions = _simplify_expression(*exprs, **named_exprs)
498+
yield expr
499+
500+
expressions = [ensure_expr(e) for e in _iter_exprs(exprs)]
501+
for alias, expr in named_exprs.items():
502+
ensure_expr(expr)
503+
expressions.append(expr.alias(alias).expr)
539504

540505
return DataFrame(self.df.with_columns(expressions))
541506

@@ -574,11 +539,7 @@ def aggregate(
574539
aggs_list = aggs if isinstance(aggs, list) else [aggs]
575540

576541
group_by_exprs = expr_list_to_raw_expr_list(group_by_list)
577-
aggs_exprs = []
578-
for agg in aggs_list:
579-
if not isinstance(agg, Expr):
580-
raise TypeError(EXPR_TYPE_ERROR)
581-
aggs_exprs.append(agg.expr)
542+
aggs_exprs = [ensure_expr(agg) for agg in aggs_list]
582543
return DataFrame(self.df.aggregate(group_by_exprs, aggs_exprs))
583544

584545
def sort(self, *exprs: SortKey) -> DataFrame:
@@ -824,7 +785,7 @@ def join_on(
824785
Returns:
825786
DataFrame after join.
826787
"""
827-
exprs = [_ensure_expr(expr) for expr in on_exprs]
788+
exprs = [ensure_expr(expr) for expr in on_exprs]
828789
return DataFrame(self.df.join_on(right.df, exprs, how))
829790

830791
def explain(self, verbose: bool = False, analyze: bool = False) -> None:

python/datafusion/expr.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,31 @@
218218
"WindowExpr",
219219
"WindowFrame",
220220
"WindowFrameBound",
221+
"ensure_expr",
221222
]
222223

223224

225+
def ensure_expr(value: Expr) -> expr_internal.Expr:
226+
"""Return the internal expression from ``Expr`` or raise ``TypeError``.
227+
228+
This helper rejects plain strings so higher level APIs consistently
229+
require explicit :func:`~datafusion.col` or :func:`~datafusion.lit`
230+
expressions.
231+
232+
Args:
233+
value: Candidate expression.
234+
235+
Returns:
236+
The internal expression representation.
237+
238+
Raises:
239+
TypeError: If ``value`` is not an instance of :class:`Expr`.
240+
"""
241+
if not isinstance(value, Expr):
242+
raise TypeError(EXPR_TYPE_ERROR)
243+
return value.expr
244+
245+
224246
def _to_raw_expr(value: Expr | str) -> expr_internal.Expr:
225247
"""Convert a Python expression or column name to its raw variant.
226248

python/tests/test_expr.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
TransactionEnd,
4848
TransactionStart,
4949
Values,
50+
ensure_expr,
5051
)
5152

5253

@@ -880,3 +881,10 @@ def test_literal_metadata(ctx):
880881
for expected_field in expected_schema:
881882
actual_field = result[0].schema.field(expected_field.name)
882883
assert expected_field.metadata == actual_field.metadata
884+
885+
886+
def test_ensure_expr():
887+
e = col("a")
888+
assert ensure_expr(e) is e.expr
889+
with pytest.raises(TypeError, match=r"Use col\(\) or lit\(\)"):
890+
ensure_expr("a")

0 commit comments

Comments
 (0)