Skip to content

Commit 7951422

Browse files
rustyconoverclaude
andcommitted
Refactor tests to use schema() helper from vgi.schema_utils
Replace verbose pa.schema([...]) patterns with cleaner schema() helper: - tests/conftest.py: fixtures use schema() - tests/scalar/test_client.py: all test schemas - tests/scalar/test_function.py: all test schemas - tests/table_in_out/test_function.py: all test schemas - tests/test_type_bounds.py: all test schemas 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 9cef5ba commit 7951422

5 files changed

Lines changed: 97 additions & 121 deletions

File tree

tests/conftest.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77
import structlog
88

9+
from vgi import schema
910
from vgi.arguments import Arguments
1011
from vgi.invocation import Invocation, InvocationType
1112

@@ -124,37 +125,28 @@ def example_worker() -> str:
124125
@pytest.fixture
125126
def simple_batches() -> list[pa.RecordBatch]:
126127
"""Create simple test batches with integer and string columns."""
127-
fields: list[pa.Field[Any]] = [
128-
pa.field("id", pa.int64()),
129-
pa.field("value", pa.int64()),
130-
pa.field("name", pa.string()),
131-
]
132-
schema = pa.schema(fields)
128+
s = schema(id=pa.int64(), value=pa.int64(), name=pa.string())
133129
batch1 = pa.RecordBatch.from_pydict(
134130
{"id": [1, 2], "value": [10, 20], "name": ["a", "b"]},
135-
schema=schema,
131+
schema=s,
136132
)
137133
batch2 = pa.RecordBatch.from_pydict(
138134
{"id": [3, 4], "value": [30, 40], "name": ["c", "d"]},
139-
schema=schema,
135+
schema=s,
140136
)
141137
return [batch1, batch2]
142138

143139

144140
@pytest.fixture
145141
def numeric_batches() -> list[pa.RecordBatch]:
146142
"""Create test batches with only numeric columns for sum tests."""
147-
fields: list[pa.Field[Any]] = [
148-
pa.field("a", pa.int32()),
149-
pa.field("b", pa.float64()),
150-
]
151-
schema = pa.schema(fields)
143+
s = schema(a=pa.int32(), b=pa.float64())
152144
batch1 = pa.RecordBatch.from_pydict(
153145
{"a": [1, 2, 3], "b": [1.5, 2.5, 3.0]},
154-
schema=schema,
146+
schema=s,
155147
)
156148
batch2 = pa.RecordBatch.from_pydict(
157149
{"a": [4, 5], "b": [4.0, 5.0]},
158-
schema=schema,
150+
schema=s,
159151
)
160152
return [batch1, batch2]

tests/scalar/test_client.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from tests.conftest import assert_total_rows
11+
from vgi import schema
1112
from vgi.arguments import Arguments
1213
from vgi.client import Client
1314
from vgi.client.client import ClientError
@@ -18,8 +19,8 @@ class TestScalarFunctionClient:
1819

1920
def test_double_column_basic(self, example_worker: str) -> None:
2021
"""Test basic scalar function via Client."""
21-
schema = pa.schema([("x", pa.int64())])
22-
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema)
22+
s = schema(x=pa.int64())
23+
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=s)
2324

2425
with Client(example_worker) as client:
2526
outputs = list(
@@ -35,9 +36,9 @@ def test_double_column_basic(self, example_worker: str) -> None:
3536

3637
def test_add_columns(self, example_worker: str) -> None:
3738
"""Test add_columns scalar function."""
38-
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
39+
s = schema(a=pa.int64(), b=pa.int64())
3940
batch = pa.RecordBatch.from_pydict(
40-
{"a": [1, 2, 3], "b": [10, 20, 30]}, schema=schema
41+
{"a": [1, 2, 3], "b": [10, 20, 30]}, schema=s
4142
)
4243

4344
with Client(example_worker) as client:
@@ -54,9 +55,9 @@ def test_add_columns(self, example_worker: str) -> None:
5455

5556
def test_upper_case(self, example_worker: str) -> None:
5657
"""Test upper_case scalar function."""
57-
schema = pa.schema([("name", pa.string())])
58+
s = schema(name=pa.string())
5859
batch = pa.RecordBatch.from_pydict(
59-
{"name": ["alice", "bob", "charlie"]}, schema=schema
60+
{"name": ["alice", "bob", "charlie"]}, schema=s
6061
)
6162

6263
with Client(example_worker) as client:
@@ -73,10 +74,10 @@ def test_upper_case(self, example_worker: str) -> None:
7374

7475
def test_multiple_batches(self, example_worker: str) -> None:
7576
"""Test scalar function with multiple input batches."""
76-
schema = pa.schema([("x", pa.int64())])
77-
batch1 = pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=schema)
78-
batch2 = pa.RecordBatch.from_pydict({"x": [3, 4, 5]}, schema=schema)
79-
batch3 = pa.RecordBatch.from_pydict({"x": [6]}, schema=schema)
77+
s = schema(x=pa.int64())
78+
batch1 = pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=s)
79+
batch2 = pa.RecordBatch.from_pydict({"x": [3, 4, 5]}, schema=s)
80+
batch3 = pa.RecordBatch.from_pydict({"x": [6]}, schema=s)
8081

