Skip to content

ENH: Add boundary detection metrics, fix #744#820

Merged
NickleDave merged 11 commits intomainfrom
ENH-add-boundary-detection-metrics-fix-#744
Mar 13, 2026
Merged

ENH: Add boundary detection metrics, fix #744#820
NickleDave merged 11 commits intomainfrom
ENH-add-boundary-detection-metrics-fix-#744

Conversation

@NickleDave
Copy link
Collaborator

This fixes #744.

Note I named the module metrics.boundary_detection (not segmentation as originally planned in that issue).

I prefer this more specific name, and it avoids a name clash with other segmentation metrics, like segmental F1 (see #819).

I also chose to not use torchmetrics. Because of the way their abstractions work, we would have to pay the cost of repeatedly calling find_hits to compute precision, recall, F-score, etc., even though we would get back the exact same hits each time we call it. There might be a clever way around that, e.g., with class hierarchies and/or passing around the LightningModule, but I want to avoid being clever and regretting it later.

Instead we call find_hits once inside the function precision_recall_fscore_rval, and then compute whatever metrics a user asks for, any combination of {"precision", "recall", "fscore", "rval"}. This gets returned as a dict so we can call self.log_dict inside a LightningModule, logging them all in one fell swoop.

Aside: I am going to stop doing squash merges, based on this comment and this one. I was able to get this PR down from 99 commits to 11 through some (painful) interactive rebasing.

@NickleDave NickleDave added ENH: enhancement enhancement; new feature or request Metrics Issue related to metrics labels Mar 13, 2026
WIP: Add vak/metrics/segmentation/

Fixup translation of vak.metrics.segmentation.ir.validators.is_valid_boundaries_tensor to torch

Fixup translation of vak.metrics.segmentation.ir.functional.find_hits to torch

Fix return value names in ir.functional.find_hits docstring: hits_ref -> hits_target, hits_hyp -> hits_preds

Fix type annotations, docstrings, return types in src/vak/metrics/segmentation/ir/functional.py

Fix more return type hints + docstrings in src/vak/metrics/segmentation/ir/functional.py

Flesh out src/vak/metrics/segmentation/ir/ir.py

Rewrite src/vak/metrics/segmentation/ir/ir.py

Add ignore_val to BaseSegmentationIRMetric

Rename -> vak.metrics.boundary_detection

Fixup boundary_detection/validators.py

Rewrite boundary_detection/functional.py

WIP: Add R-value to boundary detection metrics

WIP: Rewrite classes in metrics.boundary_detection

fixup typo in src/vak/metrics/boundary_detection/_boundary_detection.py

Clean up `find_hits` docstring

Add comment in metrics.boundary_detection.functional

Rewrite metrics/boundary_detection/validators.py

Rewrite metrics/boundary_detection/functional.py again

Rewrite PrecisionRecallFScoreRVal

Fixup rewrite of find_hits

Fixup imports in src/vak/metrics/boundary_detection/__init__.py

Fix how we test pre-conditions in find_hits

Import boundary_detection module in vak.metrics

Fixup metrics.boundary_detection.precision_recall_fscore_rval

Fixup tensor/torch issues in precision_recall_fscore_rval

Clean up code for handling edge cases in precision_recall_fscore_rval

Fixup inner loop in precision_recall_fscore_rval

Set default tolerance for PrecisionRecallFscoreRval to 0.01

Change default tolerance, revise docstrings in boundary_detection/functional.py

Revise docstring for precision_recall_fscore_rval

Add docstring for PrecisionRecallFscoreRval

Fix whitespace in _boundary_detection.py

Fixup add metrics.boundary_detection.validators.is_non_negative

Fixup add metrics.boundary_detection.validators.is_strictly_increasing

Add metrics.boundary_detection.validators.is_2d_tensor

Add/fix pre-conditions in precision_recall_fscore_rval

Fix vocalpy -> vak.metrics in metrics.boundary_detection.validators

Fixup add boundary_detection.validators

Fixup add boundary_detection/functional.py

- Fix pre-conditions for precision_recall_fscore_rval
WIP: Add tests/test_metrics/test_segmentation/

Add unit tests for vak.metrics.segmentation.ir.validators.is_1d_tensor

Rename -> test_metrics/test_segmentation/test_ir/test_functional.py

WIP: Fix / add more tests to tests/test_metrics/test_segmentation/test_ir/test_validators.py

Finish adding unit tests in test_ir/test_validators.py

WIP: Add unit tests in tests/test_metrics/test_segmentation/test_ir/test_functional.py

Add tests/test_metrics/test_segmentation/test_ir/__init__.py

Rename -> tests/test_metrics/boundary_detection

WIP: Rewrite tests in boundary_detection/test_functional.py

WIP: Rewrite unit tests in tests/test_metrics/boundary_detection/test_functional.py

WIP: Rewrite unit tests in boundary_detection/test_functional.py

Rename -> test_metrics/test_boundary_detection

Remove tests/test_metrics/test_segmentation.py

WIP: Rewrite unit tests in boundary_detection/test_functional.py

WIP: Rewrite unit tests in boundary_detection/test_functional.py

Get find_hits unit test to at least run without fail

WIP: Fixing up unit tests for precision_recall_fscore_rval

WIP: Fixing up unit tests for precision_recall_fscore_rval

WIP: Fixing up unit tests for precision_recall_fscore_rval

Finish fixing test cases for precision_recall_fscore_rval unit test

Move test case lists into test_boundary_detection/conftest.py

Fix up test cases for find hits

Add unit test for PrecisionRecallFscoreRval

Fix imports in tests/test_metrics/test_boundary_detection/test_validators.py

Fixup add unit tests in test_boundary_detection/test_validators.py

- Fix up / add unit tests for new validators
- Remove unit tests for removed is_valid_boundaries_tensor

WIP: test that find_hits raises expected exceptions

Add unit test that find_hits raises expected exceptions

Add unit_test: test_precision_recall_fscore_rval_raises

Fixup add unit tests for boundary_detection.validators

Fixup add unit tests for metrics.boundary_detection.functional

Fix up comments explaining test cases
This adds functions used by `metrics.boundary_detection`.
Adds `transforms.frame_labels.functional.to_boundary_times`
for use with boundary detection IR metrics.
Also adds DEFAULT_BOUNDARY_TIMES_PADVAL to common.constants

First finds boundary indices using `torch.diff` on
`frame_labels`, and then uses those to index into
`frame_times`, now returned by `InferDatapipe`.
My first attempt at an implementation instead
computed boundary times by getting a frame duration
with `timebins.timebin_dur_from_vec`, but its better to
just use the times in `frame_times` directly,
instead of adding (more) floating point noise
by converting the boundary times to frame times
with a conversion factor.
Also this saves us a bit of computation.

Works on 2-D batches.
Because the computed vector of boundary times for each item
in the batch can (and likely will) have different lengths,
we allow a `padval` that defaults to
`common.constants.DEFAULT_BOUNDARY_TIMES_PADVAL`.
We use this constant so that it can also be the default for
the `ignore_val` used by the functions in
`metrics.boundary_detection`.
Adds boundary detection IR metrics to TweetyNet model definition.
Add default args to use the default `ignore_val`
anbd a default `tolerance`.
@NickleDave NickleDave force-pushed the ENH-add-boundary-detection-metrics-fix-#744 branch from e97ef09 to 47014a2 Compare March 13, 2026 15:22
Fix how method handles `loss` and `metrics`, so
that it handles both `None` from the config passed in
(because it's a `dict` and we call `get` on it),
as well as an empty dict, that we get from
the default for `loss` and `metrics` on the ModelConfig class.

In practice we usually have the latter, but unit tests
were testing for the former.
This is a quick and dirty fix to keep stuff from crashing.
I'm planning to remove all this machinery anyway,
so not going to overthink it.
Make `InferDatapipe` return `frame_times`,
i.e. the vector of time bin centers from npz/mat
files containing spectrograms.

This is hacked in because I'm expecting
to replace this class anyways.
It might break things for someone fitting a model
with audio, but my guess is there are exactly zero users
actually doing that right now.
Adds logic to handle boundary detection metrics in
`models.FrameClassification.validation_step`.
Need to "cast" the parameter `annotated_files` to a list,
so that we don't get an error `ArrowStringArray does not
have attribute `remove` -- not sure why we didn't see
this before since we were passing in `values` from a
pandas.Series, which would be a `numpy.array`;
I guess maybe that also has a `remove` method?
But newer versions of pandas must be remove this Arrow
array instead. This makes sure we have a list
we can `remove` things from, which is congruous with
what the typehints say `annotated_files` is anyway.
@NickleDave NickleDave force-pushed the ENH-add-boundary-detection-metrics-fix-#744 branch from 47014a2 to 51635d5 Compare March 13, 2026 15:30
@NickleDave NickleDave merged commit db51e28 into main Mar 13, 2026
0 of 4 checks passed
@NickleDave NickleDave deleted the ENH-add-boundary-detection-metrics-fix-#744 branch March 13, 2026 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ENH: enhancement enhancement; new feature or request Metrics Issue related to metrics

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: Add boundary detection IR metrics

1 participant