1+ import jax
2+ import jax .numpy as jnp
13import numpy as np
24import pytest
35from scipy .special import expit
46
5- from genmetaballs .core import CPUFloatArray2D , sigmoid
7+ from genmetaballs .core import array2d_float , sigmoid
68
79NUM_RNG_SEEDS_PER_TEST = 5
810NUM_N_VALUES_PER_TEST = 5
@@ -60,11 +62,32 @@ def test_sigmoid_edge_cases(x: float) -> None:
6062 assert actual <= 1.0
6163
6264
63- def test_float_array2d_creation_and_view ():
65+ @pytest .mark .parametrize ("device" , ["cpu" , "gpu" ])
66+ def test_array2d_float_creation_on_jax_devices (device : str ):
6467 """Test creation of Array2D from a numpy array."""
6568 rows , cols = 4 , 5
69+ data = jnp .arange (rows * cols , dtype = jnp .float32 ).reshape ((rows , cols ))
70+ jax_device = jax .devices (device )[0 ]
71+ data = jax .device_put (data , device = jax_device )
72+ array_2d = array2d_float (data , device = device )
73+
74+ assert array_2d .num_rows == rows
75+ assert array_2d .num_cols == cols
76+ assert array_2d .ndim == 2
77+
78+ # then try converting back to numpy array via view
79+ data_view = array_2d .as_jax ()
80+ assert data_view .device == data .device
81+ assert jnp .allclose (data , data_view )
82+
83+ # Note: we can't test writability of shared view here since JAX arrays are immutable
84+
85+
86+ def test_float_array2d_view_numpy ():
87+ """Test creation of Array2D from a numpy array."""
88+ rows , cols = 3 , 4
6689 data = np .arange (rows * cols , dtype = np .float32 ).reshape ((rows , cols ))
67- array_2d = CPUFloatArray2D . from_array (data )
90+ array_2d = array2d_float (data , device = "cpu" )
6891
6992 assert array_2d .num_rows == rows
7093 assert array_2d .num_cols == cols
@@ -84,4 +107,4 @@ def test_create_invalid_array2d():
84107 data = np .arange (12 , dtype = np .float32 ).reshape ((3 , 4 ))
85108
86109 with pytest .raises (TypeError ):
87- CPUFloatArray2D . from_array (data .reshape ((3 , 4 , 1 ))) # not 2D
110+ array2d_float (data .reshape ((3 , 4 , 1 )), device = "cpu" ) # not 2D
0 commit comments