Skip to content

Commit 8f3d167

Browse files
timsaucerclaude
andcommitted
Make map the primary function with make_map as alias
map() now supports three calling conventions matching upstream: - map({"a": 1, "b": 2}) — from a Python dictionary - map([keys], [values]) — two lists that get zipped - map(k1, v1, k2, v2, ...) — variadic key-value pairs Non-Expr keys and values are automatically converted to literals. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fac2c24 commit 8f3d167

File tree

2 files changed

+86
-41
lines changed

2 files changed

+86
-41
lines changed

python/datafusion/functions.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@
202202
"make_date",
203203
"make_list",
204204
"make_map",
205+
"map",
205206
"map_entries",
206207
"map_extract",
207208
"map_keys",
@@ -3347,49 +3348,57 @@ def empty(array: Expr) -> Expr:
33473348
# map functions
33483349

33493350

3350-
def make_map(
3351-
data: dict[Any, Any] | None = None,
3352-
keys: list[Any] | None = None,
3353-
values: list[Any] | None = None,
3354-
) -> Expr:
3351+
def map(*args: Any) -> Expr:
33553352
"""Returns a map expression.
33563353
3357-
Can be called with either a Python dictionary or separate ``keys``
3358-
and ``values`` lists. Keys and values that are not already
3359-
:py:class:`~datafusion.expr.Expr` are automatically converted to
3360-
literal expressions.
3354+
Supports three calling conventions:
33613355
3362-
Args:
3363-
data: A Python dictionary of key-value pairs.
3364-
keys: A list of keys (use with ``values`` for column expressions).
3365-
values: A list of values (use with ``keys``).
3356+
- ``map({"a": 1, "b": 2})`` — from a Python dictionary.
3357+
- ``map([keys], [values])`` — two lists that get zipped.
3358+
- ``map(k1, v1, k2, v2, ...)`` — variadic key-value pairs.
3359+
3360+
Keys and values that are not already :py:class:`~datafusion.expr.Expr`
3361+
are automatically converted to literal expressions.
33663362
33673363
Examples:
33683364
>>> ctx = dfn.SessionContext()
33693365
>>> df = ctx.from_pydict({"a": [1]})
33703366
>>> result = df.select(
3371-
... dfn.functions.make_map({"a": 1, "b": 2}).alias("map"))
3372-
>>> result.collect_column("map")[0].as_py()
3367+
... dfn.functions.map({"a": 1, "b": 2}).alias("m"))
3368+
>>> result.collect_column("m")[0].as_py()
33733369
[('a', 1), ('b', 2)]
33743370
"""
3375-
if data is not None:
3376-
if keys is not None or values is not None:
3377-
msg = "Cannot specify both data and keys/values"
3378-
raise ValueError(msg)
3379-
key_list = list(data.keys())
3380-
value_list = list(data.values())
3381-
elif keys is not None and values is not None:
3382-
key_list = keys
3383-
value_list = values
3371+
if len(args) == 1 and isinstance(args[0], dict):
3372+
key_list = list(args[0].keys())
3373+
value_list = list(args[0].values())
3374+
elif (
3375+
len(args) == 2 # noqa: PLR2004
3376+
and isinstance(args[0], list)
3377+
and isinstance(args[1], list)
3378+
):
3379+
key_list = args[0]
3380+
value_list = args[1]
3381+
elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004
3382+
key_list = list(args[0::2])
3383+
value_list = list(args[1::2])
33843384
else:
3385-
msg = "Must specify either data or both keys and values"
3385+
msg = "map expects a dict, two lists, or an even number of key-value arguments"
33863386
raise ValueError(msg)
33873387

33883388
key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list]
33893389
val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list]
33903390
return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs]))
33913391

33923392

