Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions datajunction-server/datajunction_server/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,10 +3234,12 @@ class MapEntries(Function):

@MapEntries.register # type: ignore
def infer_type(map_: ct.MapType) -> ct.ColumnType:
from datajunction_server.sql.parsing import ast

return ct.ListType(
element_type=ct.StructType(
ct.NestedField("key", field_type=map_.type.key.type),
ct.NestedField("value", field_type=map_.type.value.type),
ct.NestedField(ast.Name("key"), field_type=map_.type.key.type),
ct.NestedField(ast.Name("value"), field_type=map_.type.value.type),
),
)

Expand Down Expand Up @@ -4810,6 +4812,68 @@ def infer_type(
return [arg.key, arg.value]


class Range(TableFunction):
"""
Spark SQL's `range` table-generating function.

Forms:
range(end)
range(start, end)
range(start, end, step)
range(start, end, step, numSlices)

Returns a single-column table with BIGINT column `id` containing
values in the specified range.
"""

dialects = [Dialect.SPARK]


def _range_schema() -> List[ct.NestedField]:
"""Build the single-column schema every Range arity returns. Inline
import of ast.Name avoids the module-level circular dep between
sql.functions and sql.parsing.ast."""
from datajunction_server.sql.parsing import ast

return [ct.NestedField(name=ast.Name("id"), field_type=ct.BigIntType())]


@Range.register
def infer_type(end: ct.IntegerBase) -> List[ct.NestedField]:
"""range(end) — generates 0 to end-1."""
return _range_schema()


@Range.register
def infer_type( # type: ignore
start: ct.IntegerBase,
end: ct.IntegerBase,
) -> List[ct.NestedField]:
"""range(start, end)."""
return _range_schema()


@Range.register
def infer_type( # type: ignore
start: ct.IntegerBase,
end: ct.IntegerBase,
step: ct.IntegerBase,
) -> List[ct.NestedField]:
"""range(start, end, step)."""
return _range_schema()


@Range.register
def infer_type( # type: ignore
start: ct.IntegerBase,
end: ct.IntegerBase,
step: ct.IntegerBase,
num_slices: ct.IntegerBase,
) -> List[ct.NestedField]:
"""range(start, end, step, numSlices)."""
return _range_schema()


class FunctionRegistryDict(dict):
"""
Custom dictionary mapping for functions
Expand Down
5 changes: 4 additions & 1 deletion datajunction-server/datajunction_server/sql/parsing/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2818,7 +2818,10 @@ async def compile(self, ctx):
return
self._is_compiled = True
types = await self._type(ctx)
for type, col in zip_longest(types, self.column_list):
# `column_list or []` covers the zero-column-alias case — e.g.
# `SELECT id FROM range(1, 10)` has no AS-list, so the function's
# inferred NestedField names (here, `id`) are used directly.
for type, col in zip_longest(types, self.column_list or []):
if self.column_list:
if (type is None) or (col is None):
ctx.exception.errors.append(
Expand Down
15 changes: 3 additions & 12 deletions datajunction-server/datajunction_server/sql/parsing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,29 +250,20 @@ def __init__(
is_optional: bool = True,
doc: Optional[str] = None,
):
if isinstance(name, str): # pragma: no cover
from datajunction_server.sql.parsing.ast import Name

name_str = name
name_obj = Name(name_str)
else:
name_str = name.name
name_obj = name

doc_string = "" if doc is None else f", doc={repr(doc)}"
super().__init__(
(
f"{name_str} {field_type}"
f"{name.name} {field_type}"
f"{' NOT NULL' if not is_optional else ''}"
+ ("" if doc is None else f" {doc}")
),
f"NestedField(name={repr(name_obj)}, "
f"NestedField(name={repr(name)}, "
f"field_type={repr(field_type)}, "
f"is_optional={is_optional}"
f"{doc_string})",
)
object.__setattr__(self, "_is_optional", is_optional)
object.__setattr__(self, "_name", name_obj)
object.__setattr__(self, "_name", name)
object.__setattr__(self, "_type", field_type)
object.__setattr__(self, "_doc", doc)

Expand Down
21 changes: 21 additions & 0 deletions datajunction-server/tests/sql/functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,27 @@ async def test_explode_outer_func(session: AsyncSession):
assert query.select.projection[1].type == ct.StringType() # type: ignore


@pytest.mark.asyncio
@pytest.mark.parametrize(
"call",
[
"range(10)",
"range(1, 10)",
"range(1, 10, 2)",
"range(1, 10, 2, 4)",
],
)
async def test_range(session: AsyncSession, call: str):
"""Spark's ``range(...)`` table function: all four arities expose a single
``id`` column typed BIGINT, and the generated query compiles clean."""
query = parse(f"SELECT id FROM {call}")
exc = DJException()
ctx = ast.CompileContext(session=session, exception=exc)
await query.compile(ctx)
assert not exc.errors, f"unexpected errors for {call!r}: {exc.errors}"
assert query.select.projection[0].type == ct.BigIntType() # type: ignore


@pytest.mark.asyncio
async def test_expm1_func(session: AsyncSession):
"""
Expand Down
Loading