From a732343b5adce8b52c81eb5cdb07a2e6c4380112 Mon Sep 17 00:00:00 2001 From: d33bs Date: Mon, 6 Apr 2026 15:04:00 -0600 Subject: [PATCH 01/19] Add torch/jax array ingest support to OMEArrow --- README.md | 33 +++++++ docs/src/dlpack.md | 20 ++++ docs/src/python-api.md | 18 ++++ src/ome_arrow/__init__.py | 2 + src/ome_arrow/core.py | 54 ++++++++++- src/ome_arrow/ingest.py | 194 ++++++++++++++++++++++++++++++++++++++ tests/test_core.py | 80 ++++++++++++++++ 7 files changed, 397 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ac7e1fc..a9ec120 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,39 @@ Advanced options: See full docs: [`docs/src/dlpack.md`](docs/src/dlpack.md) +## Tensor ingest (PyTorch/JAX) + +You can now ingest torch or JAX arrays directly with `OMEArrow(...)`, or use +explicit helper functions from `ome_arrow.ingest`. + +```python +from ome_arrow import OMEArrow + +# Direct constructor support: +# inferred defaults are rank-based: +# 2D -> "YX", 3D -> "CYX", 4D -> "TCYX", 5D -> "TCZYX" +oa_torch = OMEArrow(torch_tensor) +oa_jax = OMEArrow(jax_array) + +# Optional: override dim order when shape is ambiguous +oa_zyx = OMEArrow(torch_volume, dim_order="ZYX") +``` + +```python +from ome_arrow.ingest import from_torch_array, from_jax_array + +scalar_torch = from_torch_array(torch_tensor, dim_order="TCYX") +scalar_jax = from_jax_array(jax_array, dim_order="TCYX") +``` + +Notes: + +- Torch/JAX support is optional; install extras as needed: + `pip install "ome-arrow[dlpack-torch]"` or + `pip install "ome-arrow[dlpack-jax]"`. +- Torch tensors are detached and converted on CPU for ingest. +- `dim_order` is accepted only for NumPy/torch/JAX array inputs. + ## Benchmarking lazy reads Use the lightweight benchmark utility in `benchmarks/` to compare lazy tensor diff --git a/docs/src/dlpack.md b/docs/src/dlpack.md index 0b99cf9..ad6b0f2 100644 --- a/docs/src/dlpack.md +++ b/docs/src/dlpack.md @@ -42,6 +42,17 @@ flat = torch.utils.dlpack.from_dlpack(capsule) tensor = flat.reshape(view.shape) ``` +You can also ingest torch tensors directly: + +```python +from ome_arrow import OMEArrow + +oa = OMEArrow(torch_tensor) # inferred dim_order by rank +oa_zyx = OMEArrow(torch_volume, dim_order="ZYX") # explicit override +``` + +`dim_order` is only supported for array/tensor ingest paths. + ## Lazy scan-style slicing ```python @@ -76,6 +87,15 @@ flat = jnp.from_dlpack(capsule) arr = flat.reshape(view.shape) ``` +You can also ingest JAX arrays directly: + +```python +from ome_arrow import OMEArrow + +oa = OMEArrow(jax_array) # inferred dim_order by rank +oa_zyx = OMEArrow(jax_volume, dim_order="ZYX") # explicit override +``` + ## Iteration examples ```python diff --git a/docs/src/python-api.md b/docs/src/python-api.md index 6e5a260..96556ba 100644 --- a/docs/src/python-api.md +++ b/docs/src/python-api.md @@ -1,5 +1,23 @@ # Python API +```{eval-rst} +ome_arrow.core +------------------- +.. automodule:: src.ome_arrow.core + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.ingest +------------------- +.. automodule:: src.ome_arrow.ingest + :members: + :undoc-members: + :show-inheritance: +``` + ```{eval-rst} ome_arrow.meta ------------------- diff --git a/src/ome_arrow/__init__.py b/src/ome_arrow/__init__.py index a446e30..4ca3bb4 100644 --- a/src/ome_arrow/__init__.py +++ b/src/ome_arrow/__init__.py @@ -12,11 +12,13 @@ to_ome_zarr, ) from ome_arrow.ingest import ( + from_jax_array, from_numpy, from_ome_parquet, from_ome_vortex, from_ome_zarr, from_tiff, + from_torch_array, to_ome_arrow, ) from ome_arrow.meta import OME_ARROW_STRUCT, OME_ARROW_TAG_TYPE, OME_ARROW_TAG_VERSION diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 6a3b5ea..78429cf 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -29,12 +29,16 @@ to_ome_zarr, ) from ome_arrow.ingest import ( + _is_jax_array, + _is_torch_array, + from_jax_array, from_numpy, from_ome_parquet, from_ome_vortex, from_ome_zarr, from_stack_pattern_path, from_tiff, + from_torch_array, open_lazy_plane_source, ) from ome_arrow.meta import OME_ARROW_STRUCT @@ -92,6 +96,7 @@ class OMEArrow: def __init__( self, data: str | dict | pa.StructScalar | "np.ndarray", + dim_order: str | None = None, tcz: Tuple[int, int, int] = (0, 0, 0), column_name: str = "ome_arrow", row_index: int = 0, @@ -107,12 +112,32 @@ def __init__( - a path/URL to a Vortex file (.vortex) - a NumPy ndarray (2D-5D; interpreted with from_numpy defaults) + - a torch.Tensor (2D-5D; inferred dim order by rank unless provided via + `dim_order`) + - a jax.Array (2D-5D; inferred dim order by rank unless provided via + `dim_order`) - a dict already matching the OME-Arrow schema - a pa.StructScalar already typed to OME_ARROW_STRUCT - optionally override/set image_type metadata on ingest - optionally defer source-file ingestion with lazy=True + + Args: + data: Input source or record payload. + dim_order: Axis labels used only for array/tensor ingest + (NumPy, torch, JAX). Ignored inputs are rejected to prevent + silent configuration mistakes. """ + # `dim_order` only applies to in-memory array/tensor ingestion paths. + # Rejecting incompatible combinations avoids silently ignoring user intent. + if dim_order is not None and not ( + isinstance(data, np.ndarray) or _is_torch_array(data) or _is_jax_array(data) + ): + raise ValueError( + "dim_order is supported only for numpy.ndarray, torch.Tensor, " + "or jax.Array inputs." + ) + # set the tcz for viewing self.tcz = tcz self._data: pa.StructScalar | None = None @@ -161,15 +186,35 @@ def __init__( # Uses from_numpy defaults: dim_order="TCZYX", clamp_to_uint16=True, etc. # If the array is YX/ZYX/CYX/etc., # from_numpy will expand/reorder accordingly. - self.data = from_numpy(data, image_type=image_type) + self.data = from_numpy( + data, + dim_order=dim_order or "TCZYX", + image_type=image_type, + ) + + # --- 4) Torch tensor ------------------------------------------------------ + elif _is_torch_array(data): + self.data = from_torch_array( + data, + dim_order=dim_order, + image_type=image_type, + ) + + # --- 5) JAX array -------------------------------------------------------- + elif _is_jax_array(data): + self.data = from_jax_array( + data, + dim_order=dim_order, + image_type=image_type, + ) - # --- 4) Already-typed Arrow scalar --------------------------------------- + # --- 6) Already-typed Arrow scalar --------------------------------------- elif isinstance(data, pa.StructScalar): self.data = data if image_type is not None: self.data = self._wrap_with_image_type(self.data, image_type) - # --- 5) Plain dict matching the schema ----------------------------------- + # --- 7) Plain dict matching the schema ----------------------------------- elif isinstance(data, dict): record = {f.name: data.get(f.name) for f in OME_ARROW_STRUCT} self.data = pa.scalar(record, type=OME_ARROW_STRUCT) @@ -179,7 +224,8 @@ def __init__( # --- otherwise ------------------------------------------------------------ else: raise TypeError( - "input data must be str, dict, pa.StructScalar, or numpy.ndarray" + "input data must be str, dict, pa.StructScalar, numpy.ndarray, " + "torch.Tensor, or jax.Array" ) @classmethod diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index 38643f8..e727b00 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -809,6 +809,200 @@ def from_numpy( ) +def _is_torch_array(data: Any) -> bool: + """Return True when ``data`` looks like a torch tensor.""" + module = getattr(type(data), "__module__", "") + return module == "torch" or module.startswith("torch.") + + +def _is_jax_array(data: Any) -> bool: + """Return True when ``data`` looks like a JAX array.""" + module = getattr(type(data), "__module__", "") + return module.startswith("jax.") or module.startswith("jaxlib.") + + +def _infer_dim_order_for_tensor_rank(ndim: int) -> str: + """Infer a practical default dim order for tensor backends.""" + if ndim == 2: + return "YX" + if ndim == 3: + return "CYX" + if ndim == 4: + return "TCYX" + if ndim == 5: + return "TCZYX" + raise ValueError( + "Unable to infer dim_order for tensor rank " + f"{ndim}. Provide dim_order explicitly." + ) + + +def from_torch_array( + arr: Any, + *, + dim_order: str | None = None, + image_id: Optional[str] = None, + name: Optional[str] = None, + image_type: Optional[str] = None, + channel_names: Optional[Sequence[str]] = None, + acquisition_datetime: Optional[datetime] = None, + clamp_to_uint16: bool = True, + chunk_shape: Optional[Tuple[int, int, int]] = (1, 512, 512), + chunk_order: str = "ZYX", + build_chunks: bool = True, + # meta + physical_size_x: float = 1.0, + physical_size_y: float = 1.0, + physical_size_z: float = 1.0, + physical_size_unit: str = "µm", + dtype_meta: Optional[str] = None, +) -> pa.StructScalar: + """Build an OME-Arrow StructScalar from a torch tensor. + + Args: + arr: ``torch.Tensor`` image data. + dim_order: Axis labels for ``arr``. If None, infer from rank: + 2D->"YX", 3D->"CYX", 4D->"TCYX", 5D->"TCZYX". + image_id: Optional stable image identifier. + name: Optional human label. + image_type: Open-ended image kind (e.g., "image", "label"). + channel_names: Names for channels; defaults to C0..C{n-1}. + acquisition_datetime: Defaults to now (UTC) if None. + clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. + chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). + chunk_order: Flattening order for chunk pixels (default "ZYX"). + build_chunks: If True, build chunked pixels from planes. + physical_size_x: Spatial pixel size (µm) for X. + physical_size_y: Spatial pixel size (µm) for Y. + physical_size_z: Spatial pixel size (µm) for Z when present. + physical_size_unit: Unit string for spatial axes (default "µm"). + dtype_meta: Pixel dtype string to place in metadata. + + Returns: + pa.StructScalar: Typed OME-Arrow record. + """ + try: + import torch + except ImportError as exc: + raise RuntimeError( + "Torch is not installed. Install extras: " + "pip install 'ome-arrow[dlpack-torch]'." + ) from exc + + if not isinstance(arr, torch.Tensor): + raise TypeError("from_torch_array expects a torch.Tensor.") + + tensor = arr.detach() + if tensor.layout != torch.strided: + tensor = tensor.to_dense() + if getattr(tensor, "is_conj", lambda: False)(): + tensor = tensor.resolve_conj() + if getattr(tensor, "is_neg", lambda: False)(): + tensor = tensor.resolve_neg() + if tensor.device.type != "cpu": + # OME-Arrow ingest currently serializes from host memory. + tensor = tensor.to(device="cpu") + + # For CPU strided tensors this is typically a zero-copy NumPy view. + np_arr = tensor.numpy() + resolved_dim_order = dim_order or _infer_dim_order_for_tensor_rank(np_arr.ndim) + return from_numpy( + np_arr, + dim_order=resolved_dim_order, + image_id=image_id, + name=name, + image_type=image_type, + channel_names=channel_names, + acquisition_datetime=acquisition_datetime, + clamp_to_uint16=clamp_to_uint16, + chunk_shape=chunk_shape, + chunk_order=chunk_order, + build_chunks=build_chunks, + physical_size_x=physical_size_x, + physical_size_y=physical_size_y, + physical_size_z=physical_size_z, + physical_size_unit=physical_size_unit, + dtype_meta=dtype_meta, + ) + + +def from_jax_array( + arr: Any, + *, + dim_order: str | None = None, + image_id: Optional[str] = None, + name: Optional[str] = None, + image_type: Optional[str] = None, + channel_names: Optional[Sequence[str]] = None, + acquisition_datetime: Optional[datetime] = None, + clamp_to_uint16: bool = True, + chunk_shape: Optional[Tuple[int, int, int]] = (1, 512, 512), + chunk_order: str = "ZYX", + build_chunks: bool = True, + # meta + physical_size_x: float = 1.0, + physical_size_y: float = 1.0, + physical_size_z: float = 1.0, + physical_size_unit: str = "µm", + dtype_meta: Optional[str] = None, +) -> pa.StructScalar: + """Build an OME-Arrow StructScalar from a JAX array. + + Args: + arr: ``jax.Array`` image data. + dim_order: Axis labels for ``arr``. If None, infer from rank: + 2D->"YX", 3D->"CYX", 4D->"TCYX", 5D->"TCZYX". + image_id: Optional stable image identifier. + name: Optional human label. + image_type: Open-ended image kind (e.g., "image", "label"). + channel_names: Names for channels; defaults to C0..C{n-1}. + acquisition_datetime: Defaults to now (UTC) if None. + clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. + chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). + chunk_order: Flattening order for chunk pixels (default "ZYX"). + build_chunks: If True, build chunked pixels from planes. + physical_size_x: Spatial pixel size (µm) for X. + physical_size_y: Spatial pixel size (µm) for Y. + physical_size_z: Spatial pixel size (µm) for Z when present. + physical_size_unit: Unit string for spatial axes (default "µm"). + dtype_meta: Pixel dtype string to place in metadata. + + Returns: + pa.StructScalar: Typed OME-Arrow record. + """ + try: + import jax + except ImportError as exc: + raise RuntimeError( + "JAX is not installed. Install extras: pip install 'ome-arrow[dlpack-jax]'." + ) from exc + + if not isinstance(arr, jax.Array): + raise TypeError("from_jax_array expects a jax.Array.") + + # Materializes a host NumPy view/copy as needed before Arrow serialization. + np_arr = np.asarray(arr) + resolved_dim_order = dim_order or _infer_dim_order_for_tensor_rank(np_arr.ndim) + return from_numpy( + np_arr, + dim_order=resolved_dim_order, + image_id=image_id, + name=name, + image_type=image_type, + channel_names=channel_names, + acquisition_datetime=acquisition_datetime, + clamp_to_uint16=clamp_to_uint16, + chunk_shape=chunk_shape, + chunk_order=chunk_order, + build_chunks=build_chunks, + physical_size_x=physical_size_x, + physical_size_y=physical_size_y, + physical_size_z=physical_size_z, + physical_size_unit=physical_size_unit, + dtype_meta=dtype_meta, + ) + + def from_tiff( tiff_path: str | Path, image_id: Optional[str] = None, diff --git a/tests/test_core.py b/tests/test_core.py index bb15a43..570585f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -395,6 +395,86 @@ def test_parquet_roundtrip_preserves_image_type(tmp_path: pathlib.Path) -> None: assert reloaded.data.as_py()["image_type"] == "label" +def test_constructor_accepts_torch_tensor() -> None: + """Accept torch tensors directly in OMEArrow constructor.""" + torch = pytest.importorskip("torch") + + tensor = torch.arange(2 * 3 * 4).reshape(2, 3, 4).to(dtype=torch.uint16) + oa = OMEArrow(tensor) + + exported = oa.export(how="numpy") + assert exported.shape == (1, 2, 1, 3, 4) + np.testing.assert_array_equal(exported[0, :, 0], tensor.numpy()) + + +def test_constructor_accepts_jax_array() -> None: + """Accept JAX arrays directly in OMEArrow constructor.""" + jnp = pytest.importorskip("jax.numpy") + + arr = jnp.arange(2 * 3 * 4, dtype=jnp.uint16).reshape(2, 3, 4) + oa = OMEArrow(arr) + + exported = oa.export(how="numpy") + assert exported.shape == (1, 2, 1, 3, 4) + np.testing.assert_array_equal(exported[0, :, 0], np.asarray(arr)) + + +def test_from_torch_array_explicit_dim_order() -> None: + """Support explicit dim order when ingesting torch arrays.""" + torch = pytest.importorskip("torch") + + tensor = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).to(dtype=torch.uint16) + scalar = ingest.from_torch_array(tensor, dim_order="TCYX") + oa = OMEArrow(scalar) + + exported = oa.export(how="numpy") + assert exported.shape == (2, 3, 1, 4, 5) + np.testing.assert_array_equal(exported[:, :, 0], tensor.numpy()) + + +def test_from_jax_array_explicit_dim_order() -> None: + """Support explicit dim order when ingesting JAX arrays.""" + jnp = pytest.importorskip("jax.numpy") + + arr = jnp.arange(2 * 3 * 4 * 5, dtype=jnp.uint16).reshape(2, 3, 4, 5) + scalar = ingest.from_jax_array(arr, dim_order="TCYX") + oa = OMEArrow(scalar) + + exported = oa.export(how="numpy") + assert exported.shape == (2, 3, 1, 4, 5) + np.testing.assert_array_equal(exported[:, :, 0], np.asarray(arr)) + + +def test_constructor_dim_order_override_torch_tensor() -> None: + """Allow explicit constructor dim_order for ambiguous torch tensor ranks.""" + torch = pytest.importorskip("torch") + + tensor = torch.arange(3 * 4 * 5).reshape(3, 4, 5).to(dtype=torch.uint16) + oa = OMEArrow(tensor, dim_order="ZYX") + + exported = oa.export(how="numpy") + assert exported.shape == (1, 1, 3, 4, 5) + np.testing.assert_array_equal(exported[0, 0], tensor.numpy()) + + +def test_constructor_dim_order_override_jax_array() -> None: + """Allow explicit constructor dim_order for ambiguous JAX array ranks.""" + jnp = pytest.importorskip("jax.numpy") + + arr = jnp.arange(3 * 4 * 5, dtype=jnp.uint16).reshape(3, 4, 5) + oa = OMEArrow(arr, dim_order="ZYX") + + exported = oa.export(how="numpy") + assert exported.shape == (1, 1, 3, 4, 5) + np.testing.assert_array_equal(exported[0, 0], np.asarray(arr)) + + +def test_constructor_dim_order_rejects_non_array_input() -> None: + """Reject dim_order for non-array sources to avoid silent no-op configs.""" + with pytest.raises(ValueError, match="dim_order is supported only"): + OMEArrow("tests/data/JUMP-BR00117006/BR00117006.ome.parquet", dim_order="ZYX") + + def test_vortex_custom_column_name(tmp_path: pathlib.Path) -> None: """Ensure custom Vortex column names are preserved on round-trip.""" pytest.importorskip( From e79a3ece87691e053998c3be78e44a30240b80a6 Mon Sep 17 00:00:00 2001 From: d33bs Date: Mon, 6 Apr 2026 15:22:04 -0600 Subject: [PATCH 02/19] copilot review changes + docs fixes --- docs/src/index.md | 1 + docs/src/python-api.md | 53 +++++++++++++++++++++-- src/ome_arrow/core.py | 93 ++++++++++++++--------------------------- src/ome_arrow/ingest.py | 12 +++++- 4 files changed, 91 insertions(+), 68 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 32b49c3..0d68586 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,4 +14,5 @@ maxdepth: 3 --- python-api dlpack +examples/learning_to_fly_with_ome-arrow ``` diff --git a/docs/src/python-api.md b/docs/src/python-api.md index 96556ba..8e110dd 100644 --- a/docs/src/python-api.md +++ b/docs/src/python-api.md @@ -1,9 +1,18 @@ # Python API +```{eval-rst} +ome_arrow +------------------- +.. automodule:: ome_arrow + :members: + :undoc-members: + :show-inheritance: +``` + ```{eval-rst} ome_arrow.core ------------------- -.. automodule:: src.ome_arrow.core +.. automodule:: ome_arrow.core :members: :undoc-members: :show-inheritance: @@ -12,7 +21,16 @@ ome_arrow.core ```{eval-rst} ome_arrow.ingest ------------------- -.. automodule:: src.ome_arrow.ingest +.. automodule:: ome_arrow.ingest + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.export +------------------- +.. automodule:: ome_arrow.export :members: :undoc-members: :show-inheritance: @@ -21,7 +39,7 @@ ome_arrow.ingest ```{eval-rst} ome_arrow.meta ------------------- -.. automodule:: src.ome_arrow.meta +.. automodule:: ome_arrow.meta :members: :private-members: :undoc-members: @@ -31,7 +49,34 @@ ome_arrow.meta ```{eval-rst} ome_arrow.tensor ------------------- -.. automodule:: src.ome_arrow.tensor +.. automodule:: ome_arrow.tensor + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.transform +------------------- +.. automodule:: ome_arrow.transform + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.utils +------------------- +.. automodule:: ome_arrow.utils + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.view +------------------- +.. automodule:: ome_arrow.view :members: :undoc-members: :show-inheritance: diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 78429cf..17e02ed 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -188,7 +188,7 @@ def __init__( # from_numpy will expand/reorder accordingly. self.data = from_numpy( data, - dim_order=dim_order or "TCZYX", + dim_order="TCZYX" if dim_order is None else dim_order, image_type=image_type, ) @@ -652,78 +652,47 @@ def view( show_axes: bool = True, scaling_values: tuple[float, float, float] | None = None, ) -> matplotlib.figure.Figure | "pyvista.Plotter": - """ - Render an OME-Arrow record using Matplotlib or PyVista. + """Render an OME-Arrow record using Matplotlib or PyVista. This convenience method supports two rendering backends: - * ``how="matplotlib"`` — renders a single (t, c, z) plane as a 2D image. - Returns a Matplotlib :class:`~matplotlib.figure.Figure` (or whatever - :func:`view_matplotlib` returns) and optionally displays it with - ``plt.show()`` when ``show=True``. - - * ``how="pyvista"`` — creates an interactive 3D PyVista visualization in - Jupyter. When ``show=True``, displays the widget. Independently, a static - PNG snapshot is embedded in the notebook (inside a collapsed - ``
`` block) for non-interactive renderers (e.g., GitHub). + - ``how="matplotlib"`` renders a single ``(t, c, z)`` plane as a 2D + image. + - ``how="pyvista"`` creates an interactive 3D PyVista visualization. Args: - how: Rendering backend. One of ``"matplotlib"`` or ``"pyvista"``. - tcz: The (t, c, z) indices of the plane to display when using Matplotlib. - Defaults to ``(0, 0, 0)``. - autoscale: If ``True`` and ``vmin``/``vmax`` are not provided, infer - display limits from the image data range (Matplotlib path only). - vmin: Lower display limit for intensity scaling (Matplotlib path only). - vmax: Upper display limit for intensity scaling (Matplotlib path only). - cmap: Matplotlib colormap name for single-channel display (Matplotlib only). - show: Whether to display the plot immediately. For Matplotlib, calls - ``plt.show()``. For PyVista, calls ``plotter.show()``. - c: Channel index override for the PyVista view. If ``None``, uses - ``tcz[1]`` (the ``c`` from ``tcz``). - downsample: Integer downsampling factor for the PyVista volume or slices. - Must be ``>= 1``. - opacity: Opacity specification for PyVista. Either a float in ``[0, 1]`` - or the string ``"sigmoid"`` (backend interprets as a preset transfer - function). - clim: Contrast limits (``(low, high)``) for PyVista rendering. - show_axes: If ``True``, display axes in the PyVista scene. - scaling_values: Physical scale multipliers for the (x, y, z) axes used by - PyVista, typically to express anisotropy. If ``None``, uses metadata - scaling from the OME-Arrow record (pixels_meta.physical_size_x/y/z). - These scaling values will default to 1µm if metadata is missing in - source image metadata. + how: Rendering backend. One of ``"matplotlib"`` or ``"pyvista"``. + tcz: ``(t, c, z)`` indices used for plane display. + autoscale: Infer Matplotlib display limits from image range when + ``vmin``/``vmax`` are not provided. + vmin: Lower display limit for Matplotlib intensity scaling. + vmax: Upper display limit for Matplotlib intensity scaling. + cmap: Matplotlib colormap name for single-channel display. + show: Whether to display the plot immediately. + c: Channel index override for PyVista. If ``None``, uses + ``tcz[1]``. + downsample: Integer downsampling factor for PyVista views. + opacity: Opacity for PyVista. Either a float in ``[0, 1]`` or + ``"sigmoid"``. + clim: Contrast limits ``(low, high)`` for PyVista rendering. + show_axes: Whether to display axes in the PyVista scene. + scaling_values: Physical scale multipliers ``(x, y, z)`` used by + PyVista. If ``None``, uses OME metadata-derived scaling. Returns: - matplotlib.figure.Figure | pyvista.Plotter: - * If ``how="matplotlib"``, returns the figure created by - :func:`view_matplotlib` (often a :class:`~matplotlib.figure.Figure`). - * If ``how="pyvista"``, returns the created :class:`pyvista.Plotter`. + matplotlib.figure.Figure | pyvista.Plotter: Matplotlib figure for + ``how="matplotlib"`` or PyVista plotter for ``how="pyvista"``. Raises: - ValueError: If a requested plane (``t,c,z``) is not found or if pixel - array dimensions are inconsistent (propagated from - :func:`view_matplotlib`). - TypeError: If parameter types are invalid (e.g., negative ``downsample``). + ValueError: If a requested plane is not found or the render mode + is unsupported. + TypeError: If parameter types are invalid. Notes: - * The PyVista path embeds a static PNG snapshot via Pillow (``PIL``). If - Pillow is unavailable, the method logs a warning and skips the snapshot, - but the interactive viewer is still returned. - * When ``show=False`` and ``how="pyvista"``, no interactive window is - opened, but the returned :class:`pyvista.Plotter` can be shown later. - - Examples: - Display a single plane with Matplotlib: - - >>> fig = obj.view(how="matplotlib", tcz=(0, 1, 5), cmap="magma") - - Create an interactive PyVista scene in a Jupyter notebook: - - >>> plotter = obj.view(how="pyvista", c=0, downsample=2, show=True) - - Configure PyVista contrast limits and keep axes hidden: - - >>> plotter = obj.view(how="pyvista", clim=(100, 2000), show_axes=False) + - The PyVista path attempts to embed a static PNG snapshot for + non-interactive renderers. + - When ``show=False`` and ``how="pyvista"``, the returned + :class:`pyvista.Plotter` can be shown later. """ self._ensure_materialized() diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index e727b00..13da10c 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -905,7 +905,11 @@ def from_torch_array( # For CPU strided tensors this is typically a zero-copy NumPy view. np_arr = tensor.numpy() - resolved_dim_order = dim_order or _infer_dim_order_for_tensor_rank(np_arr.ndim) + resolved_dim_order = ( + _infer_dim_order_for_tensor_rank(np_arr.ndim) + if dim_order is None + else dim_order + ) return from_numpy( np_arr, dim_order=resolved_dim_order, @@ -982,7 +986,11 @@ def from_jax_array( # Materializes a host NumPy view/copy as needed before Arrow serialization. np_arr = np.asarray(arr) - resolved_dim_order = dim_order or _infer_dim_order_for_tensor_rank(np_arr.ndim) + resolved_dim_order = ( + _infer_dim_order_for_tensor_rank(np_arr.ndim) + if dim_order is None + else dim_order + ) return from_numpy( np_arr, dim_order=resolved_dim_order, From 04a93cf8fff9a2e36a5d2491db59ffa6d21e3706 Mon Sep 17 00:00:00 2001 From: d33bs Date: Mon, 6 Apr 2026 15:25:44 -0600 Subject: [PATCH 03/19] coderabbit review changes --- pyproject.toml | 4 ++-- uv.lock | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f62006c..f25fa87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,11 @@ dependencies = [ "pyarrow>=22", ] optional-dependencies.dlpack = [ - "jax>=0.4", + "jax>=0.4.1", "torch>=2.1", ] optional-dependencies.dlpack-jax = [ - "jax>=0.4", + "jax>=0.4.1", ] optional-dependencies.dlpack-torch = [ "torch>=2.1", diff --git a/uv.lock b/uv.lock index d0474e6..2988c5e 100644 --- a/uv.lock +++ b/uv.lock @@ -2926,8 +2926,8 @@ requires-dist = [ { name = "bioio-tifffile", specifier = ">=1.3" }, { name = "fire", specifier = ">=0.7" }, { name = "ipywidgets", marker = "extra == 'viz'", specifier = ">=8.1.8" }, - { name = "jax", marker = "extra == 'dlpack'", specifier = ">=0.4" }, - { name = "jax", marker = "extra == 'dlpack-jax'", specifier = ">=0.4" }, + { name = "jax", marker = "extra == 'dlpack'", specifier = ">=0.4.1" }, + { name = "jax", marker = "extra == 'dlpack-jax'", specifier = ">=0.4.1" }, { name = "jupyterlab-widgets", marker = "extra == 'viz'", specifier = ">=3.0.16" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "numpy", specifier = ">=2.2.6" }, From c97bdf9fd6b0003599df99d056fb7bd6cf8c3544 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 06:10:39 -0600 Subject: [PATCH 04/19] better docs and zyx default Co-Authored-By: Gregory Way --- README.md | 2 +- docs/src/dlpack.md | 23 +++++++++++++++++++---- src/ome_arrow/ingest.py | 6 +++--- tests/test_core.py | 8 ++++---- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index a9ec120..e24ff4f 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ from ome_arrow import OMEArrow # Direct constructor support: # inferred defaults are rank-based: -# 2D -> "YX", 3D -> "CYX", 4D -> "TCYX", 5D -> "TCZYX" +# 2D -> "YX", 3D -> "ZYX", 4D -> "TCYX", 5D -> "TCZYX" oa_torch = OMEArrow(torch_tensor) oa_jax = OMEArrow(jax_array) diff --git a/docs/src/dlpack.md b/docs/src/dlpack.md index ad6b0f2..f04d068 100644 --- a/docs/src/dlpack.md +++ b/docs/src/dlpack.md @@ -46,11 +46,19 @@ You can also ingest torch tensors directly: ```python from ome_arrow import OMEArrow +import torch + +# 2D tensor interpreted as YX by default. +torch_tensor = torch.randint(0, 256, (128, 128), dtype=torch.uint16) +oa = OMEArrow(torch_tensor) -oa = OMEArrow(torch_tensor) # inferred dim_order by rank -oa_zyx = OMEArrow(torch_volume, dim_order="ZYX") # explicit override +# 3D tensors are inferred as ZYX by default. +# Use dim_order when your tensor is arranged differently (for example CYX). +torch_volume = torch.randint(0, 256, (16, 128, 128), dtype=torch.uint16) +oa_cyx = OMEArrow(torch_volume, dim_order="CYX") ``` +Use `dim_order` when the inferred axis order does not match your tensor layout. `dim_order` is only supported for array/tensor ingest paths. ## Lazy scan-style slicing @@ -91,9 +99,16 @@ You can also ingest JAX arrays directly: ```python from ome_arrow import OMEArrow +import jax.numpy as jnp + +# 2D array interpreted as YX by default. +jax_array = jnp.arange(128 * 128, dtype=jnp.uint16).reshape(128, 128) +oa = OMEArrow(jax_array) -oa = OMEArrow(jax_array) # inferred dim_order by rank -oa_zyx = OMEArrow(jax_volume, dim_order="ZYX") # explicit override +# 3D arrays are inferred as ZYX by default. +# Use dim_order when your array is arranged differently (for example CYX). +jax_volume = jnp.arange(16 * 128 * 128, dtype=jnp.uint16).reshape(16, 128, 128) +oa_cyx = OMEArrow(jax_volume, dim_order="CYX") ``` ## Iteration examples diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index 13da10c..e479994 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -826,7 +826,7 @@ def _infer_dim_order_for_tensor_rank(ndim: int) -> str: if ndim == 2: return "YX" if ndim == 3: - return "CYX" + return "ZYX" if ndim == 4: return "TCYX" if ndim == 5: @@ -862,7 +862,7 @@ def from_torch_array( Args: arr: ``torch.Tensor`` image data. dim_order: Axis labels for ``arr``. If None, infer from rank: - 2D->"YX", 3D->"CYX", 4D->"TCYX", 5D->"TCZYX". + 2D->"YX", 3D->"ZYX", 4D->"TCYX", 5D->"TCZYX". image_id: Optional stable image identifier. name: Optional human label. image_type: Open-ended image kind (e.g., "image", "label"). @@ -955,7 +955,7 @@ def from_jax_array( Args: arr: ``jax.Array`` image data. dim_order: Axis labels for ``arr``. If None, infer from rank: - 2D->"YX", 3D->"CYX", 4D->"TCYX", 5D->"TCZYX". + 2D->"YX", 3D->"ZYX", 4D->"TCYX", 5D->"TCZYX". image_id: Optional stable image identifier. name: Optional human label. image_type: Open-ended image kind (e.g., "image", "label"). diff --git a/tests/test_core.py b/tests/test_core.py index 570585f..107b7b2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -403,8 +403,8 @@ def test_constructor_accepts_torch_tensor() -> None: oa = OMEArrow(tensor) exported = oa.export(how="numpy") - assert exported.shape == (1, 2, 1, 3, 4) - np.testing.assert_array_equal(exported[0, :, 0], tensor.numpy()) + assert exported.shape == (1, 1, 2, 3, 4) + np.testing.assert_array_equal(exported[0, 0], tensor.numpy()) def test_constructor_accepts_jax_array() -> None: @@ -415,8 +415,8 @@ def test_constructor_accepts_jax_array() -> None: oa = OMEArrow(arr) exported = oa.export(how="numpy") - assert exported.shape == (1, 2, 1, 3, 4) - np.testing.assert_array_equal(exported[0, :, 0], np.asarray(arr)) + assert exported.shape == (1, 1, 2, 3, 4) + np.testing.assert_array_equal(exported[0, 0], np.asarray(arr)) def test_from_torch_array_explicit_dim_order() -> None: From fbe78a3a1c7b700cca166439e3fd5c0494b9622a Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 06:15:25 -0600 Subject: [PATCH 05/19] better docs for channel names Co-Authored-By: Gregory Way --- src/ome_arrow/ingest.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index e479994..c582a45 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -676,7 +676,10 @@ def from_numpy( image_id: Optional stable image identifier. name: Optional human label. image_type: Open-ended image kind (e.g., "image", "label"). - channel_names: Names for channels; defaults to C0..C{n-1}. + channel_names: Optional channel names. Defaults to ``None``. When + ``None`` (or length does not match channel count), names are + auto-generated as ``C0..C{n-1}`` (for example, 3 channels become + ``C0``, ``C1``, ``C2``). acquisition_datetime: Defaults to now (UTC) if None. clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). @@ -866,7 +869,10 @@ def from_torch_array( image_id: Optional stable image identifier. name: Optional human label. image_type: Open-ended image kind (e.g., "image", "label"). - channel_names: Names for channels; defaults to C0..C{n-1}. + channel_names: Optional channel names. Defaults to ``None``. When + ``None`` (or length does not match channel count), names are + auto-generated as ``C0..C{n-1}`` (for example, 3 channels become + ``C0``, ``C1``, ``C2``). acquisition_datetime: Defaults to now (UTC) if None. clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). From c51bd66d5e58f5e1805975d9a2587b764f17e462 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 06:20:42 -0600 Subject: [PATCH 06/19] better clarity Co-Authored-By: Gregory Way --- src/ome_arrow/ingest.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index c582a45..b82a093 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -862,6 +862,11 @@ def from_torch_array( ) -> pa.StructScalar: """Build an OME-Arrow StructScalar from a torch tensor. + This is useful when your pipeline already works with ``torch.Tensor`` + objects (for example model inputs/outputs) and you want a direct path into + the canonical OME-Arrow struct without manually converting and reshaping in + user code. + Args: arr: ``torch.Tensor`` image data. dim_order: Axis labels for ``arr``. If None, infer from rank: @@ -958,6 +963,10 @@ def from_jax_array( ) -> pa.StructScalar: """Build an OME-Arrow StructScalar from a JAX array. + This is useful when your pipeline already works with ``jax.Array`` objects + and you want a direct path into the canonical OME-Arrow struct without + manual conversion boilerplate in user code. + Args: arr: ``jax.Array`` image data. dim_order: Axis labels for ``arr``. If None, infer from rank: From e1f55c81a7c10dbf8f0f8dbeb33de44a461ffbdd Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 09:39:56 -0600 Subject: [PATCH 07/19] better docs Co-Authored-By: Gregory Way --- README.md | 22 +++++++-- src/ome_arrow/ingest.py | 106 ++++++++++++++++++++++++++++------------ 2 files changed, 91 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index e24ff4f..1c7235f 100644 --- a/README.md +++ b/README.md @@ -122,8 +122,18 @@ See full docs: [`docs/src/dlpack.md`](docs/src/dlpack.md) ## Tensor ingest (PyTorch/JAX) -You can now ingest torch or JAX arrays directly with `OMEArrow(...)`, or use -explicit helper functions from `ome_arrow.ingest`. +You can ingest torch or JAX arrays directly with `OMEArrow(...)`. +You can also use explicit helper functions from `ome_arrow.ingest`. + +Why this is useful: + +- It removes conversion boilerplate in model/data pipelines that already use torch or JAX tensors. +- It keeps axis handling in one place (`dim_order`). +- This reduces mistakes when moving between tensor layouts and OME-Arrow records. +- It can reduce overhead in some paths. +- For example, CPU torch tensors often expose a NumPy view without an extra copy. +- Ingest still materializes OME-Arrow planes/chunks. +- This is more about clean interoperability than dramatic end-to-end speedups. ```python from ome_arrow import OMEArrow @@ -147,11 +157,13 @@ scalar_jax = from_jax_array(jax_array, dim_order="TCYX") Notes: -- Torch/JAX support is optional; install extras as needed: - `pip install "ome-arrow[dlpack-torch]"` or - `pip install "ome-arrow[dlpack-jax]"`. +- Torch/JAX support is optional. +- Install extras as needed: + `pip install "ome-arrow[dlpack-torch]"` or `pip install "ome-arrow[dlpack-jax]"`. - Torch tensors are detached and converted on CPU for ingest. - `dim_order` is accepted only for NumPy/torch/JAX array inputs. +- Ingest now passes flattened NumPy pixel buffers directly to Arrow. +- This avoids materializing Python `list` payloads per plane/chunk. ## Benchmarking lazy reads diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index b82a093..de0bb2f 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -459,7 +459,7 @@ def _build_chunks_from_planes( "shape_z": sz, "shape_y": sy, "shape_x": sx, - "pixels": slab.reshape(-1).tolist(), + "pixels": slab.reshape(-1), } ) return chunks @@ -563,7 +563,14 @@ def to_ome_arrow( ch["illumination"] = str(ch["illumination"]) if planes is None: - planes = [{"z": 0, "t": 0, "c": 0, "pixels": [0] * (size_x * size_y)}] + planes = [ + { + "z": 0, + "t": 0, + "c": 0, + "pixels": np.zeros(size_x * size_y, dtype=np.uint16), + } + ] if chunks is None and build_chunks: chunks = _build_chunks_from_planes( @@ -777,9 +784,7 @@ def from_numpy( for c in range(size_c): for z in range(size_z): plane = tczyx[t, c, z] - planes.append( - {"z": z, "t": t, "c": c, "pixels": plane.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": plane.reshape(-1)}) # Meta dimension_order: mirror your other ingests meta_dim_order = "XYCT" if size_z == 1 else "XYZCT" @@ -840,6 +845,51 @@ def _infer_dim_order_for_tensor_rank(ndim: int) -> str: ) +def _from_array_via_numpy( + np_arr: np.ndarray, + *, + dim_order: str | None, + image_id: Optional[str], + name: Optional[str], + image_type: Optional[str], + channel_names: Optional[Sequence[str]], + acquisition_datetime: Optional[datetime], + clamp_to_uint16: bool, + chunk_shape: Optional[Tuple[int, int, int]], + chunk_order: str, + build_chunks: bool, + physical_size_x: float, + physical_size_y: float, + physical_size_z: float, + physical_size_unit: str, + dtype_meta: Optional[str], +) -> pa.StructScalar: + """Shared array->NumPy->OME-Arrow conversion path.""" + resolved_dim_order = ( + _infer_dim_order_for_tensor_rank(np_arr.ndim) + if dim_order is None + else dim_order + ) + return from_numpy( + np_arr, + dim_order=resolved_dim_order, + image_id=image_id, + name=name, + image_type=image_type, + channel_names=channel_names, + acquisition_datetime=acquisition_datetime, + clamp_to_uint16=clamp_to_uint16, + chunk_shape=chunk_shape, + chunk_order=chunk_order, + build_chunks=build_chunks, + physical_size_x=physical_size_x, + physical_size_y=physical_size_y, + physical_size_z=physical_size_z, + physical_size_unit=physical_size_unit, + dtype_meta=dtype_meta, + ) + + def from_torch_array( arr: Any, *, @@ -916,14 +966,9 @@ def from_torch_array( # For CPU strided tensors this is typically a zero-copy NumPy view. np_arr = tensor.numpy() - resolved_dim_order = ( - _infer_dim_order_for_tensor_rank(np_arr.ndim) - if dim_order is None - else dim_order - ) - return from_numpy( + return _from_array_via_numpy( np_arr, - dim_order=resolved_dim_order, + dim_order=dim_order, image_id=image_id, name=name, image_type=image_type, @@ -974,7 +1019,10 @@ def from_jax_array( image_id: Optional stable image identifier. name: Optional human label. image_type: Open-ended image kind (e.g., "image", "label"). - channel_names: Names for channels; defaults to C0..C{n-1}. + channel_names: Optional channel names. Defaults to ``None``. When + ``None`` (or length does not match channel count), names are + auto-generated as ``C0..C{n-1}`` (for example, 3 channels become + ``C0``, ``C1``, ``C2``). acquisition_datetime: Defaults to now (UTC) if None. clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). @@ -1001,14 +1049,9 @@ def from_jax_array( # Materializes a host NumPy view/copy as needed before Arrow serialization. np_arr = np.asarray(arr) - resolved_dim_order = ( - _infer_dim_order_for_tensor_rank(np_arr.ndim) - if dim_order is None - else dim_order - ) - return from_numpy( + return _from_array_via_numpy( np_arr, - dim_order=resolved_dim_order, + dim_order=dim_order, image_id=image_id, name=name, image_type=image_type, @@ -1106,9 +1149,7 @@ def from_tiff( plane = arr[t, c, z] if clamp_to_uint16 and plane.dtype != np.uint16: plane = np.clip(plane, 0, 65535).astype(np.uint16) - planes.append( - {"z": z, "t": t, "c": c, "pixels": plane.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": plane.reshape(-1)}) dim_order = "XYCT" if size_z == 1 else "XYZCT" @@ -1335,7 +1376,12 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: if fpath is None: # missing plane: zero-fill planes.append( - {"z": z, "t": t, "c": c, "pixels": [0] * (size_x * size_y)} + { + "z": z, + "t": t, + "c": c, + "pixels": np.zeros(size_x * size_y, dtype=np.uint16), + } ) continue @@ -1355,9 +1401,7 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: f" {arr.shape} vs {(size_y, size_x)}" ) arr = _ensure_u16(arr) - planes.append( - {"z": z, "t": t, "c": c, "pixels": arr.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": arr.reshape(-1)}) else: # Treat as TCZYX; extract dims Y, X = arr.shape[-2], arr.shape[-1] @@ -1373,7 +1417,7 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: if Tn == 1 and Cn == 1 and Zn == 1: plane2d = _ensure_u16(arr.reshape(Y, X)) planes.append( - {"z": z, "t": t, "c": c, "pixels": plane2d.ravel().tolist()} + {"z": z, "t": t, "c": c, "pixels": plane2d.reshape(-1)} ) # Case B: multi-Z only (expand across Z) elif Tn == 1 and Cn == 1 and Zn > 1: @@ -1388,7 +1432,7 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: "z": z_idx, "t": t, "c": c, - "pixels": plane2d.ravel().tolist(), + "pixels": plane2d.reshape(-1), } ) # bump global size_z if we exceeded it @@ -1540,9 +1584,7 @@ def from_ome_zarr( plane = arr[t, c, z] if clamp_to_uint16 and plane.dtype != np.uint16: plane = np.clip(plane, 0, 65535).astype(np.uint16) - planes.append( - {"z": z, "t": t, "c": c, "pixels": plane.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": plane.reshape(-1)}) dim_order = "XYCT" if size_z == 1 else "XYZCT" From d716acd7e42ec3b18957def2746e56d42595b459 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 10:59:57 -0600 Subject: [PATCH 08/19] address coderabbit review --- tests/test_core.py | 95 +++++++++++++++++++++------------------------- 1 file changed, 43 insertions(+), 52 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 107b7b2..bf720ee 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -395,78 +395,69 @@ def test_parquet_roundtrip_preserves_image_type(tmp_path: pathlib.Path) -> None: assert reloaded.data.as_py()["image_type"] == "label" -def test_constructor_accepts_torch_tensor() -> None: - """Accept torch tensors directly in OMEArrow constructor.""" - torch = pytest.importorskip("torch") +def _backend_symbols(backend: str) -> dict[str, object]: + """Return backend-specific tensor constructors and converters.""" + if backend == "torch": + torch = pytest.importorskip("torch") + return { + "array_from_1d": lambda n, shape: ( + torch.arange(n).reshape(*shape).to(dtype=torch.uint16) + ), + "to_numpy": lambda x: x.numpy(), + "from_array": ingest.from_torch_array, + } - tensor = torch.arange(2 * 3 * 4).reshape(2, 3, 4).to(dtype=torch.uint16) - oa = OMEArrow(tensor) - - exported = oa.export(how="numpy") - assert exported.shape == (1, 1, 2, 3, 4) - np.testing.assert_array_equal(exported[0, 0], tensor.numpy()) + jnp = pytest.importorskip("jax.numpy") + return { + "array_from_1d": lambda n, shape: jnp.arange(n, dtype=jnp.uint16).reshape( + *shape + ), + "to_numpy": np.asarray, + "from_array": ingest.from_jax_array, + } -def test_constructor_accepts_jax_array() -> None: - """Accept JAX arrays directly in OMEArrow constructor.""" - jnp = pytest.importorskip("jax.numpy") +@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_constructor_accepts_array_backend(backend: str) -> None: + """Accept backend arrays directly in OMEArrow constructor.""" + symbols = _backend_symbols(backend) + arr = symbols["array_from_1d"](2 * 3 * 4, (2, 3, 4)) + expected = symbols["to_numpy"](arr) - arr = jnp.arange(2 * 3 * 4, dtype=jnp.uint16).reshape(2, 3, 4) oa = OMEArrow(arr) - exported = oa.export(how="numpy") + assert exported.shape == (1, 1, 2, 3, 4) - np.testing.assert_array_equal(exported[0, 0], np.asarray(arr)) + np.testing.assert_array_equal(exported[0, 0], expected) -def test_from_torch_array_explicit_dim_order() -> None: - """Support explicit dim order when ingesting torch arrays.""" - torch = pytest.importorskip("torch") +@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_from_array_explicit_dim_order_backend(backend: str) -> None: + """Support explicit dim_order with backend-specific from_* helpers.""" + symbols = _backend_symbols(backend) + arr = symbols["array_from_1d"](2 * 3 * 4 * 5, (2, 3, 4, 5)) + expected = symbols["to_numpy"](arr) + scalar = symbols["from_array"](arr, dim_order="TCYX") - tensor = torch.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5).to(dtype=torch.uint16) - scalar = ingest.from_torch_array(tensor, dim_order="TCYX") oa = OMEArrow(scalar) - exported = oa.export(how="numpy") - assert exported.shape == (2, 3, 1, 4, 5) - np.testing.assert_array_equal(exported[:, :, 0], tensor.numpy()) - - -def test_from_jax_array_explicit_dim_order() -> None: - """Support explicit dim order when ingesting JAX arrays.""" - jnp = pytest.importorskip("jax.numpy") - arr = jnp.arange(2 * 3 * 4 * 5, dtype=jnp.uint16).reshape(2, 3, 4, 5) - scalar = ingest.from_jax_array(arr, dim_order="TCYX") - oa = OMEArrow(scalar) - - exported = oa.export(how="numpy") assert exported.shape == (2, 3, 1, 4, 5) - np.testing.assert_array_equal(exported[:, :, 0], np.asarray(arr)) - - -def test_constructor_dim_order_override_torch_tensor() -> None: - """Allow explicit constructor dim_order for ambiguous torch tensor ranks.""" - torch = pytest.importorskip("torch") - - tensor = torch.arange(3 * 4 * 5).reshape(3, 4, 5).to(dtype=torch.uint16) - oa = OMEArrow(tensor, dim_order="ZYX") - - exported = oa.export(how="numpy") - assert exported.shape == (1, 1, 3, 4, 5) - np.testing.assert_array_equal(exported[0, 0], tensor.numpy()) + np.testing.assert_array_equal(exported[:, :, 0], expected) -def test_constructor_dim_order_override_jax_array() -> None: - """Allow explicit constructor dim_order for ambiguous JAX array ranks.""" - jnp = pytest.importorskip("jax.numpy") +@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_constructor_dim_order_override_backend(backend: str) -> None: + """Allow explicit constructor dim_order for ambiguous backend array ranks.""" + symbols = _backend_symbols(backend) + arr = symbols["array_from_1d"](3 * 4 * 5, (3, 4, 5)) + expected = symbols["to_numpy"](arr) - arr = jnp.arange(3 * 4 * 5, dtype=jnp.uint16).reshape(3, 4, 5) oa = OMEArrow(arr, dim_order="ZYX") - exported = oa.export(how="numpy") + assert exported.shape == (1, 1, 3, 4, 5) - np.testing.assert_array_equal(exported[0, 0], np.asarray(arr)) + np.testing.assert_array_equal(exported[0, 0], expected) def test_constructor_dim_order_rejects_non_array_input() -> None: From 8af82a05e997c1625e534a8b79bde28fc135bb4e Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 14:50:32 -0600 Subject: [PATCH 09/19] better docs Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com> --- docs/src/dlpack.md | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/docs/src/dlpack.md b/docs/src/dlpack.md index f04d068..8648af9 100644 --- a/docs/src/dlpack.md +++ b/docs/src/dlpack.md @@ -61,6 +61,34 @@ oa_cyx = OMEArrow(torch_volume, dim_order="CYX") Use `dim_order` when the inferred axis order does not match your tensor layout. `dim_order` is only supported for array/tensor ingest paths. +To persist with this interpreted axis mapping, export the resulting OME-Arrow +record (for example to parquet): + +```python +from ome_arrow import OMEArrow +import torch + +torch_volume = torch.randint(0, 256, (16, 128, 128), dtype=torch.uint16) +oa = OMEArrow(torch_volume, dim_order="ZYX") +oa.export(how="parquet", out="volume.ome.parquet") +``` + +OME-Arrow stores pixels in canonical OME-style fields (`size_t`, `size_c`, +`size_z`, `size_y`, `size_x`) rather than preserving a free-form input label +string. The interpreted mapping is preserved through those axis sizes and can +be read back with `tensor_view(...)` layouts. + +"Batch" dimension note: + +- There is no separate `B` axis in the OME-Arrow schema. +- For model batches, map batch to `T` during ingest. +- Examples: + - `B,C,Y,X` -> use `dim_order="TCYX"` + - `B,C,Z,Y,X` -> use `dim_order="TCZYX"` + - `B,Y,X,C` -> use `dim_order="TYXC"` +- If `T` is already meaningful in your data, represent batch as table rows + (one OME-Arrow record per batch item) instead of overloading another image axis. + ## Lazy scan-style slicing ```python From b8df72f880fd2d0407e866ce7dd2ddfebcc49fce Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 14:53:19 -0600 Subject: [PATCH 10/19] better docs Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com> --- src/ome_arrow/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 17e02ed..c5227ab 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -124,8 +124,8 @@ def __init__( Args: data: Input source or record payload. dim_order: Axis labels used only for array/tensor ingest - (NumPy, torch, JAX). Ignored inputs are rejected to prevent - silent configuration mistakes. + (NumPy, torch, JAX). Invalid or unrecognized combinations + raise an error instead of being silently ignored. """ # `dim_order` only applies to in-memory array/tensor ingestion paths. From 1b6cec5c4aa2534c57feb42127c4fc868d96cbfb Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 15:28:29 -0600 Subject: [PATCH 11/19] better docs Co-Authored-By: Cameron Mattson <92554334+MattsonCam@users.noreply.github.com> --- src/ome_arrow/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index c5227ab..835b07f 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -128,7 +128,8 @@ def __init__( raise an error instead of being silently ignored. """ - # `dim_order` only applies to in-memory array/tensor ingestion paths. + # `dim_order` applies only when the constructor input itself is a raw + # NumPy/torch/JAX array object (not string/file-path sources). # Rejecting incompatible combinations avoids silently ignoring user intent. if dim_order is not None and not ( isinstance(data, np.ndarray) or _is_torch_array(data) or _is_jax_array(data) From cd3bd66a9f6f2efec2f5d27aa0013b44db4b32ff Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 15:29:12 -0600 Subject: [PATCH 12/19] better error Co-Authored-By: Gregory Way --- src/ome_arrow/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 835b07f..128a9c0 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -224,9 +224,10 @@ def __init__( # --- otherwise ------------------------------------------------------------ else: + data_type = f"{type(data).__module__}.{type(data).__qualname__}" raise TypeError( "input data must be str, dict, pa.StructScalar, numpy.ndarray, " - "torch.Tensor, or jax.Array" + f"torch.Tensor, or jax.Array; got {data_type}" ) @classmethod From 267152533221d7ffaaf9390f77878e9f43313de0 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 15:30:39 -0600 Subject: [PATCH 13/19] better docs Co-Authored-By: Gregory Way --- src/ome_arrow/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 128a9c0..25e919c 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -674,6 +674,8 @@ def view( c: Channel index override for PyVista. If ``None``, uses ``tcz[1]``. downsample: Integer downsampling factor for PyVista views. + Higher values render faster for large volumes but reduce + spatial resolution. opacity: Opacity for PyVista. Either a float in ``[0, 1]`` or ``"sigmoid"``. clim: Contrast limits ``(low, high)`` for PyVista rendering. From 344ee9f67d69e79dc09fe8bb5155781ab503a692 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 15:35:16 -0600 Subject: [PATCH 14/19] better docs Co-Authored-By: Gregory Way --- src/ome_arrow/core.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 25e919c..ea54260 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -693,8 +693,11 @@ def view( TypeError: If parameter types are invalid. Notes: - - The PyVista path attempts to embed a static PNG snapshot for - non-interactive renderers. + - The ``how="pyvista"`` mode normally outputs an interactive + visualization, but attempts to embed a static PNG snapshot for + non-interactive renderers (for example, static docs builds, + nbconvert HTML/PDF exports, rendered/read-only notebook views + such as GitHub notebook previews, and CI log viewers). - When ``show=False`` and ``how="pyvista"``, the returned :class:`pyvista.Plotter` can be shown later. """ From 9e75d9c89f2c0ab2a3dacc0dad122f149ba00fd8 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 15:37:31 -0600 Subject: [PATCH 15/19] better docs Co-Authored-By: Gregory Way --- README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 1c7235f..e452194 100644 --- a/README.md +++ b/README.md @@ -127,13 +127,11 @@ You can also use explicit helper functions from `ome_arrow.ingest`. Why this is useful: -- It removes conversion boilerplate in model/data pipelines that already use torch or JAX tensors. -- It keeps axis handling in one place (`dim_order`). -- This reduces mistakes when moving between tensor layouts and OME-Arrow records. -- It can reduce overhead in some paths. -- For example, CPU torch tensors often expose a NumPy view without an extra copy. +- It reduces compute overhead by removing conversion code boilerplate in separate model/data pipelines that already use torch or JAX tensors (i.e., it provides a direct port of OME-arrow into popular deep learning libraries). +- However, this is more about clean interoperability than dramatic end-to-end speedups (although we expect fewer handoffs to result in speedups). Specifically: +- It makes it easier for a user to update dimension ordering input in the same place without requiring separate functionality (see argument `dim_order`). +- This smooths handoffs and reduces mistakes when moving between tensor layouts and OME-Arrow records. For example, CPU torch tensors often expose a NumPy view without an extra copy. - Ingest still materializes OME-Arrow planes/chunks. -- This is more about clean interoperability than dramatic end-to-end speedups. ```python from ome_arrow import OMEArrow From e42ba04aa53ad9701843be7d8255ab283a32bb9c Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 16:57:50 -0600 Subject: [PATCH 16/19] address coderabbit review --- src/ome_arrow/core.py | 5 ++++- tests/test_core.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index ea54260..d3e4eb4 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -96,8 +96,9 @@ class OMEArrow: def __init__( self, data: str | dict | pa.StructScalar | "np.ndarray", - dim_order: str | None = None, tcz: Tuple[int, int, int] = (0, 0, 0), + *, + dim_order: str | None = None, column_name: str = "ome_arrow", row_index: int = 0, image_type: str | None = None, @@ -763,6 +764,8 @@ def view( return plotter + raise ValueError(f"Unsupported view mode: {how!r}. Use 'matplotlib' or 'pyvista'.") + def tensor_view( self, *, diff --git a/tests/test_core.py b/tests/test_core.py index bf720ee..a02aeb2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -466,6 +466,22 @@ def test_constructor_dim_order_rejects_non_array_input() -> None: OMEArrow("tests/data/JUMP-BR00117006/BR00117006.ome.parquet", dim_order="ZYX") +def test_constructor_second_positional_arg_still_binds_tcz() -> None: + """Preserve positional ABI: second positional argument is tcz, not dim_order.""" + arr = np.arange(12, dtype=np.uint16).reshape(1, 1, 1, 3, 4) + oa = OMEArrow(arr, (0, 0, 1)) + assert oa.tcz == (0, 0, 1) + + +def test_view_rejects_unsupported_mode() -> None: + """Unsupported view modes should raise a clear ValueError.""" + arr = np.arange(12, dtype=np.uint16).reshape(1, 1, 1, 3, 4) + oa = OMEArrow(arr) + + with pytest.raises(ValueError, match="Unsupported view mode"): + oa.view(how="foo") + + def test_vortex_custom_column_name(tmp_path: pathlib.Path) -> None: """Ensure custom Vortex column names are preserved on round-trip.""" pytest.importorskip( From 40d8491ce4fade6a5ac5fc194bd5a8de20c496a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 22:58:33 +0000 Subject: [PATCH 17/19] [pre-commit.ci lite] apply automatic fixes --- src/ome_arrow/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index d3e4eb4..cf4fb57 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -764,7 +764,9 @@ def view( return plotter - raise ValueError(f"Unsupported view mode: {how!r}. Use 'matplotlib' or 'pyvista'.") + raise ValueError( + f"Unsupported view mode: {how!r}. Use 'matplotlib' or 'pyvista'." + ) def tensor_view( self, From be2f217fcb7e6c29dff4db49ed4cb1d651279d25 Mon Sep 17 00:00:00 2001 From: d33bs Date: Tue, 7 Apr 2026 17:12:33 -0600 Subject: [PATCH 18/19] address coderabbit review --- src/ome_arrow/core.py | 16 +++++++++++++--- tests/test_core.py | 8 ++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index d3e4eb4..05e900e 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -654,7 +654,9 @@ def view( clim: tuple[float, float] | None = None, show_axes: bool = True, scaling_values: tuple[float, float, float] | None = None, - ) -> matplotlib.figure.Figure | "pyvista.Plotter": + ) -> ( + tuple[matplotlib.figure.Figure, Any, Any] | "pyvista.Plotter" + ): """Render an OME-Arrow record using Matplotlib or PyVista. This convenience method supports two rendering backends: @@ -685,8 +687,11 @@ def view( PyVista. If ``None``, uses OME metadata-derived scaling. Returns: - matplotlib.figure.Figure | pyvista.Plotter: Matplotlib figure for - ``how="matplotlib"`` or PyVista plotter for ``how="pyvista"``. + tuple[matplotlib.figure.Figure, matplotlib.axes.Axes, + matplotlib.image.AxesImage] | pyvista.Plotter: + For ``how="matplotlib"``, returns the tuple emitted by + :func:`ome_arrow.view.view_matplotlib` as ``(figure, axes, image)``. + For ``how="pyvista"``, returns a :class:`pyvista.Plotter`. Raises: ValueError: If a requested plane is not found or the render mode @@ -702,6 +707,11 @@ def view( - When ``show=False`` and ``how="pyvista"``, the returned :class:`pyvista.Plotter` can be shown later. """ + if how not in {"matplotlib", "pyvista"}: + raise ValueError( + f"Unsupported view mode: {how!r}. Use 'matplotlib' or 'pyvista'." + ) + self._ensure_materialized() if how == "matplotlib": diff --git a/tests/test_core.py b/tests/test_core.py index a02aeb2..515ac46 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -482,6 +482,14 @@ def test_view_rejects_unsupported_mode() -> None: oa.view(how="foo") +def test_view_rejects_unsupported_mode_before_lazy_materialization() -> None: + """Validate render mode before touching lazy source materialization.""" + oa = OMEArrow.scan("tests/data/does-not-exist.ome.parquet") + + with pytest.raises(ValueError, match="Unsupported view mode"): + oa.view(how="foo") + + def test_vortex_custom_column_name(tmp_path: pathlib.Path) -> None: """Ensure custom Vortex column names are preserved on round-trip.""" pytest.importorskip( From eb8d8ebcf09b1ef5b5045521d0d890bc5193191e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci-lite[bot]" <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:13:19 +0000 Subject: [PATCH 19/19] [pre-commit.ci lite] apply automatic fixes --- src/ome_arrow/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 27582d6..993c6e6 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -654,9 +654,7 @@ def view( clim: tuple[float, float] | None = None, show_axes: bool = True, scaling_values: tuple[float, float, float] | None = None, - ) -> ( - tuple[matplotlib.figure.Figure, Any, Any] | "pyvista.Plotter" - ): + ) -> tuple[matplotlib.figure.Figure, Any, Any] | "pyvista.Plotter": """Render an OME-Arrow record using Matplotlib or PyVista. This convenience method supports two rendering backends: