Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/tracksdata/edges/_distance_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None:
"""
Initialize the edge attributes for the graph.
"""
graph.add_edge_attr_key(self.output_key, default_value=-99999.0)
if self.output_key not in graph.edge_attr_keys:
graph.add_edge_attr_key(self.output_key, default_value=-99999.0)

def _add_edges_per_time(
self,
Expand Down
3 changes: 2 additions & 1 deletion src/tracksdata/edges/_generic_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def _init_edge_attrs(self, graph: BaseGraph) -> None:
"""
Initialize the edge attributes for the graph.
"""
graph.add_edge_attr_key(self.output_key, default_value=-99999.0)
if self.output_key not in graph.edge_attr_keys:
graph.add_edge_attr_key(self.output_key, default_value=-99999.0)

def _edge_attrs_per_time(
self,
Expand Down
40 changes: 39 additions & 1 deletion src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import polars as pl
from numpy.typing import ArrayLike

from tracksdata.attrs import AttrComparison
from tracksdata.attrs import AttrComparison, NodeAttr
from tracksdata.constants import DEFAULT_ATTR_KEYS
from tracksdata.utils._logging import LOG
from tracksdata.utils._multiprocessing import multiprocessing_apply

if TYPE_CHECKING:
from tracksdata.graph._graph_view import GraphView
Expand Down Expand Up @@ -792,3 +793,40 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
graph.bulk_add_overlaps(overlaps.tolist())

return graph

def set_overlaps(self, iou_threshold: float = 0.0) -> None:
"""
Find overlapping nodes within each frame and add them their overlap relation into the graph.

Parameters
----------
iou_threshold : float
Nodes with an IoU greater than this threshold are considered overlapping.
If 0, all nodes are considered overlapping.

Examples
--------
```python
graph.set_overlaps(iou_threshold=0.5)
```
"""
if iou_threshold < 0.0 or iou_threshold > 1.0:
raise ValueError("iou_threshold must be between 0.0 and 1.0")

def _estimate_overlaps(t: int) -> list[list[int, 2]]:
node_ids = self.filter_nodes_by_attrs(NodeAttr(DEFAULT_ATTR_KEYS.T) == t)
masks = self.node_attrs(node_ids=node_ids, attr_keys=[DEFAULT_ATTR_KEYS.MASK])[DEFAULT_ATTR_KEYS.MASK]
overlaps = []
for i in range(len(masks)):
mask_i = masks[i]
for j in range(i + 1, len(masks)):
if mask_i.iou(masks[j]) > iou_threshold:
overlaps.append([node_ids[i], node_ids[j]])
return overlaps

for overlaps in multiprocessing_apply(
func=_estimate_overlaps,
sequence=self.time_points(),
desc="Setting overlaps",
):
self.bulk_add_overlaps(overlaps)
3 changes: 2 additions & 1 deletion src/tracksdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from tracksdata.nodes._generic_nodes import GenericFuncNodeAttrs
from tracksdata.nodes._mask import Mask
from tracksdata.nodes._node_interpolator import NodeInterpolator
from tracksdata.nodes._random import RandomNodes
from tracksdata.nodes._regionprops import RegionPropsNodes

__all__ = ["GenericFuncNodeAttrs", "Mask", "RandomNodes", "RegionPropsNodes"]
__all__ = ["GenericFuncNodeAttrs", "Mask", "NodeInterpolator", "RandomNodes", "RegionPropsNodes"]
3 changes: 2 additions & 1 deletion src/tracksdata/nodes/_generic_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def _init_node_attrs(self, graph: BaseGraph) -> None:
"""
Initialize the node attributes for the graph.
"""
graph.add_node_attr_key(self.output_key, default_value=self.default_value)
if self.output_key not in graph.node_attr_keys:
graph.add_node_attr_key(self.output_key, default_value=self.default_value)

def add_node_attrs(
self,
Expand Down
58 changes: 57 additions & 1 deletion src/tracksdata/nodes/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,61 @@
import blosc2
import numpy as np
import skimage.morphology as morph
from numba import njit
from numpy.typing import ArrayLike, NDArray

from tracksdata.functional._iou import fast_intersection_with_bbox, fast_iou_with_bbox


@njit
def bbox_interpolation_offset(
tgt_bbox: np.ndarray,
src_bbox: np.ndarray,
w: float,
) -> np.ndarray:
"""
Interpolate the bounding box between two masks.
The reference is the target mask and w is relative distance to it:
```python
(target_value - new_value) / (target_value - source_value) = w
```

Parameters
----------
tgt_bbox : np.ndarray
The target bounding box.
src_bbox : np.ndarray
The source bounding box.
w : float
The weight of the interpolation.

Returns
-------
np.ndarray
The offset to add to the bounding box.
"""
if w < 0 or w > 1:
raise ValueError(f"w = {w} is not between 0 and 1")

ndim = tgt_bbox.shape[0] // 2
tgt_center = tgt_bbox[ndim:] - tgt_bbox[:ndim] // 2
src_center = src_bbox[ndim:] - src_bbox[:ndim] // 2
signed_dist = tgt_center - src_center
offset = -np.round((1 - w) * signed_dist).astype(np.int32)

for i in range(ndim):
if offset[i] > 0:
new_value = tgt_bbox[ndim + i] - offset[i]
dist_to_border = min(new_value - tgt_bbox[ndim + i], 0)
offset[i] += dist_to_border
else:
new_value = tgt_bbox[i] + offset[i]
dist_to_border = max(tgt_bbox[i] - new_value, 0)
offset[i] += dist_to_border

return offset


@lru_cache(maxsize=5)
def _spherical_mask(
radius: int,
Expand Down Expand Up @@ -109,14 +159,20 @@ def crop(
slicing = tuple(slice(self._bbox[i], self._bbox[i + ndim]) for i in range(ndim))

else:
center = (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2
center = self.bbox_center()
half_shape = np.asarray(shape) // 2
start = np.maximum(center - half_shape, 0)
end = np.minimum(center + half_shape, image.shape)
slicing = tuple(slice(s, e) for s, e in zip(start, end, strict=True))

return image[slicing]

def bbox_center(self) -> NDArray[np.integer]:
"""
Get the center of the bounding box.
"""
return (self._bbox[: self._mask.ndim] + self._bbox[self._mask.ndim :]) // 2

def mask_indices(
self,
offset: NDArray[np.integer] | int = 0,
Expand Down
Loading