diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d38ad55..c4f349e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ default_stages: [pre-commit] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: check-ast - id: check-vcs-permalinks diff --git a/demos/sayt/sayt_artifact_example.py b/demos/sayt/sayt_artifact_example.py new file mode 100644 index 0000000..bbf0b37 --- /dev/null +++ b/demos/sayt/sayt_artifact_example.py @@ -0,0 +1,94 @@ +"""Build and reload a persisted SAYT artifact with the built-in retrievers.""" + +# pylint: disable=duplicate-code + +# %% +import json +from pathlib import Path +from tempfile import TemporaryDirectory + +from industrial_classification_utils.sayt import ( + NgramRetrieverSpec, + PrefixRetrieverSpec, + SAYTBuilder, + SAYTSuggester, + SemanticRetrieverSpec, +) + +# %% +############# toy example to verify SAYT artifact build/load works ############# +small_corpus = [ + ("Car wash", "Car Wash"), + ("Car wash", "CAR WASH (duplicate)"), + ("Car waxing", "Car Waxing"), + ("Waxing car", "Car Waxing"), + ("Carpentry services", "Carpentry services"), + ("Dog grooming", "Dog grooming"), + ("Cat grooming", "Cat grooming"), + ("USed car sales", "Used car sales"), + ("Car rental", "Car rental"), + ("Car repair", "Car repair"), + ("Car servicing", "Car servicing"), +] + +retrievers = [ + PrefixRetrieverSpec(), + NgramRetrieverSpec(max_df=0.8), + SemanticRetrieverSpec(), +] + +# Keep the temporary directory alive across notebook cells. +# pylint: disable-next=consider-using-with +temp_dir = TemporaryDirectory(prefix="sayt_artifact_demo_") +artifact_dir = Path(temp_dir.name) / "car_services_sayt" +print("artifact will be written to:", artifact_dir) + +# %% +# Semantic artifact builds may take longer the first time if the model cache +# needs to be created locally. +artifact_path = SAYTBuilder( + small_corpus, + retrievers=retrievers, + min_chars=3, + max_suggestions=5, +).build_artifact(artifact_dir, overwrite=True) + +print("artifact saved to:", artifact_path) +print("artifact files:") +for path in sorted(artifact_path.rglob("*")): + if path.is_file(): + print("-", path.relative_to(artifact_path)) + +# %% +manifest = json.loads((artifact_path / "manifest.json").read_text(encoding="utf-8")) +print(json.dumps(manifest, indent=2)) + +# %% +live_suggester = SAYTSuggester( + small_corpus, + retrievers=retrievers, + min_chars=3, + max_suggestions=5, +) +loaded_suggester = SAYTSuggester.from_artifact(artifact_path) + +for query in ["car", "cars", "waxi", "grom", "wash", "duplicate", "auto"]: + live_suggestions = live_suggester.suggest(query, 5) + loaded_suggestions = loaded_suggester.suggest(query, 5) + + print("searching for:", query) + print("live", "->", live_suggestions) + print("loaded", "->", loaded_suggestions) + print("loaded_scores", "->", loaded_suggester.suggest_with_scores(query, 5)) + if live_suggestions != loaded_suggestions: + raise RuntimeError("Loaded suggester results did not match live build") + print() + +# %% +# Run `temp_dir.cleanup()` when you are finished exploring the saved files. +print("artifact ready for inspection:", artifact_path) + +# %% +temp_dir.cleanup() + +# %% diff --git a/demos/sayt/sayt_example.py b/demos/sayt/sayt_example.py index 35d207a..03130a7 100644 --- a/demos/sayt/sayt_example.py +++ b/demos/sayt/sayt_example.py @@ -11,7 +11,7 @@ SAYTSuggester, SemanticRetrieverSpec, ) -from industrial_classification_utils.sayt.sayt_core import _normalise +from industrial_classification_utils.sayt.core import _normalise logger = get_logger(__name__) # %% diff --git a/pyproject.toml b/pyproject.toml index ff0bbaa..02b3c5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,8 @@ ignore = [ "E501", # indentation contains tabs "W191", + # Allow the use of older TypeVar syntax for generic functions + "UP047", ] [tool.ruff.lint.pydocstyle] diff --git a/src/industrial_classification_utils/sayt/__init__.py b/src/industrial_classification_utils/sayt/__init__.py index 8c5e03c..66180d9 100644 --- a/src/industrial_classification_utils/sayt/__init__.py +++ b/src/industrial_classification_utils/sayt/__init__.py @@ -1,7 +1,9 @@ """Public SAYT interfaces and built-in retriever components.""" -from .sayt import SAYTSuggester -from .sayt_retriever_specs import ( +from .builder import SAYTBuilder +from .core import SaytConfiguration +from .retriever_specs import ( + ArtifactRetrieverSpec, NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, @@ -9,16 +11,20 @@ SemanticRetrieverSpec, default_retriever_specs, ) -from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever +from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever +from .suggester import SAYTSuggester __all__ = [ + "ArtifactRetrieverSpec", "NgramRetriever", "NgramRetrieverSpec", "PrefixRetriever", "PrefixRetrieverSpec", "Retriever", "RetrieverSpec", + "SAYTBuilder", "SAYTSuggester", + "SaytConfiguration", "SemanticRetriever", "SemanticRetrieverSpec", "default_retriever_specs", diff --git a/src/industrial_classification_utils/sayt/_base.py b/src/industrial_classification_utils/sayt/_base.py new file mode 100644 index 0000000..cbe181b --- /dev/null +++ b/src/industrial_classification_utils/sayt/_base.py @@ -0,0 +1,59 @@ +"""Shared bootstrap helpers for corpus-bound SAYT classes.""" + +import os +from collections.abc import Iterable, Sequence + +from .core import CleanCorpus, validate_max_suggestions, validate_min_chars +from .retriever_specs import RetrieverSpec, default_retriever_specs +from .storage import load_corpus_from_csv + + +class BaseCorpusBound: # pylint: disable=too-few-public-methods + """Shared corpus/retriever bootstrap for SAYT runtime classes.""" + + _corpus: CleanCorpus + _min_chars: int + _max_suggestions: int + _retriever_specs: tuple[RetrieverSpec, ...] + + def __init__( + self, + corpus: Iterable[tuple[object, object]] | Iterable[str], + *, + retrievers: Sequence[RetrieverSpec] | None = None, + min_chars: int = 4, + max_suggestions: int = 10, + ) -> None: + """Validate and store the shared corpus-bound SAYT configuration.""" + self._corpus = CleanCorpus.model_validate(corpus) + self._min_chars = validate_min_chars(min_chars) + self._max_suggestions = validate_max_suggestions(max_suggestions) + self._retriever_specs = tuple( + default_retriever_specs() if retrievers is None else retrievers + ) + + @classmethod + def from_csv[ # pylint: disable=too-many-arguments # noqa: PLR0913 + CorpusBoundT: "BaseCorpusBound" + ]( + cls: type[CorpusBoundT], + file_path: str | os.PathLike, + *, + search_text_col: str = "title", + display_text_col: str | None = None, + retrievers: Sequence[RetrieverSpec] | None = None, + min_chars: int = 4, + max_suggestions: int = 10, + ) -> CorpusBoundT: + """Build a corpus-bound SAYT object from CSV input.""" + corpus_rows = load_corpus_from_csv( + file_path, + search_text_col=search_text_col, + display_text_col=display_text_col, + ) + return cls( + corpus_rows, + retrievers=retrievers, + min_chars=min_chars, + max_suggestions=max_suggestions, + ) diff --git a/src/industrial_classification_utils/sayt/builder.py b/src/industrial_classification_utils/sayt/builder.py new file mode 100644 index 0000000..41d2c54 --- /dev/null +++ b/src/industrial_classification_utils/sayt/builder.py @@ -0,0 +1,85 @@ +"""Offline artifact builder for persisted SAYT runtime assets.""" + +import os +import shutil +import tempfile +from pathlib import Path +from uuid import uuid4 + +from ._base import BaseCorpusBound +from .storage import ( + build_artifact_manifest, + build_retriever_artifact, + write_artifact_corpus, + write_artifact_manifest, +) + + +def _remove_path(path: Path) -> None: + if not path.exists(): + return + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + + +class SAYTBuilder(BaseCorpusBound): + """Build a persisted SAYT artifact for later runtime loading.""" + + def build_artifact( + self, + output_dir: str | os.PathLike, + *, + overwrite: bool = False, + ) -> Path: + """Persist the current SAYT configuration and dense stores to disk.""" + artifact_dir = Path(output_dir) + if artifact_dir.exists() and not overwrite: + raise FileExistsError("Artifact directory already exists") + + artifact_dir.parent.mkdir(parents=True, exist_ok=True) + staged_dir = Path( + tempfile.mkdtemp( + prefix=f".{artifact_dir.name}.tmp-", + dir=artifact_dir.parent, + ) + ) + + try: + manifest = build_artifact_manifest( + corpus=self._corpus, + min_chars=self._min_chars, + max_suggestions=self._max_suggestions, + retriever_specs=self._retriever_specs, + ) + + write_artifact_corpus(self._corpus, artifact_dir=staged_dir) + for stored_retriever in manifest.retrievers: + build_retriever_artifact( + corpus=self._corpus, + min_chars=self._min_chars, + stored_retriever=stored_retriever, + artifact_dir=staged_dir, + ) + + write_artifact_manifest(manifest, artifact_dir=staged_dir) + + if artifact_dir.exists(): + backup_dir = ( + artifact_dir.parent / f".{artifact_dir.name}.bak-{uuid4().hex}" + ) + artifact_dir.rename(backup_dir) + try: + staged_dir.rename(artifact_dir) + except Exception: + backup_dir.rename(artifact_dir) + raise + _remove_path(backup_dir) + else: + staged_dir.rename(artifact_dir) + except Exception: + _remove_path(staged_dir) + raise + + return artifact_dir diff --git a/src/industrial_classification_utils/sayt/core.py b/src/industrial_classification_utils/sayt/core.py new file mode 100644 index 0000000..7be4855 --- /dev/null +++ b/src/industrial_classification_utils/sayt/core.py @@ -0,0 +1,308 @@ +"""Core SAYT data models, corpus cleaning, and ranking helpers.""" + +# ruff: noqa: PLR2004 + +import re +import warnings +from collections.abc import Iterable +from dataclasses import dataclass +from typing import Any, cast +from uuid import NAMESPACE_URL, uuid5 + +import pandas as pd +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator + +_WS_RE = re.compile(r"\s+") +_NON_ALNUM_SPACE_RE = re.compile(r"[^a-z ]+") + + +def _normalise(text: object) -> str: + if not isinstance(text, str): + return "" + text = text.strip().lower() + text = _NON_ALNUM_SPACE_RE.sub(" ", text) + text = _WS_RE.sub(" ", text).strip() + return text + + +def _row_uid(index: int, search_text: str, display_text: str) -> str: + return str(uuid5(NAMESPACE_URL, f"{index}\0{search_text}\0{display_text}")) + + +@dataclass(frozen=True, slots=True) +class PersistedCorpusRow: + """Represent a persisted SAYT corpus row restored from artifact storage.""" + + row_id: str + search_text: str + display_text: str + + +class CleanCorpus(BaseModel): + """Store cleaned SAYT rows and their derived lookup tables. + + Instances are created from raw strings or ``(search_text, display_text)`` + pairs and retain stable row identifiers for downstream score aggregation. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + rows: list[tuple[str, str, str]] = Field(default_factory=list) + size: int = 0 + _id_to_search: dict[str, str] = PrivateAttr(default_factory=dict) + _id_to_display: dict[str, str] = PrivateAttr(default_factory=dict) + _display_text_count: dict[str, int] = PrivateAttr(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def _coerce_input(cls, data: object) -> object: + if isinstance(data, cls): + return data + if isinstance(data, dict): + if "rows" in data: + return data + if "corpus" in data: + data = data["corpus"] + + return { + "rows": cls._clean_corpus( + cast(Iterable[str] | Iterable[tuple[object, object]], data) + ) + } + + @property + def id_to_search(self) -> dict[str, str]: + """Return the row-id to normalised-search-text lookup.""" + return self._id_to_search + + @property + def id_to_display(self) -> dict[str, str]: + """Return the row-id to display-text lookup.""" + return self._id_to_display + + @property + def display_text_count(self) -> dict[str, int]: + """Return per-display-text occurrence counts for the cleaned corpus.""" + return self._display_text_count + + # Pylint does not understand Pydantic's model_post_init signature here. + def model_post_init( # pylint: disable=arguments-differ + self, __context: Any + ) -> None: + self._populate_indexes() + + def _populate_indexes(self) -> "CleanCorpus": + """Rebuild lookup tables from the current cleaned rows.""" + self._id_to_search = {rid: search for rid, search, _ in self.rows} + self._id_to_display = {rid: display for rid, _, display in self.rows} + self._display_text_count = {} + for _, _, display in self.rows: + self._display_text_count[display] = ( + self._display_text_count.get(display, 0) + 1 + ) + self.size = len(self.rows) + return self + + @classmethod + def from_persisted_rows( + cls, + rows: Iterable[PersistedCorpusRow | tuple[str, str, str]], + ) -> "CleanCorpus": + """Restore a cleaned corpus from persisted row identifiers and text. + + Args: + rows: Persisted ``(row_id, search_text, display_text)`` triples or + ``PersistedCorpusRow`` objects. + + Returns: + A ``CleanCorpus`` whose row identifiers and lookup maps match the + persisted artifact data exactly. + + Raises: + ValueError: If no persisted rows are supplied. + """ + restored_rows = [cls._coerce_persisted_row(row) for row in rows] + if not restored_rows: + raise ValueError("corpus is empty after filtering") + + return cls.model_construct(rows=restored_rows) + + @staticmethod + def _coerce_persisted_row( + row: PersistedCorpusRow | tuple[str, str, str], + ) -> tuple[str, str, str]: + """Convert persisted row data into the internal tuple format.""" + if isinstance(row, PersistedCorpusRow): + return (row.row_id, row.search_text, row.display_text) + row_id, search_text, display_text = row + return (str(row_id), str(search_text), str(display_text)) + + @staticmethod + def _clean_corpus( + corpus: Iterable[str] | Iterable[tuple[object, object]], + ) -> list[tuple[str, str, str]]: + if not isinstance(corpus, Iterable): + raise TypeError( + "corpus must be an iterable of strings or (string, original) tuples" + ) + cleaned: list[tuple[str, str]] = [] + for item in corpus: + item_tuple = item if isinstance(item, tuple) else (item, item) + text = _normalise(item_tuple[0]) + if not text or text == "-9": + warnings.warn( + f"Skipping empty or invalid corpus item: {item!r}", + stacklevel=2, + ) + continue + display = str(item_tuple[1]).strip() + if pd.isna(item_tuple[1]) or not display: + warnings.warn( + f"Empty display value for item: {item!r}, using search text as display", + stacklevel=2, + ) + display = str(item_tuple[0]).strip() + cleaned.append((text, display)) + if not cleaned: + raise ValueError("corpus is empty after filtering") + return [ + (_row_uid(i, norm, display), norm, display) + for i, (norm, display) in enumerate(cleaned) + ] + + +def _coerce_sayt_int_setting(value: object, *, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int | str): + raise TypeError(f"{field_name} must be an integer") + try: + return int(value) + except ValueError as exc: + raise TypeError(f"{field_name} must be an integer") from exc + + +def validate_min_chars(value: object) -> int: + """Validate the global SAYT minimum query length setting.""" + min_chars = _coerce_sayt_int_setting(value, field_name="min_chars") + if min_chars < 3: + raise ValueError("min_chars must be >= 3") + return min_chars + + +def validate_max_suggestions(value: object) -> int: + """Validate the global SAYT maximum suggestion count setting.""" + max_suggestions = _coerce_sayt_int_setting(value, field_name="max_suggestions") + if not 1 <= max_suggestions <= 100: + raise ValueError("max_suggestions must be between 1 and 100") + return max_suggestions + + +class SaytGlobalSettings(BaseModel): + """Describe suggester-wide runtime settings.""" + + model_config = ConfigDict(extra="forbid") + + min_chars: int + max_suggestions: int + + +class SaytCorpusSummary(BaseModel): + """Summarise the cleaned corpus bound to a suggester.""" + + model_config = ConfigDict(extra="forbid") + + size: int + unique_display_texts: int + max_duplication: int + + +class SaytRetrieverArtifactProvenance(BaseModel): + """Capture persisted artifact details for one retriever entry.""" + + model_config = ConfigDict(extra="forbid") + + artifact_type: str + path: str | None = None + config: dict[str, Any] = Field(default_factory=dict) + + +class SaytArtifactProvenance(BaseModel): + """Describe the artifact source of a suggester restored from disk.""" + + model_config = ConfigDict(extra="forbid") + + artifact_dir: str + artifact_type: str + artifact_version: int + corpus_file: str + corpus_size: int + + +class SaytRetrieverSummary(BaseModel): + """Summarise one configured retriever within a suggester.""" + + model_config = ConfigDict(extra="forbid") + + name: str + spec_type: str + retriever_type: str + configured_weight: float + normalised_weight: float + config: dict[str, Any] = Field(default_factory=dict) + artifact_provenance: SaytRetrieverArtifactProvenance | None = None + + +class SaytConfiguration(BaseModel): + """Return a rich runtime summary of a configured suggester.""" + + model_config = ConfigDict(extra="forbid") + + settings: SaytGlobalSettings + corpus: SaytCorpusSummary + retrievers: list[SaytRetrieverSummary] = Field(default_factory=list) + artifact_provenance: SaytArtifactProvenance | None = None + + +@dataclass(frozen=True, slots=True) +class Suggestion: + """Represent a SAYT match with score and row metadata. + + The meaning of ``score`` depends on the producer. Concrete retrievers emit + strategy-local scores, while ``SAYTSuggester.suggest_with_scores`` returns + the combined weighted score. + """ + + display_text: str + score: float + search_text: str = "" + row_id: str = "" + + +def take_with_ties( + items: list[tuple[str, float]], + limit: int, +) -> list[tuple[str, float]]: + """Return the first ``limit`` items and any later items tied on score. + + Args: + items: Scored ``(key, score)`` pairs to rank. + limit: Maximum number of leading items before tie extension is applied. + + Returns: + The highest-scoring items up to ``limit``, plus any later items that are + tied with the cutoff score. + """ + if limit < 1 or not items: + return [] + + items = sorted( + items, + key=lambda kv: -kv[1], + ) + + if limit >= len(items): + return items + + cutoff_score = float(items[limit - 1][1]) + end = limit + while end < len(items) and float(items[end][1]) == cutoff_score: + end += 1 + return items[:end] diff --git a/src/industrial_classification_utils/sayt/sayt_indexes.py b/src/industrial_classification_utils/sayt/indexes.py similarity index 57% rename from src/industrial_classification_utils/sayt/sayt_indexes.py rename to src/industrial_classification_utils/sayt/indexes.py index d923140..619e5ea 100644 --- a/src/industrial_classification_utils/sayt/sayt_indexes.py +++ b/src/industrial_classification_utils/sayt/indexes.py @@ -7,6 +7,7 @@ import tempfile from contextlib import contextmanager from dataclasses import dataclass +from pathlib import Path from typing import cast import classifai.indexers.main as classifai_indexers_main @@ -16,7 +17,7 @@ from scipy.sparse import csr_matrix from sklearn.feature_extraction.text import CountVectorizer -from .sayt_core import CleanCorpus, take_with_ties +from .core import CleanCorpus, take_with_ties def _silent_tqdm(iterable, **_kwargs): @@ -49,6 +50,8 @@ def from_corpus( *, corpus: CleanCorpus, vectoriser: VectoriserBase, + output_dir: str | os.PathLike[str] | None = None, + overwrite: bool = True, ) -> "DenseVectorIndex": """Build a dense index from a cleaned corpus. @@ -56,20 +59,24 @@ def from_corpus( corpus: Cleaned corpus whose normalised search text should be indexed. vectoriser: Vectoriser used to embed corpus rows and future queries. + output_dir: Optional persistent filespace directory for the + underlying ClassifAI vector store. When omitted, a temporary + directory is used. + overwrite: Whether to allow ClassifAI to replace an existing + filespace when ``output_dir`` is provided. Returns: A ``DenseVectorIndex`` backed by ClassifAI's ``VectorStore``. """ - with tempfile.TemporaryDirectory(prefix="sayt_") as tmp_dir: - csv_path = os.path.join(tmp_dir, "corpus.csv") - - with open(csv_path, "w", newline="", encoding="utf-8") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=["label", "text"]) - writer.writeheader() - writer.writerows( - {"label": row_id, "text": search_text} - for row_id, search_text, _ in corpus.rows - ) + with tempfile.TemporaryDirectory(prefix="sayt_") as temp_dir: + csv_path = os.path.join(temp_dir, "corpus.csv") + classifai_output_dir = ( + os.fspath(output_dir) + if output_dir is not None + else os.path.join(temp_dir, "vector_store") + ) + + cls._write_corpus_csv(corpus, csv_path) with _silence_classifai_tqdm(): vector_store = VectorStore( @@ -77,8 +84,8 @@ def from_corpus( data_type="csv", vectoriser=vectoriser, batch_size=64, - output_dir=None, - overwrite=True, + output_dir=classifai_output_dir, + overwrite=overwrite, hooks=None, ) @@ -88,6 +95,54 @@ def from_corpus( _corpus=corpus, ) + @classmethod + def from_filespace( + cls, + *, + corpus: CleanCorpus, + folder_path: str | os.PathLike[str], + vectoriser: VectoriserBase, + ) -> "DenseVectorIndex": + """Load a dense index from a persisted ClassifAI filespace. + + Args: + corpus: Cleaned corpus whose row metadata should back query results. + folder_path: Filesystem directory containing ``metadata.json`` and + ``vectors.parquet``. + vectoriser: Vectoriser used to embed future query text. + + Returns: + A ``DenseVectorIndex`` backed by a loaded ``VectorStore``. + """ + with _silence_classifai_tqdm(): + vector_store = VectorStore.from_filespace( + folder_path=os.fspath(folder_path), + vectoriser=vectoriser, + hooks=None, + ) + + return cls( + _vector_store=vector_store, + _num_vectors=int(vector_store.num_vectors or 0), + _corpus=corpus, + ) + + @staticmethod + def _write_corpus_csv( + corpus: CleanCorpus, + csv_path: str | os.PathLike[str], + ) -> None: + """Write the row-id and search-text schema expected by ClassifAI.""" + csv_file = Path(csv_path) + csv_file.parent.mkdir(parents=True, exist_ok=True) + with open(csv_file, "w", newline="", encoding="utf-8") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=["label", "text"]) + writer.writeheader() + writer.writerows( + {"label": row_id, "text": search_text} + for row_id, search_text, _ in corpus.rows + ) + def query(self, q_norm: str, num_suggestions: int) -> list[tuple[str, float]]: """Query the dense index with a normalised string. @@ -157,6 +212,8 @@ def build_ngram_index( *, n: int, max_df: float, + output_dir: str | os.PathLike[str] | None = None, + overwrite: bool = True, ) -> DenseVectorIndex: """Build a dense index backed by character n-gram vectors. @@ -164,6 +221,10 @@ def build_ngram_index( corpus: Cleaned corpus to index. n: Character n-gram size. max_df: Maximum document frequency passed to ``CountVectorizer``. + output_dir: Optional persistent filespace directory for the generated + vector store. + overwrite: Whether to allow ClassifAI to replace an existing filespace + when ``output_dir`` is provided. Returns: A dense index using character n-gram embeddings. @@ -175,6 +236,27 @@ def build_ngram_index( n=n, max_df=max_df, ), + output_dir=output_dir, + overwrite=overwrite, + ) + + +def load_ngram_index( + corpus: CleanCorpus, + *, + n: int, + max_df: float, + folder_path: str | os.PathLike[str], +) -> DenseVectorIndex: + """Load a persisted dense index backed by character n-gram vectors.""" + return DenseVectorIndex.from_filespace( + corpus=corpus, + folder_path=folder_path, + vectoriser=_CharNgramVectoriser( + [search for _, search, _ in corpus.rows], + n=n, + max_df=max_df, + ), ) @@ -182,12 +264,18 @@ def build_semantic_index( corpus: CleanCorpus, *, model: str, + output_dir: str | os.PathLike[str] | None = None, + overwrite: bool = True, ) -> DenseVectorIndex: """Build a dense index backed by sentence-transformer embeddings. Args: corpus: Cleaned corpus to index. model: Sentence-transformer model name without the repository prefix. + output_dir: Optional persistent filespace directory for the generated + vector store. + overwrite: Whether to allow ClassifAI to replace an existing filespace + when ``output_dir`` is provided. Returns: A dense index using semantic embeddings. @@ -199,4 +287,24 @@ def build_semantic_index( return DenseVectorIndex.from_corpus( corpus=corpus, vectoriser=semantic_vectoriser, + output_dir=output_dir, + overwrite=overwrite, + ) + + +def load_semantic_index( + corpus: CleanCorpus, + *, + model: str, + folder_path: str | os.PathLike[str], +) -> DenseVectorIndex: + """Load a persisted dense index backed by semantic embeddings.""" + base_vectoriser: VectoriserBase = HuggingFaceVectoriser( + f"sentence-transformers/{model}" + ) + semantic_vectoriser: VectoriserBase = _L2NormalisingVectoriser(base_vectoriser) + return DenseVectorIndex.from_filespace( + corpus=corpus, + folder_path=folder_path, + vectoriser=semantic_vectoriser, ) diff --git a/src/industrial_classification_utils/sayt/retriever_specs.py b/src/industrial_classification_utils/sayt/retriever_specs.py new file mode 100644 index 0000000..e6288dc --- /dev/null +++ b/src/industrial_classification_utils/sayt/retriever_specs.py @@ -0,0 +1,315 @@ +# pylint: disable=too-few-public-methods + +"""Public retriever protocols and configuration objects for SAYT.""" + +import math +from dataclasses import dataclass, field +from pathlib import Path +from typing import Protocol, runtime_checkable + +from .core import CleanCorpus, Suggestion +from .indexes import ( + build_ngram_index, + build_semantic_index, + load_ngram_index, + load_semantic_index, +) +from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever + +_MIN_NGRAM_SIZE = 2 +_MAX_NGRAM_SIZE = 5 + + +class Retriever(Protocol): + """Query contract used by the SAYT orchestrator.""" + + def suggest_with_scores( + self, q_norm: str, num_suggestions: int + ) -> list[Suggestion]: + """Return scored suggestions for a normalised query string. + + Args: + q_norm: Normalised query text. + num_suggestions: Maximum number of scored suggestions to return. + + Returns: + Ranked ``Suggestion`` objects for the query. + """ + + +class RetrieverSpec(Protocol): + """Configuration plus builder for a corpus-bound retriever instance.""" + + @property + def name(self) -> str: + """Return the stable identifier for this retriever configuration.""" + + @property + def weight(self) -> float: + """Return the finite positive weight applied during score combination.""" + + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: + """Build a corpus-bound retriever instance from this configuration. + + Args: + corpus: Cleaned corpus to bind to the retriever. + min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional persisted filespace directory used when a + caller wants the build outputs written to disk. + overwrite: Whether an existing filespace may be replaced when + ``filespace_path`` is provided. + + Returns: + A configured retriever instance bound to ``corpus``. + """ + + +@runtime_checkable +class ArtifactRetrieverSpec(RetrieverSpec, Protocol): + """Optional artifact-loading capability for persisted retriever specs.""" + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore a runtime retriever from persisted artifact state.""" + + +def _validate_retriever_weight(weight: float) -> None: + if not math.isfinite(weight) or weight <= 0: + raise ValueError("retriever weight must be a finite value > 0") + + +def _require_filespace_path( + filespace_path: str | Path | None, + *, + spec_name: str, +) -> str | Path: + if filespace_path is None: + raise ValueError(f"Retriever '{spec_name}' does not have a stored filespace") + return filespace_path + + +@dataclass(frozen=True, slots=True) +class PrefixRetrieverSpec: + """Configuration for building a prefix retriever.""" + + weight: float = 1.0 + name: str = field(init=False, default="prefix") + + def __post_init__(self) -> None: + """Validate configuration values after dataclass initialisation.""" + _validate_retriever_weight(self.weight) + + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: + """Build a prefix retriever for the provided cleaned corpus. + + Args: + corpus: Cleaned corpus to search. + min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional persisted filespace path. Prefix + retrievers ignore this because they do not store dense state. + overwrite: Ignored for prefix retrievers. + + Returns: + A configured ``PrefixRetriever``. + """ + _ = (filespace_path, overwrite) + return PrefixRetriever(corpus, min_chars=min_chars) + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore a prefix retriever directly from its runtime config.""" + _ = filespace_path + return self.build(corpus, min_chars=min_chars) + + +@dataclass(frozen=True, slots=True) +class NgramRetrieverSpec: + """Configuration for building a character n-gram retriever.""" + + weight: float = 1.0 + n: int = 3 + max_df: float = 0.2 + name: str = field(init=False, default="ngram") + + def __post_init__(self) -> None: + """Validate n-gram configuration values after initialisation.""" + _validate_retriever_weight(self.weight) + if not _MIN_NGRAM_SIZE <= self.n <= _MAX_NGRAM_SIZE: + raise ValueError("ngram n must be between 2 and 5") + if not 0.0 < self.max_df <= 1.0: + raise ValueError("ngram max_df must be in (0, 1]") + + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: + """Build a character n-gram retriever for the provided corpus. + + Args: + corpus: Cleaned corpus to search. + min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional filespace directory. When provided, the + built dense index is persisted there. + overwrite: Whether an existing filespace may be replaced when + ``filespace_path`` is provided. + + Returns: + A configured ``NgramRetriever``. + + Raises: + ValueError: If ``max_df`` would remove every n-gram feature from the + provided corpus. + """ + if self.max_df * corpus.size < 1: + raise ValueError("ngram max_df is too low for the given corpus") + if filespace_path is None: + return NgramRetriever( + corpus, + n=self.n, + max_df=self.max_df, + min_chars=min_chars, + ) + + index = build_ngram_index( + corpus, + n=self.n, + max_df=self.max_df, + output_dir=_require_filespace_path(filespace_path, spec_name=self.name), + overwrite=overwrite, + ) + return NgramRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore an n-gram retriever from its persisted dense filespace.""" + index = load_ngram_index( + corpus, + n=self.n, + max_df=self.max_df, + folder_path=_require_filespace_path(filespace_path, spec_name=self.name), + ) + return NgramRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + +@dataclass(frozen=True, slots=True) +class SemanticRetrieverSpec: + """Configuration for building a semantic retriever.""" + + weight: float = 1.0 + model: str = "all-MiniLM-L6-v2" + name: str = field(init=False, default="semantic") + + def __post_init__(self) -> None: + """Validate semantic retriever configuration after initialisation.""" + _validate_retriever_weight(self.weight) + if not isinstance(self.model, str) or not self.model.strip(): + raise ValueError("semantic model must be a non-empty string") + + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: + """Build a semantic retriever for the provided cleaned corpus. + + Args: + corpus: Cleaned corpus to search. + min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional filespace directory. When provided, the + built dense index is persisted there. + overwrite: Whether an existing filespace may be replaced when + ``filespace_path`` is provided. + + Returns: + A configured ``SemanticRetriever``. + """ + if filespace_path is None: + return SemanticRetriever(corpus, model=self.model, min_chars=min_chars) + + index = build_semantic_index( + corpus, + model=self.model, + output_dir=_require_filespace_path(filespace_path, spec_name=self.name), + overwrite=overwrite, + ) + return SemanticRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore a semantic retriever from its persisted dense filespace.""" + index = load_semantic_index( + corpus, + model=self.model, + folder_path=_require_filespace_path(filespace_path, spec_name=self.name), + ) + return SemanticRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + +def default_retriever_specs() -> list[RetrieverSpec]: + """Return the standard runtime retriever set used by SAYT. + + Returns: + The default prefix, character n-gram, and semantic retriever specs. + """ + return [ + PrefixRetrieverSpec(), + NgramRetrieverSpec(), + SemanticRetrieverSpec(), + ] diff --git a/src/industrial_classification_utils/sayt/sayt_retrievers.py b/src/industrial_classification_utils/sayt/retrievers.py similarity index 91% rename from src/industrial_classification_utils/sayt/sayt_retrievers.py rename to src/industrial_classification_utils/sayt/retrievers.py index 6bb1f1d..e8d35c2 100644 --- a/src/industrial_classification_utils/sayt/sayt_retrievers.py +++ b/src/industrial_classification_utils/sayt/retrievers.py @@ -6,8 +6,8 @@ from dataclasses import dataclass from difflib import SequenceMatcher -from .sayt_core import CleanCorpus, Suggestion, take_with_ties -from .sayt_indexes import DenseVectorIndex, build_ngram_index, build_semantic_index +from .core import CleanCorpus, Suggestion, take_with_ties +from .indexes import DenseVectorIndex, build_ngram_index, build_semantic_index _FUZZY_PREFIX_MIN_RATIO = 0.75 @@ -109,6 +109,21 @@ class _DenseRetriever: _min_chars: int _index: DenseVectorIndex + @classmethod + def from_index( + cls, + corpus: CleanCorpus, + *, + min_chars: int, + index: DenseVectorIndex, + ) -> "_DenseRetriever": + """Restore a dense retriever from an already-built dense index.""" + retriever = cls.__new__(cls) + retriever._corpus = corpus + retriever._min_chars = min_chars + retriever._index = index + return retriever + def suggest_with_scores( self, q_norm: str, num_suggestions: int ) -> list[Suggestion]: diff --git a/src/industrial_classification_utils/sayt/sayt.py b/src/industrial_classification_utils/sayt/sayt.py deleted file mode 100644 index ce41021..0000000 --- a/src/industrial_classification_utils/sayt/sayt.py +++ /dev/null @@ -1,322 +0,0 @@ -"""Search-as-you-type (SAYT) orchestration. - -This module provides the public suggester API that coordinates configured -retrievers and combines their scores into ranked suggestions. -""" - -import math -import os -from collections.abc import Iterable, Sequence -from dataclasses import dataclass - -import pandas as pd -from survey_assist_utils.logging import get_logger - -from .sayt_core import ( - CleanCorpus, - SaytConfig, - Suggestion, - _normalise, - take_with_ties, -) -from .sayt_retriever_specs import ( - Retriever, - RetrieverSpec, - default_retriever_specs, -) - -logger = get_logger(__name__) - - -@dataclass(frozen=True, slots=True) -class _ConfiguredRetriever: - """Runtime retriever binding with its configured contribution weight.""" - - name: str - weight: float - retriever: Retriever - - -class SAYTSuggester: - """Suggest free-text responses as a user types. - - The suggester: - - validates and cleans the supplied corpus - - builds the configured retrievers for that corpus - - combines retriever-local scores into a shared weighted ranking - - By default it uses the standard prefix, n-gram, and semantic retriever - specifications. Use ``retrievers=`` to override that mix. - - Suggester-wide settings are currently passed as keyword arguments and - validated by ``SaytConfig``. At present these include: - - ``min_chars``: minimum query length before retrieval runs - - ``max_suggestions``: default maximum number of ranked suggestions to return - - Examples: - Basic usage with an in-memory corpus: - - ```python - from industrial_classification_utils.sayt import SAYTSuggester - - suggester = SAYTSuggester( - corpus=[ - ("Car wash", "Car Wash"), - ("Dog grooming", "Dog grooming"), - ], - min_chars=3, - max_suggestions=5, - ) - - results = suggester.suggest("car") - ``` - - Usage with custom retriever specifications: - - ```python - from industrial_classification_utils.sayt import ( - PrefixRetrieverSpec, - SAYTSuggester, - ) - - suggester = SAYTSuggester( - corpus=[("Car wash", "Car Wash")], - retrievers=[PrefixRetrieverSpec()], - min_chars=3, - ) - ``` - """ - - def __init__( - self, - corpus: Iterable[tuple[str, str]] | Iterable[str], - *, - retrievers: Sequence[RetrieverSpec] | None = None, - **kwargs: object, - ) -> None: - """Initialise a suggester for a cleaned response corpus. - - Args: - corpus: Iterable of search strings or ``(search_text, display_text)`` - pairs. - retrievers: Optional retriever specifications. When omitted, the - standard prefix, n-gram, and semantic spec set is used. - **kwargs: Suggester-wide keyword arguments validated by - ``SaytConfig``, currently including ``min_chars`` and - ``max_suggestions``. - """ - self._corpus = CleanCorpus.model_validate(corpus) - self._config = SaytConfig.model_validate(kwargs) - self._max_duplication = max(self._corpus.display_text_count.values(), default=0) - - self._retriever_specs = tuple( - default_retriever_specs() if retrievers is None else retrievers - ) - self._retrievers = self._build_retrievers(self._retriever_specs) - logger.info(f"SAYT suggester initialized with config: {self.get_config()}") - - def _build_retrievers( - self, retriever_specs: Sequence[RetrieverSpec] - ) -> list[_ConfiguredRetriever]: - if not retriever_specs: - raise ValueError("At least one retriever must be configured") - - validated_specs: list[tuple[RetrieverSpec, float]] = [] - for spec in retriever_specs: - weight = float(spec.weight) - if not math.isfinite(weight) or weight <= 0: - raise ValueError( - f"Retriever '{spec.name}' weight must be a finite value > 0" - ) - validated_specs.append((spec, weight)) - - total_weight = sum(weight for _, weight in validated_specs) - - return [ - _ConfiguredRetriever( - name=spec.name, - weight=weight / total_weight, - retriever=spec.build( - self._corpus, - min_chars=self._config.min_chars, - ), - ) - for spec, weight in validated_specs - ] - - @classmethod - def from_csv( - cls, - file_path: str | os.PathLike, - *, - search_text_col: str = "title", - display_text_col: str | None = None, - retrievers: Sequence[RetrieverSpec] | None = None, - **kwargs: object, - ) -> "SAYTSuggester": - """Build a suggester from CSV input. - - Args: - file_path: Path to the CSV file containing suggestion rows. - search_text_col: Column containing the searchable text. - display_text_col: Optional column containing display text. When - omitted, the search column is reused for display values. - retrievers: Optional retriever specifications. When omitted, the - standard retriever set is used. - **kwargs: Keyword arguments validated by ``SaytConfig``. - - Returns: - A configured ``SAYTSuggester`` instance. - - Raises: - ValueError: If the requested search or display column is missing. - """ - df = pd.read_csv(file_path) - if search_text_col not in df.columns: - raise ValueError(f"Column '{search_text_col}' not found in CSV") - if display_text_col is None: - display_text_col = search_text_col - if display_text_col not in df.columns: - raise ValueError(f"Column '{display_text_col}' not found in CSV") - return cls( - list(zip(df[search_text_col], df[display_text_col], strict=False)), - retrievers=retrievers, - **kwargs, - ) - - def _dedup_suggestions( - self, suggestions: list[Suggestion] - ) -> list[tuple[str, float]]: - # sort by score and deduplicate by display text, keeping the highest-scoring variant. - sorted_suggestions = sorted( - suggestions, - key=lambda s: ( - -s.score, - -self._corpus.display_text_count.get(s.display_text, 0), - s.display_text.lower(), - s.row_id, - ), - ) - seen: set[str] = set() - deduped: list[tuple[str, float]] = [] - for s in sorted_suggestions: - display_text = s.display_text # lower? normalised? - if display_text not in seen: - deduped.append((display_text, s.score)) - seen.add(display_text) - return deduped - - def _combine_suggestions( - self, - result_groups: Iterable[tuple[float, list[Suggestion]]], - ) -> list[tuple[str, float]]: - def normalise_scores( - items: list[Suggestion], weight: float - ) -> dict[str, float]: - if not items: - return {} - max_score = max((float(s.score) for s in items), default=0.0) - if max_score <= 0: - return {} - out: dict[str, float] = {} - for s in items: - if not s.row_id: - continue - out[s.row_id] = max( - out.get(s.row_id, 0.0), float(s.score) / max_score * weight - ) - return out - - combined_scores: dict[str, float] = {} - for weight, suggestions in result_groups: - d = normalise_scores(suggestions, weight) - for k, v in d.items(): - combined_scores[k] = combined_scores.get(k, 0.0) + v - - return [(row_id, float(score)) for row_id, score in combined_scores.items()] - - def _collect_retriever_results( - self, q_norm: str, num_suggestions: int - ) -> list[tuple[float, list[Suggestion]]]: - return [ - ( - configured_retriever.weight, - configured_retriever.retriever.suggest_with_scores( - q_norm, - num_suggestions=num_suggestions, - ), - ) - for configured_retriever in self._retrievers - ] - - def suggest_with_scores( - self, query: str | None, num_suggestions: int | None = None - ) -> list[Suggestion]: - """Return ranked suggestions and their combined scores. - - Args: - query: Raw user query text. - num_suggestions: Optional maximum number of ranked suggestions to - return. When omitted, the configured default is used. - - Returns: - A list of combined suggestions ordered by descending score. Returns - an empty list when the normalised query is shorter than - ``SaytConfig.min_chars``. - """ - if num_suggestions is None: - num_suggestions = self._config.max_suggestions - q_norm = _normalise(query) - if len(q_norm) < self._config.min_chars: - return [] - - # Ask for more suggestions, as some may be filtered out after deduplication - results_by_kind = self._collect_retriever_results( - q_norm, - num_suggestions=10 * num_suggestions, - ) - - combined_result = self._combine_suggestions(results_by_kind) - ranked_results = take_with_ties(combined_result, num_suggestions) - out = [ - Suggestion( - row_id=row_id, - display_text=self._corpus.id_to_display.get(row_id, ""), - score=score, - search_text=self._corpus.id_to_search.get(row_id, ""), - ) - for row_id, score in ranked_results - ] - - return out - - def suggest( - self, query: str | None, num_suggestions: int | None = None - ) -> list[str]: - """Return deduplicated display-text suggestions. - - Args: - query: Raw user query text. - num_suggestions: Optional maximum number of display values to - return. When omitted, the configured default is used. - - Returns: - A list of display-text suggestions ordered by descending combined - score, while preserving ties at the cutoff. - """ - if num_suggestions is None: - num_suggestions = self._config.max_suggestions - results = self.suggest_with_scores( - query, num_suggestions=num_suggestions * self._max_duplication - ) - dedup_results = self._dedup_suggestions(results) - ranked_results = take_with_ties(dedup_results, num_suggestions) - return [result[0] for result in ranked_results] - - def get_config(self) -> SaytConfig: - """Return a copy of the validated suggester configuration. - - Returns: - A deep copy of the ``SaytConfig`` used by this suggester. - """ - return self._config.model_copy(deep=True) diff --git a/src/industrial_classification_utils/sayt/sayt_core.py b/src/industrial_classification_utils/sayt/sayt_core.py deleted file mode 100644 index b0b40f0..0000000 --- a/src/industrial_classification_utils/sayt/sayt_core.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Core SAYT data models, corpus cleaning, and ranking helpers.""" - -# ruff: noqa: PLR2004 - -import re -import warnings -from collections.abc import Iterable -from dataclasses import dataclass -from typing import cast -from uuid import NAMESPACE_URL, uuid5 - -import pandas as pd -from pydantic import BaseModel, ConfigDict, Field, model_validator - -_WS_RE = re.compile(r"\s+") -_NON_ALNUM_SPACE_RE = re.compile(r"[^a-z ]+") - - -def _normalise(text: str | None) -> str: - if not isinstance(text, str): - return "" - text = text.strip().lower() - text = _NON_ALNUM_SPACE_RE.sub(" ", text) - text = _WS_RE.sub(" ", text).strip() - return text - - -def _row_uid(index: int, search_text: str, display_text: str) -> str: - return str(uuid5(NAMESPACE_URL, f"{index}\0{search_text}\0{display_text}")) - - -class CleanCorpus(BaseModel): - """Store cleaned SAYT rows and their derived lookup tables. - - Instances are created from raw strings or ``(search_text, display_text)`` - pairs and retain stable row identifiers for downstream score aggregation. - """ - - model_config = ConfigDict(arbitrary_types_allowed=True) - corpus: object - rows: list[tuple[str, str, str]] = Field(default_factory=list) - id_to_search: dict[str, str] = Field(default_factory=dict) - id_to_display: dict[str, str] = Field(default_factory=dict) - display_text_count: dict[str, int] = Field(default_factory=dict) - size: int = 0 - - @model_validator(mode="before") - @classmethod - def _coerce_input(cls, data: object) -> object: - if isinstance(data, cls | dict): - return data - return {"corpus": data} - - @model_validator(mode="after") - def _build_indexes(self) -> "CleanCorpus": - self.rows = self._clean_corpus( - cast(Iterable[str] | Iterable[tuple[str, str]], self.corpus) - ) - self.id_to_search = {rid: search for rid, search, _ in self.rows} - self.id_to_display = {rid: display for rid, _, display in self.rows} - self.display_text_count = {} - for _, _, display in self.rows: - self.display_text_count[display] = ( - self.display_text_count.get(display, 0) + 1 - ) - self.size = len(self.rows) - return self - - @staticmethod - def _clean_corpus( - corpus: Iterable[str] | Iterable[tuple[str, str]], - ) -> list[tuple[str, str, str]]: - if not isinstance(corpus, Iterable): - raise TypeError( - "corpus must be an iterable of strings or (string, original) tuples" - ) - cleaned: list[tuple[str, str]] = [] - for item in corpus: - item_tuple = item if isinstance(item, tuple) else (item, item) - text = _normalise(item_tuple[0]) - if not text or text == "-9": - warnings.warn( - f"Skipping empty or invalid corpus item: {item!r}", - stacklevel=2, - ) - continue - display = str(item_tuple[1]).strip() - if pd.isna(item_tuple[1]) or not display: - warnings.warn( - f"Empty display value for item: {item!r}, using search text as display", - stacklevel=2, - ) - display = str(item_tuple[0]).strip() - cleaned.append((text, display)) - if not cleaned: - raise ValueError("corpus is empty after filtering") - return [ - (_row_uid(i, norm, display), norm, display) - for i, (norm, display) in enumerate(cleaned) - ] - - -class SaytConfig(BaseModel): - """Validated configuration for a SAYT suggester instance. - - This model contains only suggester-wide settings. Retriever-specific - configuration lives on individual ``RetrieverSpec`` objects. - """ - - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - - min_chars: int = 4 - max_suggestions: int = 10 - - @model_validator(mode="after") - def _validate_ranges(self) -> "SaytConfig": - """Enforce the supported numeric ranges for SAYT settings.""" - if self.min_chars < 3: - raise ValueError("min_chars must be >= 3") - if not 1 <= self.max_suggestions <= 100: - raise ValueError("max_suggestions must be between 1 and 100") - return self - - -@dataclass(frozen=True, slots=True) -class Suggestion: - """Represent a SAYT match with score and row metadata. - - The meaning of ``score`` depends on the producer. Concrete retrievers emit - strategy-local scores, while ``SAYTSuggester.suggest_with_scores`` returns - the combined weighted score. - """ - - display_text: str - score: float - search_text: str = "" - row_id: str = "" - - -def take_with_ties( - items: list[tuple[str, float]], - limit: int, -) -> list[tuple[str, float]]: - """Return the first ``limit`` items and any later items tied on score. - - Args: - items: Scored ``(key, score)`` pairs to rank. - limit: Maximum number of leading items before tie extension is applied. - - Returns: - The highest-scoring items up to ``limit``, plus any later items that are - tied with the cutoff score. - """ - if limit < 1 or not items: - return [] - - items = sorted( - items, - key=lambda kv: (-kv[1],), - ) - - if limit >= len(items): - return items - - cutoff_score = float(items[limit - 1][1]) - end = limit - while end < len(items) and float(items[end][1]) == cutoff_score: - end += 1 - return items[:end] diff --git a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py b/src/industrial_classification_utils/sayt/sayt_retriever_specs.py deleted file mode 100644 index a45df4d..0000000 --- a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py +++ /dev/null @@ -1,163 +0,0 @@ -# pylint: disable=too-few-public-methods - -"""Public retriever protocols and configuration objects for SAYT.""" - -import math -from dataclasses import dataclass, field -from typing import Protocol - -from .sayt_core import CleanCorpus, Suggestion -from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever - -_MIN_NGRAM_SIZE = 2 -_MAX_NGRAM_SIZE = 5 - - -class Retriever(Protocol): - """Query contract used by the SAYT orchestrator.""" - - def suggest_with_scores( - self, q_norm: str, num_suggestions: int - ) -> list[Suggestion]: - """Return scored suggestions for a normalised query string. - - Args: - q_norm: Normalised query text. - num_suggestions: Maximum number of scored suggestions to return. - - Returns: - Ranked ``Suggestion`` objects for the query. - """ - - -class RetrieverSpec(Protocol): - """Configuration plus builder for a corpus-bound retriever instance.""" - - @property - def name(self) -> str: - """Return the stable identifier for this retriever configuration.""" - - @property - def weight(self) -> float: - """Return the finite positive weight applied during score combination.""" - - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: - """Build a corpus-bound retriever instance from this configuration. - - Args: - corpus: Cleaned corpus to bind to the retriever. - min_chars: Minimum query length required before retrieval runs. - - Returns: - A configured retriever instance bound to ``corpus``. - """ - - -def _validate_retriever_weight(weight: float) -> None: - if not math.isfinite(weight) or weight <= 0: - raise ValueError("retriever weight must be a finite value > 0") - - -@dataclass(frozen=True, slots=True) -class PrefixRetrieverSpec: - """Configuration for building a prefix retriever.""" - - weight: float = 1.0 - name: str = field(init=False, default="prefix") - - def __post_init__(self) -> None: - """Validate configuration values after dataclass initialisation.""" - _validate_retriever_weight(self.weight) - - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: - """Build a prefix retriever for the provided cleaned corpus. - - Args: - corpus: Cleaned corpus to search. - min_chars: Minimum query length required before retrieval runs. - - Returns: - A configured ``PrefixRetriever``. - """ - return PrefixRetriever(corpus, min_chars=min_chars) - - -@dataclass(frozen=True, slots=True) -class NgramRetrieverSpec: - """Configuration for building a character n-gram retriever.""" - - weight: float = 1.0 - n: int = 3 - max_df: float = 0.2 - name: str = field(init=False, default="ngram") - - def __post_init__(self) -> None: - """Validate n-gram configuration values after initialisation.""" - _validate_retriever_weight(self.weight) - if not _MIN_NGRAM_SIZE <= self.n <= _MAX_NGRAM_SIZE: - raise ValueError("ngram n must be between 2 and 5") - if not 0.0 < self.max_df <= 1.0: - raise ValueError("ngram max_df must be in (0, 1]") - - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: - """Build a character n-gram retriever for the provided corpus. - - Args: - corpus: Cleaned corpus to search. - min_chars: Minimum query length required before retrieval runs. - - Returns: - A configured ``NgramRetriever``. - - Raises: - ValueError: If ``max_df`` would remove every n-gram feature from the - provided corpus. - """ - if self.max_df * corpus.size < 1: - raise ValueError("ngram max_df is too low for the given corpus") - return NgramRetriever( - corpus, - n=self.n, - max_df=self.max_df, - min_chars=min_chars, - ) - - -@dataclass(frozen=True, slots=True) -class SemanticRetrieverSpec: - """Configuration for building a semantic retriever.""" - - weight: float = 1.0 - model: str = "all-MiniLM-L6-v2" - name: str = field(init=False, default="semantic") - - def __post_init__(self) -> None: - """Validate semantic retriever configuration after initialisation.""" - _validate_retriever_weight(self.weight) - if not isinstance(self.model, str) or not self.model.strip(): - raise ValueError("semantic model must be a non-empty string") - - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: - """Build a semantic retriever for the provided cleaned corpus. - - Args: - corpus: Cleaned corpus to search. - min_chars: Minimum query length required before retrieval runs. - - Returns: - A configured ``SemanticRetriever``. - """ - return SemanticRetriever(corpus, model=self.model, min_chars=min_chars) - - -def default_retriever_specs() -> list[RetrieverSpec]: - """Return the standard runtime retriever set used by SAYT. - - Returns: - The default prefix, character n-gram, and semantic retriever specs. - """ - return [ - PrefixRetrieverSpec(), - NgramRetrieverSpec(), - SemanticRetrieverSpec(), - ] diff --git a/src/industrial_classification_utils/sayt/storage.py b/src/industrial_classification_utils/sayt/storage.py new file mode 100644 index 0000000..37a2377 --- /dev/null +++ b/src/industrial_classification_utils/sayt/storage.py @@ -0,0 +1,359 @@ +"""Artifact and storage helpers for SAYT builder and loader paths.""" + +import csv +import json +import os +import shutil +from dataclasses import dataclass, fields, is_dataclass +from pathlib import Path + +import pandas as pd + +from .core import ( + CleanCorpus, + PersistedCorpusRow, + validate_max_suggestions, + validate_min_chars, +) +from .retriever_specs import ( + ArtifactRetrieverSpec, + NgramRetrieverSpec, + PrefixRetrieverSpec, + Retriever, + RetrieverSpec, + SemanticRetrieverSpec, +) + +SAYT_ARTIFACT_TYPE = "sayt" +SAYT_ARTIFACT_VERSION = 2 +MANIFEST_FILE_NAME = "manifest.json" +CORPUS_FILE_NAME = "corpus.csv" +_ARTIFACT_CORPUS_FIELDS = ["row_id", "search_text", "display_text"] + + +@dataclass(frozen=True, slots=True) +class StoredRetrieverSpec: + """Persisted retriever spec plus its optional filespace path.""" + + spec: ArtifactRetrieverSpec + path: str | None = None + + +@dataclass(frozen=True, slots=True) +class SaytArtifactManifest: + """Structured manifest data for a persisted SAYT artifact.""" + + min_chars: int + max_suggestions: int + corpus_file: str + corpus_size: int + retrievers: tuple[StoredRetrieverSpec, ...] + + +def load_corpus_from_csv( + file_path: str | os.PathLike[str], + *, + search_text_col: str = "title", + display_text_col: str | None = None, +) -> list[tuple[object, object]]: + """Load raw corpus tuples from a CSV file. + + Args: + file_path: Path to the CSV file containing suggestion rows. + search_text_col: Column containing the searchable text. + display_text_col: Optional column containing display text. When + omitted, the search column is reused for display values. + + Returns: + Raw ``(search_text, display_text)`` tuples suitable for ``CleanCorpus``. + + Raises: + ValueError: If the requested search or display column is missing. + """ + df = pd.read_csv(file_path) + if search_text_col not in df.columns: + raise ValueError(f"Column '{search_text_col}' not found in CSV") + if display_text_col is None: + display_text_col = search_text_col + if display_text_col not in df.columns: + raise ValueError(f"Column '{display_text_col}' not found in CSV") + return list(zip(df[search_text_col], df[display_text_col], strict=False)) + + +def prepare_artifact_dir( + artifact_dir: str | os.PathLike[str], + *, + overwrite: bool = False, +) -> Path: + """Create or replace the output directory for a SAYT artifact.""" + path = Path(artifact_dir) + if path.exists(): + if not overwrite: + raise FileExistsError("Artifact directory already exists") + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + path.mkdir(parents=True, exist_ok=True) + return path + + +def write_artifact_corpus(corpus: CleanCorpus, *, artifact_dir: str | Path) -> Path: + """Persist cleaned SAYT rows as the artifact corpus source of truth.""" + output_path = Path(artifact_dir) / CORPUS_FILE_NAME + with open(output_path, "w", newline="", encoding="utf-8") as csv_file: + writer = csv.DictWriter(csv_file, fieldnames=_ARTIFACT_CORPUS_FIELDS) + writer.writeheader() + writer.writerows( + { + "row_id": row_id, + "search_text": search_text, + "display_text": display_text, + } + for row_id, search_text, display_text in corpus.rows + ) + return output_path + + +def read_artifact_corpus( + *, + artifact_dir: str | Path, + corpus_file: str = CORPUS_FILE_NAME, +) -> list[PersistedCorpusRow]: + """Read persisted corpus rows from a SAYT artifact.""" + corpus_path = Path(artifact_dir) / corpus_file + if not corpus_path.exists(): + raise FileNotFoundError(f"Artifact corpus file not found: {corpus_path}") + + with open(corpus_path, encoding="utf-8") as csv_file: + reader = csv.DictReader(csv_file) + return [ + PersistedCorpusRow( + row_id=row["row_id"], + search_text=row["search_text"], + display_text=row["display_text"], + ) + for row in reader + ] + + +def build_artifact_manifest( + *, + corpus: CleanCorpus, + min_chars: int, + max_suggestions: int, + retriever_specs: tuple[RetrieverSpec, ...], +) -> SaytArtifactManifest: + """Build the structured manifest payload for a SAYT artifact.""" + return SaytArtifactManifest( + min_chars=min_chars, + max_suggestions=max_suggestions, + corpus_file=CORPUS_FILE_NAME, + corpus_size=corpus.size, + retrievers=tuple( + _build_stored_retriever(index, spec) + for index, spec in enumerate(retriever_specs) + ), + ) + + +def write_artifact_manifest( + manifest: SaytArtifactManifest, + *, + artifact_dir: str | Path, +) -> Path: + """Write the manifest for a SAYT artifact.""" + manifest_path = Path(artifact_dir) / MANIFEST_FILE_NAME + manifest_path.write_text( + json.dumps(_serialise_manifest(manifest), indent=2), + encoding="utf-8", + ) + return manifest_path + + +def read_artifact_manifest(*, artifact_dir: str | Path) -> SaytArtifactManifest: + """Read and validate a SAYT artifact manifest.""" + manifest_path = Path(artifact_dir) / MANIFEST_FILE_NAME + if not manifest_path.exists(): + raise FileNotFoundError(f"Artifact manifest not found: {manifest_path}") + + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + if payload.get("artifact_type") != SAYT_ARTIFACT_TYPE: + raise ValueError("Unsupported artifact type") + if payload.get("artifact_version") != SAYT_ARTIFACT_VERSION: + raise ValueError("Unsupported artifact version") + + try: + return SaytArtifactManifest( + min_chars=validate_min_chars(payload["min_chars"]), + max_suggestions=validate_max_suggestions(payload["max_suggestions"]), + corpus_file=str(payload["corpus_file"]), + corpus_size=int(payload["corpus_size"]), + retrievers=tuple( + _deserialise_stored_retriever(item) for item in payload["retrievers"] + ), + ) + except KeyError as exc: + raise ValueError(f"Malformed artifact manifest: missing {exc.args[0]}") from exc + + +def retriever_filespace_path( + artifact_dir: str | Path, + stored_retriever: StoredRetrieverSpec, +) -> Path: + """Resolve the persisted filespace for a dense retriever entry.""" + if stored_retriever.path is None: + raise ValueError( + f"Retriever '{stored_retriever.spec.name}' does not have a stored filespace" + ) + return Path(artifact_dir) / stored_retriever.path + + +def _require_artifact_spec(spec: RetrieverSpec) -> ArtifactRetrieverSpec: + if not isinstance(spec, ArtifactRetrieverSpec): + raise TypeError( + "Only artifact-aware retriever specs can be persisted; " + f"got {type(spec).__name__}" + ) + return spec + + +def build_retriever_artifact( + *, + corpus: CleanCorpus, + min_chars: int, + stored_retriever: StoredRetrieverSpec, + artifact_dir: str | Path, +) -> None: + """Persist built-in retriever assets required by a SAYT artifact.""" + if stored_retriever.path is None: + return + + stored_retriever.spec.build( + corpus, + min_chars=min_chars, + filespace_path=retriever_filespace_path(artifact_dir, stored_retriever), + overwrite=True, + ) + + +def load_retriever_from_artifact( + *, + corpus: CleanCorpus, + min_chars: int, + stored_retriever: StoredRetrieverSpec, + artifact_dir: str | Path, +) -> Retriever: + """Restore a built-in runtime retriever from persisted artifact state.""" + return stored_retriever.spec.load_from_artifact( + corpus, + min_chars=min_chars, + filespace_path=( + retriever_filespace_path(artifact_dir, stored_retriever) + if stored_retriever.path is not None + else None + ), + ) + + +def _build_stored_retriever( + index: int, + spec: RetrieverSpec, +) -> StoredRetrieverSpec: + artifact_spec = _require_artifact_spec(spec) + + return StoredRetrieverSpec( + spec=artifact_spec, + path=( + None + if isinstance(artifact_spec, PrefixRetrieverSpec) + else f"retrievers/{index:02d}-{artifact_spec.name}" + ), + ) + + +def _serialise_manifest(manifest: SaytArtifactManifest) -> dict[str, object]: + return { + "artifact_type": SAYT_ARTIFACT_TYPE, + "artifact_version": SAYT_ARTIFACT_VERSION, + "min_chars": manifest.min_chars, + "max_suggestions": manifest.max_suggestions, + "corpus_file": manifest.corpus_file, + "corpus_size": manifest.corpus_size, + "retrievers": [ + _serialise_stored_retriever(stored_retriever) + for stored_retriever in manifest.retrievers + ], + } + + +def _serialise_stored_retriever( + stored_retriever: StoredRetrieverSpec, +) -> dict[str, object]: + spec = stored_retriever.spec + + if is_dataclass(spec): + config: dict[str, object] = { + field.name: getattr(spec, field.name) + for field in fields(spec) + if field.name not in {"name", "weight"} + } + else: + raw_config = getattr(spec, "__dict__", None) + config = ( + { + str(key): value + for key, value in raw_config.items() + if key not in {"name", "weight"} + } + if isinstance(raw_config, dict) + else {} + ) + + return { + "type": spec.name, + "weight": spec.weight, + "path": stored_retriever.path, + "config": config, + } + + +def _deserialise_stored_retriever(payload: dict[str, object]) -> StoredRetrieverSpec: + retriever_type = str(payload["type"]) + weight = _coerce_float(payload["weight"], field_name="weight") + path = payload.get("path") + config = payload.get("config", {}) + if not isinstance(config, dict): + raise ValueError(f"Malformed retriever config for type: {retriever_type}") + spec: RetrieverSpec + if retriever_type == "prefix": + spec = PrefixRetrieverSpec(weight=weight) + elif retriever_type == "ngram": + spec = NgramRetrieverSpec( + weight=weight, + n=_coerce_int(config["n"], field_name="n"), + max_df=_coerce_float(config["max_df"], field_name="max_df"), + ) + elif retriever_type == "semantic": + spec = SemanticRetrieverSpec( + weight=weight, + model=str(config["model"]), + ) + else: + raise ValueError(f"Unsupported stored retriever type: {retriever_type}") + + return StoredRetrieverSpec( + spec=spec, path=str(path) if isinstance(path, str) else None + ) + + +def _coerce_int(value: object, *, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int | str): + raise ValueError(f"Malformed integer value for retriever field: {field_name}") + return int(value) + + +def _coerce_float(value: object, *, field_name: str) -> float: + if isinstance(value, bool) or not isinstance(value, int | float | str): + raise ValueError(f"Malformed float value for retriever field: {field_name}") + return float(value) diff --git a/src/industrial_classification_utils/sayt/suggester.py b/src/industrial_classification_utils/sayt/suggester.py new file mode 100644 index 0000000..80a424f --- /dev/null +++ b/src/industrial_classification_utils/sayt/suggester.py @@ -0,0 +1,501 @@ +"""Search-as-you-type (SAYT) orchestration. + +This module provides the public suggester API that coordinates configured +retrievers and combines their scores into ranked suggestions. +""" + +import math +import os +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass, fields, is_dataclass +from pathlib import Path +from typing import Any + +from survey_assist_utils.logging import get_logger + +from ._base import BaseCorpusBound +from .core import ( + CleanCorpus, + SaytArtifactProvenance, + SaytConfiguration, + SaytCorpusSummary, + SaytGlobalSettings, + SaytRetrieverArtifactProvenance, + SaytRetrieverSummary, + Suggestion, + _normalise, + take_with_ties, +) +from .retriever_specs import ( + Retriever, + RetrieverSpec, +) +from .storage import ( + SAYT_ARTIFACT_TYPE, + SAYT_ARTIFACT_VERSION, + StoredRetrieverSpec, + load_retriever_from_artifact, + read_artifact_corpus, + read_artifact_manifest, +) + +logger = get_logger(__name__) + + +@dataclass(frozen=True, slots=True) +class _ConfiguredRetriever: + """Runtime retriever binding with its configured contribution weight.""" + + name: str + weight: float + retriever: Retriever + + +class SAYTSuggester(BaseCorpusBound): # pylint: disable=too-many-instance-attributes + """Suggest free-text responses as a user types. + + The suggester: + - validates and cleans the supplied corpus + - builds the configured retrievers for that corpus + - combines retriever-local scores into a shared weighted ranking + + By default it uses the standard prefix, n-gram, and semantic retriever + specifications. Use ``retrievers=`` to override that mix. + + Suggester-wide settings are configured directly on the suggester. At + present these include: + - ``min_chars``: minimum query length before retrieval runs + - ``max_suggestions``: default maximum number of ranked suggestions to return + + Examples: + Basic usage with an in-memory corpus: + + ```python + from industrial_classification_utils.sayt import SAYTSuggester + + suggester = SAYTSuggester( + corpus=[ + ("Car wash", "Car Wash"), + ("Dog grooming", "Dog grooming"), + ], + min_chars=3, + max_suggestions=5, + ) + + results = suggester.suggest("car") + ``` + + Usage with custom retriever specifications: + + ```python + from industrial_classification_utils.sayt import ( + PrefixRetrieverSpec, + SAYTSuggester, + ) + + suggester = SAYTSuggester( + corpus=[("Car wash", "Car Wash")], + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + ) + ``` + """ + + def __init__( + self, + corpus: Iterable[tuple[object, object]] | Iterable[str], + *, + retrievers: Sequence[RetrieverSpec] | None = None, + min_chars: int = 4, + max_suggestions: int = 10, + ) -> None: + """Initialise a suggester for a cleaned response corpus. + + Args: + corpus: Iterable of search strings or ``(search_text, display_text)`` + pairs. + retrievers: Optional retriever specifications. When omitted, the + standard prefix, n-gram, and semantic spec set is used. + min_chars: Minimum query length before retrieval runs. + max_suggestions: Default maximum number of ranked suggestions to + return. + """ + super().__init__( + corpus, + retrievers=retrievers, + min_chars=min_chars, + max_suggestions=max_suggestions, + ) + self._retrievers = self._build_retrievers(self._retriever_specs) + self._max_duplication = max(self._corpus.display_text_count.values(), default=0) + self._stored_retrievers: tuple[StoredRetrieverSpec, ...] | None = None + self._artifact_provenance: SaytArtifactProvenance | None = None + logger.info( + "SAYT suggester initialized", + config=self.get_config().model_dump(mode="json"), + ) + + @classmethod + def _from_state( # pylint: disable=too-many-arguments # noqa: PLR0913 + cls, + *, + corpus: CleanCorpus, + min_chars: int, + max_suggestions: int, + retriever_specs: Sequence[RetrieverSpec], + retrievers: list[_ConfiguredRetriever], + stored_retrievers: Sequence[StoredRetrieverSpec] | None = None, + artifact_provenance: SaytArtifactProvenance | None = None, + ) -> "SAYTSuggester": + """Construct a suggester from already-validated runtime state.""" + suggester = cls.__new__(cls) + suggester._corpus = corpus + suggester._min_chars = min_chars + suggester._max_suggestions = max_suggestions + suggester._max_duplication = max(corpus.display_text_count.values(), default=0) + suggester._retriever_specs = tuple(retriever_specs) + suggester._retrievers = retrievers + suggester._stored_retrievers = ( + tuple(stored_retrievers) if stored_retrievers is not None else None + ) + suggester._artifact_provenance = artifact_provenance + logger.info( + "SAYT suggester initialized", + config=suggester.get_config().model_dump(mode="json"), + ) + return suggester + + @classmethod + def from_artifact(cls, artifact_dir: str | os.PathLike) -> "SAYTSuggester": + """Load a suggester from a persisted SAYT artifact directory.""" + artifact_path = Path(artifact_dir) + manifest = read_artifact_manifest(artifact_dir=artifact_path) + persisted_rows = read_artifact_corpus( + artifact_dir=artifact_path, + corpus_file=manifest.corpus_file, + ) + corpus = CleanCorpus.from_persisted_rows(persisted_rows) + if corpus.size != manifest.corpus_size: + raise ValueError("Artifact corpus size does not match manifest") + + retrievers = _load_retrievers_from_artifact( + corpus=corpus, + min_chars=manifest.min_chars, + stored_retrievers=manifest.retrievers, + artifact_dir=artifact_path, + ) + artifact_provenance = SaytArtifactProvenance( + artifact_dir=str(artifact_path), + artifact_type=SAYT_ARTIFACT_TYPE, + artifact_version=SAYT_ARTIFACT_VERSION, + corpus_file=manifest.corpus_file, + corpus_size=manifest.corpus_size, + ) + return cls._from_state( + corpus=corpus, + min_chars=manifest.min_chars, + max_suggestions=manifest.max_suggestions, + retriever_specs=[ + stored_retriever.spec for stored_retriever in manifest.retrievers + ], + retrievers=retrievers, + stored_retrievers=manifest.retrievers, + artifact_provenance=artifact_provenance, + ) + + def _build_retrievers( + self, retriever_specs: Sequence[RetrieverSpec] + ) -> list[_ConfiguredRetriever]: + return [ + _ConfiguredRetriever( + name=spec.name, + weight=weight, + retriever=spec.build( + self._corpus, + min_chars=self._min_chars, + ), + ) + for spec, weight in _normalised_retriever_specs(retriever_specs) + ] + + def _dedup_suggestions( + self, suggestions: list[Suggestion] + ) -> list[tuple[str, float]]: + # sort by score and deduplicate by display text, keeping the highest-scoring variant. + sorted_suggestions = sorted( + suggestions, + key=lambda s: ( + -s.score, + -self._corpus.display_text_count.get(s.display_text, 0), + s.display_text.lower(), + s.row_id, + ), + ) + seen: set[str] = set() + deduped: list[tuple[str, float]] = [] + for s in sorted_suggestions: + display_text = s.display_text # lower? normalised? + if display_text not in seen: + deduped.append((display_text, s.score)) + seen.add(display_text) + return deduped + + def _combine_suggestions( + self, + result_groups: Iterable[tuple[float, list[Suggestion]]], + ) -> list[tuple[str, float]]: + def normalise_scores( + items: list[Suggestion], weight: float + ) -> dict[str, float]: + if not items: + return {} + max_score = max((float(s.score) for s in items), default=0.0) + if max_score <= 0: + return {} + out: dict[str, float] = {} + for s in items: + if not s.row_id: + continue + out[s.row_id] = max( + out.get(s.row_id, 0.0), float(s.score) / max_score * weight + ) + return out + + combined_scores: dict[str, float] = {} + for weight, suggestions in result_groups: + d = normalise_scores(suggestions, weight) + for k, v in d.items(): + combined_scores[k] = combined_scores.get(k, 0.0) + v + + return [(row_id, float(score)) for row_id, score in combined_scores.items()] + + def _collect_retriever_results( + self, q_norm: str, num_suggestions: int + ) -> list[tuple[float, list[Suggestion]]]: + return [ + ( + configured_retriever.weight, + configured_retriever.retriever.suggest_with_scores( + q_norm, + num_suggestions=num_suggestions, + ), + ) + for configured_retriever in self._retrievers + ] + + def suggest_with_scores( + self, query: str | None, num_suggestions: int | None = None + ) -> list[Suggestion]: + """Return ranked suggestions and their combined scores. + + Args: + query: Raw user query text. + num_suggestions: Optional maximum number of ranked suggestions to + return. When omitted, the configured default is used. + + Returns: + A list of combined suggestions ordered by descending score. Returns + an empty list when the normalised query is shorter than + ``min_chars``. + """ + if num_suggestions is None: + num_suggestions = self._max_suggestions + q_norm = _normalise(query) + if len(q_norm) < self._min_chars: + return [] + + # Ask for more suggestions, as some may be filtered out after deduplication + results_by_kind = self._collect_retriever_results( + q_norm, + num_suggestions=10 * num_suggestions, + ) + + combined_result = self._combine_suggestions(results_by_kind) + ranked_results = take_with_ties(combined_result, num_suggestions) + out = [ + Suggestion( + row_id=row_id, + display_text=self._corpus.id_to_display.get(row_id, ""), + score=score, + search_text=self._corpus.id_to_search.get(row_id, ""), + ) + for row_id, score in ranked_results + ] + + return out + + def suggest( + self, query: str | None, num_suggestions: int | None = None + ) -> list[str]: + """Return deduplicated display-text suggestions. + + Args: + query: Raw user query text. + num_suggestions: Optional maximum number of display values to + return. When omitted, the configured default is used. + + Returns: + A list of display-text suggestions ordered by descending combined + score, while preserving ties at the cutoff. + """ + if num_suggestions is None: + num_suggestions = self._max_suggestions + results = self.suggest_with_scores( + query, num_suggestions=num_suggestions * self._max_duplication + ) + dedup_results = self._dedup_suggestions(results) + ranked_results = take_with_ties(dedup_results, num_suggestions) + return [result[0] for result in ranked_results] + + def get_config(self) -> SaytConfiguration: + """Return a rich runtime summary of this suggester. + + Returns: + A summary of global settings, corpus details, retriever + configuration, and any artifact provenance available for this + suggester. + """ + stored_retrievers: Sequence[StoredRetrieverSpec | None] + if self._stored_retrievers is None: + stored_retrievers = [None] * len(self._retriever_specs) + else: + stored_retrievers = list(self._stored_retrievers) + + retrievers = [ + _build_retriever_summary( + spec=spec, + configured_retriever=configured_retriever, + stored_retriever=stored_retriever, + ) + for spec, configured_retriever, stored_retriever in zip( + self._retriever_specs, + self._retrievers, + stored_retrievers, + strict=True, + ) + ] + + return SaytConfiguration( + settings=SaytGlobalSettings( + min_chars=self._min_chars, + max_suggestions=self._max_suggestions, + ), + corpus=SaytCorpusSummary( + size=self._corpus.size, + unique_display_texts=len(self._corpus.display_text_count), + max_duplication=self._max_duplication, + ), + retrievers=retrievers, + artifact_provenance=( + self._artifact_provenance.model_copy(deep=True) + if self._artifact_provenance is not None + else None + ), + ) + + +def _normalised_retriever_specs( + retriever_specs: Sequence[RetrieverSpec], +) -> list[tuple[RetrieverSpec, float]]: + """Validate and normalise configured retriever weights.""" + if not retriever_specs: + raise ValueError("At least one retriever must be configured") + + validated_specs: list[tuple[RetrieverSpec, float]] = [] + for spec in retriever_specs: + weight = float(spec.weight) + if not math.isfinite(weight) or weight <= 0: + raise ValueError( + f"Retriever '{spec.name}' weight must be a finite value > 0" + ) + validated_specs.append((spec, weight)) + + total_weight = sum(weight for _, weight in validated_specs) + return [(spec, weight / total_weight) for spec, weight in validated_specs] + + +def _load_retrievers_from_artifact( + *, + corpus: CleanCorpus, + min_chars: int, + stored_retrievers: Sequence[StoredRetrieverSpec], + artifact_dir: Path, +) -> list[_ConfiguredRetriever]: + """Restore runtime retrievers from a persisted SAYT artifact.""" + normalised_specs = _normalised_retriever_specs( + [stored_retriever.spec for stored_retriever in stored_retrievers] + ) + return [ + _ConfiguredRetriever( + name=stored_retriever.spec.name, + weight=weight, + retriever=load_retriever_from_artifact( + corpus=corpus, + min_chars=min_chars, + stored_retriever=stored_retriever, + artifact_dir=artifact_dir, + ), + ) + for (_, weight), stored_retriever in zip( + normalised_specs, + stored_retrievers, + strict=True, + ) + ] + + +def _jsonable_value(value: Any) -> Any: + if isinstance(value, str | int | float | bool) or value is None: + return value + if isinstance(value, os.PathLike): + return os.fspath(value) + if isinstance(value, Mapping): + return {str(key): _jsonable_value(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_jsonable_value(item) for item in value] + return str(value) + + +def _summarise_retriever_config(spec: RetrieverSpec) -> dict[str, Any]: + if is_dataclass(spec): + items = ( + (field.name, getattr(spec, field.name)) + for field in fields(spec) + if field.name not in {"name", "weight"} + ) + return {key: _jsonable_value(value) for key, value in items} + + raw_config = getattr(spec, "__dict__", None) + if isinstance(raw_config, dict): + return { + str(key): _jsonable_value(value) + for key, value in raw_config.items() + if key not in {"name", "weight"} + } + return {} + + +def _build_retriever_summary( + *, + spec: RetrieverSpec, + configured_retriever: _ConfiguredRetriever, + stored_retriever: StoredRetrieverSpec | None, +) -> SaytRetrieverSummary: + artifact_provenance = None + config = _summarise_retriever_config(spec) + if stored_retriever is not None: + artifact_provenance = SaytRetrieverArtifactProvenance( + artifact_type=stored_retriever.spec.name, + path=stored_retriever.path, + config=_summarise_retriever_config(stored_retriever.spec), + ) + + return SaytRetrieverSummary( + name=spec.name, + spec_type=type(spec).__name__, + retriever_type=type(configured_retriever.retriever).__name__, + configured_weight=float(spec.weight), + normalised_weight=configured_retriever.weight, + config=config, + artifact_provenance=artifact_provenance, + ) diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index bf473e9..4eb6f73 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -2,24 +2,33 @@ # pylint: disable=protected-access,redefined-outer-name,too-few-public-methods,C0116,W0613 +import json +from collections import Counter from dataclasses import dataclass +from pathlib import Path from uuid import UUID import pandas as pd import pytest -from pydantic import ValidationError from industrial_classification_utils.sayt import ( + NgramRetrieverSpec, PrefixRetrieverSpec, + SAYTBuilder, + SaytConfiguration, ) -from industrial_classification_utils.sayt.sayt import SAYTSuggester -from industrial_classification_utils.sayt.sayt_core import CleanCorpus, Suggestion +from industrial_classification_utils.sayt.core import ( + CleanCorpus, + PersistedCorpusRow, + Suggestion, +) +from industrial_classification_utils.sayt.suggester import SAYTSuggester def test_constructor_rejects_unknown_kwargs(small_corpus): """Reject unknown constructor kwargs during config validation.""" - with pytest.raises(ValidationError, match="Extra inputs are not permitted"): - SAYTSuggester(small_corpus, does_not_exist=True) + with pytest.raises(TypeError, match="unexpected keyword argument"): + SAYTSuggester(small_corpus, does_not_exist=True) # pylint: disable=E1123 def test_empty_corpus_after_filtering_raises(): @@ -50,6 +59,30 @@ def test_clean_corpus_accepts_existing_instance_and_dict_input(small_corpus): assert dict_corpus.rows == corpus.rows +def test_clean_corpus_model_dump_excludes_derived_lookup_dicts(small_corpus): + """Keep derived lookup dictionaries out of the public model fields.""" + corpus = CleanCorpus.model_validate(small_corpus) + + dumped = corpus.model_dump() + + assert "id_to_search" not in dumped + assert "id_to_display" not in dumped + assert "display_text_count" not in dumped + assert dumped["rows"] == corpus.rows + + +def test_clean_corpus_restores_persisted_rows(small_corpus): + """Restore cleaned corpus rows without regenerating row identifiers.""" + corpus = CleanCorpus.model_validate(small_corpus) + + restored = CleanCorpus.from_persisted_rows( + [PersistedCorpusRow(*row) for row in corpus.rows] + ) + + assert restored.rows == corpus.rows + assert restored.model_dump() == corpus.model_dump() + + def test_clean_corpus_rejects_non_iterable_input(): """Reject scalar corpus values before attempting to clean them.""" with pytest.raises(TypeError, match="corpus must be an iterable"): @@ -132,6 +165,195 @@ def test_from_csv_rejects_missing_display_column(tmp_path, small_corpus): ) +def test_from_artifact_restores_prefix_suggester(tmp_path, small_corpus): + """Round-trip a prefix-only artifact into a working suggester.""" + artifact_dir = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(tmp_path / "artifact") + + restored = SAYTSuggester.from_artifact(artifact_dir) + expected = SAYTSuggester( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ) + restored_config = restored.get_config() + expected_config = expected.get_config() + + assert restored.suggest("car") == expected.suggest("car") + assert restored_config.settings == expected_config.settings + assert restored_config.corpus == expected_config.corpus + assert [ + retriever.model_dump(exclude={"artifact_provenance"}) + for retriever in restored_config.retrievers + ] == [ + retriever.model_dump(exclude={"artifact_provenance"}) + for retriever in expected_config.retrievers + ] + assert restored_config.artifact_provenance is not None + assert restored_config.artifact_provenance.artifact_dir == str(artifact_dir) + assert restored_config.retrievers[0].artifact_provenance is not None + assert expected_config.artifact_provenance is None + + +def test_from_artifact_rejects_manifest_corpus_size_mismatch(tmp_path, small_corpus): + """Reject artifacts whose manifest corpus size disagrees with stored rows.""" + artifact_dir = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(tmp_path / "artifact") + manifest_path = artifact_dir / "manifest.json" + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + manifest["corpus_size"] += 1 + manifest_path.write_text(json.dumps(manifest), encoding="utf-8") + + with pytest.raises( + ValueError, match="Artifact corpus size does not match manifest" + ): + SAYTSuggester.from_artifact(artifact_dir) + + +def test_get_config_returns_rich_runtime_summary(small_corpus): + """Expose runtime settings, corpus stats, and retriever summaries.""" + suggester = SAYTSuggester( + small_corpus, + min_chars=3, + max_suggestions=5, + retrievers=[ + PrefixRetrieverSpec(weight=2.0), + NgramRetrieverSpec(weight=1.0, n=4, max_df=1.0), + ], + ) + + config = suggester.get_config() + display_counts = Counter(display for _, _, display in suggester._corpus.rows) + + assert isinstance(config, SaytConfiguration) + assert config.settings.model_dump() == { + "min_chars": 3, + "max_suggestions": 5, + } + assert config.corpus.model_dump() == { + "size": suggester._corpus.size, + "unique_display_texts": len(display_counts), + "max_duplication": max(display_counts.values(), default=0), + } + assert [retriever.name for retriever in config.retrievers] == ["prefix", "ngram"] + assert config.retrievers[0].config == {} + assert config.retrievers[1].config == {"n": 4, "max_df": 1.0} + assert config.retrievers[0].configured_weight == pytest.approx(2.0) + assert config.retrievers[0].normalised_weight == pytest.approx(2.0 / 3.0) + assert config.retrievers[1].normalised_weight == pytest.approx(1.0 / 3.0) + assert config.retrievers[1].retriever_type == "NgramRetriever" + assert config.artifact_provenance is None + + +def test_get_config_supports_custom_specs_without_artifact_handlers(small_corpus): + """Summarise custom runtime-only specs without requiring persistence hooks.""" + + class _StubRetriever: + def suggest_with_scores(self, q_norm, num_suggestions): + _ = (q_norm, num_suggestions) + return [] + + class _CustomSpec: + def __init__(self, *, trigger: str, weight: float = 1.0): + self.trigger = trigger + self.weight = weight + self.name = "custom" + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + return _StubRetriever() + + config = SAYTSuggester( + small_corpus, + min_chars=3, + retrievers=[_CustomSpec(trigger="groom")], + ).get_config() + + assert config.retrievers[0].config == {"trigger": "groom"} + assert config.retrievers[0].artifact_provenance is None + + +def test_get_config_serialises_nested_custom_spec_values(small_corpus): + """Convert non-JSON-native custom spec config values into JSON-safe forms.""" + + class _StubRetriever: + def suggest_with_scores(self, q_norm, num_suggestions): + _ = (q_norm, num_suggestions) + return [] + + class _Marker: + def __str__(self): + return "marker-object" + + class _CustomSpec: + def __init__(self): + self.name = "custom" + self.weight = 1.0 + self.folder = Path("artifacts/model") + self.options = { + "labels": ["car", Path("cache/index"), _Marker()], + "metadata": {"marker": _Marker()}, + } + self.values = (Path("weights.bin"), _Marker()) + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + return _StubRetriever() + + config = SAYTSuggester( + small_corpus, + min_chars=3, + retrievers=[_CustomSpec()], + ).get_config() + + assert config.retrievers[0].config == { + "folder": "artifacts/model", + "options": { + "labels": ["car", "cache/index", "marker-object"], + "metadata": {"marker": "marker-object"}, + }, + "values": ["weights.bin", "marker-object"], + } + + +def test_get_config_returns_empty_config_for_slots_only_custom_spec(small_corpus): + """Return an empty config summary when a custom spec exposes no __dict__.""" + + class _StubRetriever: + def suggest_with_scores(self, q_norm, num_suggestions): + _ = (q_norm, num_suggestions) + return [] + + class _SlotsOnlySpec: + __slots__ = ("name", "trigger", "weight") + + def __init__(self, *, trigger: str, weight: float = 1.0): + self.name = "slots-only" + self.weight = weight + self.trigger = trigger + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + return _StubRetriever() + + config = SAYTSuggester( + small_corpus, + min_chars=3, + retrievers=[_SlotsOnlySpec(trigger="groom")], + ).get_config() + + assert config.retrievers[0].config == {} + + def test_suggest_returns_empty_for_short_or_non_string_query(small_corpus): """Return no suggestions for short or non-string queries.""" s = SAYTSuggester(small_corpus, min_chars=4, retrievers=[PrefixRetrieverSpec()]) @@ -349,7 +571,7 @@ def build(self, corpus, *, min_chars): return _StubRetriever() monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt.default_retriever_specs", + "industrial_classification_utils.sayt._base.default_retriever_specs", lambda: [ _StubRetrieverSpec(name="prefix"), _StubRetrieverSpec(name="ngram"), @@ -426,3 +648,16 @@ def build(self, corpus, *, min_chars): ) assert not build_calls + + +def test_clean_corpus_rejects_empty_persisted_rows(): + """Reject empty persisted row collections during artifact restore.""" + with pytest.raises(ValueError, match="corpus is empty after filtering"): + CleanCorpus.from_persisted_rows([]) + + +def test_clean_corpus_coerces_persisted_tuple_values_to_strings(): + """Coerce tuple-based persisted rows to strings before rebuilding indexes.""" + restored = CleanCorpus.from_persisted_rows([(123, 456, 789)]) + + assert restored.rows == [("123", "456", "789")] diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py new file mode 100644 index 0000000..09ed854 --- /dev/null +++ b/tests/sayt/test_sayt_builder.py @@ -0,0 +1,307 @@ +"""Tests for SAYT artifact building and loading.""" + +# pylint: disable=too-few-public-methods,missing-function-docstring,too-many-arguments,duplicate-code + +import csv +import json +from dataclasses import dataclass +from pathlib import Path + +import pandas as pd +import pytest + +from industrial_classification_utils.sayt import ( + NgramRetrieverSpec, + PrefixRetrieverSpec, + SAYTBuilder, +) +from industrial_classification_utils.sayt.core import CleanCorpus +from industrial_classification_utils.sayt.suggester import SAYTSuggester + + +class _CustomRetrieverSpec: + def __init__(self, *, trigger: str, weight: float = 1.0): + self.trigger = trigger + self.weight = weight + self.name = "custom-trigger" + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + + +@dataclass(frozen=True) +class _FailingArtifactRetrieverSpec: + name: str = "failing" + weight: float = 1.0 + + def build( + self, + corpus, + *, + min_chars, + filespace_path=None, + overwrite=True, + ): + _ = (corpus, min_chars, overwrite) + if filespace_path is not None: + Path(filespace_path).mkdir(parents=True, exist_ok=True) + Path(filespace_path, "partial.txt").write_text("partial", encoding="utf-8") + raise RuntimeError("boom") + + def load_from_artifact(self, corpus, *, min_chars, filespace_path): + _ = (corpus, min_chars, filespace_path) + raise NotImplementedError + + +def test_builder_from_csv_loads_columns_and_persists_artifact(tmp_path, small_corpus): + """Build an artifact from CSV input using the configured search/display columns.""" + csv_path = tmp_path / "responses.csv" + pd.DataFrame( + { + "search": [row[0] for row in small_corpus], + "display": [row[1] for row in small_corpus], + } + ).to_csv(csv_path, index=False) + + artifact_dir = SAYTBuilder.from_csv( + csv_path, + search_text_col="search", + display_text_col="display", + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(tmp_path / "artifact") + + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + + assert manifest["min_chars"] == 3 + assert manifest["max_suggestions"] == 5 + assert manifest["corpus_size"] == len(small_corpus) + + +def test_builder_writes_manifest_and_corpus(tmp_path, small_corpus): + """Persist manifest metadata and cleaned corpus rows for an artifact.""" + artifact_dir = tmp_path / "artifact" + + result = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(artifact_dir) + + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + with open(artifact_dir / "corpus.csv", encoding="utf-8") as corpus_file: + rows = list(csv.DictReader(corpus_file)) + + assert result == artifact_dir + assert manifest == { + "artifact_type": "sayt", + "artifact_version": 2, + "min_chars": 3, + "max_suggestions": 5, + "corpus_file": "corpus.csv", + "corpus_size": len(small_corpus), + "retrievers": [ + {"type": "prefix", "weight": 1.0, "path": None, "config": {}}, + ], + } + assert rows == [ + { + "row_id": row_id, + "search_text": search_text, + "display_text": display_text, + } + for row_id, search_text, display_text in CleanCorpus.model_validate( + small_corpus + ).rows + ] + + +def test_builder_writes_ngram_filespace(monkeypatch, tmp_path, small_corpus): + """Persist the configured dense retriever filespace inside the artifact.""" + captured = {} + artifact_dir = tmp_path / "artifact" + + class _StubPersistentVectorStore: + def __init__( # noqa: PLR0913 + self, + *, + file_name, + data_type, + vectoriser, + batch_size, + output_dir, + overwrite, + hooks, + ): + captured["file_name"] = file_name + captured["data_type"] = data_type + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["batch_size"] = batch_size + captured["output_dir"] = output_dir + captured["overwrite"] = overwrite + captured["hooks"] = hooks + Path(output_dir).mkdir(parents=True, exist_ok=True) + Path(output_dir, "metadata.json").write_text("{}", encoding="utf-8") + Path(output_dir, "vectors.parquet").write_text("dummy", encoding="utf-8") + self.num_vectors = 1 + + monkeypatch.setattr( + "industrial_classification_utils.sayt.indexes.VectorStore", + _StubPersistentVectorStore, + ) + + SAYTBuilder( + small_corpus, + retrievers=[NgramRetrieverSpec(max_df=1.0)], + min_chars=3, + ).build_artifact(artifact_dir) + + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + filespace_path = artifact_dir / manifest["retrievers"][0]["path"] + + assert Path(captured["output_dir"]).name == filespace_path.name + assert (filespace_path / "metadata.json").exists() + assert (filespace_path / "vectors.parquet").exists() + + +def test_from_artifact_loads_persisted_ngram_filespace( + monkeypatch, tmp_path, small_corpus +): + """Load persisted dense retrievers from their artifact filespaces.""" + captured = {} + artifact_dir = tmp_path / "artifact" + target_row_id, _, target_display = CleanCorpus.model_validate(small_corpus).rows[-1] + + class _StubPersistentVectorStore: + def __init__( # noqa: PLR0913 + self, + *, + file_name, + data_type, + vectoriser, + batch_size, + output_dir, + overwrite, + hooks, + ): + _ = (file_name, data_type, vectoriser, batch_size, overwrite, hooks) + Path(output_dir).mkdir(parents=True, exist_ok=True) + Path(output_dir, "metadata.json").write_text("{}", encoding="utf-8") + Path(output_dir, "vectors.parquet").write_text("dummy", encoding="utf-8") + self.num_vectors = 1 + + @classmethod + def from_filespace(cls, *, folder_path, vectoriser, hooks): + captured["folder_path"] = folder_path + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["hooks"] = hooks + return _StubLoadedVectorStore() + + class _StubSearchResults: + def to_dict(self, orient="records"): + assert orient == "records" + return [{"doc_label": target_row_id, "score": 1.0}] + + class _StubLoadedVectorStore: + num_vectors = 1 + + def search(self, query, n_results=10): + _ = query + captured["n_results"] = n_results + return _StubSearchResults() + + monkeypatch.setattr( + "industrial_classification_utils.sayt.indexes.VectorStore", + _StubPersistentVectorStore, + ) + + builder = SAYTBuilder( + small_corpus, + retrievers=[NgramRetrieverSpec(max_df=1.0)], + min_chars=3, + ) + builder.build_artifact(artifact_dir) + + suggester = SAYTSuggester.from_artifact(artifact_dir) + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + + assert suggester.suggest("groom") == [target_display] + assert captured == { + "folder_path": str(artifact_dir / manifest["retrievers"][0]["path"]), + "vectoriser_type": "_CharNgramVectoriser", + "hooks": None, + "n_results": 1, + } + + +def test_builder_rejects_custom_runtime_only_retriever_specs(tmp_path, small_corpus): + """Persisted artifacts currently support only the built-in retriever specs.""" + builder = SAYTBuilder( + small_corpus, + retrievers=[_CustomRetrieverSpec(trigger="groom", weight=1.5)], + min_chars=3, + max_suggestions=4, + ) + + with pytest.raises( + TypeError, + match="Only artifact-aware retriever specs can be persisted; got _CustomRetrieverSpec", + ): + builder.build_artifact(tmp_path / "artifact") + + +def test_builder_cleans_up_staged_artifact_when_later_retriever_fails( + tmp_path, small_corpus +): + """A failed staged build should not leave a partial artifact behind.""" + artifact_dir = tmp_path / "artifact" + + builder = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec(), _FailingArtifactRetrieverSpec()], + min_chars=3, + ) + + with pytest.raises(RuntimeError, match="boom"): + builder.build_artifact(artifact_dir) + + assert not artifact_dir.exists() + assert not list(tmp_path.glob(".artifact.tmp-*")) + + +def test_builder_preserves_existing_artifact_when_staged_overwrite_fails( + tmp_path, small_corpus +): + """A failed overwrite should leave the previous complete artifact intact.""" + artifact_dir = tmp_path / "artifact" + original_builder = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ) + original_builder.build_artifact(artifact_dir) + + original_manifest = json.loads( + (artifact_dir / "manifest.json").read_text(encoding="utf-8") + ) + original_corpus = (artifact_dir / "corpus.csv").read_text(encoding="utf-8") + + failing_builder = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec(), _FailingArtifactRetrieverSpec()], + min_chars=3, + max_suggestions=7, + ) + + with pytest.raises(RuntimeError, match="boom"): + failing_builder.build_artifact(artifact_dir, overwrite=True) + + assert json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) == ( + original_manifest + ) + assert (artifact_dir / "corpus.csv").read_text(encoding="utf-8") == original_corpus + assert not list(tmp_path.glob(".artifact.tmp-*")) + assert not list(tmp_path.glob(".artifact.bak-*")) diff --git a/tests/sayt/test_sayt_config.py b/tests/sayt/test_sayt_config.py index 229e7c7..2f2a22b 100644 --- a/tests/sayt/test_sayt_config.py +++ b/tests/sayt/test_sayt_config.py @@ -3,31 +3,77 @@ # pylint: disable=too-few-public-methods, R0801 import pytest -from pydantic import ValidationError from industrial_classification_utils.sayt import ( NgramRetrieverSpec, PrefixRetrieverSpec, + SAYTBuilder, + SAYTSuggester, SemanticRetrieverSpec, default_retriever_specs, ) -from industrial_classification_utils.sayt.sayt_core import CleanCorpus, SaytConfig +from industrial_classification_utils.sayt.core import CleanCorpus @pytest.mark.parametrize( - "kwargs, exc_type", + "factory, kwargs, exc_type, match", [ - ({"min_chars": 2}, ValidationError), - ({"min_chars": True}, ValidationError), - ({"max_suggestions": 0}, ValidationError), - ({"max_suggestions": 101}, ValidationError), - ({"ngram_enable": True}, ValidationError), + (SAYTSuggester, {"min_chars": 2}, ValueError, "min_chars must be >= 3"), + ( + SAYTSuggester, + {"min_chars": True}, + TypeError, + "min_chars must be an integer", + ), + ( + SAYTSuggester, + {"max_suggestions": 0}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + ( + SAYTSuggester, + {"max_suggestions": 101}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + ( + SAYTSuggester, + {"min_chars": "abc"}, + TypeError, + "min_chars must be an integer", + ), + (SAYTBuilder, {"min_chars": 2}, ValueError, "min_chars must be >= 3"), + ( + SAYTBuilder, + {"min_chars": True}, + TypeError, + "min_chars must be an integer", + ), + ( + SAYTBuilder, + {"max_suggestions": 0}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + ( + SAYTBuilder, + {"max_suggestions": 101}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + ( + SAYTBuilder, + {"max_suggestions": "abc"}, + TypeError, + "max_suggestions must be an integer", + ), ], ) -def test_config_validation(kwargs, exc_type): - """Reject unsupported SAYT config values and types.""" - with pytest.raises(exc_type): - SaytConfig.model_validate(kwargs) +def test_runtime_setting_validation(factory, kwargs, exc_type, match): + """Reject unsupported global SAYT settings on public entry points.""" + with pytest.raises(exc_type, match=match): + factory([("car wash", "Car Wash")], **kwargs) def test_default_retriever_specs_returns_standard_set(): @@ -86,7 +132,7 @@ def test_ngram_retriever_spec_validates_against_corpus_size(): def test_retriever_specs_keep_their_config(): - """Expose per-retriever settings on the spec object rather than SaytConfig.""" + """Expose per-retriever settings on the spec object.""" n = 4 max_df = 0.8 spec = NgramRetrieverSpec(weight=2.0, n=n, max_df=max_df) @@ -108,7 +154,7 @@ def __init__(self, corpus_arg, *, model, min_chars): captured["min_chars"] = min_chars monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_retriever_specs.SemanticRetriever", + "industrial_classification_utils.sayt.retriever_specs.SemanticRetriever", _StubSemanticRetriever, ) diff --git a/tests/sayt/test_sayt_retrievers.py b/tests/sayt/test_sayt_retrievers.py index 43d0642..e59e318 100644 --- a/tests/sayt/test_sayt_retrievers.py +++ b/tests/sayt/test_sayt_retrievers.py @@ -1,26 +1,30 @@ """Tests for SAYT retrieval and ranking behavior.""" -# ruff: noqa: PLR2004 # pylint: disable=protected-access,redefined-outer-name,too-few-public-methods,C0116,W0613 +import csv +import shutil +from pathlib import Path + import numpy as np import pytest from classifai.vectorisers import VectoriserBase from industrial_classification_utils.sayt import NgramRetrieverSpec, PrefixRetrieverSpec -from industrial_classification_utils.sayt.sayt import SAYTSuggester -from industrial_classification_utils.sayt.sayt_core import CleanCorpus -from industrial_classification_utils.sayt.sayt_indexes import ( +from industrial_classification_utils.sayt.core import CleanCorpus +from industrial_classification_utils.sayt.indexes import ( DenseVectorIndex, _CharNgramVectoriser, _L2NormalisingVectoriser, + load_semantic_index, ) -from industrial_classification_utils.sayt.sayt_retrievers import ( +from industrial_classification_utils.sayt.retrievers import ( NgramRetriever, PrefixRetriever, SemanticRetriever, _PrefixIndex, ) +from industrial_classification_utils.sayt.suggester import SAYTSuggester def test_prefix_full_string_match_ranks_expected_terms(small_corpus): @@ -233,6 +237,108 @@ def test_dense_retriever_keeps_ties_at_cutoff(small_corpus): ] +def test_dense_vector_index_builds_persistent_filespace( + monkeypatch, tmp_path, small_corpus +): + """Persist dense indexes even when the output folder is replaced first.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + output_dir = tmp_path / "ngram" + output_dir.mkdir() + + class _StubPersistentVectorStore: + # pylint: disable=too-many-arguments + def __init__( # noqa: PLR0913 + self, + *, + file_name, + data_type, + vectoriser, + batch_size, + output_dir, + overwrite, + hooks, + ): + captured["file_name"] = file_name + captured["data_type"] = data_type + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["batch_size"] = batch_size + captured["output_dir"] = output_dir + captured["overwrite"] = overwrite + captured["hooks"] = hooks + output_path = Path(output_dir) + if output_path.is_dir() and overwrite: + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + with open(file_name, encoding="utf-8") as input_file: + captured["rows"] = list(csv.DictReader(input_file)) + (output_path / "metadata.json").write_text("{}", encoding="utf-8") + (output_path / "vectors.parquet").write_text("dummy", encoding="utf-8") + self.num_vectors = len(captured["rows"]) + + monkeypatch.setattr( + "industrial_classification_utils.sayt.indexes.VectorStore", + _StubPersistentVectorStore, + ) + + index = DenseVectorIndex.from_corpus( + corpus=corpus, + vectoriser=_StubVectoriser(np.array([[1.0, 0.0]])), + output_dir=output_dir, + overwrite=True, + ) + + assert index._num_vectors == len(corpus.rows) + assert Path(captured["file_name"]).name == "corpus.csv" + assert Path(captured["file_name"]).parent != output_dir + assert captured["data_type"] == "csv" + assert captured["vectoriser_type"] == "_StubVectoriser" + assert captured["batch_size"] == 64 + assert captured["output_dir"] == str(output_dir) + assert captured["overwrite"] is True + assert captured["hooks"] is None + assert captured["rows"] == [ + {"label": row_id, "text": search_text} for row_id, search_text, _ in corpus.rows + ] + + +def test_dense_vector_index_loads_existing_filespace( + monkeypatch, tmp_path, small_corpus +): + """Load a persisted dense index via ClassifAI's filespace API.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + folder_path = tmp_path / "existing-ngram" + + class _StubLoadedVectorStore: + num_vectors = 7 + + def _fake_from_filespace(*, folder_path, vectoriser, hooks): + captured["folder_path"] = folder_path + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["hooks"] = hooks + return _StubLoadedVectorStore() + + monkeypatch.setattr( + "industrial_classification_utils.sayt.indexes.VectorStore.from_filespace", + _fake_from_filespace, + ) + + index = DenseVectorIndex.from_filespace( + corpus=corpus, + folder_path=folder_path, + vectoriser=_StubVectoriser(np.array([[1.0, 0.0]])), + ) + + assert index._num_vectors == 7 + assert index._corpus is corpus + assert captured == { + "folder_path": str(folder_path), + "vectoriser_type": "_StubVectoriser", + "hooks": None, + } + + def test_semantic_retriever_builds_index_with_wrapped_vectoriser( monkeypatch, small_corpus ): @@ -247,8 +353,16 @@ def __init__(self, model_name): def transform(self, texts): return np.array([[1.0, 0.0]]) - def _fake_build_dense_vector_index(*, corpus, vectoriser): + def _fake_build_dense_vector_index( + *, + corpus, + vectoriser, + output_dir=None, + overwrite=True, + ): captured["vectoriser_type"] = type(vectoriser).__name__ + captured["output_dir"] = output_dir + captured["overwrite"] = overwrite return DenseVectorIndex( _vector_store=_StubVectorStore([]), _num_vectors=1, @@ -256,11 +370,11 @@ def _fake_build_dense_vector_index(*, corpus, vectoriser): ) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.HuggingFaceVectoriser", + "industrial_classification_utils.sayt.indexes.HuggingFaceVectoriser", _StubHFVectoriser, ) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.DenseVectorIndex.from_corpus", + "industrial_classification_utils.sayt.indexes.DenseVectorIndex.from_corpus", _fake_build_dense_vector_index, ) @@ -269,10 +383,62 @@ def _fake_build_dense_vector_index(*, corpus, vectoriser): assert captured == { "model_name": "sentence-transformers/all-MiniLM-L6-v2", "vectoriser_type": "_L2NormalisingVectoriser", + "output_dir": None, + "overwrite": True, } assert retriever._min_chars == 3 +def test_load_semantic_index_loads_existing_filespace_with_wrapped_vectoriser( + monkeypatch, tmp_path, small_corpus +): + """Wrap the embedding vectoriser before loading a persisted semantic index.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + folder_path = tmp_path / "existing-semantic" + + class _StubHFVectoriser: + def __init__(self, model_name): + captured["model_name"] = model_name + + def transform(self, texts): + _ = texts + return np.array([[1.0, 0.0]]) + + def _fake_load_dense_vector_index(*, corpus, folder_path, vectoriser): + captured["corpus"] = corpus + captured["folder_path"] = folder_path + captured["vectoriser_type"] = type(vectoriser).__name__ + return DenseVectorIndex( + _vector_store=_StubVectorStore([]), + _num_vectors=2, + _corpus=corpus, + ) + + monkeypatch.setattr( + "industrial_classification_utils.sayt.indexes.HuggingFaceVectoriser", + _StubHFVectoriser, + ) + monkeypatch.setattr( + "industrial_classification_utils.sayt.indexes.DenseVectorIndex.from_filespace", + _fake_load_dense_vector_index, + ) + + index = load_semantic_index( + corpus, + model="all-MiniLM-L6-v2", + folder_path=folder_path, + ) + + assert index._num_vectors == 2 + assert captured == { + "model_name": "sentence-transformers/all-MiniLM-L6-v2", + "corpus": corpus, + "folder_path": folder_path, + "vectoriser_type": "_L2NormalisingVectoriser", + } + + def test_semantic_retriever_returns_empty_for_short_queries(): """Stop before vectorisation when the semantic query is too short.""" retriever = SemanticRetriever.__new__(SemanticRetriever) diff --git a/tests/sayt/test_sayt_storage.py b/tests/sayt/test_sayt_storage.py new file mode 100644 index 0000000..091f5bf --- /dev/null +++ b/tests/sayt/test_sayt_storage.py @@ -0,0 +1,218 @@ +"""Tests for SAYT storage helper validation and artifact edge cases.""" + +# pylint: disable=protected-access,too-few-public-methods,missing-function-docstring + +import json + +import pytest + +from industrial_classification_utils.sayt import ( + PrefixRetrieverSpec, + SemanticRetrieverSpec, + retriever_specs, + storage, +) +from industrial_classification_utils.sayt.core import CleanCorpus + + +def test_prepare_artifact_dir_handles_existing_paths(tmp_path): + """Reject accidental reuse, then replace existing directories or files.""" + artifact_dir = tmp_path / "artifact" + artifact_dir.mkdir() + stale_file = artifact_dir / "stale.txt" + stale_file.write_text("stale", encoding="utf-8") + + with pytest.raises(FileExistsError, match="Artifact directory already exists"): + storage.prepare_artifact_dir(artifact_dir) + + result = storage.prepare_artifact_dir(artifact_dir, overwrite=True) + + assert result == artifact_dir + assert artifact_dir.is_dir() + assert not stale_file.exists() + + artifact_file = tmp_path / "artifact-file" + artifact_file.write_text("stale", encoding="utf-8") + + replaced = storage.prepare_artifact_dir(artifact_file, overwrite=True) + + assert replaced == artifact_file + assert artifact_file.is_dir() + + +def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): + """Raise clear errors for missing files and malformed manifest payloads.""" + with pytest.raises(FileNotFoundError, match="Artifact corpus file not found"): + storage.read_artifact_corpus(artifact_dir=tmp_path) + + with pytest.raises(FileNotFoundError, match="Artifact manifest not found"): + storage.read_artifact_manifest(artifact_dir=tmp_path) + + manifest_path = tmp_path / "manifest.json" + manifest_path.write_text( + json.dumps({"artifact_type": "other", "artifact_version": 2}), + encoding="utf-8", + ) + with pytest.raises(ValueError, match="Unsupported artifact type"): + storage.read_artifact_manifest(artifact_dir=tmp_path) + + manifest_path.write_text( + json.dumps({"artifact_type": "sayt", "artifact_version": 999}), + encoding="utf-8", + ) + with pytest.raises(ValueError, match="Unsupported artifact version"): + storage.read_artifact_manifest(artifact_dir=tmp_path) + + manifest_path.write_text( + json.dumps( + { + "artifact_type": "sayt", + "artifact_version": 2, + "min_chars": 3, + "corpus_file": "corpus.csv", + "corpus_size": 1, + "retrievers": [], + } + ), + encoding="utf-8", + ) + with pytest.raises( + ValueError, match="Malformed artifact manifest: missing max_suggestions" + ): + storage.read_artifact_manifest(artifact_dir=tmp_path) + + +def test_storage_helper_validation_errors(): + """Guard helper APIs against invalid types, paths, and unsupported specs.""" + stored_retriever = storage.StoredRetrieverSpec( + spec=PrefixRetrieverSpec(), + path=None, + ) + + with pytest.raises(ValueError, match="does not have a stored filespace"): + storage.retriever_filespace_path("artifact", stored_retriever) + + with pytest.raises(ValueError, match="Malformed retriever config for type: prefix"): + storage._deserialise_stored_retriever( + {"type": "prefix", "weight": 1.0, "config": []} + ) + + with pytest.raises(ValueError, match="Unsupported stored retriever type: missing"): + storage._deserialise_stored_retriever( + {"type": "missing", "weight": 1.0, "config": {}} + ) + + class _UnknownSpec: + name = "unknown" + weight = 1.0 + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + + with pytest.raises( + TypeError, + match="Only artifact-aware retriever specs can be persisted; got _UnknownSpec", + ): + storage._build_stored_retriever(0, _UnknownSpec()) + + with pytest.raises( + ValueError, match="Malformed integer value for retriever field: n" + ): + storage._coerce_int(True, field_name="n") + + with pytest.raises( + ValueError, match="Malformed float value for retriever field: weight" + ): + storage._coerce_float(True, field_name="weight") + + +def test_semantic_retriever_artifact_round_trips_and_loads( + monkeypatch, tmp_path, small_corpus +): + """Round-trip semantic artifact state and delegate dense index load/build calls.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + spec = SemanticRetrieverSpec(model="all-MiniLM-L6-v2", weight=2.5) + stored_retriever = storage._build_stored_retriever(2, spec) + path = tmp_path / stored_retriever.path + + def _fake_build_semantic_index(corpus_arg, *, model, output_dir, overwrite): + captured["build"] = { + "corpus": corpus_arg, + "model": model, + "output_dir": output_dir, + "overwrite": overwrite, + } + + def _fake_load_semantic_index(corpus_arg, *, model, folder_path): + captured["load"] = { + "corpus": corpus_arg, + "model": model, + "folder_path": folder_path, + } + return "loaded-index" + + class _StubSemanticRetriever: + @classmethod + def from_index(cls, corpus_arg, *, min_chars, index): + captured["from_index"] = { + "corpus": corpus_arg, + "min_chars": min_chars, + "index": index, + } + return {"index": index, "min_chars": min_chars} + + monkeypatch.setattr( + retriever_specs, "build_semantic_index", _fake_build_semantic_index + ) + monkeypatch.setattr( + retriever_specs, "load_semantic_index", _fake_load_semantic_index + ) + monkeypatch.setattr(retriever_specs, "SemanticRetriever", _StubSemanticRetriever) + + rebuilt = storage._deserialise_stored_retriever( + { + "type": stored_retriever.spec.name, + "weight": spec.weight, + "path": stored_retriever.path, + "config": {"model": "all-MiniLM-L6-v2"}, + } + ) + + assert stored_retriever.spec.name == "semantic" + assert stored_retriever.path == "retrievers/02-semantic" + assert isinstance(rebuilt.spec, SemanticRetrieverSpec) + assert rebuilt.spec.weight == pytest.approx(2.5) + + storage.build_retriever_artifact( + corpus=corpus, + min_chars=3, + stored_retriever=stored_retriever, + artifact_dir=tmp_path, + ) + retriever = storage.load_retriever_from_artifact( + corpus=corpus, + min_chars=3, + stored_retriever=stored_retriever, + artifact_dir=tmp_path, + ) + + assert retriever == {"index": "loaded-index", "min_chars": 3} + assert captured == { + "build": { + "corpus": corpus, + "model": "all-MiniLM-L6-v2", + "output_dir": path, + "overwrite": True, + }, + "load": { + "corpus": corpus, + "model": "all-MiniLM-L6-v2", + "folder_path": path, + }, + "from_index": { + "corpus": corpus, + "min_chars": 3, + "index": "loaded-index", + }, + }