Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/tracksdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DefaultAttrKeys:
T = "t"
MASK = "mask"
BBOX = "bbox"
CENTROID = ("z", "y", "x")
SOLUTION = "solution"
TRACKLET_ID = "tracklet_id"

Expand Down
15 changes: 15 additions & 0 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,9 @@ def match(
matched_node_id_key: str = DEFAULT_ATTR_KEYS.MATCHED_NODE_ID,
match_score_key: str = DEFAULT_ATTR_KEYS.MATCH_SCORE,
matched_edge_mask_key: str = DEFAULT_ATTR_KEYS.MATCHED_EDGE_MASK,
matching_mode: str = "mask",
scale: tuple[float, ...] | None = None,
max_distance: float | None = None,
) -> None:
"""
Match the nodes of the graph to the nodes of another graph.
Expand All @@ -896,6 +899,15 @@ def match(
The key of the output value of the match score between matched nodes
matched_edge_mask_key : str
The key of the output as a boolean value indicating if a corresponding edge exists in the other graph.
matching_mode : str
The mode to use for matching. Either "mask" (default) for mask-based matching
or "distance" for centroid distance-based matching.
scale : tuple[float, ...] | None
Physical scale for each spatial dimension (e.g., (z, y, x)) to account for anisotropy.
Only used when matching_mode is "distance". If None, assumes isotropic data.
max_distance : float | None
Maximum distance between centroids to be considered as a match.
Only used when matching_mode is "distance". If None, all pairs are considered.
"""
from tracksdata.metrics._ctc_metrics import _matching_data

Expand All @@ -905,6 +917,9 @@ def match(
input_graph_key=DEFAULT_ATTR_KEYS.NODE_ID,
reference_graph_key=DEFAULT_ATTR_KEYS.NODE_ID,
optimal_matching=True,
matching_mode=matching_mode,
scale=scale,
max_distance=max_distance,
)

if matched_node_id_key not in self.node_attr_keys():
Expand Down
122 changes: 97 additions & 25 deletions src/tracksdata/metrics/_ctc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def _match_single_frame(
input_graph_key: str,
optimal_matching: bool = False,
min_reference_intersection: float = 0.5,
matching_mode: str = "mask",
max_distance: float | None = None,
centroid_keys: tuple[str, ...] = ("y", "x"),
scale: tuple[float, ...] | None = None,
) -> tuple[list[int], list[int], list[float]]:
"""
Match the groups of the reference and input graphs for a single time point.
Expand All @@ -56,11 +60,22 @@ def _match_single_frame(
Whether to solve the optimal matching.
min_reference_intersection : float, optional
The minimum intersection between the reference and input masks to be considered as a match.
Only used when matching_mode is "mask".
matching_mode : str, optional
The mode to use for matching. Either "mask" (default) for mask-based matching
or "distance" for centroid distance-based matching.
max_distance : float | None, optional
The maximum distance between centroids to be considered as a match.
Only used when matching_mode is "distance". If None, all pairs are considered.
centroid_keys : tuple[str, ...], optional
The keys to obtain the centroid coordinates from the dataframe.
Default is ("z", "y", "x") for 3D data.

Returns
-------
tuple[list[int], list[int], list[float]]
For each matching, their respective reference id, input id and IoU.
For each matching, their respective reference id, input id and score.
Score is IoU for mask matching or 1/(1+distance) for distance matching.
"""
try:
ref_group = groups_by_time["ref"][t]
Expand All @@ -69,35 +84,62 @@ def _match_single_frame(
except KeyError:
return [], [], []

if ref_group[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Binary:
ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK)
comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK)

_mapped_ref = []
_mapped_comp = []
_ious = []
_rows = []
_cols = []

