Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
[tool.mypy]
python_version = 3.9
python_version = "3.9"
strict_optional = false
disallow_untyped_defs = true
disallow_untyped_calls = true
plugins = ['pydantic.mypy']

[[tool.mypy.overrides]]
module = [
"setuptools",
"click.*",
]
module = ["setuptools", "click.*"]
ignore_missing_imports = true

[tool.ruff]
Expand Down
54 changes: 31 additions & 23 deletions seqscore/conll.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TextIO,
)

from attr import attrib, attrs
from pydantic import BaseModel
from tabulate import tabulate

from seqscore.encoding import Encoding, EncodingError, get_encoding
Expand Down Expand Up @@ -43,13 +43,12 @@ class CoNLLFormatError(Exception):
pass


@attrs(frozen=True)
class _CoNLLToken:
text: str = attrib()
label: str = attrib()
is_docstart: bool = attrib()
line_num: int = attrib()
other_fields: tuple[str, ...] = attrib()
class _CoNLLToken(BaseModel, frozen=True):
text: str
label: str
is_docstart: bool
line_num: int
other_fields: tuple[str, ...]

@classmethod
def from_line(cls, line: str, line_num: int, source_name: str) -> "_CoNLLToken":
Expand All @@ -76,14 +75,19 @@ def from_line(cls, line: str, line_num: int, source_name: str) -> "_CoNLLToken":
label = splits[-1]
other_fields = tuple(splits[1:-1])
is_docstart = text == DOCSTART
return cls(text, label, is_docstart, line_num, other_fields)
return cls(
text=text,
label=label,
is_docstart=is_docstart,
line_num=line_num,
other_fields=other_fields,
)


@attrs(frozen=True)
class CoNLLIngester:
encoding: Encoding = attrib()
parse_comment_lines: bool = attrib(default=False, kw_only=True)
ignore_document_boundaries: bool = attrib(default=True, kw_only=True)
class CoNLLIngester(BaseModel, arbitrary_types_allowed=True, frozen=True):
encoding: Encoding
parse_comment_lines: bool = False
ignore_document_boundaries: bool = True

