Skip to content

Commit 93536ee

Browse files
committed
Add tests for creating Array2D views on CPU/GPU JAX Arrays
1 parent f6d38f1 commit 93536ee

3 files changed

Lines changed: 46 additions & 6 deletions

File tree

genmetaballs/src/cuda/bindings.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ NB_MODULE(_genmetaballs_bindings, m) {
8484
utils.def("sigmoid", sigmoid, nb::arg("x"), "Compute the sigmoid function: 1 / (1 + exp(-x))");
8585

8686
bind_array2d<float, DeviceType::CPU>(utils, "CPUFloatArray2D");
87+
bind_array2d<float, DeviceType::GPU>(utils, "GPUFloatArray2D");
8788

8889
} // NB_MODULE(_genmetaballs_bindings)
8990

genmetaballs/src/genmetaballs/core/__init__.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,26 @@
33
TwoParameterConfidence,
44
ZeroParameterConfidence,
55
)
6-
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, sigmoid
6+
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid
7+
8+
9+
def array2d_float(data, device) -> CPUFloatArray2D | GPUFloatArray2D:
10+
"""Create a FloatArray2D on the specified device from an array.
11+
12+
Args:
13+
data: A 2D array of type float32.
14+
device: 'cpu' or 'gpu' to specify the target device.
15+
"""
16+
if device == "cpu":
17+
return CPUFloatArray2D.from_array(data)
18+
elif device == "gpu":
19+
return GPUFloatArray2D.from_array(data)
20+
else:
21+
raise ValueError(f"Unsupported device type: {device}")
22+
723

824
__all__ = [
9-
"CPUFloatArray2D",
25+
"array2d_float",
1026
"ZeroParameterConfidence",
1127
"TwoParameterConfidence",
1228
"geometry",

tests/python_tests/test_utils.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import jax
2+
import jax.numpy as jnp
13
import numpy as np
24
import pytest
35
from scipy.special import expit
46

5-
from genmetaballs.core import CPUFloatArray2D, sigmoid
7+
from genmetaballs.core import array2d_float, sigmoid
68

79
NUM_RNG_SEEDS_PER_TEST = 5
810
NUM_N_VALUES_PER_TEST = 5
@@ -60,11 +62,32 @@ def test_sigmoid_edge_cases(x: float) -> None:
6062
assert actual <= 1.0
6163

6264

63-
def test_float_array2d_creation_and_view():
65+
@pytest.mark.parametrize("device", ["cpu", "gpu"])
66+
def test_array2d_float_creation_on_jax_devices(device: str):
6467
"""Test creation of Array2D from a numpy array."""
6568
rows, cols = 4, 5
69+
data = jnp.arange(rows * cols, dtype=jnp.float32).reshape((rows, cols))
70+
jax_device = jax.devices(device)[0]
71+
data = jax.device_put(data, device=jax_device)
72+
array_2d = array2d_float(data, device=device)
73+
74+
assert array_2d.num_rows == rows
75+
assert array_2d.num_cols == cols
76+
assert array_2d.ndim == 2
77+
78+
# then try converting back to numpy array via view
79+
data_view = array_2d.as_jax()
80+
assert data_view.device == data.device
81+
assert jnp.allclose(data, data_view)
82+
83+
# Note: we can't test writability of shared view here since JAX arrays are immutable
84+
85+
86+
def test_float_array2d_view_numpy():
87+
"""Test creation of Array2D from a numpy array."""
88+
rows, cols = 3, 4
6689
data = np.arange(rows * cols, dtype=np.float32).reshape((rows, cols))
67-
array_2d = CPUFloatArray2D.from_array(data)
90+
array_2d = array2d_float(data, device="cpu")
6891

6992
assert array_2d.num_rows == rows
7093
assert array_2d.num_cols == cols
@@ -84,4 +107,4 @@ def test_create_invalid_array2d():
84107
data = np.arange(12, dtype=np.float32).reshape((3, 4))
85108

86109
with pytest.raises(TypeError):
87-
CPUFloatArray2D.from_array(data.reshape((3, 4, 1))) # not 2D
110+
array2d_float(data.reshape((3, 4, 1)), device="cpu") # not 2D

0 commit comments

Comments
 (0)