Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ make run-docs
Pytest is used for testing alongside pytest-cov for coverage testing. [/tests/conftest.py](/tests/conftest.py) defines config used by the tests.

Unit testing for embedding functions is added to the [/tests/test_embedding.py](./tests/test_embedding.py)
Unit testing for utility functions is added to the [/tests/test_sic_data_access.py](./tests/test_sic_data_access.py)
SIC workbook data access lives in `sic-classification-library` (`industrial_classification.data_access.sic_data_access`); see library `tests/test_data_access.py`.

```bash
make embed-tests
Expand Down
2 changes: 1 addition & 1 deletion docs/utils.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Embeddings Module

::: industrial_classification_utils.utils.sic_data_access
::: industrial_classification.data_access.sic_data_access
9 changes: 5 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sic-classification-utils"
version = "0.1.13"
version = "0.1.14"
description = "Utility functions used for SIC classification"
readme = "README.md"
requires-python = ">=3.12,<4.0"
Expand Down Expand Up @@ -35,7 +35,7 @@ pandas = "^2.3.0"
pydantic = "^2.11.1"
scikit-learn = "^1.3.0"
scipy = "^1.15.0"
sic-classification-library = {git = "https://github.com/ONSdigital/sic-classification-library.git", tag = "v0.1.4"}
sic-classification-library = {git = "https://github.com/ONSdigital/sic-classification-library.git", tag = "v0.1.5"}
survey-assist-utils = { git = "https://github.com/ONSdigital/survey-assist-utils.git", tag = "v0.0.8" }

[tool.poetry.group.dev.dependencies]
Expand Down
19 changes: 7 additions & 12 deletions src/industrial_classification_utils/embed/sic_specific_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@
import os
import tempfile

from industrial_classification.hierarchy.sic_hierarchy import load_hierarchy
from industrial_classification.data_access.sic_data_access import load_sic_hierarchy

from industrial_classification_utils.embed.embedding import EmbeddingHandler
from industrial_classification_utils.utils.constants import get_default_config
from industrial_classification_utils.utils.sic_data_access import (
load_sic_index,
load_sic_structure,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
Expand All @@ -43,13 +39,12 @@ def load_embedding_handler_from_sic_index_files(
Returns:
An instance of EmbeddingHandler initialized with the (published) SIC index data.
"""
logger.info("Loading SIC index file: %s", sic_index_file)
sic_index_df = load_sic_index(sic_index_file)

logger.info("Loading SIC structure file: %s", sic_structure_file)
sic_df = load_sic_structure(sic_structure_file)

sic = load_hierarchy(sic_df, sic_index_df)
logger.info(
"Loading SIC hierarchy from index=%s structure=%s",
sic_index_file,
sic_structure_file,
)
sic = load_sic_hierarchy(sic_index_file, sic_structure_file)

df = sic.all_leaf_text()
df["label"] = df["code"].apply(
Expand Down
20 changes: 10 additions & 10 deletions src/industrial_classification_utils/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from typing import Any

import numpy as np
from industrial_classification.hierarchy.sic_hierarchy import load_hierarchy
from industrial_classification.data_access.sic_data_access import (
load_sic_hierarchy,
)
from industrial_classification.hierarchy.sic_hierarchy import SIC
from industrial_classification.meta import sic_meta
from langchain_core.output_parsers import PydanticOutputParser
from langchain_google_vertexai import ChatVertexAI
Expand Down Expand Up @@ -54,10 +57,6 @@
get_default_config,
truncate_identifier,
)
from industrial_classification_utils.utils.sic_data_access import (
load_sic_index,
load_sic_structure,
)

logger = get_logger(__name__)
config = get_default_config()
Expand Down Expand Up @@ -130,7 +129,7 @@ def __init__( # noqa: PLR0913
self.sic_prompt_openfollowup = SIC_PROMPT_OPENFOLLOWUP
self.sic_prompt_closedfollowup = SIC_PROMPT_CLOSEDFOLLOWUP
self.sic_prompt_final = SIC_PROMPT_FINAL_ASSIGNMENT
self.sic = None
self.sic: SIC | None = None
self.verbose = verbose

@lru_cache # noqa: B019
Expand Down Expand Up @@ -199,11 +198,12 @@ def _prompt_candidate(
str: A formatted string containing the code, title, and example activities.
"""
if self.sic is None:
sic_index_df = load_sic_index(config["lookups"]["sic_index"])
sic_df = load_sic_structure(config["lookups"]["sic_structure"])
self.sic = load_hierarchy(sic_df, sic_index_df)
self.sic = load_sic_hierarchy(
config["lookups"]["sic_index"],
config["lookups"]["sic_structure"],
)

item = self.sic[code] # type: ignore # MyPy false positive
item = self.sic[code]
txt = "{" + f"Code: {item.numeric_string_padded()}, Title: {item.description}"
txt += f", Example activities: {', '.join(activities)}"
if include_all:
Expand Down
4 changes: 1 addition & 3 deletions src/industrial_classification_utils/llm/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

# pylint: disable=invalid-name # Need to clean up the code to remove this

from industrial_classification.data_access.sic_data_access import load_sic_index
from langchain.prompts.prompt import PromptTemplate
from langchain_core.output_parsers import PydanticOutputParser

Expand All @@ -40,9 +41,6 @@
UnambiguousResponse,
)
from industrial_classification_utils.utils.constants import get_default_config
from industrial_classification_utils.utils.sic_data_access import (
load_sic_index,
)

config = get_default_config()

Expand Down
108 changes: 0 additions & 108 deletions src/industrial_classification_utils/utils/sic_data_access.py

This file was deleted.

28 changes: 6 additions & 22 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,17 +532,9 @@ def test_load_embedding_handler_from_sic_index_files_builds(tmp_path: Path):

with (
patch(
"industrial_classification_utils.embed.sic_specific_embed.load_sic_index",
return_value=MagicMock(),
) as mock_load_index,
patch(
"industrial_classification_utils.embed.sic_specific_embed.load_sic_structure",
return_value=MagicMock(),
) as mock_load_structure,
patch(
"industrial_classification_utils.embed.sic_specific_embed.load_hierarchy",
"industrial_classification_utils.embed.sic_specific_embed.load_sic_hierarchy",
return_value=fake_sic,
) as mock_load_hierarchy,
) as mock_load_sic_hierarchy,
patch(
"industrial_classification_utils.embed.sic_specific_embed.EmbeddingHandler",
return_value=built_handler,
Expand All @@ -555,9 +547,9 @@ def test_load_embedding_handler_from_sic_index_files_builds(tmp_path: Path):
)

assert result is built_handler
mock_load_index.assert_called_once_with("sic-index.csv")
mock_load_structure.assert_called_once_with("sic-structure.csv")
mock_load_hierarchy.assert_called_once()
mock_load_sic_hierarchy.assert_called_once_with(
"sic-index.csv", "sic-structure.csv"
)
assert mock_handler_cls.called
call_kwargs = mock_handler_cls.call_args.kwargs
assert call_kwargs["db_dir"] == str(tmp_path / "vector_store")
Expand Down Expand Up @@ -800,15 +792,7 @@ def test_load_embedding_handler_from_sic_index_files_forwards_kwargs(tmp_path: P

with (
patch(
"industrial_classification_utils.embed.sic_specific_embed.load_sic_index",
return_value=MagicMock(),
),
patch(
"industrial_classification_utils.embed.sic_specific_embed.load_sic_structure",
return_value=MagicMock(),
),
patch(
"industrial_classification_utils.embed.sic_specific_embed.load_hierarchy",
"industrial_classification_utils.embed.sic_specific_embed.load_sic_hierarchy",
return_value=fake_sic,
),
patch(
Expand Down
Loading
Loading