From fee29fd37b20043c50b36d1f0bca2a2a620079a1 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 1 Jul 2025 15:59:53 -0700 Subject: [PATCH 1/8] wip node interpolator -- missing tests --- src/tracksdata/nodes/__init__.py | 3 +- src/tracksdata/nodes/_mask.py | 26 +-- src/tracksdata/nodes/_node_interpolator.py | 215 +++++++++++++++++++++ 3 files changed, 232 insertions(+), 12 deletions(-) create mode 100644 src/tracksdata/nodes/_node_interpolator.py diff --git a/src/tracksdata/nodes/__init__.py b/src/tracksdata/nodes/__init__.py index d285b271..1fdb0215 100644 --- a/src/tracksdata/nodes/__init__.py +++ b/src/tracksdata/nodes/__init__.py @@ -2,7 +2,8 @@ from tracksdata.nodes._crop_attrs import CropFuncAttrs from tracksdata.nodes._mask import Mask +from tracksdata.nodes._node_interpolator import NodeInterpolator from tracksdata.nodes._random import RandomNodes from tracksdata.nodes._regionprops import RegionPropsNodes -__all__ = ["CropFuncAttrs", "Mask", "RandomNodes", "RegionPropsNodes"] +__all__ = ["CropFuncAttrs", "Mask", "NodeInterpolator", "RandomNodes", "RegionPropsNodes"] diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 6ca2795e..c6617de3 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -31,18 +31,8 @@ def __init__( mask: NDArray[np.bool_], bbox: np.ndarray, ): - bbox = np.asarray(bbox, dtype=np.int64) - - if mask.ndim != bbox.shape[0] // 2: - raise ValueError(f"Mask dimension {mask.ndim} does not match bbox dimension {bbox.shape[0]} // 2") - - bbox_size = bbox[mask.ndim :] - bbox[: mask.ndim] - - if np.any(mask.shape != bbox_size): - raise ValueError(f"Mask shape {mask.shape} does not match bbox size {bbox_size}") - self._mask = mask - self._bbox = bbox + self.bbox = bbox def __getstate__(self) -> dict: data_dict = self.__dict__.copy() @@ -61,6 +51,20 @@ def mask(self) -> NDArray[np.bool_]: def bbox(self) -> np.ndarray: return self._bbox + @bbox.setter + def bbox(self, bbox: np.ndarray) -> None: + bbox = np.asarray(bbox, dtype=np.int64) + + if self._mask.ndim != bbox.shape[0] // 2: + raise ValueError(f"Mask dimension {self._mask.ndim} does not match bbox dimension {bbox.shape[0]} // 2") + + bbox_size = bbox[self._mask.ndim :] - bbox[: self._mask.ndim] + + if np.any(self._mask.shape != bbox_size): + raise ValueError(f"Mask shape {self._mask.shape} does not match bbox size {bbox_size}") + + self._bbox = bbox + def crop( self, image: NDArray, diff --git a/src/tracksdata/nodes/_node_interpolator.py b/src/tracksdata/nodes/_node_interpolator.py new file mode 100644 index 00000000..0c2375e4 --- /dev/null +++ b/src/tracksdata/nodes/_node_interpolator.py @@ -0,0 +1,215 @@ +import math +from copy import deepcopy +from typing import Any, Protocol + +from tqdm import tqdm + +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.edges._generic_edges import GenericNodeFunctionEdgeAttrs +from tracksdata.graph._base_graph import BaseGraph +from tracksdata.nodes._base_nodes import BaseNodesOperator +from tracksdata.utils._logging import LOG + + +class NodeInterpolationFunc(Protocol): + def __call__( + self, + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + new_t: int, + delta_t: int, + ) -> dict[str, Any]: ... + + +def default_node_interpolation( + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + new_t: int, + delta_t: int, +) -> dict[str, Any]: + """ + Default node interpolator. + Interpolates the 'mask', 'z', 'y', 'x' attributes between the source and target nodes. + + Parameters + ---------- + src_attrs : dict[str, Any] + Source node attributes. + tgt_attrs : dict[str, Any] + Target node attributes. + new_t : int + Time point of the new node. + delta_t : int + Current delta time. + + Returns + ------- + dict[str, Any] + Interpolated node. + """ + new_attrs = deepcopy(tgt_attrs) + new_attrs.pop(DEFAULT_ATTR_KEYS.NODE_ID, None) + + new_attrs[DEFAULT_ATTR_KEYS.T] = new_t + + t_tgt = tgt_attrs[DEFAULT_ATTR_KEYS.T] + w = (t_tgt - new_t) / delta_t + + if w < 0 or w > 1: + raise ValueError(f"w = {w} is not between 0 and 1") + + ndim = new_attrs[DEFAULT_ATTR_KEYS.MASK].mask.ndim + bbox = new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox + + for i, attr_key in enumerate(["x", "y", "z"][:ndim]): + if attr_key not in src_attrs or attr_key not in tgt_attrs: + continue + # TODO: make this part more clear + src_val = src_attrs[attr_key] + tgt_val = tgt_attrs[attr_key] + new_val = w * src_val + (1 - w) * tgt_val + new_attrs[attr_key] = new_val + offset = round(new_val - tgt_val) - 1 + bbox[ndim - i - 1] = bbox[ndim - i - 1] + offset + bbox[2 * ndim - i - 1] = bbox[2 * ndim - i - 1] + offset + + new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox = bbox + + return new_attrs + + +class NodeInterpolationEdgeAttrs(Protocol): + """ + Function to recompute the edge attributes between newly inserted nodes + during node interpolation. + """ + + def __call__( + self, + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + ) -> dict[str, Any]: ... + + +def default_node_interpolation_edge_attrs( + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], +) -> dict[str, Any]: + """ + Default node interpolation edge attributes. + By default, it computes: + + - `delta_t`: the absolute difference between the source and target time points + - `edge_weight`: the Euclidean distance between the source and target nodes + + Parameters + ---------- + src_attrs : dict[str, Any] + Source node attributes. + tgt_attrs : dict[str, Any] + Target node attributes. + + Returns + ------- + dict[str, Any] + Node interpolation edge attributes. + """ + return { + "delta_t": abs(src_attrs[DEFAULT_ATTR_KEYS.T] - tgt_attrs[DEFAULT_ATTR_KEYS.T]), + DEFAULT_ATTR_KEYS.EDGE_WEIGHT: math.sqrt( + sum((src_attrs.get(attr_key, 0.0) - tgt_attrs.get(attr_key, 0.0)) ** 2 for attr_key in ["z", "y", "x"]) + ), + } + + +class NodeInterpolator(BaseNodesOperator): + """ + Interpolate nodes between non-consecutive time points (delta_t > 1). + """ + + def __init__( + self, + delta_t_key: str = "delta_t", + show_progress: bool = True, + node_interpolation_func: NodeInterpolationFunc = default_node_interpolation, + edge_attrs_func: NodeInterpolationEdgeAttrs = default_node_interpolation_edge_attrs, + ): + super().__init__(show_progress=show_progress) + self.delta_t_key = delta_t_key + self.node_interpolation_func = node_interpolation_func + self.edge_attrs_func = edge_attrs_func + + def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: + if t is not None: + raise ValueError("'t' must be None for node interpolation") + + if self.delta_t_key not in graph.edge_attr_keys: + LOG.warning( + "The key '%s' is not in graph.edge_attrs (%s). Inserting edge attribute `delta_t` using '%s' attribute", + self.delta_t_key, + graph.edge_attr_keys, + DEFAULT_ATTR_KEYS.T, + ) + GenericNodeFunctionEdgeAttrs( + func=lambda x, y: abs(x - y), + attr_keys=DEFAULT_ATTR_KEYS.T, + output_key=self.delta_t_key, + show_progress=self.show_progress, + ).add_edge_attrs(graph) + + edge_attrs = graph.edge_attrs(attr_keys=[self.delta_t_key]) + long_edges = edge_attrs.filter(edge_attrs[self.delta_t_key] > 1) + + selected_node_ids = set(long_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE]) | set( + long_edges[DEFAULT_ATTR_KEYS.EDGE_TARGET] + ) + + nodes = graph.node_attrs( + node_ids=list(selected_node_ids), + ) + nodes_by_id = {node[DEFAULT_ATTR_KEYS.NODE_ID]: node for node in nodes.iter_rows(named=True)} + + for long_edge in tqdm( + list(long_edges.iter_rows(named=True)), + disable=not self.show_progress, + desc="Interpolating nodes", + ): + delta_t = long_edge[self.delta_t_key] + + src_id = long_edge[DEFAULT_ATTR_KEYS.EDGE_SOURCE] + tgt_id = long_edge[DEFAULT_ATTR_KEYS.EDGE_TARGET] + + src_attrs = nodes_by_id[src_id] + tgt_attrs = nodes_by_id[tgt_id] + + while delta_t > 1: + new_node_attrs = self.node_interpolation_func( + src_attrs=src_attrs, + tgt_attrs=tgt_attrs, + new_t=tgt_attrs[DEFAULT_ATTR_KEYS.T] - 1, + delta_t=delta_t, + ) + + new_node_id = graph.add_node(new_node_attrs) + nodes_by_id[new_node_id] = new_node_attrs + + graph.add_edge( + source_id=src_id, + target_id=new_node_id, + attrs=self.edge_attrs_func( + src_attrs=src_attrs, + tgt_attrs=new_node_attrs, + ), + ) + + graph.add_edge( + source_id=new_node_id, + target_id=tgt_id, + attrs=self.edge_attrs_func( + src_attrs=new_node_attrs, + tgt_attrs=tgt_attrs, + ), + ) + + delta_t -= 1 + tgt_attrs = new_node_attrs From 851f4d6706191cd8d0b19c72867c16c95c2fb236 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 1 Jul 2025 17:48:07 -0700 Subject: [PATCH 2/8] add bbox center method to mask --- src/tracksdata/nodes/_mask.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index c6617de3..50330826 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -90,7 +90,7 @@ def crop( slicing = tuple(slice(self._bbox[i], self._bbox[i + ndim]) for i in range(ndim)) else: - center = (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2 + center = self.bbox_center() half_shape = np.asarray(shape) // 2 start = np.maximum(center - half_shape, 0) end = np.minimum(center + half_shape, image.shape) @@ -98,6 +98,12 @@ def crop( return image[slicing] + def bbox_center(self) -> NDArray[np.integer]: + """ + Get the center of the bounding box. + """ + return (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2 + def mask_indices( self, offset: NDArray[np.integer] | int = 0, From fa192c67aaf0e68e6dc3f4bbacf82410e19dd527 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 1 Jul 2025 17:48:51 -0700 Subject: [PATCH 3/8] WIP testing node interpolation --- src/tracksdata/nodes/_node_interpolator.py | 40 ++++++++++++++-------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/tracksdata/nodes/_node_interpolator.py b/src/tracksdata/nodes/_node_interpolator.py index 0c2375e4..9c1b0021 100644 --- a/src/tracksdata/nodes/_node_interpolator.py +++ b/src/tracksdata/nodes/_node_interpolator.py @@ -2,6 +2,7 @@ from copy import deepcopy from typing import Any, Protocol +import numpy as np from tqdm import tqdm from tracksdata.constants import DEFAULT_ATTR_KEYS @@ -59,21 +60,30 @@ def default_node_interpolation( raise ValueError(f"w = {w} is not between 0 and 1") ndim = new_attrs[DEFAULT_ATTR_KEYS.MASK].mask.ndim - bbox = new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox - - for i, attr_key in enumerate(["x", "y", "z"][:ndim]): - if attr_key not in src_attrs or attr_key not in tgt_attrs: - continue - # TODO: make this part more clear - src_val = src_attrs[attr_key] - tgt_val = tgt_attrs[attr_key] - new_val = w * src_val + (1 - w) * tgt_val - new_attrs[attr_key] = new_val - offset = round(new_val - tgt_val) - 1 - bbox[ndim - i - 1] = bbox[ndim - i - 1] + offset - bbox[2 * ndim - i - 1] = bbox[2 * ndim - i - 1] + offset - - new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox = bbox + new_bbox = new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox + + # updating bounding box + tgt_center = tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox_center() + src_center = src_attrs[DEFAULT_ATTR_KEYS.MASK].bbox_center() + signed_dist = tgt_center - src_center + offset = -np.round((1 - w) * signed_dist).astype(int) + + for i in range(ndim): + if offset[i] > 0: + new_value = new_bbox[ndim + i] - offset[i] + dist_to_border = min(new_value - tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox[ndim + i], 0) + offset[i] += dist_to_border + else: + new_value = new_bbox[i] + offset[i] + dist_to_border = max(tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox[i] - new_value, 0) + offset[i] += dist_to_border + + new_bbox[ndim:] = new_bbox[ndim:] + offset + new_bbox[:ndim] = new_bbox[:ndim] + offset + new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox = new_bbox + + for o, attr_key in zip(offset[::-1], ["x", "y", "z"], strict=False): + new_attrs[attr_key] += o return new_attrs From c2cc4d9857aba07cc038ffdd70c15e61df9350d2 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 1 Jul 2025 17:59:10 -0700 Subject: [PATCH 4/8] numba implementation --- src/tracksdata/nodes/_mask.py | 50 ++++++++++++++++++++++ src/tracksdata/nodes/_node_interpolator.py | 23 +++------- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index 50330826..fb4450b1 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -2,11 +2,61 @@ import blosc2 import numpy as np +from numba import njit from numpy.typing import NDArray from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox +@njit +def bbox_interpolation_offset( + tgt_bbox: np.ndarray, + src_bbox: np.ndarray, + w: float, +) -> np.ndarray: + """ + Interpolate the bounding box between two masks. + The reference is the target mask and w is relative distance to it: + ```python + (target_value - new_value) / (target_value - source_value) = w + ``` + + Parameters + ---------- + tgt_bbox : np.ndarray + The target bounding box. + src_bbox : np.ndarray + The source bounding box. + w : float + The weight of the interpolation. + + Returns + ------- + np.ndarray + The offset to add to the bounding box. + """ + if w < 0 or w > 1: + raise ValueError(f"w = {w} is not between 0 and 1") + + ndim = tgt_bbox.shape[0] // 2 + tgt_center = tgt_bbox[ndim:] - tgt_bbox[:ndim] // 2 + src_center = src_bbox[ndim:] - src_bbox[:ndim] // 2 + signed_dist = tgt_center - src_center + offset = -np.round((1 - w) * signed_dist).astype(np.int32) + + for i in range(ndim): + if offset[i] > 0: + new_value = tgt_bbox[ndim + i] - offset[i] + dist_to_border = min(new_value - tgt_bbox[ndim + i], 0) + offset[i] += dist_to_border + else: + new_value = tgt_bbox[i] + offset[i] + dist_to_border = max(tgt_bbox[i] - new_value, 0) + offset[i] += dist_to_border + + return offset + + class Mask: """ Object used to store an individual segmentation mask of a single instance (object) diff --git a/src/tracksdata/nodes/_node_interpolator.py b/src/tracksdata/nodes/_node_interpolator.py index 9c1b0021..14c76294 100644 --- a/src/tracksdata/nodes/_node_interpolator.py +++ b/src/tracksdata/nodes/_node_interpolator.py @@ -2,13 +2,13 @@ from copy import deepcopy from typing import Any, Protocol -import numpy as np from tqdm import tqdm from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.edges._generic_edges import GenericNodeFunctionEdgeAttrs from tracksdata.graph._base_graph import BaseGraph from tracksdata.nodes._base_nodes import BaseNodesOperator +from tracksdata.nodes._mask import bbox_interpolation_offset from tracksdata.utils._logging import LOG @@ -63,20 +63,11 @@ def default_node_interpolation( new_bbox = new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox # updating bounding box - tgt_center = tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox_center() - src_center = src_attrs[DEFAULT_ATTR_KEYS.MASK].bbox_center() - signed_dist = tgt_center - src_center - offset = -np.round((1 - w) * signed_dist).astype(int) - - for i in range(ndim): - if offset[i] > 0: - new_value = new_bbox[ndim + i] - offset[i] - dist_to_border = min(new_value - tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox[ndim + i], 0) - offset[i] += dist_to_border - else: - new_value = new_bbox[i] + offset[i] - dist_to_border = max(tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox[i] - new_value, 0) - offset[i] += dist_to_border + offset = bbox_interpolation_offset( + tgt_bbox=tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox, + src_bbox=src_attrs[DEFAULT_ATTR_KEYS.MASK].bbox, + w=w, + ) new_bbox[ndim:] = new_bbox[ndim:] + offset new_bbox[:ndim] = new_bbox[:ndim] + offset @@ -182,7 +173,7 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: for long_edge in tqdm( list(long_edges.iter_rows(named=True)), disable=not self.show_progress, - desc="Interpolating nodes", + desc="Interpolating and adding nodes", ): delta_t = long_edge[self.delta_t_key] From 033353e3cd7180f9521b18b95138a414b975e2a1 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Wed, 2 Jul 2025 09:18:48 -0700 Subject: [PATCH 5/8] wip: bug fixing interpolator and adding new features --- src/tracksdata/nodes/_node_interpolator.py | 143 +++++++++++++++++---- 1 file changed, 116 insertions(+), 27 deletions(-) diff --git a/src/tracksdata/nodes/_node_interpolator.py b/src/tracksdata/nodes/_node_interpolator.py index 14c76294..5a3fa2fb 100644 --- a/src/tracksdata/nodes/_node_interpolator.py +++ b/src/tracksdata/nodes/_node_interpolator.py @@ -1,3 +1,4 @@ +import logging import math from copy import deepcopy from typing import Any, Protocol @@ -22,6 +23,18 @@ def __call__( ) -> dict[str, Any]: ... +def _replace_if_exists( + new_attrs: dict[str, Any], + attr_key: str, + value: Any, +) -> None: + """ + Replace the value of an attribute if it exists in the new attributes. + """ + if attr_key in new_attrs: + new_attrs[attr_key] = value + + def default_node_interpolation( src_attrs: dict[str, Any], tgt_attrs: dict[str, Any], @@ -76,6 +89,31 @@ def default_node_interpolation( for o, attr_key in zip(offset[::-1], ["x", "y", "z"], strict=False): new_attrs[attr_key] += o + if tgt_attrs.get(DEFAULT_ATTR_KEYS.TRACK_ID) != src_attrs.get(DEFAULT_ATTR_KEYS.TRACK_ID): + new_attrs[DEFAULT_ATTR_KEYS.TRACK_ID] = -1 + + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.SOLUTION, + src_attrs.get(DEFAULT_ATTR_KEYS.SOLUTION, False) and tgt_attrs.get(DEFAULT_ATTR_KEYS.SOLUTION, False), + ) + + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, + -1, + ) + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.MATCH_SCORE, + 0, + ) + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.MATCHED_EDGE_MASK, + False, + ) + return new_attrs @@ -87,40 +125,64 @@ class NodeInterpolationEdgeAttrs(Protocol): def __call__( self, + long_edge: dict[str, Any], src_attrs: dict[str, Any], tgt_attrs: dict[str, Any], - ) -> dict[str, Any]: ... + new_attrs: dict[str, Any], + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: ... def default_node_interpolation_edge_attrs( + long_edge: dict[str, Any], src_attrs: dict[str, Any], tgt_attrs: dict[str, Any], -) -> dict[str, Any]: + new_attrs: dict[str, Any], +) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: """ Default node interpolation edge attributes. By default, it computes: - `delta_t`: the absolute difference between the source and target time points - `edge_weight`: the Euclidean distance between the source and target nodes + - `solution`: it removes the solution attribute from the source and target nodes + and adding to the new edges. Parameters ---------- + long_edge : dict[str, Any] + Edge defining the interpolation. src_attrs : dict[str, Any] Source node attributes. tgt_attrs : dict[str, Any] Target node attributes. + new_attrs : dict[str, Any] + New node attributes. Returns ------- - dict[str, Any] - Node interpolation edge attributes. + tuple[dict[str, Any] | None, dict[str, Any] | None] + Source to new node edge attributes and new to target edge attributes. + Return `None` if no edge should be added. """ - return { - "delta_t": abs(src_attrs[DEFAULT_ATTR_KEYS.T] - tgt_attrs[DEFAULT_ATTR_KEYS.T]), - DEFAULT_ATTR_KEYS.EDGE_WEIGHT: math.sqrt( - sum((src_attrs.get(attr_key, 0.0) - tgt_attrs.get(attr_key, 0.0)) ** 2 for attr_key in ["z", "y", "x"]) - ), - } + + new_edge_attrs = [] + + for x, y in [(src_attrs, new_attrs), (new_attrs, tgt_attrs)]: + new_edge_attrs.append( + { + "delta_t": abs(x[DEFAULT_ATTR_KEYS.T] - y[DEFAULT_ATTR_KEYS.T]), + DEFAULT_ATTR_KEYS.EDGE_WEIGHT: math.sqrt( + sum((x.get(attr_key, 0.0) - y.get(attr_key, 0.0)) ** 2 for attr_key in ["z", "y", "x"]) + ), + } + ) + + solution = long_edge.get(DEFAULT_ATTR_KEYS.SOLUTION, False) + if solution is not None: + new_edge_attrs[0][DEFAULT_ATTR_KEYS.SOLUTION] = solution + new_edge_attrs[1][DEFAULT_ATTR_KEYS.SOLUTION] = solution + + return new_edge_attrs class NodeInterpolator(BaseNodesOperator): @@ -134,11 +196,13 @@ def __init__( show_progress: bool = True, node_interpolation_func: NodeInterpolationFunc = default_node_interpolation, edge_attrs_func: NodeInterpolationEdgeAttrs = default_node_interpolation_edge_attrs, + validate_keys: bool = False, ): super().__init__(show_progress=show_progress) self.delta_t_key = delta_t_key self.node_interpolation_func = node_interpolation_func self.edge_attrs_func = edge_attrs_func + self.validate_keys = validate_keys def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: if t is not None: @@ -158,7 +222,7 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: show_progress=self.show_progress, ).add_edge_attrs(graph) - edge_attrs = graph.edge_attrs(attr_keys=[self.delta_t_key]) + edge_attrs = graph.edge_attrs() long_edges = edge_attrs.filter(edge_attrs[self.delta_t_key] > 1) selected_node_ids = set(long_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE]) | set( @@ -191,26 +255,51 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: delta_t=delta_t, ) - new_node_id = graph.add_node(new_node_attrs) + new_node_id = graph.add_node(new_node_attrs, validate_keys=self.validate_keys) nodes_by_id[new_node_id] = new_node_attrs - graph.add_edge( - source_id=src_id, - target_id=new_node_id, - attrs=self.edge_attrs_func( - src_attrs=src_attrs, - tgt_attrs=new_node_attrs, - ), + src_to_new_edge_attrs, new_to_tgt_edge_attrs = self.edge_attrs_func( + long_edge=long_edge, + src_attrs=src_attrs, + tgt_attrs=tgt_attrs, + new_attrs=new_node_attrs, ) - graph.add_edge( - source_id=new_node_id, - target_id=tgt_id, - attrs=self.edge_attrs_func( - src_attrs=new_node_attrs, - tgt_attrs=tgt_attrs, - ), - ) + if src_to_new_edge_attrs is not None: + graph.add_edge( + source_id=src_id, + target_id=new_node_id, + attrs=src_to_new_edge_attrs, + validate_keys=self.validate_keys, + ) + + if new_to_tgt_edge_attrs is not None: + graph.add_edge( + source_id=new_node_id, + target_id=tgt_id, + attrs=new_to_tgt_edge_attrs, + validate_keys=self.validate_keys, + ) + + if long_edge.get(DEFAULT_ATTR_KEYS.SOLUTION, False): + graph.update_edge_attrs( + attrs={DEFAULT_ATTR_KEYS.SOLUTION: False}, + edge_ids=[long_edge[DEFAULT_ATTR_KEYS.EDGE_ID]], + ) + + if LOG.isEnabledFor(logging.INFO): + LOG.info("s -> t (before): %s", long_edge[DEFAULT_ATTR_KEYS.SOLUTION]) + LOG.info( + "s -> t (after): %s", + graph.edge_attrs(attr_keys=[DEFAULT_ATTR_KEYS.SOLUTION]) + .filter(edge_id=long_edge[DEFAULT_ATTR_KEYS.EDGE_ID]) + .to_dicts()[0], + ) + LOG.info("s -> n: %s", src_to_new_edge_attrs) + LOG.info("n -> t: %s", new_to_tgt_edge_attrs) delta_t -= 1 tgt_attrs = new_node_attrs + tgt_id = new_node_id + # replacing long_edge by new shorter but still long edge + long_edge = src_to_new_edge_attrs.copy() From ca581e1fc12217b99cb3b8198d6dc08637bb8db8 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 8 Jul 2025 16:27:40 -0700 Subject: [PATCH 6/8] adding collision to node interpolator --- src/tracksdata/nodes/_node_interpolator.py | 42 +++++++++++++++++----- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/src/tracksdata/nodes/_node_interpolator.py b/src/tracksdata/nodes/_node_interpolator.py index 5a3fa2fb..6961de7f 100644 --- a/src/tracksdata/nodes/_node_interpolator.py +++ b/src/tracksdata/nodes/_node_interpolator.py @@ -5,11 +5,13 @@ from tqdm import tqdm +from tracksdata.attrs import NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS -from tracksdata.edges._generic_edges import GenericNodeFunctionEdgeAttrs +from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs from tracksdata.graph._base_graph import BaseGraph from tracksdata.nodes._base_nodes import BaseNodesOperator -from tracksdata.nodes._mask import bbox_interpolation_offset +from tracksdata.nodes._mask import Mask, bbox_interpolation_offset +from tracksdata.options import get_options from tracksdata.utils._logging import LOG @@ -193,16 +195,17 @@ class NodeInterpolator(BaseNodesOperator): def __init__( self, delta_t_key: str = "delta_t", - show_progress: bool = True, node_interpolation_func: NodeInterpolationFunc = default_node_interpolation, edge_attrs_func: NodeInterpolationEdgeAttrs = default_node_interpolation_edge_attrs, validate_keys: bool = False, + iou_threshold: float = 0.5, ): - super().__init__(show_progress=show_progress) + super().__init__() self.delta_t_key = delta_t_key self.node_interpolation_func = node_interpolation_func self.edge_attrs_func = edge_attrs_func self.validate_keys = validate_keys + self.iou_threshold = iou_threshold def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: if t is not None: @@ -215,15 +218,14 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: graph.edge_attr_keys, DEFAULT_ATTR_KEYS.T, ) - GenericNodeFunctionEdgeAttrs( + GenericFuncEdgeAttrs( func=lambda x, y: abs(x - y), attr_keys=DEFAULT_ATTR_KEYS.T, output_key=self.delta_t_key, - show_progress=self.show_progress, ).add_edge_attrs(graph) edge_attrs = graph.edge_attrs() - long_edges = edge_attrs.filter(edge_attrs[self.delta_t_key] > 1) + long_edges = edge_attrs.filter(edge_attrs[self.delta_t_key] > 1).sort(self.delta_t_key) selected_node_ids = set(long_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE]) | set( long_edges[DEFAULT_ATTR_KEYS.EDGE_TARGET] @@ -234,9 +236,11 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: ) nodes_by_id = {node[DEFAULT_ATTR_KEYS.NODE_ID]: node for node in nodes.iter_rows(named=True)} + count = 0 + for long_edge in tqdm( list(long_edges.iter_rows(named=True)), - disable=not self.show_progress, + disable=not get_options().show_progress, desc="Interpolating and adding nodes", ): delta_t = long_edge[self.delta_t_key] @@ -255,7 +259,26 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: delta_t=delta_t, ) + masks_in_t = graph.node_attrs( + node_ids=graph.filter_nodes_by_attrs( + NodeAttr(DEFAULT_ATTR_KEYS.T) == new_node_attrs[DEFAULT_ATTR_KEYS.T] + ), + attr_keys=[DEFAULT_ATTR_KEYS.MASK], + )[DEFAULT_ATTR_KEYS.MASK] + new_mask: Mask = new_node_attrs[DEFAULT_ATTR_KEYS.MASK] + + found_collision = False + for mask in masks_in_t: + iou = new_mask.iou(mask) + if iou > self.iou_threshold: + found_collision = True + break + + if found_collision: + break + new_node_id = graph.add_node(new_node_attrs, validate_keys=self.validate_keys) + count += 1 nodes_by_id[new_node_id] = new_node_attrs src_to_new_edge_attrs, new_to_tgt_edge_attrs = self.edge_attrs_func( @@ -303,3 +326,6 @@ def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: tgt_id = new_node_id # replacing long_edge by new shorter but still long edge long_edge = src_to_new_edge_attrs.copy() + + LOG.info("Number of nodes added: %s", count) + print(f"Number of nodes added: {count}") From 65f0ede5b1126ccd35a2caee2a71bf139ff06e76 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Tue, 8 Jul 2025 16:30:10 -0700 Subject: [PATCH 7/8] fixing when adding attribute to existing key --- src/tracksdata/edges/_distance_edges.py | 3 ++- src/tracksdata/edges/_generic_edges.py | 3 ++- src/tracksdata/nodes/_generic_nodes.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/tracksdata/edges/_distance_edges.py b/src/tracksdata/edges/_distance_edges.py index 3b03b1d2..414b03c9 100644 --- a/src/tracksdata/edges/_distance_edges.py +++ b/src/tracksdata/edges/_distance_edges.py @@ -108,7 +108,8 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: """ Initialize the edge attributes for the graph. """ - graph.add_edge_attr_key(self.output_key, default_value=-99999.0) + if self.output_key not in graph.edge_attr_keys: + graph.add_edge_attr_key(self.output_key, default_value=-99999.0) def _add_edges_per_time( self, diff --git a/src/tracksdata/edges/_generic_edges.py b/src/tracksdata/edges/_generic_edges.py index f7b26914..bbd100a2 100644 --- a/src/tracksdata/edges/_generic_edges.py +++ b/src/tracksdata/edges/_generic_edges.py @@ -53,7 +53,8 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: """ Initialize the edge attributes for the graph. """ - graph.add_edge_attr_key(self.output_key, default_value=-99999.0) + if self.output_key not in graph.edge_attr_keys: + graph.add_edge_attr_key(self.output_key, default_value=-99999.0) def _edge_attrs_per_time( self, diff --git a/src/tracksdata/nodes/_generic_nodes.py b/src/tracksdata/nodes/_generic_nodes.py index 609d4e79..0d777d3b 100644 --- a/src/tracksdata/nodes/_generic_nodes.py +++ b/src/tracksdata/nodes/_generic_nodes.py @@ -108,7 +108,8 @@ def _init_node_attrs(self, graph: BaseGraph) -> None: """ Initialize the node attributes for the graph. """ - graph.add_node_attr_key(self.output_key, default_value=self.default_value) + if self.output_key not in graph.node_attr_keys: + graph.add_node_attr_key(self.output_key, default_value=self.default_value) def add_node_attrs( self, From df5c022d19092906edcff83f8b3dabc443039698 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Thu, 10 Jul 2025 09:48:50 -0700 Subject: [PATCH 8/8] adding function to compute overlaps --- src/tracksdata/graph/_base_graph.py | 40 ++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 1d5af642..d4661cdd 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -9,9 +9,10 @@ import polars as pl from numpy.typing import ArrayLike -from tracksdata.attrs import AttrComparison +from tracksdata.attrs import AttrComparison, NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.utils._logging import LOG +from tracksdata.utils._multiprocessing import multiprocessing_apply if TYPE_CHECKING: from tracksdata.graph._graph_view import GraphView @@ -792,3 +793,40 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph.bulk_add_overlaps(overlaps.tolist()) return graph + + def set_overlaps(self, iou_threshold: float = 0.0) -> None: + """ + Find overlapping nodes within each frame and add them their overlap relation into the graph. + + Parameters + ---------- + iou_threshold : float + Nodes with an IoU greater than this threshold are considered overlapping. + If 0, all nodes are considered overlapping. + + Examples + -------- + ```python + graph.set_overlaps(iou_threshold=0.5) + ``` + """ + if iou_threshold < 0.0 or iou_threshold > 1.0: + raise ValueError("iou_threshold must be between 0.0 and 1.0") + + def _estimate_overlaps(t: int) -> list[list[int, 2]]: + node_ids = self.filter_nodes_by_attrs(NodeAttr(DEFAULT_ATTR_KEYS.T) == t) + masks = self.node_attrs(node_ids=node_ids, attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK] + overlaps = [] + for i in range(len(masks)): + mask_i = masks[i] + for j in range(i + 1, len(masks)): + if mask_i.iou(masks[j]) > iou_threshold: + overlaps.append([node_ids[i], node_ids[j]]) + return overlaps + + for overlaps in multiprocessing_apply( + func=_estimate_overlaps, + sequence=self.time_points(), + desc="Setting overlaps", + ): + self.bulk_add_overlaps(overlaps)