8182
with Client(example_worker) as client:
8283
outputs = list(
@@ -99,8 +100,8 @@ def test_multiple_batches(self, example_worker: str) -> None:
99100

100101
def test_empty_batch(self, example_worker: str) -> None:
101102
"""Test scalar function with empty batch."""
102-
schema = pa.schema([("x", pa.int64())])
103-
empty_batch = pa.RecordBatch.from_pydict({"x": []}, schema=schema)
103+
s = schema(x=pa.int64())
104+
empty_batch = pa.RecordBatch.from_pydict({"x": []}, schema=s)
104105

105106
with Client(example_worker) as client:
106107
outputs = list(
@@ -132,8 +133,8 @@ def test_empty_iterator(self, example_worker: str) -> None:
132133
def test_scalar_function_not_started_raises(self, example_worker: str) -> None:
133134
"""Calling scalar_function before start should raise ClientError."""
134135
client = Client(example_worker)
135-
schema = pa.schema([("x", pa.int64())])
136-
batch = pa.RecordBatch.from_pydict({"x": [1]}, schema=schema)
136+
s = schema(x=pa.int64())
137+
batch = pa.RecordBatch.from_pydict({"x": [1]}, schema=s)
137138

138139
with pytest.raises(ClientError, match="not started"):
139140
list(
@@ -146,9 +147,9 @@ def test_scalar_function_not_started_raises(self, example_worker: str) -> None:
146147

147148
def test_large_batch(self, example_worker: str) -> None:
148149
"""Test scalar function with a large batch."""
149-
schema = pa.schema([("x", pa.int64())])
150+
s = schema(x=pa.int64())
150151
large_data = list(range(10000))
151-
batch = pa.RecordBatch.from_pydict({"x": large_data}, schema=schema)
152+
batch = pa.RecordBatch.from_pydict({"x": large_data}, schema=s)
152153

153154
with Client(example_worker) as client:
154155
outputs = list(
@@ -170,8 +171,8 @@ def test_large_batch(self, example_worker: str) -> None:
170171

171172
def test_bind_result_callback(self, example_worker: str) -> None:
172173
"""Test that bind_result_callback is invoked."""
173-
schema = pa.schema([("x", pa.int64())])
174-
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema)
174+
s = schema(x=pa.int64())
175+
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=s)
175176

176177
bind_results: list[pa.RecordBatch] = []
177178

@@ -198,10 +199,8 @@ def capture_bind_result(result: pa.RecordBatch) -> None:
198199

199200
def test_add_columns_accepts_float_columns(self, example_worker: str) -> None:
200201
"""Test that add_columns accepts float columns."""
201-
schema = pa.schema([("a", pa.float64()), ("b", pa.float64())])
202-
batch = pa.RecordBatch.from_pydict(
203-
{"a": [1.5, 2.5], "b": [0.5, 0.5]}, schema=schema
204-
)
202+
s = schema(a=pa.float64(), b=pa.float64())
203+
batch = pa.RecordBatch.from_pydict({"a": [1.5, 2.5], "b": [0.5, 0.5]}, schema=s)
205204

206205
with Client(example_worker) as client:
207206
outputs = list(
@@ -217,8 +216,8 @@ def test_add_columns_accepts_float_columns(self, example_worker: str) -> None:
217216

218217
def test_add_columns_accepts_mixed_int_types(self, example_worker: str) -> None:
219218
"""Test that add_columns accepts mixed integer types and promotes correctly."""
220-
schema = pa.schema([("a", pa.int32()), ("b", pa.int64())])
221-
batch = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=schema)
219+
s = schema(a=pa.int32(), b=pa.int64())
220+
batch = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=s)
222221

223222
with Client(example_worker) as client:
224223
outputs = list(
@@ -240,9 +239,9 @@ class TestSumColumns:
240239

241240
def test_sum_two_columns(self, example_worker: str) -> None:
242241
"""Sum of two columns."""
243-
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
242+
s = schema(a=pa.int64(), b=pa.int64())
244243
batch = pa.RecordBatch.from_pydict(
245-
{"a": [1, 2, 3], "b": [10, 20, 30]}, schema=schema
244+
{"a": [1, 2, 3], "b": [10, 20, 30]}, schema=s
246245
)
247246

248247
with Client(example_worker) as client:
@@ -259,9 +258,9 @@ def test_sum_two_columns(self, example_worker: str) -> None:
259258

260259
def test_sum_three_columns(self, example_worker: str) -> None:
261260
"""Sum of three columns using varargs."""
262-
schema = pa.schema([("a", pa.int64()), ("b", pa.int64()), ("c", pa.int64())])
261+
s = schema(a=pa.int64(), b=pa.int64(), c=pa.int64())
263262
batch = pa.RecordBatch.from_pydict(
264-
{"a": [1, 2], "b": [10, 20], "c": [100, 200]}, schema=schema
263+
{"a": [1, 2], "b": [10, 20], "c": [100, 200]}, schema=s
265264
)
266265

267266
with Client(example_worker) as client:
@@ -280,8 +279,8 @@ def test_sum_three_columns(self, example_worker: str) -> None:
280279

281280
def test_sum_with_type_promotion(self, example_worker: str) -> None:
282281
"""Different int types promote correctly."""
283-
schema = pa.schema([("a", pa.int32()), ("b", pa.int64())])
284-
batch = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=schema)
282+
s = schema(a=pa.int32(), b=pa.int64())
283+
batch = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=s)
285284