3393+
def make_map(*args: Any) -> Expr:
3394+
"""Returns a map expression.
3395+
3396+
See Also:
3397+
This is an alias for :py:func:`map`.
3398+
"""
3399+
return map(*args)
3400+
3401+
33933402
def map_keys(map: Expr) -> Expr:
33943403
"""Returns a list of all keys in the map.
33953404
@@ -3398,7 +3407,7 @@ def map_keys(map: Expr) -> Expr:
33983407
>>> df = ctx.from_pydict({"a": [1]})
33993408
>>> result = df.select(
34003409
... dfn.functions.map_keys(
3401-
... dfn.functions.make_map({"x": 1, "y": 2})
3410+
... dfn.functions.map({"x": 1, "y": 2})
34023411
... ).alias("keys"))
34033412
>>> result.collect_column("keys")[0].as_py()
34043413
['x', 'y']
@@ -3414,7 +3423,7 @@ def map_values(map: Expr) -> Expr:
34143423
>>> df = ctx.from_pydict({"a": [1]})
34153424
>>> result = df.select(
34163425
... dfn.functions.map_values(
3417-
... dfn.functions.make_map({"x": 1, "y": 2})
3426+
... dfn.functions.map({"x": 1, "y": 2})
34183427
... ).alias("vals"))
34193428
>>> result.collect_column("vals")[0].as_py()
34203429
[1, 2]
@@ -3430,7 +3439,7 @@ def map_extract(map: Expr, key: Expr) -> Expr:
34303439
>>> df = ctx.from_pydict({"a": [1]})
34313440
>>> result = df.select(
34323441
... dfn.functions.map_extract(
3433-
... dfn.functions.make_map({"x": 1, "y": 2}),
3442+
... dfn.functions.map({"x": 1, "y": 2}),
34343443
... dfn.lit("x"),
34353444
... ).alias("val"))
34363445
>>> result.collect_column("val")[0].as_py()
@@ -3447,7 +3456,7 @@ def map_entries(map: Expr) -> Expr:
34473456
>>> df = ctx.from_pydict({"a": [1]})
34483457
>>> result = df.select(
34493458
... dfn.functions.map_entries(
3450-
... dfn.functions.make_map({"x": 1, "y": 2})
3459+
... dfn.functions.map({"x": 1, "y": 2})
34513460
... ).alias("entries"))
34523461
>>> result.collect_column("entries")[0].as_py()
34533462
[{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}]

python/tests/test_functions.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -668,29 +668,29 @@ def test_array_function_obj_tests(stmt, py_expr):
668668
assert a == b
669669

670670

671-
def test_make_map():
671+
def test_map_from_dict():
672672
ctx = SessionContext()
673673
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
674674
df = ctx.create_dataframe([[batch]])
675675

676-
result = df.select(f.make_map({"x": 1, "y": 2}).alias("map")).collect()[0].column(0)
676+
result = df.select(f.map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0)
677677
assert result[0].as_py() == [("x", 1), ("y", 2)]
678678

679679

680-
def test_make_map_with_expr_values():
680+
def test_map_from_dict_with_expr_values():
681681
ctx = SessionContext()
682682
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
683683
df = ctx.create_dataframe([[batch]])
684684

685685
result = (
686-
df.select(f.make_map({"x": literal(1), "y": literal(2)}).alias("map"))
686+
df.select(f.map({"x": literal(1), "y": literal(2)}).alias("m"))
687687
.collect()[0]
688688
.column(0)
689689
)
690690
assert result[0].as_py() == [("x", 1), ("y", 2)]
691691

692692

693-
def test_make_map_with_column_data():
693+
def test_map_from_two_lists():
694694
ctx = SessionContext()
695695
batch = pa.RecordBatch.from_arrays(
696696
[
@@ -701,7 +701,7 @@ def test_make_map_with_column_data():
701701
)
702702
df = ctx.create_dataframe([[batch]])
703703

704-
m = f.make_map(keys=[column("keys")], values=[column("vals")])
704+
m = f.map([column("keys")], [column("vals")])
705705
result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0)
706706
for i, expected in enumerate(["k1", "k2", "k3"]):
707707
assert result[i].as_py() == [expected]
@@ -711,12 +711,48 @@ def test_make_map_with_column_data():
711711
assert result[i].as_py() == [expected]
712712

713713

714+
def test_map_from_variadic_pairs():
715+
ctx = SessionContext()
716+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
717+
df = ctx.create_dataframe([[batch]])
718+
719+
result = df.select(f.map("x", 1, "y", 2).alias("m")).collect()[0].column(0)
720+
assert result[0].as_py() == [("x", 1), ("y", 2)]
721+
722+
723+
def test_map_variadic_with_exprs():
724+
ctx = SessionContext()
725+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
726+
df = ctx.create_dataframe([[batch]])
727+
728+
result = (
729+
df.select(f.map(literal("x"), literal(1), literal("y"), literal(2)).alias("m"))
730+
.collect()[0]
731+
.column(0)
732+
)
733+
assert result[0].as_py() == [("x", 1), ("y", 2)]
734+
735+
736+
def test_map_odd_args_raises():
737+
with pytest.raises(ValueError, match="map expects"):
738+
f.map("x", 1, "y")
739+
740+
741+
def test_make_map_is_alias():
742+
ctx = SessionContext()
743+
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
744+
df = ctx.create_dataframe([[batch]])
745+
746+
result = df.select(f.make_map({"x": 1, "y": 2}).alias("m")).collect()[0].column(0)
747+
assert result[0].as_py() == [("x", 1), ("y", 2)]
748+
749+
714750
def test_map_keys():
715751
ctx = SessionContext()
716752
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
717753
df = ctx.create_dataframe([[batch]])
718754

719-
m = f.make_map({"x": 1, "y": 2})
755+
m = f.map({"x": 1, "y": 2})
720756
result = df.select(f.map_keys(m).alias("keys")).collect()[0].column(0)
721757
assert result[0].as_py() == ["x", "y"]
722758

@@ -726,7 +762,7 @@ def test_map_values():
726762
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
727763
df = ctx.create_dataframe([[batch]])
728764

729-
m = f.make_map({"x": 1, "y": 2})
765+
m = f.map({"x": 1, "y": 2})
730766
result = df.select(f.map_values(m).alias("vals")).collect()[0].column(0)
731767
assert result[0].as_py() == [1, 2]
732768

@@ -736,7 +772,7 @@ def test_map_extract():
736772
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
737773
df = ctx.create_dataframe([[batch]])
738774

739-
m = f.make_map({"x": 1, "y": 2})
775+
m = f.map({"x": 1, "y": 2})
740776
result = (
741777
df.select(f.map_extract(m, literal("x")).alias("val")).collect()[0].column(0)
742778
)
@@ -748,7 +784,7 @@ def test_map_extract_missing_key():
748784
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
749785
df = ctx.create_dataframe([[batch]])
750786

751-
m = f.make_map({"x": 1})
787+
m = f.map({"x": 1})
752788
result = (
753789
df.select(f.map_extract(m, literal("z")).alias("val")).collect()[0].column(0)
754790
)
@@ -760,7 +796,7 @@ def test_map_entries():
760796
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
761797
df = ctx.create_dataframe([[batch]])
762798

763-
m = f.make_map({"x": 1, "y": 2})
799+
m = f.map({"x": 1, "y": 2})
764800
result = df.select(f.map_entries(m).alias("entries")).collect()[0].column(0)
765801
assert result[0].as_py() == [
766802
{"key": "x", "value": 1},
@@ -773,7 +809,7 @@ def test_element_at():
773809
batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"])
774810
df = ctx.create_dataframe([[batch]])
775811

776-
m = f.make_map({"a": 10, "b": 20})
812+
m = f.map({"a": 10, "b": 20})
777813
result = (
778814
df.select(f.element_at(m, literal("b")).alias("val")).collect()[0].column(0)
779815
)

0 commit comments

Comments
 (0)