|
6 | 6 |
|
7 | 7 | from tests.conftest import make_invocation |
8 | 8 | from vgi.arguments import Arguments |
| 9 | +from vgi.client import Client |
9 | 10 | from vgi.examples.table import SequenceFunction |
10 | 11 | from vgi.testing import assert_table_function_output, batch |
11 | 12 |
|
@@ -79,3 +80,58 @@ def test_large_sequence_batches( |
79 | 80 | table = pa.Table.from_batches(outputs) |
80 | 81 | assert table.num_rows == 2500 |
81 | 82 | assert table.column("n").to_pylist() == list(range(2500)) |
| 83 | + |
| 84 | + def test_custom_batch_size(self, run_table_function_mode: RunnerWithMode) -> None: |
| 85 | + """Custom batch size should control output batch sizes.""" |
| 86 | + runner, mode = run_table_function_mode |
| 87 | + # Generate 250 values with batch size of 100 |
| 88 | + outputs, logs = runner(SequenceFunction, (250, 100)) |
| 89 | + |
| 90 | + # Should produce 3 batches: 100, 100, 50 |
| 91 | + assert len(outputs) == 3 |
| 92 | + assert outputs[0].num_rows == 100 |
| 93 | + assert outputs[1].num_rows == 100 |
| 94 | + assert outputs[2].num_rows == 50 |
| 95 | + |
| 96 | + table = pa.Table.from_batches(outputs) |
| 97 | + assert table.column("n").to_pylist() == list(range(250)) |
| 98 | + |
| 99 | + def test_batch_size_larger_than_count( |
| 100 | + self, run_table_function_mode: RunnerWithMode |
| 101 | + ) -> None: |
| 102 | + """Batch size larger than count should produce single batch.""" |
| 103 | + runner, mode = run_table_function_mode |
| 104 | + outputs, logs = runner(SequenceFunction, (50, 1000)) |
| 105 | + |
| 106 | + assert len(outputs) == 1 |
| 107 | + assert outputs[0].num_rows == 50 |
| 108 | + assert outputs[0].column("n").to_pylist() == list(range(50)) |
| 109 | + |
| 110 | + |
| 111 | +class TestSequenceFunctionClient: |
| 112 | + """Tests for SequenceFunction via Client (wire protocol).""" |
| 113 | + |
| 114 | + def test_cardinality_returned_in_bind_result(self) -> None: |
| 115 | + """Cardinality should be returned in bind_result via Client.""" |
| 116 | + bind_results: list[pa.RecordBatch] = [] |
| 117 | + |
| 118 | + def capture_bind_result(result: pa.RecordBatch) -> None: |
| 119 | + bind_results.append(result) |
| 120 | + |
| 121 | + with Client("vgi-example-worker") as client: |
| 122 | + list( |
| 123 | + client.table_function( |
| 124 | + function_name="sequence", |
| 125 | + arguments=Arguments(positional=(pa.scalar(100),)), |
| 126 | + bind_result_callback=capture_bind_result, |
| 127 | + ) |
| 128 | + ) |
| 129 | + |
| 130 | + assert len(bind_results) == 1 |
| 131 | + bind_result = bind_results[0] |
| 132 | + |
| 133 | + # Verify cardinality fields are present and correct |
| 134 | + assert "cardinality_estimated" in bind_result.schema.names |
| 135 | + assert "cardinality_max" in bind_result.schema.names |
| 136 | + assert bind_result.column("cardinality_estimated")[0].as_py() == 100 |
| 137 | + assert bind_result.column("cardinality_max")[0].as_py() == 100 |
0 commit comments