286285
with Client(example_worker) as client:
287286
outputs = list(
@@ -299,10 +298,8 @@ def test_sum_with_type_promotion(self, example_worker: str) -> None:
299298

300299
def test_sum_rejects_string_column(self, example_worker: str) -> None:
301300
"""Type bound rejects non-numeric columns."""
302-
schema = pa.schema([("a", pa.int64()), ("b", pa.string())]) # type: ignore[arg-type]
303-
batch = pa.RecordBatch.from_pydict(
304-
{"a": [1, 2], "b": ["x", "y"]}, schema=schema
305-
)
301+
s = schema(a=pa.int64(), b=pa.string())
302+
batch = pa.RecordBatch.from_pydict({"a": [1, 2], "b": ["x", "y"]}, schema=s)
306303

307304
with (
308305
Client(example_worker) as client,
@@ -318,9 +315,9 @@ def test_sum_rejects_string_column(self, example_worker: str) -> None:
318315

319316
def test_sum_multiple_batches(self, example_worker: str) -> None:
320317
"""Multiple input batches processed correctly."""
321-
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
322-
batch1 = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=schema)
323-
batch2 = pa.RecordBatch.from_pydict({"a": [3, 4], "b": [30, 40]}, schema=schema)
318+
s = schema(a=pa.int64(), b=pa.int64())
319+
batch1 = pa.RecordBatch.from_pydict({"a": [1, 2], "b": [10, 20]}, schema=s)
320+
batch2 = pa.RecordBatch.from_pydict({"a": [3, 4], "b": [30, 40]}, schema=s)
324321

325322
with Client(example_worker) as client:
326323
outputs = list(
@@ -339,8 +336,8 @@ def test_sum_multiple_batches(self, example_worker: str) -> None:
339336

340337
def test_sum_empty_batch(self, example_worker: str) -> None:
341338
"""Empty batch returns empty output."""
342-
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
343-
empty_batch = pa.RecordBatch.from_pydict({"a": [], "b": []}, schema=schema)
339+
s = schema(a=pa.int64(), b=pa.int64())
340+
empty_batch = pa.RecordBatch.from_pydict({"a": [], "b": []}, schema=s)
344341

345342
with Client(example_worker) as client:
346343
outputs = list(
@@ -356,10 +353,8 @@ def test_sum_empty_batch(self, example_worker: str) -> None:
356353

357354
def test_sum_float_columns(self, example_worker: str) -> None:
358355
"""Sum of float columns."""
359-
schema = pa.schema([("a", pa.float64()), ("b", pa.float64())])
360-
batch = pa.RecordBatch.from_pydict(
361-
{"a": [1.5, 2.5], "b": [0.5, 0.5]}, schema=schema
362-
)
356+
s = schema(a=pa.float64(), b=pa.float64())
357+
batch = pa.RecordBatch.from_pydict({"a": [1.5, 2.5], "b": [0.5, 0.5]}, schema=s)
363358

364359
with Client(example_worker) as client:
365360
outputs = list(
@@ -379,10 +374,10 @@ class TestScalarFunctionParallel:
379374

380375
def test_parallel_double_column(self, example_worker: str) -> None:
381376
"""Test scalar function with multiple workers."""
382-
schema = pa.schema([("x", pa.int64())])
377+
s = schema(x=pa.int64())
383378
batches = [
384379
pa.RecordBatch.from_pydict(
385-
{"x": list(range(i * 100, (i + 1) * 100))}, schema=schema
380+
{"x": list(range(i * 100, (i + 1) * 100))}, schema=s
386381
)
387382
for i in range(10)
388383
]
@@ -409,10 +404,10 @@ def test_parallel_double_column(self, example_worker: str) -> None:
409404

410405
def test_parallel_add_columns(self, example_worker: str) -> None:
411406
"""Test add_columns with multiple workers."""
412-
schema = pa.schema([("a", pa.int64()), ("b", pa.int64())])
407+
s = schema(a=pa.int64(), b=pa.int64())
413408
batches = [
414409
pa.RecordBatch.from_pydict(
415-
{"a": [i, i + 1, i + 2], "b": [100, 200, 300]}, schema=schema
410+
{"a": [i, i + 1, i + 2], "b": [100, 200, 300]}, schema=s
416411
)
417412
for i in range(20)
418413
]
@@ -431,13 +426,13 @@ def test_parallel_add_columns(self, example_worker: str) -> None:
431426

432427
def test_parallel_empty_batches_mixed(self, example_worker: str) -> None:
433428
"""Test parallel processing with mix of empty and non-empty batches."""
434-
schema = pa.schema([("x", pa.int64())])
429+
s = schema(x=pa.int64())
435430
batches = [
436-
pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=schema),
437-
pa.RecordBatch.from_pydict({"x": []}, schema=schema), # Empty
438-
pa.RecordBatch.from_pydict({"x": [3]}, schema=schema),
439-
pa.RecordBatch.from_pydict({"x": []}, schema=schema), # Empty
440-
pa.RecordBatch.from_pydict({"x": [4, 5, 6]}, schema=schema),
431+
pa.RecordBatch.from_pydict({"x": [1, 2]}, schema=s),
432+
pa.RecordBatch.from_pydict({"x": []}, schema=s), # Empty
433+
pa.RecordBatch.from_pydict({"x": [3]}, schema=s),
434+
pa.RecordBatch.from_pydict({"x": []}, schema=s), # Empty
435+
pa.RecordBatch.from_pydict({"x": [4, 5, 6]}, schema=s),
441436
]
442437

443438
with Client(example_worker, max_workers=2) as client:
@@ -460,8 +455,8 @@ def test_parallel_empty_batches_mixed(self, example_worker: str) -> None:
460455

461456
def test_parallel_single_batch(self, example_worker: str) -> None:
462457
"""Test parallel mode with just one batch (should still work)."""
463-
schema = pa.schema([("x", pa.int64())])
464-
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=schema)
458+
s = schema(x=pa.int64())
459+
batch = pa.RecordBatch.from_pydict({"x": [1, 2, 3]}, schema=s)
465460

466461
with Client(example_worker, max_workers=4) as client:
467462
outputs = list(

0 commit comments

Comments
 (0)