From d5ecf20be98027c01fb5d6b28cf540c9c8d7d7ac Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Thu, 28 May 2026 12:16:00 +0100 Subject: [PATCH 01/11] feat(sayt): Add SAYTBuilder that constructs runtime artifacts for later use --- demos/sayt/sayt_artifact_example.py | 94 +++ pyproject.toml | 2 + .../sayt/__init__.py | 10 + .../sayt/sayt.py | 137 ++++- .../sayt/sayt_builder.py | 80 +++ .../sayt/sayt_core.py | 64 +- .../sayt/sayt_indexes.py | 132 +++- .../sayt/sayt_retriever_specs.py | 53 +- .../sayt/sayt_retrievers.py | 15 + .../sayt/sayt_storage.py | 565 ++++++++++++++++++ tests/sayt/test_sayt.py | 42 +- tests/sayt/test_sayt_builder.py | 263 ++++++++ tests/sayt/test_sayt_retrievers.py | 118 +++- 13 files changed, 1541 insertions(+), 34 deletions(-) create mode 100644 demos/sayt/sayt_artifact_example.py create mode 100644 src/industrial_classification_utils/sayt/sayt_builder.py create mode 100644 src/industrial_classification_utils/sayt/sayt_storage.py create mode 100644 tests/sayt/test_sayt_builder.py diff --git a/demos/sayt/sayt_artifact_example.py b/demos/sayt/sayt_artifact_example.py new file mode 100644 index 0000000..bbf0b37 --- /dev/null +++ b/demos/sayt/sayt_artifact_example.py @@ -0,0 +1,94 @@ +"""Build and reload a persisted SAYT artifact with the built-in retrievers.""" + +# pylint: disable=duplicate-code + +# %% +import json +from pathlib import Path +from tempfile import TemporaryDirectory + +from industrial_classification_utils.sayt import ( + NgramRetrieverSpec, + PrefixRetrieverSpec, + SAYTBuilder, + SAYTSuggester, + SemanticRetrieverSpec, +) + +# %% +############# toy example to verify SAYT artifact build/load works ############# +small_corpus = [ + ("Car wash", "Car Wash"), + ("Car wash", "CAR WASH (duplicate)"), + ("Car waxing", "Car Waxing"), + ("Waxing car", "Car Waxing"), + ("Carpentry services", "Carpentry services"), + ("Dog grooming", "Dog grooming"), + ("Cat grooming", "Cat grooming"), + ("USed car sales", "Used car sales"), + ("Car rental", "Car rental"), + ("Car repair", "Car repair"), + ("Car servicing", "Car servicing"), +] + +retrievers = [ + PrefixRetrieverSpec(), + NgramRetrieverSpec(max_df=0.8), + SemanticRetrieverSpec(), +] + +# Keep the temporary directory alive across notebook cells. +# pylint: disable-next=consider-using-with +temp_dir = TemporaryDirectory(prefix="sayt_artifact_demo_") +artifact_dir = Path(temp_dir.name) / "car_services_sayt" +print("artifact will be written to:", artifact_dir) + +# %% +# Semantic artifact builds may take longer the first time if the model cache +# needs to be created locally. +artifact_path = SAYTBuilder( + small_corpus, + retrievers=retrievers, + min_chars=3, + max_suggestions=5, +).build_artifact(artifact_dir, overwrite=True) + +print("artifact saved to:", artifact_path) +print("artifact files:") +for path in sorted(artifact_path.rglob("*")): + if path.is_file(): + print("-", path.relative_to(artifact_path)) + +# %% +manifest = json.loads((artifact_path / "manifest.json").read_text(encoding="utf-8")) +print(json.dumps(manifest, indent=2)) + +# %% +live_suggester = SAYTSuggester( + small_corpus, + retrievers=retrievers, + min_chars=3, + max_suggestions=5, +) +loaded_suggester = SAYTSuggester.from_artifact(artifact_path) + +for query in ["car", "cars", "waxi", "grom", "wash", "duplicate", "auto"]: + live_suggestions = live_suggester.suggest(query, 5) + loaded_suggestions = loaded_suggester.suggest(query, 5) + + print("searching for:", query) + print("live", "->", live_suggestions) + print("loaded", "->", loaded_suggestions) + print("loaded_scores", "->", loaded_suggester.suggest_with_scores(query, 5)) + if live_suggestions != loaded_suggestions: + raise RuntimeError("Loaded suggester results did not match live build") + print() + +# %% +# Run `temp_dir.cleanup()` when you are finished exploring the saved files. +print("artifact ready for inspection:", artifact_path) + +# %% +temp_dir.cleanup() + +# %% diff --git a/pyproject.toml b/pyproject.toml index ff0bbaa..02b3c5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,8 @@ ignore = [ "E501", # indentation contains tabs "W191", + # Allow the use of older TypeVar syntax for generic functions + "UP047", ] [tool.ruff.lint.pydocstyle] diff --git a/src/industrial_classification_utils/sayt/__init__.py b/src/industrial_classification_utils/sayt/__init__.py index 8c5e03c..07f2bdf 100644 --- a/src/industrial_classification_utils/sayt/__init__.py +++ b/src/industrial_classification_utils/sayt/__init__.py @@ -1,15 +1,21 @@ """Public SAYT interfaces and built-in retriever components.""" from .sayt import SAYTSuggester +from .sayt_builder import SAYTBuilder from .sayt_retriever_specs import ( NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, + RetrieverArtifactHandler, RetrieverSpec, SemanticRetrieverSpec, default_retriever_specs, ) from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever +from .sayt_storage import ( + register_retriever_artifact_handler, + unregister_retriever_artifact_handler, +) __all__ = [ "NgramRetriever", @@ -17,9 +23,13 @@ "PrefixRetriever", "PrefixRetrieverSpec", "Retriever", + "RetrieverArtifactHandler", "RetrieverSpec", + "SAYTBuilder", "SAYTSuggester", "SemanticRetriever", "SemanticRetrieverSpec", "default_retriever_specs", + "register_retriever_artifact_handler", + "unregister_retriever_artifact_handler", ] diff --git a/src/industrial_classification_utils/sayt/sayt.py b/src/industrial_classification_utils/sayt/sayt.py index ce41021..aa108ca 100644 --- a/src/industrial_classification_utils/sayt/sayt.py +++ b/src/industrial_classification_utils/sayt/sayt.py @@ -8,8 +8,8 @@ import os from collections.abc import Iterable, Sequence from dataclasses import dataclass +from pathlib import Path -import pandas as pd from survey_assist_utils.logging import get_logger from .sayt_core import ( @@ -24,6 +24,13 @@ RetrieverSpec, default_retriever_specs, ) +from .sayt_storage import ( + StoredRetrieverSpec, + load_corpus_from_csv, + load_retriever_from_artifact, + read_artifact_corpus, + read_artifact_manifest, +) logger = get_logger(__name__) @@ -87,9 +94,28 @@ class SAYTSuggester: ``` """ + @classmethod + def _from_state( + cls, + *, + corpus: CleanCorpus, + config: SaytConfig, + retriever_specs: Sequence[RetrieverSpec], + retrievers: list[_ConfiguredRetriever], + ) -> "SAYTSuggester": + """Construct a suggester from already-validated runtime state.""" + suggester = cls.__new__(cls) + suggester._corpus = corpus + suggester._config = config + suggester._max_duplication = max(corpus.display_text_count.values(), default=0) + suggester._retriever_specs = tuple(retriever_specs) + suggester._retrievers = retrievers + logger.info(f"SAYT suggester initialized with config: {suggester.get_config()}") + return suggester + def __init__( self, - corpus: Iterable[tuple[str, str]] | Iterable[str], + corpus: Iterable[tuple[object, object]] | Iterable[str], *, retrievers: Sequence[RetrieverSpec] | None = None, **kwargs: object, @@ -107,17 +133,19 @@ def __init__( """ self._corpus = CleanCorpus.model_validate(corpus) self._config = SaytConfig.model_validate(kwargs) - self._max_duplication = max(self._corpus.display_text_count.values(), default=0) self._retriever_specs = tuple( default_retriever_specs() if retrievers is None else retrievers ) self._retrievers = self._build_retrievers(self._retriever_specs) + self._max_duplication = max(self._corpus.display_text_count.values(), default=0) logger.info(f"SAYT suggester initialized with config: {self.get_config()}") - def _build_retrievers( - self, retriever_specs: Sequence[RetrieverSpec] - ) -> list[_ConfiguredRetriever]: + @staticmethod + def _normalised_retriever_specs( + retriever_specs: Sequence[RetrieverSpec], + ) -> list[tuple[RetrieverSpec, float]]: + """Validate and normalise configured retriever weights.""" if not retriever_specs: raise ValueError("At least one retriever must be configured") @@ -131,17 +159,21 @@ def _build_retrievers( validated_specs.append((spec, weight)) total_weight = sum(weight for _, weight in validated_specs) + return [(spec, weight / total_weight) for spec, weight in validated_specs] + def _build_retrievers( + self, retriever_specs: Sequence[RetrieverSpec] + ) -> list[_ConfiguredRetriever]: return [ _ConfiguredRetriever( name=spec.name, - weight=weight / total_weight, + weight=weight, retriever=spec.build( self._corpus, min_chars=self._config.min_chars, ), ) - for spec, weight in validated_specs + for spec, weight in self._normalised_retriever_specs(retriever_specs) ] @classmethod @@ -171,19 +203,92 @@ def from_csv( Raises: ValueError: If the requested search or display column is missing. """ - df = pd.read_csv(file_path) - if search_text_col not in df.columns: - raise ValueError(f"Column '{search_text_col}' not found in CSV") - if display_text_col is None: - display_text_col = search_text_col - if display_text_col not in df.columns: - raise ValueError(f"Column '{display_text_col}' not found in CSV") + corpus_rows = load_corpus_from_csv( + file_path, + search_text_col=search_text_col, + display_text_col=display_text_col, + ) return cls( - list(zip(df[search_text_col], df[display_text_col], strict=False)), + corpus_rows, retrievers=retrievers, **kwargs, ) + @classmethod + def from_artifact(cls, artifact_dir: str | os.PathLike) -> "SAYTSuggester": + """Load a suggester from a persisted SAYT artifact directory.""" + artifact_path = Path(artifact_dir) + manifest = read_artifact_manifest(artifact_dir=artifact_path) + persisted_rows = read_artifact_corpus( + artifact_dir=artifact_path, + corpus_file=manifest.corpus_file, + ) + corpus = CleanCorpus.from_persisted_rows(persisted_rows) + if corpus.size != manifest.corpus_size: + raise ValueError("Artifact corpus size does not match manifest") + + retrievers = cls._load_retrievers_from_artifact( + corpus=corpus, + config=manifest.config, + stored_retrievers=manifest.retrievers, + artifact_dir=artifact_path, + ) + return cls._from_state( + corpus=corpus, + config=manifest.config, + retriever_specs=[ + stored_retriever.spec for stored_retriever in manifest.retrievers + ], + retrievers=retrievers, + ) + + @classmethod + def _load_retrievers_from_artifact( + cls, + *, + corpus: CleanCorpus, + config: SaytConfig, + stored_retrievers: Sequence[StoredRetrieverSpec], + artifact_dir: Path, + ) -> list[_ConfiguredRetriever]: + """Restore runtime retrievers from a persisted SAYT artifact.""" + normalised_specs = cls._normalised_retriever_specs( + [stored_retriever.spec for stored_retriever in stored_retrievers] + ) + return [ + _ConfiguredRetriever( + name=stored_retriever.spec.name, + weight=weight, + retriever=cls._load_retriever_from_artifact( + corpus=corpus, + config=config, + stored_retriever=stored_retriever, + artifact_dir=artifact_dir, + ), + ) + for (_, weight), stored_retriever in zip( + normalised_specs, + stored_retrievers, + strict=True, + ) + ] + + @staticmethod + def _load_retriever_from_artifact( + *, + corpus: CleanCorpus, + config: SaytConfig, + stored_retriever: StoredRetrieverSpec, + artifact_dir: Path, + ) -> Retriever: + """Restore a runtime retriever from persisted artifact state.""" + return load_retriever_from_artifact( + corpus=corpus, + config=config, + stored_retriever=stored_retriever, + artifact_dir=artifact_dir, + ) + def _dedup_suggestions( self, suggestions: list[Suggestion] ) -> list[tuple[str, float]]: diff --git a/src/industrial_classification_utils/sayt/sayt_builder.py b/src/industrial_classification_utils/sayt/sayt_builder.py new file mode 100644 index 0000000..367680d --- /dev/null +++ b/src/industrial_classification_utils/sayt/sayt_builder.py @@ -0,0 +1,80 @@ +"""Offline artifact builder for persisted SAYT runtime assets.""" + +import os +from collections.abc import Iterable, Sequence +from pathlib import Path + +from .sayt_core import CleanCorpus, SaytConfig +from .sayt_retriever_specs import ( + RetrieverSpec, + default_retriever_specs, +) +from .sayt_storage import ( + build_artifact_manifest, + build_retriever_artifact, + load_corpus_from_csv, + prepare_artifact_dir, + write_artifact_corpus, + write_artifact_manifest, +) + + +class SAYTBuilder: + """Build a persisted SAYT artifact for later runtime loading.""" + + def __init__( + self, + corpus: Iterable[tuple[object, object]] | Iterable[str], + *, + retrievers: Sequence[RetrieverSpec] | None = None, + **kwargs: object, + ) -> None: + """Initialise an artifact builder from raw corpus input.""" + self._corpus = CleanCorpus.model_validate(corpus) + self._config = SaytConfig.model_validate(kwargs) + self._retriever_specs = tuple( + default_retriever_specs() if retrievers is None else retrievers + ) + + @classmethod + def from_csv( + cls, + file_path: str | os.PathLike, + *, + search_text_col: str = "title", + display_text_col: str | None = None, + retrievers: Sequence[RetrieverSpec] | None = None, + **kwargs: object, + ) -> "SAYTBuilder": + """Initialise an artifact builder from CSV input.""" + corpus_rows = load_corpus_from_csv( + file_path, + search_text_col=search_text_col, + display_text_col=display_text_col, + ) + return cls(corpus_rows, retrievers=retrievers, **kwargs) + + def build_artifact( + self, + output_dir: str | os.PathLike, + *, + overwrite: bool = False, + ) -> Path: + """Persist the current SAYT configuration and dense stores to disk.""" + artifact_dir = prepare_artifact_dir(output_dir, overwrite=overwrite) + manifest = build_artifact_manifest( + corpus=self._corpus, + config=self._config, + retriever_specs=self._retriever_specs, + ) + + write_artifact_corpus(self._corpus, artifact_dir=artifact_dir) + for stored_retriever in manifest.retrievers: + build_retriever_artifact( + corpus=self._corpus, + stored_retriever=stored_retriever, + artifact_dir=artifact_dir, + ) + + write_artifact_manifest(manifest, artifact_dir=artifact_dir) + return artifact_dir diff --git a/src/industrial_classification_utils/sayt/sayt_core.py b/src/industrial_classification_utils/sayt/sayt_core.py index b0b40f0..64a77e2 100644 --- a/src/industrial_classification_utils/sayt/sayt_core.py +++ b/src/industrial_classification_utils/sayt/sayt_core.py @@ -16,7 +16,7 @@ _NON_ALNUM_SPACE_RE = re.compile(r"[^a-z ]+") -def _normalise(text: str | None) -> str: +def _normalise(text: object) -> str: if not isinstance(text, str): return "" text = text.strip().lower() @@ -29,6 +29,15 @@ def _row_uid(index: int, search_text: str, display_text: str) -> str: return str(uuid5(NAMESPACE_URL, f"{index}\0{search_text}\0{display_text}")) +@dataclass(frozen=True, slots=True) +class PersistedCorpusRow: + """Represent a persisted SAYT corpus row restored from artifact storage.""" + + row_id: str + search_text: str + display_text: str + + class CleanCorpus(BaseModel): """Store cleaned SAYT rows and their derived lookup tables. @@ -54,8 +63,12 @@ def _coerce_input(cls, data: object) -> object: @model_validator(mode="after") def _build_indexes(self) -> "CleanCorpus": self.rows = self._clean_corpus( - cast(Iterable[str] | Iterable[tuple[str, str]], self.corpus) + cast(Iterable[str] | Iterable[tuple[object, object]], self.corpus) ) + return self._populate_indexes() + + def _populate_indexes(self) -> "CleanCorpus": + """Rebuild lookup tables from the current cleaned rows.""" self.id_to_search = {rid: search for rid, search, _ in self.rows} self.id_to_display = {rid: display for rid, _, display in self.rows} self.display_text_count = {} @@ -66,9 +79,54 @@ def _build_indexes(self) -> "CleanCorpus": self.size = len(self.rows) return self + @classmethod + def from_persisted_rows( + cls, + rows: Iterable[PersistedCorpusRow | tuple[str, str, str]], + ) -> "CleanCorpus": + """Restore a cleaned corpus from persisted row identifiers and text. + + Args: + rows: Persisted ``(row_id, search_text, display_text)`` triples or + ``PersistedCorpusRow`` objects. + + Returns: + A ``CleanCorpus`` whose row identifiers and lookup maps match the + persisted artifact data exactly. + + Raises: + ValueError: If no persisted rows are supplied. + """ + restored_rows = [cls._coerce_persisted_row(row) for row in rows] + if not restored_rows: + raise ValueError("corpus is empty after filtering") + + corpus = cls.model_construct( + corpus=[ + (search_text, display_text) + for _, search_text, display_text in restored_rows + ], + rows=restored_rows, + id_to_search={}, + id_to_display={}, + display_text_count={}, + size=0, + ) + return corpus._populate_indexes() + + @staticmethod + def _coerce_persisted_row( + row: PersistedCorpusRow | tuple[str, str, str], + ) -> tuple[str, str, str]: + """Convert persisted row data into the internal tuple format.""" + if isinstance(row, PersistedCorpusRow): + return (row.row_id, row.search_text, row.display_text) + row_id, search_text, display_text = row + return (str(row_id), str(search_text), str(display_text)) + @staticmethod def _clean_corpus( - corpus: Iterable[str] | Iterable[tuple[str, str]], + corpus: Iterable[str] | Iterable[tuple[object, object]], ) -> list[tuple[str, str, str]]: if not isinstance(corpus, Iterable): raise TypeError( diff --git a/src/industrial_classification_utils/sayt/sayt_indexes.py b/src/industrial_classification_utils/sayt/sayt_indexes.py index d923140..a435ab2 100644 --- a/src/industrial_classification_utils/sayt/sayt_indexes.py +++ b/src/industrial_classification_utils/sayt/sayt_indexes.py @@ -7,6 +7,7 @@ import tempfile from contextlib import contextmanager from dataclasses import dataclass +from pathlib import Path from typing import cast import classifai.indexers.main as classifai_indexers_main @@ -49,6 +50,8 @@ def from_corpus( *, corpus: CleanCorpus, vectoriser: VectoriserBase, + output_dir: str | os.PathLike[str] | None = None, + overwrite: bool = True, ) -> "DenseVectorIndex": """Build a dense index from a cleaned corpus. @@ -56,20 +59,24 @@ def from_corpus( corpus: Cleaned corpus whose normalised search text should be indexed. vectoriser: Vectoriser used to embed corpus rows and future queries. + output_dir: Optional persistent filespace directory for the + underlying ClassifAI vector store. When omitted, a temporary + directory is used. + overwrite: Whether to allow ClassifAI to replace an existing + filespace when ``output_dir`` is provided. Returns: A ``DenseVectorIndex`` backed by ClassifAI's ``VectorStore``. """ - with tempfile.TemporaryDirectory(prefix="sayt_") as tmp_dir: - csv_path = os.path.join(tmp_dir, "corpus.csv") - - with open(csv_path, "w", newline="", encoding="utf-8") as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=["label", "text"]) - writer.writeheader() - writer.writerows( - {"label": row_id, "text": search_text} - for row_id, search_text, _ in corpus.rows - ) + with tempfile.TemporaryDirectory(prefix="sayt_") as temp_dir: + csv_path = os.path.join(temp_dir, "corpus.csv") + classifai_output_dir = ( + os.fspath(output_dir) + if output_dir is not None + else os.path.join(temp_dir, "vector_store") + ) + + cls._write_corpus_csv(corpus, csv_path) with _silence_classifai_tqdm(): vector_store = VectorStore( @@ -77,8 +84,8 @@ def from_corpus( data_type="csv", vectoriser=vectoriser, batch_size=64, - output_dir=None, - overwrite=True, + output_dir=classifai_output_dir, + overwrite=overwrite, hooks=None, ) @@ -88,6 +95,54 @@ def from_corpus( _corpus=corpus, ) + @classmethod + def from_filespace( + cls, + *, + corpus: CleanCorpus, + folder_path: str | os.PathLike[str], + vectoriser: VectoriserBase, + ) -> "DenseVectorIndex": + """Load a dense index from a persisted ClassifAI filespace. + + Args: + corpus: Cleaned corpus whose row metadata should back query results. + folder_path: Filesystem directory containing ``metadata.json`` and + ``vectors.parquet``. + vectoriser: Vectoriser used to embed future query text. + + Returns: + A ``DenseVectorIndex`` backed by a loaded ``VectorStore``. + """ + with _silence_classifai_tqdm(): + vector_store = VectorStore.from_filespace( + folder_path=os.fspath(folder_path), + vectoriser=vectoriser, + hooks=None, + ) + + return cls( + _vector_store=vector_store, + _num_vectors=int(vector_store.num_vectors or 0), + _corpus=corpus, + ) + + @staticmethod + def _write_corpus_csv( + corpus: CleanCorpus, + csv_path: str | os.PathLike[str], + ) -> None: + """Write the row-id and search-text schema expected by ClassifAI.""" + csv_file = Path(csv_path) + csv_file.parent.mkdir(parents=True, exist_ok=True) + with open(csv_file, "w", newline="", encoding="utf-8") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=["label", "text"]) + writer.writeheader() + writer.writerows( + {"label": row_id, "text": search_text} + for row_id, search_text, _ in corpus.rows + ) + def query(self, q_norm: str, num_suggestions: int) -> list[tuple[str, float]]: """Query the dense index with a normalised string. @@ -157,6 +212,8 @@ def build_ngram_index( *, n: int, max_df: float, + output_dir: str | os.PathLike[str] | None = None, + overwrite: bool = True, ) -> DenseVectorIndex: """Build a dense index backed by character n-gram vectors. @@ -164,6 +221,10 @@ def build_ngram_index( corpus: Cleaned corpus to index. n: Character n-gram size. max_df: Maximum document frequency passed to ``CountVectorizer``. + output_dir: Optional persistent filespace directory for the generated + vector store. + overwrite: Whether to allow ClassifAI to replace an existing filespace + when ``output_dir`` is provided. Returns: A dense index using character n-gram embeddings. @@ -175,6 +236,27 @@ def build_ngram_index( n=n, max_df=max_df, ), + output_dir=output_dir, + overwrite=overwrite, + ) + + +def load_ngram_index( + corpus: CleanCorpus, + *, + n: int, + max_df: float, + folder_path: str | os.PathLike[str], +) -> DenseVectorIndex: + """Load a persisted dense index backed by character n-gram vectors.""" + return DenseVectorIndex.from_filespace( + corpus=corpus, + folder_path=folder_path, + vectoriser=_CharNgramVectoriser( + [search for _, search, _ in corpus.rows], + n=n, + max_df=max_df, + ), ) @@ -182,12 +264,18 @@ def build_semantic_index( corpus: CleanCorpus, *, model: str, + output_dir: str | os.PathLike[str] | None = None, + overwrite: bool = True, ) -> DenseVectorIndex: """Build a dense index backed by sentence-transformer embeddings. Args: corpus: Cleaned corpus to index. model: Sentence-transformer model name without the repository prefix. + output_dir: Optional persistent filespace directory for the generated + vector store. + overwrite: Whether to allow ClassifAI to replace an existing filespace + when ``output_dir`` is provided. Returns: A dense index using semantic embeddings. @@ -199,4 +287,24 @@ def build_semantic_index( return DenseVectorIndex.from_corpus( corpus=corpus, vectoriser=semantic_vectoriser, + output_dir=output_dir, + overwrite=overwrite, + ) + + +def load_semantic_index( + corpus: CleanCorpus, + *, + model: str, + folder_path: str | os.PathLike[str], +) -> DenseVectorIndex: + """Load a persisted dense index backed by semantic embeddings.""" + base_vectoriser: VectoriserBase = HuggingFaceVectoriser( + f"sentence-transformers/{model}" + ) + semantic_vectoriser: VectoriserBase = _L2NormalisingVectoriser(base_vectoriser) + return DenseVectorIndex.from_filespace( + corpus=corpus, + folder_path=folder_path, + vectoriser=semantic_vectoriser, ) diff --git a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py b/src/industrial_classification_utils/sayt/sayt_retriever_specs.py index a45df4d..e760d8a 100644 --- a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py +++ b/src/industrial_classification_utils/sayt/sayt_retriever_specs.py @@ -3,10 +3,12 @@ """Public retriever protocols and configuration objects for SAYT.""" import math +from collections.abc import Mapping from dataclasses import dataclass, field +from pathlib import Path from typing import Protocol -from .sayt_core import CleanCorpus, Suggestion +from .sayt_core import CleanCorpus, SaytConfig, Suggestion from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever _MIN_NGRAM_SIZE = 2 @@ -53,6 +55,55 @@ def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: """ +class RetrieverArtifactHandler(Protocol): + """Persistence hooks for storing and restoring a retriever spec. + + This optional protocol extends the runtime-only ``RetrieverSpec`` contract + for artifact build/load flows. Handlers are registered separately so custom + retriever specs remain lightweight unless they need persistence support. + """ + + @property + def artifact_type(self) -> str: + """Return the stable manifest identifier for this handler.""" + + def can_handle(self, spec: RetrieverSpec) -> bool: + """Return whether this handler can persist the supplied spec.""" + + def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: + """Return spec-specific manifest configuration excluding weight/path.""" + + def deserialise_spec( + self, + *, + weight: float, + config: Mapping[str, object], + ) -> RetrieverSpec: + """Rebuild a retriever spec from persisted manifest data.""" + + def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: + """Return the default relative artifact path for persisted assets.""" + + def build_artifact( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + path: Path | None, + ) -> None: + """Write any persisted retriever assets needed for later loading.""" + + def load_retriever( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + config: SaytConfig, + path: Path | None, + ) -> Retriever: + """Restore a runtime retriever from persisted artifact state.""" + + def _validate_retriever_weight(weight: float) -> None: if not math.isfinite(weight) or weight <= 0: raise ValueError("retriever weight must be a finite value > 0") diff --git a/src/industrial_classification_utils/sayt/sayt_retrievers.py b/src/industrial_classification_utils/sayt/sayt_retrievers.py index 6bb1f1d..dcbf8e5 100644 --- a/src/industrial_classification_utils/sayt/sayt_retrievers.py +++ b/src/industrial_classification_utils/sayt/sayt_retrievers.py @@ -109,6 +109,21 @@ class _DenseRetriever: _min_chars: int _index: DenseVectorIndex + @classmethod + def from_index( + cls, + corpus: CleanCorpus, + *, + min_chars: int, + index: DenseVectorIndex, + ) -> "_DenseRetriever": + """Restore a dense retriever from an already-built dense index.""" + retriever = cls.__new__(cls) + retriever._corpus = corpus + retriever._min_chars = min_chars + retriever._index = index + return retriever + def suggest_with_scores( self, q_norm: str, num_suggestions: int ) -> list[Suggestion]: diff --git a/src/industrial_classification_utils/sayt/sayt_storage.py b/src/industrial_classification_utils/sayt/sayt_storage.py new file mode 100644 index 0000000..62ec41a --- /dev/null +++ b/src/industrial_classification_utils/sayt/sayt_storage.py @@ -0,0 +1,565 @@ +"""Artifact and storage helpers for SAYT builder and loader paths.""" + +import csv +import json +import os +import shutil +from collections.abc import Mapping +from dataclasses import dataclass +from pathlib import Path +from typing import TypeVar + +import pandas as pd + +from .sayt_core import CleanCorpus, PersistedCorpusRow, SaytConfig +from .sayt_indexes import ( + build_ngram_index, + build_semantic_index, + load_ngram_index, + load_semantic_index, +) +from .sayt_retriever_specs import ( + NgramRetrieverSpec, + PrefixRetrieverSpec, + Retriever, + RetrieverArtifactHandler, + RetrieverSpec, + SemanticRetrieverSpec, +) +from .sayt_retrievers import NgramRetriever, SemanticRetriever + +SAYT_ARTIFACT_TYPE = "sayt" +SAYT_ARTIFACT_VERSION = 1 +MANIFEST_FILE_NAME = "manifest.json" +CORPUS_FILE_NAME = "corpus.csv" +_ARTIFACT_CORPUS_FIELDS = ["row_id", "search_text", "display_text"] +_RETRIEVERS_DIR_NAME = "retrievers" + +_RETRIEVER_ARTIFACT_HANDLERS: dict[str, RetrieverArtifactHandler] = {} +SpecT = TypeVar("SpecT", bound=RetrieverSpec) + + +@dataclass(frozen=True, slots=True) +class StoredRetrieverSpec: + """Persisted retriever spec plus its optional filespace path.""" + + artifact_type: str + spec: RetrieverSpec + config: dict[str, object] + path: str | None = None + + +@dataclass(frozen=True, slots=True) +class SaytArtifactManifest: + """Structured manifest data for a persisted SAYT artifact.""" + + config: SaytConfig + corpus_file: str + corpus_size: int + retrievers: tuple[StoredRetrieverSpec, ...] + + +def load_corpus_from_csv( + file_path: str | os.PathLike[str], + *, + search_text_col: str = "title", + display_text_col: str | None = None, +) -> list[tuple[object, object]]: + """Load raw corpus tuples from a CSV file. + + Args: + file_path: Path to the CSV file containing suggestion rows. + search_text_col: Column containing the searchable text. + display_text_col: Optional column containing display text. When + omitted, the search column is reused for display values. + + Returns: + Raw ``(search_text, display_text)`` tuples suitable for ``CleanCorpus``. + + Raises: + ValueError: If the requested search or display column is missing. + """ + df = pd.read_csv(file_path) + if search_text_col not in df.columns: + raise ValueError(f"Column '{search_text_col}' not found in CSV") + if display_text_col is None: + display_text_col = search_text_col + if display_text_col not in df.columns: + raise ValueError(f"Column '{display_text_col}' not found in CSV") + return list(zip(df[search_text_col], df[display_text_col], strict=False)) + + +def prepare_artifact_dir( + artifact_dir: str | os.PathLike[str], + *, + overwrite: bool = False, +) -> Path: + """Create or replace the output directory for a SAYT artifact.""" + path = Path(artifact_dir) + if path.exists(): + if not overwrite: + raise FileExistsError("Artifact directory already exists") + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + path.mkdir(parents=True, exist_ok=True) + return path + + +def write_artifact_corpus(corpus: CleanCorpus, *, artifact_dir: str | Path) -> Path: + """Persist cleaned SAYT rows as the artifact corpus source of truth.""" + output_path = Path(artifact_dir) / CORPUS_FILE_NAME + with open(output_path, "w", newline="", encoding="utf-8") as csv_file: + writer = csv.DictWriter(csv_file, fieldnames=_ARTIFACT_CORPUS_FIELDS) + writer.writeheader() + writer.writerows( + { + "row_id": row_id, + "search_text": search_text, + "display_text": display_text, + } + for row_id, search_text, display_text in corpus.rows + ) + return output_path + + +def read_artifact_corpus( + *, + artifact_dir: str | Path, + corpus_file: str = CORPUS_FILE_NAME, +) -> list[PersistedCorpusRow]: + """Read persisted corpus rows from a SAYT artifact.""" + corpus_path = Path(artifact_dir) / corpus_file + if not corpus_path.exists(): + raise FileNotFoundError(f"Artifact corpus file not found: {corpus_path}") + + with open(corpus_path, encoding="utf-8") as csv_file: + reader = csv.DictReader(csv_file) + return [ + PersistedCorpusRow( + row_id=row["row_id"], + search_text=row["search_text"], + display_text=row["display_text"], + ) + for row in reader + ] + + +def build_artifact_manifest( + *, + corpus: CleanCorpus, + config: SaytConfig, + retriever_specs: tuple[RetrieverSpec, ...], +) -> SaytArtifactManifest: + """Build the structured manifest payload for a SAYT artifact.""" + return SaytArtifactManifest( + config=config.model_copy(deep=True), + corpus_file=CORPUS_FILE_NAME, + corpus_size=corpus.size, + retrievers=tuple( + _build_stored_retriever(index, spec) + for index, spec in enumerate(retriever_specs) + ), + ) + + +def write_artifact_manifest( + manifest: SaytArtifactManifest, + *, + artifact_dir: str | Path, +) -> Path: + """Write the manifest for a SAYT artifact.""" + manifest_path = Path(artifact_dir) / MANIFEST_FILE_NAME + manifest_path.write_text( + json.dumps(_serialise_manifest(manifest), indent=2), + encoding="utf-8", + ) + return manifest_path + + +def read_artifact_manifest(*, artifact_dir: str | Path) -> SaytArtifactManifest: + """Read and validate a SAYT artifact manifest.""" + manifest_path = Path(artifact_dir) / MANIFEST_FILE_NAME + if not manifest_path.exists(): + raise FileNotFoundError(f"Artifact manifest not found: {manifest_path}") + + payload = json.loads(manifest_path.read_text(encoding="utf-8")) + if payload.get("artifact_type") != SAYT_ARTIFACT_TYPE: + raise ValueError("Unsupported artifact type") + if payload.get("artifact_version") != SAYT_ARTIFACT_VERSION: + raise ValueError("Unsupported artifact version") + + try: + return SaytArtifactManifest( + config=SaytConfig.model_validate(payload["config"]), + corpus_file=str(payload["corpus_file"]), + corpus_size=int(payload["corpus_size"]), + retrievers=tuple( + _deserialise_stored_retriever(item) for item in payload["retrievers"] + ), + ) + except KeyError as exc: + raise ValueError(f"Malformed artifact manifest: missing {exc.args[0]}") from exc + + +def retriever_filespace_path( + artifact_dir: str | Path, + stored_retriever: StoredRetrieverSpec, +) -> Path: + """Resolve the persisted filespace for a dense retriever entry.""" + if stored_retriever.path is None: + raise ValueError( + f"Retriever '{stored_retriever.spec.name}' does not have a stored filespace" + ) + return Path(artifact_dir) / stored_retriever.path + + +def register_retriever_artifact_handler( + handler: RetrieverArtifactHandler, + *, + replace: bool = False, +) -> None: + """Register a handler for artifact persistence of retriever specs.""" + artifact_type = handler.artifact_type + if artifact_type in _RETRIEVER_ARTIFACT_HANDLERS and not replace: + raise ValueError( + f"Retriever artifact handler already registered for type: {artifact_type}" + ) + _RETRIEVER_ARTIFACT_HANDLERS[artifact_type] = handler + + +def unregister_retriever_artifact_handler(artifact_type: str) -> None: + """Remove a previously registered retriever artifact handler.""" + _RETRIEVER_ARTIFACT_HANDLERS.pop(artifact_type, None) + + +def build_retriever_artifact( + *, + corpus: CleanCorpus, + stored_retriever: StoredRetrieverSpec, + artifact_dir: str | Path, +) -> None: + """Persist retriever-specific artifact state using its registered handler.""" + handler = _get_retriever_artifact_handler(stored_retriever.artifact_type) + path = ( + retriever_filespace_path(artifact_dir, stored_retriever) + if stored_retriever.path is not None + else None + ) + handler.build_artifact( + spec=stored_retriever.spec, + corpus=corpus, + path=path, + ) + + +def load_retriever_from_artifact( + *, + corpus: CleanCorpus, + config: SaytConfig, + stored_retriever: StoredRetrieverSpec, + artifact_dir: str | Path, +) -> Retriever: + """Restore a runtime retriever using its registered artifact handler.""" + handler = _get_retriever_artifact_handler(stored_retriever.artifact_type) + path = ( + retriever_filespace_path(artifact_dir, stored_retriever) + if stored_retriever.path is not None + else None + ) + return handler.load_retriever( + spec=stored_retriever.spec, + corpus=corpus, + config=config, + path=path, + ) + + +def _build_stored_retriever( + index: int, + spec: RetrieverSpec, +) -> StoredRetrieverSpec: + handler = _get_retriever_artifact_handler_for_spec(spec) + return StoredRetrieverSpec( + artifact_type=handler.artifact_type, + spec=spec, + config=handler.serialise_spec(spec), + path=handler.default_path(index=index, spec=spec), + ) + + +def _serialise_manifest(manifest: SaytArtifactManifest) -> dict[str, object]: + return { + "artifact_type": SAYT_ARTIFACT_TYPE, + "artifact_version": SAYT_ARTIFACT_VERSION, + "config": manifest.config.model_dump(mode="json"), + "corpus_file": manifest.corpus_file, + "corpus_size": manifest.corpus_size, + "retrievers": [ + _serialise_stored_retriever(stored_retriever) + for stored_retriever in manifest.retrievers + ], + } + + +def _serialise_stored_retriever( + stored_retriever: StoredRetrieverSpec, +) -> dict[str, object]: + return { + "type": stored_retriever.artifact_type, + "weight": stored_retriever.spec.weight, + "path": stored_retriever.path, + "config": stored_retriever.config, + } + + +def _deserialise_stored_retriever(payload: dict[str, object]) -> StoredRetrieverSpec: + retriever_type = str(payload["type"]) + weight = _coerce_float(payload["weight"], field_name="weight") + path = payload.get("path") + config = payload.get("config", {}) + if not isinstance(config, dict): + raise ValueError(f"Malformed retriever config for type: {retriever_type}") + handler = _get_retriever_artifact_handler(retriever_type) + spec = handler.deserialise_spec(weight=weight, config=config) + return StoredRetrieverSpec( + artifact_type=retriever_type, + spec=spec, + config=dict(config), + path=str(path) if isinstance(path, str) else None, + ) + + +def _get_retriever_artifact_handler(artifact_type: str) -> RetrieverArtifactHandler: + try: + return _RETRIEVER_ARTIFACT_HANDLERS[artifact_type] + except KeyError as exc: + raise ValueError( + f"No retriever artifact handler registered for type: {artifact_type}" + ) from exc + + +def _get_retriever_artifact_handler_for_spec( + spec: RetrieverSpec, +) -> RetrieverArtifactHandler: + for handler in reversed(tuple(_RETRIEVER_ARTIFACT_HANDLERS.values())): + if handler.can_handle(spec): + return handler + raise TypeError( + f"No retriever artifact handler registered for spec type: {type(spec).__name__}" + ) + + +class _PrefixRetrieverArtifactHandler: # pylint: disable=missing-function-docstring,useless-return + """Artifact handler for the built-in prefix retriever spec.""" + + artifact_type = "prefix" + + def can_handle(self, spec: RetrieverSpec) -> bool: + return isinstance(spec, PrefixRetrieverSpec) + + def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: + _ = spec + return {} + + def deserialise_spec( + self, + *, + weight: float, + config: Mapping[str, object], + ) -> RetrieverSpec: + _ = config + return PrefixRetrieverSpec(weight=weight) + + def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: + _ = (index, spec) + return None + + def build_artifact( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + path: Path | None, + ) -> None: + _ = (spec, corpus, path) + + def load_retriever( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + config: SaytConfig, + path: Path | None, + ) -> Retriever: + _ = path + return spec.build(corpus, min_chars=config.min_chars) + + +class _NgramRetrieverArtifactHandler: # pylint: disable=missing-function-docstring + """Artifact handler for the built-in n-gram retriever spec.""" + + artifact_type = "ngram" + + def can_handle(self, spec: RetrieverSpec) -> bool: + return isinstance(spec, NgramRetrieverSpec) + + def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: + typed_spec = _require_spec_type(spec, NgramRetrieverSpec) + return {"n": typed_spec.n, "max_df": typed_spec.max_df} + + def deserialise_spec( + self, + *, + weight: float, + config: Mapping[str, object], + ) -> RetrieverSpec: + return NgramRetrieverSpec( + weight=weight, + n=_coerce_int(config["n"], field_name="n"), + max_df=_coerce_float(config["max_df"], field_name="max_df"), + ) + + def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: + return f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}" + + def build_artifact( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + path: Path | None, + ) -> None: + typed_spec = _require_spec_type(spec, NgramRetrieverSpec) + build_ngram_index( + corpus, + n=typed_spec.n, + max_df=typed_spec.max_df, + output_dir=_require_path(path, typed_spec.name), + overwrite=True, + ) + + def load_retriever( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + config: SaytConfig, + path: Path | None, + ) -> Retriever: + typed_spec = _require_spec_type(spec, NgramRetrieverSpec) + index = load_ngram_index( + corpus, + n=typed_spec.n, + max_df=typed_spec.max_df, + folder_path=_require_path(path, typed_spec.name), + ) + return NgramRetriever.from_index( + corpus, + min_chars=config.min_chars, + index=index, + ) + + +class _SemanticRetrieverArtifactHandler: # pylint: disable=missing-function-docstring + """Artifact handler for the built-in semantic retriever spec.""" + + artifact_type = "semantic" + + def can_handle(self, spec: RetrieverSpec) -> bool: + return isinstance(spec, SemanticRetrieverSpec) + + def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: + typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) + return {"model": typed_spec.model} + + def deserialise_spec( + self, + *, + weight: float, + config: Mapping[str, object], + ) -> RetrieverSpec: + return SemanticRetrieverSpec( + weight=weight, + model=str(config["model"]), + ) + + def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: + return f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}" + + def build_artifact( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + path: Path | None, + ) -> None: + typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) + build_semantic_index( + corpus, + model=typed_spec.model, + output_dir=_require_path(path, typed_spec.name), + overwrite=True, + ) + + def load_retriever( + self, + *, + spec: RetrieverSpec, + corpus: CleanCorpus, + config: SaytConfig, + path: Path | None, + ) -> Retriever: + typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) + index = load_semantic_index( + corpus, + model=typed_spec.model, + folder_path=_require_path(path, typed_spec.name), + ) + return SemanticRetriever.from_index( + corpus, + min_chars=config.min_chars, + index=index, + ) + + +def _require_path(path: Path | None, retriever_name: str) -> Path: + if path is None: + raise ValueError( + f"Retriever '{retriever_name}' requires a persisted filespace path" + ) + return path + + +def _coerce_int(value: object, *, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int | str): + raise ValueError(f"Malformed integer value for retriever field: {field_name}") + return int(value) + + +def _coerce_float(value: object, *, field_name: str) -> float: + if isinstance(value, bool) or not isinstance(value, int | float | str): + raise ValueError(f"Malformed float value for retriever field: {field_name}") + return float(value) + + +def _require_spec_type(spec: RetrieverSpec, spec_type: type[SpecT]) -> SpecT: + if not isinstance(spec, spec_type): + raise TypeError( + f"Expected spec of type {spec_type.__name__}, got {type(spec).__name__}" + ) + return spec + + +def _register_builtin_retriever_artifact_handlers() -> None: + """Seed the artifact handler registry with the built-in retriever types.""" + for handler in ( + _PrefixRetrieverArtifactHandler(), + _NgramRetrieverArtifactHandler(), + _SemanticRetrieverArtifactHandler(), + ): + register_retriever_artifact_handler(handler) + + +_register_builtin_retriever_artifact_handlers() diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index bf473e9..3280f5d 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -11,9 +11,14 @@ from industrial_classification_utils.sayt import ( PrefixRetrieverSpec, + SAYTBuilder, ) from industrial_classification_utils.sayt.sayt import SAYTSuggester -from industrial_classification_utils.sayt.sayt_core import CleanCorpus, Suggestion +from industrial_classification_utils.sayt.sayt_core import ( + CleanCorpus, + PersistedCorpusRow, + Suggestion, +) def test_constructor_rejects_unknown_kwargs(small_corpus): @@ -50,6 +55,20 @@ def test_clean_corpus_accepts_existing_instance_and_dict_input(small_corpus): assert dict_corpus.rows == corpus.rows +def test_clean_corpus_restores_persisted_rows(small_corpus): + """Restore cleaned corpus rows without regenerating row identifiers.""" + corpus = CleanCorpus.model_validate(small_corpus) + + restored = CleanCorpus.from_persisted_rows( + [PersistedCorpusRow(*row) for row in corpus.rows] + ) + + assert restored.rows == corpus.rows + assert restored.id_to_search == corpus.id_to_search + assert restored.id_to_display == corpus.id_to_display + assert restored.display_text_count == corpus.display_text_count + + def test_clean_corpus_rejects_non_iterable_input(): """Reject scalar corpus values before attempting to clean them.""" with pytest.raises(TypeError, match="corpus must be an iterable"): @@ -132,6 +151,27 @@ def test_from_csv_rejects_missing_display_column(tmp_path, small_corpus): ) +def test_from_artifact_restores_prefix_suggester(tmp_path, small_corpus): + """Round-trip a prefix-only artifact into a working suggester.""" + artifact_dir = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(tmp_path / "artifact") + + restored = SAYTSuggester.from_artifact(artifact_dir) + expected = SAYTSuggester( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ) + + assert restored.suggest("car") == expected.suggest("car") + assert restored.get_config().model_dump() == expected.get_config().model_dump() + + def test_suggest_returns_empty_for_short_or_non_string_query(small_corpus): """Return no suggestions for short or non-string queries.""" s = SAYTSuggester(small_corpus, min_chars=4, retrievers=[PrefixRetrieverSpec()]) diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py new file mode 100644 index 0000000..b6c8be1 --- /dev/null +++ b/tests/sayt/test_sayt_builder.py @@ -0,0 +1,263 @@ +"""Tests for SAYT artifact building and loading.""" + +# pylint: disable=too-few-public-methods,missing-function-docstring,too-many-arguments,duplicate-code + +import csv +import json +from pathlib import Path + +from industrial_classification_utils.sayt import ( + NgramRetrieverSpec, + PrefixRetrieverSpec, + RetrieverArtifactHandler, + SAYTBuilder, + register_retriever_artifact_handler, + unregister_retriever_artifact_handler, +) +from industrial_classification_utils.sayt.sayt import SAYTSuggester +from industrial_classification_utils.sayt.sayt_core import CleanCorpus, Suggestion + + +class _CustomRetriever: + def __init__(self, row, *, trigger: str, min_chars: int): + self._row = row + self._trigger = trigger + self._min_chars = min_chars + + def suggest_with_scores(self, q_norm, num_suggestions): + _ = num_suggestions + if len(q_norm) < self._min_chars or self._trigger not in q_norm: + return [] + return [ + Suggestion( + display_text=self._row[2], + score=1.0, + search_text=self._row[1], + row_id=self._row[0], + ) + ] + + +class _CustomRetrieverSpec: + def __init__(self, *, trigger: str, weight: float = 1.0): + self.trigger = trigger + self.weight = weight + self.name = "custom-trigger" + + def build(self, corpus, *, min_chars): + return _CustomRetriever( + corpus.rows[-1], trigger=self.trigger, min_chars=min_chars + ) + + +class _CustomRetrieverArtifactHandlerImpl: + artifact_type = "custom-trigger" + + def can_handle(self, spec): + return isinstance(spec, _CustomRetrieverSpec) + + def serialise_spec(self, spec): + return {"trigger": spec.trigger} + + def deserialise_spec(self, *, weight, config): + return _CustomRetrieverSpec(trigger=str(config["trigger"]), weight=weight) + + def default_path(self, *, index, spec): + _ = (index, spec) + + def build_artifact(self, *, spec, corpus, path): + _ = (spec, corpus, path) + + def load_retriever(self, *, spec, corpus, config, path): + _ = (config, path) + return spec.build(corpus, min_chars=3) + + +def test_builder_writes_manifest_and_corpus(tmp_path, small_corpus): + """Persist manifest metadata and cleaned corpus rows for an artifact.""" + artifact_dir = tmp_path / "artifact" + + result = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(artifact_dir) + + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + with open(artifact_dir / "corpus.csv", encoding="utf-8") as corpus_file: + rows = list(csv.DictReader(corpus_file)) + + assert result == artifact_dir + assert manifest == { + "artifact_type": "sayt", + "artifact_version": 1, + "config": {"min_chars": 3, "max_suggestions": 5}, + "corpus_file": "corpus.csv", + "corpus_size": len(small_corpus), + "retrievers": [ + {"type": "prefix", "weight": 1.0, "path": None, "config": {}}, + ], + } + assert rows == [ + { + "row_id": row_id, + "search_text": search_text, + "display_text": display_text, + } + for row_id, search_text, display_text in CleanCorpus.model_validate( + small_corpus + ).rows + ] + + +def test_builder_writes_ngram_filespace(monkeypatch, tmp_path, small_corpus): + """Persist the configured dense retriever filespace inside the artifact.""" + captured = {} + artifact_dir = tmp_path / "artifact" + + class _StubPersistentVectorStore: + def __init__( # noqa: PLR0913 + self, + *, + file_name, + data_type, + vectoriser, + batch_size, + output_dir, + overwrite, + hooks, + ): + captured["file_name"] = file_name + captured["data_type"] = data_type + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["batch_size"] = batch_size + captured["output_dir"] = output_dir + captured["overwrite"] = overwrite + captured["hooks"] = hooks + Path(output_dir).mkdir(parents=True, exist_ok=True) + Path(output_dir, "metadata.json").write_text("{}", encoding="utf-8") + Path(output_dir, "vectors.parquet").write_text("dummy", encoding="utf-8") + self.num_vectors = 1 + + monkeypatch.setattr( + "industrial_classification_utils.sayt.sayt_indexes.VectorStore", + _StubPersistentVectorStore, + ) + + SAYTBuilder( + small_corpus, + retrievers=[NgramRetrieverSpec(max_df=1.0)], + min_chars=3, + ).build_artifact(artifact_dir) + + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + filespace_path = artifact_dir / manifest["retrievers"][0]["path"] + + assert captured["output_dir"] == str(filespace_path) + assert (filespace_path / "metadata.json").exists() + assert (filespace_path / "vectors.parquet").exists() + + +def test_from_artifact_loads_persisted_ngram_filespace( + monkeypatch, tmp_path, small_corpus +): + """Load persisted dense retrievers from their artifact filespaces.""" + captured = {} + artifact_dir = tmp_path / "artifact" + target_row_id, _, target_display = CleanCorpus.model_validate(small_corpus).rows[-1] + + class _StubPersistentVectorStore: + def __init__( # noqa: PLR0913 + self, + *, + file_name, + data_type, + vectoriser, + batch_size, + output_dir, + overwrite, + hooks, + ): + _ = (file_name, data_type, vectoriser, batch_size, overwrite, hooks) + Path(output_dir).mkdir(parents=True, exist_ok=True) + Path(output_dir, "metadata.json").write_text("{}", encoding="utf-8") + Path(output_dir, "vectors.parquet").write_text("dummy", encoding="utf-8") + self.num_vectors = 1 + + @classmethod + def from_filespace(cls, *, folder_path, vectoriser, hooks): + captured["folder_path"] = folder_path + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["hooks"] = hooks + return _StubLoadedVectorStore() + + class _StubSearchResults: + def to_dict(self, orient="records"): + assert orient == "records" + return [{"doc_label": target_row_id, "score": 1.0}] + + class _StubLoadedVectorStore: + num_vectors = 1 + + def search(self, query, n_results=10): + _ = query + captured["n_results"] = n_results + return _StubSearchResults() + + monkeypatch.setattr( + "industrial_classification_utils.sayt.sayt_indexes.VectorStore", + _StubPersistentVectorStore, + ) + + builder = SAYTBuilder( + small_corpus, + retrievers=[NgramRetrieverSpec(max_df=1.0)], + min_chars=3, + ) + builder.build_artifact(artifact_dir) + + suggester = SAYTSuggester.from_artifact(artifact_dir) + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + + assert suggester.suggest("groom") == [target_display] + assert captured == { + "folder_path": str(artifact_dir / manifest["retrievers"][0]["path"]), + "vectoriser_type": "_CharNgramVectoriser", + "hooks": None, + "n_results": 1, + } + + +def test_custom_retriever_artifact_handler_round_trips(tmp_path, small_corpus): + """Allow custom retriever specs to participate in artifact build and load.""" + artifact_dir = tmp_path / "artifact" + handler: RetrieverArtifactHandler = _CustomRetrieverArtifactHandlerImpl() + register_retriever_artifact_handler(handler) + try: + spec = _CustomRetrieverSpec(trigger="groom", weight=1.5) + builder = SAYTBuilder( + small_corpus, + retrievers=[spec], + min_chars=3, + max_suggestions=4, + ) + + builder.build_artifact(artifact_dir) + + manifest = json.loads( + (artifact_dir / "manifest.json").read_text(encoding="utf-8") + ) + suggester = SAYTSuggester.from_artifact(artifact_dir) + + assert manifest["retrievers"] == [ + { + "type": "custom-trigger", + "weight": 1.5, + "path": None, + "config": {"trigger": "groom"}, + } + ] + assert suggester.suggest("groom") == ["Dog grooming"] + finally: + unregister_retriever_artifact_handler("custom-trigger") diff --git a/tests/sayt/test_sayt_retrievers.py b/tests/sayt/test_sayt_retrievers.py index 43d0642..14f3fb6 100644 --- a/tests/sayt/test_sayt_retrievers.py +++ b/tests/sayt/test_sayt_retrievers.py @@ -3,6 +3,10 @@ # ruff: noqa: PLR2004 # pylint: disable=protected-access,redefined-outer-name,too-few-public-methods,C0116,W0613 +import csv +import shutil +from pathlib import Path + import numpy as np import pytest from classifai.vectorisers import VectoriserBase @@ -233,6 +237,108 @@ def test_dense_retriever_keeps_ties_at_cutoff(small_corpus): ] +def test_dense_vector_index_builds_persistent_filespace( + monkeypatch, tmp_path, small_corpus +): + """Persist dense indexes even when the output folder is replaced first.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + output_dir = tmp_path / "ngram" + output_dir.mkdir() + + class _StubPersistentVectorStore: + # pylint: disable=too-many-arguments + def __init__( # noqa: PLR0913 + self, + *, + file_name, + data_type, + vectoriser, + batch_size, + output_dir, + overwrite, + hooks, + ): + captured["file_name"] = file_name + captured["data_type"] = data_type + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["batch_size"] = batch_size + captured["output_dir"] = output_dir + captured["overwrite"] = overwrite + captured["hooks"] = hooks + output_path = Path(output_dir) + if output_path.is_dir() and overwrite: + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + with open(file_name, encoding="utf-8") as input_file: + captured["rows"] = list(csv.DictReader(input_file)) + (output_path / "metadata.json").write_text("{}", encoding="utf-8") + (output_path / "vectors.parquet").write_text("dummy", encoding="utf-8") + self.num_vectors = len(captured["rows"]) + + monkeypatch.setattr( + "industrial_classification_utils.sayt.sayt_indexes.VectorStore", + _StubPersistentVectorStore, + ) + + index = DenseVectorIndex.from_corpus( + corpus=corpus, + vectoriser=_StubVectoriser(np.array([[1.0, 0.0]])), + output_dir=output_dir, + overwrite=True, + ) + + assert index._num_vectors == len(corpus.rows) + assert Path(captured["file_name"]).name == "corpus.csv" + assert Path(captured["file_name"]).parent != output_dir + assert captured["data_type"] == "csv" + assert captured["vectoriser_type"] == "_StubVectoriser" + assert captured["batch_size"] == 64 + assert captured["output_dir"] == str(output_dir) + assert captured["overwrite"] is True + assert captured["hooks"] is None + assert captured["rows"] == [ + {"label": row_id, "text": search_text} for row_id, search_text, _ in corpus.rows + ] + + +def test_dense_vector_index_loads_existing_filespace( + monkeypatch, tmp_path, small_corpus +): + """Load a persisted dense index via ClassifAI's filespace API.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + folder_path = tmp_path / "existing-ngram" + + class _StubLoadedVectorStore: + num_vectors = 7 + + def _fake_from_filespace(*, folder_path, vectoriser, hooks): + captured["folder_path"] = folder_path + captured["vectoriser_type"] = type(vectoriser).__name__ + captured["hooks"] = hooks + return _StubLoadedVectorStore() + + monkeypatch.setattr( + "industrial_classification_utils.sayt.sayt_indexes.VectorStore.from_filespace", + _fake_from_filespace, + ) + + index = DenseVectorIndex.from_filespace( + corpus=corpus, + folder_path=folder_path, + vectoriser=_StubVectoriser(np.array([[1.0, 0.0]])), + ) + + assert index._num_vectors == 7 + assert index._corpus is corpus + assert captured == { + "folder_path": str(folder_path), + "vectoriser_type": "_StubVectoriser", + "hooks": None, + } + + def test_semantic_retriever_builds_index_with_wrapped_vectoriser( monkeypatch, small_corpus ): @@ -247,8 +353,16 @@ def __init__(self, model_name): def transform(self, texts): return np.array([[1.0, 0.0]]) - def _fake_build_dense_vector_index(*, corpus, vectoriser): + def _fake_build_dense_vector_index( + *, + corpus, + vectoriser, + output_dir=None, + overwrite=True, + ): captured["vectoriser_type"] = type(vectoriser).__name__ + captured["output_dir"] = output_dir + captured["overwrite"] = overwrite return DenseVectorIndex( _vector_store=_StubVectorStore([]), _num_vectors=1, @@ -269,6 +383,8 @@ def _fake_build_dense_vector_index(*, corpus, vectoriser): assert captured == { "model_name": "sentence-transformers/all-MiniLM-L6-v2", "vectoriser_type": "_L2NormalisingVectoriser", + "output_dir": None, + "overwrite": True, } assert retriever._min_chars == 3 From adc75007ea8f3fe36165bfc252fd09b7d48a5bea Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 9 Jun 2026 12:28:49 +0100 Subject: [PATCH 02/11] chore: remove unused ruff noqa --- tests/sayt/test_sayt_retrievers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/sayt/test_sayt_retrievers.py b/tests/sayt/test_sayt_retrievers.py index 14f3fb6..6b447ec 100644 --- a/tests/sayt/test_sayt_retrievers.py +++ b/tests/sayt/test_sayt_retrievers.py @@ -1,6 +1,5 @@ """Tests for SAYT retrieval and ranking behavior.""" -# ruff: noqa: PLR2004 # pylint: disable=protected-access,redefined-outer-name,too-few-public-methods,C0116,W0613 import csv From f5c859cc63675fa6e2782a905695501ca0c49ab2 Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 9 Jun 2026 13:13:01 +0100 Subject: [PATCH 03/11] refactor(sayt): remove SAYTConfig --- .../sayt/sayt.py | 76 +++++++++++-------- .../sayt/sayt_builder.py | 25 ++++-- .../sayt/sayt_core.py | 43 ++++++----- .../sayt/sayt_retriever_specs.py | 4 +- .../sayt/sayt_storage.py | 40 ++++++---- tests/sayt/test_sayt.py | 7 +- tests/sayt/test_sayt_builder.py | 11 +-- tests/sayt/test_sayt_config.py | 60 +++++++++++---- 8 files changed, 170 insertions(+), 96 deletions(-) diff --git a/src/industrial_classification_utils/sayt/sayt.py b/src/industrial_classification_utils/sayt/sayt.py index aa108ca..e3fa74a 100644 --- a/src/industrial_classification_utils/sayt/sayt.py +++ b/src/industrial_classification_utils/sayt/sayt.py @@ -4,6 +4,8 @@ retrievers and combines their scores into ranked suggestions. """ +# pylint: disable=duplicate-code + import math import os from collections.abc import Iterable, Sequence @@ -14,10 +16,11 @@ from .sayt_core import ( CleanCorpus, - SaytConfig, Suggestion, _normalise, take_with_ties, + validate_max_suggestions, + validate_min_chars, ) from .sayt_retriever_specs import ( Retriever, @@ -55,8 +58,8 @@ class SAYTSuggester: By default it uses the standard prefix, n-gram, and semantic retriever specifications. Use ``retrievers=`` to override that mix. - Suggester-wide settings are currently passed as keyword arguments and - validated by ``SaytConfig``. At present these include: + Suggester-wide settings are configured directly on the suggester. At + present these include: - ``min_chars``: minimum query length before retrieval runs - ``max_suggestions``: default maximum number of ranked suggestions to return @@ -95,18 +98,20 @@ class SAYTSuggester: """ @classmethod - def _from_state( + def _from_state( # pylint: disable=too-many-arguments cls, *, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, + max_suggestions: int, retriever_specs: Sequence[RetrieverSpec], retrievers: list[_ConfiguredRetriever], ) -> "SAYTSuggester": """Construct a suggester from already-validated runtime state.""" suggester = cls.__new__(cls) suggester._corpus = corpus - suggester._config = config + suggester._min_chars = min_chars + suggester._max_suggestions = max_suggestions suggester._max_duplication = max(corpus.display_text_count.values(), default=0) suggester._retriever_specs = tuple(retriever_specs) suggester._retrievers = retrievers @@ -118,7 +123,8 @@ def __init__( corpus: Iterable[tuple[object, object]] | Iterable[str], *, retrievers: Sequence[RetrieverSpec] | None = None, - **kwargs: object, + min_chars: int = 4, + max_suggestions: int = 10, ) -> None: """Initialise a suggester for a cleaned response corpus. @@ -127,12 +133,13 @@ def __init__( pairs. retrievers: Optional retriever specifications. When omitted, the standard prefix, n-gram, and semantic spec set is used. - **kwargs: Suggester-wide keyword arguments validated by - ``SaytConfig``, currently including ``min_chars`` and - ``max_suggestions``. + min_chars: Minimum query length before retrieval runs. + max_suggestions: Default maximum number of ranked suggestions to + return. """ self._corpus = CleanCorpus.model_validate(corpus) - self._config = SaytConfig.model_validate(kwargs) + self._min_chars = validate_min_chars(min_chars) + self._max_suggestions = validate_max_suggestions(max_suggestions) self._retriever_specs = tuple( default_retriever_specs() if retrievers is None else retrievers @@ -170,21 +177,22 @@ def _build_retrievers( weight=weight, retriever=spec.build( self._corpus, - min_chars=self._config.min_chars, + min_chars=self._min_chars, ), ) for spec, weight in self._normalised_retriever_specs(retriever_specs) ] @classmethod - def from_csv( + def from_csv( # pylint: disable=too-many-arguments # noqa: PLR0913 cls, file_path: str | os.PathLike, *, search_text_col: str = "title", display_text_col: str | None = None, retrievers: Sequence[RetrieverSpec] | None = None, - **kwargs: object, + min_chars: int = 4, + max_suggestions: int = 10, ) -> "SAYTSuggester": """Build a suggester from CSV input. @@ -195,7 +203,9 @@ def from_csv( omitted, the search column is reused for display values. retrievers: Optional retriever specifications. When omitted, the standard retriever set is used. - **kwargs: Keyword arguments validated by ``SaytConfig``. + min_chars: Minimum query length before retrieval runs. + max_suggestions: Default maximum number of ranked suggestions to + return. Returns: A configured ``SAYTSuggester`` instance. @@ -211,7 +221,8 @@ def from_csv( return cls( corpus_rows, retrievers=retrievers, - **kwargs, + min_chars=min_chars, + max_suggestions=max_suggestions, ) @classmethod @@ -229,13 +240,14 @@ def from_artifact(cls, artifact_dir: str | os.PathLike) -> "SAYTSuggester": retrievers = cls._load_retrievers_from_artifact( corpus=corpus, - config=manifest.config, + min_chars=manifest.min_chars, stored_retrievers=manifest.retrievers, artifact_dir=artifact_path, ) return cls._from_state( corpus=corpus, - config=manifest.config, + min_chars=manifest.min_chars, + max_suggestions=manifest.max_suggestions, retriever_specs=[ stored_retriever.spec for stored_retriever in manifest.retrievers ], @@ -247,7 +259,7 @@ def _load_retrievers_from_artifact( cls, *, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, stored_retrievers: Sequence[StoredRetrieverSpec], artifact_dir: Path, ) -> list[_ConfiguredRetriever]: @@ -261,7 +273,7 @@ def _load_retrievers_from_artifact( weight=weight, retriever=cls._load_retriever_from_artifact( corpus=corpus, - config=config, + min_chars=min_chars, stored_retriever=stored_retriever, artifact_dir=artifact_dir, ), @@ -277,14 +289,14 @@ def _load_retrievers_from_artifact( def _load_retriever_from_artifact( *, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, stored_retriever: StoredRetrieverSpec, artifact_dir: Path, ) -> Retriever: """Restore a runtime retriever from persisted artifact state.""" return load_retriever_from_artifact( corpus=corpus, - config=config, + min_chars=min_chars, stored_retriever=stored_retriever, artifact_dir=artifact_dir, ) @@ -367,12 +379,12 @@ def suggest_with_scores( Returns: A list of combined suggestions ordered by descending score. Returns an empty list when the normalised query is shorter than - ``SaytConfig.min_chars``. + ``min_chars``. """ if num_suggestions is None: - num_suggestions = self._config.max_suggestions + num_suggestions = self._max_suggestions q_norm = _normalise(query) - if len(q_norm) < self._config.min_chars: + if len(q_norm) < self._min_chars: return [] # Ask for more suggestions, as some may be filtered out after deduplication @@ -410,7 +422,7 @@ def suggest( score, while preserving ties at the cutoff. """ if num_suggestions is None: - num_suggestions = self._config.max_suggestions + num_suggestions = self._max_suggestions results = self.suggest_with_scores( query, num_suggestions=num_suggestions * self._max_duplication ) @@ -418,10 +430,14 @@ def suggest( ranked_results = take_with_ties(dedup_results, num_suggestions) return [result[0] for result in ranked_results] - def get_config(self) -> SaytConfig: - """Return a copy of the validated suggester configuration. + def get_config(self) -> dict[str, int]: + """Return the validated global suggester settings. Returns: - A deep copy of the ``SaytConfig`` used by this suggester. + A copy of the current global ``min_chars`` and ``max_suggestions`` + settings. """ - return self._config.model_copy(deep=True) + return { + "min_chars": self._min_chars, + "max_suggestions": self._max_suggestions, + } diff --git a/src/industrial_classification_utils/sayt/sayt_builder.py b/src/industrial_classification_utils/sayt/sayt_builder.py index 367680d..a416603 100644 --- a/src/industrial_classification_utils/sayt/sayt_builder.py +++ b/src/industrial_classification_utils/sayt/sayt_builder.py @@ -1,10 +1,12 @@ """Offline artifact builder for persisted SAYT runtime assets.""" +# pylint: disable=duplicate-code + import os from collections.abc import Iterable, Sequence from pathlib import Path -from .sayt_core import CleanCorpus, SaytConfig +from .sayt_core import CleanCorpus, validate_max_suggestions, validate_min_chars from .sayt_retriever_specs import ( RetrieverSpec, default_retriever_specs, @@ -27,24 +29,27 @@ def __init__( corpus: Iterable[tuple[object, object]] | Iterable[str], *, retrievers: Sequence[RetrieverSpec] | None = None, - **kwargs: object, + min_chars: int = 4, + max_suggestions: int = 10, ) -> None: """Initialise an artifact builder from raw corpus input.""" self._corpus = CleanCorpus.model_validate(corpus) - self._config = SaytConfig.model_validate(kwargs) + self._min_chars = validate_min_chars(min_chars) + self._max_suggestions = validate_max_suggestions(max_suggestions) self._retriever_specs = tuple( default_retriever_specs() if retrievers is None else retrievers ) @classmethod - def from_csv( + def from_csv( # pylint: disable=too-many-arguments # noqa: PLR0913 cls, file_path: str | os.PathLike, *, search_text_col: str = "title", display_text_col: str | None = None, retrievers: Sequence[RetrieverSpec] | None = None, - **kwargs: object, + min_chars: int = 4, + max_suggestions: int = 10, ) -> "SAYTBuilder": """Initialise an artifact builder from CSV input.""" corpus_rows = load_corpus_from_csv( @@ -52,7 +57,12 @@ def from_csv( search_text_col=search_text_col, display_text_col=display_text_col, ) - return cls(corpus_rows, retrievers=retrievers, **kwargs) + return cls( + corpus_rows, + retrievers=retrievers, + min_chars=min_chars, + max_suggestions=max_suggestions, + ) def build_artifact( self, @@ -64,7 +74,8 @@ def build_artifact( artifact_dir = prepare_artifact_dir(output_dir, overwrite=overwrite) manifest = build_artifact_manifest( corpus=self._corpus, - config=self._config, + min_chars=self._min_chars, + max_suggestions=self._max_suggestions, retriever_specs=self._retriever_specs, ) diff --git a/src/industrial_classification_utils/sayt/sayt_core.py b/src/industrial_classification_utils/sayt/sayt_core.py index 64a77e2..dcccbc0 100644 --- a/src/industrial_classification_utils/sayt/sayt_core.py +++ b/src/industrial_classification_utils/sayt/sayt_core.py @@ -158,26 +158,29 @@ def _clean_corpus( ] -class SaytConfig(BaseModel): - """Validated configuration for a SAYT suggester instance. - - This model contains only suggester-wide settings. Retriever-specific - configuration lives on individual ``RetrieverSpec`` objects. - """ - - model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) - - min_chars: int = 4 - max_suggestions: int = 10 - - @model_validator(mode="after") - def _validate_ranges(self) -> "SaytConfig": - """Enforce the supported numeric ranges for SAYT settings.""" - if self.min_chars < 3: - raise ValueError("min_chars must be >= 3") - if not 1 <= self.max_suggestions <= 100: - raise ValueError("max_suggestions must be between 1 and 100") - return self +def _coerce_sayt_int_setting(value: object, *, field_name: str) -> int: + if isinstance(value, bool) or not isinstance(value, int | str): + raise TypeError(f"{field_name} must be an integer") + try: + return int(value) + except ValueError as exc: + raise TypeError(f"{field_name} must be an integer") from exc + + +def validate_min_chars(value: object) -> int: + """Validate the global SAYT minimum query length setting.""" + min_chars = _coerce_sayt_int_setting(value, field_name="min_chars") + if min_chars < 3: + raise ValueError("min_chars must be >= 3") + return min_chars + + +def validate_max_suggestions(value: object) -> int: + """Validate the global SAYT maximum suggestion count setting.""" + max_suggestions = _coerce_sayt_int_setting(value, field_name="max_suggestions") + if not 1 <= max_suggestions <= 100: + raise ValueError("max_suggestions must be between 1 and 100") + return max_suggestions @dataclass(frozen=True, slots=True) diff --git a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py b/src/industrial_classification_utils/sayt/sayt_retriever_specs.py index e760d8a..955b803 100644 --- a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py +++ b/src/industrial_classification_utils/sayt/sayt_retriever_specs.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Protocol -from .sayt_core import CleanCorpus, SaytConfig, Suggestion +from .sayt_core import CleanCorpus, Suggestion from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever _MIN_NGRAM_SIZE = 2 @@ -98,7 +98,7 @@ def load_retriever( *, spec: RetrieverSpec, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, path: Path | None, ) -> Retriever: """Restore a runtime retriever from persisted artifact state.""" diff --git a/src/industrial_classification_utils/sayt/sayt_storage.py b/src/industrial_classification_utils/sayt/sayt_storage.py index 62ec41a..87f95c8 100644 --- a/src/industrial_classification_utils/sayt/sayt_storage.py +++ b/src/industrial_classification_utils/sayt/sayt_storage.py @@ -11,7 +11,12 @@ import pandas as pd -from .sayt_core import CleanCorpus, PersistedCorpusRow, SaytConfig +from .sayt_core import ( + CleanCorpus, + PersistedCorpusRow, + validate_max_suggestions, + validate_min_chars, +) from .sayt_indexes import ( build_ngram_index, build_semantic_index, @@ -29,7 +34,7 @@ from .sayt_retrievers import NgramRetriever, SemanticRetriever SAYT_ARTIFACT_TYPE = "sayt" -SAYT_ARTIFACT_VERSION = 1 +SAYT_ARTIFACT_VERSION = 2 MANIFEST_FILE_NAME = "manifest.json" CORPUS_FILE_NAME = "corpus.csv" _ARTIFACT_CORPUS_FIELDS = ["row_id", "search_text", "display_text"] @@ -53,7 +58,8 @@ class StoredRetrieverSpec: class SaytArtifactManifest: """Structured manifest data for a persisted SAYT artifact.""" - config: SaytConfig + min_chars: int + max_suggestions: int corpus_file: str corpus_size: int retrievers: tuple[StoredRetrieverSpec, ...] @@ -149,12 +155,14 @@ def read_artifact_corpus( def build_artifact_manifest( *, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, + max_suggestions: int, retriever_specs: tuple[RetrieverSpec, ...], ) -> SaytArtifactManifest: """Build the structured manifest payload for a SAYT artifact.""" return SaytArtifactManifest( - config=config.model_copy(deep=True), + min_chars=min_chars, + max_suggestions=max_suggestions, corpus_file=CORPUS_FILE_NAME, corpus_size=corpus.size, retrievers=tuple( @@ -192,7 +200,8 @@ def read_artifact_manifest(*, artifact_dir: str | Path) -> SaytArtifactManifest: try: return SaytArtifactManifest( - config=SaytConfig.model_validate(payload["config"]), + min_chars=validate_min_chars(payload["min_chars"]), + max_suggestions=validate_max_suggestions(payload["max_suggestions"]), corpus_file=str(payload["corpus_file"]), corpus_size=int(payload["corpus_size"]), retrievers=tuple( @@ -257,7 +266,7 @@ def build_retriever_artifact( def load_retriever_from_artifact( *, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, stored_retriever: StoredRetrieverSpec, artifact_dir: str | Path, ) -> Retriever: @@ -271,7 +280,7 @@ def load_retriever_from_artifact( return handler.load_retriever( spec=stored_retriever.spec, corpus=corpus, - config=config, + min_chars=min_chars, path=path, ) @@ -293,7 +302,8 @@ def _serialise_manifest(manifest: SaytArtifactManifest) -> dict[str, object]: return { "artifact_type": SAYT_ARTIFACT_TYPE, "artifact_version": SAYT_ARTIFACT_VERSION, - "config": manifest.config.model_dump(mode="json"), + "min_chars": manifest.min_chars, + "max_suggestions": manifest.max_suggestions, "corpus_file": manifest.corpus_file, "corpus_size": manifest.corpus_size, "retrievers": [ @@ -390,11 +400,11 @@ def load_retriever( *, spec: RetrieverSpec, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, path: Path | None, ) -> Retriever: _ = path - return spec.build(corpus, min_chars=config.min_chars) + return spec.build(corpus, min_chars=min_chars) class _NgramRetrieverArtifactHandler: # pylint: disable=missing-function-docstring @@ -445,7 +455,7 @@ def load_retriever( *, spec: RetrieverSpec, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, path: Path | None, ) -> Retriever: typed_spec = _require_spec_type(spec, NgramRetrieverSpec) @@ -457,7 +467,7 @@ def load_retriever( ) return NgramRetriever.from_index( corpus, - min_chars=config.min_chars, + min_chars=min_chars, index=index, ) @@ -508,7 +518,7 @@ def load_retriever( *, spec: RetrieverSpec, corpus: CleanCorpus, - config: SaytConfig, + min_chars: int, path: Path | None, ) -> Retriever: typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) @@ -519,7 +529,7 @@ def load_retriever( ) return SemanticRetriever.from_index( corpus, - min_chars=config.min_chars, + min_chars=min_chars, index=index, ) diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index 3280f5d..a8c4f2b 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -7,7 +7,6 @@ import pandas as pd import pytest -from pydantic import ValidationError from industrial_classification_utils.sayt import ( PrefixRetrieverSpec, @@ -23,8 +22,8 @@ def test_constructor_rejects_unknown_kwargs(small_corpus): """Reject unknown constructor kwargs during config validation.""" - with pytest.raises(ValidationError, match="Extra inputs are not permitted"): - SAYTSuggester(small_corpus, does_not_exist=True) + with pytest.raises(TypeError, match="unexpected keyword argument"): + SAYTSuggester(small_corpus, does_not_exist=True) # pylint: disable=E1123 def test_empty_corpus_after_filtering_raises(): @@ -169,7 +168,7 @@ def test_from_artifact_restores_prefix_suggester(tmp_path, small_corpus): ) assert restored.suggest("car") == expected.suggest("car") - assert restored.get_config().model_dump() == expected.get_config().model_dump() + assert restored.get_config() == expected.get_config() def test_suggest_returns_empty_for_short_or_non_string_query(small_corpus): diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py index b6c8be1..bf51895 100644 --- a/tests/sayt/test_sayt_builder.py +++ b/tests/sayt/test_sayt_builder.py @@ -68,9 +68,9 @@ def default_path(self, *, index, spec): def build_artifact(self, *, spec, corpus, path): _ = (spec, corpus, path) - def load_retriever(self, *, spec, corpus, config, path): - _ = (config, path) - return spec.build(corpus, min_chars=3) + def load_retriever(self, *, spec, corpus, min_chars, path): + _ = path + return spec.build(corpus, min_chars=min_chars) def test_builder_writes_manifest_and_corpus(tmp_path, small_corpus): @@ -91,8 +91,9 @@ def test_builder_writes_manifest_and_corpus(tmp_path, small_corpus): assert result == artifact_dir assert manifest == { "artifact_type": "sayt", - "artifact_version": 1, - "config": {"min_chars": 3, "max_suggestions": 5}, + "artifact_version": 2, + "min_chars": 3, + "max_suggestions": 5, "corpus_file": "corpus.csv", "corpus_size": len(small_corpus), "retrievers": [ diff --git a/tests/sayt/test_sayt_config.py b/tests/sayt/test_sayt_config.py index 229e7c7..23843c8 100644 --- a/tests/sayt/test_sayt_config.py +++ b/tests/sayt/test_sayt_config.py @@ -3,31 +3,65 @@ # pylint: disable=too-few-public-methods, R0801 import pytest -from pydantic import ValidationError from industrial_classification_utils.sayt import ( NgramRetrieverSpec, PrefixRetrieverSpec, + SAYTBuilder, + SAYTSuggester, SemanticRetrieverSpec, default_retriever_specs, ) -from industrial_classification_utils.sayt.sayt_core import CleanCorpus, SaytConfig +from industrial_classification_utils.sayt.sayt_core import CleanCorpus @pytest.mark.parametrize( - "kwargs, exc_type", + "factory, kwargs, exc_type, match", [ - ({"min_chars": 2}, ValidationError), - ({"min_chars": True}, ValidationError), - ({"max_suggestions": 0}, ValidationError), - ({"max_suggestions": 101}, ValidationError), - ({"ngram_enable": True}, ValidationError), + (SAYTSuggester, {"min_chars": 2}, ValueError, "min_chars must be >= 3"), + ( + SAYTSuggester, + {"min_chars": True}, + TypeError, + "min_chars must be an integer", + ), + ( + SAYTSuggester, + {"max_suggestions": 0}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + ( + SAYTSuggester, + {"max_suggestions": 101}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + (SAYTBuilder, {"min_chars": 2}, ValueError, "min_chars must be >= 3"), + ( + SAYTBuilder, + {"min_chars": True}, + TypeError, + "min_chars must be an integer", + ), + ( + SAYTBuilder, + {"max_suggestions": 0}, + ValueError, + "max_suggestions must be between 1 and 100", + ), + ( + SAYTBuilder, + {"max_suggestions": 101}, + ValueError, + "max_suggestions must be between 1 and 100", + ), ], ) -def test_config_validation(kwargs, exc_type): - """Reject unsupported SAYT config values and types.""" - with pytest.raises(exc_type): - SaytConfig.model_validate(kwargs) +def test_runtime_setting_validation(factory, kwargs, exc_type, match): + """Reject unsupported global SAYT settings on public entry points.""" + with pytest.raises(exc_type, match=match): + factory([("car wash", "Car Wash")], **kwargs) def test_default_retriever_specs_returns_standard_set(): @@ -86,7 +120,7 @@ def test_ngram_retriever_spec_validates_against_corpus_size(): def test_retriever_specs_keep_their_config(): - """Expose per-retriever settings on the spec object rather than SaytConfig.""" + """Expose per-retriever settings on the spec object.""" n = 4 max_df = 0.8 spec = NgramRetrieverSpec(weight=2.0, n=n, max_df=max_df) From 9aa864e1bf9dc6930195b427fce5a6d4389be29f Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 9 Jun 2026 14:54:30 +0100 Subject: [PATCH 04/11] feat(sayt): export full config in get_config() method --- .../sayt/__init__.py | 2 + .../sayt/sayt.py | 334 ++++++++++++------ .../sayt/sayt_core.py | 68 +++- tests/sayt/test_sayt.py | 80 ++++- 4 files changed, 378 insertions(+), 106 deletions(-) diff --git a/src/industrial_classification_utils/sayt/__init__.py b/src/industrial_classification_utils/sayt/__init__.py index 07f2bdf..2bbed3f 100644 --- a/src/industrial_classification_utils/sayt/__init__.py +++ b/src/industrial_classification_utils/sayt/__init__.py @@ -2,6 +2,7 @@ from .sayt import SAYTSuggester from .sayt_builder import SAYTBuilder +from .sayt_core import SaytConfiguration from .sayt_retriever_specs import ( NgramRetrieverSpec, PrefixRetrieverSpec, @@ -27,6 +28,7 @@ "RetrieverSpec", "SAYTBuilder", "SAYTSuggester", + "SaytConfiguration", "SemanticRetriever", "SemanticRetrieverSpec", "default_retriever_specs", diff --git a/src/industrial_classification_utils/sayt/sayt.py b/src/industrial_classification_utils/sayt/sayt.py index e3fa74a..8111db5 100644 --- a/src/industrial_classification_utils/sayt/sayt.py +++ b/src/industrial_classification_utils/sayt/sayt.py @@ -8,14 +8,21 @@ import math import os -from collections.abc import Iterable, Sequence -from dataclasses import dataclass +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass, fields, is_dataclass from pathlib import Path +from typing import Any from survey_assist_utils.logging import get_logger from .sayt_core import ( CleanCorpus, + SaytArtifactProvenance, + SaytConfiguration, + SaytCorpusSummary, + SaytGlobalSettings, + SaytRetrieverArtifactProvenance, + SaytRetrieverSummary, Suggestion, _normalise, take_with_ties, @@ -28,6 +35,8 @@ default_retriever_specs, ) from .sayt_storage import ( + SAYT_ARTIFACT_TYPE, + SAYT_ARTIFACT_VERSION, StoredRetrieverSpec, load_corpus_from_csv, load_retriever_from_artifact, @@ -47,7 +56,7 @@ class _ConfiguredRetriever: retriever: Retriever -class SAYTSuggester: +class SAYTSuggester: # pylint: disable=too-many-instance-attributes """Suggest free-text responses as a user types. The suggester: @@ -97,27 +106,6 @@ class SAYTSuggester: ``` """ - @classmethod - def _from_state( # pylint: disable=too-many-arguments - cls, - *, - corpus: CleanCorpus, - min_chars: int, - max_suggestions: int, - retriever_specs: Sequence[RetrieverSpec], - retrievers: list[_ConfiguredRetriever], - ) -> "SAYTSuggester": - """Construct a suggester from already-validated runtime state.""" - suggester = cls.__new__(cls) - suggester._corpus = corpus - suggester._min_chars = min_chars - suggester._max_suggestions = max_suggestions - suggester._max_duplication = max(corpus.display_text_count.values(), default=0) - suggester._retriever_specs = tuple(retriever_specs) - suggester._retrievers = retrievers - logger.info(f"SAYT suggester initialized with config: {suggester.get_config()}") - return suggester - def __init__( self, corpus: Iterable[tuple[object, object]] | Iterable[str], @@ -146,42 +134,42 @@ def __init__( ) self._retrievers = self._build_retrievers(self._retriever_specs) self._max_duplication = max(self._corpus.display_text_count.values(), default=0) - logger.info(f"SAYT suggester initialized with config: {self.get_config()}") + self._stored_retrievers: tuple[StoredRetrieverSpec, ...] | None = None + self._artifact_provenance: SaytArtifactProvenance | None = None + logger.info( + "SAYT suggester initialized", + config=self.get_config().model_dump(mode="json"), + ) - @staticmethod - def _normalised_retriever_specs( + @classmethod + def _from_state( # pylint: disable=too-many-arguments # noqa: PLR0913 + cls, + *, + corpus: CleanCorpus, + min_chars: int, + max_suggestions: int, retriever_specs: Sequence[RetrieverSpec], - ) -> list[tuple[RetrieverSpec, float]]: - """Validate and normalise configured retriever weights.""" - if not retriever_specs: - raise ValueError("At least one retriever must be configured") - - validated_specs: list[tuple[RetrieverSpec, float]] = [] - for spec in retriever_specs: - weight = float(spec.weight) - if not math.isfinite(weight) or weight <= 0: - raise ValueError( - f"Retriever '{spec.name}' weight must be a finite value > 0" - ) - validated_specs.append((spec, weight)) - - total_weight = sum(weight for _, weight in validated_specs) - return [(spec, weight / total_weight) for spec, weight in validated_specs] - - def _build_retrievers( - self, retriever_specs: Sequence[RetrieverSpec] - ) -> list[_ConfiguredRetriever]: - return [ - _ConfiguredRetriever( - name=spec.name, - weight=weight, - retriever=spec.build( - self._corpus, - min_chars=self._min_chars, - ), - ) - for spec, weight in self._normalised_retriever_specs(retriever_specs) - ] + retrievers: list[_ConfiguredRetriever], + stored_retrievers: Sequence[StoredRetrieverSpec] | None = None, + artifact_provenance: SaytArtifactProvenance | None = None, + ) -> "SAYTSuggester": + """Construct a suggester from already-validated runtime state.""" + suggester = cls.__new__(cls) + suggester._corpus = corpus + suggester._min_chars = min_chars + suggester._max_suggestions = max_suggestions + suggester._max_duplication = max(corpus.display_text_count.values(), default=0) + suggester._retriever_specs = tuple(retriever_specs) + suggester._retrievers = retrievers + suggester._stored_retrievers = ( + tuple(stored_retrievers) if stored_retrievers is not None else None + ) + suggester._artifact_provenance = artifact_provenance + logger.info( + "SAYT suggester initialized", + config=suggester.get_config().model_dump(mode="json"), + ) + return suggester @classmethod def from_csv( # pylint: disable=too-many-arguments # noqa: PLR0913 @@ -238,12 +226,19 @@ def from_artifact(cls, artifact_dir: str | os.PathLike) -> "SAYTSuggester": if corpus.size != manifest.corpus_size: raise ValueError("Artifact corpus size does not match manifest") - retrievers = cls._load_retrievers_from_artifact( + retrievers = _load_retrievers_from_artifact( corpus=corpus, min_chars=manifest.min_chars, stored_retrievers=manifest.retrievers, artifact_dir=artifact_path, ) + artifact_provenance = SaytArtifactProvenance( + artifact_dir=str(artifact_path), + artifact_type=SAYT_ARTIFACT_TYPE, + artifact_version=SAYT_ARTIFACT_VERSION, + corpus_file=manifest.corpus_file, + corpus_size=manifest.corpus_size, + ) return cls._from_state( corpus=corpus, min_chars=manifest.min_chars, @@ -252,55 +247,25 @@ def from_artifact(cls, artifact_dir: str | os.PathLike) -> "SAYTSuggester": stored_retriever.spec for stored_retriever in manifest.retrievers ], retrievers=retrievers, + stored_retrievers=manifest.retrievers, + artifact_provenance=artifact_provenance, ) - @classmethod - def _load_retrievers_from_artifact( - cls, - *, - corpus: CleanCorpus, - min_chars: int, - stored_retrievers: Sequence[StoredRetrieverSpec], - artifact_dir: Path, + def _build_retrievers( + self, retriever_specs: Sequence[RetrieverSpec] ) -> list[_ConfiguredRetriever]: - """Restore runtime retrievers from a persisted SAYT artifact.""" - normalised_specs = cls._normalised_retriever_specs( - [stored_retriever.spec for stored_retriever in stored_retrievers] - ) return [ _ConfiguredRetriever( - name=stored_retriever.spec.name, + name=spec.name, weight=weight, - retriever=cls._load_retriever_from_artifact( - corpus=corpus, - min_chars=min_chars, - stored_retriever=stored_retriever, - artifact_dir=artifact_dir, + retriever=spec.build( + self._corpus, + min_chars=self._min_chars, ), ) - for (_, weight), stored_retriever in zip( - normalised_specs, - stored_retrievers, - strict=True, - ) + for spec, weight in _normalised_retriever_specs(retriever_specs) ] - @staticmethod - def _load_retriever_from_artifact( - *, - corpus: CleanCorpus, - min_chars: int, - stored_retriever: StoredRetrieverSpec, - artifact_dir: Path, - ) -> Retriever: - """Restore a runtime retriever from persisted artifact state.""" - return load_retriever_from_artifact( - corpus=corpus, - min_chars=min_chars, - stored_retriever=stored_retriever, - artifact_dir=artifact_dir, - ) - def _dedup_suggestions( self, suggestions: list[Suggestion] ) -> list[tuple[str, float]]: @@ -430,14 +395,175 @@ def suggest( ranked_results = take_with_ties(dedup_results, num_suggestions) return [result[0] for result in ranked_results] - def get_config(self) -> dict[str, int]: - """Return the validated global suggester settings. + def get_config(self) -> SaytConfiguration: + """Return a rich runtime summary of this suggester. Returns: - A copy of the current global ``min_chars`` and ``max_suggestions`` - settings. + A summary of global settings, corpus details, retriever + configuration, and any artifact provenance available for this + suggester. """ + stored_retrievers: Sequence[StoredRetrieverSpec | None] + if self._stored_retrievers is None: + stored_retrievers = [None] * len(self._retriever_specs) + else: + stored_retrievers = list(self._stored_retrievers) + + retrievers = [ + _build_retriever_summary( + spec=spec, + configured_retriever=configured_retriever, + stored_retriever=stored_retriever, + ) + for spec, configured_retriever, stored_retriever in zip( + self._retriever_specs, + self._retrievers, + stored_retrievers, + strict=True, + ) + ] + + return SaytConfiguration( + settings=SaytGlobalSettings( + min_chars=self._min_chars, + max_suggestions=self._max_suggestions, + ), + corpus=SaytCorpusSummary( + size=self._corpus.size, + unique_display_texts=len(self._corpus.display_text_count), + max_duplication=self._max_duplication, + ), + retrievers=retrievers, + artifact_provenance=( + self._artifact_provenance.model_copy(deep=True) + if self._artifact_provenance is not None + else None + ), + ) + + +def _normalised_retriever_specs( + retriever_specs: Sequence[RetrieverSpec], +) -> list[tuple[RetrieverSpec, float]]: + """Validate and normalise configured retriever weights.""" + if not retriever_specs: + raise ValueError("At least one retriever must be configured") + + validated_specs: list[tuple[RetrieverSpec, float]] = [] + for spec in retriever_specs: + weight = float(spec.weight) + if not math.isfinite(weight) or weight <= 0: + raise ValueError( + f"Retriever '{spec.name}' weight must be a finite value > 0" + ) + validated_specs.append((spec, weight)) + + total_weight = sum(weight for _, weight in validated_specs) + return [(spec, weight / total_weight) for spec, weight in validated_specs] + + +def _restore_retriever_from_artifact( + *, + corpus: CleanCorpus, + min_chars: int, + stored_retriever: StoredRetrieverSpec, + artifact_dir: Path, +) -> Retriever: + """Restore a runtime retriever from persisted artifact state.""" + return load_retriever_from_artifact( + corpus=corpus, + min_chars=min_chars, + stored_retriever=stored_retriever, + artifact_dir=artifact_dir, + ) + + +def _load_retrievers_from_artifact( + *, + corpus: CleanCorpus, + min_chars: int, + stored_retrievers: Sequence[StoredRetrieverSpec], + artifact_dir: Path, +) -> list[_ConfiguredRetriever]: + """Restore runtime retrievers from a persisted SAYT artifact.""" + normalised_specs = _normalised_retriever_specs( + [stored_retriever.spec for stored_retriever in stored_retrievers] + ) + return [ + _ConfiguredRetriever( + name=stored_retriever.spec.name, + weight=weight, + retriever=_restore_retriever_from_artifact( + corpus=corpus, + min_chars=min_chars, + stored_retriever=stored_retriever, + artifact_dir=artifact_dir, + ), + ) + for (_, weight), stored_retriever in zip( + normalised_specs, + stored_retrievers, + strict=True, + ) + ] + + +def _jsonable_value(value: Any) -> Any: + if isinstance(value, str | int | float | bool) or value is None: + return value + if isinstance(value, os.PathLike): + return os.fspath(value) + if isinstance(value, Mapping): + return {str(key): _jsonable_value(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_jsonable_value(item) for item in value] + return str(value) + + +def _summarise_retriever_config(spec: RetrieverSpec) -> dict[str, Any]: + if is_dataclass(spec): + items = ( + (field.name, getattr(spec, field.name)) + for field in fields(spec) + if field.name not in {"name", "weight"} + ) + return {key: _jsonable_value(value) for key, value in items} + + raw_config = getattr(spec, "__dict__", None) + if isinstance(raw_config, dict): return { - "min_chars": self._min_chars, - "max_suggestions": self._max_suggestions, + str(key): _jsonable_value(value) + for key, value in raw_config.items() + if key not in {"name", "weight"} } + return {} + + +def _build_retriever_summary( + *, + spec: RetrieverSpec, + configured_retriever: _ConfiguredRetriever, + stored_retriever: StoredRetrieverSpec | None, +) -> SaytRetrieverSummary: + artifact_provenance = None + config = _summarise_retriever_config(spec) + if stored_retriever is not None: + artifact_config = { + str(key): _jsonable_value(value) + for key, value in stored_retriever.config.items() + } + artifact_provenance = SaytRetrieverArtifactProvenance( + artifact_type=stored_retriever.artifact_type, + path=stored_retriever.path, + config=artifact_config, + ) + + return SaytRetrieverSummary( + name=spec.name, + spec_type=type(spec).__name__, + retriever_type=type(configured_retriever.retriever).__name__, + configured_weight=float(spec.weight), + normalised_weight=configured_retriever.weight, + config=config, + artifact_provenance=artifact_provenance, + ) diff --git a/src/industrial_classification_utils/sayt/sayt_core.py b/src/industrial_classification_utils/sayt/sayt_core.py index dcccbc0..f1f101e 100644 --- a/src/industrial_classification_utils/sayt/sayt_core.py +++ b/src/industrial_classification_utils/sayt/sayt_core.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Iterable from dataclasses import dataclass -from typing import cast +from typing import Any, cast from uuid import NAMESPACE_URL, uuid5 import pandas as pd @@ -183,6 +183,72 @@ def validate_max_suggestions(value: object) -> int: return max_suggestions +class SaytGlobalSettings(BaseModel): + """Describe suggester-wide runtime settings.""" + + model_config = ConfigDict(extra="forbid") + + min_chars: int + max_suggestions: int + + +class SaytCorpusSummary(BaseModel): + """Summarise the cleaned corpus bound to a suggester.""" + + model_config = ConfigDict(extra="forbid") + + size: int + unique_display_texts: int + max_duplication: int + + +class SaytRetrieverArtifactProvenance(BaseModel): + """Capture persisted artifact details for one retriever entry.""" + + model_config = ConfigDict(extra="forbid") + + artifact_type: str + path: str | None = None + config: dict[str, Any] = Field(default_factory=dict) + + +class SaytArtifactProvenance(BaseModel): + """Describe the artifact source of a suggester restored from disk.""" + + model_config = ConfigDict(extra="forbid") + + artifact_dir: str + artifact_type: str + artifact_version: int + corpus_file: str + corpus_size: int + + +class SaytRetrieverSummary(BaseModel): + """Summarise one configured retriever within a suggester.""" + + model_config = ConfigDict(extra="forbid") + + name: str + spec_type: str + retriever_type: str + configured_weight: float + normalised_weight: float + config: dict[str, Any] = Field(default_factory=dict) + artifact_provenance: SaytRetrieverArtifactProvenance | None = None + + +class SaytConfiguration(BaseModel): + """Return a rich runtime summary of a configured suggester.""" + + model_config = ConfigDict(extra="forbid") + + settings: SaytGlobalSettings + corpus: SaytCorpusSummary + retrievers: list[SaytRetrieverSummary] = Field(default_factory=list) + artifact_provenance: SaytArtifactProvenance | None = None + + @dataclass(frozen=True, slots=True) class Suggestion: """Represent a SAYT match with score and row metadata. diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index a8c4f2b..357d6f4 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -9,8 +9,10 @@ import pytest from industrial_classification_utils.sayt import ( + NgramRetrieverSpec, PrefixRetrieverSpec, SAYTBuilder, + SaytConfiguration, ) from industrial_classification_utils.sayt.sayt import SAYTSuggester from industrial_classification_utils.sayt.sayt_core import ( @@ -166,9 +168,85 @@ def test_from_artifact_restores_prefix_suggester(tmp_path, small_corpus): min_chars=3, max_suggestions=5, ) + restored_config = restored.get_config() + expected_config = expected.get_config() assert restored.suggest("car") == expected.suggest("car") - assert restored.get_config() == expected.get_config() + assert restored_config.settings == expected_config.settings + assert restored_config.corpus == expected_config.corpus + assert [ + retriever.model_dump(exclude={"artifact_provenance"}) + for retriever in restored_config.retrievers + ] == [ + retriever.model_dump(exclude={"artifact_provenance"}) + for retriever in expected_config.retrievers + ] + assert restored_config.artifact_provenance is not None + assert restored_config.artifact_provenance.artifact_dir == str(artifact_dir) + assert restored_config.retrievers[0].artifact_provenance is not None + assert expected_config.artifact_provenance is None + + +def test_get_config_returns_rich_runtime_summary(small_corpus): + """Expose runtime settings, corpus stats, and retriever summaries.""" + suggester = SAYTSuggester( + small_corpus, + min_chars=3, + max_suggestions=5, + retrievers=[ + PrefixRetrieverSpec(weight=2.0), + NgramRetrieverSpec(weight=1.0, n=4, max_df=1.0), + ], + ) + + config = suggester.get_config() + + assert isinstance(config, SaytConfiguration) + assert config.settings.model_dump() == { + "min_chars": 3, + "max_suggestions": 5, + } + assert config.corpus.model_dump() == { + "size": suggester._corpus.size, + "unique_display_texts": len(suggester._corpus.display_text_count), + "max_duplication": suggester._max_duplication, + } + assert [retriever.name for retriever in config.retrievers] == ["prefix", "ngram"] + assert config.retrievers[0].config == {} + assert config.retrievers[1].config == {"n": 4, "max_df": 1.0} + assert config.retrievers[0].configured_weight == pytest.approx(2.0) + assert config.retrievers[0].normalised_weight == pytest.approx(2.0 / 3.0) + assert config.retrievers[1].normalised_weight == pytest.approx(1.0 / 3.0) + assert config.retrievers[1].retriever_type == "NgramRetriever" + assert config.artifact_provenance is None + + +def test_get_config_supports_custom_specs_without_artifact_handlers(small_corpus): + """Summarise custom runtime-only specs without requiring persistence hooks.""" + + class _StubRetriever: + def suggest_with_scores(self, q_norm, num_suggestions): + _ = (q_norm, num_suggestions) + return [] + + class _CustomSpec: + def __init__(self, *, trigger: str, weight: float = 1.0): + self.trigger = trigger + self.weight = weight + self.name = "custom" + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + return _StubRetriever() + + config = SAYTSuggester( + small_corpus, + min_chars=3, + retrievers=[_CustomSpec(trigger="groom")], + ).get_config() + + assert config.retrievers[0].config == {"trigger": "groom"} + assert config.retrievers[0].artifact_provenance is None def test_suggest_returns_empty_for_short_or_non_string_query(small_corpus): From 5a2f260ff2a520e8321723d24306cc4df4935808 Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 9 Jun 2026 15:04:47 +0100 Subject: [PATCH 05/11] test(sayt): ensure test coverage of new sayt code --- tests/sayt/test_sayt.py | 106 ++++++++++++ tests/sayt/test_sayt_builder.py | 28 ++++ tests/sayt/test_sayt_config.py | 12 ++ tests/sayt/test_sayt_retrievers.py | 51 ++++++ tests/sayt/test_sayt_storage.py | 257 +++++++++++++++++++++++++++++ 5 files changed, 454 insertions(+) create mode 100644 tests/sayt/test_sayt_storage.py diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index 357d6f4..34f3854 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -2,7 +2,9 @@ # pylint: disable=protected-access,redefined-outer-name,too-few-public-methods,C0116,W0613 +import json from dataclasses import dataclass +from pathlib import Path from uuid import UUID import pandas as pd @@ -187,6 +189,25 @@ def test_from_artifact_restores_prefix_suggester(tmp_path, small_corpus): assert expected_config.artifact_provenance is None +def test_from_artifact_rejects_manifest_corpus_size_mismatch(tmp_path, small_corpus): + """Reject artifacts whose manifest corpus size disagrees with stored rows.""" + artifact_dir = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(tmp_path / "artifact") + manifest_path = artifact_dir / "manifest.json" + manifest = json.loads(manifest_path.read_text(encoding="utf-8")) + manifest["corpus_size"] += 1 + manifest_path.write_text(json.dumps(manifest), encoding="utf-8") + + with pytest.raises( + ValueError, match="Artifact corpus size does not match manifest" + ): + SAYTSuggester.from_artifact(artifact_dir) + + def test_get_config_returns_rich_runtime_summary(small_corpus): """Expose runtime settings, corpus stats, and retriever summaries.""" suggester = SAYTSuggester( @@ -249,6 +270,78 @@ def build(self, corpus, *, min_chars): assert config.retrievers[0].artifact_provenance is None +def test_get_config_serialises_nested_custom_spec_values(small_corpus): + """Convert non-JSON-native custom spec config values into JSON-safe forms.""" + + class _StubRetriever: + def suggest_with_scores(self, q_norm, num_suggestions): + _ = (q_norm, num_suggestions) + return [] + + class _Marker: + def __str__(self): + return "marker-object" + + class _CustomSpec: + def __init__(self): + self.name = "custom" + self.weight = 1.0 + self.folder = Path("artifacts/model") + self.options = { + "labels": ["car", Path("cache/index"), _Marker()], + "metadata": {"marker": _Marker()}, + } + self.values = (Path("weights.bin"), _Marker()) + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + return _StubRetriever() + + config = SAYTSuggester( + small_corpus, + min_chars=3, + retrievers=[_CustomSpec()], + ).get_config() + + assert config.retrievers[0].config == { + "folder": "artifacts/model", + "options": { + "labels": ["car", "cache/index", "marker-object"], + "metadata": {"marker": "marker-object"}, + }, + "values": ["weights.bin", "marker-object"], + } + + +def test_get_config_returns_empty_config_for_slots_only_custom_spec(small_corpus): + """Return an empty config summary when a custom spec exposes no __dict__.""" + + class _StubRetriever: + def suggest_with_scores(self, q_norm, num_suggestions): + _ = (q_norm, num_suggestions) + return [] + + class _SlotsOnlySpec: + __slots__ = ("name", "trigger", "weight") + + def __init__(self, *, trigger: str, weight: float = 1.0): + self.name = "slots-only" + self.weight = weight + self.trigger = trigger + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + return _StubRetriever() + + config = SAYTSuggester( + small_corpus, + min_chars=3, + retrievers=[_SlotsOnlySpec(trigger="groom")], + ).get_config() + + assert config.retrievers[0].config == {} + + def test_suggest_returns_empty_for_short_or_non_string_query(small_corpus): """Return no suggestions for short or non-string queries.""" s = SAYTSuggester(small_corpus, min_chars=4, retrievers=[PrefixRetrieverSpec()]) @@ -543,3 +636,16 @@ def build(self, corpus, *, min_chars): ) assert not build_calls + + +def test_clean_corpus_rejects_empty_persisted_rows(): + """Reject empty persisted row collections during artifact restore.""" + with pytest.raises(ValueError, match="corpus is empty after filtering"): + CleanCorpus.from_persisted_rows([]) + + +def test_clean_corpus_coerces_persisted_tuple_values_to_strings(): + """Coerce tuple-based persisted rows to strings before rebuilding indexes.""" + restored = CleanCorpus.from_persisted_rows([(123, 456, 789)]) + + assert restored.rows == [("123", "456", "789")] diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py index bf51895..5487a79 100644 --- a/tests/sayt/test_sayt_builder.py +++ b/tests/sayt/test_sayt_builder.py @@ -6,6 +6,8 @@ import json from pathlib import Path +import pandas as pd + from industrial_classification_utils.sayt import ( NgramRetrieverSpec, PrefixRetrieverSpec, @@ -73,6 +75,32 @@ def load_retriever(self, *, spec, corpus, min_chars, path): return spec.build(corpus, min_chars=min_chars) +def test_builder_from_csv_loads_columns_and_persists_artifact(tmp_path, small_corpus): + """Build an artifact from CSV input using the configured search/display columns.""" + csv_path = tmp_path / "responses.csv" + pd.DataFrame( + { + "search": [row[0] for row in small_corpus], + "display": [row[1] for row in small_corpus], + } + ).to_csv(csv_path, index=False) + + artifact_dir = SAYTBuilder.from_csv( + csv_path, + search_text_col="search", + display_text_col="display", + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ).build_artifact(tmp_path / "artifact") + + manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) + + assert manifest["min_chars"] == 3 + assert manifest["max_suggestions"] == 5 + assert manifest["corpus_size"] == len(small_corpus) + + def test_builder_writes_manifest_and_corpus(tmp_path, small_corpus): """Persist manifest metadata and cleaned corpus rows for an artifact.""" artifact_dir = tmp_path / "artifact" diff --git a/tests/sayt/test_sayt_config.py b/tests/sayt/test_sayt_config.py index 23843c8..7ff10c1 100644 --- a/tests/sayt/test_sayt_config.py +++ b/tests/sayt/test_sayt_config.py @@ -37,6 +37,12 @@ ValueError, "max_suggestions must be between 1 and 100", ), + ( + SAYTSuggester, + {"min_chars": "abc"}, + TypeError, + "min_chars must be an integer", + ), (SAYTBuilder, {"min_chars": 2}, ValueError, "min_chars must be >= 3"), ( SAYTBuilder, @@ -56,6 +62,12 @@ ValueError, "max_suggestions must be between 1 and 100", ), + ( + SAYTBuilder, + {"max_suggestions": "abc"}, + TypeError, + "max_suggestions must be an integer", + ), ], ) def test_runtime_setting_validation(factory, kwargs, exc_type, match): diff --git a/tests/sayt/test_sayt_retrievers.py b/tests/sayt/test_sayt_retrievers.py index 6b447ec..bd6f4a4 100644 --- a/tests/sayt/test_sayt_retrievers.py +++ b/tests/sayt/test_sayt_retrievers.py @@ -17,6 +17,7 @@ DenseVectorIndex, _CharNgramVectoriser, _L2NormalisingVectoriser, + load_semantic_index, ) from industrial_classification_utils.sayt.sayt_retrievers import ( NgramRetriever, @@ -388,6 +389,56 @@ def _fake_build_dense_vector_index( assert retriever._min_chars == 3 +def test_load_semantic_index_loads_existing_filespace_with_wrapped_vectoriser( + monkeypatch, tmp_path, small_corpus +): + """Wrap the embedding vectoriser before loading a persisted semantic index.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + folder_path = tmp_path / "existing-semantic" + + class _StubHFVectoriser: + def __init__(self, model_name): + captured["model_name"] = model_name + + def transform(self, texts): + _ = texts + return np.array([[1.0, 0.0]]) + + def _fake_load_dense_vector_index(*, corpus, folder_path, vectoriser): + captured["corpus"] = corpus + captured["folder_path"] = folder_path + captured["vectoriser_type"] = type(vectoriser).__name__ + return DenseVectorIndex( + _vector_store=_StubVectorStore([]), + _num_vectors=2, + _corpus=corpus, + ) + + monkeypatch.setattr( + "industrial_classification_utils.sayt.sayt_indexes.HuggingFaceVectoriser", + _StubHFVectoriser, + ) + monkeypatch.setattr( + "industrial_classification_utils.sayt.sayt_indexes.DenseVectorIndex.from_filespace", + _fake_load_dense_vector_index, + ) + + index = load_semantic_index( + corpus, + model="all-MiniLM-L6-v2", + folder_path=folder_path, + ) + + assert index._num_vectors == 2 + assert captured == { + "model_name": "sentence-transformers/all-MiniLM-L6-v2", + "corpus": corpus, + "folder_path": folder_path, + "vectoriser_type": "_L2NormalisingVectoriser", + } + + def test_semantic_retriever_returns_empty_for_short_queries(): """Stop before vectorisation when the semantic query is too short.""" retriever = SemanticRetriever.__new__(SemanticRetriever) diff --git a/tests/sayt/test_sayt_storage.py b/tests/sayt/test_sayt_storage.py new file mode 100644 index 0000000..3fe00cb --- /dev/null +++ b/tests/sayt/test_sayt_storage.py @@ -0,0 +1,257 @@ +"""Tests for SAYT storage helper validation and artifact edge cases.""" + +# pylint: disable=protected-access,too-few-public-methods,missing-function-docstring + +import json + +import pytest + +from industrial_classification_utils.sayt import ( + PrefixRetrieverSpec, + SemanticRetrieverSpec, + sayt_storage, +) +from industrial_classification_utils.sayt.sayt_core import CleanCorpus + + +class _DuplicateHandler: + artifact_type = "test-duplicate" + + def can_handle(self, spec): + _ = spec + return False + + def serialise_spec(self, spec): + _ = spec + return {} + + def deserialise_spec(self, *, weight, config): + _ = (weight, config) + return PrefixRetrieverSpec() + + def default_path(self, *, index, spec): + _ = (index, spec) + + def build_artifact(self, *, spec, corpus, path): + _ = (spec, corpus, path) + + def load_retriever(self, *, spec, corpus, min_chars, path): + _ = path + return spec.build(corpus, min_chars=min_chars) + + +def test_prepare_artifact_dir_handles_existing_paths(tmp_path): + """Reject accidental reuse, then replace existing directories or files.""" + artifact_dir = tmp_path / "artifact" + artifact_dir.mkdir() + stale_file = artifact_dir / "stale.txt" + stale_file.write_text("stale", encoding="utf-8") + + with pytest.raises(FileExistsError, match="Artifact directory already exists"): + sayt_storage.prepare_artifact_dir(artifact_dir) + + result = sayt_storage.prepare_artifact_dir(artifact_dir, overwrite=True) + + assert result == artifact_dir + assert artifact_dir.is_dir() + assert not stale_file.exists() + + artifact_file = tmp_path / "artifact-file" + artifact_file.write_text("stale", encoding="utf-8") + + replaced = sayt_storage.prepare_artifact_dir(artifact_file, overwrite=True) + + assert replaced == artifact_file + assert artifact_file.is_dir() + + +def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): + """Raise clear errors for missing files and malformed manifest payloads.""" + with pytest.raises(FileNotFoundError, match="Artifact corpus file not found"): + sayt_storage.read_artifact_corpus(artifact_dir=tmp_path) + + with pytest.raises(FileNotFoundError, match="Artifact manifest not found"): + sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + + manifest_path = tmp_path / "manifest.json" + manifest_path.write_text( + json.dumps({"artifact_type": "other", "artifact_version": 2}), + encoding="utf-8", + ) + with pytest.raises(ValueError, match="Unsupported artifact type"): + sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + + manifest_path.write_text( + json.dumps({"artifact_type": "sayt", "artifact_version": 999}), + encoding="utf-8", + ) + with pytest.raises(ValueError, match="Unsupported artifact version"): + sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + + manifest_path.write_text( + json.dumps( + { + "artifact_type": "sayt", + "artifact_version": 2, + "min_chars": 3, + "corpus_file": "corpus.csv", + "corpus_size": 1, + "retrievers": [], + } + ), + encoding="utf-8", + ) + with pytest.raises( + ValueError, match="Malformed artifact manifest: missing max_suggestions" + ): + sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + + +def test_storage_helper_validation_errors(): + """Guard helper APIs against invalid types, paths, and missing handlers.""" + stored_retriever = sayt_storage.StoredRetrieverSpec( + artifact_type="prefix", + spec=PrefixRetrieverSpec(), + config={}, + path=None, + ) + + with pytest.raises(ValueError, match="does not have a stored filespace"): + sayt_storage.retriever_filespace_path("artifact", stored_retriever) + + with pytest.raises(ValueError, match="Malformed retriever config for type: prefix"): + sayt_storage._deserialise_stored_retriever( + {"type": "prefix", "weight": 1.0, "config": []} + ) + + with pytest.raises( + ValueError, match="No retriever artifact handler registered for type: missing" + ): + sayt_storage._get_retriever_artifact_handler("missing") + + class _UnknownSpec: + name = "unknown" + weight = 1.0 + + def build(self, corpus, *, min_chars): + _ = (corpus, min_chars) + + with pytest.raises( + TypeError, + match="No retriever artifact handler registered for spec type: _UnknownSpec", + ): + sayt_storage._get_retriever_artifact_handler_for_spec(_UnknownSpec()) + + with pytest.raises( + ValueError, match="Retriever 'semantic' requires a persisted filespace path" + ): + sayt_storage._require_path(None, "semantic") + + with pytest.raises( + ValueError, match="Malformed integer value for retriever field: n" + ): + sayt_storage._coerce_int(True, field_name="n") + + with pytest.raises( + ValueError, match="Malformed float value for retriever field: weight" + ): + sayt_storage._coerce_float(True, field_name="weight") + + with pytest.raises( + TypeError, + match="Expected spec of type SemanticRetrieverSpec, got PrefixRetrieverSpec", + ): + sayt_storage._require_spec_type(PrefixRetrieverSpec(), SemanticRetrieverSpec) + + +def test_register_retriever_artifact_handler_rejects_duplicate_registration(): + """Require replace=True before reusing an artifact type registration.""" + handler = _DuplicateHandler() + sayt_storage.register_retriever_artifact_handler(handler) + try: + with pytest.raises( + ValueError, + match="Retriever artifact handler already registered for type: test-duplicate", + ): + sayt_storage.register_retriever_artifact_handler(handler) + finally: + sayt_storage.unregister_retriever_artifact_handler(handler.artifact_type) + + +def test_semantic_artifact_handler_round_trips_and_loads( + monkeypatch, tmp_path, small_corpus +): + """Round-trip semantic spec state and delegate dense index load/build calls.""" + captured = {} + corpus = CleanCorpus.model_validate(small_corpus) + handler = sayt_storage._SemanticRetrieverArtifactHandler() + spec = SemanticRetrieverSpec(model="all-MiniLM-L6-v2", weight=2.5) + path = tmp_path / "retrievers" / "02-semantic" + + def _fake_build_semantic_index(corpus_arg, *, model, output_dir, overwrite): + captured["build"] = { + "corpus": corpus_arg, + "model": model, + "output_dir": output_dir, + "overwrite": overwrite, + } + + def _fake_load_semantic_index(corpus_arg, *, model, folder_path): + captured["load"] = { + "corpus": corpus_arg, + "model": model, + "folder_path": folder_path, + } + return "loaded-index" + + class _StubSemanticRetriever: + @classmethod + def from_index(cls, corpus_arg, *, min_chars, index): + captured["from_index"] = { + "corpus": corpus_arg, + "min_chars": min_chars, + "index": index, + } + return {"index": index, "min_chars": min_chars} + + monkeypatch.setattr( + sayt_storage, "build_semantic_index", _fake_build_semantic_index + ) + monkeypatch.setattr(sayt_storage, "load_semantic_index", _fake_load_semantic_index) + monkeypatch.setattr(sayt_storage, "SemanticRetriever", _StubSemanticRetriever) + + rebuilt = handler.deserialise_spec(weight=2.5, config={"model": "all-MiniLM-L6-v2"}) + + assert handler.serialise_spec(spec) == {"model": "all-MiniLM-L6-v2"} + assert isinstance(rebuilt, SemanticRetrieverSpec) + assert rebuilt.weight == pytest.approx(2.5) + assert rebuilt.model == "all-MiniLM-L6-v2" + assert handler.default_path(index=2, spec=spec) == "retrievers/02-semantic" + + handler.build_artifact(spec=spec, corpus=corpus, path=path) + retriever = handler.load_retriever( + spec=spec, + corpus=corpus, + min_chars=3, + path=path, + ) + + assert retriever == {"index": "loaded-index", "min_chars": 3} + assert captured == { + "build": { + "corpus": corpus, + "model": "all-MiniLM-L6-v2", + "output_dir": path, + "overwrite": True, + }, + "load": { + "corpus": corpus, + "model": "all-MiniLM-L6-v2", + "folder_path": path, + }, + "from_index": { + "corpus": corpus, + "min_chars": 3, + "index": "loaded-index", + }, + } From 70cf7287e36edc8d3ba5d271fe78f6c87f83a52a Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Fri, 12 Jun 2026 09:35:05 +0100 Subject: [PATCH 06/11] chore(sayt): rename sayt modules to remove sayt_ prefix --- demos/sayt/sayt_example.py | 2 +- .../sayt/__init__.py | 12 ++--- .../sayt/{sayt_builder.py => builder.py} | 6 +-- .../sayt/{sayt_core.py => core.py} | 0 .../sayt/{sayt_indexes.py => indexes.py} | 2 +- ..._retriever_specs.py => retriever_specs.py} | 4 +- .../{sayt_retrievers.py => retrievers.py} | 4 +- .../sayt/{sayt_storage.py => storage.py} | 8 +-- .../sayt/{sayt.py => suggester.py} | 6 +-- tests/sayt/test_sayt.py | 6 +-- tests/sayt/test_sayt_builder.py | 8 +-- tests/sayt/test_sayt_config.py | 4 +- tests/sayt/test_sayt_retrievers.py | 20 +++---- tests/sayt/test_sayt_storage.py | 54 +++++++++---------- 14 files changed, 67 insertions(+), 69 deletions(-) rename src/industrial_classification_utils/sayt/{sayt_builder.py => builder.py} (95%) rename src/industrial_classification_utils/sayt/{sayt_core.py => core.py} (100%) rename src/industrial_classification_utils/sayt/{sayt_indexes.py => indexes.py} (99%) rename src/industrial_classification_utils/sayt/{sayt_retriever_specs.py => retriever_specs.py} (98%) rename src/industrial_classification_utils/sayt/{sayt_retrievers.py => retrievers.py} (97%) rename src/industrial_classification_utils/sayt/{sayt_storage.py => storage.py} (99%) rename src/industrial_classification_utils/sayt/{sayt.py => suggester.py} (99%) diff --git a/demos/sayt/sayt_example.py b/demos/sayt/sayt_example.py index 35d207a..03130a7 100644 --- a/demos/sayt/sayt_example.py +++ b/demos/sayt/sayt_example.py @@ -11,7 +11,7 @@ SAYTSuggester, SemanticRetrieverSpec, ) -from industrial_classification_utils.sayt.sayt_core import _normalise +from industrial_classification_utils.sayt.core import _normalise logger = get_logger(__name__) # %% diff --git a/src/industrial_classification_utils/sayt/__init__.py b/src/industrial_classification_utils/sayt/__init__.py index 2bbed3f..d1bdb8c 100644 --- a/src/industrial_classification_utils/sayt/__init__.py +++ b/src/industrial_classification_utils/sayt/__init__.py @@ -1,9 +1,8 @@ """Public SAYT interfaces and built-in retriever components.""" -from .sayt import SAYTSuggester -from .sayt_builder import SAYTBuilder -from .sayt_core import SaytConfiguration -from .sayt_retriever_specs import ( +from .builder import SAYTBuilder +from .core import SaytConfiguration +from .retriever_specs import ( NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, @@ -12,11 +11,12 @@ SemanticRetrieverSpec, default_retriever_specs, ) -from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever -from .sayt_storage import ( +from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever +from .storage import ( register_retriever_artifact_handler, unregister_retriever_artifact_handler, ) +from .suggester import SAYTSuggester __all__ = [ "NgramRetriever", diff --git a/src/industrial_classification_utils/sayt/sayt_builder.py b/src/industrial_classification_utils/sayt/builder.py similarity index 95% rename from src/industrial_classification_utils/sayt/sayt_builder.py rename to src/industrial_classification_utils/sayt/builder.py index a416603..57770d9 100644 --- a/src/industrial_classification_utils/sayt/sayt_builder.py +++ b/src/industrial_classification_utils/sayt/builder.py @@ -6,12 +6,12 @@ from collections.abc import Iterable, Sequence from pathlib import Path -from .sayt_core import CleanCorpus, validate_max_suggestions, validate_min_chars -from .sayt_retriever_specs import ( +from .core import CleanCorpus, validate_max_suggestions, validate_min_chars +from .retriever_specs import ( RetrieverSpec, default_retriever_specs, ) -from .sayt_storage import ( +from .storage import ( build_artifact_manifest, build_retriever_artifact, load_corpus_from_csv, diff --git a/src/industrial_classification_utils/sayt/sayt_core.py b/src/industrial_classification_utils/sayt/core.py similarity index 100% rename from src/industrial_classification_utils/sayt/sayt_core.py rename to src/industrial_classification_utils/sayt/core.py diff --git a/src/industrial_classification_utils/sayt/sayt_indexes.py b/src/industrial_classification_utils/sayt/indexes.py similarity index 99% rename from src/industrial_classification_utils/sayt/sayt_indexes.py rename to src/industrial_classification_utils/sayt/indexes.py index a435ab2..619e5ea 100644 --- a/src/industrial_classification_utils/sayt/sayt_indexes.py +++ b/src/industrial_classification_utils/sayt/indexes.py @@ -17,7 +17,7 @@ from scipy.sparse import csr_matrix from sklearn.feature_extraction.text import CountVectorizer -from .sayt_core import CleanCorpus, take_with_ties +from .core import CleanCorpus, take_with_ties def _silent_tqdm(iterable, **_kwargs): diff --git a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py b/src/industrial_classification_utils/sayt/retriever_specs.py similarity index 98% rename from src/industrial_classification_utils/sayt/sayt_retriever_specs.py rename to src/industrial_classification_utils/sayt/retriever_specs.py index 955b803..e03bb84 100644 --- a/src/industrial_classification_utils/sayt/sayt_retriever_specs.py +++ b/src/industrial_classification_utils/sayt/retriever_specs.py @@ -8,8 +8,8 @@ from pathlib import Path from typing import Protocol -from .sayt_core import CleanCorpus, Suggestion -from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever +from .core import CleanCorpus, Suggestion +from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever _MIN_NGRAM_SIZE = 2 _MAX_NGRAM_SIZE = 5 diff --git a/src/industrial_classification_utils/sayt/sayt_retrievers.py b/src/industrial_classification_utils/sayt/retrievers.py similarity index 97% rename from src/industrial_classification_utils/sayt/sayt_retrievers.py rename to src/industrial_classification_utils/sayt/retrievers.py index dcbf8e5..e8d35c2 100644 --- a/src/industrial_classification_utils/sayt/sayt_retrievers.py +++ b/src/industrial_classification_utils/sayt/retrievers.py @@ -6,8 +6,8 @@ from dataclasses import dataclass from difflib import SequenceMatcher -from .sayt_core import CleanCorpus, Suggestion, take_with_ties -from .sayt_indexes import DenseVectorIndex, build_ngram_index, build_semantic_index +from .core import CleanCorpus, Suggestion, take_with_ties +from .indexes import DenseVectorIndex, build_ngram_index, build_semantic_index _FUZZY_PREFIX_MIN_RATIO = 0.75 diff --git a/src/industrial_classification_utils/sayt/sayt_storage.py b/src/industrial_classification_utils/sayt/storage.py similarity index 99% rename from src/industrial_classification_utils/sayt/sayt_storage.py rename to src/industrial_classification_utils/sayt/storage.py index 87f95c8..eafc576 100644 --- a/src/industrial_classification_utils/sayt/sayt_storage.py +++ b/src/industrial_classification_utils/sayt/storage.py @@ -11,19 +11,19 @@ import pandas as pd -from .sayt_core import ( +from .core import ( CleanCorpus, PersistedCorpusRow, validate_max_suggestions, validate_min_chars, ) -from .sayt_indexes import ( +from .indexes import ( build_ngram_index, build_semantic_index, load_ngram_index, load_semantic_index, ) -from .sayt_retriever_specs import ( +from .retriever_specs import ( NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, @@ -31,7 +31,7 @@ RetrieverSpec, SemanticRetrieverSpec, ) -from .sayt_retrievers import NgramRetriever, SemanticRetriever +from .retrievers import NgramRetriever, SemanticRetriever SAYT_ARTIFACT_TYPE = "sayt" SAYT_ARTIFACT_VERSION = 2 diff --git a/src/industrial_classification_utils/sayt/sayt.py b/src/industrial_classification_utils/sayt/suggester.py similarity index 99% rename from src/industrial_classification_utils/sayt/sayt.py rename to src/industrial_classification_utils/sayt/suggester.py index 8111db5..e688e32 100644 --- a/src/industrial_classification_utils/sayt/sayt.py +++ b/src/industrial_classification_utils/sayt/suggester.py @@ -15,7 +15,7 @@ from survey_assist_utils.logging import get_logger -from .sayt_core import ( +from .core import ( CleanCorpus, SaytArtifactProvenance, SaytConfiguration, @@ -29,12 +29,12 @@ validate_max_suggestions, validate_min_chars, ) -from .sayt_retriever_specs import ( +from .retriever_specs import ( Retriever, RetrieverSpec, default_retriever_specs, ) -from .sayt_storage import ( +from .storage import ( SAYT_ARTIFACT_TYPE, SAYT_ARTIFACT_VERSION, StoredRetrieverSpec, diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index 34f3854..f93bad5 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -16,12 +16,12 @@ SAYTBuilder, SaytConfiguration, ) -from industrial_classification_utils.sayt.sayt import SAYTSuggester -from industrial_classification_utils.sayt.sayt_core import ( +from industrial_classification_utils.sayt.core import ( CleanCorpus, PersistedCorpusRow, Suggestion, ) +from industrial_classification_utils.sayt.suggester import SAYTSuggester def test_constructor_rejects_unknown_kwargs(small_corpus): @@ -559,7 +559,7 @@ def build(self, corpus, *, min_chars): return _StubRetriever() monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt.default_retriever_specs", + "industrial_classification_utils.sayt.suggester.default_retriever_specs", lambda: [ _StubRetrieverSpec(name="prefix"), _StubRetrieverSpec(name="ngram"), diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py index 5487a79..5bca314 100644 --- a/tests/sayt/test_sayt_builder.py +++ b/tests/sayt/test_sayt_builder.py @@ -16,8 +16,8 @@ register_retriever_artifact_handler, unregister_retriever_artifact_handler, ) -from industrial_classification_utils.sayt.sayt import SAYTSuggester -from industrial_classification_utils.sayt.sayt_core import CleanCorpus, Suggestion +from industrial_classification_utils.sayt.core import CleanCorpus, Suggestion +from industrial_classification_utils.sayt.suggester import SAYTSuggester class _CustomRetriever: @@ -170,7 +170,7 @@ def __init__( # noqa: PLR0913 self.num_vectors = 1 monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.VectorStore", + "industrial_classification_utils.sayt.indexes.VectorStore", _StubPersistentVectorStore, ) @@ -235,7 +235,7 @@ def search(self, query, n_results=10): return _StubSearchResults() monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.VectorStore", + "industrial_classification_utils.sayt.indexes.VectorStore", _StubPersistentVectorStore, ) diff --git a/tests/sayt/test_sayt_config.py b/tests/sayt/test_sayt_config.py index 7ff10c1..2f2a22b 100644 --- a/tests/sayt/test_sayt_config.py +++ b/tests/sayt/test_sayt_config.py @@ -12,7 +12,7 @@ SemanticRetrieverSpec, default_retriever_specs, ) -from industrial_classification_utils.sayt.sayt_core import CleanCorpus +from industrial_classification_utils.sayt.core import CleanCorpus @pytest.mark.parametrize( @@ -154,7 +154,7 @@ def __init__(self, corpus_arg, *, model, min_chars): captured["min_chars"] = min_chars monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_retriever_specs.SemanticRetriever", + "industrial_classification_utils.sayt.retriever_specs.SemanticRetriever", _StubSemanticRetriever, ) diff --git a/tests/sayt/test_sayt_retrievers.py b/tests/sayt/test_sayt_retrievers.py index bd6f4a4..e59e318 100644 --- a/tests/sayt/test_sayt_retrievers.py +++ b/tests/sayt/test_sayt_retrievers.py @@ -11,20 +11,20 @@ from classifai.vectorisers import VectoriserBase from industrial_classification_utils.sayt import NgramRetrieverSpec, PrefixRetrieverSpec -from industrial_classification_utils.sayt.sayt import SAYTSuggester -from industrial_classification_utils.sayt.sayt_core import CleanCorpus -from industrial_classification_utils.sayt.sayt_indexes import ( +from industrial_classification_utils.sayt.core import CleanCorpus +from industrial_classification_utils.sayt.indexes import ( DenseVectorIndex, _CharNgramVectoriser, _L2NormalisingVectoriser, load_semantic_index, ) -from industrial_classification_utils.sayt.sayt_retrievers import ( +from industrial_classification_utils.sayt.retrievers import ( NgramRetriever, PrefixRetriever, SemanticRetriever, _PrefixIndex, ) +from industrial_classification_utils.sayt.suggester import SAYTSuggester def test_prefix_full_string_match_ranks_expected_terms(small_corpus): @@ -277,7 +277,7 @@ def __init__( # noqa: PLR0913 self.num_vectors = len(captured["rows"]) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.VectorStore", + "industrial_classification_utils.sayt.indexes.VectorStore", _StubPersistentVectorStore, ) @@ -320,7 +320,7 @@ def _fake_from_filespace(*, folder_path, vectoriser, hooks): return _StubLoadedVectorStore() monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.VectorStore.from_filespace", + "industrial_classification_utils.sayt.indexes.VectorStore.from_filespace", _fake_from_filespace, ) @@ -370,11 +370,11 @@ def _fake_build_dense_vector_index( ) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.HuggingFaceVectoriser", + "industrial_classification_utils.sayt.indexes.HuggingFaceVectoriser", _StubHFVectoriser, ) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.DenseVectorIndex.from_corpus", + "industrial_classification_utils.sayt.indexes.DenseVectorIndex.from_corpus", _fake_build_dense_vector_index, ) @@ -416,11 +416,11 @@ def _fake_load_dense_vector_index(*, corpus, folder_path, vectoriser): ) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.HuggingFaceVectoriser", + "industrial_classification_utils.sayt.indexes.HuggingFaceVectoriser", _StubHFVectoriser, ) monkeypatch.setattr( - "industrial_classification_utils.sayt.sayt_indexes.DenseVectorIndex.from_filespace", + "industrial_classification_utils.sayt.indexes.DenseVectorIndex.from_filespace", _fake_load_dense_vector_index, ) diff --git a/tests/sayt/test_sayt_storage.py b/tests/sayt/test_sayt_storage.py index 3fe00cb..a1c8336 100644 --- a/tests/sayt/test_sayt_storage.py +++ b/tests/sayt/test_sayt_storage.py @@ -9,9 +9,9 @@ from industrial_classification_utils.sayt import ( PrefixRetrieverSpec, SemanticRetrieverSpec, - sayt_storage, + storage, ) -from industrial_classification_utils.sayt.sayt_core import CleanCorpus +from industrial_classification_utils.sayt.core import CleanCorpus class _DuplicateHandler: @@ -48,9 +48,9 @@ def test_prepare_artifact_dir_handles_existing_paths(tmp_path): stale_file.write_text("stale", encoding="utf-8") with pytest.raises(FileExistsError, match="Artifact directory already exists"): - sayt_storage.prepare_artifact_dir(artifact_dir) + storage.prepare_artifact_dir(artifact_dir) - result = sayt_storage.prepare_artifact_dir(artifact_dir, overwrite=True) + result = storage.prepare_artifact_dir(artifact_dir, overwrite=True) assert result == artifact_dir assert artifact_dir.is_dir() @@ -59,7 +59,7 @@ def test_prepare_artifact_dir_handles_existing_paths(tmp_path): artifact_file = tmp_path / "artifact-file" artifact_file.write_text("stale", encoding="utf-8") - replaced = sayt_storage.prepare_artifact_dir(artifact_file, overwrite=True) + replaced = storage.prepare_artifact_dir(artifact_file, overwrite=True) assert replaced == artifact_file assert artifact_file.is_dir() @@ -68,10 +68,10 @@ def test_prepare_artifact_dir_handles_existing_paths(tmp_path): def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): """Raise clear errors for missing files and malformed manifest payloads.""" with pytest.raises(FileNotFoundError, match="Artifact corpus file not found"): - sayt_storage.read_artifact_corpus(artifact_dir=tmp_path) + storage.read_artifact_corpus(artifact_dir=tmp_path) with pytest.raises(FileNotFoundError, match="Artifact manifest not found"): - sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + storage.read_artifact_manifest(artifact_dir=tmp_path) manifest_path = tmp_path / "manifest.json" manifest_path.write_text( @@ -79,14 +79,14 @@ def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): encoding="utf-8", ) with pytest.raises(ValueError, match="Unsupported artifact type"): - sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + storage.read_artifact_manifest(artifact_dir=tmp_path) manifest_path.write_text( json.dumps({"artifact_type": "sayt", "artifact_version": 999}), encoding="utf-8", ) with pytest.raises(ValueError, match="Unsupported artifact version"): - sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + storage.read_artifact_manifest(artifact_dir=tmp_path) manifest_path.write_text( json.dumps( @@ -104,12 +104,12 @@ def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): with pytest.raises( ValueError, match="Malformed artifact manifest: missing max_suggestions" ): - sayt_storage.read_artifact_manifest(artifact_dir=tmp_path) + storage.read_artifact_manifest(artifact_dir=tmp_path) def test_storage_helper_validation_errors(): """Guard helper APIs against invalid types, paths, and missing handlers.""" - stored_retriever = sayt_storage.StoredRetrieverSpec( + stored_retriever = storage.StoredRetrieverSpec( artifact_type="prefix", spec=PrefixRetrieverSpec(), config={}, @@ -117,17 +117,17 @@ def test_storage_helper_validation_errors(): ) with pytest.raises(ValueError, match="does not have a stored filespace"): - sayt_storage.retriever_filespace_path("artifact", stored_retriever) + storage.retriever_filespace_path("artifact", stored_retriever) with pytest.raises(ValueError, match="Malformed retriever config for type: prefix"): - sayt_storage._deserialise_stored_retriever( + storage._deserialise_stored_retriever( {"type": "prefix", "weight": 1.0, "config": []} ) with pytest.raises( ValueError, match="No retriever artifact handler registered for type: missing" ): - sayt_storage._get_retriever_artifact_handler("missing") + storage._get_retriever_artifact_handler("missing") class _UnknownSpec: name = "unknown" @@ -140,42 +140,42 @@ def build(self, corpus, *, min_chars): TypeError, match="No retriever artifact handler registered for spec type: _UnknownSpec", ): - sayt_storage._get_retriever_artifact_handler_for_spec(_UnknownSpec()) + storage._get_retriever_artifact_handler_for_spec(_UnknownSpec()) with pytest.raises( ValueError, match="Retriever 'semantic' requires a persisted filespace path" ): - sayt_storage._require_path(None, "semantic") + storage._require_path(None, "semantic") with pytest.raises( ValueError, match="Malformed integer value for retriever field: n" ): - sayt_storage._coerce_int(True, field_name="n") + storage._coerce_int(True, field_name="n") with pytest.raises( ValueError, match="Malformed float value for retriever field: weight" ): - sayt_storage._coerce_float(True, field_name="weight") + storage._coerce_float(True, field_name="weight") with pytest.raises( TypeError, match="Expected spec of type SemanticRetrieverSpec, got PrefixRetrieverSpec", ): - sayt_storage._require_spec_type(PrefixRetrieverSpec(), SemanticRetrieverSpec) + storage._require_spec_type(PrefixRetrieverSpec(), SemanticRetrieverSpec) def test_register_retriever_artifact_handler_rejects_duplicate_registration(): """Require replace=True before reusing an artifact type registration.""" handler = _DuplicateHandler() - sayt_storage.register_retriever_artifact_handler(handler) + storage.register_retriever_artifact_handler(handler) try: with pytest.raises( ValueError, match="Retriever artifact handler already registered for type: test-duplicate", ): - sayt_storage.register_retriever_artifact_handler(handler) + storage.register_retriever_artifact_handler(handler) finally: - sayt_storage.unregister_retriever_artifact_handler(handler.artifact_type) + storage.unregister_retriever_artifact_handler(handler.artifact_type) def test_semantic_artifact_handler_round_trips_and_loads( @@ -184,7 +184,7 @@ def test_semantic_artifact_handler_round_trips_and_loads( """Round-trip semantic spec state and delegate dense index load/build calls.""" captured = {} corpus = CleanCorpus.model_validate(small_corpus) - handler = sayt_storage._SemanticRetrieverArtifactHandler() + handler = storage._SemanticRetrieverArtifactHandler() spec = SemanticRetrieverSpec(model="all-MiniLM-L6-v2", weight=2.5) path = tmp_path / "retrievers" / "02-semantic" @@ -214,11 +214,9 @@ def from_index(cls, corpus_arg, *, min_chars, index): } return {"index": index, "min_chars": min_chars} - monkeypatch.setattr( - sayt_storage, "build_semantic_index", _fake_build_semantic_index - ) - monkeypatch.setattr(sayt_storage, "load_semantic_index", _fake_load_semantic_index) - monkeypatch.setattr(sayt_storage, "SemanticRetriever", _StubSemanticRetriever) + monkeypatch.setattr(storage, "build_semantic_index", _fake_build_semantic_index) + monkeypatch.setattr(storage, "load_semantic_index", _fake_load_semantic_index) + monkeypatch.setattr(storage, "SemanticRetriever", _StubSemanticRetriever) rebuilt = handler.deserialise_spec(weight=2.5, config={"model": "all-MiniLM-L6-v2"}) From dce6a3342857425c5476c8b80cb1a0c56d75893c Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Fri, 12 Jun 2026 14:16:19 +0100 Subject: [PATCH 07/11] refactor(sayt): use built-in artifact persistence only for now --- .../sayt/__init__.py | 8 - .../sayt/retriever_specs.py | 51 --- .../sayt/storage.py | 376 +++++------------- tests/sayt/test_sayt_builder.py | 99 +---- tests/sayt/test_sayt_storage.py | 105 ++--- 5 files changed, 150 insertions(+), 489 deletions(-) diff --git a/src/industrial_classification_utils/sayt/__init__.py b/src/industrial_classification_utils/sayt/__init__.py index d1bdb8c..edab23b 100644 --- a/src/industrial_classification_utils/sayt/__init__.py +++ b/src/industrial_classification_utils/sayt/__init__.py @@ -6,16 +6,11 @@ NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, - RetrieverArtifactHandler, RetrieverSpec, SemanticRetrieverSpec, default_retriever_specs, ) from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever -from .storage import ( - register_retriever_artifact_handler, - unregister_retriever_artifact_handler, -) from .suggester import SAYTSuggester __all__ = [ @@ -24,7 +19,6 @@ "PrefixRetriever", "PrefixRetrieverSpec", "Retriever", - "RetrieverArtifactHandler", "RetrieverSpec", "SAYTBuilder", "SAYTSuggester", @@ -32,6 +26,4 @@ "SemanticRetriever", "SemanticRetrieverSpec", "default_retriever_specs", - "register_retriever_artifact_handler", - "unregister_retriever_artifact_handler", ] diff --git a/src/industrial_classification_utils/sayt/retriever_specs.py b/src/industrial_classification_utils/sayt/retriever_specs.py index e03bb84..c144649 100644 --- a/src/industrial_classification_utils/sayt/retriever_specs.py +++ b/src/industrial_classification_utils/sayt/retriever_specs.py @@ -3,9 +3,7 @@ """Public retriever protocols and configuration objects for SAYT.""" import math -from collections.abc import Mapping from dataclasses import dataclass, field -from pathlib import Path from typing import Protocol from .core import CleanCorpus, Suggestion @@ -55,55 +53,6 @@ def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: """ -class RetrieverArtifactHandler(Protocol): - """Persistence hooks for storing and restoring a retriever spec. - - This optional protocol extends the runtime-only ``RetrieverSpec`` contract - for artifact build/load flows. Handlers are registered separately so custom - retriever specs remain lightweight unless they need persistence support. - """ - - @property - def artifact_type(self) -> str: - """Return the stable manifest identifier for this handler.""" - - def can_handle(self, spec: RetrieverSpec) -> bool: - """Return whether this handler can persist the supplied spec.""" - - def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: - """Return spec-specific manifest configuration excluding weight/path.""" - - def deserialise_spec( - self, - *, - weight: float, - config: Mapping[str, object], - ) -> RetrieverSpec: - """Rebuild a retriever spec from persisted manifest data.""" - - def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: - """Return the default relative artifact path for persisted assets.""" - - def build_artifact( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - path: Path | None, - ) -> None: - """Write any persisted retriever assets needed for later loading.""" - - def load_retriever( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - min_chars: int, - path: Path | None, - ) -> Retriever: - """Restore a runtime retriever from persisted artifact state.""" - - def _validate_retriever_weight(weight: float) -> None: if not math.isfinite(weight) or weight <= 0: raise ValueError("retriever weight must be a finite value > 0") diff --git a/src/industrial_classification_utils/sayt/storage.py b/src/industrial_classification_utils/sayt/storage.py index eafc576..fa59605 100644 --- a/src/industrial_classification_utils/sayt/storage.py +++ b/src/industrial_classification_utils/sayt/storage.py @@ -4,10 +4,8 @@ import json import os import shutil -from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path -from typing import TypeVar import pandas as pd @@ -27,7 +25,6 @@ NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, - RetrieverArtifactHandler, RetrieverSpec, SemanticRetrieverSpec, ) @@ -40,9 +37,6 @@ _ARTIFACT_CORPUS_FIELDS = ["row_id", "search_text", "display_text"] _RETRIEVERS_DIR_NAME = "retrievers" -_RETRIEVER_ARTIFACT_HANDLERS: dict[str, RetrieverArtifactHandler] = {} -SpecT = TypeVar("SpecT", bound=RetrieverSpec) - @dataclass(frozen=True, slots=True) class StoredRetrieverSpec: @@ -224,42 +218,38 @@ def retriever_filespace_path( return Path(artifact_dir) / stored_retriever.path -def register_retriever_artifact_handler( - handler: RetrieverArtifactHandler, - *, - replace: bool = False, -) -> None: - """Register a handler for artifact persistence of retriever specs.""" - artifact_type = handler.artifact_type - if artifact_type in _RETRIEVER_ARTIFACT_HANDLERS and not replace: - raise ValueError( - f"Retriever artifact handler already registered for type: {artifact_type}" - ) - _RETRIEVER_ARTIFACT_HANDLERS[artifact_type] = handler - - -def unregister_retriever_artifact_handler(artifact_type: str) -> None: - """Remove a previously registered retriever artifact handler.""" - _RETRIEVER_ARTIFACT_HANDLERS.pop(artifact_type, None) - - def build_retriever_artifact( *, corpus: CleanCorpus, stored_retriever: StoredRetrieverSpec, artifact_dir: str | Path, ) -> None: - """Persist retriever-specific artifact state using its registered handler.""" - handler = _get_retriever_artifact_handler(stored_retriever.artifact_type) - path = ( - retriever_filespace_path(artifact_dir, stored_retriever) - if stored_retriever.path is not None - else None - ) - handler.build_artifact( - spec=stored_retriever.spec, - corpus=corpus, - path=path, + """Persist built-in retriever assets required by a SAYT artifact.""" + spec = stored_retriever.spec + if isinstance(spec, PrefixRetrieverSpec): + return + + if isinstance(spec, NgramRetrieverSpec): + build_ngram_index( + corpus, + n=spec.n, + max_df=spec.max_df, + output_dir=retriever_filespace_path(artifact_dir, stored_retriever), + overwrite=True, + ) + return + + if isinstance(spec, SemanticRetrieverSpec): + build_semantic_index( + corpus, + model=spec.model, + output_dir=retriever_filespace_path(artifact_dir, stored_retriever), + overwrite=True, + ) + return + + raise TypeError( + "Only built-in retriever specs can be persisted; " f"got {type(spec).__name__}" ) @@ -270,18 +260,39 @@ def load_retriever_from_artifact( stored_retriever: StoredRetrieverSpec, artifact_dir: str | Path, ) -> Retriever: - """Restore a runtime retriever using its registered artifact handler.""" - handler = _get_retriever_artifact_handler(stored_retriever.artifact_type) - path = ( - retriever_filespace_path(artifact_dir, stored_retriever) - if stored_retriever.path is not None - else None - ) - return handler.load_retriever( - spec=stored_retriever.spec, - corpus=corpus, - min_chars=min_chars, - path=path, + """Restore a built-in runtime retriever from persisted artifact state.""" + spec = stored_retriever.spec + if isinstance(spec, PrefixRetrieverSpec): + return spec.build(corpus, min_chars=min_chars) + + if isinstance(spec, NgramRetrieverSpec): + index = load_ngram_index( + corpus, + n=spec.n, + max_df=spec.max_df, + folder_path=retriever_filespace_path(artifact_dir, stored_retriever), + ) + return NgramRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + if isinstance(spec, SemanticRetrieverSpec): + index = load_semantic_index( + corpus, + model=spec.model, + folder_path=retriever_filespace_path(artifact_dir, stored_retriever), + ) + return SemanticRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + raise TypeError( + "Only built-in retriever specs can be restored from artifacts; " + f"got {type(spec).__name__}" ) @@ -289,12 +300,32 @@ def _build_stored_retriever( index: int, spec: RetrieverSpec, ) -> StoredRetrieverSpec: - handler = _get_retriever_artifact_handler_for_spec(spec) - return StoredRetrieverSpec( - artifact_type=handler.artifact_type, - spec=spec, - config=handler.serialise_spec(spec), - path=handler.default_path(index=index, spec=spec), + if isinstance(spec, PrefixRetrieverSpec): + return StoredRetrieverSpec( + artifact_type="prefix", + spec=spec, + config={}, + path=None, + ) + + if isinstance(spec, NgramRetrieverSpec): + return StoredRetrieverSpec( + artifact_type="ngram", + spec=spec, + config={"n": spec.n, "max_df": spec.max_df}, + path=f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}", + ) + + if isinstance(spec, SemanticRetrieverSpec): + return StoredRetrieverSpec( + artifact_type="semantic", + spec=spec, + config={"model": spec.model}, + path=f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}", + ) + + raise TypeError( + "Only built-in retriever specs can be persisted; " f"got {type(spec).__name__}" ) @@ -331,215 +362,29 @@ def _deserialise_stored_retriever(payload: dict[str, object]) -> StoredRetriever config = payload.get("config", {}) if not isinstance(config, dict): raise ValueError(f"Malformed retriever config for type: {retriever_type}") - handler = _get_retriever_artifact_handler(retriever_type) - spec = handler.deserialise_spec(weight=weight, config=config) - return StoredRetrieverSpec( - artifact_type=retriever_type, - spec=spec, - config=dict(config), - path=str(path) if isinstance(path, str) else None, - ) - - -def _get_retriever_artifact_handler(artifact_type: str) -> RetrieverArtifactHandler: - try: - return _RETRIEVER_ARTIFACT_HANDLERS[artifact_type] - except KeyError as exc: - raise ValueError( - f"No retriever artifact handler registered for type: {artifact_type}" - ) from exc - - -def _get_retriever_artifact_handler_for_spec( - spec: RetrieverSpec, -) -> RetrieverArtifactHandler: - for handler in reversed(tuple(_RETRIEVER_ARTIFACT_HANDLERS.values())): - if handler.can_handle(spec): - return handler - raise TypeError( - f"No retriever artifact handler registered for spec type: {type(spec).__name__}" - ) - - -class _PrefixRetrieverArtifactHandler: # pylint: disable=missing-function-docstring,useless-return - """Artifact handler for the built-in prefix retriever spec.""" - - artifact_type = "prefix" - - def can_handle(self, spec: RetrieverSpec) -> bool: - return isinstance(spec, PrefixRetrieverSpec) - - def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: - _ = spec - return {} - - def deserialise_spec( - self, - *, - weight: float, - config: Mapping[str, object], - ) -> RetrieverSpec: - _ = config - return PrefixRetrieverSpec(weight=weight) - - def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: - _ = (index, spec) - return None - - def build_artifact( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - path: Path | None, - ) -> None: - _ = (spec, corpus, path) - - def load_retriever( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - min_chars: int, - path: Path | None, - ) -> Retriever: - _ = path - return spec.build(corpus, min_chars=min_chars) - - -class _NgramRetrieverArtifactHandler: # pylint: disable=missing-function-docstring - """Artifact handler for the built-in n-gram retriever spec.""" - - artifact_type = "ngram" - - def can_handle(self, spec: RetrieverSpec) -> bool: - return isinstance(spec, NgramRetrieverSpec) - - def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: - typed_spec = _require_spec_type(spec, NgramRetrieverSpec) - return {"n": typed_spec.n, "max_df": typed_spec.max_df} - - def deserialise_spec( - self, - *, - weight: float, - config: Mapping[str, object], - ) -> RetrieverSpec: - return NgramRetrieverSpec( + spec: RetrieverSpec + if retriever_type == "prefix": + spec = PrefixRetrieverSpec(weight=weight) + elif retriever_type == "ngram": + spec = NgramRetrieverSpec( weight=weight, n=_coerce_int(config["n"], field_name="n"), max_df=_coerce_float(config["max_df"], field_name="max_df"), ) - - def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: - return f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}" - - def build_artifact( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - path: Path | None, - ) -> None: - typed_spec = _require_spec_type(spec, NgramRetrieverSpec) - build_ngram_index( - corpus, - n=typed_spec.n, - max_df=typed_spec.max_df, - output_dir=_require_path(path, typed_spec.name), - overwrite=True, - ) - - def load_retriever( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - min_chars: int, - path: Path | None, - ) -> Retriever: - typed_spec = _require_spec_type(spec, NgramRetrieverSpec) - index = load_ngram_index( - corpus, - n=typed_spec.n, - max_df=typed_spec.max_df, - folder_path=_require_path(path, typed_spec.name), - ) - return NgramRetriever.from_index( - corpus, - min_chars=min_chars, - index=index, - ) - - -class _SemanticRetrieverArtifactHandler: # pylint: disable=missing-function-docstring - """Artifact handler for the built-in semantic retriever spec.""" - - artifact_type = "semantic" - - def can_handle(self, spec: RetrieverSpec) -> bool: - return isinstance(spec, SemanticRetrieverSpec) - - def serialise_spec(self, spec: RetrieverSpec) -> dict[str, object]: - typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) - return {"model": typed_spec.model} - - def deserialise_spec( - self, - *, - weight: float, - config: Mapping[str, object], - ) -> RetrieverSpec: - return SemanticRetrieverSpec( + elif retriever_type == "semantic": + spec = SemanticRetrieverSpec( weight=weight, model=str(config["model"]), ) + else: + raise ValueError(f"Unsupported stored retriever type: {retriever_type}") - def default_path(self, *, index: int, spec: RetrieverSpec) -> str | None: - return f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}" - - def build_artifact( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - path: Path | None, - ) -> None: - typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) - build_semantic_index( - corpus, - model=typed_spec.model, - output_dir=_require_path(path, typed_spec.name), - overwrite=True, - ) - - def load_retriever( - self, - *, - spec: RetrieverSpec, - corpus: CleanCorpus, - min_chars: int, - path: Path | None, - ) -> Retriever: - typed_spec = _require_spec_type(spec, SemanticRetrieverSpec) - index = load_semantic_index( - corpus, - model=typed_spec.model, - folder_path=_require_path(path, typed_spec.name), - ) - return SemanticRetriever.from_index( - corpus, - min_chars=min_chars, - index=index, - ) - - -def _require_path(path: Path | None, retriever_name: str) -> Path: - if path is None: - raise ValueError( - f"Retriever '{retriever_name}' requires a persisted filespace path" - ) - return path + return StoredRetrieverSpec( + artifact_type=retriever_type, + spec=spec, + config=dict(config), + path=str(path) if isinstance(path, str) else None, + ) def _coerce_int(value: object, *, field_name: str) -> int: @@ -552,24 +397,3 @@ def _coerce_float(value: object, *, field_name: str) -> float: if isinstance(value, bool) or not isinstance(value, int | float | str): raise ValueError(f"Malformed float value for retriever field: {field_name}") return float(value) - - -def _require_spec_type(spec: RetrieverSpec, spec_type: type[SpecT]) -> SpecT: - if not isinstance(spec, spec_type): - raise TypeError( - f"Expected spec of type {spec_type.__name__}, got {type(spec).__name__}" - ) - return spec - - -def _register_builtin_retriever_artifact_handlers() -> None: - """Seed the artifact handler registry with the built-in retriever types.""" - for handler in ( - _PrefixRetrieverArtifactHandler(), - _NgramRetrieverArtifactHandler(), - _SemanticRetrieverArtifactHandler(), - ): - register_retriever_artifact_handler(handler) - - -_register_builtin_retriever_artifact_handlers() diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py index 5bca314..8bc5eb9 100644 --- a/tests/sayt/test_sayt_builder.py +++ b/tests/sayt/test_sayt_builder.py @@ -7,39 +7,17 @@ from pathlib import Path import pandas as pd +import pytest from industrial_classification_utils.sayt import ( NgramRetrieverSpec, PrefixRetrieverSpec, - RetrieverArtifactHandler, SAYTBuilder, - register_retriever_artifact_handler, - unregister_retriever_artifact_handler, ) -from industrial_classification_utils.sayt.core import CleanCorpus, Suggestion +from industrial_classification_utils.sayt.core import CleanCorpus from industrial_classification_utils.sayt.suggester import SAYTSuggester -class _CustomRetriever: - def __init__(self, row, *, trigger: str, min_chars: int): - self._row = row - self._trigger = trigger - self._min_chars = min_chars - - def suggest_with_scores(self, q_norm, num_suggestions): - _ = num_suggestions - if len(q_norm) < self._min_chars or self._trigger not in q_norm: - return [] - return [ - Suggestion( - display_text=self._row[2], - score=1.0, - search_text=self._row[1], - row_id=self._row[0], - ) - ] - - class _CustomRetrieverSpec: def __init__(self, *, trigger: str, weight: float = 1.0): self.trigger = trigger @@ -47,32 +25,7 @@ def __init__(self, *, trigger: str, weight: float = 1.0): self.name = "custom-trigger" def build(self, corpus, *, min_chars): - return _CustomRetriever( - corpus.rows[-1], trigger=self.trigger, min_chars=min_chars - ) - - -class _CustomRetrieverArtifactHandlerImpl: - artifact_type = "custom-trigger" - - def can_handle(self, spec): - return isinstance(spec, _CustomRetrieverSpec) - - def serialise_spec(self, spec): - return {"trigger": spec.trigger} - - def deserialise_spec(self, *, weight, config): - return _CustomRetrieverSpec(trigger=str(config["trigger"]), weight=weight) - - def default_path(self, *, index, spec): - _ = (index, spec) - - def build_artifact(self, *, spec, corpus, path): - _ = (spec, corpus, path) - - def load_retriever(self, *, spec, corpus, min_chars, path): - _ = path - return spec.build(corpus, min_chars=min_chars) + _ = (corpus, min_chars) def test_builder_from_csv_loads_columns_and_persists_artifact(tmp_path, small_corpus): @@ -258,35 +211,17 @@ def search(self, query, n_results=10): } -def test_custom_retriever_artifact_handler_round_trips(tmp_path, small_corpus): - """Allow custom retriever specs to participate in artifact build and load.""" - artifact_dir = tmp_path / "artifact" - handler: RetrieverArtifactHandler = _CustomRetrieverArtifactHandlerImpl() - register_retriever_artifact_handler(handler) - try: - spec = _CustomRetrieverSpec(trigger="groom", weight=1.5) - builder = SAYTBuilder( - small_corpus, - retrievers=[spec], - min_chars=3, - max_suggestions=4, - ) - - builder.build_artifact(artifact_dir) - - manifest = json.loads( - (artifact_dir / "manifest.json").read_text(encoding="utf-8") - ) - suggester = SAYTSuggester.from_artifact(artifact_dir) - - assert manifest["retrievers"] == [ - { - "type": "custom-trigger", - "weight": 1.5, - "path": None, - "config": {"trigger": "groom"}, - } - ] - assert suggester.suggest("groom") == ["Dog grooming"] - finally: - unregister_retriever_artifact_handler("custom-trigger") +def test_builder_rejects_custom_runtime_only_retriever_specs(tmp_path, small_corpus): + """Persisted artifacts currently support only the built-in retriever specs.""" + builder = SAYTBuilder( + small_corpus, + retrievers=[_CustomRetrieverSpec(trigger="groom", weight=1.5)], + min_chars=3, + max_suggestions=4, + ) + + with pytest.raises( + TypeError, + match="Only built-in retriever specs can be persisted; got _CustomRetrieverSpec", + ): + builder.build_artifact(tmp_path / "artifact") diff --git a/tests/sayt/test_sayt_storage.py b/tests/sayt/test_sayt_storage.py index a1c8336..390c257 100644 --- a/tests/sayt/test_sayt_storage.py +++ b/tests/sayt/test_sayt_storage.py @@ -14,32 +14,6 @@ from industrial_classification_utils.sayt.core import CleanCorpus -class _DuplicateHandler: - artifact_type = "test-duplicate" - - def can_handle(self, spec): - _ = spec - return False - - def serialise_spec(self, spec): - _ = spec - return {} - - def deserialise_spec(self, *, weight, config): - _ = (weight, config) - return PrefixRetrieverSpec() - - def default_path(self, *, index, spec): - _ = (index, spec) - - def build_artifact(self, *, spec, corpus, path): - _ = (spec, corpus, path) - - def load_retriever(self, *, spec, corpus, min_chars, path): - _ = path - return spec.build(corpus, min_chars=min_chars) - - def test_prepare_artifact_dir_handles_existing_paths(tmp_path): """Reject accidental reuse, then replace existing directories or files.""" artifact_dir = tmp_path / "artifact" @@ -108,7 +82,7 @@ def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): def test_storage_helper_validation_errors(): - """Guard helper APIs against invalid types, paths, and missing handlers.""" + """Guard helper APIs against invalid types, paths, and unsupported specs.""" stored_retriever = storage.StoredRetrieverSpec( artifact_type="prefix", spec=PrefixRetrieverSpec(), @@ -124,10 +98,10 @@ def test_storage_helper_validation_errors(): {"type": "prefix", "weight": 1.0, "config": []} ) - with pytest.raises( - ValueError, match="No retriever artifact handler registered for type: missing" - ): - storage._get_retriever_artifact_handler("missing") + with pytest.raises(ValueError, match="Unsupported stored retriever type: missing"): + storage._deserialise_stored_retriever( + {"type": "missing", "weight": 1.0, "config": {}} + ) class _UnknownSpec: name = "unknown" @@ -138,14 +112,9 @@ def build(self, corpus, *, min_chars): with pytest.raises( TypeError, - match="No retriever artifact handler registered for spec type: _UnknownSpec", + match="Only built-in retriever specs can be persisted; got _UnknownSpec", ): - storage._get_retriever_artifact_handler_for_spec(_UnknownSpec()) - - with pytest.raises( - ValueError, match="Retriever 'semantic' requires a persisted filespace path" - ): - storage._require_path(None, "semantic") + storage._build_stored_retriever(0, _UnknownSpec()) with pytest.raises( ValueError, match="Malformed integer value for retriever field: n" @@ -157,36 +126,16 @@ def build(self, corpus, *, min_chars): ): storage._coerce_float(True, field_name="weight") - with pytest.raises( - TypeError, - match="Expected spec of type SemanticRetrieverSpec, got PrefixRetrieverSpec", - ): - storage._require_spec_type(PrefixRetrieverSpec(), SemanticRetrieverSpec) - -def test_register_retriever_artifact_handler_rejects_duplicate_registration(): - """Require replace=True before reusing an artifact type registration.""" - handler = _DuplicateHandler() - storage.register_retriever_artifact_handler(handler) - try: - with pytest.raises( - ValueError, - match="Retriever artifact handler already registered for type: test-duplicate", - ): - storage.register_retriever_artifact_handler(handler) - finally: - storage.unregister_retriever_artifact_handler(handler.artifact_type) - - -def test_semantic_artifact_handler_round_trips_and_loads( +def test_semantic_retriever_artifact_round_trips_and_loads( monkeypatch, tmp_path, small_corpus ): - """Round-trip semantic spec state and delegate dense index load/build calls.""" + """Round-trip semantic artifact state and delegate dense index load/build calls.""" captured = {} corpus = CleanCorpus.model_validate(small_corpus) - handler = storage._SemanticRetrieverArtifactHandler() spec = SemanticRetrieverSpec(model="all-MiniLM-L6-v2", weight=2.5) - path = tmp_path / "retrievers" / "02-semantic" + stored_retriever = storage._build_stored_retriever(2, spec) + path = tmp_path / stored_retriever.path def _fake_build_semantic_index(corpus_arg, *, model, output_dir, overwrite): captured["build"] = { @@ -218,20 +167,32 @@ def from_index(cls, corpus_arg, *, min_chars, index): monkeypatch.setattr(storage, "load_semantic_index", _fake_load_semantic_index) monkeypatch.setattr(storage, "SemanticRetriever", _StubSemanticRetriever) - rebuilt = handler.deserialise_spec(weight=2.5, config={"model": "all-MiniLM-L6-v2"}) + rebuilt = storage._deserialise_stored_retriever( + { + "type": stored_retriever.artifact_type, + "weight": spec.weight, + "path": stored_retriever.path, + "config": stored_retriever.config, + } + ) - assert handler.serialise_spec(spec) == {"model": "all-MiniLM-L6-v2"} - assert isinstance(rebuilt, SemanticRetrieverSpec) - assert rebuilt.weight == pytest.approx(2.5) - assert rebuilt.model == "all-MiniLM-L6-v2" - assert handler.default_path(index=2, spec=spec) == "retrievers/02-semantic" + assert stored_retriever.artifact_type == "semantic" + assert stored_retriever.config == {"model": "all-MiniLM-L6-v2"} + assert stored_retriever.path == "retrievers/02-semantic" + assert isinstance(rebuilt.spec, SemanticRetrieverSpec) + assert rebuilt.spec.weight == pytest.approx(2.5) + assert rebuilt.config == {"model": "all-MiniLM-L6-v2"} - handler.build_artifact(spec=spec, corpus=corpus, path=path) - retriever = handler.load_retriever( - spec=spec, + storage.build_retriever_artifact( + corpus=corpus, + stored_retriever=stored_retriever, + artifact_dir=tmp_path, + ) + retriever = storage.load_retriever_from_artifact( corpus=corpus, min_chars=3, - path=path, + stored_retriever=stored_retriever, + artifact_dir=tmp_path, ) assert retriever == {"index": "loaded-index", "min_chars": 3} From 33c825d63b510c538c1b535c4c290f775d66096b Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Mon, 15 Jun 2026 17:00:07 +0100 Subject: [PATCH 08/11] chore: remove tuple sorting --- src/industrial_classification_utils/sayt/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/industrial_classification_utils/sayt/core.py b/src/industrial_classification_utils/sayt/core.py index f1f101e..ffb16f9 100644 --- a/src/industrial_classification_utils/sayt/core.py +++ b/src/industrial_classification_utils/sayt/core.py @@ -283,7 +283,7 @@ def take_with_ties( items = sorted( items, - key=lambda kv: (-kv[1],), + key=lambda kv: -kv[1], ) if limit >= len(items): From 1d09299525adb368299c6c5589f6990f84224e4b Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 16 Jun 2026 14:35:43 +0100 Subject: [PATCH 09/11] refactor: build to temp directory first, then move on success refactor: remove repeated isinstance checks using slim Protocol --- .../sayt/__init__.py | 2 + .../sayt/builder.py | 68 +++++-- .../sayt/core.py | 22 +-- .../sayt/retriever_specs.py | 166 +++++++++++++++++- .../sayt/storage.py | 158 +++++++---------- .../sayt/suggester.py | 26 +-- tests/sayt/test_sayt_builder.py | 84 ++++++++- tests/sayt/test_sayt_storage.py | 24 +-- 8 files changed, 385 insertions(+), 165 deletions(-) diff --git a/src/industrial_classification_utils/sayt/__init__.py b/src/industrial_classification_utils/sayt/__init__.py index edab23b..66180d9 100644 --- a/src/industrial_classification_utils/sayt/__init__.py +++ b/src/industrial_classification_utils/sayt/__init__.py @@ -3,6 +3,7 @@ from .builder import SAYTBuilder from .core import SaytConfiguration from .retriever_specs import ( + ArtifactRetrieverSpec, NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, @@ -14,6 +15,7 @@ from .suggester import SAYTSuggester __all__ = [ + "ArtifactRetrieverSpec", "NgramRetriever", "NgramRetrieverSpec", "PrefixRetriever", diff --git a/src/industrial_classification_utils/sayt/builder.py b/src/industrial_classification_utils/sayt/builder.py index 57770d9..a28304e 100644 --- a/src/industrial_classification_utils/sayt/builder.py +++ b/src/industrial_classification_utils/sayt/builder.py @@ -3,8 +3,11 @@ # pylint: disable=duplicate-code import os +import shutil +import tempfile from collections.abc import Iterable, Sequence from pathlib import Path +from uuid import uuid4 from .core import CleanCorpus, validate_max_suggestions, validate_min_chars from .retriever_specs import ( @@ -15,12 +18,20 @@ build_artifact_manifest, build_retriever_artifact, load_corpus_from_csv, - prepare_artifact_dir, write_artifact_corpus, write_artifact_manifest, ) +def _remove_path(path: Path) -> None: + if not path.exists(): + return + if path.is_dir(): + shutil.rmtree(path) + else: + path.unlink() + + class SAYTBuilder: """Build a persisted SAYT artifact for later runtime loading.""" @@ -71,21 +82,52 @@ def build_artifact( overwrite: bool = False, ) -> Path: """Persist the current SAYT configuration and dense stores to disk.""" - artifact_dir = prepare_artifact_dir(output_dir, overwrite=overwrite) - manifest = build_artifact_manifest( - corpus=self._corpus, - min_chars=self._min_chars, - max_suggestions=self._max_suggestions, - retriever_specs=self._retriever_specs, + artifact_dir = Path(output_dir) + if artifact_dir.exists() and not overwrite: + raise FileExistsError("Artifact directory already exists") + + artifact_dir.parent.mkdir(parents=True, exist_ok=True) + staged_dir = Path( + tempfile.mkdtemp( + prefix=f".{artifact_dir.name}.tmp-", + dir=artifact_dir.parent, + ) ) - write_artifact_corpus(self._corpus, artifact_dir=artifact_dir) - for stored_retriever in manifest.retrievers: - build_retriever_artifact( + try: + manifest = build_artifact_manifest( corpus=self._corpus, - stored_retriever=stored_retriever, - artifact_dir=artifact_dir, + min_chars=self._min_chars, + max_suggestions=self._max_suggestions, + retriever_specs=self._retriever_specs, ) - write_artifact_manifest(manifest, artifact_dir=artifact_dir) + write_artifact_corpus(self._corpus, artifact_dir=staged_dir) + for stored_retriever in manifest.retrievers: + build_retriever_artifact( + corpus=self._corpus, + min_chars=self._min_chars, + stored_retriever=stored_retriever, + artifact_dir=staged_dir, + ) + + write_artifact_manifest(manifest, artifact_dir=staged_dir) + + if artifact_dir.exists(): + backup_dir = ( + artifact_dir.parent / f".{artifact_dir.name}.bak-{uuid4().hex}" + ) + artifact_dir.rename(backup_dir) + try: + staged_dir.rename(artifact_dir) + except Exception: + backup_dir.rename(artifact_dir) + raise + _remove_path(backup_dir) + else: + staged_dir.rename(artifact_dir) + except Exception: + _remove_path(staged_dir) + raise + return artifact_dir diff --git a/src/industrial_classification_utils/sayt/core.py b/src/industrial_classification_utils/sayt/core.py index ffb16f9..cec3ae1 100644 --- a/src/industrial_classification_utils/sayt/core.py +++ b/src/industrial_classification_utils/sayt/core.py @@ -46,7 +46,6 @@ class CleanCorpus(BaseModel): """ model_config = ConfigDict(arbitrary_types_allowed=True) - corpus: object rows: list[tuple[str, str, str]] = Field(default_factory=list) id_to_search: dict[str, str] = Field(default_factory=dict) id_to_display: dict[str, str] = Field(default_factory=dict) @@ -56,15 +55,22 @@ class CleanCorpus(BaseModel): @model_validator(mode="before") @classmethod def _coerce_input(cls, data: object) -> object: - if isinstance(data, cls | dict): + if isinstance(data, cls): return data - return {"corpus": data} + if isinstance(data, dict): + if "rows" in data: + return data + if "corpus" in data: + data = data["corpus"] + + return { + "rows": cls._clean_corpus( + cast(Iterable[str] | Iterable[tuple[object, object]], data) + ) + } @model_validator(mode="after") def _build_indexes(self) -> "CleanCorpus": - self.rows = self._clean_corpus( - cast(Iterable[str] | Iterable[tuple[object, object]], self.corpus) - ) return self._populate_indexes() def _populate_indexes(self) -> "CleanCorpus": @@ -102,10 +108,6 @@ def from_persisted_rows( raise ValueError("corpus is empty after filtering") corpus = cls.model_construct( - corpus=[ - (search_text, display_text) - for _, search_text, display_text in restored_rows - ], rows=restored_rows, id_to_search={}, id_to_display={}, diff --git a/src/industrial_classification_utils/sayt/retriever_specs.py b/src/industrial_classification_utils/sayt/retriever_specs.py index c144649..e6288dc 100644 --- a/src/industrial_classification_utils/sayt/retriever_specs.py +++ b/src/industrial_classification_utils/sayt/retriever_specs.py @@ -4,9 +4,16 @@ import math from dataclasses import dataclass, field -from typing import Protocol +from pathlib import Path +from typing import Protocol, runtime_checkable from .core import CleanCorpus, Suggestion +from .indexes import ( + build_ngram_index, + build_semantic_index, + load_ngram_index, + load_semantic_index, +) from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever _MIN_NGRAM_SIZE = 2 @@ -41,23 +48,58 @@ def name(self) -> str: def weight(self) -> float: """Return the finite positive weight applied during score combination.""" - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: """Build a corpus-bound retriever instance from this configuration. Args: corpus: Cleaned corpus to bind to the retriever. min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional persisted filespace directory used when a + caller wants the build outputs written to disk. + overwrite: Whether an existing filespace may be replaced when + ``filespace_path`` is provided. Returns: A configured retriever instance bound to ``corpus``. """ +@runtime_checkable +class ArtifactRetrieverSpec(RetrieverSpec, Protocol): + """Optional artifact-loading capability for persisted retriever specs.""" + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore a runtime retriever from persisted artifact state.""" + + def _validate_retriever_weight(weight: float) -> None: if not math.isfinite(weight) or weight <= 0: raise ValueError("retriever weight must be a finite value > 0") +def _require_filespace_path( + filespace_path: str | Path | None, + *, + spec_name: str, +) -> str | Path: + if filespace_path is None: + raise ValueError(f"Retriever '{spec_name}' does not have a stored filespace") + return filespace_path + + @dataclass(frozen=True, slots=True) class PrefixRetrieverSpec: """Configuration for building a prefix retriever.""" @@ -69,18 +111,40 @@ def __post_init__(self) -> None: """Validate configuration values after dataclass initialisation.""" _validate_retriever_weight(self.weight) - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: """Build a prefix retriever for the provided cleaned corpus. Args: corpus: Cleaned corpus to search. min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional persisted filespace path. Prefix + retrievers ignore this because they do not store dense state. + overwrite: Ignored for prefix retrievers. Returns: A configured ``PrefixRetriever``. """ + _ = (filespace_path, overwrite) return PrefixRetriever(corpus, min_chars=min_chars) + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore a prefix retriever directly from its runtime config.""" + _ = filespace_path + return self.build(corpus, min_chars=min_chars) + @dataclass(frozen=True, slots=True) class NgramRetrieverSpec: @@ -99,12 +163,23 @@ def __post_init__(self) -> None: if not 0.0 < self.max_df <= 1.0: raise ValueError("ngram max_df must be in (0, 1]") - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: """Build a character n-gram retriever for the provided corpus. Args: corpus: Cleaned corpus to search. min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional filespace directory. When provided, the + built dense index is persisted there. + overwrite: Whether an existing filespace may be replaced when + ``filespace_path`` is provided. Returns: A configured ``NgramRetriever``. @@ -115,11 +190,45 @@ def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: """ if self.max_df * corpus.size < 1: raise ValueError("ngram max_df is too low for the given corpus") - return NgramRetriever( + if filespace_path is None: + return NgramRetriever( + corpus, + n=self.n, + max_df=self.max_df, + min_chars=min_chars, + ) + + index = build_ngram_index( + corpus, + n=self.n, + max_df=self.max_df, + output_dir=_require_filespace_path(filespace_path, spec_name=self.name), + overwrite=overwrite, + ) + return NgramRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore an n-gram retriever from its persisted dense filespace.""" + index = load_ngram_index( corpus, n=self.n, max_df=self.max_df, + folder_path=_require_filespace_path(filespace_path, spec_name=self.name), + ) + return NgramRetriever.from_index( + corpus, min_chars=min_chars, + index=index, ) @@ -137,17 +246,60 @@ def __post_init__(self) -> None: if not isinstance(self.model, str) or not self.model.strip(): raise ValueError("semantic model must be a non-empty string") - def build(self, corpus: CleanCorpus, *, min_chars: int) -> Retriever: + def build( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None = None, + overwrite: bool = True, + ) -> Retriever: """Build a semantic retriever for the provided cleaned corpus. Args: corpus: Cleaned corpus to search. min_chars: Minimum query length required before retrieval runs. + filespace_path: Optional filespace directory. When provided, the + built dense index is persisted there. + overwrite: Whether an existing filespace may be replaced when + ``filespace_path`` is provided. Returns: A configured ``SemanticRetriever``. """ - return SemanticRetriever(corpus, model=self.model, min_chars=min_chars) + if filespace_path is None: + return SemanticRetriever(corpus, model=self.model, min_chars=min_chars) + + index = build_semantic_index( + corpus, + model=self.model, + output_dir=_require_filespace_path(filespace_path, spec_name=self.name), + overwrite=overwrite, + ) + return SemanticRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) + + def load_from_artifact( + self, + corpus: CleanCorpus, + *, + min_chars: int, + filespace_path: str | Path | None, + ) -> Retriever: + """Restore a semantic retriever from its persisted dense filespace.""" + index = load_semantic_index( + corpus, + model=self.model, + folder_path=_require_filespace_path(filespace_path, spec_name=self.name), + ) + return SemanticRetriever.from_index( + corpus, + min_chars=min_chars, + index=index, + ) def default_retriever_specs() -> list[RetrieverSpec]: diff --git a/src/industrial_classification_utils/sayt/storage.py b/src/industrial_classification_utils/sayt/storage.py index fa59605..37a2377 100644 --- a/src/industrial_classification_utils/sayt/storage.py +++ b/src/industrial_classification_utils/sayt/storage.py @@ -4,7 +4,7 @@ import json import os import shutil -from dataclasses import dataclass +from dataclasses import dataclass, fields, is_dataclass from pathlib import Path import pandas as pd @@ -15,36 +15,27 @@ validate_max_suggestions, validate_min_chars, ) -from .indexes import ( - build_ngram_index, - build_semantic_index, - load_ngram_index, - load_semantic_index, -) from .retriever_specs import ( + ArtifactRetrieverSpec, NgramRetrieverSpec, PrefixRetrieverSpec, Retriever, RetrieverSpec, SemanticRetrieverSpec, ) -from .retrievers import NgramRetriever, SemanticRetriever SAYT_ARTIFACT_TYPE = "sayt" SAYT_ARTIFACT_VERSION = 2 MANIFEST_FILE_NAME = "manifest.json" CORPUS_FILE_NAME = "corpus.csv" _ARTIFACT_CORPUS_FIELDS = ["row_id", "search_text", "display_text"] -_RETRIEVERS_DIR_NAME = "retrievers" @dataclass(frozen=True, slots=True) class StoredRetrieverSpec: """Persisted retriever spec plus its optional filespace path.""" - artifact_type: str - spec: RetrieverSpec - config: dict[str, object] + spec: ArtifactRetrieverSpec path: str | None = None @@ -218,38 +209,31 @@ def retriever_filespace_path( return Path(artifact_dir) / stored_retriever.path +def _require_artifact_spec(spec: RetrieverSpec) -> ArtifactRetrieverSpec: + if not isinstance(spec, ArtifactRetrieverSpec): + raise TypeError( + "Only artifact-aware retriever specs can be persisted; " + f"got {type(spec).__name__}" + ) + return spec + + def build_retriever_artifact( *, corpus: CleanCorpus, + min_chars: int, stored_retriever: StoredRetrieverSpec, artifact_dir: str | Path, ) -> None: """Persist built-in retriever assets required by a SAYT artifact.""" - spec = stored_retriever.spec - if isinstance(spec, PrefixRetrieverSpec): - return - - if isinstance(spec, NgramRetrieverSpec): - build_ngram_index( - corpus, - n=spec.n, - max_df=spec.max_df, - output_dir=retriever_filespace_path(artifact_dir, stored_retriever), - overwrite=True, - ) - return - - if isinstance(spec, SemanticRetrieverSpec): - build_semantic_index( - corpus, - model=spec.model, - output_dir=retriever_filespace_path(artifact_dir, stored_retriever), - overwrite=True, - ) + if stored_retriever.path is None: return - raise TypeError( - "Only built-in retriever specs can be persisted; " f"got {type(spec).__name__}" + stored_retriever.spec.build( + corpus, + min_chars=min_chars, + filespace_path=retriever_filespace_path(artifact_dir, stored_retriever), + overwrite=True, ) @@ -261,38 +245,14 @@ def load_retriever_from_artifact( artifact_dir: str | Path, ) -> Retriever: """Restore a built-in runtime retriever from persisted artifact state.""" - spec = stored_retriever.spec - if isinstance(spec, PrefixRetrieverSpec): - return spec.build(corpus, min_chars=min_chars) - - if isinstance(spec, NgramRetrieverSpec): - index = load_ngram_index( - corpus, - n=spec.n, - max_df=spec.max_df, - folder_path=retriever_filespace_path(artifact_dir, stored_retriever), - ) - return NgramRetriever.from_index( - corpus, - min_chars=min_chars, - index=index, - ) - - if isinstance(spec, SemanticRetrieverSpec): - index = load_semantic_index( - corpus, - model=spec.model, - folder_path=retriever_filespace_path(artifact_dir, stored_retriever), - ) - return SemanticRetriever.from_index( - corpus, - min_chars=min_chars, - index=index, - ) - - raise TypeError( - "Only built-in retriever specs can be restored from artifacts; " - f"got {type(spec).__name__}" + return stored_retriever.spec.load_from_artifact( + corpus, + min_chars=min_chars, + filespace_path=( + retriever_filespace_path(artifact_dir, stored_retriever) + if stored_retriever.path is not None + else None + ), ) @@ -300,32 +260,15 @@ def _build_stored_retriever( index: int, spec: RetrieverSpec, ) -> StoredRetrieverSpec: - if isinstance(spec, PrefixRetrieverSpec): - return StoredRetrieverSpec( - artifact_type="prefix", - spec=spec, - config={}, - path=None, - ) - - if isinstance(spec, NgramRetrieverSpec): - return StoredRetrieverSpec( - artifact_type="ngram", - spec=spec, - config={"n": spec.n, "max_df": spec.max_df}, - path=f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}", - ) - - if isinstance(spec, SemanticRetrieverSpec): - return StoredRetrieverSpec( - artifact_type="semantic", - spec=spec, - config={"model": spec.model}, - path=f"{_RETRIEVERS_DIR_NAME}/{index:02d}-{spec.name}", - ) + artifact_spec = _require_artifact_spec(spec) - raise TypeError( - "Only built-in retriever specs can be persisted; " f"got {type(spec).__name__}" + return StoredRetrieverSpec( + spec=artifact_spec, + path=( + None + if isinstance(artifact_spec, PrefixRetrieverSpec) + else f"retrievers/{index:02d}-{artifact_spec.name}" + ), ) @@ -347,11 +290,31 @@ def _serialise_manifest(manifest: SaytArtifactManifest) -> dict[str, object]: def _serialise_stored_retriever( stored_retriever: StoredRetrieverSpec, ) -> dict[str, object]: + spec = stored_retriever.spec + + if is_dataclass(spec): + config: dict[str, object] = { + field.name: getattr(spec, field.name) + for field in fields(spec) + if field.name not in {"name", "weight"} + } + else: + raw_config = getattr(spec, "__dict__", None) + config = ( + { + str(key): value + for key, value in raw_config.items() + if key not in {"name", "weight"} + } + if isinstance(raw_config, dict) + else {} + ) + return { - "type": stored_retriever.artifact_type, - "weight": stored_retriever.spec.weight, + "type": spec.name, + "weight": spec.weight, "path": stored_retriever.path, - "config": stored_retriever.config, + "config": config, } @@ -380,10 +343,7 @@ def _deserialise_stored_retriever(payload: dict[str, object]) -> StoredRetriever raise ValueError(f"Unsupported stored retriever type: {retriever_type}") return StoredRetrieverSpec( - artifact_type=retriever_type, - spec=spec, - config=dict(config), - path=str(path) if isinstance(path, str) else None, + spec=spec, path=str(path) if isinstance(path, str) else None ) diff --git a/src/industrial_classification_utils/sayt/suggester.py b/src/industrial_classification_utils/sayt/suggester.py index e688e32..f3ae85f 100644 --- a/src/industrial_classification_utils/sayt/suggester.py +++ b/src/industrial_classification_utils/sayt/suggester.py @@ -462,22 +462,6 @@ def _normalised_retriever_specs( return [(spec, weight / total_weight) for spec, weight in validated_specs] -def _restore_retriever_from_artifact( - *, - corpus: CleanCorpus, - min_chars: int, - stored_retriever: StoredRetrieverSpec, - artifact_dir: Path, -) -> Retriever: - """Restore a runtime retriever from persisted artifact state.""" - return load_retriever_from_artifact( - corpus=corpus, - min_chars=min_chars, - stored_retriever=stored_retriever, - artifact_dir=artifact_dir, - ) - - def _load_retrievers_from_artifact( *, corpus: CleanCorpus, @@ -493,7 +477,7 @@ def _load_retrievers_from_artifact( _ConfiguredRetriever( name=stored_retriever.spec.name, weight=weight, - retriever=_restore_retriever_from_artifact( + retriever=load_retriever_from_artifact( corpus=corpus, min_chars=min_chars, stored_retriever=stored_retriever, @@ -548,14 +532,10 @@ def _build_retriever_summary( artifact_provenance = None config = _summarise_retriever_config(spec) if stored_retriever is not None: - artifact_config = { - str(key): _jsonable_value(value) - for key, value in stored_retriever.config.items() - } artifact_provenance = SaytRetrieverArtifactProvenance( - artifact_type=stored_retriever.artifact_type, + artifact_type=stored_retriever.spec.name, path=stored_retriever.path, - config=artifact_config, + config=_summarise_retriever_config(stored_retriever.spec), ) return SaytRetrieverSummary( diff --git a/tests/sayt/test_sayt_builder.py b/tests/sayt/test_sayt_builder.py index 8bc5eb9..09ed854 100644 --- a/tests/sayt/test_sayt_builder.py +++ b/tests/sayt/test_sayt_builder.py @@ -4,6 +4,7 @@ import csv import json +from dataclasses import dataclass from pathlib import Path import pandas as pd @@ -28,6 +29,30 @@ def build(self, corpus, *, min_chars): _ = (corpus, min_chars) +@dataclass(frozen=True) +class _FailingArtifactRetrieverSpec: + name: str = "failing" + weight: float = 1.0 + + def build( + self, + corpus, + *, + min_chars, + filespace_path=None, + overwrite=True, + ): + _ = (corpus, min_chars, overwrite) + if filespace_path is not None: + Path(filespace_path).mkdir(parents=True, exist_ok=True) + Path(filespace_path, "partial.txt").write_text("partial", encoding="utf-8") + raise RuntimeError("boom") + + def load_from_artifact(self, corpus, *, min_chars, filespace_path): + _ = (corpus, min_chars, filespace_path) + raise NotImplementedError + + def test_builder_from_csv_loads_columns_and_persists_artifact(tmp_path, small_corpus): """Build an artifact from CSV input using the configured search/display columns.""" csv_path = tmp_path / "responses.csv" @@ -136,7 +161,7 @@ def __init__( # noqa: PLR0913 manifest = json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) filespace_path = artifact_dir / manifest["retrievers"][0]["path"] - assert captured["output_dir"] == str(filespace_path) + assert Path(captured["output_dir"]).name == filespace_path.name assert (filespace_path / "metadata.json").exists() assert (filespace_path / "vectors.parquet").exists() @@ -222,6 +247,61 @@ def test_builder_rejects_custom_runtime_only_retriever_specs(tmp_path, small_cor with pytest.raises( TypeError, - match="Only built-in retriever specs can be persisted; got _CustomRetrieverSpec", + match="Only artifact-aware retriever specs can be persisted; got _CustomRetrieverSpec", ): builder.build_artifact(tmp_path / "artifact") + + +def test_builder_cleans_up_staged_artifact_when_later_retriever_fails( + tmp_path, small_corpus +): + """A failed staged build should not leave a partial artifact behind.""" + artifact_dir = tmp_path / "artifact" + + builder = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec(), _FailingArtifactRetrieverSpec()], + min_chars=3, + ) + + with pytest.raises(RuntimeError, match="boom"): + builder.build_artifact(artifact_dir) + + assert not artifact_dir.exists() + assert not list(tmp_path.glob(".artifact.tmp-*")) + + +def test_builder_preserves_existing_artifact_when_staged_overwrite_fails( + tmp_path, small_corpus +): + """A failed overwrite should leave the previous complete artifact intact.""" + artifact_dir = tmp_path / "artifact" + original_builder = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec()], + min_chars=3, + max_suggestions=5, + ) + original_builder.build_artifact(artifact_dir) + + original_manifest = json.loads( + (artifact_dir / "manifest.json").read_text(encoding="utf-8") + ) + original_corpus = (artifact_dir / "corpus.csv").read_text(encoding="utf-8") + + failing_builder = SAYTBuilder( + small_corpus, + retrievers=[PrefixRetrieverSpec(), _FailingArtifactRetrieverSpec()], + min_chars=3, + max_suggestions=7, + ) + + with pytest.raises(RuntimeError, match="boom"): + failing_builder.build_artifact(artifact_dir, overwrite=True) + + assert json.loads((artifact_dir / "manifest.json").read_text(encoding="utf-8")) == ( + original_manifest + ) + assert (artifact_dir / "corpus.csv").read_text(encoding="utf-8") == original_corpus + assert not list(tmp_path.glob(".artifact.tmp-*")) + assert not list(tmp_path.glob(".artifact.bak-*")) diff --git a/tests/sayt/test_sayt_storage.py b/tests/sayt/test_sayt_storage.py index 390c257..091f5bf 100644 --- a/tests/sayt/test_sayt_storage.py +++ b/tests/sayt/test_sayt_storage.py @@ -9,6 +9,7 @@ from industrial_classification_utils.sayt import ( PrefixRetrieverSpec, SemanticRetrieverSpec, + retriever_specs, storage, ) from industrial_classification_utils.sayt.core import CleanCorpus @@ -84,9 +85,7 @@ def test_read_artifact_inputs_validate_missing_and_malformed_state(tmp_path): def test_storage_helper_validation_errors(): """Guard helper APIs against invalid types, paths, and unsupported specs.""" stored_retriever = storage.StoredRetrieverSpec( - artifact_type="prefix", spec=PrefixRetrieverSpec(), - config={}, path=None, ) @@ -112,7 +111,7 @@ def build(self, corpus, *, min_chars): with pytest.raises( TypeError, - match="Only built-in retriever specs can be persisted; got _UnknownSpec", + match="Only artifact-aware retriever specs can be persisted; got _UnknownSpec", ): storage._build_stored_retriever(0, _UnknownSpec()) @@ -163,28 +162,31 @@ def from_index(cls, corpus_arg, *, min_chars, index): } return {"index": index, "min_chars": min_chars} - monkeypatch.setattr(storage, "build_semantic_index", _fake_build_semantic_index) - monkeypatch.setattr(storage, "load_semantic_index", _fake_load_semantic_index) - monkeypatch.setattr(storage, "SemanticRetriever", _StubSemanticRetriever) + monkeypatch.setattr( + retriever_specs, "build_semantic_index", _fake_build_semantic_index + ) + monkeypatch.setattr( + retriever_specs, "load_semantic_index", _fake_load_semantic_index + ) + monkeypatch.setattr(retriever_specs, "SemanticRetriever", _StubSemanticRetriever) rebuilt = storage._deserialise_stored_retriever( { - "type": stored_retriever.artifact_type, + "type": stored_retriever.spec.name, "weight": spec.weight, "path": stored_retriever.path, - "config": stored_retriever.config, + "config": {"model": "all-MiniLM-L6-v2"}, } ) - assert stored_retriever.artifact_type == "semantic" - assert stored_retriever.config == {"model": "all-MiniLM-L6-v2"} + assert stored_retriever.spec.name == "semantic" assert stored_retriever.path == "retrievers/02-semantic" assert isinstance(rebuilt.spec, SemanticRetrieverSpec) assert rebuilt.spec.weight == pytest.approx(2.5) - assert rebuilt.config == {"model": "all-MiniLM-L6-v2"} storage.build_retriever_artifact( corpus=corpus, + min_chars=3, stored_retriever=stored_retriever, artifact_dir=tmp_path, ) From 9a74c8beca6d73c2a3ee063b3886a8ad3908512b Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 16 Jun 2026 15:47:29 +0100 Subject: [PATCH 10/11] refactor: use PrivateAttrs for derived fields in CleanCorpus --- .../sayt/core.py | 50 +++++++++++-------- .../sayt/suggester.py | 4 +- tests/sayt/test_sayt.py | 22 ++++++-- 3 files changed, 50 insertions(+), 26 deletions(-) diff --git a/src/industrial_classification_utils/sayt/core.py b/src/industrial_classification_utils/sayt/core.py index cec3ae1..7be4855 100644 --- a/src/industrial_classification_utils/sayt/core.py +++ b/src/industrial_classification_utils/sayt/core.py @@ -10,7 +10,7 @@ from uuid import NAMESPACE_URL, uuid5 import pandas as pd -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator _WS_RE = re.compile(r"\s+") _NON_ALNUM_SPACE_RE = re.compile(r"[^a-z ]+") @@ -47,10 +47,10 @@ class CleanCorpus(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) rows: list[tuple[str, str, str]] = Field(default_factory=list) - id_to_search: dict[str, str] = Field(default_factory=dict) - id_to_display: dict[str, str] = Field(default_factory=dict) - display_text_count: dict[str, int] = Field(default_factory=dict) size: int = 0 + _id_to_search: dict[str, str] = PrivateAttr(default_factory=dict) + _id_to_display: dict[str, str] = PrivateAttr(default_factory=dict) + _display_text_count: dict[str, int] = PrivateAttr(default_factory=dict) @model_validator(mode="before") @classmethod @@ -69,18 +69,35 @@ def _coerce_input(cls, data: object) -> object: ) } - @model_validator(mode="after") - def _build_indexes(self) -> "CleanCorpus": - return self._populate_indexes() + @property + def id_to_search(self) -> dict[str, str]: + """Return the row-id to normalised-search-text lookup.""" + return self._id_to_search + + @property + def id_to_display(self) -> dict[str, str]: + """Return the row-id to display-text lookup.""" + return self._id_to_display + + @property + def display_text_count(self) -> dict[str, int]: + """Return per-display-text occurrence counts for the cleaned corpus.""" + return self._display_text_count + + # Pylint does not understand Pydantic's model_post_init signature here. + def model_post_init( # pylint: disable=arguments-differ + self, __context: Any + ) -> None: + self._populate_indexes() def _populate_indexes(self) -> "CleanCorpus": """Rebuild lookup tables from the current cleaned rows.""" - self.id_to_search = {rid: search for rid, search, _ in self.rows} - self.id_to_display = {rid: display for rid, _, display in self.rows} - self.display_text_count = {} + self._id_to_search = {rid: search for rid, search, _ in self.rows} + self._id_to_display = {rid: display for rid, _, display in self.rows} + self._display_text_count = {} for _, _, display in self.rows: - self.display_text_count[display] = ( - self.display_text_count.get(display, 0) + 1 + self._display_text_count[display] = ( + self._display_text_count.get(display, 0) + 1 ) self.size = len(self.rows) return self @@ -107,14 +124,7 @@ def from_persisted_rows( if not restored_rows: raise ValueError("corpus is empty after filtering") - corpus = cls.model_construct( - rows=restored_rows, - id_to_search={}, - id_to_display={}, - display_text_count={}, - size=0, - ) - return corpus._populate_indexes() + return cls.model_construct(rows=restored_rows) @staticmethod def _coerce_persisted_row( diff --git a/src/industrial_classification_utils/sayt/suggester.py b/src/industrial_classification_utils/sayt/suggester.py index f3ae85f..8b10a13 100644 --- a/src/industrial_classification_utils/sayt/suggester.py +++ b/src/industrial_classification_utils/sayt/suggester.py @@ -133,7 +133,9 @@ def __init__( default_retriever_specs() if retrievers is None else retrievers ) self._retrievers = self._build_retrievers(self._retriever_specs) - self._max_duplication = max(self._corpus.display_text_count.values(), default=0) + self._max_duplication = max( + self._corpus._display_text_count.values(), default=0 + ) self._stored_retrievers: tuple[StoredRetrieverSpec, ...] | None = None self._artifact_provenance: SaytArtifactProvenance | None = None logger.info( diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index f93bad5..ade05ad 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -3,6 +3,7 @@ # pylint: disable=protected-access,redefined-outer-name,too-few-public-methods,C0116,W0613 import json +from collections import Counter from dataclasses import dataclass from pathlib import Path from uuid import UUID @@ -58,6 +59,18 @@ def test_clean_corpus_accepts_existing_instance_and_dict_input(small_corpus): assert dict_corpus.rows == corpus.rows +def test_clean_corpus_model_dump_excludes_derived_lookup_dicts(small_corpus): + """Keep derived lookup dictionaries out of the public model fields.""" + corpus = CleanCorpus.model_validate(small_corpus) + + dumped = corpus.model_dump() + + assert "id_to_search" not in dumped + assert "id_to_display" not in dumped + assert "display_text_count" not in dumped + assert dumped["rows"] == corpus.rows + + def test_clean_corpus_restores_persisted_rows(small_corpus): """Restore cleaned corpus rows without regenerating row identifiers.""" corpus = CleanCorpus.model_validate(small_corpus) @@ -67,9 +80,7 @@ def test_clean_corpus_restores_persisted_rows(small_corpus): ) assert restored.rows == corpus.rows - assert restored.id_to_search == corpus.id_to_search - assert restored.id_to_display == corpus.id_to_display - assert restored.display_text_count == corpus.display_text_count + assert restored.model_dump() == corpus.model_dump() def test_clean_corpus_rejects_non_iterable_input(): @@ -221,6 +232,7 @@ def test_get_config_returns_rich_runtime_summary(small_corpus): ) config = suggester.get_config() + display_counts = Counter(display for _, _, display in suggester._corpus.rows) assert isinstance(config, SaytConfiguration) assert config.settings.model_dump() == { @@ -229,8 +241,8 @@ def test_get_config_returns_rich_runtime_summary(small_corpus): } assert config.corpus.model_dump() == { "size": suggester._corpus.size, - "unique_display_texts": len(suggester._corpus.display_text_count), - "max_duplication": suggester._max_duplication, + "unique_display_texts": len(display_counts), + "max_duplication": max(display_counts.values(), default=0), } assert [retriever.name for retriever in config.retrievers] == ["prefix", "ngram"] assert config.retrievers[0].config == {} From aa1c2aeaa9119d27b342b6e87326993e02912794 Mon Sep 17 00:00:00 2001 From: Thomas Owen Date: Tue, 16 Jun 2026 16:34:36 +0100 Subject: [PATCH 11/11] refactor: use a shared base class for duplicate init and from_csv --- .pre-commit-config.yaml | 2 +- .../sayt/_base.py | 59 +++++++++++++++++ .../sayt/builder.py | 52 +-------------- .../sayt/suggester.py | 66 +++---------------- tests/sayt/test_sayt.py | 2 +- 5 files changed, 71 insertions(+), 110 deletions(-) create mode 100644 src/industrial_classification_utils/sayt/_base.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d38ad55..c4f349e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ default_stages: [pre-commit] repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v6.0.0 hooks: - id: check-ast - id: check-vcs-permalinks diff --git a/src/industrial_classification_utils/sayt/_base.py b/src/industrial_classification_utils/sayt/_base.py new file mode 100644 index 0000000..cbe181b --- /dev/null +++ b/src/industrial_classification_utils/sayt/_base.py @@ -0,0 +1,59 @@ +"""Shared bootstrap helpers for corpus-bound SAYT classes.""" + +import os +from collections.abc import Iterable, Sequence + +from .core import CleanCorpus, validate_max_suggestions, validate_min_chars +from .retriever_specs import RetrieverSpec, default_retriever_specs +from .storage import load_corpus_from_csv + + +class BaseCorpusBound: # pylint: disable=too-few-public-methods + """Shared corpus/retriever bootstrap for SAYT runtime classes.""" + + _corpus: CleanCorpus + _min_chars: int + _max_suggestions: int + _retriever_specs: tuple[RetrieverSpec, ...] + + def __init__( + self, + corpus: Iterable[tuple[object, object]] | Iterable[str], + *, + retrievers: Sequence[RetrieverSpec] | None = None, + min_chars: int = 4, + max_suggestions: int = 10, + ) -> None: + """Validate and store the shared corpus-bound SAYT configuration.""" + self._corpus = CleanCorpus.model_validate(corpus) + self._min_chars = validate_min_chars(min_chars) + self._max_suggestions = validate_max_suggestions(max_suggestions) + self._retriever_specs = tuple( + default_retriever_specs() if retrievers is None else retrievers + ) + + @classmethod + def from_csv[ # pylint: disable=too-many-arguments # noqa: PLR0913 + CorpusBoundT: "BaseCorpusBound" + ]( + cls: type[CorpusBoundT], + file_path: str | os.PathLike, + *, + search_text_col: str = "title", + display_text_col: str | None = None, + retrievers: Sequence[RetrieverSpec] | None = None, + min_chars: int = 4, + max_suggestions: int = 10, + ) -> CorpusBoundT: + """Build a corpus-bound SAYT object from CSV input.""" + corpus_rows = load_corpus_from_csv( + file_path, + search_text_col=search_text_col, + display_text_col=display_text_col, + ) + return cls( + corpus_rows, + retrievers=retrievers, + min_chars=min_chars, + max_suggestions=max_suggestions, + ) diff --git a/src/industrial_classification_utils/sayt/builder.py b/src/industrial_classification_utils/sayt/builder.py index a28304e..41d2c54 100644 --- a/src/industrial_classification_utils/sayt/builder.py +++ b/src/industrial_classification_utils/sayt/builder.py @@ -1,23 +1,15 @@ """Offline artifact builder for persisted SAYT runtime assets.""" -# pylint: disable=duplicate-code - import os import shutil import tempfile -from collections.abc import Iterable, Sequence from pathlib import Path from uuid import uuid4 -from .core import CleanCorpus, validate_max_suggestions, validate_min_chars -from .retriever_specs import ( - RetrieverSpec, - default_retriever_specs, -) +from ._base import BaseCorpusBound from .storage import ( build_artifact_manifest, build_retriever_artifact, - load_corpus_from_csv, write_artifact_corpus, write_artifact_manifest, ) @@ -32,49 +24,9 @@ def _remove_path(path: Path) -> None: path.unlink() -class SAYTBuilder: +class SAYTBuilder(BaseCorpusBound): """Build a persisted SAYT artifact for later runtime loading.""" - def __init__( - self, - corpus: Iterable[tuple[object, object]] | Iterable[str], - *, - retrievers: Sequence[RetrieverSpec] | None = None, - min_chars: int = 4, - max_suggestions: int = 10, - ) -> None: - """Initialise an artifact builder from raw corpus input.""" - self._corpus = CleanCorpus.model_validate(corpus) - self._min_chars = validate_min_chars(min_chars) - self._max_suggestions = validate_max_suggestions(max_suggestions) - self._retriever_specs = tuple( - default_retriever_specs() if retrievers is None else retrievers - ) - - @classmethod - def from_csv( # pylint: disable=too-many-arguments # noqa: PLR0913 - cls, - file_path: str | os.PathLike, - *, - search_text_col: str = "title", - display_text_col: str | None = None, - retrievers: Sequence[RetrieverSpec] | None = None, - min_chars: int = 4, - max_suggestions: int = 10, - ) -> "SAYTBuilder": - """Initialise an artifact builder from CSV input.""" - corpus_rows = load_corpus_from_csv( - file_path, - search_text_col=search_text_col, - display_text_col=display_text_col, - ) - return cls( - corpus_rows, - retrievers=retrievers, - min_chars=min_chars, - max_suggestions=max_suggestions, - ) - def build_artifact( self, output_dir: str | os.PathLike, diff --git a/src/industrial_classification_utils/sayt/suggester.py b/src/industrial_classification_utils/sayt/suggester.py index 8b10a13..80a424f 100644 --- a/src/industrial_classification_utils/sayt/suggester.py +++ b/src/industrial_classification_utils/sayt/suggester.py @@ -4,8 +4,6 @@ retrievers and combines their scores into ranked suggestions. """ -# pylint: disable=duplicate-code - import math import os from collections.abc import Iterable, Mapping, Sequence @@ -15,6 +13,7 @@ from survey_assist_utils.logging import get_logger +from ._base import BaseCorpusBound from .core import ( CleanCorpus, SaytArtifactProvenance, @@ -26,19 +25,15 @@ Suggestion, _normalise, take_with_ties, - validate_max_suggestions, - validate_min_chars, ) from .retriever_specs import ( Retriever, RetrieverSpec, - default_retriever_specs, ) from .storage import ( SAYT_ARTIFACT_TYPE, SAYT_ARTIFACT_VERSION, StoredRetrieverSpec, - load_corpus_from_csv, load_retriever_from_artifact, read_artifact_corpus, read_artifact_manifest, @@ -56,7 +51,7 @@ class _ConfiguredRetriever: retriever: Retriever -class SAYTSuggester: # pylint: disable=too-many-instance-attributes +class SAYTSuggester(BaseCorpusBound): # pylint: disable=too-many-instance-attributes """Suggest free-text responses as a user types. The suggester: @@ -125,17 +120,14 @@ def __init__( max_suggestions: Default maximum number of ranked suggestions to return. """ - self._corpus = CleanCorpus.model_validate(corpus) - self._min_chars = validate_min_chars(min_chars) - self._max_suggestions = validate_max_suggestions(max_suggestions) - - self._retriever_specs = tuple( - default_retriever_specs() if retrievers is None else retrievers + super().__init__( + corpus, + retrievers=retrievers, + min_chars=min_chars, + max_suggestions=max_suggestions, ) self._retrievers = self._build_retrievers(self._retriever_specs) - self._max_duplication = max( - self._corpus._display_text_count.values(), default=0 - ) + self._max_duplication = max(self._corpus.display_text_count.values(), default=0) self._stored_retrievers: tuple[StoredRetrieverSpec, ...] | None = None self._artifact_provenance: SaytArtifactProvenance | None = None logger.info( @@ -173,48 +165,6 @@ def _from_state( # pylint: disable=too-many-arguments # noqa: PLR0913 ) return suggester - @classmethod - def from_csv( # pylint: disable=too-many-arguments # noqa: PLR0913 - cls, - file_path: str | os.PathLike, - *, - search_text_col: str = "title", - display_text_col: str | None = None, - retrievers: Sequence[RetrieverSpec] | None = None, - min_chars: int = 4, - max_suggestions: int = 10, - ) -> "SAYTSuggester": - """Build a suggester from CSV input. - - Args: - file_path: Path to the CSV file containing suggestion rows. - search_text_col: Column containing the searchable text. - display_text_col: Optional column containing display text. When - omitted, the search column is reused for display values. - retrievers: Optional retriever specifications. When omitted, the - standard retriever set is used. - min_chars: Minimum query length before retrieval runs. - max_suggestions: Default maximum number of ranked suggestions to - return. - - Returns: - A configured ``SAYTSuggester`` instance. - - Raises: - ValueError: If the requested search or display column is missing. - """ - corpus_rows = load_corpus_from_csv( - file_path, - search_text_col=search_text_col, - display_text_col=display_text_col, - ) - return cls( - corpus_rows, - retrievers=retrievers, - min_chars=min_chars, - max_suggestions=max_suggestions, - ) - @classmethod def from_artifact(cls, artifact_dir: str | os.PathLike) -> "SAYTSuggester": """Load a suggester from a persisted SAYT artifact directory.""" diff --git a/tests/sayt/test_sayt.py b/tests/sayt/test_sayt.py index ade05ad..4eb6f73 100644 --- a/tests/sayt/test_sayt.py +++ b/tests/sayt/test_sayt.py @@ -571,7 +571,7 @@ def build(self, corpus, *, min_chars): return _StubRetriever() monkeypatch.setattr( - "industrial_classification_utils.sayt.suggester.default_retriever_specs", + "industrial_classification_utils.sayt._base.default_retriever_specs", lambda: [ _StubRetrieverSpec(name="prefix"), _StubRetrieverSpec(name="ngram"),