ENH: Add boundary detection metrics, fix #744#820
Merged
NickleDave merged 11 commits intomainfrom Mar 13, 2026
Merged
Conversation
WIP: Add vak/metrics/segmentation/ Fixup translation of vak.metrics.segmentation.ir.validators.is_valid_boundaries_tensor to torch Fixup translation of vak.metrics.segmentation.ir.functional.find_hits to torch Fix return value names in ir.functional.find_hits docstring: hits_ref -> hits_target, hits_hyp -> hits_preds Fix type annotations, docstrings, return types in src/vak/metrics/segmentation/ir/functional.py Fix more return type hints + docstrings in src/vak/metrics/segmentation/ir/functional.py Flesh out src/vak/metrics/segmentation/ir/ir.py Rewrite src/vak/metrics/segmentation/ir/ir.py Add ignore_val to BaseSegmentationIRMetric Rename -> vak.metrics.boundary_detection Fixup boundary_detection/validators.py Rewrite boundary_detection/functional.py WIP: Add R-value to boundary detection metrics WIP: Rewrite classes in metrics.boundary_detection fixup typo in src/vak/metrics/boundary_detection/_boundary_detection.py Clean up `find_hits` docstring Add comment in metrics.boundary_detection.functional Rewrite metrics/boundary_detection/validators.py Rewrite metrics/boundary_detection/functional.py again Rewrite PrecisionRecallFScoreRVal Fixup rewrite of find_hits Fixup imports in src/vak/metrics/boundary_detection/__init__.py Fix how we test pre-conditions in find_hits Import boundary_detection module in vak.metrics Fixup metrics.boundary_detection.precision_recall_fscore_rval Fixup tensor/torch issues in precision_recall_fscore_rval Clean up code for handling edge cases in precision_recall_fscore_rval Fixup inner loop in precision_recall_fscore_rval Set default tolerance for PrecisionRecallFscoreRval to 0.01 Change default tolerance, revise docstrings in boundary_detection/functional.py Revise docstring for precision_recall_fscore_rval Add docstring for PrecisionRecallFscoreRval Fix whitespace in _boundary_detection.py Fixup add metrics.boundary_detection.validators.is_non_negative Fixup add metrics.boundary_detection.validators.is_strictly_increasing Add metrics.boundary_detection.validators.is_2d_tensor Add/fix pre-conditions in precision_recall_fscore_rval Fix vocalpy -> vak.metrics in metrics.boundary_detection.validators Fixup add boundary_detection.validators Fixup add boundary_detection/functional.py - Fix pre-conditions for precision_recall_fscore_rval
WIP: Add tests/test_metrics/test_segmentation/ Add unit tests for vak.metrics.segmentation.ir.validators.is_1d_tensor Rename -> test_metrics/test_segmentation/test_ir/test_functional.py WIP: Fix / add more tests to tests/test_metrics/test_segmentation/test_ir/test_validators.py Finish adding unit tests in test_ir/test_validators.py WIP: Add unit tests in tests/test_metrics/test_segmentation/test_ir/test_functional.py Add tests/test_metrics/test_segmentation/test_ir/__init__.py Rename -> tests/test_metrics/boundary_detection WIP: Rewrite tests in boundary_detection/test_functional.py WIP: Rewrite unit tests in tests/test_metrics/boundary_detection/test_functional.py WIP: Rewrite unit tests in boundary_detection/test_functional.py Rename -> test_metrics/test_boundary_detection Remove tests/test_metrics/test_segmentation.py WIP: Rewrite unit tests in boundary_detection/test_functional.py WIP: Rewrite unit tests in boundary_detection/test_functional.py Get find_hits unit test to at least run without fail WIP: Fixing up unit tests for precision_recall_fscore_rval WIP: Fixing up unit tests for precision_recall_fscore_rval WIP: Fixing up unit tests for precision_recall_fscore_rval Finish fixing test cases for precision_recall_fscore_rval unit test Move test case lists into test_boundary_detection/conftest.py Fix up test cases for find hits Add unit test for PrecisionRecallFscoreRval Fix imports in tests/test_metrics/test_boundary_detection/test_validators.py Fixup add unit tests in test_boundary_detection/test_validators.py - Fix up / add unit tests for new validators - Remove unit tests for removed is_valid_boundaries_tensor WIP: test that find_hits raises expected exceptions Add unit test that find_hits raises expected exceptions Add unit_test: test_precision_recall_fscore_rval_raises Fixup add unit tests for boundary_detection.validators Fixup add unit tests for metrics.boundary_detection.functional Fix up comments explaining test cases
This adds functions used by `metrics.boundary_detection`.
Adds `transforms.frame_labels.functional.to_boundary_times` for use with boundary detection IR metrics. Also adds DEFAULT_BOUNDARY_TIMES_PADVAL to common.constants First finds boundary indices using `torch.diff` on `frame_labels`, and then uses those to index into `frame_times`, now returned by `InferDatapipe`. My first attempt at an implementation instead computed boundary times by getting a frame duration with `timebins.timebin_dur_from_vec`, but its better to just use the times in `frame_times` directly, instead of adding (more) floating point noise by converting the boundary times to frame times with a conversion factor. Also this saves us a bit of computation. Works on 2-D batches. Because the computed vector of boundary times for each item in the batch can (and likely will) have different lengths, we allow a `padval` that defaults to `common.constants.DEFAULT_BOUNDARY_TIMES_PADVAL`. We use this constant so that it can also be the default for the `ignore_val` used by the functions in `metrics.boundary_detection`.
Adds boundary detection IR metrics to TweetyNet model definition. Add default args to use the default `ignore_val` anbd a default `tolerance`.
e97ef09 to
47014a2
Compare
Fix how method handles `loss` and `metrics`, so that it handles both `None` from the config passed in (because it's a `dict` and we call `get` on it), as well as an empty dict, that we get from the default for `loss` and `metrics` on the ModelConfig class. In practice we usually have the latter, but unit tests were testing for the former. This is a quick and dirty fix to keep stuff from crashing. I'm planning to remove all this machinery anyway, so not going to overthink it.
Make `InferDatapipe` return `frame_times`, i.e. the vector of time bin centers from npz/mat files containing spectrograms. This is hacked in because I'm expecting to replace this class anyways. It might break things for someone fitting a model with audio, but my guess is there are exactly zero users actually doing that right now.
Adds logic to handle boundary detection metrics in `models.FrameClassification.validation_step`.
Need to "cast" the parameter `annotated_files` to a list, so that we don't get an error `ArrowStringArray does not have attribute `remove` -- not sure why we didn't see this before since we were passing in `values` from a pandas.Series, which would be a `numpy.array`; I guess maybe that also has a `remove` method? But newer versions of pandas must be remove this Arrow array instead. This makes sure we have a list we can `remove` things from, which is congruous with what the typehints say `annotated_files` is anyway.
47014a2 to
51635d5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This fixes #744.
Note I named the module
metrics.boundary_detection(notsegmentationas originally planned in that issue).I prefer this more specific name, and it avoids a name clash with other segmentation metrics, like segmental F1 (see #819).
I also chose to not use
torchmetrics. Because of the way their abstractions work, we would have to pay the cost of repeatedly callingfind_hitsto compute precision, recall, F-score, etc., even though we would get back the exact same hits each time we call it. There might be a clever way around that, e.g., with class hierarchies and/or passing around the LightningModule, but I want to avoid being clever and regretting it later.Instead we call
find_hitsonce inside the functionprecision_recall_fscore_rval, and then compute whatevermetricsa user asks for, any combination of{"precision", "recall", "fscore", "rval"}. This gets returned as adictso we can callself.log_dictinside aLightningModule, logging them all in one fell swoop.Aside: I am going to stop doing squash merges, based on this comment and this one. I was able to get this PR down from 99 commits to 11 through some (painful) interactive rebasing.