Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_stages: [pre-commit]

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v6.0.0
Comment thread
ivyONS marked this conversation as resolved.
hooks:
- id: check-ast
- id: check-vcs-permalinks
Expand Down
94 changes: 94 additions & 0 deletions demos/sayt/sayt_artifact_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Build and reload a persisted SAYT artifact with the built-in retrievers."""

# pylint: disable=duplicate-code

# %%
import json
from pathlib import Path
from tempfile import TemporaryDirectory

from industrial_classification_utils.sayt import (
NgramRetrieverSpec,
PrefixRetrieverSpec,
SAYTBuilder,
SAYTSuggester,
SemanticRetrieverSpec,
)

# %%
############# toy example to verify SAYT artifact build/load works #############
small_corpus = [
("Car wash", "Car Wash"),
("Car wash", "CAR WASH (duplicate)"),
("Car waxing", "Car Waxing"),
("Waxing car", "Car Waxing"),
("Carpentry services", "Carpentry services"),
("Dog grooming", "Dog grooming"),
("Cat grooming", "Cat grooming"),
("USed car sales", "Used car sales"),
("Car rental", "Car rental"),
("Car repair", "Car repair"),
("Car servicing", "Car servicing"),
]

retrievers = [
PrefixRetrieverSpec(),
NgramRetrieverSpec(max_df=0.8),
SemanticRetrieverSpec(),
]

# Keep the temporary directory alive across notebook cells.
# pylint: disable-next=consider-using-with
temp_dir = TemporaryDirectory(prefix="sayt_artifact_demo_")
artifact_dir = Path(temp_dir.name) / "car_services_sayt"
print("artifact will be written to:", artifact_dir)

# %%
# Semantic artifact builds may take longer the first time if the model cache
# needs to be created locally.
artifact_path = SAYTBuilder(
small_corpus,
retrievers=retrievers,
min_chars=3,
max_suggestions=5,
).build_artifact(artifact_dir, overwrite=True)

print("artifact saved to:", artifact_path)
print("artifact files:")
for path in sorted(artifact_path.rglob("*")):
if path.is_file():
print("-", path.relative_to(artifact_path))

# %%
manifest = json.loads((artifact_path / "manifest.json").read_text(encoding="utf-8"))
print(json.dumps(manifest, indent=2))

# %%
live_suggester = SAYTSuggester(
small_corpus,
retrievers=retrievers,
min_chars=3,
max_suggestions=5,
)
loaded_suggester = SAYTSuggester.from_artifact(artifact_path)

for query in ["car", "cars", "waxi", "grom", "wash", "duplicate", "auto"]:
live_suggestions = live_suggester.suggest(query, 5)
loaded_suggestions = loaded_suggester.suggest(query, 5)

print("searching for:", query)
print("live", "->", live_suggestions)
print("loaded", "->", loaded_suggestions)
print("loaded_scores", "->", loaded_suggester.suggest_with_scores(query, 5))
if live_suggestions != loaded_suggestions:
raise RuntimeError("Loaded suggester results did not match live build")
print()

# %%
# Run `temp_dir.cleanup()` when you are finished exploring the saved files.
print("artifact ready for inspection:", artifact_path)

# %%
temp_dir.cleanup()

# %%
2 changes: 1 addition & 1 deletion demos/sayt/sayt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SAYTSuggester,
SemanticRetrieverSpec,
)
from industrial_classification_utils.sayt.sayt_core import _normalise
from industrial_classification_utils.sayt.core import _normalise

logger = get_logger(__name__)
# %%
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ ignore = [
"E501",
# indentation contains tabs
"W191",
# Allow the use of older TypeVar syntax for generic functions
"UP047",
]

[tool.ruff.lint.pydocstyle]
Expand Down
12 changes: 9 additions & 3 deletions src/industrial_classification_utils/sayt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
"""Public SAYT interfaces and built-in retriever components."""

from .sayt import SAYTSuggester
from .sayt_retriever_specs import (
from .builder import SAYTBuilder
from .core import SaytConfiguration
from .retriever_specs import (
ArtifactRetrieverSpec,
NgramRetrieverSpec,
PrefixRetrieverSpec,
Retriever,
RetrieverSpec,
SemanticRetrieverSpec,
default_retriever_specs,
)
from .sayt_retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever
from .retrievers import NgramRetriever, PrefixRetriever, SemanticRetriever
from .suggester import SAYTSuggester

__all__ = [
"ArtifactRetrieverSpec",
"NgramRetriever",
"NgramRetrieverSpec",
"PrefixRetriever",
"PrefixRetrieverSpec",
"Retriever",
"RetrieverSpec",
"SAYTBuilder",
"SAYTSuggester",
"SaytConfiguration",
"SemanticRetriever",
"SemanticRetrieverSpec",
"default_retriever_specs",
Expand Down
59 changes: 59 additions & 0 deletions src/industrial_classification_utils/sayt/_base.py
Comment thread
ivyONS marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Shared bootstrap helpers for corpus-bound SAYT classes."""

