diff --git a/python/zarrista/__init__.py b/python/zarrista/__init__.py index 00d0d8c..a0509db 100644 --- a/python/zarrista/__init__.py +++ b/python/zarrista/__init__.py @@ -1,6 +1,6 @@ """A low-level Zarr API for Python, binding to Rust's Zarrs.""" -from typing import TypeAlias +from typing import Literal, TypeAlias from . import codec, exceptions from ._zarrista import ( @@ -29,6 +29,27 @@ type before using layout-specific methods. """ +DataTypeName: TypeAlias = Literal[ + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + "string", + "bytes", +] +"""The Zarr v3 names of the built-in fixed data types. +""" + __all__ = [ "Array", @@ -37,6 +58,7 @@ "AsyncGroup", "ChunkGrid", "DataType", + "DataTypeName", "DecodedArray", "FilesystemStore", "FillValue", diff --git a/python/zarrista/_dtype.pyi b/python/zarrista/_dtype.pyi index 775762d..c7eee92 100644 --- a/python/zarrista/_dtype.pyi +++ b/python/zarrista/_dtype.pyi @@ -1,10 +1,35 @@ -from typing import Any +from typing import Any, Literal, TypeAlias + +DataTypeName: TypeAlias = Literal[ + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "complex64", + "complex128", + "string", + "bytes", +] +"""The Zarr v3 names of the built-in fixed data types. +""" class DataType: """A Zarr v3 data type.""" - def __init__(self, metadata: dict[str, Any]) -> None: + @staticmethod + def from_metadata(metadata: dict[str, Any]) -> DataType: """Construct a data type from its Zarr v3 metadata.""" + @staticmethod + def from_string(name: DataTypeName | str) -> DataType: + """Construct a data type from its Zarr v3 name (e.g. `"float32"`).""" @property def name(self) -> str | None: """The Zarr v3 data-type name (e.g. `"float64"`).""" diff --git a/src/dtype.rs b/src/dtype.rs index f75d4ab..1caa871 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -3,19 +3,14 @@ use std::borrow::Cow; -use crate::error::ZarristaError; +use crate::error::ZarristaResult; use crate::metadata::PyMetadataV3; -use numpy::prelude::*; -use numpy::IntoPyArray; -use pyo3::exceptions::PyNotImplementedError; use pyo3::prelude::*; -use pyo3::IntoPyObjectExt; -use zarrs::array::{Array, ArraySubset}; -use zarrs::array::{ArrayError, ElementOwned}; use zarrs::array::{DataType, DataTypeSize}; -use zarrs::storage::ReadableListableStorageTraits; +use zarrs::metadata::v3::MetadataV3; -#[pyclass(module = "zarrista", frozen, name = "DataType")] +#[derive(Debug, Clone)] +#[pyclass(module = "zarrista", frozen, name = "DataType", skip_from_py_object)] pub struct PyDataType { inner: DataType, } @@ -28,10 +23,19 @@ impl PyDataType { #[pymethods] impl PyDataType { - #[new] - fn py_new(metadata: PyMetadataV3) -> Self { - let data_type = DataType::from_metadata(&metadata.into_inner()).unwrap(); - PyDataType { inner: data_type } + /// Construct a data type from its Zarr v3 metadata. + #[staticmethod] + fn from_metadata(metadata: PyMetadataV3) -> ZarristaResult { + let data_type = DataType::from_metadata(&metadata.into_inner())?; + Ok(Self { inner: data_type }) + } + + /// Construct a data type from its Zarr v3 name (e.g. `"float32"`). + #[staticmethod] + fn from_string(name: &str) -> ZarristaResult { + let metadata = MetadataV3::new(name); + let data_type = DataType::from_metadata(&metadata)?; + Ok(Self { inner: data_type }) } #[getter] @@ -71,122 +75,3 @@ impl From for DataType { py_data_type.inner } } - -/// The store trait object backing every zarrista array/group. -pub(crate) type DynStorage = dyn ReadableListableStorageTraits; - -// The following helpers back the array-read path (numpy region/chunk reads), -// which is still commented out in `array/sync.rs`. Allow dead code until those -// methods are enabled. - -/// A region of an array to read: either an explicit subset or a whole chunk. -#[allow(dead_code)] -pub(crate) enum Region<'a> { - Subset(&'a ArraySubset), - Chunk(&'a [u64]), -} - -#[allow(dead_code)] -fn retrieve_vec( - array: &Array, - region: &Region<'_>, -) -> Result, ArrayError> { - match region { - Region::Subset(subset) => array.retrieve_array_subset(*subset), - Region::Chunk(indices) => array.retrieve_chunk(indices), - } -} - -#[allow(dead_code)] -fn vec_to_numpy( - py: Python<'_>, - data: Vec, - shape: &[usize], -) -> PyResult> { - let array = data.into_pyarray(py); - let reshaped = array.reshape(shape.to_vec())?; - Ok(reshaped.into_any().unbind()) -} - -/// Read a region of `array` into a C-order numpy array of the given output -/// shape. Only fixed-length numeric and boolean dtypes are supported so far. -#[allow(dead_code)] -pub(crate) fn read_region( - py: Python<'_>, - array: &Array, - region: &Region<'_>, - out_shape: &[usize], -) -> PyResult> { - let name = array.data_type().name_v3(); - - macro_rules! arm { - ($t:ty) => {{ - let data: Vec<$t> = retrieve_vec(array, region).map_err(ZarristaError::from)?; - vec_to_numpy(py, data, out_shape) - }}; - } - - match name.as_deref() { - Some("bool") => arm!(bool), - Some("int8") => arm!(i8), - Some("int16") => arm!(i16), - Some("int32") => arm!(i32), - Some("int64") => arm!(i64), - Some("uint8") => arm!(u8), - Some("uint16") => arm!(u16), - Some("uint32") => arm!(u32), - Some("uint64") => arm!(u64), - Some("float16") => arm!(half::f16), - Some("float32") => arm!(f32), - Some("float64") => arm!(f64), - other => Err(PyNotImplementedError::new_err(format!( - "reading dtype {:?} is not supported yet", - other.unwrap_or("") - ))), - } -} - -/// Convert a fill value (native-endian bytes) into a Python scalar, returning -/// `None` for dtypes we do not yet interpret. -#[allow(dead_code)] -pub(crate) fn fill_value_to_py( - py: Python<'_>, - data_type: &DataType, - bytes: &[u8], -) -> PyResult> { - macro_rules! scalar { - ($t:ty) => {{ - const N: usize = std::mem::size_of::<$t>(); - match <[u8; N]>::try_from(bytes) { - Ok(arr) => <$t>::from_ne_bytes(arr).into_bound_py_any(py)?.unbind(), - Err(_) => py.None(), - } - }}; - } - - let dtype_name = data_type.name_v3(); - let value = match dtype_name.as_deref() { - Some("bool") => (!bytes.is_empty() && bytes[0] != 0) - .into_bound_py_any(py)? - .unbind(), - Some("int8") => scalar!(i8), - Some("int16") => scalar!(i16), - Some("int32") => scalar!(i32), - Some("int64") => scalar!(i64), - Some("uint8") => scalar!(u8), - Some("uint16") => scalar!(u16), - Some("uint32") => scalar!(u32), - Some("uint64") => scalar!(u64), - Some("float16") => match <[u8; 2]>::try_from(bytes) { - Ok(arr) => half::f16::from_ne_bytes(arr) - .to_f32() - .into_bound_py_any(py)? - .unbind(), - Err(_) => py.None(), - }, - Some("float32") => scalar!(f32), - Some("float64") => scalar!(f64), - _ => py.None(), - }; - Ok(value) -} diff --git a/tests/test_dtype.py b/tests/test_dtype.py new file mode 100644 index 0000000..94d2aa9 --- /dev/null +++ b/tests/test_dtype.py @@ -0,0 +1,47 @@ +import pytest +from zarrista import DataType + + +def test_from_metadata(): + dtype = DataType.from_metadata({"name": "float32"}) + assert dtype.name == "float32" + assert dtype.size == 4 + + +def test_from_string(): + dtype = DataType.from_string("float32") + assert dtype.name == "float32" + assert dtype.size == 4 + + +def test_from_string_matches_metadata_construction(): + from_string = DataType.from_string("float32") + from_metadata = DataType.from_metadata({"name": "float32"}) + assert from_string == from_metadata + + +def test_from_string_variable_length_has_no_size(): + assert DataType.from_string("string").size is None + + +def test_from_string_rejects_unknown_name(): + with pytest.raises(Exception): # noqa: B017, PT011 + DataType.from_string("not_a_real_dtype") + + +def test_eq_same_dtype(): + assert DataType.from_string("float32") == DataType.from_string("float32") + + +def test_eq_different_dtype(): + assert DataType.from_string("float32") != DataType.from_string("int8") + + +def test_eq_non_dtype_is_false(): + # __eq__ is strict: a string is not equal to a DataType. Conversion is + # explicit via `from_string`, never implicit. + assert DataType.from_string("float32") != "float32" + + +def test_repr(): + assert repr(DataType.from_string("float32")) == "DataType(float32 /