From 978a1122a4881f0c352ed629d20c86d20668ea4c Mon Sep 17 00:00:00 2001 From: Anh Le Date: Wed, 25 Feb 2026 17:01:42 -0800 Subject: [PATCH 1/4] Add range function --- .../datajunction_server/sql/functions.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/datajunction-server/datajunction_server/sql/functions.py b/datajunction-server/datajunction_server/sql/functions.py index 5adc61b31..3fabdae53 100644 --- a/datajunction-server/datajunction_server/sql/functions.py +++ b/datajunction-server/datajunction_server/sql/functions.py @@ -4810,6 +4810,25 @@ def infer_type( return [arg.key, arg.value] +class Range(TableFunction): + """ + range(start[, end[, step[, numSlices]]]) / range(end) + Returns a table with a single BIGINT column `id` containing values + within the specified range. + """ + + dialects = [Dialect.SPARK] + + +@Range.register +def infer_type( + *args: ct.IntegerBase, +) -> List[ct.NestedField]: + from datajunction_server.sql.parsing.ast import Name + + return [ct.NestedField(name=Name("id"), field_type=ct.BigIntType())] + + class FunctionRegistryDict(dict): """ Custom dictionary mapping for functions From df724fb142899ff650ce9d2bbcc88a8c798bc7b1 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Sun, 1 Mar 2026 11:30:16 -0800 Subject: [PATCH 2/4] Fix --- .../datajunction_server/sql/functions.py | 38 +++++++++++++++++-- .../datajunction_server/sql/parsing/ast.py | 9 ++++- .../datajunction_server/sql/parsing/types.py | 3 +- .../tests/sql/functions_test.py | 15 ++++++++ 4 files changed, 58 insertions(+), 7 deletions(-) diff --git a/datajunction-server/datajunction_server/sql/functions.py b/datajunction-server/datajunction_server/sql/functions.py index 3fabdae53..c5e5081ee 100644 --- a/datajunction-server/datajunction_server/sql/functions.py +++ b/datajunction-server/datajunction_server/sql/functions.py @@ -4821,12 +4821,42 @@ class Range(TableFunction): @Range.register -def infer_type( - *args: ct.IntegerBase, +def infer_type(end: ct.IntegerBase) -> List[ct.NestedField]: + """range(end) - generates 0 to end-1""" + return [ct.NestedField(name="id", field_type=ct.BigIntType())] + + +@Range.register +def infer_type( # type: ignore + start: ct.IntegerBase, + end: ct.IntegerBase, +) -> List[ct.NestedField]: + """range(start, end)""" + return [ct.NestedField(name="id", field_type=ct.BigIntType())] + + +@Range.register +def infer_type( # type: ignore + start: ct.IntegerBase, + end: ct.IntegerBase, + step: ct.IntegerBase, ) -> List[ct.NestedField]: - from datajunction_server.sql.parsing.ast import Name + """range(start, end, step)""" + print(f"DEBUG: Range.infer_type called with start={start}, end={end}, step={step}") + result = [ct.NestedField(name="id", field_type=ct.BigIntType())] + print(f"DEBUG: Returning {result}") + return result + - return [ct.NestedField(name=Name("id"), field_type=ct.BigIntType())] +@Range.register +def infer_type( # type: ignore + start: ct.IntegerBase, + end: ct.IntegerBase, + step: ct.IntegerBase, + num_slices: ct.IntegerBase, +) -> List[ct.NestedField]: + """range(start, end, step, numSlices)""" + return [ct.NestedField(name="id", field_type=ct.BigIntType())] class FunctionRegistryDict(dict): diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index c644901fa..404ed2815 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -2811,14 +2811,19 @@ async def _type(self, ctx: Optional[CompileContext] = None) -> List[NestedField] if ctx: await arg.compile(ctx) arg_types.append(arg.type) - return dj_func.infer_type(*arg_types) + print(f"DEBUG _type: About to call {name}.infer_type") + result = dj_func.infer_type(*arg_types) + print(f"DEBUG _type: result={result}") + return result async def compile(self, ctx): if self.is_compiled(): return self._is_compiled = True types = await self._type(ctx) - for type, col in zip_longest(types, self.column_list): + print(f"DEBUG compile: types={types}") + print(f"DEBUG compile: self.column_list={self.column_list}") + for type, col in zip_longest(types, self.column_list or []): if self.column_list: if (type is None) or (col is None): ctx.exception.errors.append( diff --git a/datajunction-server/datajunction_server/sql/parsing/types.py b/datajunction-server/datajunction_server/sql/parsing/types.py index 8b7230309..50bc27609 100644 --- a/datajunction-server/datajunction_server/sql/parsing/types.py +++ b/datajunction-server/datajunction_server/sql/parsing/types.py @@ -20,6 +20,7 @@ Dict, Optional, Tuple, + Union, cast, Callable, ) @@ -245,7 +246,7 @@ class NestedField(ColumnType): def __init__( self, - name: "ast.Name", + name: Union["ast.Name", str], field_type: ColumnType, is_optional: bool = True, doc: Optional[str] = None, diff --git a/datajunction-server/tests/sql/functions_test.py b/datajunction-server/tests/sql/functions_test.py index ed21d4d6e..43976a395 100644 --- a/datajunction-server/tests/sql/functions_test.py +++ b/datajunction-server/tests/sql/functions_test.py @@ -1608,6 +1608,21 @@ async def test_explode_outer_func(session: AsyncSession): assert query.select.projection[1].type == ct.StringType() # type: ignore +@pytest.mark.asyncio +async def test_range(session: AsyncSession): + """ + Test the `range` function + """ + query = parse( + "SELECT id FROM range(1, 10, 2)", + ) + exc = DJException() + ctx = ast.CompileContext(session=session, exception=exc) + await query.compile(ctx) + assert not exc.errors + assert query.select.projection[0].type == ct.IntegerType() # type: ignore + + @pytest.mark.asyncio async def test_expm1_func(session: AsyncSession): """ From 6008644ebd5ea473bc672a29d36ebfe099c1d919 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Fri, 24 Apr 2026 08:58:17 -0700 Subject: [PATCH 3/4] Minor fixes to range --- .../datajunction_server/sql/functions.py | 41 ++++++++++++------- .../datajunction_server/sql/parsing/ast.py | 10 ++--- .../datajunction_server/sql/parsing/types.py | 18 ++------ .../tests/sql/functions_test.py | 24 +++++++---- 4 files changed, 50 insertions(+), 43 deletions(-) diff --git a/datajunction-server/datajunction_server/sql/functions.py b/datajunction-server/datajunction_server/sql/functions.py index c5e5081ee..f1230a9a3 100644 --- a/datajunction-server/datajunction_server/sql/functions.py +++ b/datajunction-server/datajunction_server/sql/functions.py @@ -4812,18 +4812,34 @@ def infer_type( class Range(TableFunction): """ - range(start[, end[, step[, numSlices]]]) / range(end) - Returns a table with a single BIGINT column `id` containing values - within the specified range. + Spark SQL's `range` table-generating function. + + Forms: + range(end) + range(start, end) + range(start, end, step) + range(start, end, step, numSlices) + + Returns a single-column table with BIGINT column `id` containing + values in the specified range. """ dialects = [Dialect.SPARK] +def _range_schema() -> List[ct.NestedField]: + """Build the single-column schema every Range arity returns. Inline + import of ast.Name avoids the module-level circular dep between + sql.functions and sql.parsing.ast.""" + from datajunction_server.sql.parsing import ast + + return [ct.NestedField(name=ast.Name("id"), field_type=ct.BigIntType())] + + @Range.register def infer_type(end: ct.IntegerBase) -> List[ct.NestedField]: - """range(end) - generates 0 to end-1""" - return [ct.NestedField(name="id", field_type=ct.BigIntType())] + """range(end) — generates 0 to end-1.""" + return _range_schema() @Range.register @@ -4831,8 +4847,8 @@ def infer_type( # type: ignore start: ct.IntegerBase, end: ct.IntegerBase, ) -> List[ct.NestedField]: - """range(start, end)""" - return [ct.NestedField(name="id", field_type=ct.BigIntType())] + """range(start, end).""" + return _range_schema() @Range.register @@ -4841,11 +4857,8 @@ def infer_type( # type: ignore end: ct.IntegerBase, step: ct.IntegerBase, ) -> List[ct.NestedField]: - """range(start, end, step)""" - print(f"DEBUG: Range.infer_type called with start={start}, end={end}, step={step}") - result = [ct.NestedField(name="id", field_type=ct.BigIntType())] - print(f"DEBUG: Returning {result}") - return result + """range(start, end, step).""" + return _range_schema() @Range.register @@ -4855,8 +4868,8 @@ def infer_type( # type: ignore step: ct.IntegerBase, num_slices: ct.IntegerBase, ) -> List[ct.NestedField]: - """range(start, end, step, numSlices)""" - return [ct.NestedField(name="id", field_type=ct.BigIntType())] + """range(start, end, step, numSlices).""" + return _range_schema() class FunctionRegistryDict(dict): diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index 404ed2815..ad570048e 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -2811,18 +2811,16 @@ async def _type(self, ctx: Optional[CompileContext] = None) -> List[NestedField] if ctx: await arg.compile(ctx) arg_types.append(arg.type) - print(f"DEBUG _type: About to call {name}.infer_type") - result = dj_func.infer_type(*arg_types) - print(f"DEBUG _type: result={result}") - return result + return dj_func.infer_type(*arg_types) async def compile(self, ctx): if self.is_compiled(): return self._is_compiled = True types = await self._type(ctx) - print(f"DEBUG compile: types={types}") - print(f"DEBUG compile: self.column_list={self.column_list}") + # `column_list or []` covers the zero-column-alias case — e.g. + # `SELECT id FROM range(1, 10)` has no AS-list, so the function's + # inferred NestedField names (here, `id`) are used directly. for type, col in zip_longest(types, self.column_list or []): if self.column_list: if (type is None) or (col is None): diff --git a/datajunction-server/datajunction_server/sql/parsing/types.py b/datajunction-server/datajunction_server/sql/parsing/types.py index 50bc27609..79a2c0a76 100644 --- a/datajunction-server/datajunction_server/sql/parsing/types.py +++ b/datajunction-server/datajunction_server/sql/parsing/types.py @@ -20,7 +20,6 @@ Dict, Optional, Tuple, - Union, cast, Callable, ) @@ -246,34 +245,25 @@ class NestedField(ColumnType): def __init__( self, - name: Union["ast.Name", str], + name: "ast.Name", field_type: ColumnType, is_optional: bool = True, doc: Optional[str] = None, ): - if isinstance(name, str): # pragma: no cover - from datajunction_server.sql.parsing.ast import Name - - name_str = name - name_obj = Name(name_str) - else: - name_str = name.name - name_obj = name - doc_string = "" if doc is None else f", doc={repr(doc)}" super().__init__( ( - f"{name_str} {field_type}" + f"{name.name} {field_type}" f"{' NOT NULL' if not is_optional else ''}" + ("" if doc is None else f" {doc}") ), - f"NestedField(name={repr(name_obj)}, " + f"NestedField(name={repr(name)}, " f"field_type={repr(field_type)}, " f"is_optional={is_optional}" f"{doc_string})", ) object.__setattr__(self, "_is_optional", is_optional) - object.__setattr__(self, "_name", name_obj) + object.__setattr__(self, "_name", name) object.__setattr__(self, "_type", field_type) object.__setattr__(self, "_doc", doc) diff --git a/datajunction-server/tests/sql/functions_test.py b/datajunction-server/tests/sql/functions_test.py index 43976a395..b9876d34b 100644 --- a/datajunction-server/tests/sql/functions_test.py +++ b/datajunction-server/tests/sql/functions_test.py @@ -1609,18 +1609,24 @@ async def test_explode_outer_func(session: AsyncSession): @pytest.mark.asyncio -async def test_range(session: AsyncSession): - """ - Test the `range` function - """ - query = parse( - "SELECT id FROM range(1, 10, 2)", - ) +@pytest.mark.parametrize( + "call", + [ + "range(10)", + "range(1, 10)", + "range(1, 10, 2)", + "range(1, 10, 2, 4)", + ], +) +async def test_range(session: AsyncSession, call: str): + """Spark's ``range(...)`` table function: all four arities expose a single + ``id`` column typed BIGINT, and the generated query compiles clean.""" + query = parse(f"SELECT id FROM {call}") exc = DJException() ctx = ast.CompileContext(session=session, exception=exc) await query.compile(ctx) - assert not exc.errors - assert query.select.projection[0].type == ct.IntegerType() # type: ignore + assert not exc.errors, f"unexpected errors for {call!r}: {exc.errors}" + assert query.select.projection[0].type == ct.BigIntType() # type: ignore @pytest.mark.asyncio From 1d4bf15f4c4c1807d69bcb2ad6f35ff7d30f7a26 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Fri, 24 Apr 2026 08:59:39 -0700 Subject: [PATCH 4/4] Fix nested field construction --- datajunction-server/datajunction_server/sql/functions.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datajunction-server/datajunction_server/sql/functions.py b/datajunction-server/datajunction_server/sql/functions.py index f1230a9a3..81df63e6c 100644 --- a/datajunction-server/datajunction_server/sql/functions.py +++ b/datajunction-server/datajunction_server/sql/functions.py @@ -3234,10 +3234,12 @@ class MapEntries(Function): @MapEntries.register # type: ignore def infer_type(map_: ct.MapType) -> ct.ColumnType: + from datajunction_server.sql.parsing import ast + return ct.ListType( element_type=ct.StructType( - ct.NestedField("key", field_type=map_.type.key.type), - ct.NestedField("value", field_type=map_.type.value.type), + ct.NestedField(ast.Name("key"), field_type=map_.type.key.type), + ct.NestedField(ast.Name("value"), field_type=map_.type.value.type), ), )