From 5733b3866f7b03f6c9896259f26aecc500e45466 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Thu, 18 Sep 2025 15:09:15 +0200 Subject: [PATCH 1/2] Add "type" identifier to error rate serialization --- meeteval/der/md_eval.py | 28 +++- meeteval/viz/visualize.py | 6 +- meeteval/wer/__main__.py | 8 +- meeteval/wer/wer/cp.py | 2 + meeteval/wer/wer/di_cp.py | 5 + meeteval/wer/wer/error_rate.py | 235 ++++++++++++++++++--------- meeteval/wer/wer/mimo.py | 5 + meeteval/wer/wer/orc.py | 5 + meeteval/wer/wer/time_constrained.py | 28 ---- tests/test_error_rate.py | 127 +++++++++++++++ 10 files changed, 329 insertions(+), 120 deletions(-) create mode 100644 tests/test_error_rate.py diff --git a/meeteval/der/md_eval.py b/meeteval/der/md_eval.py index 26929cd6..6352ecb8 100644 --- a/meeteval/der/md_eval.py +++ b/meeteval/der/md_eval.py @@ -8,7 +8,7 @@ from pathlib import Path import meeteval.io -from meeteval.wer.wer.error_rate import ErrorRate +from meeteval.wer.wer.error_rate import BaseErrorRate def _fix_channel(r): @@ -21,10 +21,12 @@ def _fix_channel(r): @dataclasses.dataclass(frozen=True) -class DiaErrorRate: +class DiaErrorRate(BaseErrorRate): """ """ + identifier = 'diarization-error-rate' + error_rate: 'float | decimal.Decimal' scored_speaker_time: 'float | decimal.Decimal' @@ -36,16 +38,29 @@ class DiaErrorRate: def zero(cls): return cls(0, 0, 0, 0, 0) + @classmethod + def from_dict(cls, d: dict) -> 'Self': + return cls( + d['error_rate'], + d['scored_speaker_time'], + d['missed_speaker_time'], + d['falarm_speaker_time'], + d['speaker_error_time'], + ) + def __post_init__(self): assert self.scored_speaker_time >= 0 assert self.missed_speaker_time >= 0 assert self.falarm_speaker_time >= 0 assert self.speaker_error_time >= 0 errors = self.speaker_error_time + self.falarm_speaker_time + self.missed_speaker_time - error_rate = errors / self.scored_speaker_time + if self.scored_speaker_time > 0: + error_rate = errors / self.scored_speaker_time + else: + error_rate = None if self.error_rate is None: object.__setattr__(self, 'error_rate', error_rate) - else: + elif error_rate is not None: # Since md-eval uses float internally, and the printed numbers are # rounded, it is in corner cases not possible to reproduce the # exact error rate, that is calculated internally by md-eval. @@ -76,6 +91,11 @@ def __add__(self, other: 'DiaErrorRate'): speaker_error_time=self.speaker_error_time + other.speaker_error_time, ) + def asdict(self): + d = dataclasses.asdict(self) + d['type'] = self.identifier + return d + class _FilenameEscaper: """ diff --git a/meeteval/viz/visualize.py b/meeteval/viz/visualize.py index e4ee99f3..9c1385ac 100644 --- a/meeteval/viz/visualize.py +++ b/meeteval/viz/visualize.py @@ -414,7 +414,7 @@ def compress(m): # Add utterances to data. Add total number of words to each utterance data['utterances'] = [{**l, 'total': len(l['words'].split())} for l in u] - data['info']['wer'] = dataclasses.asdict(wer) + data['info']['wer'] = wer.asdict() def wer_by_speaker(speaker): # Get all words from this speaker @@ -434,7 +434,7 @@ def wer_by_speaker(speaker): deletions = len(ref_words.filter( lambda s: not [w for w, _ in s['matches'] if w is not None and words[w]['source'] == 'hypothesis'])) - return dataclasses.asdict(ErrorRate( + return ErrorRate( errors=insertions + deletions + substitutions, length=len(ref_words), insertions=insertions, @@ -442,7 +442,7 @@ def wer_by_speaker(speaker): substitutions=substitutions, reference_self_overlap=None, hypothesis_self_overlap=None, - )) + ).asdict() data['info']['wer_by_speakers'] = { speaker: wer_by_speaker(speaker) diff --git a/meeteval/wer/__main__.py b/meeteval/wer/__main__.py index e65dcd98..beb50303 100644 --- a/meeteval/wer/__main__.py +++ b/meeteval/wer/__main__.py @@ -119,14 +119,14 @@ def to_str(example_id): # Save details _dump({ - to_str(example_id): dataclasses.asdict(error_rate) + to_str(example_id): error_rate.asdict() for example_id, error_rate in per_reco.items() }, per_reco_out.format(parent=parent, stem=stem)) # Compute and save average average = combine_error_rates(*per_reco.values()) _dump( - dataclasses.asdict(average), + average.asdict(), average_out.format(parent=parent, stem=stem), ) if hasattr(average, 'scored_speaker_time'): @@ -455,10 +455,10 @@ def _merge( if average: er = meeteval.wer.combine_error_rates(*[er for _, er in ers]) - out_data = dataclasses.asdict(er) + out_data = er.asdict() else: out_data = { - k: dataclasses.asdict(er) + k: er.asdict() for k, er in ers } assert len(out_data) == len(ers), (len(out_data), len(ers), 'Duplicate filenames') diff --git a/meeteval/wer/wer/cp.py b/meeteval/wer/wer/cp.py index a74a7399..83dfec03 100644 --- a/meeteval/wer/wer/cp.py +++ b/meeteval/wer/wer/cp.py @@ -36,6 +36,8 @@ class CPErrorRate(ErrorRate): >>> combine_error_rates(CPErrorRate(0, 10, 0, 0, 0, None, None, 1, 0, 3), CPErrorRate(5, 10, 0, 0, 5, None, None, 0, 1, 3)) CPErrorRate(error_rate=0.25, errors=5, length=20, insertions=0, deletions=0, substitutions=5, missed_speaker=1, falarm_speaker=1, scored_speaker=6) """ + identifier = 'cp-error-rate' + missed_speaker: int falarm_speaker: int scored_speaker: int diff --git a/meeteval/wer/wer/di_cp.py b/meeteval/wer/wer/di_cp.py index dbc2605a..131c5242 100644 --- a/meeteval/wer/wer/di_cp.py +++ b/meeteval/wer/wer/di_cp.py @@ -19,8 +19,13 @@ @dataclasses.dataclass(frozen=True) class DICPErrorRate(ErrorRate): + identifier = 'di-cp-error-rate' assignment: Tuple[int, ...] + @classmethod + def zero(cls): + return DICPErrorRate(0, 0, 0, 0, 0, None, None, ()) + def apply_assignment(self, reference, hypothesis): return apply_dicp_assignment(self.assignment, reference, hypothesis) diff --git a/meeteval/wer/wer/error_rate.py b/meeteval/wer/wer/error_rate.py index 5c0df849..aae67915 100644 --- a/meeteval/wer/wer/error_rate.py +++ b/meeteval/wer/wer/error_rate.py @@ -1,9 +1,10 @@ import dataclasses -__all__ = ['ErrorRate', 'combine_error_rates'] +__all__ = ['ErrorRate', 'combine_error_rates', 'SelfOverlap'] from typing import Optional, Any import logging +import abc logger = logging.getLogger('error_rate') @@ -66,8 +67,38 @@ def warn(self, name): ) +class BaseErrorRate(abc.ABC): + @abc.abstractclassmethod + def zero(cls): + raise NotImplementedError() + + @abc.abstractclassmethod + def asdict(self): + raise NotImplementedError() + + @abc.abstractmethod + def __add__(self, other: 'ErrorRate') -> 'ErrorRate': + raise NotImplementedError() + + @abc.abstractmethod + def asdict(self) -> dict: + """ + Returns a dictionary representation. Used for dumping into json or + yaml files. + """ + raise NotImplementedError() + + @abc.abstractclassmethod + def from_dict(self, d: dict): + """ + Constructs an error rate object from a dict. Used to load data from + json or yaml files. + """ + raise NotImplementedError() + + @dataclasses.dataclass(frozen=True, repr=False) -class ErrorRate: +class ErrorRate(BaseErrorRate): """ This class represents an error rate. It bundles statistics over the errors and makes sure that no wrong arithmetic operations can be performed on @@ -76,6 +107,8 @@ class ErrorRate: This class is frozen because an error rate should not change after it has been computed. """ + identifier = 'error-rate' + error_rate: float = dataclasses.field(init=False) errors: int @@ -102,13 +135,16 @@ def __post_init__(self): if self.length < 0: raise ValueError() + if self.errors == 0: + error_rate = 0 + elif self.length == 0: + error_rate = None + else: + error_rate = self.errors / self.length # We have to use object.__setattr__ in frozen dataclass. # The alternative would be a property named `error_rate` and a custom # repr - object.__setattr__( - self, 'error_rate', - self.errors / self.length if self.length > 0 else None - ) + object.__setattr__(self, 'error_rate', error_rate) assert self.length == 0 or self.error_rate >= 0 errors = self.insertions + self.deletions + self.substitutions if errors != self.errors: @@ -125,6 +161,11 @@ def __add__(self, other: 'ErrorRate') -> 'ErrorRate': """Combines two error rates""" if not isinstance(other, ErrorRate): return NotImplemented + + # Only allow add between the same type of error or with the base + if not isinstance(other, self.__class__) and type(other) is not ErrorRate: + return NotImplemented + # Return the base class here. Meta information can become # meaningless and should be handled in subclasses return ErrorRate( @@ -138,94 +179,73 @@ def __add__(self, other: 'ErrorRate') -> 'ErrorRate': ) @classmethod - def from_dict(self, d: dict): + def _from_dict(cls, d: dict): + """ + Instantiates `cls` from `d`. Keys in `d` must match the required data + for `cls`. + + Used by `from_dict` after identifying and checking the arguments. + """ + def _get_self_overlap(so): + if so is None: + return None + return SelfOverlap.from_dict(so) + + d = { + 'insertions': None, + 'deletions': None, + 'substitutions': None, + **d, + 'reference_self_overlap': _get_self_overlap(d.get('reference_self_overlap')), + 'hypothesis_self_overlap': _get_self_overlap(d.get('hypothesis_self_overlap')) + } + d.pop('error_rate', None) # Is computed from errors and length + d.pop('type', None) # Is needed for automatic identification + return cls(**d) + + @classmethod + def from_dict(cls, d: dict): """ - >>> ErrorRate.from_dict(dataclasses.asdict(ErrorRate(1, 1, 0, 0, 1, None, None))) + >>> ErrorRate.from_dict(ErrorRate(1, 1, 0, 0, 1, None, None).asdict()) ErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=0, substitutions=1) >>> from meeteval.wer.wer.cp import CPErrorRate - >>> ErrorRate.from_dict(dataclasses.asdict(CPErrorRate(1, 1, 0, 0, 1, None, None, 1, 1, 1))) + >>> ErrorRate.from_dict(CPErrorRate(1, 1, 0, 0, 1, None, None, 1, 1, 1).asdict()) CPErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=0, substitutions=1, missed_speaker=1, falarm_speaker=1, scored_speaker=1) >>> from meeteval.wer.wer.orc import OrcErrorRate - >>> ErrorRate.from_dict(dataclasses.asdict(OrcErrorRate(1, 1, 0, 0, 1, None, None, (0, 1)))) + >>> ErrorRate.from_dict(OrcErrorRate(1, 1, 0, 0, 1, None, None, (0, 1)).asdict()) OrcErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=0, substitutions=1, assignment=(0, 1)) >>> from meeteval.wer.wer.mimo import MimoErrorRate - >>> ErrorRate.from_dict(dataclasses.asdict(MimoErrorRate(1, 1, 0, 0, 1, None, None, [(0, 1)]))) + >>> ErrorRate.from_dict(MimoErrorRate(1, 1, 0, 0, 1, None, None, [(0, 1)]).asdict()) MimoErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=0, substitutions=1, assignment=[(0, 1)]) - >>> ErrorRate.from_dict(dataclasses.asdict(ErrorRate(1, 1, 0, 0, 1, SelfOverlap(10, 100), SelfOverlap(0, 90)))) + >>> ErrorRate.from_dict(ErrorRate(1, 1, 0, 0, 1, SelfOverlap(10, 100), SelfOverlap(0, 90)).asdict()) ErrorRate(error_rate=1.0, errors=1, length=1, insertions=0, deletions=0, substitutions=1, reference_self_overlap=SelfOverlap(overlap_rate=0.1, overlap_time=10, total_time=100), hypothesis_self_overlap=SelfOverlap(overlap_rate=0.0, overlap_time=0, total_time=90)) """ - # For backward compatibility, set default values. - d.setdefault('insertions', None) - d.setdefault('deletions', None) - d.setdefault('substitutions', None) - d.setdefault('reference_self_overlap', None) - d.setdefault('hypothesis_self_overlap', None) - - def _get_self_overlap(so): - if so is None: - return None - return SelfOverlap.from_dict(so) - - if d.keys() == { - 'errors', 'length', 'error_rate', - 'insertions', 'deletions', 'substitutions', - 'reference_self_overlap', 'hypothesis_self_overlap' - }: - return ErrorRate( - errors=d['errors'], - length=d['length'], - insertions=d['insertions'], - deletions=d['deletions'], - substitutions=d['substitutions'], - reference_self_overlap=_get_self_overlap(d['reference_self_overlap']), - hypothesis_self_overlap=_get_self_overlap(d['hypothesis_self_overlap']), - ) + if cls is ErrorRate: + type_ = _guess_type(d) - if d.keys() == { - 'errors', 'length', 'error_rate', - 'insertions', 'deletions', 'substitutions', - 'missed_speaker', 'falarm_speaker', 'scored_speaker', - 'assignment', - 'reference_self_overlap', 'hypothesis_self_overlap' - }: from meeteval.wer.wer.cp import CPErrorRate - return CPErrorRate( - errors=d['errors'], length=d['length'], - insertions=d['insertions'], - deletions=d['deletions'], - substitutions=d['substitutions'], - missed_speaker=d['missed_speaker'], - falarm_speaker=d['falarm_speaker'], - scored_speaker=d['scored_speaker'], - assignment=d['assignment'], - reference_self_overlap=_get_self_overlap(d['reference_self_overlap']), - hypothesis_self_overlap=_get_self_overlap(d['hypothesis_self_overlap']), - ) - - if d.keys() == { - 'errors', 'length', 'error_rate', - 'insertions', 'deletions', 'substitutions', - 'assignment', - 'reference_self_overlap', 'hypothesis_self_overlap' - }: - if isinstance(d['assignment'][0], (tuple, list)): - from meeteval.wer.wer.mimo import MimoErrorRate - XErrorRate = MimoErrorRate + from meeteval.wer.wer.orc import OrcErrorRate + from meeteval.wer.wer.mimo import MimoErrorRate + from meeteval.wer.wer.di_cp import DICPErrorRate + from meeteval.der.md_eval import DiaErrorRate + + if type_ == cls.identifier: + return ErrorRate._from_dict(d) + elif type_ == CPErrorRate.identifier: + return CPErrorRate._from_dict(d) + elif type_ == OrcErrorRate.identifier: + return OrcErrorRate._from_dict(d) + elif type_ == MimoErrorRate.identifier: + return MimoErrorRate._from_dict(d) + elif type_ == DICPErrorRate.identifier: + return DICPErrorRate._from_dict(d) + elif type_ == DiaErrorRate.identifier: + return DiaErrorRate.from_dict(d) else: - from meeteval.wer.wer.orc import OrcErrorRate - XErrorRate = OrcErrorRate - - return XErrorRate( - errors=d['errors'], - length=d['length'], - insertions=d['insertions'], - deletions=d['deletions'], - substitutions=d['substitutions'], - assignment=d['assignment'], - reference_self_overlap=_get_self_overlap(d['reference_self_overlap']), - hypothesis_self_overlap=_get_self_overlap(d['hypothesis_self_overlap']), - ) - raise ValueError(d.keys(), d) + raise AssertionError(f'Error while detecting ErrorRate type.') + else: + return cls._from_dict(d) + def __repr__(self): return ( @@ -237,6 +257,11 @@ def __repr__(self): ]) + ')' ) + def asdict(self): + d = dataclasses.asdict(self) + d['type'] = self.identifier + return d + def combine_error_rates(*error_rates: ErrorRate) -> ErrorRate: """ @@ -293,3 +318,51 @@ def __repr__(self): if getattr(self, f.name) is not None ]) + ')' ) + +def _guess_type(error_rate_dict: dict) -> str: + """ + Guess the error rate type from an error rate dict. + + Mainly for backwards compatibility for files that do not have the 'type' + field set. + """ + if 'type' in error_rate_dict: + return error_rate_dict['type'] + + # Add keys with defaults for backwards compatibility + keys = set(error_rate_dict.keys()) | { + 'insertions', + 'deletions', + 'substitutions', + 'reference_self_overlap', + 'hypothesis_self_overlap', + } + + # Every error rate must have these keys + required_keys = { + 'errors', 'length', 'error_rate', + 'insertions', 'deletions', 'substitutions', + 'reference_self_overlap', 'hypothesis_self_overlap' + } + + if keys == required_keys: + return ErrorRate.identifier + + if keys == required_keys | { + 'missed_speaker', 'falarm_speaker', 'scored_speaker', 'assignment' + }: + from meeteval.wer.wer.cp import CPErrorRate + return CPErrorRate.identifier + + if keys == required_keys | {'assignment'}: + if isinstance(error_rate_dict['assignment'][0], (tuple, list)): + from meeteval.wer.wer.mimo import MimoErrorRate + return MimoErrorRate.identifier + else: + from meeteval.wer.wer.orc import OrcErrorRate + return OrcErrorRate.identifier + + raise ValueError( + f'Cannot identify error rate type from dict: {keys}', + error_rate_dict, + ) diff --git a/meeteval/wer/wer/mimo.py b/meeteval/wer/wer/mimo.py index 7eb0ce20..007d8023 100644 --- a/meeteval/wer/wer/mimo.py +++ b/meeteval/wer/wer/mimo.py @@ -24,8 +24,13 @@ class MimoErrorRate(ErrorRate): >>> MimoErrorRate(0, 10, 0, 0, 0, None, None, [(0, 0)]) + MimoErrorRate(10, 10, 0, 0, 10, None, None, [(0, 0)]) ErrorRate(error_rate=0.5, errors=10, length=20, insertions=0, deletions=0, substitutions=10) """ + identifier = 'mimo-error-rate' assignment: 'tuple[int, ...]' + @classmethod + def zero(cls): + return MimoErrorRate(0, 0, 0, 0, 0, None, None, ()) + def apply_assignment(self, reference, hypothesis): return apply_mimo_assignment( self.assignment, diff --git a/meeteval/wer/wer/orc.py b/meeteval/wer/wer/orc.py index 779073bb..caf66d35 100644 --- a/meeteval/wer/wer/orc.py +++ b/meeteval/wer/wer/orc.py @@ -33,8 +33,13 @@ class OrcErrorRate(ErrorRate): >>> OrcErrorRate(0, 10, 0, 0, 0, None, None, (0, 1)) + OrcErrorRate(10, 10, 0, 0, 10, None, None, (1, 0, 1)) ErrorRate(error_rate=0.5, errors=10, length=20, insertions=0, deletions=0, substitutions=10) """ + identifier = 'orc-error-rate' assignment: 'tuple[int, ...]' + @classmethod + def zero(cls): + return OrcErrorRate(0, 0, 0, 0, 0, None, None, ()) + def apply_assignment(self, reference, hypothesis): """ >>> OrcErrorRate(0, 10, 0, 0, 0, None, None, (0, 1)).apply_assignment(['a', 'b'], ['a', 'b']) diff --git a/meeteval/wer/wer/time_constrained.py b/meeteval/wer/wer/time_constrained.py index 84c3d32a..0a384722 100644 --- a/meeteval/wer/wer/time_constrained.py +++ b/meeteval/wer/wer/time_constrained.py @@ -898,34 +898,6 @@ def align( segment_index='word' if style == 'index' else False, remove_empty_segments=True, ) - # reference = sort_and_validate( - # reference, - # reference_sort, - # reference_pseudo_word_level_timing, - # 'reference' - # ) - # hypothesis = sort_and_validate( - # hypothesis, - # hypothesis_sort, - # hypothesis_pseudo_word_level_timing, - # 'hypothesis' - # ) - - # Add index for tracking across filtering operations. This is only required - # for the index style since all other styles can be constructed from seglst - # without the index. Especially for `style = 'seglst'` we want to keep - # identity - # if style == 'index': - # reference = SegLST( - # [{**s, '__align_index': i} for i, s in enumerate(reference)] - # ) - # hypothesis = SegLST( - # [{**s, '__align_index': i} for i, s in enumerate(hypothesis)] - # ) - - # Ignore empty segments - # reference = reference.filter(lambda s: s['words']) - # hypothesis = hypothesis.filter(lambda s: s['words']) hypothesis_ = apply_collar(hypothesis, collar=collar) diff --git a/tests/test_error_rate.py b/tests/test_error_rate.py new file mode 100644 index 00000000..29e590e0 --- /dev/null +++ b/tests/test_error_rate.py @@ -0,0 +1,127 @@ +import pytest + +from meeteval.wer.wer.error_rate import ErrorRate +from meeteval.wer.wer.cp import CPErrorRate +from meeteval.wer.wer.orc import OrcErrorRate +from meeteval.wer.wer.mimo import MimoErrorRate +from meeteval.wer.wer.di_cp import DICPErrorRate +from meeteval.der.md_eval import DiaErrorRate +from hypothesis import given, strategies as st + + +all_word_error_rates = [ + ErrorRate, + CPErrorRate, + OrcErrorRate, + MimoErrorRate, + DICPErrorRate, +] + +all_error_rates = all_word_error_rates + [DiaErrorRate] + +@st.composite +def random_error_rate(draw, error_rate_cls): + + if isinstance(error_rate_cls, list): + error_rate_cls = draw(st.sampled_from(error_rate_cls)) + + if error_rate_cls is DiaErrorRate: + # Extreme float values do not appear in practice and cause numerical + # issues in the serialization + floats = st.floats( + min_value=0, + max_value=1e5, + allow_infinity=False, + allow_nan=False, + allow_subnormal=False + ) + return DiaErrorRate( + None, + scored_speaker_time=draw(floats), + missed_speaker_time=draw(floats), + falarm_speaker_time=draw(floats), + speaker_error_time=draw(floats), + ) + else: + integers = st.integers(min_value=0) + insertions = draw(integers) + deletions = draw(integers) + substitutions = draw(integers) + length = draw(st.integers(min_value=substitutions + deletions)) + + if error_rate_cls in [ + OrcErrorRate, + MimoErrorRate, + DICPErrorRate, + ]: + assignment = () + return error_rate_cls( + errors=insertions + deletions + substitutions, + length=length, + insertions=insertions, + deletions=deletions, + substitutions=substitutions, + reference_self_overlap=None, + hypothesis_self_overlap=None, + assignment=assignment, + ) + elif error_rate_cls is CPErrorRate: + return CPErrorRate( + errors=insertions + deletions + substitutions, + length=length, + insertions=insertions, + deletions=deletions, + substitutions=substitutions, + reference_self_overlap=None, + hypothesis_self_overlap=None, + assignment={}, + falarm_speaker=draw(integers), + missed_speaker=draw(integers), + scored_speaker=draw(integers), + ) + else: + return ErrorRate( + errors=insertions + deletions + substitutions, + length=length, + insertions=insertions, + deletions=deletions, + substitutions=substitutions, + reference_self_overlap=None, + hypothesis_self_overlap=None, + ) + + +@st.composite +def list_of_error_rates(draw, error_rate_cls=all_error_rates): + error_rate_cls = draw(st.sampled_from(error_rate_cls)) + error_rate_list = draw(st.lists(random_error_rate(error_rate_cls), min_size=1)) + return error_rate_list + +@given(list_of_error_rates(all_word_error_rates)) +def test_sum_error_rates(list_of_error_rates): + """Test that the sum of error rates works correctly""" + total_errors = sum([er.errors for er in list_of_error_rates]) + total_insertions = sum([er.insertions for er in list_of_error_rates]) + total_deletions = sum([er.deletions for er in list_of_error_rates]) + total_substitutions = sum([er.substitutions for er in list_of_error_rates]) + summed = sum(list_of_error_rates) + assert isinstance(summed, ErrorRate) + assert summed.errors == total_errors + assert summed.insertions == total_insertions + assert summed.deletions == total_deletions + assert summed.substitutions == total_substitutions + +@pytest.mark.parametrize('cls', all_error_rates) +def test_zero(cls: ErrorRate): + """Test that the zero function returns the right type and an error_rate of 0""" + er = cls.zero() + assert isinstance(er, cls) + assert er.error_rate == 0 + +@given(random_error_rate(all_error_rates)) +def test_serialize(error_rate): + serialized = error_rate.asdict() + assert isinstance(serialized, dict) + assert 'type' in serialized + reconstructed = ErrorRate.from_dict(serialized) + assert reconstructed == error_rate From 323ac47cd53e8cabaabf6a89ac9f40ddf976ad12 Mon Sep 17 00:00:00 2001 From: Thilo von Neumann Date: Fri, 19 Sep 2025 10:03:30 +0200 Subject: [PATCH 2/2] Fix flake8 and zero error rate --- meeteval/der/md_eval.py | 8 ++++---- meeteval/wer/wer/error_rate.py | 4 +--- tests/test_error_rate.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/meeteval/der/md_eval.py b/meeteval/der/md_eval.py index 6352ecb8..f80d5212 100644 --- a/meeteval/der/md_eval.py +++ b/meeteval/der/md_eval.py @@ -8,7 +8,7 @@ from pathlib import Path import meeteval.io -from meeteval.wer.wer.error_rate import BaseErrorRate +from meeteval.wer.wer.error_rate import BaseErrorRate, ErrorRate def _fix_channel(r): @@ -39,7 +39,7 @@ def zero(cls): return cls(0, 0, 0, 0, 0) @classmethod - def from_dict(cls, d: dict) -> 'Self': + def from_dict(cls, d: dict) -> 'DiaErrorRate': return cls( d['error_rate'], d['scored_speaker_time'], @@ -73,13 +73,13 @@ def __post_init__(self): # Hence, we allow a small difference. assert abs(self.error_rate - error_rate) < 0.00007, (error_rate, self) - def __radd__(self, other: 'int') -> 'ErrorRate': + def __radd__(self, other: 'int') -> 'DiaErrorRate': if isinstance(other, int) and other == 0: # Special case to support sum. return self return NotImplemented - def __add__(self, other: 'DiaErrorRate'): + def __add__(self, other: 'DiaErrorRate') -> 'DiaErrorRate': if not isinstance(other, self.__class__): raise ValueError() diff --git a/meeteval/wer/wer/error_rate.py b/meeteval/wer/wer/error_rate.py index aae67915..b111f1c9 100644 --- a/meeteval/wer/wer/error_rate.py +++ b/meeteval/wer/wer/error_rate.py @@ -135,9 +135,7 @@ def __post_init__(self): if self.length < 0: raise ValueError() - if self.errors == 0: - error_rate = 0 - elif self.length == 0: + if self.length == 0: error_rate = None else: error_rate = self.errors / self.length diff --git a/tests/test_error_rate.py b/tests/test_error_rate.py index 29e590e0..edc53fc7 100644 --- a/tests/test_error_rate.py +++ b/tests/test_error_rate.py @@ -116,7 +116,7 @@ def test_zero(cls: ErrorRate): """Test that the zero function returns the right type and an error_rate of 0""" er = cls.zero() assert isinstance(er, cls) - assert er.error_rate == 0 + assert er.error_rate is None or er.error_rate == 0 @given(random_error_rate(all_error_rates)) def test_serialize(error_rate):