diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 80f6d536..d27b56b6 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -115,6 +115,10 @@ class GraphArrayView(BaseReadOnlyArray): The attribute key to view as an array. offset : int | np.ndarray, optional The offset to apply to the array. + strides : tuple[int, ...] | None, optional + The strides to apply to the array. If None, the default stride of 1 is used. + dtype : np.dtype | None, optional + The dtype of the array. If None, the dtype is inferred from the graph's attribute chunk_shape : tuple[int] | None, optional The chunk shape for the array. If None, the default chunk size is used. buffer_cache_size : int, optional @@ -126,11 +130,13 @@ def __init__( self, graph: BaseGraph, shape: tuple[int, ...], - attr_key: str = DEFAULT_ATTR_KEYS.BBOX, + attr_key: str, + *, + dtype: np.dtype | None = None, offset: int | np.ndarray = 0, + strides: tuple[int, ...] | None = None, chunk_shape: tuple[int, ...] | int | None = None, buffer_cache_size: int | None = None, - dtype: np.dtype | None = None, ): if attr_key not in graph.node_attr_keys: raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys}'") @@ -138,6 +144,10 @@ def __init__( self.graph = graph self._attr_key = attr_key self._offset = offset + self._strides = strides if strides is not None else tuple([1] * (len(shape) - 1)) + self._original_shape = tuple( + [shape[0]] + [(fs - 1) // st + 1 for fs, st in zip(shape[1:], self._strides, strict=True)] + ) if dtype is None: # Infer the dtype from the graph's attribute @@ -152,18 +162,18 @@ def __init__( dtype = np.uint8 self._dtype = dtype - self.original_shape = shape chunk_shape = chunk_shape or get_options().gav_chunk_shape + ndim = len(shape) if isinstance(chunk_shape, int): - chunk_shape = (chunk_shape,) * (len(shape) - 1) - elif len(chunk_shape) < len(shape) - 1: - chunk_shape = (1,) * (len(shape) - 1 - len(chunk_shape)) + tuple(chunk_shape) + chunk_shape = (chunk_shape,) * (ndim - 1) + elif len(chunk_shape) < ndim - 1: + chunk_shape = (1,) * (ndim - 1 - len(chunk_shape)) + tuple(chunk_shape) self.chunk_shape = chunk_shape self.buffer_cache_size = buffer_cache_size or get_options().gav_buffer_cache_size - self._indices = tuple(slice(0, s) for s in shape) + self._indices = tuple(slice(0, s) for s in self._original_shape) self._cache = NDChunkCache( compute_func=self._fill_array, shape=self.shape[1:], @@ -181,7 +191,7 @@ def __init__( def shape(self) -> tuple[int, ...]: """Returns the shape of the array.""" - shape = [_get_size(ind, os) for ind, os in zip(self._indices, self.original_shape, strict=True)] + shape = [_get_size(ind, os) for ind, os in zip(self._indices, self._original_shape, strict=True)] return tuple(s for s in shape if s is not None) @property @@ -274,10 +284,8 @@ def __array__( volume_slicing = self._indices[1:] if np.isscalar(time): - try: + if hasattr(time, "item"): time = time.item() # convert from numpy.int to int - except AttributeError: - pass result = self._cache.get( time=time, volume_slicing=volume_slicing, @@ -285,7 +293,7 @@ def __array__( return np.array(result) if np.isscalar(result) else result else: if isinstance(time, slice): - time = range(self.original_shape[0])[time] + time = range(self._original_shape[0])[time] return np.stack( [ @@ -314,6 +322,14 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda np.ndarray The filled buffer. """ + volume_slicing = [ + slice( + s.start * st if s.start else s.start, + s.stop * st if s.stop else s.stop, + s.step * st if s.step else s.step, + ) + for s, st in zip(volume_slicing, self._strides, strict=True) + ] subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)] df = subgraph.node_attrs( attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK], @@ -321,4 +337,4 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): mask: Mask - mask.paint_buffer(buffer, value, offset=self._offset) + mask.paint_buffer(buffer, value, offset=self._offset, strides=self._strides) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index 59350598..fef660de 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -347,3 +347,145 @@ def test_graph_array_set_options() -> None: array_view = GraphArrayView(graph=empty_graph, shape=(10, 100, 100), attr_key="label") assert array_view.chunk_shape == (512, 512) assert array_view.dtype == np.int16 + + +def test_graph_array_view_with_strides_downsamples_data() -> None: + graph = RustWorkXGraph() + graph.add_node_attr_key("label", 0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + + shape = (3, 7, 8) + strides = (2, 3) + reference = np.zeros(shape, dtype=np.uint8) + + nodes = [ + { + "time": 0, + "start": (0, 0), + "mask": np.array([[True, False, True], [False, True, True]], dtype=bool), + "value": 5, + }, + { + "time": 1, + "start": (2, 1), + "mask": np.array([[True, True, False], [False, True, True]], dtype=bool), + "value": 7, + }, + { + "time": 2, + "start": (3, 3), + "mask": np.array([[True, True, False], [True, False, True]], dtype=bool), + "value": 9, + }, + ] + + for node in nodes: + start = node["start"] + mask_data = node["mask"] + bbox = np.array( + [ + *start, + *(coord + size for coord, size in zip(start, mask_data.shape, strict=True)), + ] + ) + mask = Mask(mask_data, bbox=bbox) + graph.add_node( + { + DEFAULT_ATTR_KEYS.T: node["time"], + "label": node["value"], + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + region = tuple(slice(coord, coord + size) for coord, size in zip(start, mask_data.shape, strict=True)) + reference[(node["time"], *region)][mask_data] = node["value"] + + array_view = GraphArrayView( + graph=graph, + shape=shape, + attr_key="label", + strides=strides, + ) + + expected = reference[(slice(None), *(slice(None, None, step) for step in strides))] + + assert array_view.shape == expected.shape + assert len(array_view) == expected.shape[0] + np.testing.assert_array_equal(np.asarray(array_view), expected) + np.testing.assert_array_equal(np.asarray(array_view[1]), expected[1]) + np.testing.assert_array_equal(np.asarray(array_view[:, 1:, :]), expected[:, 1:, :]) + np.testing.assert_array_equal(np.asarray(array_view[::2, 0]), expected[::2, 0]) + + +def test_graph_array_view_with_strides_3d_data() -> None: + graph = RustWorkXGraph() + graph.add_node_attr_key("label", 0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + + shape = (2, 4, 6, 6) + strides = (2, 3, 2) + reference = np.zeros(shape, dtype=np.uint8) + + nodes = [ + { + "time": 0, + "start": (0, 1, 1), + "mask": np.array( + [ + [[True, False, True], [False, True, False]], + [[True, True, False], [False, False, True]], + ], + dtype=bool, + ), + "value": 4, + }, + { + "time": 1, + "start": (2, 3, 2), + "mask": np.array( + [ + [[True, True], [True, False], [False, True]], + [[True, False], [True, True], [True, True]], + ], + dtype=bool, + ), + "value": 8, + }, + ] + + for node in nodes: + start = node["start"] + mask_data = node["mask"] + bbox = np.array( + [ + *start, + *(coord + size for coord, size in zip(start, mask_data.shape, strict=True)), + ] + ) + mask = Mask(mask_data, bbox=bbox) + graph.add_node( + { + DEFAULT_ATTR_KEYS.T: node["time"], + "label": node["value"], + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + region = tuple(slice(coord, coord + size) for coord, size in zip(start, mask_data.shape, strict=True)) + reference[(node["time"], *region)][mask_data] = node["value"] + + array_view = GraphArrayView( + graph=graph, + shape=shape, + attr_key="label", + strides=strides, + ) + + expected = reference[(slice(None), *(slice(None, None, step) for step in strides))] + + assert array_view.shape == expected.shape + np.testing.assert_array_equal(np.asarray(array_view), expected) + np.testing.assert_array_equal(np.asarray(array_view[:, 1]), expected[:, 1]) + np.testing.assert_array_equal(np.asarray(array_view[:, :, :, 1:]), expected[:, :, :, 1:]) diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 43ec6b96..bdc55332 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -157,6 +157,7 @@ def paint_buffer( buffer: np.ndarray, value: int | float, offset: NDArray[np.integer] | int = 0, + strides: NDArray[np.integer] | int = 1, ) -> None: """ Paint object into a buffer. @@ -169,15 +170,56 @@ def paint_buffer( The value to paint the object. offset : NDArray[np.integer] | int, optional The offset to add to the indices, should be used with bounding box information. + strides : NDArray[np.integer] | int, optional + The strides to apply to the mask when painting. """ if isinstance(offset, int): - offset = np.full(self._mask.ndim, offset) + offset = np.full(self._mask.ndim, offset, dtype=np.int64) + else: + offset = np.asarray(offset, dtype=np.int64) - window = tuple( - slice(i + o, j + o) - for i, j, o in zip(self._bbox[: self._mask.ndim], self._bbox[self._mask.ndim :], offset, strict=True) - ) - buffer[window][self._mask] = value + if isinstance(strides, int): + strides = np.full(self._mask.ndim, strides, dtype=np.int64) + else: + strides = np.asarray(strides, dtype=np.int64) + + windows: list[slice] = [] + mask_slices: list[slice] = [] + + for start, stop, off, stride in zip( + self._bbox[: self._mask.ndim], + self._bbox[self._mask.ndim :], + offset, + strides, + strict=True, + ): + if stride <= 0: + raise ValueError("Strides must be positive when painting a buffer.") + + global_start = start + off + global_stop = stop + off + + # Find the first coordinate within [global_start, global_stop) + # that aligns with the stride grid starting at zero. + slice_start = (-global_start) % stride + first_coord = global_start + slice_start + + if first_coord >= global_stop: + # Nothing aligns with the stride grid along this axis. + return + + window_start = first_coord // stride + window_stop = (global_stop - 1) // stride + 1 + + windows.append(slice(window_start, window_stop)) + mask_slices.append(slice(slice_start, None, stride)) + + sliced_mask = self._mask[tuple(mask_slices)] + + if sliced_mask.size == 0 or not np.any(sliced_mask): + return + + buffer[tuple(windows)][sliced_mask] = value def iou(self, other: "Mask") -> float: """