diff --git a/src/pyrecest/_backend/pytorch/linalg.py b/src/pyrecest/_backend/pytorch/linalg.py index d3ed0a052..8a0f753d4 100644 --- a/src/pyrecest/_backend/pytorch/linalg.py +++ b/src/pyrecest/_backend/pytorch/linalg.py @@ -58,11 +58,22 @@ def _torch_as_like(value, like): } +def _default_linalg_dtype(): + dtype = get_default_dtype() + if dtype in (_torch.float32, _torch.float64): + return dtype + if dtype == _np.dtype("float32"): + return _torch.float32 + if dtype == _np.dtype("float64"): + return _torch.float64 + return _torch.float64 + + def _as_linalg_tensor(value): """Convert array-like values to a floating/complex tensor for torch.linalg.""" tensor = array(value) if not is_floating(tensor) and not is_complex(tensor): - tensor = cast(tensor, dtype=get_default_dtype()) + tensor = cast(tensor, dtype=_default_linalg_dtype()) return tensor @@ -73,7 +84,7 @@ def _common_linalg_dtype(*tensors): dtype = _torch.promote_types(dtype, tensor.dtype) if dtype.is_floating_point or dtype.is_complex: return dtype - return get_default_dtype() + return _default_linalg_dtype() class _Logm(_torch.autograd.Function): diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index 6b84654ca..9eb140b2c 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -1,5 +1,6 @@ """Torch based random backend.""" +from math import prod as _prod from numbers import Integral as _Integral import torch as _torch @@ -17,21 +18,42 @@ } +def _size_type_error(): + return TypeError("size must be None, an integer, or a sequence of integers") + + +def _looks_like_integer_dimension(value): + return isinstance(value, _Integral) and not isinstance(value, bool) + + +def _integer_dimension(value): + if not _looks_like_integer_dimension(value): + raise _size_type_error() + value = int(value) + if value < 0: + raise ValueError("size dimensions must be non-negative") + return value + + +def _shape_from_size(size): + if size is None: + return () + if _looks_like_integer_dimension(size): + return (_integer_dimension(size),) + if isinstance(size, (str, bytes)) or not hasattr(size, "__iter__"): + raise _size_type_error() + return tuple(_integer_dimension(dim) for dim in size) + + def _choice_size(size): if size is None: return None, 1 - if not hasattr(size, "__iter__"): - size = (size,) - size = tuple(int(dim) for dim in size) - return size, int(_torch.prod(_torch.tensor(size)).item()) + size = _shape_from_size(size) + return size, _prod(size) if size else 1 def _randint_size(size): - if size is None: - return () - if not hasattr(size, "__iter__") or isinstance(size, (str, bytes)): - return (size,) - return tuple(size) + return _shape_from_size(size) def randint(low, high=None, size=None, *args, **kwargs): @@ -45,9 +67,7 @@ def randint(low, high=None, size=None, *args, **kwargs): def _normal_size(size): if size is None: return None - if not hasattr(size, "__iter__"): - return (size,) - return tuple(int(dim) for dim in size) + return _shape_from_size(size) def _normal_device(*values): @@ -166,11 +186,7 @@ def seed(*args, **kwargs): def rand(size=None, dtype=None): - if size is None: - size = () - elif not hasattr(size, "__iter__"): - size = (size,) - return _torch.rand(size, dtype=dtype) + return _torch.rand(_shape_from_size(size), dtype=dtype) def multinomial(n, pvals): @@ -198,9 +214,7 @@ def normal(loc=0.0, scale=1.0, size=None): def _uniform_size(size, low, high): if size is not None: - if not hasattr(size, "__iter__") or isinstance(size, (str, bytes)): - return (size,) - return tuple(int(dim) for dim in size) + return _shape_from_size(size) try: return tuple(_torch.broadcast_shapes(low.shape, high.shape)) @@ -242,11 +256,7 @@ def _floating_distribution_dtype(*values): def _normal_sample_size(size): - if size is None: - return () - if not hasattr(size, "__iter__"): - return (size,) - return tuple(size) + return _shape_from_size(size) @_modify_func_default_dtype(copy=False, kw_only=True) diff --git a/tests/backend/test_pytorch_random_backend.py b/tests/backend/test_pytorch_random_backend.py new file mode 100644 index 000000000..f364cce20 --- /dev/null +++ b/tests/backend/test_pytorch_random_backend.py @@ -0,0 +1,55 @@ +import pytest + +pytest.importorskip("torch") + +from pyrecest._backend.pytorch import random # noqa: E402 + + +@pytest.mark.parametrize( + "bad_size", + [True, False, (True,), [False, 2], 1.5, (2.0,), "3"], +) +def test_size_arguments_reject_bool_and_non_integral_dimensions(bad_size): + samplers = ( + lambda size: random.rand(size=size), + lambda size: random.uniform(size=size), + lambda size: random.normal(size=size), + lambda size: random.randint(0, 3, size=size), + lambda size: random.choice(3, size=size), + lambda size: random.multivariate_normal([0.0], [[1.0]], size=size), + ) + + for sampler in samplers: + with pytest.raises(TypeError): + sampler(bad_size) + + +@pytest.mark.parametrize("bad_size", [-1, (2, -1)]) +def test_size_arguments_reject_negative_dimensions(bad_size): + samplers = ( + lambda size: random.rand(size=size), + lambda size: random.uniform(size=size), + lambda size: random.normal(size=size), + lambda size: random.randint(0, 3, size=size), + lambda size: random.choice(3, size=size), + lambda size: random.multivariate_normal([0.0], [[1.0]], size=size), + ) + + for sampler in samplers: + with pytest.raises(ValueError): + sampler(bad_size) + + +def test_scalar_and_empty_tuple_sizes_keep_scalar_shape(): + assert random.rand().shape == () + assert random.rand(size=()).shape == () + assert random.normal(size=()).shape == () + assert random.uniform(size=()).shape == () + assert random.randint(0, 3, size=()).shape == () + assert random.multivariate_normal([0.0], [[1.0]], size=()).shape == (1,) + + +def test_zero_sized_choice_still_works_for_empty_population(): + sample = random.choice(0, size=(0,)) + + assert sample.shape == (0,)