diff --git a/README.md b/README.md index ac7e1fc..e452194 100644 --- a/README.md +++ b/README.md @@ -120,6 +120,49 @@ Advanced options: See full docs: [`docs/src/dlpack.md`](docs/src/dlpack.md) +## Tensor ingest (PyTorch/JAX) + +You can ingest torch or JAX arrays directly with `OMEArrow(...)`. +You can also use explicit helper functions from `ome_arrow.ingest`. + +Why this is useful: + +- It reduces compute overhead by removing conversion code boilerplate in separate model/data pipelines that already use torch or JAX tensors (i.e., it provides a direct port of OME-arrow into popular deep learning libraries). +- However, this is more about clean interoperability than dramatic end-to-end speedups (although we expect fewer handoffs to result in speedups). Specifically: +- It makes it easier for a user to update dimension ordering input in the same place without requiring separate functionality (see argument `dim_order`). +- This smooths handoffs and reduces mistakes when moving between tensor layouts and OME-Arrow records. For example, CPU torch tensors often expose a NumPy view without an extra copy. +- Ingest still materializes OME-Arrow planes/chunks. + +```python +from ome_arrow import OMEArrow + +# Direct constructor support: +# inferred defaults are rank-based: +# 2D -> "YX", 3D -> "ZYX", 4D -> "TCYX", 5D -> "TCZYX" +oa_torch = OMEArrow(torch_tensor) +oa_jax = OMEArrow(jax_array) + +# Optional: override dim order when shape is ambiguous +oa_zyx = OMEArrow(torch_volume, dim_order="ZYX") +``` + +```python +from ome_arrow.ingest import from_torch_array, from_jax_array + +scalar_torch = from_torch_array(torch_tensor, dim_order="TCYX") +scalar_jax = from_jax_array(jax_array, dim_order="TCYX") +``` + +Notes: + +- Torch/JAX support is optional. +- Install extras as needed: + `pip install "ome-arrow[dlpack-torch]"` or `pip install "ome-arrow[dlpack-jax]"`. +- Torch tensors are detached and converted on CPU for ingest. +- `dim_order` is accepted only for NumPy/torch/JAX array inputs. +- Ingest now passes flattened NumPy pixel buffers directly to Arrow. +- This avoids materializing Python `list` payloads per plane/chunk. + ## Benchmarking lazy reads Use the lightweight benchmark utility in `benchmarks/` to compare lazy tensor diff --git a/docs/src/dlpack.md b/docs/src/dlpack.md index 0b99cf9..8648af9 100644 --- a/docs/src/dlpack.md +++ b/docs/src/dlpack.md @@ -42,6 +42,53 @@ flat = torch.utils.dlpack.from_dlpack(capsule) tensor = flat.reshape(view.shape) ``` +You can also ingest torch tensors directly: + +```python +from ome_arrow import OMEArrow +import torch + +# 2D tensor interpreted as YX by default. +torch_tensor = torch.randint(0, 256, (128, 128), dtype=torch.uint16) +oa = OMEArrow(torch_tensor) + +# 3D tensors are inferred as ZYX by default. +# Use dim_order when your tensor is arranged differently (for example CYX). +torch_volume = torch.randint(0, 256, (16, 128, 128), dtype=torch.uint16) +oa_cyx = OMEArrow(torch_volume, dim_order="CYX") +``` + +Use `dim_order` when the inferred axis order does not match your tensor layout. +`dim_order` is only supported for array/tensor ingest paths. + +To persist with this interpreted axis mapping, export the resulting OME-Arrow +record (for example to parquet): + +```python +from ome_arrow import OMEArrow +import torch + +torch_volume = torch.randint(0, 256, (16, 128, 128), dtype=torch.uint16) +oa = OMEArrow(torch_volume, dim_order="ZYX") +oa.export(how="parquet", out="volume.ome.parquet") +``` + +OME-Arrow stores pixels in canonical OME-style fields (`size_t`, `size_c`, +`size_z`, `size_y`, `size_x`) rather than preserving a free-form input label +string. The interpreted mapping is preserved through those axis sizes and can +be read back with `tensor_view(...)` layouts. + +"Batch" dimension note: + +- There is no separate `B` axis in the OME-Arrow schema. +- For model batches, map batch to `T` during ingest. +- Examples: + - `B,C,Y,X` -> use `dim_order="TCYX"` + - `B,C,Z,Y,X` -> use `dim_order="TCZYX"` + - `B,Y,X,C` -> use `dim_order="TYXC"` +- If `T` is already meaningful in your data, represent batch as table rows + (one OME-Arrow record per batch item) instead of overloading another image axis. + ## Lazy scan-style slicing ```python @@ -76,6 +123,22 @@ flat = jnp.from_dlpack(capsule) arr = flat.reshape(view.shape) ``` +You can also ingest JAX arrays directly: + +```python +from ome_arrow import OMEArrow +import jax.numpy as jnp + +# 2D array interpreted as YX by default. +jax_array = jnp.arange(128 * 128, dtype=jnp.uint16).reshape(128, 128) +oa = OMEArrow(jax_array) + +# 3D arrays are inferred as ZYX by default. +# Use dim_order when your array is arranged differently (for example CYX). +jax_volume = jnp.arange(16 * 128 * 128, dtype=jnp.uint16).reshape(16, 128, 128) +oa_cyx = OMEArrow(jax_volume, dim_order="CYX") +``` + ## Iteration examples ```python diff --git a/docs/src/index.md b/docs/src/index.md index 32b49c3..0d68586 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -14,4 +14,5 @@ maxdepth: 3 --- python-api dlpack +examples/learning_to_fly_with_ome-arrow ``` diff --git a/docs/src/python-api.md b/docs/src/python-api.md index 6e5a260..8e110dd 100644 --- a/docs/src/python-api.md +++ b/docs/src/python-api.md @@ -1,9 +1,45 @@ # Python API +```{eval-rst} +ome_arrow +------------------- +.. automodule:: ome_arrow + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.core +------------------- +.. automodule:: ome_arrow.core + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.ingest +------------------- +.. automodule:: ome_arrow.ingest + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.export +------------------- +.. automodule:: ome_arrow.export + :members: + :undoc-members: + :show-inheritance: +``` + ```{eval-rst} ome_arrow.meta ------------------- -.. automodule:: src.ome_arrow.meta +.. automodule:: ome_arrow.meta :members: :private-members: :undoc-members: @@ -13,7 +49,34 @@ ome_arrow.meta ```{eval-rst} ome_arrow.tensor ------------------- -.. automodule:: src.ome_arrow.tensor +.. automodule:: ome_arrow.tensor + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.transform +------------------- +.. automodule:: ome_arrow.transform + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.utils +------------------- +.. automodule:: ome_arrow.utils + :members: + :undoc-members: + :show-inheritance: +``` + +```{eval-rst} +ome_arrow.view +------------------- +.. automodule:: ome_arrow.view :members: :undoc-members: :show-inheritance: diff --git a/pyproject.toml b/pyproject.toml index f62006c..f25fa87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,11 @@ dependencies = [ "pyarrow>=22", ] optional-dependencies.dlpack = [ - "jax>=0.4", + "jax>=0.4.1", "torch>=2.1", ] optional-dependencies.dlpack-jax = [ - "jax>=0.4", + "jax>=0.4.1", ] optional-dependencies.dlpack-torch = [ "torch>=2.1", diff --git a/src/ome_arrow/__init__.py b/src/ome_arrow/__init__.py index a446e30..4ca3bb4 100644 --- a/src/ome_arrow/__init__.py +++ b/src/ome_arrow/__init__.py @@ -12,11 +12,13 @@ to_ome_zarr, ) from ome_arrow.ingest import ( + from_jax_array, from_numpy, from_ome_parquet, from_ome_vortex, from_ome_zarr, from_tiff, + from_torch_array, to_ome_arrow, ) from ome_arrow.meta import OME_ARROW_STRUCT, OME_ARROW_TAG_TYPE, OME_ARROW_TAG_VERSION diff --git a/src/ome_arrow/core.py b/src/ome_arrow/core.py index 6a3b5ea..993c6e6 100644 --- a/src/ome_arrow/core.py +++ b/src/ome_arrow/core.py @@ -29,12 +29,16 @@ to_ome_zarr, ) from ome_arrow.ingest import ( + _is_jax_array, + _is_torch_array, + from_jax_array, from_numpy, from_ome_parquet, from_ome_vortex, from_ome_zarr, from_stack_pattern_path, from_tiff, + from_torch_array, open_lazy_plane_source, ) from ome_arrow.meta import OME_ARROW_STRUCT @@ -93,6 +97,8 @@ def __init__( self, data: str | dict | pa.StructScalar | "np.ndarray", tcz: Tuple[int, int, int] = (0, 0, 0), + *, + dim_order: str | None = None, column_name: str = "ome_arrow", row_index: int = 0, image_type: str | None = None, @@ -107,12 +113,33 @@ def __init__( - a path/URL to a Vortex file (.vortex) - a NumPy ndarray (2D-5D; interpreted with from_numpy defaults) + - a torch.Tensor (2D-5D; inferred dim order by rank unless provided via + `dim_order`) + - a jax.Array (2D-5D; inferred dim order by rank unless provided via + `dim_order`) - a dict already matching the OME-Arrow schema - a pa.StructScalar already typed to OME_ARROW_STRUCT - optionally override/set image_type metadata on ingest - optionally defer source-file ingestion with lazy=True + + Args: + data: Input source or record payload. + dim_order: Axis labels used only for array/tensor ingest + (NumPy, torch, JAX). Invalid or unrecognized combinations + raise an error instead of being silently ignored. """ + # `dim_order` applies only when the constructor input itself is a raw + # NumPy/torch/JAX array object (not string/file-path sources). + # Rejecting incompatible combinations avoids silently ignoring user intent. + if dim_order is not None and not ( + isinstance(data, np.ndarray) or _is_torch_array(data) or _is_jax_array(data) + ): + raise ValueError( + "dim_order is supported only for numpy.ndarray, torch.Tensor, " + "or jax.Array inputs." + ) + # set the tcz for viewing self.tcz = tcz self._data: pa.StructScalar | None = None @@ -161,15 +188,35 @@ def __init__( # Uses from_numpy defaults: dim_order="TCZYX", clamp_to_uint16=True, etc. # If the array is YX/ZYX/CYX/etc., # from_numpy will expand/reorder accordingly. - self.data = from_numpy(data, image_type=image_type) + self.data = from_numpy( + data, + dim_order="TCZYX" if dim_order is None else dim_order, + image_type=image_type, + ) + + # --- 4) Torch tensor ------------------------------------------------------ + elif _is_torch_array(data): + self.data = from_torch_array( + data, + dim_order=dim_order, + image_type=image_type, + ) + + # --- 5) JAX array -------------------------------------------------------- + elif _is_jax_array(data): + self.data = from_jax_array( + data, + dim_order=dim_order, + image_type=image_type, + ) - # --- 4) Already-typed Arrow scalar --------------------------------------- + # --- 6) Already-typed Arrow scalar --------------------------------------- elif isinstance(data, pa.StructScalar): self.data = data if image_type is not None: self.data = self._wrap_with_image_type(self.data, image_type) - # --- 5) Plain dict matching the schema ----------------------------------- + # --- 7) Plain dict matching the schema ----------------------------------- elif isinstance(data, dict): record = {f.name: data.get(f.name) for f in OME_ARROW_STRUCT} self.data = pa.scalar(record, type=OME_ARROW_STRUCT) @@ -178,8 +225,10 @@ def __init__( # --- otherwise ------------------------------------------------------------ else: + data_type = f"{type(data).__module__}.{type(data).__qualname__}" raise TypeError( - "input data must be str, dict, pa.StructScalar, or numpy.ndarray" + "input data must be str, dict, pa.StructScalar, numpy.ndarray, " + f"torch.Tensor, or jax.Array; got {data_type}" ) @classmethod @@ -605,80 +654,62 @@ def view( clim: tuple[float, float] | None = None, show_axes: bool = True, scaling_values: tuple[float, float, float] | None = None, - ) -> matplotlib.figure.Figure | "pyvista.Plotter": - """ - Render an OME-Arrow record using Matplotlib or PyVista. + ) -> tuple[matplotlib.figure.Figure, Any, Any] | "pyvista.Plotter": + """Render an OME-Arrow record using Matplotlib or PyVista. This convenience method supports two rendering backends: - * ``how="matplotlib"`` — renders a single (t, c, z) plane as a 2D image. - Returns a Matplotlib :class:`~matplotlib.figure.Figure` (or whatever - :func:`view_matplotlib` returns) and optionally displays it with - ``plt.show()`` when ``show=True``. - - * ``how="pyvista"`` — creates an interactive 3D PyVista visualization in - Jupyter. When ``show=True``, displays the widget. Independently, a static - PNG snapshot is embedded in the notebook (inside a collapsed - ``
`` block) for non-interactive renderers (e.g., GitHub). + - ``how="matplotlib"`` renders a single ``(t, c, z)`` plane as a 2D + image. + - ``how="pyvista"`` creates an interactive 3D PyVista visualization. Args: - how: Rendering backend. One of ``"matplotlib"`` or ``"pyvista"``. - tcz: The (t, c, z) indices of the plane to display when using Matplotlib. - Defaults to ``(0, 0, 0)``. - autoscale: If ``True`` and ``vmin``/``vmax`` are not provided, infer - display limits from the image data range (Matplotlib path only). - vmin: Lower display limit for intensity scaling (Matplotlib path only). - vmax: Upper display limit for intensity scaling (Matplotlib path only). - cmap: Matplotlib colormap name for single-channel display (Matplotlib only). - show: Whether to display the plot immediately. For Matplotlib, calls - ``plt.show()``. For PyVista, calls ``plotter.show()``. - c: Channel index override for the PyVista view. If ``None``, uses - ``tcz[1]`` (the ``c`` from ``tcz``). - downsample: Integer downsampling factor for the PyVista volume or slices. - Must be ``>= 1``. - opacity: Opacity specification for PyVista. Either a float in ``[0, 1]`` - or the string ``"sigmoid"`` (backend interprets as a preset transfer - function). - clim: Contrast limits (``(low, high)``) for PyVista rendering. - show_axes: If ``True``, display axes in the PyVista scene. - scaling_values: Physical scale multipliers for the (x, y, z) axes used by - PyVista, typically to express anisotropy. If ``None``, uses metadata - scaling from the OME-Arrow record (pixels_meta.physical_size_x/y/z). - These scaling values will default to 1µm if metadata is missing in - source image metadata. + how: Rendering backend. One of ``"matplotlib"`` or ``"pyvista"``. + tcz: ``(t, c, z)`` indices used for plane display. + autoscale: Infer Matplotlib display limits from image range when + ``vmin``/``vmax`` are not provided. + vmin: Lower display limit for Matplotlib intensity scaling. + vmax: Upper display limit for Matplotlib intensity scaling. + cmap: Matplotlib colormap name for single-channel display. + show: Whether to display the plot immediately. + c: Channel index override for PyVista. If ``None``, uses + ``tcz[1]``. + downsample: Integer downsampling factor for PyVista views. + Higher values render faster for large volumes but reduce + spatial resolution. + opacity: Opacity for PyVista. Either a float in ``[0, 1]`` or + ``"sigmoid"``. + clim: Contrast limits ``(low, high)`` for PyVista rendering. + show_axes: Whether to display axes in the PyVista scene. + scaling_values: Physical scale multipliers ``(x, y, z)`` used by + PyVista. If ``None``, uses OME metadata-derived scaling. Returns: - matplotlib.figure.Figure | pyvista.Plotter: - * If ``how="matplotlib"``, returns the figure created by - :func:`view_matplotlib` (often a :class:`~matplotlib.figure.Figure`). - * If ``how="pyvista"``, returns the created :class:`pyvista.Plotter`. + tuple[matplotlib.figure.Figure, matplotlib.axes.Axes, + matplotlib.image.AxesImage] | pyvista.Plotter: + For ``how="matplotlib"``, returns the tuple emitted by + :func:`ome_arrow.view.view_matplotlib` as ``(figure, axes, image)``. + For ``how="pyvista"``, returns a :class:`pyvista.Plotter`. Raises: - ValueError: If a requested plane (``t,c,z``) is not found or if pixel - array dimensions are inconsistent (propagated from - :func:`view_matplotlib`). - TypeError: If parameter types are invalid (e.g., negative ``downsample``). + ValueError: If a requested plane is not found or the render mode + is unsupported. + TypeError: If parameter types are invalid. Notes: - * The PyVista path embeds a static PNG snapshot via Pillow (``PIL``). If - Pillow is unavailable, the method logs a warning and skips the snapshot, - but the interactive viewer is still returned. - * When ``show=False`` and ``how="pyvista"``, no interactive window is - opened, but the returned :class:`pyvista.Plotter` can be shown later. - - Examples: - Display a single plane with Matplotlib: - - >>> fig = obj.view(how="matplotlib", tcz=(0, 1, 5), cmap="magma") - - Create an interactive PyVista scene in a Jupyter notebook: - - >>> plotter = obj.view(how="pyvista", c=0, downsample=2, show=True) - - Configure PyVista contrast limits and keep axes hidden: - - >>> plotter = obj.view(how="pyvista", clim=(100, 2000), show_axes=False) + - The ``how="pyvista"`` mode normally outputs an interactive + visualization, but attempts to embed a static PNG snapshot for + non-interactive renderers (for example, static docs builds, + nbconvert HTML/PDF exports, rendered/read-only notebook views + such as GitHub notebook previews, and CI log viewers). + - When ``show=False`` and ``how="pyvista"``, the returned + :class:`pyvista.Plotter` can be shown later. """ + if how not in {"matplotlib", "pyvista"}: + raise ValueError( + f"Unsupported view mode: {how!r}. Use 'matplotlib' or 'pyvista'." + ) + self._ensure_materialized() if how == "matplotlib": @@ -741,6 +772,10 @@ def view( return plotter + raise ValueError( + f"Unsupported view mode: {how!r}. Use 'matplotlib' or 'pyvista'." + ) + def tensor_view( self, *, diff --git a/src/ome_arrow/ingest.py b/src/ome_arrow/ingest.py index 38643f8..de0bb2f 100644 --- a/src/ome_arrow/ingest.py +++ b/src/ome_arrow/ingest.py @@ -459,7 +459,7 @@ def _build_chunks_from_planes( "shape_z": sz, "shape_y": sy, "shape_x": sx, - "pixels": slab.reshape(-1).tolist(), + "pixels": slab.reshape(-1), } ) return chunks @@ -563,7 +563,14 @@ def to_ome_arrow( ch["illumination"] = str(ch["illumination"]) if planes is None: - planes = [{"z": 0, "t": 0, "c": 0, "pixels": [0] * (size_x * size_y)}] + planes = [ + { + "z": 0, + "t": 0, + "c": 0, + "pixels": np.zeros(size_x * size_y, dtype=np.uint16), + } + ] if chunks is None and build_chunks: chunks = _build_chunks_from_planes( @@ -676,7 +683,10 @@ def from_numpy( image_id: Optional stable image identifier. name: Optional human label. image_type: Open-ended image kind (e.g., "image", "label"). - channel_names: Names for channels; defaults to C0..C{n-1}. + channel_names: Optional channel names. Defaults to ``None``. When + ``None`` (or length does not match channel count), names are + auto-generated as ``C0..C{n-1}`` (for example, 3 channels become + ``C0``, ``C1``, ``C2``). acquisition_datetime: Defaults to now (UTC) if None. clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). @@ -774,9 +784,7 @@ def from_numpy( for c in range(size_c): for z in range(size_z): plane = tczyx[t, c, z] - planes.append( - {"z": z, "t": t, "c": c, "pixels": plane.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": plane.reshape(-1)}) # Meta dimension_order: mirror your other ingests meta_dim_order = "XYCT" if size_z == 1 else "XYZCT" @@ -809,6 +817,258 @@ def from_numpy( ) +def _is_torch_array(data: Any) -> bool: + """Return True when ``data`` looks like a torch tensor.""" + module = getattr(type(data), "__module__", "") + return module == "torch" or module.startswith("torch.") + + +def _is_jax_array(data: Any) -> bool: + """Return True when ``data`` looks like a JAX array.""" + module = getattr(type(data), "__module__", "") + return module.startswith("jax.") or module.startswith("jaxlib.") + + +def _infer_dim_order_for_tensor_rank(ndim: int) -> str: + """Infer a practical default dim order for tensor backends.""" + if ndim == 2: + return "YX" + if ndim == 3: + return "ZYX" + if ndim == 4: + return "TCYX" + if ndim == 5: + return "TCZYX" + raise ValueError( + "Unable to infer dim_order for tensor rank " + f"{ndim}. Provide dim_order explicitly." + ) + + +def _from_array_via_numpy( + np_arr: np.ndarray, + *, + dim_order: str | None, + image_id: Optional[str], + name: Optional[str], + image_type: Optional[str], + channel_names: Optional[Sequence[str]], + acquisition_datetime: Optional[datetime], + clamp_to_uint16: bool, + chunk_shape: Optional[Tuple[int, int, int]], + chunk_order: str, + build_chunks: bool, + physical_size_x: float, + physical_size_y: float, + physical_size_z: float, + physical_size_unit: str, + dtype_meta: Optional[str], +) -> pa.StructScalar: + """Shared array->NumPy->OME-Arrow conversion path.""" + resolved_dim_order = ( + _infer_dim_order_for_tensor_rank(np_arr.ndim) + if dim_order is None + else dim_order + ) + return from_numpy( + np_arr, + dim_order=resolved_dim_order, + image_id=image_id, + name=name, + image_type=image_type, + channel_names=channel_names, + acquisition_datetime=acquisition_datetime, + clamp_to_uint16=clamp_to_uint16, + chunk_shape=chunk_shape, + chunk_order=chunk_order, + build_chunks=build_chunks, + physical_size_x=physical_size_x, + physical_size_y=physical_size_y, + physical_size_z=physical_size_z, + physical_size_unit=physical_size_unit, + dtype_meta=dtype_meta, + ) + + +def from_torch_array( + arr: Any, + *, + dim_order: str | None = None, + image_id: Optional[str] = None, + name: Optional[str] = None, + image_type: Optional[str] = None, + channel_names: Optional[Sequence[str]] = None, + acquisition_datetime: Optional[datetime] = None, + clamp_to_uint16: bool = True, + chunk_shape: Optional[Tuple[int, int, int]] = (1, 512, 512), + chunk_order: str = "ZYX", + build_chunks: bool = True, + # meta + physical_size_x: float = 1.0, + physical_size_y: float = 1.0, + physical_size_z: float = 1.0, + physical_size_unit: str = "µm", + dtype_meta: Optional[str] = None, +) -> pa.StructScalar: + """Build an OME-Arrow StructScalar from a torch tensor. + + This is useful when your pipeline already works with ``torch.Tensor`` + objects (for example model inputs/outputs) and you want a direct path into + the canonical OME-Arrow struct without manually converting and reshaping in + user code. + + Args: + arr: ``torch.Tensor`` image data. + dim_order: Axis labels for ``arr``. If None, infer from rank: + 2D->"YX", 3D->"ZYX", 4D->"TCYX", 5D->"TCZYX". + image_id: Optional stable image identifier. + name: Optional human label. + image_type: Open-ended image kind (e.g., "image", "label"). + channel_names: Optional channel names. Defaults to ``None``. When + ``None`` (or length does not match channel count), names are + auto-generated as ``C0..C{n-1}`` (for example, 3 channels become + ``C0``, ``C1``, ``C2``). + acquisition_datetime: Defaults to now (UTC) if None. + clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. + chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). + chunk_order: Flattening order for chunk pixels (default "ZYX"). + build_chunks: If True, build chunked pixels from planes. + physical_size_x: Spatial pixel size (µm) for X. + physical_size_y: Spatial pixel size (µm) for Y. + physical_size_z: Spatial pixel size (µm) for Z when present. + physical_size_unit: Unit string for spatial axes (default "µm"). + dtype_meta: Pixel dtype string to place in metadata. + + Returns: + pa.StructScalar: Typed OME-Arrow record. + """ + try: + import torch + except ImportError as exc: + raise RuntimeError( + "Torch is not installed. Install extras: " + "pip install 'ome-arrow[dlpack-torch]'." + ) from exc + + if not isinstance(arr, torch.Tensor): + raise TypeError("from_torch_array expects a torch.Tensor.") + + tensor = arr.detach() + if tensor.layout != torch.strided: + tensor = tensor.to_dense() + if getattr(tensor, "is_conj", lambda: False)(): + tensor = tensor.resolve_conj() + if getattr(tensor, "is_neg", lambda: False)(): + tensor = tensor.resolve_neg() + if tensor.device.type != "cpu": + # OME-Arrow ingest currently serializes from host memory. + tensor = tensor.to(device="cpu") + + # For CPU strided tensors this is typically a zero-copy NumPy view. + np_arr = tensor.numpy() + return _from_array_via_numpy( + np_arr, + dim_order=dim_order, + image_id=image_id, + name=name, + image_type=image_type, + channel_names=channel_names, + acquisition_datetime=acquisition_datetime, + clamp_to_uint16=clamp_to_uint16, + chunk_shape=chunk_shape, + chunk_order=chunk_order, + build_chunks=build_chunks, + physical_size_x=physical_size_x, + physical_size_y=physical_size_y, + physical_size_z=physical_size_z, + physical_size_unit=physical_size_unit, + dtype_meta=dtype_meta, + ) + + +def from_jax_array( + arr: Any, + *, + dim_order: str | None = None, + image_id: Optional[str] = None, + name: Optional[str] = None, + image_type: Optional[str] = None, + channel_names: Optional[Sequence[str]] = None, + acquisition_datetime: Optional[datetime] = None, + clamp_to_uint16: bool = True, + chunk_shape: Optional[Tuple[int, int, int]] = (1, 512, 512), + chunk_order: str = "ZYX", + build_chunks: bool = True, + # meta + physical_size_x: float = 1.0, + physical_size_y: float = 1.0, + physical_size_z: float = 1.0, + physical_size_unit: str = "µm", + dtype_meta: Optional[str] = None, +) -> pa.StructScalar: + """Build an OME-Arrow StructScalar from a JAX array. + + This is useful when your pipeline already works with ``jax.Array`` objects + and you want a direct path into the canonical OME-Arrow struct without + manual conversion boilerplate in user code. + + Args: + arr: ``jax.Array`` image data. + dim_order: Axis labels for ``arr``. If None, infer from rank: + 2D->"YX", 3D->"ZYX", 4D->"TCYX", 5D->"TCZYX". + image_id: Optional stable image identifier. + name: Optional human label. + image_type: Open-ended image kind (e.g., "image", "label"). + channel_names: Optional channel names. Defaults to ``None``. When + ``None`` (or length does not match channel count), names are + auto-generated as ``C0..C{n-1}`` (for example, 3 channels become + ``C0``, ``C1``, ``C2``). + acquisition_datetime: Defaults to now (UTC) if None. + clamp_to_uint16: If True, clamp/cast planes to uint16 before serialization. + chunk_shape: Chunk shape as (Z, Y, X). Defaults to (1, 512, 512). + chunk_order: Flattening order for chunk pixels (default "ZYX"). + build_chunks: If True, build chunked pixels from planes. + physical_size_x: Spatial pixel size (µm) for X. + physical_size_y: Spatial pixel size (µm) for Y. + physical_size_z: Spatial pixel size (µm) for Z when present. + physical_size_unit: Unit string for spatial axes (default "µm"). + dtype_meta: Pixel dtype string to place in metadata. + + Returns: + pa.StructScalar: Typed OME-Arrow record. + """ + try: + import jax + except ImportError as exc: + raise RuntimeError( + "JAX is not installed. Install extras: pip install 'ome-arrow[dlpack-jax]'." + ) from exc + + if not isinstance(arr, jax.Array): + raise TypeError("from_jax_array expects a jax.Array.") + + # Materializes a host NumPy view/copy as needed before Arrow serialization. + np_arr = np.asarray(arr) + return _from_array_via_numpy( + np_arr, + dim_order=dim_order, + image_id=image_id, + name=name, + image_type=image_type, + channel_names=channel_names, + acquisition_datetime=acquisition_datetime, + clamp_to_uint16=clamp_to_uint16, + chunk_shape=chunk_shape, + chunk_order=chunk_order, + build_chunks=build_chunks, + physical_size_x=physical_size_x, + physical_size_y=physical_size_y, + physical_size_z=physical_size_z, + physical_size_unit=physical_size_unit, + dtype_meta=dtype_meta, + ) + + def from_tiff( tiff_path: str | Path, image_id: Optional[str] = None, @@ -889,9 +1149,7 @@ def from_tiff( plane = arr[t, c, z] if clamp_to_uint16 and plane.dtype != np.uint16: plane = np.clip(plane, 0, 65535).astype(np.uint16) - planes.append( - {"z": z, "t": t, "c": c, "pixels": plane.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": plane.reshape(-1)}) dim_order = "XYCT" if size_z == 1 else "XYZCT" @@ -1118,7 +1376,12 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: if fpath is None: # missing plane: zero-fill planes.append( - {"z": z, "t": t, "c": c, "pixels": [0] * (size_x * size_y)} + { + "z": z, + "t": t, + "c": c, + "pixels": np.zeros(size_x * size_y, dtype=np.uint16), + } ) continue @@ -1138,9 +1401,7 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: f" {arr.shape} vs {(size_y, size_x)}" ) arr = _ensure_u16(arr) - planes.append( - {"z": z, "t": t, "c": c, "pixels": arr.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": arr.reshape(-1)}) else: # Treat as TCZYX; extract dims Y, X = arr.shape[-2], arr.shape[-1] @@ -1156,7 +1417,7 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: if Tn == 1 and Cn == 1 and Zn == 1: plane2d = _ensure_u16(arr.reshape(Y, X)) planes.append( - {"z": z, "t": t, "c": c, "pixels": plane2d.ravel().tolist()} + {"z": z, "t": t, "c": c, "pixels": plane2d.reshape(-1)} ) # Case B: multi-Z only (expand across Z) elif Tn == 1 and Cn == 1 and Zn > 1: @@ -1171,7 +1432,7 @@ def _ensure_u16(arr: np.ndarray) -> np.ndarray: "z": z_idx, "t": t, "c": c, - "pixels": plane2d.ravel().tolist(), + "pixels": plane2d.reshape(-1), } ) # bump global size_z if we exceeded it @@ -1323,9 +1584,7 @@ def from_ome_zarr( plane = arr[t, c, z] if clamp_to_uint16 and plane.dtype != np.uint16: plane = np.clip(plane, 0, 65535).astype(np.uint16) - planes.append( - {"z": z, "t": t, "c": c, "pixels": plane.ravel().tolist()} - ) + planes.append({"z": z, "t": t, "c": c, "pixels": plane.reshape(-1)}) dim_order = "XYCT" if size_z == 1 else "XYZCT" diff --git a/tests/test_core.py b/tests/test_core.py index bb15a43..515ac46 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -395,6 +395,101 @@ def test_parquet_roundtrip_preserves_image_type(tmp_path: pathlib.Path) -> None: assert reloaded.data.as_py()["image_type"] == "label" +def _backend_symbols(backend: str) -> dict[str, object]: + """Return backend-specific tensor constructors and converters.""" + if backend == "torch": + torch = pytest.importorskip("torch") + return { + "array_from_1d": lambda n, shape: ( + torch.arange(n).reshape(*shape).to(dtype=torch.uint16) + ), + "to_numpy": lambda x: x.numpy(), + "from_array": ingest.from_torch_array, + } + + jnp = pytest.importorskip("jax.numpy") + return { + "array_from_1d": lambda n, shape: jnp.arange(n, dtype=jnp.uint16).reshape( + *shape + ), + "to_numpy": np.asarray, + "from_array": ingest.from_jax_array, + } + + +@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_constructor_accepts_array_backend(backend: str) -> None: + """Accept backend arrays directly in OMEArrow constructor.""" + symbols = _backend_symbols(backend) + arr = symbols["array_from_1d"](2 * 3 * 4, (2, 3, 4)) + expected = symbols["to_numpy"](arr) + + oa = OMEArrow(arr) + exported = oa.export(how="numpy") + + assert exported.shape == (1, 1, 2, 3, 4) + np.testing.assert_array_equal(exported[0, 0], expected) + + +@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_from_array_explicit_dim_order_backend(backend: str) -> None: + """Support explicit dim_order with backend-specific from_* helpers.""" + symbols = _backend_symbols(backend) + arr = symbols["array_from_1d"](2 * 3 * 4 * 5, (2, 3, 4, 5)) + expected = symbols["to_numpy"](arr) + scalar = symbols["from_array"](arr, dim_order="TCYX") + + oa = OMEArrow(scalar) + exported = oa.export(how="numpy") + + assert exported.shape == (2, 3, 1, 4, 5) + np.testing.assert_array_equal(exported[:, :, 0], expected) + + +@pytest.mark.parametrize("backend", ["torch", "jax"]) +def test_constructor_dim_order_override_backend(backend: str) -> None: + """Allow explicit constructor dim_order for ambiguous backend array ranks.""" + symbols = _backend_symbols(backend) + arr = symbols["array_from_1d"](3 * 4 * 5, (3, 4, 5)) + expected = symbols["to_numpy"](arr) + + oa = OMEArrow(arr, dim_order="ZYX") + exported = oa.export(how="numpy") + + assert exported.shape == (1, 1, 3, 4, 5) + np.testing.assert_array_equal(exported[0, 0], expected) + + +def test_constructor_dim_order_rejects_non_array_input() -> None: + """Reject dim_order for non-array sources to avoid silent no-op configs.""" + with pytest.raises(ValueError, match="dim_order is supported only"): + OMEArrow("tests/data/JUMP-BR00117006/BR00117006.ome.parquet", dim_order="ZYX") + + +def test_constructor_second_positional_arg_still_binds_tcz() -> None: + """Preserve positional ABI: second positional argument is tcz, not dim_order.""" + arr = np.arange(12, dtype=np.uint16).reshape(1, 1, 1, 3, 4) + oa = OMEArrow(arr, (0, 0, 1)) + assert oa.tcz == (0, 0, 1) + + +def test_view_rejects_unsupported_mode() -> None: + """Unsupported view modes should raise a clear ValueError.""" + arr = np.arange(12, dtype=np.uint16).reshape(1, 1, 1, 3, 4) + oa = OMEArrow(arr) + + with pytest.raises(ValueError, match="Unsupported view mode"): + oa.view(how="foo") + + +def test_view_rejects_unsupported_mode_before_lazy_materialization() -> None: + """Validate render mode before touching lazy source materialization.""" + oa = OMEArrow.scan("tests/data/does-not-exist.ome.parquet") + + with pytest.raises(ValueError, match="Unsupported view mode"): + oa.view(how="foo") + + def test_vortex_custom_column_name(tmp_path: pathlib.Path) -> None: """Ensure custom Vortex column names are preserved on round-trip.""" pytest.importorskip( diff --git a/uv.lock b/uv.lock index d0474e6..2988c5e 100644 --- a/uv.lock +++ b/uv.lock @@ -2926,8 +2926,8 @@ requires-dist = [ { name = "bioio-tifffile", specifier = ">=1.3" }, { name = "fire", specifier = ">=0.7" }, { name = "ipywidgets", marker = "extra == 'viz'", specifier = ">=8.1.8" }, - { name = "jax", marker = "extra == 'dlpack'", specifier = ">=0.4" }, - { name = "jax", marker = "extra == 'dlpack-jax'", specifier = ">=0.4" }, + { name = "jax", marker = "extra == 'dlpack'", specifier = ">=0.4.1" }, + { name = "jax", marker = "extra == 'dlpack-jax'", specifier = ">=0.4.1" }, { name = "jupyterlab-widgets", marker = "extra == 'viz'", specifier = ">=3.0.16" }, { name = "matplotlib", specifier = ">=3.10.7" }, { name = "numpy", specifier = ">=2.2.6" },