diff --git a/src/pyrecest/_backend/_shared_numpy/random.py b/src/pyrecest/_backend/_shared_numpy/random.py index 5f24adcbe..fc640e028 100644 --- a/src/pyrecest/_backend/_shared_numpy/random.py +++ b/src/pyrecest/_backend/_shared_numpy/random.py @@ -5,8 +5,23 @@ _allow_complex_dtype = _common._allow_complex_dtype +def _rand(*dims, size=None): + """Draw uniform samples while accepting the backend ``size=`` contract. + + ``numpy.random.rand`` only accepts legacy positional dimensions, whereas the + PyRecEst random backend exposes ``rand(size=...)`` like the JAX and PyTorch + implementations. Use ``numpy.random.random`` internally so keyword and + tuple sizes work without dropping support for NumPy's positional form. + """ + if dims: + if size is not None: + raise TypeError("Specify either positional dimensions or size, not both.") + size = dims[0] if len(dims) == 1 else dims + return _np.random.random(size) + + rand = _modify_func_default_dtype( - copy=False, kw_only=True, target=_allow_complex_dtype(target=_np.random.rand) + copy=False, kw_only=True, target=_allow_complex_dtype(target=_rand) ) diff --git a/tests/backend/test_numpy_random_backend.py b/tests/backend/test_numpy_random_backend.py new file mode 100644 index 000000000..cea9a1903 --- /dev/null +++ b/tests/backend/test_numpy_random_backend.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest + +from pyrecest._backend.numpy import random + + +def test_rand_accepts_backend_size_keyword(): + random.seed(0) + + samples = random.rand(size=(2, 3)) + + assert samples.shape == (2, 3) + assert samples.dtype == np.float64 + + +def test_rand_keeps_numpy_positional_dimensions(): + random.seed(0) + + assert random.rand(2, 3).shape == (2, 3) + assert random.rand(4).shape == (4,) + + +def test_rand_rejects_ambiguous_positional_and_size_arguments(): + with pytest.raises(TypeError, match="positional dimensions or size"): + random.rand(2, size=(3,))