diff --git a/pyproject.toml b/pyproject.toml index e3493a6..498dea3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,12 @@ [tool.mypy] -python_version = 3.9 +python_version = "3.9" strict_optional = false disallow_untyped_defs = true disallow_untyped_calls = true +plugins = ['pydantic.mypy'] [[tool.mypy.overrides]] -module = [ - "setuptools", - "click.*", -] +module = ["setuptools", "click.*"] ignore_missing_imports = true [tool.ruff] diff --git a/seqscore/conll.py b/seqscore/conll.py index d60bb8b..5f46464 100644 --- a/seqscore/conll.py +++ b/seqscore/conll.py @@ -10,7 +10,7 @@ TextIO, ) -from attr import attrib, attrs +from pydantic import BaseModel from tabulate import tabulate from seqscore.encoding import Encoding, EncodingError, get_encoding @@ -43,13 +43,12 @@ class CoNLLFormatError(Exception): pass -@attrs(frozen=True) -class _CoNLLToken: - text: str = attrib() - label: str = attrib() - is_docstart: bool = attrib() - line_num: int = attrib() - other_fields: tuple[str, ...] = attrib() +class _CoNLLToken(BaseModel, frozen=True): + text: str + label: str + is_docstart: bool + line_num: int + other_fields: tuple[str, ...] @classmethod def from_line(cls, line: str, line_num: int, source_name: str) -> "_CoNLLToken": @@ -76,14 +75,19 @@ def from_line(cls, line: str, line_num: int, source_name: str) -> "_CoNLLToken": label = splits[-1] other_fields = tuple(splits[1:-1]) is_docstart = text == DOCSTART - return cls(text, label, is_docstart, line_num, other_fields) + return cls( + text=text, + label=label, + is_docstart=is_docstart, + line_num=line_num, + other_fields=other_fields, + ) -@attrs(frozen=True) -class CoNLLIngester: - encoding: Encoding = attrib() - parse_comment_lines: bool = attrib(default=False, kw_only=True) - ignore_document_boundaries: bool = attrib(default=True, kw_only=True) +class CoNLLIngester(BaseModel, arbitrary_types_allowed=True, frozen=True): + encoding: Encoding + parse_comment_lines: bool = False + ignore_document_boundaries: bool = True def ingest( self, @@ -183,11 +187,13 @@ def ingest( ) from e sequences = LabeledSequence( - tokens, - labels, - mentions, + tokens=tokens, + labels=labels, + mentions=tuple(mentions), other_fields=other_fields, - provenance=SequenceProvenance(line_nums[0], source_name), + provenance=SequenceProvenance( + starting_line=line_nums[0], source=source_name + ), comment=comment, ) document.append(sequences) @@ -197,7 +203,7 @@ def ingest( document_counter += 1 yield document - def validate( + def conll_validate( self, source: TextIO, source_name: str ) -> list[list[SequenceValidationResult]]: all_results: list[list[SequenceValidationResult]] = [] @@ -330,7 +336,7 @@ def ingest_conll_file( ) ingester = CoNLLIngester( - mention_encoding, + encoding=mention_encoding, parse_comment_lines=parse_comment_lines, ignore_document_boundaries=ignore_document_boundaries, ) @@ -349,12 +355,12 @@ def validate_conll_file( ) -> ValidationResult: encoding = get_encoding(mention_encoding_name) ingester = CoNLLIngester( - encoding, + encoding=encoding, parse_comment_lines=parse_comment_lines, ignore_document_boundaries=ignore_document_boundaries, ) with open(input_path, encoding=file_encoding) as input_file: - results = ingester.validate(input_file, input_path) + results = ingester.conll_validate(input_file, input_path) n_docs = len(results) n_sequences = sum(len(doc_results) for doc_results in results) @@ -365,7 +371,9 @@ def validate_conll_file( result.errors for doc_results in results for result in doc_results ) ) - return ValidationResult(errors, n_tokens, n_sequences, n_docs) + return ValidationResult( + errors=errors, n_tokens=n_tokens, n_sequences=n_sequences, n_docs=n_docs + ) def repair_conll_file( diff --git a/seqscore/encoding.py b/seqscore/encoding.py index 9efae76..7338805 100644 --- a/seqscore/encoding.py +++ b/seqscore/encoding.py @@ -1,13 +1,9 @@ from abc import abstractmethod from collections.abc import Sequence from functools import lru_cache -from typing import ( - AbstractSet, - Optional, - Protocol, -) +from typing import AbstractSet, Optional, Protocol, runtime_checkable -from attr import Factory, attrib, attrs +from pydantic import BaseModel from seqscore.model import LabeledSequence, Mention, Span @@ -19,6 +15,7 @@ DEFAULT_OUTSIDE = "O" +@runtime_checkable class EncodingDialect(Protocol): label_delim: str outside: str @@ -57,6 +54,7 @@ def __init__(self) -> None: self.single = "W" +@runtime_checkable class Encoding(Protocol): dialect: EncodingDialect @@ -660,11 +658,10 @@ def get_encoding(name: str) -> Encoding: raise ValueError(f"Unknown encoder {repr(name)}") -@attrs -class _MentionBuilder: - start_idx: Optional[int] = attrib(default=None, init=False) - entity_type: Optional[str] = attrib(default=None, init=False) - mentions: list[Mention] = attrib(default=Factory(list), init=False) +class _MentionBuilder(BaseModel): + start_idx: Optional[int] = None + entity_type: Optional[str] = None + mentions: list[Mention] = [] def start_mention(self, start_idx: int, entity_type: str) -> None: # Check arguments @@ -690,7 +687,9 @@ def end_mention(self, end_idx: int) -> None: assert self.start_idx is not None, "No mention start index" assert self.entity_type is not None, "No mention entity type" - mention = Mention(Span(self.start_idx, end_idx), self.entity_type) + mention = Mention( + span=Span(start=self.start_idx, end=end_idx), type=self.entity_type + ) self.mentions.append(mention) self.start_idx = None diff --git a/seqscore/model.py b/seqscore/model.py index 09e67e2..bf6a726 100644 --- a/seqscore/model.py +++ b/seqscore/model.py @@ -1,19 +1,15 @@ -from collections.abc import Iterable, Iterator, Sequence +from collections.abc import Iterable, Sequence from itertools import repeat -from typing import Any, Optional, Union, overload +from typing import Annotated, Any, Optional, Union, overload -from attr import Attribute, attrib, attrs +from pydantic import BaseModel, BeforeValidator, Field -from seqscore.util import ( - tuplify_optional_nested_strs, - tuplify_strs, - validator_nonempty_str, -) - -def _validator_nonnegative(_inst: Any, _attr: Attribute, value: Any) -> None: +def _validator_nonnegative(value: Any) -> Any: if value < 0: raise ValueError(f"Negative value: {repr(value)}") + else: + return value def _tuplify_mentions( @@ -22,12 +18,11 @@ def _tuplify_mentions( return tuple(mentions) -@attrs(frozen=True, slots=True) -class Span: - start: int = attrib(validator=_validator_nonnegative) - end: int = attrib(validator=_validator_nonnegative) +class Span(BaseModel, frozen=True): + start: Annotated[int, BeforeValidator(_validator_nonnegative)] + end: Annotated[int, BeforeValidator(_validator_nonnegative)] - def __attrs_post_init__(self) -> None: + def model_post_init(self, context: Any) -> None: if not self.end > self.start: raise ValueError( f"End of span ({self.end}) must be greater than start ({self.start}" @@ -37,38 +32,31 @@ def __len__(self) -> int: return self.end - self.start -@attrs(frozen=True, slots=True) -class Mention: - span: Span = attrib() - type: str = attrib(validator=validator_nonempty_str) +class Mention(BaseModel, frozen=True): + span: Span + type: str = Field(min_length=1, description="Must be a non-empty string") def __len__(self) -> int: return len(self.span) def with_type(self, new_type: str) -> "Mention": - return Mention(self.span, new_type) + return Mention(span=self.span, type=new_type) -@attrs(frozen=True, slots=True) -class SequenceProvenance: - starting_line: int = attrib() - source: Optional[str] = attrib() +class SequenceProvenance(BaseModel, frozen=True): + starting_line: int + source: Optional[str] -@attrs(frozen=True, slots=True) -class LabeledSequence(Sequence[str]): - tokens: tuple[str, ...] = attrib(converter=tuplify_strs) - labels: tuple[str, ...] = attrib(converter=tuplify_strs) - mentions: tuple[Mention, ...] = attrib(default=(), converter=_tuplify_mentions) - other_fields: Optional[tuple[tuple[str, ...], ...]] = attrib( - default=None, kw_only=True, converter=tuplify_optional_nested_strs - ) - provenance: Optional[SequenceProvenance] = attrib( - default=None, eq=False, kw_only=True - ) - comment: Optional[str] = attrib(default=None, eq=False, kw_only=True) +class LabeledSequence(BaseModel, frozen=True): + tokens: tuple[str, ...] + labels: tuple[str, ...] + mentions: tuple[Mention, ...] = tuple() + other_fields: Optional[tuple[tuple[str, ...], ...]] = None + provenance: Optional[SequenceProvenance] = None + comment: Optional[str] = None - def __attrs_post_init__(self) -> None: + def model_post_init(self, context: Any) -> None: # TODO: Check for overlapping mentions if len(self.tokens) != len(self.labels): @@ -97,7 +85,29 @@ def __attrs_post_init__(self) -> None: def with_mentions(self, mentions: Sequence[Mention]) -> "LabeledSequence": return LabeledSequence( - self.tokens, self.labels, mentions, provenance=self.provenance + tokens=self.tokens, + labels=self.labels, + mentions=tuple(mentions), + provenance=self.provenance, + ) + + # Pydantic doesn't support excluding certain fields when it generates + # its default `__hash__` method when `frozen=True` + # To get around that limitation, define custom `__hash__` method + def __hash__(self) -> int: + # Do not hash `provenance` and `comment` + return hash((self.tokens, self.labels, self.mentions, self.other_fields)) + + # Do not check eq with `provenance` and `comment` fields + def __eq__(self, other: object) -> bool: + if not isinstance(other, LabeledSequence): + return NotImplemented + + return ( + self.tokens == other.tokens + and self.labels == other.labels + and self.mentions == other.mentions + and self.other_fields == other.other_fields ) @overload @@ -111,9 +121,6 @@ def __getitem__(self, index: slice) -> tuple[str, ...]: def __getitem__(self, i: Union[int, slice]) -> Union[str, tuple[str, ...]]: return self.tokens[i] - def __iter__(self) -> Iterator[str]: - return iter(self.tokens) - def __len__(self) -> int: # Guaranteed that labels and tokens are same length by construction return len(self.tokens) diff --git a/seqscore/scoring.py b/seqscore/scoring.py index 1622613..f73f6b9 100644 --- a/seqscore/scoring.py +++ b/seqscore/scoring.py @@ -1,9 +1,9 @@ from collections import Counter, defaultdict from collections.abc import Iterable, Sequence from decimal import ROUND_HALF_UP, Decimal -from typing import DefaultDict, Optional, Union +from typing import Annotated, DefaultDict, Optional, Union -from attr import Factory, attrib, attrs +from pydantic import BaseModel, BeforeValidator, Field from seqscore.encoding import Encoding, EncodingError, get_encoding from seqscore.model import LabeledSequence, Mention @@ -15,10 +15,9 @@ def _defaultdict_classification_score() -> DefaultDict[str, "ClassificationScore return defaultdict(ClassificationScore) -@attrs(frozen=True, slots=True) -class TokensWithType: - tokens: tuple[str, ...] = attrib(converter=tuplify_strs) - type: str = attrib(validator=validator_nonempty_str) +class TokensWithType(BaseModel, frozen=True): + tokens: Annotated[tuple[str, ...], lambda tokens: tuplify_strs(tokens)] + type: Annotated[str, BeforeValidator(validator_nonempty_str)] class TokenCountError(ValueError): @@ -62,22 +61,22 @@ def from_predicted_sequence( ) -@attrs -class ClassificationScore: - true_pos: int = attrib(default=0, kw_only=True) - false_pos: int = attrib(default=0, kw_only=True) - false_neg: int = attrib(default=0, kw_only=True) - type_scores: DefaultDict[str, "ClassificationScore"] = attrib( - default=Factory(_defaultdict_classification_score), kw_only=True - ) - false_pos_examples: Counter[TokensWithType] = attrib(default=Factory(Counter)) - false_neg_examples: Counter[TokensWithType] = attrib(default=Factory(Counter)) +class ClassificationScore(BaseModel): + true_pos: int = 0 + false_pos: int = 0 + false_neg: int = 0 + type_scores: DefaultDict[ + str, + Annotated["ClassificationScore", Field(default_factory="ClassificationScore")], + ] = Field(default_factory=lambda: _defaultdict_classification_score()) + false_pos_examples: Counter[TokensWithType] = Counter() + false_neg_examples: Counter[TokensWithType] = Counter() def count_false_positive(self, tokens: Iterable[str], type_: str) -> None: - self.false_pos_examples[TokensWithType(tuple(tokens), type_)] += 1 + self.false_pos_examples[TokensWithType(tokens=tuple(tokens), type=type_)] += 1 def count_false_negative(self, tokens: Iterable[str], type_: str) -> None: - self.false_neg_examples[TokensWithType(tuple(tokens), type_)] += 1 + self.false_neg_examples[TokensWithType(tokens=tuple(tokens), type=type_)] += 1 def update(self, score: "ClassificationScore") -> None: self.true_pos += score.true_pos @@ -117,10 +116,11 @@ def f1(self) -> float: return 2 * (precision * recall) / (precision + recall) -@attrs -class AccuracyScore: - hits: int = attrib(default=0, kw_only=True) - total: int = attrib(default=0, kw_only=True) +class AccuracyScore( + BaseModel, +): + hits: int = 0 + total: int = 0 @property def accuracy(self) -> float: diff --git a/seqscore/scripts/seqscore.py b/seqscore/scripts/seqscore.py index 5ec3f0d..72d9ea9 100644 --- a/seqscore/scripts/seqscore.py +++ b/seqscore/scripts/seqscore.py @@ -22,6 +22,7 @@ SUPPORTED_ENCODINGS, SUPPORTED_REPAIR_METHODS, ) +from seqscore.model import LabeledSequence from seqscore.processing import modify_types @@ -483,7 +484,10 @@ def extract_text( else: first_doc = False for sentence in doc: - print(" ".join(sentence), file=output) + if isinstance(sentence, LabeledSequence): + print(" ".join(sentence.tokens), file=output) + else: + print(" ".join(sentence), file=output) def _normalize_tab(s: str) -> str: diff --git a/seqscore/util.py b/seqscore/util.py index 70677fd..e48178d 100644 --- a/seqscore/util.py +++ b/seqscore/util.py @@ -5,8 +5,6 @@ from pathlib import Path from typing import Any, Optional, Union -from attr import Attribute, validators - # Union[str, Path] isn't enough to appease PyCharm's type checker, so adding Path here # avoids warnings. PathType = Union[str, Path, PathLike] @@ -58,13 +56,8 @@ def normalize_str_with_path(s: str) -> str: return s.replace(os.path.sep, "/") -# Instantiate in advance for _validator_nonempty_str -_instance_of_str = validators.instance_of(str) - - -def validator_nonempty_str(_inst: Any, attr: Attribute, value: Any) -> None: - # Check type - _instance_of_str(value, attr, value) +def validator_nonempty_str(value: Any) -> Any: # Check string isn't empty if not value: raise ValueError(f"Empty string: {repr(value)}") + return value diff --git a/seqscore/validation.py b/seqscore/validation.py index 24bff39..b626599 100644 --- a/seqscore/validation.py +++ b/seqscore/validation.py @@ -1,7 +1,7 @@ from collections.abc import Iterable, Sequence -from typing import Any, Optional +from typing import Annotated, Any, Optional -from attr import attrib, attrs +from pydantic import BaseModel, BeforeValidator from seqscore.encoding import _ENCODING_NAMES, Encoding, EncodingError from seqscore.util import tuplify_strs @@ -10,15 +10,14 @@ VALIDATION_SUPPORTED_ENCODINGS: Sequence[str] = tuple(_ENCODING_NAMES) -@attrs -class ValidationError: - msg: str = attrib() - label: str = attrib() - type: str = attrib() - state: str = attrib() - token: Optional[str] = attrib(default=None) - line_num: Optional[int] = attrib(default=None) - source_name: Optional[str] = attrib(default=None) +class ValidationError(BaseModel): + msg: str + label: str + type: str + state: str + token: Optional[str] = None + line_num: Optional[int] = None + source_name: Optional[str] = None class InvalidStateError(ValidationError): @@ -39,13 +38,12 @@ def tuplify_errors(errors: Iterable[ValidationError]) -> tuple[ValidationError, return tuple(errors) -@attrs -class SequenceValidationResult: - errors: Sequence[ValidationError] = attrib(converter=tuplify_errors) - n_tokens: int = attrib() - repaired_labels: Optional[tuple[str, ...]] = attrib( - converter=tuplify_strs, default=() - ) +class SequenceValidationResult(BaseModel): + errors: Annotated[Sequence[ValidationError], BeforeValidator(tuplify_errors)] + n_tokens: int + repaired_labels: Annotated[ + Optional[tuple[str, ...]], BeforeValidator(tuplify_strs) + ] = () def is_valid(self) -> bool: return not self.errors @@ -57,12 +55,11 @@ def __len__(self) -> int: return len(self.errors) -@attrs(frozen=True) -class ValidationResult: - errors: Sequence[ValidationError] = attrib(converter=tuplify_errors) - n_tokens: int = attrib() - n_sequences: int = attrib() - n_docs: int = attrib() +class ValidationResult(BaseModel): + errors: Annotated[Sequence[ValidationError], BeforeValidator(tuplify_errors)] + n_tokens: int + n_sequences: int + n_docs: int def validate_labels( @@ -120,7 +117,13 @@ def validate_labels( errors.append( InvalidStateError( - msg, label, entity_type, state, token, line_num, source_name + msg=msg, + label=label, + type=entity_type if entity_type else "", + state=state, + token=token, + line_num=line_num, + source_name=source_name, ) ) @@ -145,7 +148,13 @@ def validate_labels( errors.append( InvalidTransitionError( - msg, label, entity_type, state, token, line_num, source_name + msg=msg, + label=label, + type=entity_type if entity_type else "", + state=state, + token=token, + line_num=line_num, + source_name=source_name, ) ) prev_label, prev_state, prev_entity_type = ( @@ -175,12 +184,19 @@ def validate_labels( errors.append( InvalidTransitionError( - msg, prev_label, prev_entity_type, prev_state, token, line_num + msg=msg, + label=prev_label, + type=prev_entity_type, + state=prev_state, + token=token, + line_num=line_num, ) ) if errors and repair: repaired_labels = encoding.repair_labels(labels, repair) - return SequenceValidationResult(errors, len(labels), repaired_labels) + return SequenceValidationResult( + errors=errors, n_tokens=len(labels), repaired_labels=tuple(repaired_labels) + ) else: - return SequenceValidationResult(errors, len(labels)) + return SequenceValidationResult(errors=errors, n_tokens=len(labels)) diff --git a/setup.py b/setup.py index 0c227b3..4db1a61 100755 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ def setup_package() -> None: description="SeqScore: Scoring for named entity recognition and other sequence labeling tasks", long_description=long_description, install_requires=[ - "attrs>=19.2.0", + "pydantic>=2.11.7", "click", "tabulate", ], diff --git a/tests/test_conll_format.py b/tests/test_conll_format.py index 0ca1d08..d676852 100644 --- a/tests/test_conll_format.py +++ b/tests/test_conll_format.py @@ -9,7 +9,7 @@ def test_parse_comments_true() -> None: mention_encoding = get_encoding("BIO") - ingester = CoNLLIngester(mention_encoding, parse_comment_lines=True) + ingester = CoNLLIngester(encoding=mention_encoding, parse_comment_lines=True) comments_path = Path("tests") / "test_files" / "minimal_comments.bio" with comments_path.open(encoding="utf8") as file: documents = list(ingester.ingest(file, "test", REPAIR_NONE)) @@ -32,7 +32,7 @@ def test_parse_comments_true() -> None: def test_parse_comments_false() -> None: mention_encoding = get_encoding("BIO") - ingester = CoNLLIngester(mention_encoding) + ingester = CoNLLIngester(encoding=mention_encoding) comments_path = Path("tests") / "test_files" / "minimal_comments_1.bio" with comments_path.open(encoding="utf8") as file: diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 97c8ae6..392e24f 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -1,5 +1,5 @@ import pytest -from attr import attrs +from pydantic import BaseModel from seqscore.encoding import ( _ENCODING_NAMES, @@ -27,16 +27,16 @@ "BMEOW": ["W-PER", "O", "B-ORG", "E-ORG", "B-ORG", "M-ORG", "E-ORG", "W-LOC"], } FULL_SENTENCE_MENTS = [ - Mention(Span(0, 1), "PER"), - Mention(Span(2, 4), "ORG"), - Mention(Span(4, 7), "ORG"), - Mention(Span(7, 8), "LOC"), + Mention(span=Span(start=0, end=1), type="PER"), + Mention(span=Span(start=2, end=4), type="ORG"), + Mention(span=Span(start=4, end=7), type="ORG"), + Mention(span=Span(start=7, end=8), type="LOC"), ] # IO cannot faithfully encode this sentence, so there is just one org FULL_SENTENCE_MENTS_IO = [ - Mention(Span(0, 1), "PER"), - Mention(Span(2, 7), "ORG"), - Mention(Span(7, 8), "LOC"), + Mention(span=Span(start=0, end=1), type="PER"), + Mention(span=Span(start=2, end=7), type="ORG"), + Mention(span=Span(start=7, end=8), type="LOC"), ] # Map to sets of encodings that allow that state VALID_ENCODING_STATES = { @@ -51,8 +51,7 @@ } -@attrs(auto_attribs=True) -class EdgeTestSentence: +class EdgeTestSentence(BaseModel): name: str mentions: list[Mention] encoding_labels: list[tuple[list[str], list[str]]] @@ -60,9 +59,9 @@ class EdgeTestSentence: EDGE_TEST_SENTENCES = [ EdgeTestSentence( - "One token, one mention", - [Mention(Span(0, 1), "PER")], - [ + name="One token, one mention", + mentions=[Mention(span=Span(start=0, end=1), type="PER")], + encoding_labels=[ (["BIO"], ["B-PER"]), (["BIOES", "BMES"], ["S-PER"]), (["BILOU"], ["U-PER"]), @@ -71,9 +70,9 @@ class EdgeTestSentence: ], ), EdgeTestSentence( - "Two tokens, one mention covering them all", - [Mention(Span(0, 2), "PER")], - [ + name="Two tokens, one mention covering them all", + mentions=[Mention(span=Span(start=0, end=2), type="PER")], + encoding_labels=[ (["BIO"], ["B-PER", "I-PER"]), (["BIOES", "BMES", "BMEOW"], ["B-PER", "E-PER"]), (["BILOU"], ["B-PER", "L-PER"]), @@ -81,9 +80,9 @@ class EdgeTestSentence: ], ), EdgeTestSentence( - "Three tokens, one mention covering them all", - [Mention(Span(0, 3), "PER")], - [ + name="Three tokens, one mention covering them all", + mentions=[Mention(span=Span(start=0, end=3), type="PER")], + encoding_labels=[ (["BIO"], ["B-PER", "I-PER", "I-PER"]), (["BIOES"], ["B-PER", "I-PER", "E-PER"]), (["BMES", "BMEOW"], ["B-PER", "M-PER", "E-PER"]), @@ -92,9 +91,12 @@ class EdgeTestSentence: ], ), EdgeTestSentence( - "Adjacent same-type one-token mentions", - [Mention(Span(0, 1), "PER"), Mention(Span(1, 2), "PER")], - [ + name="Adjacent same-type one-token mentions", + mentions=[ + Mention(span=Span(start=0, end=1), type="PER"), + Mention(span=Span(start=1, end=2), type="PER"), + ], + encoding_labels=[ (["BIO"], ["B-PER", "B-PER"]), (["BIOES", "BMES"], ["S-PER", "S-PER"]), (["BILOU"], ["U-PER", "U-PER"]), @@ -104,9 +106,12 @@ class EdgeTestSentence: ], ), EdgeTestSentence( - "Adjacent different-type one-token mentions", - [Mention(Span(0, 1), "PER"), Mention(Span(1, 2), "ORG")], - [ + name="Adjacent different-type one-token mentions", + mentions=[ + Mention(span=Span(start=0, end=1), type="PER"), + Mention(span=Span(start=1, end=2), type="ORG"), + ], + encoding_labels=[ (["BIO"], ["B-PER", "B-ORG"]), (["BIOES", "BMES"], ["S-PER", "S-ORG"]), (["BILOU"], ["U-PER", "U-ORG"]), @@ -115,9 +120,12 @@ class EdgeTestSentence: ], ), EdgeTestSentence( - "Adjacent same-type two-token mentions", - [Mention(Span(0, 2), "PER"), Mention(Span(2, 4), "PER")], - [ + name="Adjacent same-type two-token mentions", + mentions=[ + Mention(span=Span(start=0, end=2), type="PER"), + Mention(span=Span(start=2, end=4), type="PER"), + ], + encoding_labels=[ (["BIO"], ["B-PER", "I-PER", "B-PER", "I-PER"]), (["BIOES", "BMES", "BMEOW"], ["B-PER", "E-PER", "B-PER", "E-PER"]), (["BILOU"], ["B-PER", "L-PER", "B-PER", "L-PER"]), @@ -126,9 +134,12 @@ class EdgeTestSentence: ], ), EdgeTestSentence( - "Adjacent different-type two-token mentions", - [Mention(Span(0, 2), "PER"), Mention(Span(2, 4), "ORG")], - [ + name="Adjacent different-type two-token mentions", + mentions=[ + Mention(span=Span(start=0, end=2), type="PER"), + Mention(span=Span(start=2, end=4), type="ORG"), + ], + encoding_labels=[ (["BIO"], ["B-PER", "I-PER", "B-ORG", "I-ORG"]), (["BIOES", "BMES", "BMEOW"], ["B-PER", "E-PER", "B-ORG", "E-ORG"]), (["BILOU"], ["B-PER", "L-PER", "B-ORG", "L-ORG"]), @@ -159,7 +170,11 @@ def test_basic_encoding() -> None: assert encoding.encode_mentions(mentions, len(labels)) == labels # Also test encoding sentence object, intentionally putting no mentions in the # sentence labels to make sure encoding using the mentions, not the labels - sentence = LabeledSequence(["a"] * len(labels), ["O"] * len(labels), mentions) + sentence = LabeledSequence( + tokens=["a"] * len(labels), + labels=["O"] * len(labels), + mentions=mentions, + ) assert encoding.encode_sequence(sentence) == labels @@ -267,21 +282,21 @@ def test_labeled_sequence() -> None: # Test length mismatch with pytest.raises(ValueError): LabeledSequence( - ["a"] * 10, - ["O"] * 9, + tokens=["a"] * 10, + labels=["O"] * 9, ) def test_decode_bio_invalid_continue() -> None: decoder = get_encoding("BIO") - sent1 = LabeledSequence(("a", "b"), ("B-PER", "I-LOC")) + sent1 = LabeledSequence(tokens=("a", "b"), labels=("B-PER", "I-LOC")) with pytest.raises(AssertionError): assert decoder.decode_sequence(sent1) def test_decode_iob_invalid_begin() -> None: decoder = get_encoding("IOB") - sent = LabeledSequence(("a", "b"), ("I-PER", "B-LOC")) + sent = LabeledSequence(tokens=("a", "b"), labels=("I-PER", "B-LOC")) with pytest.raises(AssertionError): assert decoder.decode_sequence(sent) @@ -289,8 +304,8 @@ def test_decode_iob_invalid_begin() -> None: def test_decode_bioes_invalid_start() -> None: decoder = get_encoding("BIOES") sents = [ - LabeledSequence(("a",), ("I-PER",)), - LabeledSequence(("a",), ("E-PER",)), + LabeledSequence(tokens=("a",), labels=("I-PER",)), + LabeledSequence(tokens=("a",), labels=("E-PER",)), ] for sent in sents: with pytest.raises(AssertionError): @@ -301,14 +316,14 @@ def test_decode_bioes_invalid_end() -> None: decoder = get_encoding("BIOES") sents = [ # Single-token mentions must start (and end) with S - LabeledSequence(("a", "b"), ("B-PER", "S-PER")), + LabeledSequence(tokens=("a", "b"), labels=("B-PER", "S-PER")), # Multi-token mentions must end in E - LabeledSequence(("a",), ("B-PER",)), - LabeledSequence(("a", "b"), ("B-PER", "I-PER")), + LabeledSequence(tokens=("a",), labels=("B-PER",)), + LabeledSequence(tokens=("a", "b"), labels=("B-PER", "I-PER")), # Ends with wrong type - LabeledSequence(("a", "b", "c"), ("B-PER", "I-PER", "E-ORG")), + LabeledSequence(tokens=("a", "b", "c"), labels=("B-PER", "I-PER", "E-ORG")), # Multi-token mentions cannot end in S - LabeledSequence(("a", "b", "c"), ("B-PER", "I-PER", "S-PER")), + LabeledSequence(tokens=("a", "b", "c"), labels=("B-PER", "I-PER", "S-PER")), ] for sent in sents: with pytest.raises(AssertionError): @@ -319,10 +334,10 @@ def test_decode_bioes_invalid_continue() -> None: decoder = get_encoding("BIOES") sents = [ # B must be followed by I or E of the same type - LabeledSequence(("a", "b"), ("B-PER", "B-PER")), + LabeledSequence(tokens=("a", "b"), labels=("B-PER", "B-PER")), # Cannot change types mid-mention - LabeledSequence(("a", "b"), ("B-PER", "E-ORG")), - LabeledSequence(("a", "b", "c"), ("B-PER", "I-PER", "E-ORG")), + LabeledSequence(tokens=("a", "b"), labels=("B-PER", "E-ORG")), + LabeledSequence(tokens=("a", "b", "c"), labels=("B-PER", "I-PER", "E-ORG")), ] for sent in sents: with pytest.raises(AssertionError): diff --git a/tests/test_model.py b/tests/test_model.py index b2e4732..ac415d5 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -4,25 +4,25 @@ def test_span() -> None: - assert len(Span(0, 1)) == 1 - assert len(Span(1, 2)) == 1 - assert len(Span(0, 2)) == 2 + assert len(Span(start=0, end=1)) == 1 + assert len(Span(start=1, end=2)) == 1 + assert len(Span(start=0, end=2)) == 2 with pytest.raises(ValueError): - Span(-1, 0) + Span(start=-1, end=0) with pytest.raises(ValueError): - Span(0, 0) + Span(start=0, end=0) def test_mention() -> None: - m1 = Mention(Span(0, 1), "PER") + m1 = Mention(span=Span(start=0, end=1), type="PER") assert m1.type == "PER" - assert m1.span == Span(0, 1) + assert m1.span == Span(start=0, end=1) assert len(m1) == 1 with pytest.raises(ValueError): - Mention(Span(0, 1), "") + Mention(span=Span(start=0, end=1), type="") with pytest.raises(TypeError): # Intentionally incorrect type @@ -31,44 +31,48 @@ def test_mention() -> None: def test_labeled_sentence() -> None: s1 = LabeledSequence( - ["a", "b"], - ["B-PER", "I-PER"], - provenance=SequenceProvenance(7, "test"), + tokens=["a", "b"], + labels=["B-PER", "I-PER"], + provenance=SequenceProvenance(starting_line=7, source="test"), ) assert s1.tokens == ("a", "b") assert s1[0] == "a" assert s1[0:2] == ("a", "b") - assert list(s1) == ["a", "b"] + assert list(s1.tokens) == ["a", "b"] assert s1.labels == ("B-PER", "I-PER") - assert s1.provenance == SequenceProvenance(7, "test") + assert s1.provenance == SequenceProvenance(starting_line=7, source="test") assert str(s1) == "a/B-PER b/I-PER" assert s1.tokens_with_labels() == (("a", "B-PER"), ("b", "I-PER")) - assert s1.span_tokens(Span(0, 1)) == ("a",) - assert s1.mention_tokens(Mention(Span(0, 1), "PER")) == ("a",) + assert s1.span_tokens(Span(start=0, end=1)) == ("a",) + assert s1.mention_tokens(Mention(span=Span(start=0, end=1), type="PER")) == ("a",) - s2 = LabeledSequence(s1.tokens, s1.labels) + s2 = LabeledSequence(tokens=s1.tokens, labels=s1.labels) # Provenance not included in equality assert s1 == s2 with pytest.raises(ValueError): # Mismatched length - LabeledSequence(["a", "b"], ["B-PER"]) + LabeledSequence(tokens=["a", "b"], labels=["B-PER"]) with pytest.raises(ValueError): # Empty - LabeledSequence([], []) + LabeledSequence(tokens=[], labels=[]) with pytest.raises(ValueError): # Bad label - LabeledSequence(["a"], [""]) + LabeledSequence(tokens=["a"], labels=[""]) with pytest.raises(ValueError): # Bad token - LabeledSequence([""], ["B-PER"]) + LabeledSequence(tokens=[""], labels=["B-PER"]) - s2 = s1.with_mentions([Mention(Span(0, 2), "PER")]) - assert s2.mentions == (Mention(Span(0, 2), "PER"),) + s2 = s1.with_mentions([Mention(span=Span(start=0, end=2), type="PER")]) + assert s2.mentions == (Mention(span=Span(start=0, end=2), type="PER"),) with pytest.raises(ValueError): # Mismatched length between tokens and other_fields - LabeledSequence(["a", "b"], ["B-PER", "I-PER"], other_fields=[["DT"]]) + LabeledSequence( + tokens=["a", "b"], + labels=["B-PER", "I-PER"], + other_fields=[["DT"]], + ) diff --git a/tests/test_scoring.py b/tests/test_scoring.py index ebb8881..b4cc085 100644 --- a/tests/test_scoring.py +++ b/tests/test_scoring.py @@ -46,8 +46,14 @@ def test_score_sentence_labels_invalid() -> None: def test_score_sentence_mentions_correct() -> None: - ref_mentions = [Mention(Span(0, 2), "PER"), Mention(Span(4, 5), "ORG")] - pred_mentions = [Mention(Span(0, 2), "PER"), Mention(Span(4, 5), "ORG")] + ref_mentions = [ + Mention(span=Span(start=0, end=2), type="PER"), + Mention(span=Span(start=4, end=5), type="ORG"), + ] + pred_mentions = [ + Mention(span=Span(start=0, end=2), type="PER"), + Mention(span=Span(start=4, end=5), type="ORG"), + ] score = ClassificationScore() score_sequence_mentions(pred_mentions, ref_mentions, score) assert score.true_pos == 2 @@ -66,18 +72,18 @@ def test_score_sentence_mentions_correct() -> None: def test_score_sentence_mentions_incorrect1() -> None: ref_mentions = [ - Mention(Span(0, 2), "LOC"), - Mention(Span(4, 5), "PER"), - Mention(Span(7, 8), "MISC"), - Mention(Span(9, 11), "MISC"), + Mention(span=Span(start=0, end=2), type="LOC"), + Mention(span=Span(start=4, end=5), type="PER"), + Mention(span=Span(start=7, end=8), type="MISC"), + Mention(span=Span(start=9, end=11), type="MISC"), ] pred_mentions = [ - Mention(Span(0, 2), "ORG"), - Mention(Span(4, 5), "PER"), + Mention(span=Span(start=0, end=2), type="ORG"), + Mention(span=Span(start=4, end=5), type="PER"), Mention( - Span(6, 7), "SPURIOUS" + span=Span(start=6, end=7), type="SPURIOUS" ), # Note that this type isn't even in the reference - Mention(Span(9, 11), "MISC"), + Mention(span=Span(start=9, end=11), type="MISC"), ] score = ClassificationScore() score_sequence_mentions(pred_mentions, ref_mentions, score) @@ -196,10 +202,14 @@ def test_token_count_error() -> None: ref_labels = ["O", "B-ORG", "I-ORG", "O"] pred_labels = ["O", "B-ORG", "I-ORG", "O", "O"] ref_sequence = LabeledSequence( - ["a", "b", "c", "d"], ref_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b", "c", "d"], + labels=ref_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) pred_sequence = LabeledSequence( - ["a", "b", "c", "d", "e"], pred_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b", "c", "d", "e"], + labels=pred_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) with pytest.raises(TokenCountError): compute_scores([[pred_sequence]], [[ref_sequence]]) @@ -207,7 +217,7 @@ def test_token_count_error() -> None: def test_provenance_none_raises_error() -> None: labels = ["O", "B-ORG"] - sequence = LabeledSequence(["a", "b"], labels, provenance=None) + sequence = LabeledSequence(tokens=["a", "b"], labels=labels, provenance=None) with pytest.raises(ValueError): TokenCountError.from_predicted_sequence(2, sequence) @@ -216,10 +226,14 @@ def test_differing_num_docs() -> None: ref_labels = ["O", "B-ORG"] pred_labels = ["O", "B-LOC"] ref_sequence = LabeledSequence( - ["a", "b"], ref_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b"], + labels=ref_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) pred_sequence = LabeledSequence( - ["a", "b"], pred_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b"], + labels=pred_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) with pytest.raises(ValueError): compute_scores([[pred_sequence]], [[ref_sequence], [ref_sequence]]) @@ -229,10 +243,14 @@ def test_differing_doc_length() -> None: ref_labels = ["O", "B-ORG"] pred_labels = ["O", "B-LOC"] ref_sequence = LabeledSequence( - ["a", "b"], ref_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b"], + labels=ref_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) pred_sequence = LabeledSequence( - ["a", "b"], pred_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b"], + labels=pred_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) with pytest.raises(ValueError): compute_scores([[pred_sequence]], [[ref_sequence, ref_sequence]]) @@ -242,10 +260,14 @@ def test_differing_pred_and_ref_tokens() -> None: ref_labels = ["O", "B-ORG"] pred_labels = ["O", "B-LOC"] ref_sequence = LabeledSequence( - ["a", "b"], ref_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "b"], + labels=ref_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) pred_sequence = LabeledSequence( - ["a", "c"], pred_labels, provenance=SequenceProvenance(0, "test") + tokens=["a", "c"], + labels=pred_labels, + provenance=SequenceProvenance(starting_line=0, source="test"), ) with pytest.raises(ValueError): compute_scores([[pred_sequence]], [[ref_sequence]]) diff --git a/tests/test_validation.py b/tests/test_validation.py index 614a928..80d83cb 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -1,12 +1,11 @@ import pytest -from attr import attrs +from pydantic import BaseModel from seqscore.encoding import REPAIR_NONE, EncodingError, get_encoding from seqscore.validation import validate_labels -@attrs(auto_attribs=True) -class RepairTest: +class RepairTest(BaseModel): original_labels: list[str] n_errors: int repaired_labels: dict[str, list[str]] @@ -14,39 +13,51 @@ class RepairTest: BIO_REPAIRS = [ RepairTest( - ["I-PER"], - 1, - {"conlleval": ["B-PER"], "discard": ["O"]}, + original_labels=["I-PER"], + n_errors=1, + repaired_labels={"conlleval": ["B-PER"], "discard": ["O"]}, ), RepairTest( - ["I-PER", "I-PER"], - 1, - {"conlleval": ["B-PER", "I-PER"], "discard": ["O", "O"]}, + original_labels=["I-PER", "I-PER"], + n_errors=1, + repaired_labels={"conlleval": ["B-PER", "I-PER"], "discard": ["O", "O"]}, ), RepairTest( - ["O", "I-PER", "I-PER"], - 1, - {"conlleval": ["O", "B-PER", "I-PER"], "discard": ["O", "O", "O"]}, + original_labels=["O", "I-PER", "I-PER"], + n_errors=1, + repaired_labels={ + "conlleval": ["O", "B-PER", "I-PER"], + "discard": ["O", "O", "O"], + }, ), RepairTest( - ["B-ORG", "I-PER", "I-PER"], - 1, - {"conlleval": ["B-ORG", "B-PER", "I-PER"], "discard": ["B-ORG", "O", "O"]}, + original_labels=["B-ORG", "I-PER", "I-PER"], + n_errors=1, + repaired_labels={ + "conlleval": ["B-ORG", "B-PER", "I-PER"], + "discard": ["B-ORG", "O", "O"], + }, ), RepairTest( - ["I-ORG", "I-PER", "I-PER"], - 2, - {"conlleval": ["B-ORG", "B-PER", "I-PER"], "discard": ["O", "O", "O"]}, + original_labels=["I-ORG", "I-PER", "I-PER"], + n_errors=2, + repaired_labels={ + "conlleval": ["B-ORG", "B-PER", "I-PER"], + "discard": ["O", "O", "O"], + }, ), RepairTest( - ["O", "I-ORG", "I-PER", "I-ORG"], - 3, - {"conlleval": ["O", "B-ORG", "B-PER", "B-ORG"], "discard": ["O", "O", "O", "O"]}, + original_labels=["O", "I-ORG", "I-PER", "I-ORG"], + n_errors=3, + repaired_labels={ + "conlleval": ["O", "B-ORG", "B-PER", "B-ORG"], + "discard": ["O", "O", "O", "O"], + }, ), RepairTest( - ["O", "B-ORG", "B-PER", "I-PER"], - 0, - { + original_labels=["O", "B-ORG", "B-PER", "I-PER"], + n_errors=0, + repaired_labels={ "conlleval": ["O", "B-ORG", "B-PER", "I-PER"], "discard": ["O", "B-ORG", "B-PER", "I-PER"], }, @@ -54,46 +65,46 @@ class RepairTest: ] IOB_REPAIRS = [ RepairTest( - ["B-PER"], - 1, - {"conlleval": ["I-PER"]}, + original_labels=["B-PER"], + n_errors=1, + repaired_labels={"conlleval": ["I-PER"]}, ), RepairTest( - ["B-PER", "I-PER"], - 1, - {"conlleval": ["I-PER", "I-PER"]}, + original_labels=["B-PER", "I-PER"], + n_errors=1, + repaired_labels={"conlleval": ["I-PER", "I-PER"]}, ), RepairTest( - ["O", "B-PER", "I-PER"], - 1, - {"conlleval": ["O", "I-PER", "I-PER"]}, + original_labels=["O", "B-PER", "I-PER"], + n_errors=1, + repaired_labels={"conlleval": ["O", "I-PER", "I-PER"]}, ), RepairTest( - ["B-ORG", "B-PER", "I-PER"], - 2, - {"conlleval": ["I-ORG", "I-PER", "I-PER"]}, + original_labels=["B-ORG", "B-PER", "I-PER"], + n_errors=2, + repaired_labels={"conlleval": ["I-ORG", "I-PER", "I-PER"]}, ), RepairTest( - ["I-ORG", "B-PER", "I-PER"], - 1, - {"conlleval": ["I-ORG", "I-PER", "I-PER"]}, + original_labels=["I-ORG", "B-PER", "I-PER"], + n_errors=1, + repaired_labels={"conlleval": ["I-ORG", "I-PER", "I-PER"]}, ), RepairTest( - ["O", "I-ORG", "B-PER", "I-ORG"], - 1, - {"conlleval": ["O", "I-ORG", "I-PER", "I-ORG"]}, + original_labels=["O", "I-ORG", "B-PER", "I-ORG"], + n_errors=1, + repaired_labels={"conlleval": ["O", "I-ORG", "I-PER", "I-ORG"]}, ), RepairTest( - ["O", "B-ORG", "B-PER", "I-PER"], - 2, - { + original_labels=["O", "B-ORG", "B-PER", "I-PER"], + n_errors=2, + repaired_labels={ "conlleval": ["O", "I-ORG", "I-PER", "I-PER"], }, ), RepairTest( - ["O", "B-ORG", "B-ORG", "I-PER"], - 1, - { + original_labels=["O", "B-ORG", "B-ORG", "I-PER"], + n_errors=1, + repaired_labels={ "conlleval": ["O", "I-ORG", "B-ORG", "I-PER"], }, ),