diff --git a/datajunction-server/datajunction_server/sql/functions.py b/datajunction-server/datajunction_server/sql/functions.py index 5adc61b31..81df63e6c 100644 --- a/datajunction-server/datajunction_server/sql/functions.py +++ b/datajunction-server/datajunction_server/sql/functions.py @@ -3234,10 +3234,12 @@ class MapEntries(Function): @MapEntries.register # type: ignore def infer_type(map_: ct.MapType) -> ct.ColumnType: + from datajunction_server.sql.parsing import ast + return ct.ListType( element_type=ct.StructType( - ct.NestedField("key", field_type=map_.type.key.type), - ct.NestedField("value", field_type=map_.type.value.type), + ct.NestedField(ast.Name("key"), field_type=map_.type.key.type), + ct.NestedField(ast.Name("value"), field_type=map_.type.value.type), ), ) @@ -4810,6 +4812,68 @@ def infer_type( return [arg.key, arg.value] +class Range(TableFunction): + """ + Spark SQL's `range` table-generating function. + + Forms: + range(end) + range(start, end) + range(start, end, step) + range(start, end, step, numSlices) + + Returns a single-column table with BIGINT column `id` containing + values in the specified range. + """ + + dialects = [Dialect.SPARK] + + +def _range_schema() -> List[ct.NestedField]: + """Build the single-column schema every Range arity returns. Inline + import of ast.Name avoids the module-level circular dep between + sql.functions and sql.parsing.ast.""" + from datajunction_server.sql.parsing import ast + + return [ct.NestedField(name=ast.Name("id"), field_type=ct.BigIntType())] + + +@Range.register +def infer_type(end: ct.IntegerBase) -> List[ct.NestedField]: + """range(end) — generates 0 to end-1.""" + return _range_schema() + + +@Range.register +def infer_type( # type: ignore + start: ct.IntegerBase, + end: ct.IntegerBase, +) -> List[ct.NestedField]: + """range(start, end).""" + return _range_schema() + + +@Range.register +def infer_type( # type: ignore + start: ct.IntegerBase, + end: ct.IntegerBase, + step: ct.IntegerBase, +) -> List[ct.NestedField]: + """range(start, end, step).""" + return _range_schema() + + +@Range.register +def infer_type( # type: ignore + start: ct.IntegerBase, + end: ct.IntegerBase, + step: ct.IntegerBase, + num_slices: ct.IntegerBase, +) -> List[ct.NestedField]: + """range(start, end, step, numSlices).""" + return _range_schema() + + class FunctionRegistryDict(dict): """ Custom dictionary mapping for functions diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index c644901fa..ad570048e 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -2818,7 +2818,10 @@ async def compile(self, ctx): return self._is_compiled = True types = await self._type(ctx) - for type, col in zip_longest(types, self.column_list): + # `column_list or []` covers the zero-column-alias case — e.g. + # `SELECT id FROM range(1, 10)` has no AS-list, so the function's + # inferred NestedField names (here, `id`) are used directly. + for type, col in zip_longest(types, self.column_list or []): if self.column_list: if (type is None) or (col is None): ctx.exception.errors.append( diff --git a/datajunction-server/datajunction_server/sql/parsing/types.py b/datajunction-server/datajunction_server/sql/parsing/types.py index 8b7230309..79a2c0a76 100644 --- a/datajunction-server/datajunction_server/sql/parsing/types.py +++ b/datajunction-server/datajunction_server/sql/parsing/types.py @@ -250,29 +250,20 @@ def __init__( is_optional: bool = True, doc: Optional[str] = None, ): - if isinstance(name, str): # pragma: no cover - from datajunction_server.sql.parsing.ast import Name - - name_str = name - name_obj = Name(name_str) - else: - name_str = name.name - name_obj = name - doc_string = "" if doc is None else f", doc={repr(doc)}" super().__init__( ( - f"{name_str} {field_type}" + f"{name.name} {field_type}" f"{' NOT NULL' if not is_optional else ''}" + ("" if doc is None else f" {doc}") ), - f"NestedField(name={repr(name_obj)}, " + f"NestedField(name={repr(name)}, " f"field_type={repr(field_type)}, " f"is_optional={is_optional}" f"{doc_string})", ) object.__setattr__(self, "_is_optional", is_optional) - object.__setattr__(self, "_name", name_obj) + object.__setattr__(self, "_name", name) object.__setattr__(self, "_type", field_type) object.__setattr__(self, "_doc", doc) diff --git a/datajunction-server/tests/sql/functions_test.py b/datajunction-server/tests/sql/functions_test.py index ed21d4d6e..b9876d34b 100644 --- a/datajunction-server/tests/sql/functions_test.py +++ b/datajunction-server/tests/sql/functions_test.py @@ -1608,6 +1608,27 @@ async def test_explode_outer_func(session: AsyncSession): assert query.select.projection[1].type == ct.StringType() # type: ignore +@pytest.mark.asyncio +@pytest.mark.parametrize( + "call", + [ + "range(10)", + "range(1, 10)", + "range(1, 10, 2)", + "range(1, 10, 2, 4)", + ], +) +async def test_range(session: AsyncSession, call: str): + """Spark's ``range(...)`` table function: all four arities expose a single + ``id`` column typed BIGINT, and the generated query compiles clean.""" + query = parse(f"SELECT id FROM {call}") + exc = DJException() + ctx = ast.CompileContext(session=session, exception=exc) + await query.compile(ctx) + assert not exc.errors, f"unexpected errors for {call!r}: {exc.errors}" + assert query.select.projection[0].type == ct.BigIntType() # type: ignore + + @pytest.mark.asyncio async def test_expm1_func(session: AsyncSession): """