diff --git a/src/tracksdata/constants.py b/src/tracksdata/constants.py index 762fa38b..53a3bbac 100644 --- a/src/tracksdata/constants.py +++ b/src/tracksdata/constants.py @@ -55,6 +55,7 @@ class DefaultAttrKeys: T = "t" MASK = "mask" BBOX = "bbox" + CENTROID = ("z", "y", "x") SOLUTION = "solution" TRACKLET_ID = "tracklet_id" diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index 31f2bec7..70e3efac 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -882,6 +882,9 @@ def match( matched_node_id_key: str = DEFAULT_ATTR_KEYS.MATCHED_NODE_ID, match_score_key: str = DEFAULT_ATTR_KEYS.MATCH_SCORE, matched_edge_mask_key: str = DEFAULT_ATTR_KEYS.MATCHED_EDGE_MASK, + matching_mode: str = "mask", + scale: tuple[float, ...] | None = None, + max_distance: float | None = None, ) -> None: """ Match the nodes of the graph to the nodes of another graph. @@ -896,6 +899,15 @@ def match( The key of the output value of the match score between matched nodes matched_edge_mask_key : str The key of the output as a boolean value indicating if a corresponding edge exists in the other graph. + matching_mode : str + The mode to use for matching. Either "mask" (default) for mask-based matching + or "distance" for centroid distance-based matching. + scale : tuple[float, ...] | None + Physical scale for each spatial dimension (e.g., (z, y, x)) to account for anisotropy. + Only used when matching_mode is "distance". If None, assumes isotropic data. + max_distance : float | None + Maximum distance between centroids to be considered as a match. + Only used when matching_mode is "distance". If None, all pairs are considered. """ from tracksdata.metrics._ctc_metrics import _matching_data @@ -905,6 +917,9 @@ def match( input_graph_key=DEFAULT_ATTR_KEYS.NODE_ID, reference_graph_key=DEFAULT_ATTR_KEYS.NODE_ID, optimal_matching=True, + matching_mode=matching_mode, + scale=scale, + max_distance=max_distance, ) if matched_node_id_key not in self.node_attr_keys(): diff --git a/src/tracksdata/metrics/_ctc_metrics.py b/src/tracksdata/metrics/_ctc_metrics.py index 8e951bf3..f727056d 100644 --- a/src/tracksdata/metrics/_ctc_metrics.py +++ b/src/tracksdata/metrics/_ctc_metrics.py @@ -38,6 +38,10 @@ def _match_single_frame( input_graph_key: str, optimal_matching: bool = False, min_reference_intersection: float = 0.5, + matching_mode: str = "mask", + max_distance: float | None = None, + centroid_keys: tuple[str, ...] = ("y", "x"), + scale: tuple[float, ...] | None = None, ) -> tuple[list[int], list[int], list[float]]: """ Match the groups of the reference and input graphs for a single time point. @@ -56,11 +60,22 @@ def _match_single_frame( Whether to solve the optimal matching. min_reference_intersection : float, optional The minimum intersection between the reference and input masks to be considered as a match. + Only used when matching_mode is "mask". + matching_mode : str, optional + The mode to use for matching. Either "mask" (default) for mask-based matching + or "distance" for centroid distance-based matching. + max_distance : float | None, optional + The maximum distance between centroids to be considered as a match. + Only used when matching_mode is "distance". If None, all pairs are considered. + centroid_keys : tuple[str, ...], optional + The keys to obtain the centroid coordinates from the dataframe. + Default is ("z", "y", "x") for 3D data. Returns ------- tuple[list[int], list[int], list[float]] - For each matching, their respective reference id, input id and IoU. + For each matching, their respective reference id, input id and score. + Score is IoU for mask matching or 1/(1+distance) for distance matching. """ try: ref_group = groups_by_time["ref"][t] @@ -69,35 +84,62 @@ def _match_single_frame( except KeyError: return [], [], [] - if ref_group[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Binary: - ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK) - comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK) - _mapped_ref = [] _mapped_comp = [] _ious = [] _rows = [] _cols = [] - for i, (ref_id, ref_mask) in enumerate( - zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True) - ): - for j, (comp_id, comp_mask) in enumerate( - zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True) + if matching_mode == "distance": + # Distance-based matching using centroid coordinates + ref_centroids = ref_group.select(list(centroid_keys)).to_numpy() + comp_centroids = comp_group.select(list(centroid_keys)).to_numpy() + + # Apply scale for anisotropic data + if scale is not None: + scale_arr = np.array(scale) + ref_centroids = ref_centroids * scale_arr + comp_centroids = comp_centroids * scale_arr + + for i, (ref_id, ref_centroid) in enumerate( + zip(ref_group[reference_graph_key], ref_centroids, strict=True) ): - # intersection over reference is used to select the matches - inter = ref_mask.intersection(comp_mask) - ctc_score = inter / ref_mask.size - if ctc_score > min_reference_intersection: - _mapped_ref.append(ref_id) - _mapped_comp.append(comp_id) - _rows.append(i) - _cols.append(j) - - # NOTE: there was something weird with IoU, the length when compared with `ctc_metrics` - # sometimes it had an extra element - iou = inter / (ref_mask.size + comp_mask.size - inter) - _ious.append(iou.item()) + for j, (comp_id, comp_centroid) in enumerate( + zip(comp_group[input_graph_key], comp_centroids, strict=True) + ): + dist = np.linalg.norm(ref_centroid - comp_centroid) + if max_distance is None or dist <= max_distance: + _mapped_ref.append(ref_id) + _mapped_comp.append(comp_id) + _rows.append(i) + _cols.append(j) + # Use 1/(1+dist) so larger = better for bipartite matching + _ious.append(1.0 / (1.0 + dist)) + else: + # Mask-based matching (default) + if ref_group[DEFAULT_ATTR_KEYS.MASK].dtype == pl.Binary: + ref_group = column_from_bytes(ref_group, DEFAULT_ATTR_KEYS.MASK) + comp_group = column_from_bytes(comp_group, DEFAULT_ATTR_KEYS.MASK) + + for i, (ref_id, ref_mask) in enumerate( + zip(ref_group[reference_graph_key], ref_group[DEFAULT_ATTR_KEYS.MASK], strict=True) + ): + for j, (comp_id, comp_mask) in enumerate( + zip(comp_group[input_graph_key], comp_group[DEFAULT_ATTR_KEYS.MASK], strict=True) + ): + # intersection over reference is used to select the matches + inter = ref_mask.intersection(comp_mask) + ctc_score = inter / ref_mask.size + if ctc_score > min_reference_intersection: + _mapped_ref.append(ref_id) + _mapped_comp.append(comp_id) + _rows.append(i) + _cols.append(j) + + # NOTE: there was something weird with IoU, the length when compared with `ctc_metrics` + # sometimes it had an extra element + iou = inter / (ref_mask.size + comp_mask.size - inter) + _ious.append(iou.item()) if optimal_matching and len(_rows) > 0: LOG.info("Solving optimal matching ...") @@ -127,6 +169,10 @@ def _matching_data( reference_graph_key: str, optimal_matching: bool = False, min_reference_intersection: float = 0.5, + matching_mode: str = "mask", + max_distance: float | None = None, + centroid_keys: tuple[str, ...] | None = None, + scale: tuple[float, ...] | None = None, ) -> dict[str, list[list]]: """ Compute matching data for CTC metrics. @@ -148,6 +194,16 @@ def _matching_data( Whether to solve the optimal matching. min_reference_intersection : float, optional Minimum coverage of a reference mask to be considered as a possible match. + Only used when matching_mode is "mask". + matching_mode : str, optional + The mode to use for matching. Either "mask" (default) for mask-based matching + or "distance" for centroid distance-based matching. + max_distance : float | None, optional + The maximum distance between centroids to be considered as a match. + Only used when matching_mode is "distance". If None, all pairs are considered. + centroid_keys : tuple[str, ...] | None, optional + The keys to obtain the centroid coordinates from the node attributes. + If None, defaults to DEFAULT_ATTR_KEYS.CENTROID (("z", "y", "x")). Returns ------- @@ -163,13 +219,25 @@ def _matching_data( n_workers = get_options().n_workers + # Set default centroid keys + if centroid_keys is None: + centroid_keys = DEFAULT_ATTR_KEYS.CENTROID + + # Determine which attributes to fetch based on matching mode + if matching_mode == "distance": + attr_keys = [DEFAULT_ATTR_KEYS.T, None, *centroid_keys] # None placeholder for tracklet_id_key + else: + attr_keys = [DEFAULT_ATTR_KEYS.T, None, DEFAULT_ATTR_KEYS.MASK] + # computing unique labels for each graph for name, graph, tracklet_id_key in [ ("ref", reference_graph, reference_graph_key), ("comp", input_graph, input_graph_key), ]: - nodes_df = graph.node_attrs(attr_keys=[DEFAULT_ATTR_KEYS.T, tracklet_id_key, DEFAULT_ATTR_KEYS.MASK]) - if n_workers > 1: + # Replace placeholder with actual tracklet_id_key + fetch_keys = [tracklet_id_key if k is None else k for k in attr_keys] + nodes_df = graph.node_attrs(attr_keys=fetch_keys) + if n_workers > 1 and matching_mode == "mask": # required by multiprocessing nodes_df = column_to_bytes(nodes_df, DEFAULT_ATTR_KEYS.MASK) labels = {} @@ -206,6 +274,10 @@ def _matching_data( input_graph_key=input_graph_key, optimal_matching=optimal_matching, min_reference_intersection=min_reference_intersection, + matching_mode=matching_mode, + max_distance=max_distance, + centroid_keys=centroid_keys, + scale=scale, ) for _mapped_ref, _mapped_comp, _ious in multiprocessing_apply(