|
14 | 14 | from vgi._test_fixtures.table._common import ( |
15 | 15 | _cardinality_from_count, |
16 | 16 | ) |
17 | | -from vgi.arguments import Arg |
| 17 | +from vgi.arguments import Arg, TaggedUnion |
18 | 18 | from vgi.invocation import BindResponse |
19 | 19 | from vgi.metadata import FunctionExample |
20 | 20 | from vgi.schema_utils import schema |
@@ -447,3 +447,109 @@ def process( |
447 | 447 | data = {f"v{i}": col for i, col in enumerate(state.rows)} |
448 | 448 | out_schema = schema({f"v{i}": pa.string() for i in range(len(state.rows))}) |
449 | 449 | out.emit(pa.RecordBatch.from_pydict(data, schema=out_schema)) |
| 450 | + |
| 451 | + |
| 452 | +# ============================================================================ |
| 453 | + |
| 454 | +# Sparse union shared by every union_varargs argument. DuckDB only ever emits |
| 455 | +# sparse unions (+us:) over Arrow, so this round-trips end-to-end. |
| 456 | +UNION_VARARGS_TYPE = pa.sparse_union([pa.field("i", pa.int64()), pa.field("s", pa.string())]) |
| 457 | + |
| 458 | +UNION_VARARGS_SCHEMA = schema(idx=pa.int64(), tag=pa.string(), value=pa.string()) |
| 459 | + |
| 460 | + |
| 461 | +@dataclass(kw_only=True) |
| 462 | +class UnionVarargsArgs: |
| 463 | + """Arguments for union_varargs.""" |
| 464 | + |
| 465 | + configs: Annotated[ |
| 466 | + tuple[TaggedUnion, ...], |
| 467 | + Arg( |
| 468 | + 0, |
| 469 | + varargs=True, |
| 470 | + arrow_type=UNION_VARARGS_TYPE, |
| 471 | + doc="Union values whose active member tag is echoed back", |
| 472 | + ), |
| 473 | + ] |
| 474 | + |
| 475 | + |
| 476 | +@dataclass(kw_only=True) |
| 477 | +class UnionVarargsState(ArrowSerializableDataclass): |
| 478 | + """State for union_varargs.""" |
| 479 | + |
| 480 | + idx: list[int] = field(default_factory=list) |
| 481 | + tags: list[str | None] = field(default_factory=list) |
| 482 | + values: list[str] = field(default_factory=list) |
| 483 | + done: bool = False |
| 484 | + |
| 485 | + |
| 486 | +@init_single_worker |
| 487 | +@bind_fixed_schema |
| 488 | +class UnionVarargsFunction(TableFunctionGenerator[UnionVarargsArgs, UnionVarargsState]): |
| 489 | + """Echo the active member tag and value of each union vararg. |
| 490 | +
|
| 491 | + USE CASE |
| 492 | + -------- |
| 493 | + Exercises union-typed varargs: each argument arrives as a |
| 494 | + [`TaggedUnion`][vgi.arguments.TaggedUnion] so the active member |
| 495 | + discriminator (which a plain ``Scalar.as_py()`` would drop) is preserved. |
| 496 | + Emits one row per vararg with its positional index, the active member |
| 497 | + name, and the member value stringified into a single fixed column. |
| 498 | +
|
| 499 | + SCHEMA |
| 500 | + ------ |
| 501 | + Fixed: ``{"idx": int64, "tag": string, "value": string}``. |
| 502 | +
|
| 503 | + Example: |
| 504 | + SELECT * FROM union_varargs( |
| 505 | + union_value(i := 1)::UNION(i INT, s VARCHAR), |
| 506 | + union_value(s := 'x')::UNION(i INT, s VARCHAR)) |
| 507 | + Returns: (0, 'i', '1'), (1, 's', 'x') |
| 508 | +
|
| 509 | + Attributes: |
| 510 | + FIXED_SCHEMA: The fixed Arrow output schema this function always produces. |
| 511 | +
|
| 512 | + """ |
| 513 | + |
| 514 | + FIXED_SCHEMA: ClassVar[pa.Schema] = UNION_VARARGS_SCHEMA |
| 515 | + |
| 516 | + class Meta: |
| 517 | + """Function metadata.""" |
| 518 | + |
| 519 | + name = "union_varargs" |
| 520 | + description = "Echo the active member tag and value of each union vararg" |
| 521 | + categories = ["generator", "utility"] |
| 522 | + examples = [ |
| 523 | + FunctionExample( |
| 524 | + sql=( |
| 525 | + "SELECT * FROM union_varargs(" |
| 526 | + "union_value(i := 1)::UNION(i INT, s VARCHAR), " |
| 527 | + "union_value(s := 'x')::UNION(i INT, s VARCHAR))" |
| 528 | + ), |
| 529 | + description="Echo the tag and value of two union arguments", |
| 530 | + ), |
| 531 | + ] |
| 532 | + |
| 533 | + @classmethod |
| 534 | + def initial_state(cls, params: ProcessParams[UnionVarargsArgs]) -> UnionVarargsState: |
| 535 | + """Decompose each union vararg into (idx, tag, value) rows.""" |
| 536 | + configs = params.args.configs |
| 537 | + return UnionVarargsState( |
| 538 | + idx=list(range(len(configs))), |
| 539 | + tags=[cfg.tag for cfg in configs], |
| 540 | + values=[str(cfg.value) for cfg in configs], |
| 541 | + ) |
| 542 | + |
| 543 | + @classmethod |
| 544 | + def process(cls, params: ProcessParams[UnionVarargsArgs], state: UnionVarargsState, out: OutputCollector) -> None: |
| 545 | + """Emit one row per union vararg.""" |
| 546 | + if state.done: |
| 547 | + out.finish() |
| 548 | + return |
| 549 | + state.done = True |
| 550 | + out.emit( |
| 551 | + pa.RecordBatch.from_pydict( |
| 552 | + {"idx": state.idx, "tag": state.tags, "value": state.values}, |
| 553 | + schema=UNION_VARARGS_SCHEMA, |
| 554 | + ) |
| 555 | + ) |
0 commit comments