Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ class GraphArrayView(BaseReadOnlyArray):
The attribute key to view as an array.
offset : int | np.ndarray, optional
The offset to apply to the array.
strides : tuple[int, ...] | None, optional
The strides to apply to the array. If None, the default stride of 1 is used.
dtype : np.dtype | None, optional
The dtype of the array. If None, the dtype is inferred from the graph's attribute
chunk_shape : tuple[int] | None, optional
The chunk shape for the array. If None, the default chunk size is used.
buffer_cache_size : int, optional
Expand All @@ -126,18 +130,24 @@ def __init__(
self,
graph: BaseGraph,
shape: tuple[int, ...],
attr_key: str = DEFAULT_ATTR_KEYS.BBOX,
attr_key: str,
*,
dtype: np.dtype | None = None,
offset: int | np.ndarray = 0,
strides: tuple[int, ...] | None = None,
chunk_shape: tuple[int, ...] | int | None = None,
buffer_cache_size: int | None = None,
dtype: np.dtype | None = None,
):
if attr_key not in graph.node_attr_keys:
raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys}'")

self.graph = graph
self._attr_key = attr_key
self._offset = offset
self._strides = strides if strides is not None else tuple([1] * (len(shape) - 1))
self._original_shape = tuple(
[shape[0]] + [(fs - 1) // st + 1 for fs, st in zip(shape[1:], self._strides, strict=True)]
)

if dtype is None:
# Infer the dtype from the graph's attribute
Expand All @@ -152,18 +162,18 @@ def __init__(
dtype = np.uint8

self._dtype = dtype
self.original_shape = shape

chunk_shape = chunk_shape or get_options().gav_chunk_shape
ndim = len(shape)
if isinstance(chunk_shape, int):
chunk_shape = (chunk_shape,) * (len(shape) - 1)
elif len(chunk_shape) < len(shape) - 1:
chunk_shape = (1,) * (len(shape) - 1 - len(chunk_shape)) + tuple(chunk_shape)
chunk_shape = (chunk_shape,) * (ndim - 1)
elif len(chunk_shape) < ndim - 1:
chunk_shape = (1,) * (ndim - 1 - len(chunk_shape)) + tuple(chunk_shape)

self.chunk_shape = chunk_shape
self.buffer_cache_size = buffer_cache_size or get_options().gav_buffer_cache_size

self._indices = tuple(slice(0, s) for s in shape)
self._indices = tuple(slice(0, s) for s in self._original_shape)
self._cache = NDChunkCache(
compute_func=self._fill_array,
shape=self.shape[1:],
Expand All @@ -181,7 +191,7 @@ def __init__(
def shape(self) -> tuple[int, ...]:
"""Returns the shape of the array."""

shape = [_get_size(ind, os) for ind, os in zip(self._indices, self.original_shape, strict=True)]
shape = [_get_size(ind, os) for ind, os in zip(self._indices, self._original_shape, strict=True)]
return tuple(s for s in shape if s is not None)

@property
Expand Down Expand Up @@ -274,18 +284,16 @@ def __array__(
volume_slicing = self._indices[1:]

if np.isscalar(time):
try:
if hasattr(time, "item"):
time = time.item() # convert from numpy.int to int
except AttributeError:
pass
result = self._cache.get(
time=time,
volume_slicing=volume_slicing,
).astype(dtype or self.dtype)
return np.array(result) if np.isscalar(result) else result
else:
if isinstance(time, slice):
time = range(self.original_shape[0])[time]
time = range(self._original_shape[0])[time]

return np.stack(
[
Expand Down Expand Up @@ -314,11 +322,19 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda
np.ndarray
The filled buffer.
"""
volume_slicing = [
slice(
s.start * st if s.start else s.start,
s.stop * st if s.stop else s.stop,
s.step * st if s.step else s.step,
)
for s, st in zip(volume_slicing, self._strides, strict=True)
]
subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)]
df = subgraph.node_attrs(
attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK],
)

for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True):
mask: Mask
mask.paint_buffer(buffer, value, offset=self._offset)
mask.paint_buffer(buffer, value, offset=self._offset, strides=self._strides)
142 changes: 142 additions & 0 deletions src/tracksdata/array/_test/test_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,145 @@ def test_graph_array_set_options() -> None:
array_view = GraphArrayView(graph=empty_graph, shape=(10, 100, 100), attr_key="label")
assert array_view.chunk_shape == (512, 512)
assert array_view.dtype == np.int16


def test_graph_array_view_with_strides_downsamples_data() -> None:
graph = RustWorkXGraph()
graph.add_node_attr_key("label", 0)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None)

shape = (3, 7, 8)
strides = (2, 3)
reference = np.zeros(shape, dtype=np.uint8)

nodes = [
{
"time": 0,
"start": (0, 0),
"mask": np.array([[True, False, True], [False, True, True]], dtype=bool),
"value": 5,
},
{
"time": 1,
"start": (2, 1),
"mask": np.array([[True, True, False], [False, True, True]], dtype=bool),
"value": 7,
},
{
"time": 2,
"start": (3, 3),
"mask": np.array([[True, True, False], [True, False, True]], dtype=bool),
"value": 9,
},
]

for node in nodes:
start = node["start"]
mask_data = node["mask"]
bbox = np.array(
[
*start,
*(coord + size for coord, size in zip(start, mask_data.shape, strict=True)),
]
)
mask = Mask(mask_data, bbox=bbox)
graph.add_node(
{
DEFAULT_ATTR_KEYS.T: node["time"],
"label": node["value"],
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)
region = tuple(slice(coord, coord + size) for coord, size in zip(start, mask_data.shape, strict=True))
reference[(node["time"], *region)][mask_data] = node["value"]

array_view = GraphArrayView(
graph=graph,
shape=shape,
attr_key="label",
strides=strides,
)

expected = reference[(slice(None), *(slice(None, None, step) for step in strides))]

assert array_view.shape == expected.shape
assert len(array_view) == expected.shape[0]
np.testing.assert_array_equal(np.asarray(array_view), expected)
np.testing.assert_array_equal(np.asarray(array_view[1]), expected[1])
np.testing.assert_array_equal(np.asarray(array_view[:, 1:, :]), expected[:, 1:, :])
np.testing.assert_array_equal(np.asarray(array_view[::2, 0]), expected[::2, 0])


def test_graph_array_view_with_strides_3d_data() -> None:
graph = RustWorkXGraph()
graph.add_node_attr_key("label", 0)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None)
graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None)

shape = (2, 4, 6, 6)
strides = (2, 3, 2)
reference = np.zeros(shape, dtype=np.uint8)

nodes = [
{
"time": 0,
"start": (0, 1, 1),
"mask": np.array(
[
[[True, False, True], [False, True, False]],
[[True, True, False], [False, False, True]],
],
dtype=bool,
),
"value": 4,
},
{
"time": 1,
"start": (2, 3, 2),
"mask": np.array(
[
[[True, True], [True, False], [False, True]],
[[True, False], [True, True], [True, True]],
],
dtype=bool,
),
"value": 8,
},
]

for node in nodes:
start = node["start"]
mask_data = node["mask"]
bbox = np.array(
[
*start,
*(coord + size for coord, size in zip(start, mask_data.shape, strict=True)),
]
)
mask = Mask(mask_data, bbox=bbox)
graph.add_node(
{
DEFAULT_ATTR_KEYS.T: node["time"],
"label": node["value"],
DEFAULT_ATTR_KEYS.MASK: mask,
DEFAULT_ATTR_KEYS.BBOX: mask.bbox,
}
)
region = tuple(slice(coord, coord + size) for coord, size in zip(start, mask_data.shape, strict=True))
reference[(node["time"], *region)][mask_data] = node["value"]

array_view = GraphArrayView(
graph=graph,
shape=shape,
attr_key="label",
strides=strides,
)

expected = reference[(slice(None), *(slice(None, None, step) for step in strides))]

assert array_view.shape == expected.shape
np.testing.assert_array_equal(np.asarray(array_view), expected)
np.testing.assert_array_equal(np.asarray(array_view[:, 1]), expected[:, 1])
np.testing.assert_array_equal(np.asarray(array_view[:, :, :, 1:]), expected[:, :, :, 1:])
54 changes: 48 additions & 6 deletions src/tracksdata/nodes/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def paint_buffer(
buffer: np.ndarray,
value: int | float,
offset: NDArray[np.integer] | int = 0,
strides: NDArray[np.integer] | int = 1,
) -> None:
"""
Paint object into a buffer.
Expand All @@ -169,15 +170,56 @@ def paint_buffer(
The value to paint the object.
offset : NDArray[np.integer] | int, optional
The offset to add to the indices, should be used with bounding box information.
strides : NDArray[np.integer] | int, optional
The strides to apply to the mask when painting.
"""
if isinstance(offset, int):
offset = np.full(self._mask.ndim, offset)
offset = np.full(self._mask.ndim, offset, dtype=np.int64)
else:
offset = np.asarray(offset, dtype=np.int64)

window = tuple(
slice(i + o, j + o)
for i, j, o in zip(self._bbox[: self._mask.ndim], self._bbox[self._mask.ndim :], offset, strict=True)
)
buffer[window][self._mask] = value
if isinstance(strides, int):
strides = np.full(self._mask.ndim, strides, dtype=np.int64)
else:
strides = np.asarray(strides, dtype=np.int64)

windows: list[slice] = []
mask_slices: list[slice] = []

for start, stop, off, stride in zip(
self._bbox[: self._mask.ndim],
self._bbox[self._mask.ndim :],
offset,
strides,
strict=True,
):
if stride <= 0:
raise ValueError("Strides must be positive when painting a buffer.")

global_start = start + off
global_stop = stop + off

# Find the first coordinate within [global_start, global_stop)
# that aligns with the stride grid starting at zero.
slice_start = (-global_start) % stride
first_coord = global_start + slice_start

if first_coord >= global_stop:
# Nothing aligns with the stride grid along this axis.
return

window_start = first_coord // stride
window_stop = (global_stop - 1) // stride + 1

windows.append(slice(window_start, window_stop))
mask_slices.append(slice(slice_start, None, stride))

sliced_mask = self._mask[tuple(mask_slices)]

if sliced_mask.size == 0 or not np.any(sliced_mask):
return

buffer[tuple(windows)][sliced_mask] = value

def iou(self, other: "Mask") -> float:
"""
Expand Down
Loading