for i, (ref_id, ref_mask) in enumerate(
zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
):
for j, (comp_id, comp_mask) in enumerate(
zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
if matching_mode == "distance":
# Distance-based matching using centroid coordinates
ref_centroids = ref_group.select(list(centroid_keys)).to_numpy()
comp_centroids = comp_group.select(list(centroid_keys)).to_numpy()

# Apply scale for anisotropic data
if scale is not None:
scale_arr = np.array(scale)
ref_centroids = ref_centroids * scale_arr
comp_centroids = comp_centroids * scale_arr

for i, (ref_id, ref_centroid) in enumerate(
zip(ref_group[reference_graph_key], ref_centroids, strict=True)
):
# intersection over reference is used to select the matches
inter = ref_mask.intersection(comp_mask)
ctc_score = inter / ref_mask.size
if ctc_score > min_reference_intersection:
_mapped_ref.append(ref_id)
_mapped_comp.append(comp_id)
_rows.append(i)
_cols.append(j)

# NOTE: there was something weird with IoU, the length when compared with `ctc_metrics`
# sometimes it had an extra element
iou = inter / (ref_mask.size + comp_mask.size - inter)
_ious.append(iou.item())
for j, (comp_id, comp_centroid) in enumerate(
zip(comp_group[input_graph_key], comp_centroids, strict=True)
):
dist = np.linalg.norm(ref_centroid - comp_centroid)
if max_distance is None or dist <= max_distance:
_mapped_ref.append(ref_id)
_mapped_comp.append(comp_id)
_rows.append(i)
_cols.append(j)
# Use 1/(1+dist) so larger = better for bipartite matching
_ious.append(1.0 / (1.0 + dist))
else:
# Mask-based matching (default)
if ref_group[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Binary:
ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK)
comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK)

for i, (ref_id, ref_mask) in enumerate(
zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
):
for j, (comp_id, comp_mask) in enumerate(
zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True)
):
# intersection over reference is used to select the matches
inter = ref_mask.intersection(comp_mask)
ctc_score = inter / ref_mask.size
if ctc_score > min_reference_intersection:
_mapped_ref.append(ref_id)
_mapped_comp.append(comp_id)
_rows.append(i)
_cols.append(j)

# NOTE: there was something weird with IoU, the length when compared with `ctc_metrics`
# sometimes it had an extra element
iou = inter / (ref_mask.size + comp_mask.size - inter)
_ious.append(iou.item())

if optimal_matching and len(_rows) > 0:
LOG.info("Solving optimal matching ...")
Expand Down Expand Up @@ -127,6 +169,10 @@ def _matching_data(
reference_graph_key: str,
optimal_matching: bool = False,
min_reference_intersection: float = 0.5,
matching_mode: str = "mask",
max_distance: float | None = None,
centroid_keys: tuple[str, ...] | None = None,
scale: tuple[float, ...] | None = None,
) -> dict[str, list[list]]:
"""
Compute matching data for CTC metrics.
Expand All @@ -148,6 +194,16 @@ def _matching_data(
Whether to solve the optimal matching.
min_reference_intersection : float, optional
Minimum coverage of a reference mask to be considered as a possible match.
Only used when matching_mode is "mask".
matching_mode : str, optional
The mode to use for matching. Either "mask" (default) for mask-based matching
or "distance" for centroid distance-based matching.
max_distance : float | None, optional
The maximum distance between centroids to be considered as a match.
Only used when matching_mode is "distance". If None, all pairs are considered.
centroid_keys : tuple[str, ...] | None, optional
The keys to obtain the centroid coordinates from the node attributes.
If None, defaults to DEFAULT_ATTR_KEYS.CENTROID (("z", "y", "x")).

Returns
-------
Expand All @@ -163,13 +219,25 @@ def _matching_data(

n_workers = get_options().n_workers

# Set default centroid keys
if centroid_keys is None:
centroid_keys = DEFAULT_ATTR_KEYS.CENTROID

# Determine which attributes to fetch based on matching mode
if matching_mode == "distance":
attr_keys = [DEFAULT_ATTR_KEYS.T, None, *centroid_keys] # None placeholder for tracklet_id_key
else:
attr_keys = [DEFAULT_ATTR_KEYS.T, None, DEFAULT_ATTR_KEYS.MASK]

# computing unique labels for each graph
for name, graph, tracklet_id_key in [
("ref", reference_graph, reference_graph_key),
("comp", input_graph, input_graph_key),
]:
nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.T, tracklet_id_key, DEFAULT_ATTR_KEYS.MASK])
if n_workers > 1:
# Replace placeholder with actual tracklet_id_key
fetch_keys = [tracklet_id_key if k is None else k for k in attr_keys]
nodes_df = graph.node_attrs(attr_keys=fetch_keys)
if n_workers > 1 and matching_mode == "mask":
# required by multiprocessing
nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK)
labels = {}
Expand Down Expand Up @@ -206,6 +274,10 @@ def _matching_data(
input_graph_key=input_graph_key,
optimal_matching=optimal_matching,
min_reference_intersection=min_reference_intersection,
matching_mode=matching_mode,
max_distance=max_distance,
centroid_keys=centroid_keys,
scale=scale,
)

for _mapped_ref, _mapped_comp, _ious in multiprocessing_apply(
Expand Down
Loading