Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions packages/seqscore-lib/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[project]
name = "seqscore-lib"
version = "0.7.0"
description = "The library for the seqscore cli application"
requires-python = ">=3.9"
dependencies = []

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
File renamed without changes.
64 changes: 36 additions & 28 deletions seqscore/conll.py → ...es/seqscore-lib/src/seqscore_lib/conll.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@
TextIO,
)

from attr import attrib, attrs
from pydantic import BaseModel
from tabulate import tabulate

from seqscore.encoding import Encoding, EncodingError, get_encoding
from seqscore.model import LabeledSequence, SequenceProvenance
from seqscore.scoring import (
from seqscore_lib.encoding import Encoding, EncodingError, get_encoding
from seqscore_lib.model import LabeledSequence, SequenceProvenance
from seqscore_lib.scoring import (
AccuracyScore,
ClassificationScore,
compute_scores,
convert_score,
)
from seqscore.util import PathType
from seqscore.validation import (
from seqscore_lib.util import PathType
from seqscore_lib.validation import (
InvalidLabelError,
SequenceValidationResult,
ValidationResult,
Expand All @@ -43,13 +43,12 @@ class CoNLLFormatError(Exception):
pass


@attrs(frozen=True)
class _CoNLLToken:
text: str = attrib()
label: str = attrib()
is_docstart: bool = attrib()
line_num: int = attrib()
other_fields: tuple[str, ...] = attrib()
class _CoNLLToken(BaseModel, frozen=True):
text: str
label: str
is_docstart: bool
line_num: int
other_fields: tuple[str, ...]

@classmethod
def from_line(cls, line: str, line_num: int, source_name: str) -> "_CoNLLToken":
Expand All @@ -76,14 +75,19 @@ def from_line(cls, line: str, line_num: int, source_name: str) -> "_CoNLLToken":
label = splits[-1]
other_fields = tuple(splits[1:-1])
is_docstart = text == DOCSTART
return cls(text, label, is_docstart, line_num, other_fields)
return cls(
text=text,
label=label,
is_docstart=is_docstart,
line_num=line_num,
other_fields=other_fields,
)


@attrs(frozen=True)
class CoNLLIngester:
encoding: Encoding = attrib()
parse_comment_lines: bool = attrib(default=False, kw_only=True)
ignore_document_boundaries: bool = attrib(default=True, kw_only=True)
class CoNLLIngester(BaseModel, arbitrary_types_allowed=True, frozen=True):
encoding: Encoding
parse_comment_lines: bool = False
ignore_document_boundaries: bool = True

def ingest(
self,
Expand Down Expand Up @@ -183,11 +187,13 @@ def ingest(
) from e

sequences = LabeledSequence(
tokens,
labels,
mentions,
tokens=tokens,
labels=labels,
mentions=tuple(mentions),
other_fields=other_fields,
provenance=SequenceProvenance(line_nums[0], source_name),
provenance=SequenceProvenance(
starting_line=line_nums[0], source=source_name
),
comment=comment,
)
document.append(sequences)
Expand All @@ -197,7 +203,7 @@ def ingest(
document_counter += 1
yield document

def validate(
def conll_validate(
self, source: TextIO, source_name: str
) -> list[list[SequenceValidationResult]]:
all_results: list[list[SequenceValidationResult]] = []
Expand Down Expand Up @@ -330,7 +336,7 @@ def ingest_conll_file(
)

ingester = CoNLLIngester(
mention_encoding,
encoding=mention_encoding,
parse_comment_lines=parse_comment_lines,
ignore_document_boundaries=ignore_document_boundaries,
)
Expand All @@ -349,12 +355,12 @@ def validate_conll_file(
) -> ValidationResult:
encoding = get_encoding(mention_encoding_name)
ingester = CoNLLIngester(
encoding,
encoding=encoding,
parse_comment_lines=parse_comment_lines,
ignore_document_boundaries=ignore_document_boundaries,
)
with open(input_path, encoding=file_encoding) as input_file:
results = ingester.validate(input_file, input_path)
results = ingester.conll_validate(input_file, input_path)

n_docs = len(results)
n_sequences = sum(len(doc_results) for doc_results in results)
Expand All @@ -365,7 +371,9 @@ def validate_conll_file(
result.errors for doc_results in results for result in doc_results
)
)
return ValidationResult(errors, n_tokens, n_sequences, n_docs)
return ValidationResult(
errors=errors, n_tokens=n_tokens, n_sequences=n_sequences, n_docs=n_docs
)


def repair_conll_file(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from abc import abstractmethod
from collections.abc import Sequence
from functools import lru_cache
from typing import (
AbstractSet,
Optional,
Protocol,
)
from typing import AbstractSet, Optional, Protocol, runtime_checkable

from attr import Factory, attrib, attrs
from pydantic import BaseModel

from seqscore.model import LabeledSequence, Mention, Span
from seqscore_lib.model import LabeledSequence, Mention, Span

REPAIR_CONLL = "conlleval"
REPAIR_DISCARD = "discard"
Expand All @@ -19,6 +15,7 @@
DEFAULT_OUTSIDE = "O"


@runtime_checkable
class EncodingDialect(Protocol):
label_delim: str
outside: str
Expand Down Expand Up @@ -57,6 +54,7 @@ def __init__(self) -> None:
self.single = "W"


@runtime_checkable
class Encoding(Protocol):
dialect: EncodingDialect

Expand Down Expand Up @@ -660,11 +658,10 @@ def get_encoding(name: str) -> Encoding:
raise ValueError(f"Unknown encoder {repr(name)}")


@attrs
class _MentionBuilder:
start_idx: Optional[int] = attrib(default=None, init=False)
entity_type: Optional[str] = attrib(default=None, init=False)
mentions: list[Mention] = attrib(default=Factory(list), init=False)
class _MentionBuilder(BaseModel):
start_idx: Optional[int] = None
entity_type: Optional[str] = None
mentions: list[Mention] = []

def start_mention(self, start_idx: int, entity_type: str) -> None:
# Check arguments
Expand All @@ -690,7 +687,9 @@ def end_mention(self, end_idx: int) -> None:
assert self.start_idx is not None, "No mention start index"
assert self.entity_type is not None, "No mention entity type"

mention = Mention(Span(self.start_idx, end_idx), self.entity_type)
mention = Mention(
span=Span(start=self.start_idx, end=end_idx), type=self.entity_type
)
self.mentions.append(mention)

self.start_idx = None
Expand Down
89 changes: 48 additions & 41 deletions seqscore/model.py → ...es/seqscore-lib/src/seqscore_lib/model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
from collections.abc import Iterable, Iterator, Sequence
from collections.abc import Iterable, Sequence
from itertools import repeat
from typing import Any, Optional, Union, overload
from typing import Annotated, Any, Optional, Union, overload

from attr import Attribute, attrib, attrs
from pydantic import BaseModel, BeforeValidator, Field

from seqscore.util import (
tuplify_optional_nested_strs,
tuplify_strs,
validator_nonempty_str,
)


def _validator_nonnegative(_inst: Any, _attr: Attribute, value: Any) -> None:
def _validator_nonnegative(value: Any) -> Any:
if value < 0:
raise ValueError(f"Negative value: {repr(value)}")
else:
return value


def _tuplify_mentions(
Expand All @@ -22,12 +18,11 @@ def _tuplify_mentions(
return tuple(mentions)


@attrs(frozen=True, slots=True)
class Span:
start: int = attrib(validator=_validator_nonnegative)
end: int = attrib(validator=_validator_nonnegative)
class Span(BaseModel, frozen=True):
start: Annotated[int, BeforeValidator(_validator_nonnegative)]
end: Annotated[int, BeforeValidator(_validator_nonnegative)]

def __attrs_post_init__(self) -> None:
def model_post_init(self, context: Any) -> None:
if not self.end > self.start:
raise ValueError(
f"End of span ({self.end}) must be greater than start ({self.start}"
Expand All @@ -37,38 +32,31 @@ def __len__(self) -> int:
return self.end - self.start


@attrs(frozen=True, slots=True)
class Mention:
span: Span = attrib()
type: str = attrib(validator=validator_nonempty_str)
class Mention(BaseModel, frozen=True):
span: Span
type: str = Field(min_length=1, description="Must be a non-empty string")

def __len__(self) -> int:
return len(self.span)

def with_type(self, new_type: str) -> "Mention":
return Mention(self.span, new_type)
return Mention(span=self.span, type=new_type)


@attrs(frozen=True, slots=True)
class SequenceProvenance:
starting_line: int = attrib()
source: Optional[str] = attrib()
class SequenceProvenance(BaseModel, frozen=True):
starting_line: int
source: Optional[str]


@attrs(frozen=True, slots=True)
class LabeledSequence(Sequence[str]):
tokens: tuple[str, ...] = attrib(converter=tuplify_strs)
labels: tuple[str, ...] = attrib(converter=tuplify_strs)
mentions: tuple[Mention, ...] = attrib(default=(), converter=_tuplify_mentions)
other_fields: Optional[tuple[tuple[str, ...], ...]] = attrib(
default=None, kw_only=True, converter=tuplify_optional_nested_strs
)
provenance: Optional[SequenceProvenance] = attrib(
default=None, eq=False, kw_only=True
)
comment: Optional[str] = attrib(default=None, eq=False, kw_only=True)
class LabeledSequence(BaseModel, frozen=True):
tokens: tuple[str, ...]
labels: tuple[str, ...]
mentions: tuple[Mention, ...] = tuple()
other_fields: Optional[tuple[tuple[str, ...], ...]] = None
provenance: Optional[SequenceProvenance] = None
comment: Optional[str] = None

def __attrs_post_init__(self) -> None:
def model_post_init(self, context: Any) -> None:
# TODO: Check for overlapping mentions

if len(self.tokens) != len(self.labels):
Expand Down Expand Up @@ -97,7 +85,29 @@ def __attrs_post_init__(self) -> None:

def with_mentions(self, mentions: Sequence[Mention]) -> "LabeledSequence":
return LabeledSequence(
self.tokens, self.labels, mentions, provenance=self.provenance
tokens=self.tokens,
labels=self.labels,
mentions=tuple(mentions),
provenance=self.provenance,
)

# Pydantic doesn't support excluding certain fields when it generates
# its default `__hash__` method when `frozen=True`
# To get around that limitation, define custom `__hash__` method
def __hash__(self) -> int:
# Do not hash `provenance` and `comment`
return hash((self.tokens, self.labels, self.mentions, self.other_fields))

# Do not check eq with `provenance` and `comment` fields
def __eq__(self, other: object) -> bool:
if not isinstance(other, LabeledSequence):
return NotImplemented

return (
self.tokens == other.tokens
and self.labels == other.labels
and self.mentions == other.mentions
and self.other_fields == other.other_fields
)

@overload
Expand All @@ -111,9 +121,6 @@ def __getitem__(self, index: slice) -> tuple[str, ...]:
def __getitem__(self, i: Union[int, slice]) -> Union[str, tuple[str, ...]]:
return self.tokens[i]

def __iter__(self) -> Iterator[str]:
return iter(self.tokens)

def __len__(self) -> int:
# Guaranteed that labels and tokens are same length by construction
return len(self.tokens)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections.abc import Iterable

from seqscore.model import LabeledSequence, Mention
from seqscore_lib.model import LabeledSequence, Mention


class TypeMapper:
Expand Down
Loading
Loading