4141from datafusion ._internal import ParquetColumnOptions as ParquetColumnOptionsInternal
4242from datafusion ._internal import ParquetWriterOptions as ParquetWriterOptionsInternal
4343from datafusion .expr import (
44- EXPR_TYPE_ERROR ,
4544 Expr ,
4645 SortKey ,
47- _to_raw_expr ,
46+ ensure_expr ,
4847 expr_list_to_raw_expr_list ,
4948 sort_list_to_raw_sort_list ,
5049)
5958 import polars as pl
6059 import pyarrow as pa
6160
62- from datafusion ._internal import expr as expr_internal
63-
6461from enum import Enum
6562
6663
@@ -293,27 +290,6 @@ def __init__(
293290 self .bloom_filter_ndv = bloom_filter_ndv
294291
295292
296- def _ensure_expr (value : Expr | str ) -> expr_internal .Expr :
297- """Return the internal expression from ``Expr`` or raise ``TypeError``.
298-
299- This helper rejects plain strings so higher level APIs consistently
300- require explicit :func:`~datafusion.col` or :func:`~datafusion.lit`
301- expressions.
302-
303- Args:
304- value: Candidate expression.
305-
306- Returns:
307- The internal expression representation.
308-
309- Raises:
310- TypeError: If ``value`` is not an instance of :class:`Expr`.
311- """
312- if not isinstance (value , Expr ):
313- raise TypeError (EXPR_TYPE_ERROR )
314- return _to_raw_expr (value )
315-
316-
317293class DataFrame :
318294 """Two dimensional table representation of data.
319295
@@ -460,7 +436,7 @@ def filter(self, *predicates: Expr) -> DataFrame:
460436 """
461437 df = self .df
462438 for p in predicates :
463- df = df .filter (_ensure_expr (p ))
439+ df = df .filter (ensure_expr (p ))
464440 return DataFrame (df )
465441
466442 def with_column (self , name : str , expr : Expr ) -> DataFrame :
@@ -482,7 +458,7 @@ def with_column(self, name: str, expr: Expr) -> DataFrame:
482458 Returns:
483459 DataFrame with the new column.
484460 """
485- return DataFrame (self .df .with_column (name , _ensure_expr (expr )))
461+ return DataFrame (self .df .with_column (name , ensure_expr (expr )))
486462
487463 def with_columns (
488464 self , * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
@@ -512,30 +488,19 @@ def with_columns(
512488 DataFrame with the new columns added.
513489 """
514490
515- def _simplify_expression (
516- * exprs : Expr | Iterable [Expr ], ** named_exprs : Expr
517- ) -> list [expr_internal .Expr ]:
518- expr_list : list [expr_internal .Expr ] = []
519- for expr in exprs :
491+ def _iter_exprs (items : Iterable [Expr | Iterable [Expr ]]) -> Iterable [Expr | str ]:
492+ for expr in items :
520493 if isinstance (expr , str ):
521- raise TypeError (EXPR_TYPE_ERROR )
522- if isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
523- expr_value = list (expr )
524- if any (isinstance (inner , str ) for inner in expr_value ):
525- raise TypeError (EXPR_TYPE_ERROR )
494+ yield expr
495+ elif isinstance (expr , Iterable ) and not isinstance (expr , Expr ):
496+ yield from _iter_exprs (expr )
526497 else :
527- expr_value = expr
528- try :
529- expr_list .extend (expr_list_to_raw_expr_list (expr_value ))
530- except TypeError as err :
531- raise TypeError (EXPR_TYPE_ERROR ) from err
532- for alias , expr in named_exprs .items ():
533- if not isinstance (expr , Expr ):
534- raise TypeError (EXPR_TYPE_ERROR )
535- expr_list .append (expr .alias (alias ).expr )
536- return expr_list
537-
538- expressions = _simplify_expression (* exprs , ** named_exprs )
498+ yield expr
499+
500+ expressions = [ensure_expr (e ) for e in _iter_exprs (exprs )]
501+ for alias , expr in named_exprs .items ():
502+ ensure_expr (expr )
503+ expressions .append (expr .alias (alias ).expr )
539504
540505 return DataFrame (self .df .with_columns (expressions ))
541506
@@ -574,11 +539,7 @@ def aggregate(
574539 aggs_list = aggs if isinstance (aggs , list ) else [aggs ]
575540
576541 group_by_exprs = expr_list_to_raw_expr_list (group_by_list )
577- aggs_exprs = []
578- for agg in aggs_list :
579- if not isinstance (agg , Expr ):
580- raise TypeError (EXPR_TYPE_ERROR )
581- aggs_exprs .append (agg .expr )
542+ aggs_exprs = [ensure_expr (agg ) for agg in aggs_list ]
582543 return DataFrame (self .df .aggregate (group_by_exprs , aggs_exprs ))
583544
584545 def sort (self , * exprs : SortKey ) -> DataFrame :
@@ -824,7 +785,7 @@ def join_on(
824785 Returns:
825786 DataFrame after join.
826787 """
827- exprs = [_ensure_expr (expr ) for expr in on_exprs ]
788+ exprs = [ensure_expr (expr ) for expr in on_exprs ]
828789 return DataFrame (self .df .join_on (right .df , exprs , how ))
829790
830791 def explain (self , verbose : bool = False , analyze : bool = False ) -> None :
0 commit comments