From a5ba9943309336d7d2dee609d45236c84fb89752 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 1 Oct 2025 22:08:56 +0900 Subject: [PATCH 1/4] added strides to graph array view --- src/tracksdata/array/_graph_array.py | 42 +++++++++++++------ .../array/_test/test_graph_array.py | 18 ++++---- src/tracksdata/io/_ctc.py | 2 +- src/tracksdata/nodes/_mask.py | 14 +++++-- 4 files changed, 50 insertions(+), 26 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index 4a35606e..45706c8b 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -115,6 +115,10 @@ class GraphArrayView(BaseReadOnlyArray): The attribute key to view as an array. offset : int | np.ndarray, optional The offset to apply to the array. + strides : tuple[int, ...] | None, optional + The strides to apply to the array. If None, the default stride of 1 is used. + dtype : np.dtype | None, optional + The dtype of the array. If None, the dtype is inferred from the graph's attribute chunk_shape : tuple[int] | None, optional The chunk shape for the array. If None, the default chunk size is used. buffer_cache_size : int, optional @@ -125,12 +129,14 @@ class GraphArrayView(BaseReadOnlyArray): def __init__( self, graph: BaseGraph, - shape: tuple[int, ...], - attr_key: str = DEFAULT_ATTR_KEYS.BBOX, + full_shape: tuple[int, ...], + attr_key: str, + *, + dtype: np.dtype | None = None, offset: int | np.ndarray = 0, + strides: tuple[int, ...] | None = None, chunk_shape: tuple[int, ...] | int | None = None, buffer_cache_size: int | None = None, - dtype: np.dtype | None = None, ): if attr_key not in graph.node_attr_keys: raise ValueError(f"Attribute key '{attr_key}' not found in graph. Expected '{graph.node_attr_keys}'") @@ -138,6 +144,10 @@ def __init__( self.graph = graph self._attr_key = attr_key self._offset = offset + self._strides = strides if strides is not None else tuple([1] * (len(full_shape) - 1)) + self._original_shape = tuple( + [full_shape[0]] + [fs // st for fs, st in zip(full_shape[1:], strides, strict=True)] + ) if dtype is None: # Infer the dtype from the graph's attribute @@ -152,18 +162,18 @@ def __init__( dtype = np.uint8 self._dtype = dtype - self.original_shape = shape chunk_shape = chunk_shape or get_options().gav_chunk_shape + ndim = len(full_shape) if isinstance(chunk_shape, int): - chunk_shape = (chunk_shape,) * (len(shape) - 1) - elif len(chunk_shape) < len(shape) - 1: - chunk_shape = (1,) * (len(shape) - 1 - len(chunk_shape)) + tuple(chunk_shape) + chunk_shape = (chunk_shape,) * (ndim - 1) + elif len(chunk_shape) < ndim - 1: + chunk_shape = (1,) * (ndim - 1 - len(chunk_shape)) + tuple(chunk_shape) self.chunk_shape = chunk_shape self.buffer_cache_size = buffer_cache_size or get_options().gav_buffer_cache_size - self._indices = tuple(slice(0, s) for s in shape) + self._indices = tuple(slice(0, s) for s in self._original_shape) self._cache = NDChunkCache( compute_func=self._fill_array, shape=self.shape[1:], @@ -181,7 +191,7 @@ def __init__( def shape(self) -> tuple[int, ...]: """Returns the shape of the array.""" - shape = [_get_size(ind, os) for ind, os in zip(self._indices, self.original_shape, strict=True)] + shape = [_get_size(ind, os) for ind, os in zip(self._indices, self._original_shape, strict=True)] return tuple(s for s in shape if s is not None) @property @@ -274,17 +284,15 @@ def __array__( volume_slicing = self._indices[1:] if np.isscalar(time): - try: + if hasattr(time, "item"): time = time.item() # convert from numpy.int to int - except AttributeError: - pass return self._cache.get( time=time, volume_slicing=volume_slicing, ).astype(dtype or self.dtype) else: if isinstance(time, slice): - time = range(self.original_shape[0])[time] + time = range(self._original_shape[0])[time] return np.stack( [ @@ -313,6 +321,14 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda np.ndarray The filled buffer. """ + volume_slicing = [ + slice( + s.start * st if s.start else s.start, + s.stop * st if s.stop else s.stop, + s.step * st if s.step else s.step, + ) + for s, st in zip(volume_slicing, self._strides, strict=True) + ] subgraph = self._spatial_filter[(slice(time, time), *volume_slicing)] df = subgraph.node_attrs( attr_keys=[self._attr_key, DEFAULT_ATTR_KEYS.MASK], diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index b281e1cb..fd72262b 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -35,7 +35,7 @@ def test_graph_array_view_init() -> None: # Add a attribute key graph.add_node_attr_key("label", 0) - array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label", offset=0) + array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label", offset=0) assert array_view.graph is graph assert array_view.shape == (10, 100, 100) @@ -51,7 +51,7 @@ def test_graph_array_view_init_invalid_attr_key() -> None: graph = RustWorkXGraph() with pytest.raises(ValueError, match="Attribute key 'invalid_key' not found in graph"): - GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="invalid_key") + GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="invalid_key") def test_graph_array_view_getitem_empty_time() -> None: @@ -59,7 +59,7 @@ def test_graph_array_view_getitem_empty_time() -> None: graph = RustWorkXGraph() graph.add_node_attr_key("label", 0) - array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label") # Get data for time point 0 (no nodes) result = array_view[0] @@ -97,7 +97,7 @@ def test_graph_array_view_getitem_with_nodes() -> None: } ) - array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label") # Get data for time point 0 result = array_view[0] @@ -158,7 +158,7 @@ def test_graph_array_view_getitem_multiple_nodes() -> None: } ) - array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label") # Get data for time point 0 result = array_view[0] @@ -200,7 +200,7 @@ def test_graph_array_view_getitem_boolean_dtype() -> None: } ) - array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="is_active") + array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="is_active") # Get data for time point 0 result = array_view[0] @@ -238,7 +238,7 @@ def test_graph_array_view_dtype_inference() -> None: } ) - array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="float_label") + array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="float_label") # Get data to trigger dtype inference _ = array_view[0] @@ -262,7 +262,7 @@ def multi_node_graph_from_image(request) -> GraphArrayView: graph = RustWorkXGraph() nodes_operator = RegionPropsNodes(extra_properties=["label"]) nodes_operator.add_nodes(graph, labels=label) - return GraphArrayView(graph=graph, shape=shape, attr_key="label"), label + return GraphArrayView(graph=graph, full_shape=shape, attr_key="label"), label def test_graph_array_view_equal(multi_node_graph_from_image) -> None: @@ -340,6 +340,6 @@ def test_graph_array_set_options() -> None: with Options(gav_chunk_shape=(512, 512), gav_default_dtype=np.int16): empty_graph = RustWorkXGraph() empty_graph.add_node_attr_key("label", 0) - array_view = GraphArrayView(graph=empty_graph, shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=empty_graph, full_shape=(10, 100, 100), attr_key="label") assert array_view.chunk_shape == (512, 512) assert array_view.dtype == np.int16 diff --git a/src/tracksdata/io/_ctc.py b/src/tracksdata/io/_ctc.py index 2ce609e5..70aa7d8d 100644 --- a/src/tracksdata/io/_ctc.py +++ b/src/tracksdata/io/_ctc.py @@ -255,7 +255,7 @@ def to_ctc( output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - view = GraphArrayView(graph, shape=shape, attr_key=track_id_key) + view = GraphArrayView(graph, full_shape=shape, attr_key=track_id_key) n_digits = max(len(str(view.shape[0])), 3) diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index ce50701c..a15c8e55 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -156,6 +156,7 @@ def paint_buffer( buffer: np.ndarray, value: int | float, offset: NDArray[np.integer] | int = 0, + strides: NDArray[np.integer] | int = 1, ) -> None: """ Paint object into a buffer. @@ -168,15 +169,22 @@ def paint_buffer( The value to paint the object. offset : NDArray[np.integer] | int, optional The offset to add to the indices, should be used with bounding box information. + strides : NDArray[np.integer] | int, optional + The strides to apply to the mask when painting. """ if isinstance(offset, int): offset = np.full(self._mask.ndim, offset) + if isinstance(strides, int): + strides = np.full(self._mask.ndim, strides) window = tuple( - slice(i + o, j + o) - for i, j, o in zip(self._bbox[: self._mask.ndim], self._bbox[self._mask.ndim :], offset, strict=True) + slice((i + o) // st, (j + o) // st) + for i, j, o, st in zip( + self._bbox[: self._mask.ndim], self._bbox[self._mask.ndim :], offset, strides, strict=True + ) ) - buffer[window][self._mask] = value + mask_slices = tuple(slice(None, None, st) for st in strides) + buffer[window][self._mask[mask_slices]] = value def iou(self, other: "Mask") -> float: """ From b926525b67fd79766ad3c1f825bba70777ec7b3a Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 1 Oct 2025 22:56:08 +0900 Subject: [PATCH 2/4] test passing --- src/tracksdata/array/_graph_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index b831b395..d6f7b184 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -146,7 +146,7 @@ def __init__( self._offset = offset self._strides = strides if strides is not None else tuple([1] * (len(full_shape) - 1)) self._original_shape = tuple( - [full_shape[0]] + [fs // st for fs, st in zip(full_shape[1:], strides, strict=True)] + [full_shape[0]] + [fs // st for fs, st in zip(full_shape[1:], self._strides, strict=True)] ) if dtype is None: From fbcaca5cd344e97cf1d4dea5d9cc6379e5e6cdde Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Mon, 6 Oct 2025 17:17:49 +0900 Subject: [PATCH 3/4] fixed bug and tested strides --- src/tracksdata/array/_graph_array.py | 4 +- .../array/_test/test_graph_array.py | 142 ++++++++++++++++++ src/tracksdata/nodes/_mask.py | 54 +++++-- 3 files changed, 188 insertions(+), 12 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index d6f7b184..aa4a67eb 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -146,7 +146,7 @@ def __init__( self._offset = offset self._strides = strides if strides is not None else tuple([1] * (len(full_shape) - 1)) self._original_shape = tuple( - [full_shape[0]] + [fs // st for fs, st in zip(full_shape[1:], self._strides, strict=True)] + [full_shape[0]] + [(fs - 1) // st + 1 for fs, st in zip(full_shape[1:], self._strides, strict=True)] ) if dtype is None: @@ -337,4 +337,4 @@ def _fill_array(self, time: int, volume_slicing: Sequence[slice], buffer: np.nda for mask, value in zip(df[DEFAULT_ATTR_KEYS.MASK], df[self._attr_key], strict=True): mask: Mask - mask.paint_buffer(buffer, value, offset=self._offset) + mask.paint_buffer(buffer, value, offset=self._offset, strides=self._strides) diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index aa46b5fe..723892cb 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -347,3 +347,145 @@ def test_graph_array_set_options() -> None: array_view = GraphArrayView(graph=empty_graph, full_shape=(10, 100, 100), attr_key="label") assert array_view.chunk_shape == (512, 512) assert array_view.dtype == np.int16 + + +def test_graph_array_view_with_strides_downsamples_data() -> None: + graph = RustWorkXGraph() + graph.add_node_attr_key("label", 0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + + full_shape = (3, 7, 8) + strides = (2, 3) + reference = np.zeros(full_shape, dtype=np.uint8) + + nodes = [ + { + "time": 0, + "start": (0, 0), + "mask": np.array([[True, False, True], [False, True, True]], dtype=bool), + "value": 5, + }, + { + "time": 1, + "start": (2, 1), + "mask": np.array([[True, True, False], [False, True, True]], dtype=bool), + "value": 7, + }, + { + "time": 2, + "start": (3, 3), + "mask": np.array([[True, True, False], [True, False, True]], dtype=bool), + "value": 9, + }, + ] + + for node in nodes: + start = node["start"] + mask_data = node["mask"] + bbox = np.array( + [ + *start, + *(coord + size for coord, size in zip(start, mask_data.shape, strict=True)), + ] + ) + mask = Mask(mask_data, bbox=bbox) + graph.add_node( + { + DEFAULT_ATTR_KEYS.T: node["time"], + "label": node["value"], + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + region = tuple(slice(coord, coord + size) for coord, size in zip(start, mask_data.shape, strict=True)) + reference[(node["time"], *region)][mask_data] = node["value"] + + array_view = GraphArrayView( + graph=graph, + full_shape=full_shape, + attr_key="label", + strides=strides, + ) + + expected = reference[(slice(None), *(slice(None, None, step) for step in strides))] + + assert array_view.shape == expected.shape + assert len(array_view) == expected.shape[0] + np.testing.assert_array_equal(np.asarray(array_view), expected) + np.testing.assert_array_equal(np.asarray(array_view[1]), expected[1]) + np.testing.assert_array_equal(np.asarray(array_view[:, 1:, :]), expected[:, 1:, :]) + np.testing.assert_array_equal(np.asarray(array_view[::2, 0]), expected[::2, 0]) + + +def test_graph_array_view_with_strides_3d_data() -> None: + graph = RustWorkXGraph() + graph.add_node_attr_key("label", 0) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) + graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) + + full_shape = (2, 4, 6, 6) + strides = (2, 3, 2) + reference = np.zeros(full_shape, dtype=np.uint8) + + nodes = [ + { + "time": 0, + "start": (0, 1, 1), + "mask": np.array( + [ + [[True, False, True], [False, True, False]], + [[True, True, False], [False, False, True]], + ], + dtype=bool, + ), + "value": 4, + }, + { + "time": 1, + "start": (2, 3, 2), + "mask": np.array( + [ + [[True, True], [True, False], [False, True]], + [[True, False], [True, True], [True, True]], + ], + dtype=bool, + ), + "value": 8, + }, + ] + + for node in nodes: + start = node["start"] + mask_data = node["mask"] + bbox = np.array( + [ + *start, + *(coord + size for coord, size in zip(start, mask_data.shape, strict=True)), + ] + ) + mask = Mask(mask_data, bbox=bbox) + graph.add_node( + { + DEFAULT_ATTR_KEYS.T: node["time"], + "label": node["value"], + DEFAULT_ATTR_KEYS.MASK: mask, + DEFAULT_ATTR_KEYS.BBOX: mask.bbox, + } + ) + region = tuple(slice(coord, coord + size) for coord, size in zip(start, mask_data.shape, strict=True)) + reference[(node["time"], *region)][mask_data] = node["value"] + + array_view = GraphArrayView( + graph=graph, + full_shape=full_shape, + attr_key="label", + strides=strides, + ) + + expected = reference[(slice(None), *(slice(None, None, step) for step in strides))] + + assert array_view.shape == expected.shape + np.testing.assert_array_equal(np.asarray(array_view), expected) + np.testing.assert_array_equal(np.asarray(array_view[:, 1]), expected[:, 1]) + np.testing.assert_array_equal(np.asarray(array_view[:, :, :, 1:]), expected[:, :, :, 1:]) diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 95ecbdce..0317b2f8 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -174,18 +174,52 @@ def paint_buffer( The strides to apply to the mask when painting. """ if isinstance(offset, int): - offset = np.full(self._mask.ndim, offset) + offset = np.full(self._mask.ndim, offset, dtype=np.int64) + else: + offset = np.asarray(offset, dtype=np.int64) + if isinstance(strides, int): - strides = np.full(self._mask.ndim, strides) + strides = np.full(self._mask.ndim, strides, dtype=np.int64) + else: + strides = np.asarray(strides, dtype=np.int64) - window = tuple( - slice((i + o) // st, (j + o) // st) - for i, j, o, st in zip( - self._bbox[: self._mask.ndim], self._bbox[self._mask.ndim :], offset, strides, strict=True - ) - ) - mask_slices = tuple(slice(None, None, st) for st in strides) - buffer[window][self._mask[mask_slices]] = value + windows: list[slice] = [] + mask_slices: list[slice] = [] + + for start, stop, off, stride in zip( + self._bbox[: self._mask.ndim], + self._bbox[self._mask.ndim :], + offset, + strides, + strict=True, + ): + if stride <= 0: + raise ValueError("Strides must be positive when painting a buffer.") + + global_start = start + off + global_stop = stop + off + + # Find the first coordinate within [global_start, global_stop) + # that aligns with the stride grid starting at zero. + slice_start = (-global_start) % stride + first_coord = global_start + slice_start + + if first_coord >= global_stop: + # Nothing aligns with the stride grid along this axis. + return + + window_start = first_coord // stride + window_stop = (global_stop - 1) // stride + 1 + + windows.append(slice(window_start, window_stop)) + mask_slices.append(slice(slice_start, None, stride)) + + sliced_mask = self._mask[tuple(mask_slices)] + + if sliced_mask.size == 0 or not np.any(sliced_mask): + return + + buffer[tuple(windows)][sliced_mask] = value def iou(self, other: "Mask") -> float: """ From 1a67f006dbf5902183227fea049e1cd5894bee19 Mon Sep 17 00:00:00 2001 From: Yohsuke Fukai Date: Wed, 15 Oct 2025 10:17:10 +0900 Subject: [PATCH 4/4] renamed full_shape to shape --- src/tracksdata/array/_graph_array.py | 8 ++--- .../array/_test/test_graph_array.py | 30 +++++++++---------- src/tracksdata/io/_ctc.py | 2 +- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/tracksdata/array/_graph_array.py b/src/tracksdata/array/_graph_array.py index aa4a67eb..d27b56b6 100644 --- a/src/tracksdata/array/_graph_array.py +++ b/src/tracksdata/array/_graph_array.py @@ -129,7 +129,7 @@ class GraphArrayView(BaseReadOnlyArray): def __init__( self, graph: BaseGraph, - full_shape: tuple[int, ...], + shape: tuple[int, ...], attr_key: str, *, dtype: np.dtype | None = None, @@ -144,9 +144,9 @@ def __init__( self.graph = graph self._attr_key = attr_key self._offset = offset - self._strides = strides if strides is not None else tuple([1] * (len(full_shape) - 1)) + self._strides = strides if strides is not None else tuple([1] * (len(shape) - 1)) self._original_shape = tuple( - [full_shape[0]] + [(fs - 1) // st + 1 for fs, st in zip(full_shape[1:], self._strides, strict=True)] + [shape[0]] + [(fs - 1) // st + 1 for fs, st in zip(shape[1:], self._strides, strict=True)] ) if dtype is None: @@ -164,7 +164,7 @@ def __init__( self._dtype = dtype chunk_shape = chunk_shape or get_options().gav_chunk_shape - ndim = len(full_shape) + ndim = len(shape) if isinstance(chunk_shape, int): chunk_shape = (chunk_shape,) * (ndim - 1) elif len(chunk_shape) < ndim - 1: diff --git a/src/tracksdata/array/_test/test_graph_array.py b/src/tracksdata/array/_test/test_graph_array.py index 723892cb..fef660de 100644 --- a/src/tracksdata/array/_test/test_graph_array.py +++ b/src/tracksdata/array/_test/test_graph_array.py @@ -35,7 +35,7 @@ def test_graph_array_view_init() -> None: # Add a attribute key graph.add_node_attr_key("label", 0) - array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label", offset=0) + array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label", offset=0) assert array_view.graph is graph assert array_view.shape == (10, 100, 100) @@ -51,7 +51,7 @@ def test_graph_array_view_init_invalid_attr_key() -> None: graph = RustWorkXGraph() with pytest.raises(ValueError, match="Attribute key 'invalid_key' not found in graph"): - GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="invalid_key") + GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="invalid_key") def test_graph_array_view_getitem_empty_time() -> None: @@ -59,7 +59,7 @@ def test_graph_array_view_getitem_empty_time() -> None: graph = RustWorkXGraph() graph.add_node_attr_key("label", 0) - array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") # Get data for time point 0 (no nodes) result = array_view[0] @@ -97,7 +97,7 @@ def test_graph_array_view_getitem_with_nodes() -> None: } ) - array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") # Get data for time point 0 result = array_view[0] @@ -162,7 +162,7 @@ def test_graph_array_view_getitem_multiple_nodes() -> None: } ) - array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="label") # Get data for time point 0 result = array_view[0] @@ -204,7 +204,7 @@ def test_graph_array_view_getitem_boolean_dtype() -> None: } ) - array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="is_active") + array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="is_active") # Get data for time point 0 result = array_view[0] @@ -242,7 +242,7 @@ def test_graph_array_view_dtype_inference() -> None: } ) - array_view = GraphArrayView(graph=graph, full_shape=(10, 100, 100), attr_key="float_label") + array_view = GraphArrayView(graph=graph, shape=(10, 100, 100), attr_key="float_label") # Get data to trigger dtype inference _ = array_view[0] @@ -266,7 +266,7 @@ def multi_node_graph_from_image(request) -> GraphArrayView: graph = RustWorkXGraph() nodes_operator = RegionPropsNodes(extra_properties=["label"]) nodes_operator.add_nodes(graph, labels=label) - return GraphArrayView(graph=graph, full_shape=shape, attr_key="label"), label + return GraphArrayView(graph=graph, shape=shape, attr_key="label"), label def test_graph_array_view_equal(multi_node_graph_from_image) -> None: @@ -344,7 +344,7 @@ def test_graph_array_set_options() -> None: with Options(gav_chunk_shape=(512, 512), gav_default_dtype=np.int16): empty_graph = RustWorkXGraph() empty_graph.add_node_attr_key("label", 0) - array_view = GraphArrayView(graph=empty_graph, full_shape=(10, 100, 100), attr_key="label") + array_view = GraphArrayView(graph=empty_graph, shape=(10, 100, 100), attr_key="label") assert array_view.chunk_shape == (512, 512) assert array_view.dtype == np.int16 @@ -355,9 +355,9 @@ def test_graph_array_view_with_strides_downsamples_data() -> None: graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) - full_shape = (3, 7, 8) + shape = (3, 7, 8) strides = (2, 3) - reference = np.zeros(full_shape, dtype=np.uint8) + reference = np.zeros(shape, dtype=np.uint8) nodes = [ { @@ -403,7 +403,7 @@ def test_graph_array_view_with_strides_downsamples_data() -> None: array_view = GraphArrayView( graph=graph, - full_shape=full_shape, + shape=shape, attr_key="label", strides=strides, ) @@ -424,9 +424,9 @@ def test_graph_array_view_with_strides_3d_data() -> None: graph.add_node_attr_key(DEFAULT_ATTR_KEYS.MASK, None) graph.add_node_attr_key(DEFAULT_ATTR_KEYS.BBOX, None) - full_shape = (2, 4, 6, 6) + shape = (2, 4, 6, 6) strides = (2, 3, 2) - reference = np.zeros(full_shape, dtype=np.uint8) + reference = np.zeros(shape, dtype=np.uint8) nodes = [ { @@ -478,7 +478,7 @@ def test_graph_array_view_with_strides_3d_data() -> None: array_view = GraphArrayView( graph=graph, - full_shape=full_shape, + shape=shape, attr_key="label", strides=strides, ) diff --git a/src/tracksdata/io/_ctc.py b/src/tracksdata/io/_ctc.py index c63e6c65..d8f28ca2 100644 --- a/src/tracksdata/io/_ctc.py +++ b/src/tracksdata/io/_ctc.py @@ -267,7 +267,7 @@ def to_ctc( output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) - view = GraphArrayView(graph, full_shape=shape, attr_key=track_id_key) + view = GraphArrayView(graph, shape=shape, attr_key=track_id_key) n_digits = max(len(str(view.shape[0])), 3)