Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,49 @@ Advanced options:

See full docs: [`docs/src/dlpack.md`](docs/src/dlpack.md)

## Tensor ingest (PyTorch/JAX)

You can ingest torch or JAX arrays directly with `OMEArrow(...)`.
You can also use explicit helper functions from `ome_arrow.ingest`.

Why this is useful:

- It reduces compute overhead by removing conversion code boilerplate in separate model/data pipelines that already use torch or JAX tensors (i.e., it provides a direct port of OME-arrow into popular deep learning libraries).
- However, this is more about clean interoperability than dramatic end-to-end speedups (although we expect fewer handoffs to result in speedups). Specifically:
- It makes it easier for a user to update dimension ordering input in the same place without requiring separate functionality (see argument `dim_order`).
- This smooths handoffs and reduces mistakes when moving between tensor layouts and OME-Arrow records. For example, CPU torch tensors often expose a NumPy view without an extra copy.
- Ingest still materializes OME-Arrow planes/chunks.

```python
from ome_arrow import OMEArrow

# Direct constructor support:
# inferred defaults are rank-based:
# 2D -> "YX", 3D -> "ZYX", 4D -> "TCYX", 5D -> "TCZYX"
oa_torch = OMEArrow(torch_tensor)
oa_jax = OMEArrow(jax_array)

# Optional: override dim order when shape is ambiguous
oa_zyx = OMEArrow(torch_volume, dim_order="ZYX")
```

```python
from ome_arrow.ingest import from_torch_array, from_jax_array

scalar_torch = from_torch_array(torch_tensor, dim_order="TCYX")
scalar_jax = from_jax_array(jax_array, dim_order="TCYX")
```

Notes:

- Torch/JAX support is optional.
- Install extras as needed:
`pip install "ome-arrow[dlpack-torch]"` or `pip install "ome-arrow[dlpack-jax]"`.
- Torch tensors are detached and converted on CPU for ingest.
- `dim_order` is accepted only for NumPy/torch/JAX array inputs.
- Ingest now passes flattened NumPy pixel buffers directly to Arrow.
- This avoids materializing Python `list` payloads per plane/chunk.

## Benchmarking lazy reads

Use the lightweight benchmark utility in `benchmarks/` to compare lazy tensor
Expand Down
63 changes: 63 additions & 0 deletions docs/src/dlpack.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,53 @@ flat = torch.utils.dlpack.from_dlpack(capsule)
tensor = flat.reshape(view.shape)
```

You can also ingest torch tensors directly:

```python
from ome_arrow import OMEArrow
import torch

# 2D tensor interpreted as YX by default.
torch_tensor = torch.randint(0, 256, (128, 128), dtype=torch.uint16)
oa = OMEArrow(torch_tensor)

# 3D tensors are inferred as ZYX by default.
# Use dim_order when your tensor is arranged differently (for example CYX).
torch_volume = torch.randint(0, 256, (16, 128, 128), dtype=torch.uint16)
oa_cyx = OMEArrow(torch_volume, dim_order="CYX")
Comment thread
d33bs marked this conversation as resolved.
```

Use `dim_order` when the inferred axis order does not match your tensor layout.
`dim_order` is only supported for array/tensor ingest paths.

To persist with this interpreted axis mapping, export the resulting OME-Arrow
record (for example to parquet):

```python
from ome_arrow import OMEArrow
import torch

torch_volume = torch.randint(0, 256, (16, 128, 128), dtype=torch.uint16)
oa = OMEArrow(torch_volume, dim_order="ZYX")
oa.export(how="parquet", out="volume.ome.parquet")
```

OME-Arrow stores pixels in canonical OME-style fields (`size_t`, `size_c`,
`size_z`, `size_y`, `size_x`) rather than preserving a free-form input label
string. The interpreted mapping is preserved through those axis sizes and can
be read back with `tensor_view(...)` layouts.

"Batch" dimension note:

- There is no separate `B` axis in the OME-Arrow schema.
- For model batches, map batch to `T` during ingest.
- Examples:
- `B,C,Y,X` -> use `dim_order="TCYX"`
- `B,C,Z,Y,X` -> use `dim_order="TCZYX"`
- `B,Y,X,C` -> use `dim_order="TYXC"`
- If `T` is already meaningful in your data, represent batch as table rows
(one OME-Arrow record per batch item) instead of overloading another image axis.

## Lazy scan-style slicing

```python
Expand Down Expand Up @@ -76,6 +123,22 @@ flat = jnp.from_dlpack(capsule)
arr = flat.reshape(view.shape)
```

You can also ingest JAX arrays directly:

```python
from ome_arrow import OMEArrow
import jax.numpy as jnp

# 2D array interpreted as YX by default.
jax_array = jnp.arange(128 * 128, dtype=jnp.uint16).reshape(128, 128)
oa = OMEArrow(jax_array)

# 3D arrays are inferred as ZYX by default.
# Use dim_order when your array is arranged differently (for example CYX).
jax_volume = jnp.arange(16 * 128 * 128, dtype=jnp.uint16).reshape(16, 128, 128)
oa_cyx = OMEArrow(jax_volume, dim_order="CYX")
```

## Iteration examples

```python
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ maxdepth: 3
---
python-api
dlpack
examples/learning_to_fly_with_ome-arrow
```
67 changes: 65 additions & 2 deletions docs/src/python-api.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
# Python API

```{eval-rst}
ome_arrow
-------------------
.. automodule:: ome_arrow
:members:
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.core
-------------------
.. automodule:: ome_arrow.core
:members:
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.ingest
-------------------
.. automodule:: ome_arrow.ingest
:members:
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.export
-------------------
.. automodule:: ome_arrow.export
:members:
:undoc-members:
:show-inheritance:
Comment thread
d33bs marked this conversation as resolved.
```

```{eval-rst}
ome_arrow.meta
-------------------
.. automodule:: src.ome_arrow.meta
.. automodule:: ome_arrow.meta
:members:
:private-members:
:undoc-members:
Expand All @@ -13,7 +49,34 @@ ome_arrow.meta
```{eval-rst}
ome_arrow.tensor
-------------------
.. automodule:: src.ome_arrow.tensor
.. automodule:: ome_arrow.tensor
:members:
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.transform
-------------------
.. automodule:: ome_arrow.transform
:members:
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.utils
-------------------
.. automodule:: ome_arrow.utils
:members:
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.view
-------------------
.. automodule:: ome_arrow.view
:members:
:undoc-members:
:show-inheritance:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ dependencies = [
"pyarrow>=22",
]
optional-dependencies.dlpack = [
"jax>=0.4",
"jax>=0.4.1",
"torch>=2.1",
]
optional-dependencies.dlpack-jax = [
"jax>=0.4",
"jax>=0.4.1",
]
optional-dependencies.dlpack-torch = [
"torch>=2.1",
Expand Down
2 changes: 2 additions & 0 deletions src/ome_arrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
to_ome_zarr,
)
from ome_arrow.ingest import (
from_jax_array,
from_numpy,
from_ome_parquet,
from_ome_vortex,
from_ome_zarr,
from_tiff,
from_torch_array,
to_ome_arrow,
)
from ome_arrow.meta import OME_ARROW_STRUCT, OME_ARROW_TAG_TYPE, OME_ARROW_TAG_VERSION
Expand Down
Loading
Loading