Skip to content

Commit d5ba35e

Browse files
rustyconoverclaude
andcommitted
feat: decode union-typed table varargs as TaggedUnion
Table-function varargs were always collected as raw pa.Scalar objects, which drops the active-member discriminator of a union value. Keyed on a declared union arrow_type, decode each vararg via _scalar_to_py() so it arrives as a TaggedUnion (matching how non-vararg union args resolve); the raw-scalar contract is untouched for every other type. Add the union_varargs fixture exercising this end-to-end (echoes each union vararg's active tag + value). Incidental: tidy the SecretsAccessor .to_dict() return annotation and reflow copy_from.py to formatter width. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 1cbc777 commit d5ba35e

5 files changed

Lines changed: 134 additions & 18 deletions

File tree

vgi/_test_fixtures/copy_from.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
from vgi.copy_from_function import CopyFromFunction
2626

2727
if TYPE_CHECKING:
28-
from vgi.table_function import ProcessParams
2928
from vgi_rpc.rpc import OutputCollector
3029

30+
from vgi.table_function import ProcessParams
31+
3132
__all__ = ["ExampleLinesCopyFromFunction"]
3233

3334

@@ -37,9 +38,7 @@ class ExampleLinesCopyFromArgs:
3738

3839
null_string: Annotated[str, Arg("null_string", doc="Token parsed as SQL NULL")]
3940
delimiter: Annotated[str, Arg("delimiter", default=",", doc="Field separator")] = ","
40-
skip_rows: Annotated[
41-
int, Arg("skip_rows", default=0, ge=0, doc="Leading lines to skip before data")
42-
] = 0
41+
skip_rows: Annotated[int, Arg("skip_rows", default=0, ge=0, doc="Leading lines to skip before data")] = 0
4342
on_error: Annotated[
4443
str,
4544
Arg(
@@ -87,9 +86,7 @@ def read(
8786
if len(cells) != ncols:
8887
if options.on_error == "skip":
8988
continue
90-
raise ValueError(
91-
f"example_lines: row has {len(cells)} fields, expected {ncols}: {line!r}"
92-
)
89+
raise ValueError(f"example_lines: row has {len(cells)} fields, expected {ncols}: {line!r}")
9390
rows.append(cells)
9491

9592
# Column-major string arrays, NULL where the cell equals null_string,

vgi/_test_fixtures/table/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
MakePairsStrFunction,
7979
RepeatValueIntFunction,
8080
RepeatValueStrFunction,
81+
UnionVarargsFunction,
8182
)
8283
from vgi._test_fixtures.table.partition_columns import (
8384
CountryPartitionedSalesFunction,
@@ -199,6 +200,7 @@
199200
"RegionYearPartitionedFunction",
200201
"RepeatValueIntFunction",
201202
"RepeatValueStrFunction",
203+
"UnionVarargsFunction",
202204
"RFF_MULTI_COLUMNS",
203205
"RFF_NESTED_COLUMNS",
204206
"RFF_NONE_COLUMNS",

vgi/_test_fixtures/table/pairs.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vgi._test_fixtures.table._common import (
1515
_cardinality_from_count,
1616
)
17-
from vgi.arguments import Arg
17+
from vgi.arguments import Arg, TaggedUnion
1818
from vgi.invocation import BindResponse
1919
from vgi.metadata import FunctionExample
2020
from vgi.schema_utils import schema
@@ -447,3 +447,109 @@ def process(
447447
data = {f"v{i}": col for i, col in enumerate(state.rows)}
448448
out_schema = schema({f"v{i}": pa.string() for i in range(len(state.rows))})
449449
out.emit(pa.RecordBatch.from_pydict(data, schema=out_schema))
450+
451+
452+
# ============================================================================
453+
454+
# Sparse union shared by every union_varargs argument. DuckDB only ever emits
455+
# sparse unions (+us:) over Arrow, so this round-trips end-to-end.
456+
UNION_VARARGS_TYPE = pa.sparse_union([pa.field("i", pa.int64()), pa.field("s", pa.string())])
457+
458+
UNION_VARARGS_SCHEMA = schema(idx=pa.int64(), tag=pa.string(), value=pa.string())
459+
460+
461+
@dataclass(kw_only=True)
462+
class UnionVarargsArgs:
463+
"""Arguments for union_varargs."""
464+
465+
configs: Annotated[
466+
tuple[TaggedUnion, ...],
467+
Arg(
468+
0,
469+
varargs=True,
470+
arrow_type=UNION_VARARGS_TYPE,
471+
doc="Union values whose active member tag is echoed back",
472+
),
473+
]
474+
475+
476+
@dataclass(kw_only=True)
477+
class UnionVarargsState(ArrowSerializableDataclass):
478+
"""State for union_varargs."""
479+
480+
idx: list[int] = field(default_factory=list)
481+
tags: list[str | None] = field(default_factory=list)
482+
values: list[str] = field(default_factory=list)
483+
done: bool = False
484+
485+
486+
@init_single_worker
487+
@bind_fixed_schema
488+
class UnionVarargsFunction(TableFunctionGenerator[UnionVarargsArgs, UnionVarargsState]):
489+
"""Echo the active member tag and value of each union vararg.
490+
491+
USE CASE
492+
--------
493+
Exercises union-typed varargs: each argument arrives as a
494+
[`TaggedUnion`][vgi.arguments.TaggedUnion] so the active member
495+
discriminator (which a plain ``Scalar.as_py()`` would drop) is preserved.
496+
Emits one row per vararg with its positional index, the active member
497+
name, and the member value stringified into a single fixed column.
498+
499+
SCHEMA
500+
------
501+
Fixed: ``{"idx": int64, "tag": string, "value": string}``.
502+
503+
Example:
504+
SELECT * FROM union_varargs(
505+
union_value(i := 1)::UNION(i INT, s VARCHAR),
506+
union_value(s := 'x')::UNION(i INT, s VARCHAR))
507+
Returns: (0, 'i', '1'), (1, 's', 'x')
508+
509+
Attributes:
510+
FIXED_SCHEMA: The fixed Arrow output schema this function always produces.
511+
512+
"""
513+
514+
FIXED_SCHEMA: ClassVar[pa.Schema] = UNION_VARARGS_SCHEMA
515+
516+
class Meta:
517+
"""Function metadata."""
518+
519+
name = "union_varargs"
520+
description = "Echo the active member tag and value of each union vararg"
521+
categories = ["generator", "utility"]
522+
examples = [
523+
FunctionExample(
524+
sql=(
525+
"SELECT * FROM union_varargs("
526+
"union_value(i := 1)::UNION(i INT, s VARCHAR), "
527+
"union_value(s := 'x')::UNION(i INT, s VARCHAR))"
528+
),
529+
description="Echo the tag and value of two union arguments",
530+
),
531+
]
532+
533+
@classmethod
534+
def initial_state(cls, params: ProcessParams[UnionVarargsArgs]) -> UnionVarargsState:
535+
"""Decompose each union vararg into (idx, tag, value) rows."""
536+
configs = params.args.configs
537+
return UnionVarargsState(
538+
idx=list(range(len(configs))),
539+
tags=[cfg.tag for cfg in configs],
540+
values=[str(cfg.value) for cfg in configs],
541+
)
542+
543+
@classmethod
544+
def process(cls, params: ProcessParams[UnionVarargsArgs], state: UnionVarargsState, out: OutputCollector) -> None:
545+
"""Emit one row per union vararg."""
546+
if state.done:
547+
out.finish()
548+
return
549+
state.done = True
550+
out.emit(
551+
pa.RecordBatch.from_pydict(
552+
{"idx": state.idx, "tag": state.tags, "value": state.values},
553+
schema=UNION_VARARGS_SCHEMA,
554+
)
555+
)

vgi/_test_fixtures/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@
144144
MakeSeriesRangeFunction,
145145
MakeSeriesStepFunction,
146146
MissingBatchIndexTagFunction,
147+
MultiSecretDemoFunction,
147148
NamedParamsEchoFunction,
148149
NestedSequenceFunction,
149150
NonMonotoneBatchIndexFunction,
@@ -171,7 +172,6 @@
171172
RffStructScanFunction,
172173
RowIdSequenceFunction,
173174
SampleEchoFunction,
174-
MultiSecretDemoFunction,
175175
ScopedSecretDemoFunction,
176176
SecretDemoFunction,
177177
SequenceFunction,
@@ -181,6 +181,7 @@
181181
TenThousandFunction,
182182
TxCachedValueFunction,
183183
TypedProbeFunction,
184+
UnionVarargsFunction,
184185
ValuePruneFunction,
185186
VersionedConstraintsScanFunction,
186187
VersionedDataFunction,
@@ -378,6 +379,7 @@ def _build_enum_stats() -> dict[str, ColumnStatisticsInput]:
378379
MakePairsStrFunction,
379380
RepeatValueIntFunction,
380381
RepeatValueStrFunction,
382+
UnionVarargsFunction,
381383
NamedParamsEchoFunction,
382384
NestedSequenceFunction,
383385
ProfilingDemoFunction,

vgi/table_function.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
TableInput,
3939
_accepts_none,
4040
_extract_setting_secret_params,
41+
_scalar_to_py,
4142
)
4243
from vgi.function_storage import BoundStorage, TransactionBoundStorage, attach_catalog_bytes
4344
from vgi.invocation import (
@@ -240,7 +241,7 @@ def pending_lookups(self) -> list[SecretLookupEntry]:
240241
"""Return the list of pending secret lookups."""
241242
return list(self._pending_lookups)
242243

243-
def to_dict(self) -> "ResolvedSecrets":
244+
def to_dict(self) -> ResolvedSecrets:
244245
"""Return all resolved secrets keyed by secret name.
245246
246247
Resolved secrets are keyed by their unique DuckDB secret name, so several
@@ -316,9 +317,7 @@ def secret_type(self, name: str) -> str | None:
316317

317318
def of_type(self, secret_type: str) -> list[dict[str, Any]]:
318319
"""Every resolved secret whose ``type`` field matches ``secret_type``."""
319-
return [
320-
f for f in self.values() if _secret_scalar_str(f.get("type")) == secret_type
321-
]
320+
return [f for f in self.values() if _secret_scalar_str(f.get("type")) == secret_type]
322321

323322
def for_scope(self, path: str) -> dict[str, Any] | None:
324323
"""The secret whose ``scope`` is the longest prefix of ``path``.
@@ -338,9 +337,7 @@ def field_for(self, path: str, field: str) -> Any | None:
338337
fields = self.for_scope(path)
339338
return None if fields is None else fields.get(field)
340339

341-
def _select_for_scope(
342-
self, path: str, secret_type: str | None
343-
) -> dict[str, Any] | None:
340+
def _select_for_scope(self, path: str, secret_type: str | None) -> dict[str, Any] | None:
344341
best: dict[str, Any] | None = None
345342
best_len = -1
346343
fallback: dict[str, Any] | None = None
@@ -763,9 +760,21 @@ def _parse_arguments(args_class: type[TArgs], arguments: Arguments) -> TArgs:
763760
for meta in get_args(hint)[1:]:
764761
if isinstance(meta, Arg):
765762
if meta.varargs:
766-
# Varargs: collect remaining positional args as raw pa.Scalar objects
763+
# Varargs: collect remaining positional args as raw pa.Scalar
764+
# objects (e.g. constant_columns reads .type / pa.repeat off
765+
# them). Union-typed varargs are the exception: decode each
766+
# scalar to a TaggedUnion so the active member discriminator
767+
# is preserved — matching how non-vararg union args resolve
768+
# via Arguments.get()/_scalar_to_py(). Keyed on the declared
769+
# arrow_type so the raw-scalar contract is untouched otherwise.
767770
assert isinstance(meta.position, int)
768-
kwargs[attr_name] = tuple(arguments.positional[meta.position :])
771+
varargs_scalars = arguments.positional[meta.position :]
772+
if meta.arrow_type is not None and pa.types.is_union(meta.arrow_type):
773+
kwargs[attr_name] = tuple(
774+
_scalar_to_py(s) if s is not None else None for s in varargs_scalars
775+
)
776+
else:
777+
kwargs[attr_name] = tuple(varargs_scalars)
769778
else:
770779
value = arguments.get(meta.position, default=meta.default)
771780
# Reject SQL NULL for non-Optional Args. Without this,

0 commit comments

Comments
 (0)