diff --git a/README.md b/README.md index 1c2bd01..0a86906 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ make run-docs Pytest is used for testing alongside pytest-cov for coverage testing. [/tests/conftest.py](/tests/conftest.py) defines config used by the tests. Unit testing for embedding functions is added to the [/tests/test_embedding.py](./tests/test_embedding.py) -Unit testing for utility functions is added to the [/tests/test_sic_data_access.py](./tests/test_sic_data_access.py) +SIC workbook data access lives in `sic-classification-library` (`industrial_classification.data_access.sic_data_access`); see library `tests/test_data_access.py`. ```bash make embed-tests diff --git a/docs/utils.md b/docs/utils.md index 50616cd..0c77142 100644 --- a/docs/utils.md +++ b/docs/utils.md @@ -1,3 +1,3 @@ # Embeddings Module -::: industrial_classification_utils.utils.sic_data_access +::: industrial_classification.data_access.sic_data_access diff --git a/poetry.lock b/poetry.lock index b57a0e5..75c1d0d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -6879,7 +6879,7 @@ files = [ [[package]] name = "sic-classification-library" -version = "0.1.3" +version = "0.1.5" description = "Standard Industrial Classification library" optional = false python-versions = "^3.12" @@ -6888,14 +6888,15 @@ files = [] develop = false [package.dependencies] +openpyxl = "^3.1.5" pandas = "^2.3.0" pydantic = "^2.11.7" [package.source] type = "git" url = "https://github.com/ONSdigital/sic-classification-library.git" -reference = "v0.1.4" -resolved_reference = "2f2592dfc501ea48ee1fb0a9a452bbea2e3642c1" +reference = "v0.1.5" +resolved_reference = "92fb9a7a50d6b46cea0ed06fc0fd3082c8677828" [[package]] name = "six" @@ -8604,4 +8605,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "248e0bedf9df2fef3e88728ebab42f944f15427102d4f5fd43d463aebd40db04" +content-hash = "060c92f4d31f8c00a3cec0f6621d837c49e9923090365153cbbe139fb7e8a5ee" diff --git a/pyproject.toml b/pyproject.toml index ff0bbaa..ae3be38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sic-classification-utils" -version = "0.1.13" +version = "0.1.14" description = "Utility functions used for SIC classification" readme = "README.md" requires-python = ">=3.12,<4.0" @@ -35,7 +35,7 @@ pandas = "^2.3.0" pydantic = "^2.11.1" scikit-learn = "^1.3.0" scipy = "^1.15.0" -sic-classification-library = {git = "https://github.com/ONSdigital/sic-classification-library.git", tag = "v0.1.4"} +sic-classification-library = {git = "https://github.com/ONSdigital/sic-classification-library.git", tag = "v0.1.5"} survey-assist-utils = { git = "https://github.com/ONSdigital/survey-assist-utils.git", tag = "v0.0.8" } [tool.poetry.group.dev.dependencies] diff --git a/src/industrial_classification_utils/embed/sic_specific_embed.py b/src/industrial_classification_utils/embed/sic_specific_embed.py index 026a268..ee3e738 100644 --- a/src/industrial_classification_utils/embed/sic_specific_embed.py +++ b/src/industrial_classification_utils/embed/sic_specific_embed.py @@ -9,14 +9,10 @@ import os import tempfile -from industrial_classification.hierarchy.sic_hierarchy import load_hierarchy +from industrial_classification.data_access.sic_data_access import load_sic_hierarchy from industrial_classification_utils.embed.embedding import EmbeddingHandler from industrial_classification_utils.utils.constants import get_default_config -from industrial_classification_utils.utils.sic_data_access import ( - load_sic_index, - load_sic_structure, -) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -43,13 +39,12 @@ def load_embedding_handler_from_sic_index_files( Returns: An instance of EmbeddingHandler initialized with the (published) SIC index data. """ - logger.info("Loading SIC index file: %s", sic_index_file) - sic_index_df = load_sic_index(sic_index_file) - - logger.info("Loading SIC structure file: %s", sic_structure_file) - sic_df = load_sic_structure(sic_structure_file) - - sic = load_hierarchy(sic_df, sic_index_df) + logger.info( + "Loading SIC hierarchy from index=%s structure=%s", + sic_index_file, + sic_structure_file, + ) + sic = load_sic_hierarchy(sic_index_file, sic_structure_file) df = sic.all_leaf_text() df["label"] = df["code"].apply( diff --git a/src/industrial_classification_utils/llm/llm.py b/src/industrial_classification_utils/llm/llm.py index 8f2b55e..60cd038 100644 --- a/src/industrial_classification_utils/llm/llm.py +++ b/src/industrial_classification_utils/llm/llm.py @@ -21,7 +21,10 @@ from typing import Any import numpy as np -from industrial_classification.hierarchy.sic_hierarchy import load_hierarchy +from industrial_classification.data_access.sic_data_access import ( + load_sic_hierarchy, +) +from industrial_classification.hierarchy.sic_hierarchy import SIC from industrial_classification.meta import sic_meta from langchain_core.output_parsers import PydanticOutputParser from langchain_google_vertexai import ChatVertexAI @@ -54,10 +57,6 @@ get_default_config, truncate_identifier, ) -from industrial_classification_utils.utils.sic_data_access import ( - load_sic_index, - load_sic_structure, -) logger = get_logger(__name__) config = get_default_config() @@ -130,7 +129,7 @@ def __init__( # noqa: PLR0913 self.sic_prompt_openfollowup = SIC_PROMPT_OPENFOLLOWUP self.sic_prompt_closedfollowup = SIC_PROMPT_CLOSEDFOLLOWUP self.sic_prompt_final = SIC_PROMPT_FINAL_ASSIGNMENT - self.sic = None + self.sic: SIC | None = None self.verbose = verbose @lru_cache # noqa: B019 @@ -199,11 +198,12 @@ def _prompt_candidate( str: A formatted string containing the code, title, and example activities. """ if self.sic is None: - sic_index_df = load_sic_index(config["lookups"]["sic_index"]) - sic_df = load_sic_structure(config["lookups"]["sic_structure"]) - self.sic = load_hierarchy(sic_df, sic_index_df) + self.sic = load_sic_hierarchy( + config["lookups"]["sic_index"], + config["lookups"]["sic_structure"], + ) - item = self.sic[code] # type: ignore # MyPy false positive + item = self.sic[code] txt = "{" + f"Code: {item.numeric_string_padded()}, Title: {item.description}" txt += f", Example activities: {', '.join(activities)}" if include_all: diff --git a/src/industrial_classification_utils/llm/prompt.py b/src/industrial_classification_utils/llm/prompt.py index 2a7bfbc..bf3e263 100644 --- a/src/industrial_classification_utils/llm/prompt.py +++ b/src/industrial_classification_utils/llm/prompt.py @@ -28,6 +28,7 @@ # pylint: disable=invalid-name # Need to clean up the code to remove this +from industrial_classification.data_access.sic_data_access import load_sic_index from langchain.prompts.prompt import PromptTemplate from langchain_core.output_parsers import PydanticOutputParser @@ -40,9 +41,6 @@ UnambiguousResponse, ) from industrial_classification_utils.utils.constants import get_default_config -from industrial_classification_utils.utils.sic_data_access import ( - load_sic_index, -) config = get_default_config() diff --git a/src/industrial_classification_utils/utils/sic_data_access.py b/src/industrial_classification_utils/utils/sic_data_access.py deleted file mode 100644 index 3c2c0d4..0000000 --- a/src/industrial_classification_utils/utils/sic_data_access.py +++ /dev/null @@ -1,108 +0,0 @@ -"""Provides data access for key files. - -This module contains utility functions to load and process data from -SIC-related Excel files. The filepaths for these files are defined in -the configuration function in `embedding.py`. -""" - -import logging -from importlib.resources import files - -import pandas as pd - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def load_sic_index(resource_ref: tuple[str, str]) -> pd.DataFrame: - """Loads the SIC index from an Excel file. - - The SIC index provides a list of around 15,000 activities and their - associated 5-digit SIC codes. - - Args: - resource_ref (tuple): The path to the Excel file containing the SIC index. - - Returns: - pd.DataFrame: A DataFrame containing the SIC index with columns - `uk_sic_2007` and `activity`. - """ - pkg, filename = resource_ref - file_path = files(pkg).joinpath(filename) - - logger.debug("Loading SIC index from %s", file_path) - - sic_index_df = pd.read_excel( - file_path, - sheet_name="Alphabetical Index", - skiprows=2, - usecols=["UK SIC 2007", "Activity"], - dtype=str, - ) - - sic_index_df.columns = [ - col.lower().replace(" ", "_") for col in sic_index_df.columns - ] - - return sic_index_df - - -def load_sic_structure(resource_ref: tuple[str, str]) -> pd.DataFrame: - """Loads the SIC structure from an Excel file. - - This function loads a worksheet containing all the levels and names - of the UK SIC 2007 hierarchy. - - Args: - resource_ref (tuple): The path to the Excel file containing the SIC structure. - - Returns: - pd.DataFrame: A DataFrame containing the SIC structure with columns - `description`, `section`, `most_disaggregated_level`, and `level_headings`. - """ - pkg, filename = resource_ref - file_path = files(pkg).joinpath(filename) - - logger.debug("Loading SIC structure from %s", file_path) - - sic_df = pd.read_excel( - file_path, - sheet_name="reworked structure", - usecols=[ - "Description", - "SECTION", - "Most disaggregated level", - "Level headings", - ], - dtype=str, - ) - - sic_df.columns = [col.lower().replace(" ", "_") for col in sic_df.columns] - - for col in sic_df.columns: - sic_df[col] = sic_df[col].str.strip() - - return sic_df - - -def load_text_from_config(config_section: tuple[str, str]) -> str: - """Loads text content from a configuration file. - - This function reads the content of a text file specified by the given - configuration section and returns it as a string. - - Args: - config_section (tuple[str, str]): A tuple containing the package name - and the filename of the configuration file. - - Returns: - str: The content of the configuration file as a string. - - """ - pkg, filename = config_section - file_path = files(pkg).joinpath(filename) - - logger.debug("Loading text from %s", file_path) - - with file_path.open(encoding="utf-8") as f: - return f.read() diff --git a/tests/test_embedding.py b/tests/test_embedding.py index fe816b8..9d77cb0 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -532,17 +532,9 @@ def test_load_embedding_handler_from_sic_index_files_builds(tmp_path: Path): with ( patch( - "industrial_classification_utils.embed.sic_specific_embed.load_sic_index", - return_value=MagicMock(), - ) as mock_load_index, - patch( - "industrial_classification_utils.embed.sic_specific_embed.load_sic_structure", - return_value=MagicMock(), - ) as mock_load_structure, - patch( - "industrial_classification_utils.embed.sic_specific_embed.load_hierarchy", + "industrial_classification_utils.embed.sic_specific_embed.load_sic_hierarchy", return_value=fake_sic, - ) as mock_load_hierarchy, + ) as mock_load_sic_hierarchy, patch( "industrial_classification_utils.embed.sic_specific_embed.EmbeddingHandler", return_value=built_handler, @@ -555,9 +547,9 @@ def test_load_embedding_handler_from_sic_index_files_builds(tmp_path: Path): ) assert result is built_handler - mock_load_index.assert_called_once_with("sic-index.csv") - mock_load_structure.assert_called_once_with("sic-structure.csv") - mock_load_hierarchy.assert_called_once() + mock_load_sic_hierarchy.assert_called_once_with( + "sic-index.csv", "sic-structure.csv" + ) assert mock_handler_cls.called call_kwargs = mock_handler_cls.call_args.kwargs assert call_kwargs["db_dir"] == str(tmp_path / "vector_store") @@ -800,15 +792,7 @@ def test_load_embedding_handler_from_sic_index_files_forwards_kwargs(tmp_path: P with ( patch( - "industrial_classification_utils.embed.sic_specific_embed.load_sic_index", - return_value=MagicMock(), - ), - patch( - "industrial_classification_utils.embed.sic_specific_embed.load_sic_structure", - return_value=MagicMock(), - ), - patch( - "industrial_classification_utils.embed.sic_specific_embed.load_hierarchy", + "industrial_classification_utils.embed.sic_specific_embed.load_sic_hierarchy", return_value=fake_sic, ), patch( diff --git a/tests/test_sic_data_access.py b/tests/test_sic_data_access.py deleted file mode 100644 index bcfa2ff..0000000 --- a/tests/test_sic_data_access.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Unit tests for the SIC data access utility functions. - -This module contains tests for the `load_sic_index` and `load_sic_structure` -functions from the `industrial_classification_utils.utils.sic_data_access` module. -""" - -from unittest.mock import ANY, patch - -import pandas as pd -import pytest - -from industrial_classification_utils.utils.sic_data_access import ( - load_sic_index, - load_sic_structure, -) - -# pylint: disable=redefined-outer-name -# pylint: disable=duplicate-code - - -@pytest.fixture -def mock_sic_index_data(): - """Fixture for mock SIC index data. - - Returns: - pd.DataFrame: A DataFrame containing mock SIC index data. - """ - return pd.DataFrame( - {"uk_sic_2007": ["12345", "67890"], "activity": ["Manufacturing", "Retail"]} - ) - - -@pytest.fixture -def mock_sic_structure_data(): - """Fixture for mock SIC structure data. - - Returns: - pd.DataFrame: A DataFrame containing mock SIC structure data. - """ - return pd.DataFrame( - { - "description": ["Section A", "Section B"], - "section": ["A", "B"], - "most_disaggregated_level": ["Level 1", "Level 2"], - "level_headings": ["Heading 1", "Heading 2"], - } - ) - - -@pytest.mark.utils -@patch("pandas.read_excel") -def test_load_sic_index(mock_read_excel, mock_sic_index_data): - """Test the `load_sic_index` function. - - Args: - mock_read_excel (MagicMock): Mocked `pandas.read_excel` function. - mock_sic_index_data (pd.DataFrame): Mock SIC index data. - - Asserts: - - The `pandas.read_excel` function is called with the correct arguments. - - The returned DataFrame matches the mock SIC index data. - """ - mock_read_excel.return_value = mock_sic_index_data - result = load_sic_index( - ( - "industrial_classification_utils.data.sic_index", - "uksic2007indexeswithaddendumdecember2022.xlsx", - ) - ) - - mock_read_excel.assert_called_once_with( - ANY, - sheet_name="Alphabetical Index", - skiprows=2, - usecols=["UK SIC 2007", "Activity"], - dtype=str, - ) - - # Verify the path used in the call - called_args, _ = mock_read_excel.call_args - assert str(called_args[0]).endswith("uksic2007indexeswithaddendumdecember2022.xlsx") - assert result.equals(mock_sic_index_data) - - -@pytest.mark.utils -@patch("pandas.read_excel") -def test_load_sic_structure(mock_read_excel, mock_sic_structure_data): - """Test the `load_sic_structure` function. - - Args: - mock_read_excel (MagicMock): Mocked `pandas.read_excel` function. - mock_sic_structure_data (pd.DataFrame): Mock SIC structure data. - - Asserts: - - The `pandas.read_excel` function is called with the correct arguments. - - The returned DataFrame matches the mock SIC structure data. - """ - mock_read_excel.return_value = mock_sic_structure_data - result = load_sic_structure( - ( - "industrial_classification_utils.data.sic_index", - "publisheduksicsummaryofstructureworksheet.xlsx", - ) - ) - mock_read_excel.assert_called_once_with( - ANY, - sheet_name="reworked structure", - usecols=[ - "Description", - "SECTION", - "Most disaggregated level", - "Level headings", - ], - dtype=str, - ) - - # Verify the path used in the call - called_args, _ = mock_read_excel.call_args - assert str(called_args[0]).endswith( - "publisheduksicsummaryofstructureworksheet.xlsx" - ) - - assert result.equals(mock_sic_structure_data)