Skip to content

Commit 1755a78

Browse files
committed
test: add alias-equivalence and length-mismatch tests for array fns
Pin the contracts the doctests don't cover: list_compact/list_normalize must produce the same output as their array_* primaries, and cosine_distance/inner_product must reject length-mismatched inputs at execution time.
1 parent 2a9292e commit 1755a78

1 file changed

Lines changed: 27 additions & 0 deletions

File tree

python/tests/test_functions.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,33 @@ def test_array_function_obj_tests(stmt, py_expr):
717717
assert a == b
718718

719719

720+
@pytest.mark.parametrize(
721+
("alias_fn", "primary_fn", "data"),
722+
[
723+
(f.list_compact, f.array_compact, [[1.0, None, 2.0, None, 3.0]]),
724+
(f.list_normalize, f.array_normalize, [[3.0, 4.0]]),
725+
],
726+
)
727+
def test_array_function_aliases(alias_fn, primary_fn, data):
728+
"""list_* helpers should be exact aliases for their array_* counterparts."""
729+
ctx = SessionContext()
730+
df = ctx.from_pydict({"a": data})
731+
alias_result = df.select(alias_fn(column("a")).alias("r")).collect()
732+
primary_result = df.select(primary_fn(column("a")).alias("r")).collect()
733+
assert (
734+
alias_result[0].column(0).to_pylist() == primary_result[0].column(0).to_pylist()
735+
)
736+
737+
738+
@pytest.mark.parametrize("fn", [f.cosine_distance, f.inner_product])
739+
def test_array_distance_length_mismatch_raises(fn):
740+
"""Length-mismatched inputs to vector distance fns should raise at execute."""
741+
ctx = SessionContext()
742+
df = ctx.from_pydict({"a": [[1.0, 2.0]], "b": [[1.0, 2.0, 3.0]]})
743+
with pytest.raises(Exception, match="same length"):
744+
df.select(fn(column("a"), column("b")).alias("r")).collect()
745+
746+
720747
@pytest.mark.parametrize(
721748
("args", "expected"),
722749
[

0 commit comments

Comments
 (0)