import os
from collections.abc import Iterable, Sequence

from .core import CleanCorpus, validate_max_suggestions, validate_min_chars
from .retriever_specs import RetrieverSpec, default_retriever_specs
from .storage import load_corpus_from_csv


class BaseCorpusBound: # pylint: disable=too-few-public-methods
"""Shared corpus/retriever bootstrap for SAYT runtime classes."""

_corpus: CleanCorpus
_min_chars: int
_max_suggestions: int
_retriever_specs: tuple[RetrieverSpec, ...]

def __init__(
self,
corpus: Iterable[tuple[object, object]] | Iterable[str],
*,
retrievers: Sequence[RetrieverSpec] | None = None,
min_chars: int = 4,
max_suggestions: int = 10,
) -> None:
"""Validate and store the shared corpus-bound SAYT configuration."""
self._corpus = CleanCorpus.model_validate(corpus)
self._min_chars = validate_min_chars(min_chars)
self._max_suggestions = validate_max_suggestions(max_suggestions)
self._retriever_specs = tuple(
default_retriever_specs() if retrievers is None else retrievers
)

@classmethod
def from_csv[ # pylint: disable=too-many-arguments # noqa: PLR0913
CorpusBoundT: "BaseCorpusBound"
](
cls: type[CorpusBoundT],
file_path: str | os.PathLike,
*,
search_text_col: str = "title",
display_text_col: str | None = None,
retrievers: Sequence[RetrieverSpec] | None = None,
min_chars: int = 4,
max_suggestions: int = 10,
) -> CorpusBoundT:
"""Build a corpus-bound SAYT object from CSV input."""
corpus_rows = load_corpus_from_csv(
file_path,
search_text_col=search_text_col,
display_text_col=display_text_col,
)
return cls(
corpus_rows,
retrievers=retrievers,
min_chars=min_chars,
max_suggestions=max_suggestions,
)
85 changes: 85 additions & 0 deletions src/industrial_classification_utils/sayt/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Offline artifact builder for persisted SAYT runtime assets."""

import os
import shutil
import tempfile
from pathlib import Path
from uuid import uuid4

from ._base import BaseCorpusBound
from .storage import (
build_artifact_manifest,
build_retriever_artifact,
write_artifact_corpus,
write_artifact_manifest,
)


def _remove_path(path: Path) -> None:
if not path.exists():
return
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()


class SAYTBuilder(BaseCorpusBound):
"""Build a persisted SAYT artifact for later runtime loading."""

def build_artifact(
self,
output_dir: str | os.PathLike,
*,
overwrite: bool = False,
) -> Path:
"""Persist the current SAYT configuration and dense stores to disk."""
artifact_dir = Path(output_dir)
if artifact_dir.exists() and not overwrite:
raise FileExistsError("Artifact directory already exists")

artifact_dir.parent.mkdir(parents=True, exist_ok=True)
staged_dir = Path(
tempfile.mkdtemp(
prefix=f".{artifact_dir.name}.tmp-",
dir=artifact_dir.parent,
)
)

try:
manifest = build_artifact_manifest(
corpus=self._corpus,
min_chars=self._min_chars,
max_suggestions=self._max_suggestions,
retriever_specs=self._retriever_specs,
)

write_artifact_corpus(self._corpus, artifact_dir=staged_dir)
for stored_retriever in manifest.retrievers:
build_retriever_artifact(
corpus=self._corpus,
min_chars=self._min_chars,
stored_retriever=stored_retriever,
artifact_dir=staged_dir,
)

write_artifact_manifest(manifest, artifact_dir=staged_dir)

if artifact_dir.exists():
backup_dir = (
artifact_dir.parent / f".{artifact_dir.name}.bak-{uuid4().hex}"
)
artifact_dir.rename(backup_dir)
try:
staged_dir.rename(artifact_dir)
except Exception:
backup_dir.rename(artifact_dir)
raise
_remove_path(backup_dir)
else:
staged_dir.rename(artifact_dir)
except Exception:
_remove_path(staged_dir)
raise

return artifact_dir
Loading
Loading