def ingest(
self,
Expand Down Expand Up @@ -183,11 +187,13 @@ def ingest(
) from e

sequences = LabeledSequence(
tokens,
labels,
mentions,
tokens=tokens,
labels=labels,
mentions=tuple(mentions),
other_fields=other_fields,
provenance=SequenceProvenance(line_nums[0], source_name),
provenance=SequenceProvenance(
starting_line=line_nums[0], source=source_name
),
comment=comment,
)
document.append(sequences)
Expand All @@ -197,7 +203,7 @@ def ingest(
document_counter += 1
yield document

def validate(
def conll_validate(
self, source: TextIO, source_name: str
) -> list[list[SequenceValidationResult]]:
all_results: list[list[SequenceValidationResult]] = []
Expand Down Expand Up @@ -330,7 +336,7 @@ def ingest_conll_file(
)

ingester = CoNLLIngester(
mention_encoding,
encoding=mention_encoding,
parse_comment_lines=parse_comment_lines,
ignore_document_boundaries=ignore_document_boundaries,
)
Expand All @@ -349,12 +355,12 @@ def validate_conll_file(
) -> ValidationResult:
encoding = get_encoding(mention_encoding_name)
ingester = CoNLLIngester(
encoding,
encoding=encoding,
parse_comment_lines=parse_comment_lines,
ignore_document_boundaries=ignore_document_boundaries,
)
with open(input_path, encoding=file_encoding) as input_file:
results = ingester.validate(input_file, input_path)
results = ingester.conll_validate(input_file, input_path)

n_docs = len(results)
n_sequences = sum(len(doc_results) for doc_results in results)
Expand All @@ -365,7 +371,9 @@ def validate_conll_file(
result.errors for doc_results in results for result in doc_results
)
)
return ValidationResult(errors, n_tokens, n_sequences, n_docs)
return ValidationResult(
errors=errors, n_tokens=n_tokens, n_sequences=n_sequences, n_docs=n_docs
)


def repair_conll_file(
Expand Down
23 changes: 11 additions & 12 deletions seqscore/encoding.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from abc import abstractmethod
from collections.abc import Sequence
from functools import lru_cache
from typing import (
AbstractSet,
Optional,
Protocol,
)
from typing import AbstractSet, Optional, Protocol, runtime_checkable

from attr import Factory, attrib, attrs
from pydantic import BaseModel

from seqscore.model import LabeledSequence, Mention, Span

Expand All @@ -19,6 +15,7 @@
DEFAULT_OUTSIDE = "O"


@runtime_checkable
class EncodingDialect(Protocol):
label_delim: str
outside: str
Expand Down Expand Up @@ -57,6 +54,7 @@ def __init__(self) -> None:
self.single = "W"


@runtime_checkable
class Encoding(Protocol):
dialect: EncodingDialect

Expand Down Expand Up @@ -660,11 +658,10 @@ def get_encoding(name: str) -> Encoding:
raise ValueError(f"Unknown encoder {repr(name)}")


@attrs
class _MentionBuilder:
start_idx: Optional[int] = attrib(default=None, init=False)
entity_type: Optional[str] = attrib(default=None, init=False)
mentions: list[Mention] = attrib(default=Factory(list), init=False)
class _MentionBuilder(BaseModel):
start_idx: Optional[int] = None
entity_type: Optional[str] = None
mentions: list[Mention] = []

def start_mention(self, start_idx: int, entity_type: str) -> None:
# Check arguments
Expand All @@ -690,7 +687,9 @@ def end_mention(self, end_idx: int) -> None:
assert self.start_idx is not None, "No mention start index"
assert self.entity_type is not None, "No mention entity type"

mention = Mention(Span(self.start_idx, end_idx), self.entity_type)
mention = Mention(
span=Span(start=self.start_idx, end=end_idx), type=self.entity_type
)
self.mentions.append(mention)

self.start_idx = None
Expand Down
89 changes: 48 additions & 41 deletions seqscore/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Sequence
from itertools import repeat
from typing import Any, Optional, Union, overload
from typing import Annotated, Any, Optional, Union, overload

from attr import Attribute, attrib, attrs
from pydantic import BaseModel, BeforeValidator, Field

from seqscore.util import (
tuplify_optional_nested_strs,
tuplify_strs,
validator_nonempty_str,
)


def _validator_nonnegative(_inst: Any, _attr: Attribute, value: Any) -> None:
def _validator_nonnegative(value: Any) -> Any:
if value < 0:
raise ValueError(f"Negative value: {repr(value)}")
else:
return value


def _tuplify_mentions(
Expand All @@ -22,12 +18,11 @@ def _tuplify_mentions(
return tuple(mentions)


@attrs(frozen=True, slots=True)
class Span:
start: int = attrib(validator=_validator_nonnegative)
end: int = attrib(validator=_validator_nonnegative)
class Span(BaseModel, frozen=True):
start: Annotated[int, BeforeValidator(_validator_nonnegative)]
end: Annotated[int, BeforeValidator(_validator_nonnegative)]

def __attrs_post_init__(self) -> None:
def model_post_init(self, context: Any) -> None:
if not self.end > self.start:
raise ValueError(
f"End of span ({self.end}) must be greater than start ({self.start}"
Expand All @@ -37,38 +32,31 @@ def __len__(self) -> int:
return self.end - self.start


@attrs(frozen=True, slots=True)
class Mention:
span: Span = attrib()
type: str = attrib(validator=validator_nonempty_str)
class Mention(BaseModel, frozen=True):
span: Span
type: str = Field(min_length=1, description="Must be a non-empty string")

def __len__(self) -> int:
return len(self.span)

def with_type(self, new_type: str) -> "Mention":
return Mention(self.span, new_type)
return Mention(span=self.span, type=new_type)


@attrs(frozen=True, slots=True)
class SequenceProvenance:
starting_line: int = attrib()
source: Optional[str] = attrib()
class SequenceProvenance(BaseModel, frozen=True):
starting_line: int
source: Optional[str]


@attrs(frozen=True, slots=True)
class LabeledSequence(Sequence[str]):
tokens: tuple[str, ...] = attrib(converter=tuplify_strs)
labels: tuple[str, ...] = attrib(converter=tuplify_strs)
mentions: tuple[Mention, ...] = attrib(default=(), converter=_tuplify_mentions)
other_fields: Optional[tuple[tuple[str, ...], ...]] = attrib(
default=None, kw_only=True, converter=tuplify_optional_nested_strs
)
provenance: Optional[SequenceProvenance] = attrib(
default=None, eq=False, kw_only=True
)
comment: Optional[str] = attrib(default=None, eq=False, kw_only=True)
class LabeledSequence(BaseModel, frozen=True):
tokens: tuple[str, ...]
labels: tuple[str, ...]
mentions: tuple[Mention, ...] = tuple()
other_fields: Optional[tuple[tuple[str, ...], ...]] = None
provenance: Optional[SequenceProvenance] = None
comment: Optional[str] = None

def __attrs_post_init__(self) -> None:
def model_post_init(self, context: Any) -> None:
# TODO: Check for overlapping mentions

if len(self.tokens) != len(self.labels):
Expand Down Expand Up @@ -97,7 +85,29 @@ def __attrs_post_init__(self) -> None:

def with_mentions(self, mentions: Sequence[Mention]) -> "LabeledSequence":
return LabeledSequence(
self.tokens, self.labels, mentions, provenance=self.provenance
tokens=self.tokens,
labels=self.labels,
mentions=tuple(mentions),
provenance=self.provenance,
)

# Pydantic doesn't support excluding certain fields when it generates
# its default `__hash__` method when `frozen=True`
# To get around that limitation, define custom `__hash__` method
def __hash__(self) -> int:
# Do not hash `provenance` and `comment`
return hash((self.tokens, self.labels, self.mentions, self.other_fields))

# Do not check eq with `provenance` and `comment` fields
def __eq__(self, other: object) -> bool:
if not isinstance(other, LabeledSequence):
return NotImplemented

return (
self.tokens == other.tokens
and self.labels == other.labels
and self.mentions == other.mentions
and self.other_fields == other.other_fields
)

@overload
Expand All @@ -111,9 +121,6 @@ def __getitem__(self, index: slice) -> tuple[str, ...]:
def __getitem__(self, i: Union[int, slice]) -> Union[str, tuple[str, ...]]:
return self.tokens[i]

def __iter__(self) -> Iterator[str]:
return iter(self.tokens)

def __len__(self) -> int:
# Guaranteed that labels and tokens are same length by construction
return len(self.tokens)
Expand Down
Loading