diff --git a/src/tracksdata/edges/_distance_edges.py b/src/tracksdata/edges/_distance_edges.py index 3b03b1d2..414b03c9 100644 --- a/src/tracksdata/edges/_distance_edges.py +++ b/src/tracksdata/edges/_distance_edges.py @@ -108,7 +108,8 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: """ Initialize the edge attributes for the graph. """ - graph.add_edge_attr_key(self.output_key, default_value=-99999.0) + if self.output_key not in graph.edge_attr_keys: + graph.add_edge_attr_key(self.output_key, default_value=-99999.0) def _add_edges_per_time( self, diff --git a/src/tracksdata/edges/_generic_edges.py b/src/tracksdata/edges/_generic_edges.py index f7b26914..bbd100a2 100644 --- a/src/tracksdata/edges/_generic_edges.py +++ b/src/tracksdata/edges/_generic_edges.py @@ -53,7 +53,8 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None: """ Initialize the edge attributes for the graph. """ - graph.add_edge_attr_key(self.output_key, default_value=-99999.0) + if self.output_key not in graph.edge_attr_keys: + graph.add_edge_attr_key(self.output_key, default_value=-99999.0) def _edge_attrs_per_time( self, diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 1d5af642..d4661cdd 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -9,9 +9,10 @@ import polars as pl from numpy.typing import ArrayLike -from tracksdata.attrs import AttrComparison +from tracksdata.attrs import AttrComparison, NodeAttr from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.utils._logging import LOG +from tracksdata.utils._multiprocessing import multiprocessing_apply if TYPE_CHECKING: from tracksdata.graph._graph_view import GraphView @@ -792,3 +793,40 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T: graph.bulk_add_overlaps(overlaps.tolist()) return graph + + def set_overlaps(self, iou_threshold: float = 0.0) -> None: + """ + Find overlapping nodes within each frame and add them their overlap relation into the graph. + + Parameters + ---------- + iou_threshold : float + Nodes with an IoU greater than this threshold are considered overlapping. + If 0, all nodes are considered overlapping. + + Examples + -------- + ```python + graph.set_overlaps(iou_threshold=0.5) + ``` + """ + if iou_threshold < 0.0 or iou_threshold > 1.0: + raise ValueError("iou_threshold must be between 0.0 and 1.0") + + def _estimate_overlaps(t: int) -> list[list[int, 2]]: + node_ids = self.filter_nodes_by_attrs(NodeAttr(DEFAULT_ATTR_KEYS.T) == t) + masks = self.node_attrs(node_ids=node_ids, attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK] + overlaps = [] + for i in range(len(masks)): + mask_i = masks[i] + for j in range(i + 1, len(masks)): + if mask_i.iou(masks[j]) > iou_threshold: + overlaps.append([node_ids[i], node_ids[j]]) + return overlaps + + for overlaps in multiprocessing_apply( + func=_estimate_overlaps, + sequence=self.time_points(), + desc="Setting overlaps", + ): + self.bulk_add_overlaps(overlaps) diff --git a/src/tracksdata/nodes/__init__.py b/src/tracksdata/nodes/__init__.py index 54ffaa77..8658446b 100644 --- a/src/tracksdata/nodes/__init__.py +++ b/src/tracksdata/nodes/__init__.py @@ -2,7 +2,8 @@ from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs from tracksdata.nodes._mask import Mask +from tracksdata.nodes._node_interpolator import NodeInterpolator from tracksdata.nodes._random import RandomNodes from tracksdata.nodes._regionprops import RegionPropsNodes -__all__ = ["GenericFuncNodeAttrs", "Mask", "RandomNodes", "RegionPropsNodes"] +__all__ = ["GenericFuncNodeAttrs", "Mask", "NodeInterpolator", "RandomNodes", "RegionPropsNodes"] diff --git a/src/tracksdata/nodes/_generic_nodes.py b/src/tracksdata/nodes/_generic_nodes.py index 609d4e79..0d777d3b 100644 --- a/src/tracksdata/nodes/_generic_nodes.py +++ b/src/tracksdata/nodes/_generic_nodes.py @@ -108,7 +108,8 @@ def _init_node_attrs(self, graph: BaseGraph) -> None: """ Initialize the node attributes for the graph. """ - graph.add_node_attr_key(self.output_key, default_value=self.default_value) + if self.output_key not in graph.node_attr_keys: + graph.add_node_attr_key(self.output_key, default_value=self.default_value) def add_node_attrs( self, diff --git a/src/tracksdata/nodes/_mask.py b/src/tracksdata/nodes/_mask.py index bb2a10d9..0697aba8 100644 --- a/src/tracksdata/nodes/_mask.py +++ b/src/tracksdata/nodes/_mask.py @@ -3,11 +3,61 @@ import blosc2 import numpy as np import skimage.morphology as morph +from numba import njit from numpy.typing import ArrayLike, NDArray from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox +@njit +def bbox_interpolation_offset( + tgt_bbox: np.ndarray, + src_bbox: np.ndarray, + w: float, +) -> np.ndarray: + """ + Interpolate the bounding box between two masks. + The reference is the target mask and w is relative distance to it: + ```python + (target_value - new_value) / (target_value - source_value) = w + ``` + + Parameters + ---------- + tgt_bbox : np.ndarray + The target bounding box. + src_bbox : np.ndarray + The source bounding box. + w : float + The weight of the interpolation. + + Returns + ------- + np.ndarray + The offset to add to the bounding box. + """ + if w < 0 or w > 1: + raise ValueError(f"w = {w} is not between 0 and 1") + + ndim = tgt_bbox.shape[0] // 2 + tgt_center = tgt_bbox[ndim:] - tgt_bbox[:ndim] // 2 + src_center = src_bbox[ndim:] - src_bbox[:ndim] // 2 + signed_dist = tgt_center - src_center + offset = -np.round((1 - w) * signed_dist).astype(np.int32) + + for i in range(ndim): + if offset[i] > 0: + new_value = tgt_bbox[ndim + i] - offset[i] + dist_to_border = min(new_value - tgt_bbox[ndim + i], 0) + offset[i] += dist_to_border + else: + new_value = tgt_bbox[i] + offset[i] + dist_to_border = max(tgt_bbox[i] - new_value, 0) + offset[i] += dist_to_border + + return offset + + @lru_cache(maxsize=5) def _spherical_mask( radius: int, @@ -109,7 +159,7 @@ def crop( slicing = tuple(slice(self._bbox[i], self._bbox[i + ndim]) for i in range(ndim)) else: - center = (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2 + center = self.bbox_center() half_shape = np.asarray(shape) // 2 start = np.maximum(center - half_shape, 0) end = np.minimum(center + half_shape, image.shape) @@ -117,6 +167,12 @@ def crop( return image[slicing] + def bbox_center(self) -> NDArray[np.integer]: + """ + Get the center of the bounding box. + """ + return (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2 + def mask_indices( self, offset: NDArray[np.integer] | int = 0, diff --git a/src/tracksdata/nodes/_node_interpolator.py b/src/tracksdata/nodes/_node_interpolator.py new file mode 100644 index 00000000..6961de7f --- /dev/null +++ b/src/tracksdata/nodes/_node_interpolator.py @@ -0,0 +1,331 @@ +import logging +import math +from copy import deepcopy +from typing import Any, Protocol + +from tqdm import tqdm + +from tracksdata.attrs import NodeAttr +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.edges._generic_edges import GenericFuncEdgeAttrs +from tracksdata.graph._base_graph import BaseGraph +from tracksdata.nodes._base_nodes import BaseNodesOperator +from tracksdata.nodes._mask import Mask, bbox_interpolation_offset +from tracksdata.options import get_options +from tracksdata.utils._logging import LOG + + +class NodeInterpolationFunc(Protocol): + def __call__( + self, + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + new_t: int, + delta_t: int, + ) -> dict[str, Any]: ... + + +def _replace_if_exists( + new_attrs: dict[str, Any], + attr_key: str, + value: Any, +) -> None: + """ + Replace the value of an attribute if it exists in the new attributes. + """ + if attr_key in new_attrs: + new_attrs[attr_key] = value + + +def default_node_interpolation( + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + new_t: int, + delta_t: int, +) -> dict[str, Any]: + """ + Default node interpolator. + Interpolates the 'mask', 'z', 'y', 'x' attributes between the source and target nodes. + + Parameters + ---------- + src_attrs : dict[str, Any] + Source node attributes. + tgt_attrs : dict[str, Any] + Target node attributes. + new_t : int + Time point of the new node. + delta_t : int + Current delta time. + + Returns + ------- + dict[str, Any] + Interpolated node. + """ + new_attrs = deepcopy(tgt_attrs) + new_attrs.pop(DEFAULT_ATTR_KEYS.NODE_ID, None) + + new_attrs[DEFAULT_ATTR_KEYS.T] = new_t + + t_tgt = tgt_attrs[DEFAULT_ATTR_KEYS.T] + w = (t_tgt - new_t) / delta_t + + if w < 0 or w > 1: + raise ValueError(f"w = {w} is not between 0 and 1") + + ndim = new_attrs[DEFAULT_ATTR_KEYS.MASK].mask.ndim + new_bbox = new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox + + # updating bounding box + offset = bbox_interpolation_offset( + tgt_bbox=tgt_attrs[DEFAULT_ATTR_KEYS.MASK].bbox, + src_bbox=src_attrs[DEFAULT_ATTR_KEYS.MASK].bbox, + w=w, + ) + + new_bbox[ndim:] = new_bbox[ndim:] + offset + new_bbox[:ndim] = new_bbox[:ndim] + offset + new_attrs[DEFAULT_ATTR_KEYS.MASK].bbox = new_bbox + + for o, attr_key in zip(offset[::-1], ["x", "y", "z"], strict=False): + new_attrs[attr_key] += o + + if tgt_attrs.get(DEFAULT_ATTR_KEYS.TRACK_ID) != src_attrs.get(DEFAULT_ATTR_KEYS.TRACK_ID): + new_attrs[DEFAULT_ATTR_KEYS.TRACK_ID] = -1 + + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.SOLUTION, + src_attrs.get(DEFAULT_ATTR_KEYS.SOLUTION, False) and tgt_attrs.get(DEFAULT_ATTR_KEYS.SOLUTION, False), + ) + + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, + -1, + ) + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.MATCH_SCORE, + 0, + ) + _replace_if_exists( + new_attrs, + DEFAULT_ATTR_KEYS.MATCHED_EDGE_MASK, + False, + ) + + return new_attrs + + +class NodeInterpolationEdgeAttrs(Protocol): + """ + Function to recompute the edge attributes between newly inserted nodes + during node interpolation. + """ + + def __call__( + self, + long_edge: dict[str, Any], + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + new_attrs: dict[str, Any], + ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: ... + + +def default_node_interpolation_edge_attrs( + long_edge: dict[str, Any], + src_attrs: dict[str, Any], + tgt_attrs: dict[str, Any], + new_attrs: dict[str, Any], +) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: + """ + Default node interpolation edge attributes. + By default, it computes: + + - `delta_t`: the absolute difference between the source and target time points + - `edge_weight`: the Euclidean distance between the source and target nodes + - `solution`: it removes the solution attribute from the source and target nodes + and adding to the new edges. + + Parameters + ---------- + long_edge : dict[str, Any] + Edge defining the interpolation. + src_attrs : dict[str, Any] + Source node attributes. + tgt_attrs : dict[str, Any] + Target node attributes. + new_attrs : dict[str, Any] + New node attributes. + + Returns + ------- + tuple[dict[str, Any] | None, dict[str, Any] | None] + Source to new node edge attributes and new to target edge attributes. + Return `None` if no edge should be added. + """ + + new_edge_attrs = [] + + for x, y in [(src_attrs, new_attrs), (new_attrs, tgt_attrs)]: + new_edge_attrs.append( + { + "delta_t": abs(x[DEFAULT_ATTR_KEYS.T] - y[DEFAULT_ATTR_KEYS.T]), + DEFAULT_ATTR_KEYS.EDGE_WEIGHT: math.sqrt( + sum((x.get(attr_key, 0.0) - y.get(attr_key, 0.0)) ** 2 for attr_key in ["z", "y", "x"]) + ), + } + ) + + solution = long_edge.get(DEFAULT_ATTR_KEYS.SOLUTION, False) + if solution is not None: + new_edge_attrs[0][DEFAULT_ATTR_KEYS.SOLUTION] = solution + new_edge_attrs[1][DEFAULT_ATTR_KEYS.SOLUTION] = solution + + return new_edge_attrs + + +class NodeInterpolator(BaseNodesOperator): + """ + Interpolate nodes between non-consecutive time points (delta_t > 1). + """ + + def __init__( + self, + delta_t_key: str = "delta_t", + node_interpolation_func: NodeInterpolationFunc = default_node_interpolation, + edge_attrs_func: NodeInterpolationEdgeAttrs = default_node_interpolation_edge_attrs, + validate_keys: bool = False, + iou_threshold: float = 0.5, + ): + super().__init__() + self.delta_t_key = delta_t_key + self.node_interpolation_func = node_interpolation_func + self.edge_attrs_func = edge_attrs_func + self.validate_keys = validate_keys + self.iou_threshold = iou_threshold + + def add_nodes(self, graph: BaseGraph, *, t: None = None) -> None: + if t is not None: + raise ValueError("'t' must be None for node interpolation") + + if self.delta_t_key not in graph.edge_attr_keys: + LOG.warning( + "The key '%s' is not in graph.edge_attrs (%s). Inserting edge attribute `delta_t` using '%s' attribute", + self.delta_t_key, + graph.edge_attr_keys, + DEFAULT_ATTR_KEYS.T, + ) + GenericFuncEdgeAttrs( + func=lambda x, y: abs(x - y), + attr_keys=DEFAULT_ATTR_KEYS.T, + output_key=self.delta_t_key, + ).add_edge_attrs(graph) + + edge_attrs = graph.edge_attrs() + long_edges = edge_attrs.filter(edge_attrs[self.delta_t_key] > 1).sort(self.delta_t_key) + + selected_node_ids = set(long_edges[DEFAULT_ATTR_KEYS.EDGE_SOURCE]) | set( + long_edges[DEFAULT_ATTR_KEYS.EDGE_TARGET] + ) + + nodes = graph.node_attrs( + node_ids=list(selected_node_ids), + ) + nodes_by_id = {node[DEFAULT_ATTR_KEYS.NODE_ID]: node for node in nodes.iter_rows(named=True)} + + count = 0 + + for long_edge in tqdm( + list(long_edges.iter_rows(named=True)), + disable=not get_options().show_progress, + desc="Interpolating and adding nodes", + ): + delta_t = long_edge[self.delta_t_key] + + src_id = long_edge[DEFAULT_ATTR_KEYS.EDGE_SOURCE] + tgt_id = long_edge[DEFAULT_ATTR_KEYS.EDGE_TARGET] + + src_attrs = nodes_by_id[src_id] + tgt_attrs = nodes_by_id[tgt_id] + + while delta_t > 1: + new_node_attrs = self.node_interpolation_func( + src_attrs=src_attrs, + tgt_attrs=tgt_attrs, + new_t=tgt_attrs[DEFAULT_ATTR_KEYS.T] - 1, + delta_t=delta_t, + ) + + masks_in_t = graph.node_attrs( + node_ids=graph.filter_nodes_by_attrs( + NodeAttr(DEFAULT_ATTR_KEYS.T) == new_node_attrs[DEFAULT_ATTR_KEYS.T] + ), + attr_keys=[DEFAULT_ATTR_KEYS.MASK], + )[DEFAULT_ATTR_KEYS.MASK] + new_mask: Mask = new_node_attrs[DEFAULT_ATTR_KEYS.MASK] + + found_collision = False + for mask in masks_in_t: + iou = new_mask.iou(mask) + if iou > self.iou_threshold: + found_collision = True + break + + if found_collision: + break + + new_node_id = graph.add_node(new_node_attrs, validate_keys=self.validate_keys) + count += 1 + nodes_by_id[new_node_id] = new_node_attrs + + src_to_new_edge_attrs, new_to_tgt_edge_attrs = self.edge_attrs_func( + long_edge=long_edge, + src_attrs=src_attrs, + tgt_attrs=tgt_attrs, + new_attrs=new_node_attrs, + ) + + if src_to_new_edge_attrs is not None: + graph.add_edge( + source_id=src_id, + target_id=new_node_id, + attrs=src_to_new_edge_attrs, + validate_keys=self.validate_keys, + ) + + if new_to_tgt_edge_attrs is not None: + graph.add_edge( + source_id=new_node_id, + target_id=tgt_id, + attrs=new_to_tgt_edge_attrs, + validate_keys=self.validate_keys, + ) + + if long_edge.get(DEFAULT_ATTR_KEYS.SOLUTION, False): + graph.update_edge_attrs( + attrs={DEFAULT_ATTR_KEYS.SOLUTION: False}, + edge_ids=[long_edge[DEFAULT_ATTR_KEYS.EDGE_ID]], + ) + + if LOG.isEnabledFor(logging.INFO): + LOG.info("s -> t (before): %s", long_edge[DEFAULT_ATTR_KEYS.SOLUTION]) + LOG.info( + "s -> t (after): %s", + graph.edge_attrs(attr_keys=[DEFAULT_ATTR_KEYS.SOLUTION]) + .filter(edge_id=long_edge[DEFAULT_ATTR_KEYS.EDGE_ID]) + .to_dicts()[0], + ) + LOG.info("s -> n: %s", src_to_new_edge_attrs) + LOG.info("n -> t: %s", new_to_tgt_edge_attrs) + + delta_t -= 1 + tgt_attrs = new_node_attrs + tgt_id = new_node_id + # replacing long_edge by new shorter but still long edge + long_edge = src_to_new_edge_attrs.copy() + + LOG.info("Number of nodes added: %s", count) + print(f"Number of nodes added: {count}")