diff --git a/medcat-v2/medcat/__init__.py b/medcat-v2/medcat/__init__.py index 63e20aacd..4971dade1 100644 --- a/medcat-v2/medcat/__init__.py +++ b/medcat-v2/medcat/__init__.py @@ -1,17 +1,9 @@ -from importlib.metadata import version as __version_method -from importlib.metadata import PackageNotFoundError as __PackageNotFoundError - +from medcat.version import __version__ from medcat.utils.check_for_updates import ( check_for_updates as __check_for_updates) from medcat.plugins import load_plugins as __load_plugins -try: - __version__ = __version_method("medcat") -except __PackageNotFoundError: - __version__ = "0.0.0-dev" - - # NOTE: this will not always actually do the check # it will only (by default) check once a week __check_for_updates("medcat", __version__) diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 2072db054..1706ca87e 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -27,8 +27,10 @@ from medcat.tokenizing.tokens import MutableDocument, MutableEntity from medcat.tokenizing.tokenizers import SaveableTokenizer, TOKENIZER_PREFIX from medcat.data.entities import Entity, Entities, OnlyCUIEntities -from medcat.data.model_card import ModelCard +from medcat.data.model_card import ModelCard, PipelineDescription +from medcat.data.model_card import RequiredPluginDescription from medcat.components.types import AbstractCoreComponent, HashableComponent +from medcat.components.types import CoreComponent from medcat.components.addons.addons import AddonComponent from medcat.utils.legacy.identifier import is_legacy_model_pack from medcat.utils.defaults import avoid_legacy_conversion @@ -36,6 +38,10 @@ from medcat.utils.defaults import LegacyConversionDisabledError from medcat.utils.usage_monitoring import UsageMonitor, _NoDelUM from medcat.utils.import_utils import MissingDependenciesError +from medcat.plugins.registry import plugin_registry, find_provider +import importlib.util +from medcat.utils.exceptions import MissingPluginError, MissingPluginInfo + logger = logging.getLogger(__name__) @@ -655,6 +661,11 @@ def trainer(self): self._trainer = Trainer(self.cdb, self.__call__, self._pipeline) return self._trainer + def save_model_card(self, model_card_path: str) -> None: + model_card: str = self.get_model_card(as_dict=False) + with open(model_card_path, 'w') as f: + f.write(model_card) + def save_model_pack( self, target_folder: str, pack_name: str = DEFAULT_PACK_NAME, serialiser_type: Union[str, AvailableSerialisers] = 'dill', @@ -705,10 +716,7 @@ def save_model_pack( self.config.general.nlp.modelname = internals_path # serialise serialise(serialiser_type, self, model_pack_path) - model_card: str = self.get_model_card(as_dict=False) - model_card_path = os.path.join(model_pack_path, "model_card.json") - with open(model_card_path, 'w') as f: - f.write(model_card) + self.save_model_card(os.path.join(model_pack_path, "model_card.json")) # components components_folder = os.path.join( model_pack_path, COMPONENTS_FOLDER) @@ -777,6 +785,33 @@ def attempt_unpack(cls, zip_path: str) -> str: shutil.unpack_archive(zip_path, extract_dir=model_pack_path) return model_pack_path + @classmethod + def _get_missing_plugins(cls, model_pack_path: str) -> list[MissingPluginInfo]: + model_card = cls.load_model_card_off_disk(model_pack_path, as_dict=True) + required_plugins: list[ + RequiredPluginDescription] = model_card.get("Required Plugins", []) + missing_plugins: list[MissingPluginInfo] = [] + + for plugin_info in required_plugins: + # Check if the plugin module can be imported + if importlib.util.find_spec(plugin_info["name"]) is None: + # Cast to str for safety + provided = [(str(p[0]), str(p[1])) for p in plugin_info["provides"]] + missing_plugins.append(MissingPluginInfo( + name=plugin_info["name"], + provides=provided, + author=plugin_info.get("author"), + url=plugin_info.get("url"), + )) + + if missing_plugins: + logger.warning( + "Missing required plugins for this model pack. " + "Attempting to load anyway, but it may fail. " + f"Missing: {[p['name'] for p in missing_plugins]}" + ) + return missing_plugins + @classmethod def load_model_pack(cls, model_pack_path: str, config_dict: Optional[dict] = None, @@ -796,6 +831,7 @@ def load_model_pack(cls, model_pack_path: str, Raises: ValueError: If the saved data does not represent a model pack. + MissingPluginError: If required plugins are missing for this model pack. Returns: CAT: The loaded model pack. @@ -812,22 +848,32 @@ def load_model_pack(cls, model_pack_path: str, return Converter(model_pack_path, None).convert() elif is_legacy and avoid_legacy: raise LegacyConversionDisabledError("CAT") - # NOTE: ignoring addons since they will be loaded later / separately - cat = deserialise(model_pack_path, model_load_path=model_pack_path, - ignore_folders_prefix={ - AddonComponent.NAME_PREFIX, - # NOTE: will be loaded manually - AbstractCoreComponent.NAME_PREFIX, - # tokenizer stuff internals are loaded separately - # if appropraite - TOKENIZER_PREFIX, - # components will be loaded semi-manually - # within the creation of pipe - COMPONENTS_FOLDER, - # ignore hidden files/folders - '.'}, - config_dict=config_dict, - addon_config_dict=addon_config_dict) + + # Load model card to check for required plugins + missing_plugins = cls._get_missing_plugins(model_pack_path) + + try: + # NOTE: ignoring addons since they will be loaded later / separately + cat = deserialise(model_pack_path, model_load_path=model_pack_path, + ignore_folders_prefix={ + AddonComponent.NAME_PREFIX, + # NOTE: will be loaded manually + AbstractCoreComponent.NAME_PREFIX, + # tokenizer stuff internals are loaded separately + # if appropraite + TOKENIZER_PREFIX, + # components will be loaded semi-manually + # within the creation of pipe + COMPONENTS_FOLDER, + # ignore hidden files/folders + '.'}, + config_dict=config_dict, + addon_config_dict=addon_config_dict) + except ImportError as e: + if missing_plugins: + raise MissingPluginError(missing_plugins) from e + raise + # NOTE: deserialising of components that need serialised # will be dealt with upon pipeline creation automatically if not isinstance(cat, CAT): @@ -924,6 +970,13 @@ def get_model_card(self, as_dict: bool = False) -> Union[str, ModelCard]: else: met_cat_model_cards = [] cdb_info = self.cdb.get_basic_info() + + # Pipeline Description + pipeline_description = self.describe_pipeline() + + # Required Plugins + required_plugins = self.get_required_plugins() + model_card: ModelCard = { 'Model ID': self.config.meta.hash, 'Last Modified On': self.config.meta.last_saved.isoformat(), @@ -931,6 +984,8 @@ def get_model_card(self, as_dict: bool = False) -> Union[str, ModelCard]: 'Description': self.config.meta.description, 'Source Ontology': self.config.meta.ontology, 'Location': self.config.meta.location, + 'Pipeline Description': pipeline_description, + 'Required Plugins': required_plugins, 'MetaCAT models': met_cat_model_cards, 'Basic CDB Stats': cdb_info, 'Performance': {}, # TODO @@ -943,6 +998,55 @@ def get_model_card(self, as_dict: bool = False) -> Union[str, ModelCard]: return model_card return json.dumps(model_card, indent=2, sort_keys=False) + + def describe_pipeline(self) -> PipelineDescription: + pipeline_description: PipelineDescription = {"core": {}, "addons": []} + + for component in self._pipeline.iter_all_components(): + provider = find_provider(component) + + if component.is_core(): + core_comp = cast(CoreComponent, component) + pipeline_description["core"][core_comp.get_type().name] = { + "name": component.name, + "provider": provider, + } + else: + pipeline_description["addons"].append({ + "name": component.name, + "provider": provider, + }) + return pipeline_description + + def get_required_plugins(self) -> list[RequiredPluginDescription]: + # get plugins based on pipe + req_plugins: dict[str, list[tuple[str, str]]] = {} + pipe_descr = self.describe_pipeline() + core_comps = list(pipe_descr["core"].items()) + addons = [("addon", addon) for addon in pipe_descr["addons"]] + for comp_type, comp in core_comps + addons: + provider = comp["provider"] + if provider == "medcat": + continue + if provider not in req_plugins: + req_plugins[provider] = [] + req_plugins[provider].append((comp_type, comp["name"])) + # map to plugin info + out_plugins: list[RequiredPluginDescription] = [] + for plugin_name, comp_names in req_plugins.items(): + plugin_info = plugin_registry.get_plugin_info(plugin_name) + if plugin_info is None: + continue + out_plugins.append( + { + "name": plugin_name, + "provides": comp_names, + "author": plugin_info.author, + "url": plugin_info.url, + } + ) + return out_plugins + @overload @classmethod def load_model_card_off_disk(cls, model_pack_path: str, diff --git a/medcat-v2/medcat/components/addons/addons.py b/medcat-v2/medcat/components/addons/addons.py index d32c4ce73..c173fc19b 100644 --- a/medcat-v2/medcat/components/addons/addons.py +++ b/medcat-v2/medcat/components/addons/addons.py @@ -106,3 +106,7 @@ def create_addon( """ return get_addon_creator(addon_name)( cnf, tokenizer, cdb, vocab, model_load_path) + + +def get_registered_addons() -> list[tuple[str, str]]: + return _ADDON_REGISTRY.list_components() diff --git a/medcat-v2/medcat/config/config.py b/medcat-v2/medcat/config/config.py index 0976e788f..a8bd8aa08 100644 --- a/medcat-v2/medcat/config/config.py +++ b/medcat-v2/medcat/config/config.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field, ValidationError, ConfigDict -from medcat import __version__ as medcat_version +from medcat.version import __version__ as medcat_version from medcat.utils.defaults import workers from medcat.utils.envsnapshot import Environment, get_environment_info from medcat.utils.iterutils import callback_iterator diff --git a/medcat-v2/medcat/data/model_card.py b/medcat-v2/medcat/data/model_card.py index 7867be877..448270ec4 100644 --- a/medcat-v2/medcat/data/model_card.py +++ b/medcat-v2/medcat/data/model_card.py @@ -14,6 +14,21 @@ } ) +class ComponentDescription(TypedDict): + name: str + provider: str + +class PipelineDescription(TypedDict): + core: dict[str, ComponentDescription] + addons: list[ComponentDescription] + + +class RequiredPluginDescription(TypedDict): + name: str + provides: list[tuple[str, str]] + author: str | None + url: str | None + ModelCard = TypedDict( "ModelCard", { @@ -23,6 +38,8 @@ 'Description': str, 'Source Ontology': list[str], 'Location': str, + 'Pipeline Description': PipelineDescription, + 'Required Plugins': list[RequiredPluginDescription], 'MetaCAT models': list[dict], 'Basic CDB Stats': CDBInfo, 'Performance': dict[str, Any], diff --git a/medcat-v2/medcat/model_creation/preprocess_snomed.py b/medcat-v2/medcat/model_creation/preprocess_snomed.py index 70bffac11..f42556c4c 100644 --- a/medcat-v2/medcat/model_creation/preprocess_snomed.py +++ b/medcat-v2/medcat/model_creation/preprocess_snomed.py @@ -3,7 +3,7 @@ import re import hashlib import pandas as pd -from typing import Dict, List, Optional, Tuple +from typing import Optional from dataclasses import dataclass, field from enum import Enum, auto @@ -189,12 +189,12 @@ class SupportedExtension(Enum): @dataclass class BundleDescriptor: - extensions: List[SupportedExtension] - ignores: Dict[RefSetFileType, List[SupportedExtension]] = field( + extensions: list[SupportedExtension] + ignores: dict[RefSetFileType, list[SupportedExtension]] = field( default_factory=dict) def has_invalid(self, ext: SupportedExtension, - file_types: Tuple[RefSetFileType]) -> bool: + file_types: tuple[RefSetFileType]) -> bool: for ft in file_types: if ft not in self.ignores: continue @@ -217,8 +217,8 @@ class SupportedBundles(Enum): ) -def match_partials_with_folders(exp_names: List[Tuple[str, Optional[str]]], - folder_names: List[str], +def match_partials_with_folders(exp_names: list[tuple[str, Optional[str]]], + folder_names: list[str], _group_nr1: int = 1, _group_nr2: int = 2 ) -> bool: if len(exp_names) > len(folder_names): diff --git a/medcat-v2/medcat/plugins/loader.py b/medcat-v2/medcat/plugins/loader.py index 6d4ba81d9..0821784e8 100644 --- a/medcat-v2/medcat/plugins/loader.py +++ b/medcat-v2/medcat/plugins/loader.py @@ -1,11 +1,86 @@ -from importlib.metadata import entry_points +from importlib.metadata import EntryPoint, entry_points, metadata + +from medcat.plugins.registry import PluginInfo, plugin_registry, RegisteredComponents +from medcat.plugins.registry import create_empty_reg_comps +from medcat.components.types import get_registered_components, CoreComponentType +from medcat.components.addons.addons import get_registered_addons +from medcat.utils.import_utils import get_module_base_name ENTRY_POINT_PATH = "medcat.plugins" +def _get_registered_components() -> RegisteredComponents: + registered: RegisteredComponents = create_empty_reg_comps() + for comp_type in CoreComponentType: + registered["core"][comp_type.name] = get_registered_components( + comp_type).copy() + registered["addons"].extend(get_registered_addons().copy()) + return registered + + +def _get_changes(before_load_components: RegisteredComponents, + after_load_components: RegisteredComponents + ) -> RegisteredComponents: + newly_registered: RegisteredComponents = create_empty_reg_comps() + for comp_type, components in after_load_components["core"].items(): + diff = set(components) - set(before_load_components["core"].get(comp_type, [])) + if diff: + newly_registered["core"][comp_type] = list(diff) + + diff = set(after_load_components["addons"]) - set( + before_load_components["addons"]) + if diff: + newly_registered["addons"] = list(diff) + return newly_registered + + +def _load_plugin(ep: EntryPoint) -> None: + # Get components before plugin load + before_load_components = _get_registered_components() + + # this will init the addon + ep.load() + + # Get components after plugin load + after_load_components = _get_registered_components() + + # Identify newly registered components + newly_registered: RegisteredComponents = _get_changes( + before_load_components, after_load_components) + + # Extract package metadata + # The entry point name is not necessarily the distribution name, + # so we use ep.dist.name + # if available (Python 3.10+). Otherwise, we fall back to ep.name. + # See: https://docs.python.org/3/library/importlib.metadata.html#entry-points + distribution_name = ep.dist.name if hasattr(ep, 'dist') and ep.dist else ep.name + pkg_metadata = metadata(distribution_name) + # NOTE: the .get method isn't visible to mypy prior to 3.12 though it is + # available (from Message) so just ignoring the typing stuff for now + plugin_name = pkg_metadata.get("Name", distribution_name) # type: ignore + plugin_version = pkg_metadata.get("Version") # type: ignore + plugin_author = pkg_metadata.get("Author") # type: ignore + if plugin_author is None: + plugin_author = pkg_metadata.get("Author-email") # type: ignore + plugin_url = pkg_metadata.get("Home-page") # type: ignore + if plugin_url is None: + plugin_url = pkg_metadata.get("Project-URL") # type: ignore + + # Create PluginInfo and register + plugin_info = PluginInfo( + name=plugin_name, + version=plugin_version, + author=plugin_author, + url=plugin_url, + module_paths=[get_module_base_name(ep.value)], + registered_components=newly_registered, + metadata={key: pkg_metadata[key] for key in pkg_metadata}, + ) + plugin_registry.register_plugin(plugin_info) + + def load_plugins(): eps = entry_points(group=ENTRY_POINT_PATH) for ep in eps: - # this will init the addon - ep.load() + _load_plugin(ep) diff --git a/medcat-v2/medcat/plugins/registry.py b/medcat-v2/medcat/plugins/registry.py new file mode 100644 index 000000000..adc1383c3 --- /dev/null +++ b/medcat-v2/medcat/plugins/registry.py @@ -0,0 +1,127 @@ +from typing import Any, TypedDict, cast +from dataclasses import dataclass, field +import logging + +from medcat.components.types import BaseComponent, CoreComponent + + +logger = logging.getLogger(__name__) + +class RegisteredComponents(TypedDict): + core: dict[str, list[tuple[str, str]]] + addons: list[tuple[str, str]] + + +def create_empty_reg_comps() -> RegisteredComponents: + return {"core": {}, "addons": []} + + +@dataclass +class PluginInfo: + name: str + version: str | None = None + author: str | None = None + url: str | None = None + module_paths: list[str] = field(default_factory=list) + registered_components: RegisteredComponents = field( + default_factory=create_empty_reg_comps) + metadata: dict[str, Any] = field(default_factory=dict) + + +class PluginRegistry: + def __init__(self): + self._plugins: dict[str, PluginInfo] = {} + + def register_plugin(self, plugin_info: PluginInfo): + self._plugins[plugin_info.name] = plugin_info + + def get_plugin_info(self, name: str) -> PluginInfo | None: + return self._plugins.get(name) + + def get_all_plugins(self) -> dict[str, PluginInfo]: + return self._plugins.copy() + + +plugin_registry = PluginRegistry() + + +def _late_register(component: BaseComponent, plugin_info: PluginInfo): + module_name = component.__module__ + cls_name = component.__class__.__name__ + create_new_component = component.create_new_component.__name__ + comp_descr = (module_name, f"{cls_name}.{create_new_component}") + logger.warning( + "Registering %s component '%s' (%s, %s) for plugin %s " + "during a later stage when plugin registrations are expected " + "to have already been done. This is most likely because the " + "component registration was done outside the loading of the " + "plugin by the core library. Normally it is better to import " + "medcat before registration so that things can be kept track " + "of consisetntly.", + "core" if component.is_core() else 'addon', + component.full_name, *comp_descr, plugin_info.name, + ) + if component.is_core(): + core_comp = cast(CoreComponent, component) + component_type = core_comp.get_type().name + if component_type not in plugin_info.registered_components["core"]: + plugin_info.registered_components["core"][component_type] = [] + plugin_info.registered_components[ + "core"][component_type].append(comp_descr) + else: + plugin_info.registered_components[ + "addons"].append(comp_descr) + + +def find_provider(component: BaseComponent) -> str: + all_plugins = plugin_registry.get_all_plugins() + provider = "medcat" # Default provider + component_identifier = "" + if component.is_core(): + core_comp = cast(CoreComponent, component) + component_type = core_comp.get_type().name + component_name = component.name + component_identifier = f"core:{component_type}:{component_name}" + else: + component_name = component.name + component_identifier = f"addon:{component_name}" + + # Check if this component is provided by a plugin via direct registration + for plugin_info in all_plugins.values(): + found = False + # Check core components registered by the plugin + core_comps = plugin_info.registered_components["core"].items() + for c_type, registered_comps in core_comps: + for reg_comp_name, _ in registered_comps: + if component_identifier == f"core:{c_type}:{reg_comp_name}": + provider = plugin_info.name + found = True + break + if found: + break + if found: + break + + # If not found in core, check addon components registered by the plugin + if not found: + for reg_comp_name, _ in plugin_info.registered_components["addons"]: + if component_identifier == f"addon:{reg_comp_name}": + provider = plugin_info.name + found = True + break + if found: + break + + # Fallback: If not found by explicit registration, check module paths + if provider == "medcat": + component_module = component.__class__.__module__ + for plugin_info in all_plugins.values(): + for module_path in plugin_info.module_paths: + if component_module.startswith(module_path): + provider = plugin_info.name + # register component with plugin + _late_register(component, plugin_info) + break + if provider != "medcat": # If a provider was found, break outer loop + break + return provider diff --git a/medcat-v2/medcat/utils/exceptions.py b/medcat-v2/medcat/utils/exceptions.py new file mode 100644 index 000000000..655503ec1 --- /dev/null +++ b/medcat-v2/medcat/utils/exceptions.py @@ -0,0 +1,35 @@ +from typing import TypedDict + + +class MissingPluginInfo(TypedDict): + name: str + provides: list[tuple[str, str]] + author: str | None + url: str | None + + +class MissingPluginError(ImportError): + """Custom exception raised when required plugins are missing.""" + + def __init__(self, missing_plugins: list[MissingPluginInfo], + message: str | None = None) -> None: + self.missing_plugins = missing_plugins + if message is None: + message = self._generate_message() + super().__init__(message) + + def _generate_message(self) -> str: + msg = "The following required plugins are missing:\n" + for plugin in self.missing_plugins: + msg += f" - Plugin: {plugin['name']}\n" + provided_components = ', '.join( + [f'{c_type}:{c_name}' for c_type, c_name in plugin['provides']]) + msg += f" Provides components: {provided_components}\n" + if plugin['author']: + msg += f" Author: {plugin['author']}\n" + if plugin['url']: + msg += f" URL: {plugin['url']}\n" + msg += "\n" + msg += "Please install the missing plugins to load this model pack." + return msg + diff --git a/medcat-v2/medcat/utils/import_utils.py b/medcat-v2/medcat/utils/import_utils.py index 669021341..13aac37e1 100644 --- a/medcat-v2/medcat/utils/import_utils.py +++ b/medcat-v2/medcat/utils/import_utils.py @@ -14,10 +14,24 @@ def __missing__(self, key): # Map the project name to the package needed to be imported where appropraite. # Default to the package name itself. _DEP_NAME_MAPPER = KeyDefaultDict({ - "pyahocorasick": "ahocorasick" + "pyahocorasick": "ahocorasick", + "scikit-learn": "sklearn", }) +def get_module_base_name(entry_point_value: str) -> str: + """Extracts the base module name from an entry point value string. + + Args: + entry_point_value (str): The value string of an EntryPoint object, + e.g., "my_plugin.module:load_func". + + Returns: + str: The base module name, e.g., "my_plugin.module". + """ + return entry_point_value.split(':')[0] + + def get_all_extra_deps_raw(package_name: str) -> list[str]: """Get all the dependencies for a pcakge that are for an extra component. diff --git a/medcat-v2/medcat/utils/ner/data_collator.py b/medcat-v2/medcat/utils/ner/data_collator.py index bb4489d92..488b63297 100644 --- a/medcat-v2/medcat/utils/ner/data_collator.py +++ b/medcat-v2/medcat/utils/ner/data_collator.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any import torch @@ -6,7 +6,7 @@ class CollateAndPadNER(object): def __init__(self, pad_id): self.pad_id = pad_id - def __call__(self, features: List[Any]) -> Dict[str, torch.Tensor]: + def __call__(self, features: list[Any]) -> dict[str, torch.Tensor]: batch = {} max_len = max([len(f['input_ids']) for f in features]) diff --git a/medcat-v2/medcat/version.py b/medcat-v2/medcat/version.py new file mode 100644 index 000000000..aab8cc876 --- /dev/null +++ b/medcat-v2/medcat/version.py @@ -0,0 +1,8 @@ +from importlib.metadata import version as __version_method +from importlib.metadata import PackageNotFoundError as __PackageNotFoundError + +try: + __version__ = __version_method("medcat") +except __PackageNotFoundError: + __version__ = "0.0.0-dev" + diff --git a/medcat-v2/tests/components/addons/test_addons.py b/medcat-v2/tests/components/addons/test_addons.py index 30ff12cd0..f1d234b75 100644 --- a/medcat-v2/tests/components/addons/test_addons.py +++ b/medcat-v2/tests/components/addons/test_addons.py @@ -20,6 +20,9 @@ def __init__(self, cnf: ComponentConfig): assert cnf.comp_name == self.name self.config = cnf + def is_core(self) -> bool: + return False + def __call__(self, doc): return doc @@ -62,6 +65,9 @@ def __init__(self, cnf: ComponentConfig, self._cdb = cdb self.config = cnf + def is_core(self) -> bool: + return False + def __call__(self, doc): return doc diff --git a/medcat-v2/tests/plugins/__init__.py b/medcat-v2/tests/plugins/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-v2/tests/plugins/test_loader.py b/medcat-v2/tests/plugins/test_loader.py new file mode 100644 index 000000000..5be929b29 --- /dev/null +++ b/medcat-v2/tests/plugins/test_loader.py @@ -0,0 +1,117 @@ +import unittest +from unittest.mock import patch, MagicMock +from importlib.metadata import EntryPoint + +from medcat.plugins.loader import load_plugins, _load_plugin, ENTRY_POINT_PATH, _get_changes +from medcat.plugins.registry import plugin_registry, PluginInfo, RegisteredComponents, create_empty_reg_comps +from medcat.components.types import CoreComponentType + +class TestPluginLoader(unittest.TestCase): + + def setUp(self): + # Clear the registry before each test + plugin_registry._plugins = {} + + @patch('medcat.plugins.loader.entry_points') + @patch('medcat.plugins.loader.metadata') + @patch('medcat.components.types.get_registered_components') + @patch('medcat.components.addons.addons.get_registered_addons') + def test_load_plugins_empty(self, + mock_get_registered_addons, + mock_get_registered_components, + mock_metadata, + mock_entry_points): + mock_entry_points.return_value = [] + load_plugins() + self.assertEqual(len(plugin_registry.get_all_plugins()), 0) + + @patch('medcat.plugins.loader._load_plugin') + @patch('medcat.plugins.loader.entry_points') + def test_load_plugins_multiple(self, mock_entry_points, mock_load_plugin): + mock_ep1 = MagicMock(spec=EntryPoint) + mock_ep1.name = "mock-plugin-1" + mock_ep1.value = "mock_plugin_module1:load" + mock_ep1.group = ENTRY_POINT_PATH + + mock_ep2 = MagicMock(spec=EntryPoint) + mock_ep2.name = "mock-plugin-2" + mock_ep2.value = "mock_plugin_module2:load" + mock_ep2.group = ENTRY_POINT_PATH + + mock_entry_points.return_value = [mock_ep1, mock_ep2] + load_plugins() + + self.assertEqual(mock_load_plugin.call_count, 2) + mock_load_plugin.assert_any_call(mock_ep1) + mock_load_plugin.assert_any_call(mock_ep2) + + def test_get_changes_identifies_new_components(self): + before_comps: RegisteredComponents = { + "core": { + CoreComponentType.ner.name: [] + }, + "addons": [] + } + + after_comps: RegisteredComponents = { + "core": { + CoreComponentType.ner.name: [("mock_ner", "mock_module.MockNER.create")] + }, + "addons": [("mock_addon", "mock_addon_module.MockAddon.create")] + } + + newly_registered = _get_changes(before_comps, after_comps) + + self.assertIn(CoreComponentType.ner.name, newly_registered["core"]) + self.assertEqual(newly_registered["core"][CoreComponentType.ner.name], + [("mock_ner", "mock_module.MockNER.create")]) + self.assertEqual(newly_registered["addons"], + [("mock_addon", "mock_addon_module.MockAddon.create")]) + + @patch('medcat.plugins.loader.metadata') + @patch('medcat.plugins.loader.EntryPoint.load') + @patch('medcat.plugins.loader._get_changes') + def test_load_plugin_with_different_entrypoint_and_distribution_name( + self, + mock_get_changes, + mock_ep_load, + mock_metadata): + mock_get_changes.return_value = { + "core": { + "ner": [("test_ner_comp", "test_module.TestNER.create")] + }, + "addons": [("test_addon_comp", "test_module.TestAddon.create")]} + + # Mock EntryPoint with different name and dist.name + mock_ep = MagicMock(spec=EntryPoint) + mock_ep.name = "my-plugin-entrypoint" + mock_ep.value = "my_plugin.module:load_func" + mock_ep.group = ENTRY_POINT_PATH + mock_ep.dist.name = "my-plugin-package" # Actual distribution name + + # Mock metadata to return info for the distribution name + mock_metadata.return_value = { + "Name": "My Awesome Plugin", + "Version": "0.0.1", + "Author": "Plugin Author", + "Home-page": "http://plugin.com" + } + + _load_plugin(mock_ep) + + # Assert metadata was called with the distribution name + mock_metadata.assert_called_once_with("my-plugin-package") + + # Assert the plugin was registered correctly + all_plugins = plugin_registry.get_all_plugins() + self.assertEqual(len(all_plugins), 1) + registered_plugin = all_plugins["My Awesome Plugin"] + + self.assertEqual(registered_plugin.name, "My Awesome Plugin") + self.assertEqual(registered_plugin.version, "0.0.1") + self.assertEqual(registered_plugin.author, "Plugin Author") + self.assertEqual(registered_plugin.url, "http://plugin.com") + self.assertIn(("test_ner_comp", "test_module.TestNER.create"), + registered_plugin.registered_components["core"][CoreComponentType.ner.name]) + self.assertIn(("test_addon_comp", "test_module.TestAddon.create"), + registered_plugin.registered_components["addons"]) diff --git a/medcat-v2/tests/test_cat.py b/medcat-v2/tests/test_cat.py index b3a081312..26733d2ea 100644 --- a/medcat-v2/tests/test_cat.py +++ b/medcat-v2/tests/test_cat.py @@ -20,8 +20,12 @@ from medcat.utils.defaults import AVOID_LEGACY_CONVERSION_ENVIRON from medcat.utils.defaults import LegacyConversionDisabledError from medcat.utils.config_utils import temp_changed_config +from medcat.plugins.registry import create_empty_reg_comps, plugin_registry, PluginInfo +from medcat.components.types import CoreComponentType, AbstractCoreComponent +from medcat.components.addons.addons import AddonComponent import unittest +from unittest.mock import MagicMock import tempfile import pickle import shutil @@ -1083,6 +1087,7 @@ def setUpClass(cls): cls.temp_folder = tempfile.TemporaryDirectory() cls.saved_path = cls.cat.save_model_pack( cls.temp_folder.name, change_description=cls.DESCRIPTION) + cls.model_card_path = os.path.join(cls.saved_path, "model_card.json") @classmethod def tearDownClass(cls): @@ -1095,6 +1100,29 @@ def test_can_save_model_pack(self): def test_model_adds_description(self): self.assertIn(self.DESCRIPTION, self.cat.config.meta.description) + def test_saved_has_model_card(self): + self.assertTrue(os.path.exists(self.model_card_path)) + + def test_model_card_is_json(self): + with open(self.model_card_path) as f: + mc = json.load(f) + self.assertIsInstance(mc, dict) + + def test_model_card_has_pipe_description(self): + with open(self.model_card_path) as f: + mc = json.load(f) + self.assertIn('Pipeline Description', mc) + core_descr = mc["Pipeline Description"]["core"] + for cct in CoreComponentType: + with self.subTest(f"Core component {cct.name}"): + self.assertIn(cct.name, core_descr) + + def test_model_card_has_empty_required_plugins_setion(self): + with open(self.model_card_path) as f: + mc = json.load(f) + self.assertIn('Required Plugins', mc) + self.assertFalse(mc['Required Plugins']) + class BatchingTests(unittest.TestCase): NUM_TEXTS = 100 @@ -1205,3 +1233,230 @@ def test_can_set_batch_size_per_doc(self): # has same number of texts in each batch -> doc based self.assertEqual(max(batch_lens), min(batch_lens)) self.assertEqual(max(batch_lens), exp_batches) + +class TestModelCardEnhancements(unittest.TestCase): + + def setUp(self): + # Clear the plugin registry before each test + plugin_registry._plugins = {} + + self.mock_config = MagicMock(spec=Config()) + self.mock_config.general.nlp.provider = 'regex' + self.mock_config.meta.hash = "testhash123" + self.mock_config.meta.last_saved.isoformat.return_value = "2025-12-19T12:00:00" + self.mock_config.meta.history = ["testhash123"] + self.mock_config.meta.description = "Test description" + self.mock_config.meta.ontology = ["SNOMEDCT"] + self.mock_config.meta.location = "/path/to/model" + self.mock_config.meta.medcat_version = "1.0.0" + # these will be used in model card so need values + self.mock_config.components.ner.min_name_len = 3 + self.mock_config.components.ner.upper_case_limit_len = 3 + self.mock_config.components.linking.similarity_threshold = 0.3 + self.mock_config.components.linking.filters.cuis = {} + self.mock_config.general.spell_check = True + self.mock_config.general.spell_check_len_limit = 3 + + self.mock_cdb = MagicMock(spec=CDB) + self.mock_cdb.get_basic_info.return_value = {"Number of concepts": 100} + + self.mock_pipeline = MagicMock(spec=cat.Pipeline) + self.mock_pipeline.iter_all_components.return_value = [] # Default empty + + self.mock_cat = cat.CAT(self.mock_cdb, config=self.mock_config) + self.mock_cat._pipeline = self.mock_pipeline # Override with our mock pipeline + + def test_describe_pipeline_core_components(self): + mock_core_comp = MagicMock(spec=AbstractCoreComponent) + mock_core_comp.is_core.return_value = True + mock_core_comp.get_type.return_value = CoreComponentType.ner + mock_core_comp.name = "my_ner_component" + mock_core_comp.full_name = "core:ner:my_ner_component" + + self.mock_pipeline.iter_all_components.return_value = [mock_core_comp] + + pipeline_desc = self.mock_cat.describe_pipeline() + + self.assertIn(CoreComponentType.ner.name, pipeline_desc["core"]) + self.assertEqual(pipeline_desc["core"][CoreComponentType.ner.name]["name"], "my_ner_component") + self.assertEqual(pipeline_desc["core"][CoreComponentType.ner.name]["provider"], "medcat") + + def test_describe_pipeline_addons(self): + mock_addon_comp = MagicMock(spec=AddonComponent) + mock_addon_comp.is_core.return_value = False + mock_addon_comp.name = "my_addon_component" + mock_addon_comp.full_name = "addon:my_addon_component" + + self.mock_pipeline.iter_all_components.return_value = [mock_addon_comp] + + pipeline_desc = self.mock_cat.describe_pipeline() + + self.assertEqual(len(pipeline_desc["addons"]), 1) + self.assertEqual(pipeline_desc["addons"][0]["name"], "my_addon_component") + self.assertEqual(pipeline_desc["addons"][0]["provider"], "medcat") + + def test_get_required_plugins(self): + # Mock a plugin in the registry + mock_plugin_info = PluginInfo( + name="MockPlugin", + version="1.0", + author="Mock Author", + url="http://mock.com", + registered_components={ + "core": {CoreComponentType.ner.name: [("my_ner_component", "MockNER.create")]}, + "addons": [("my_addon_component", "MockAddon.create")] + } + ) + plugin_registry.register_plugin(mock_plugin_info) + + # Mock pipeline components that this plugin provides + mock_core_comp = MagicMock(spec=AbstractCoreComponent) + mock_core_comp.is_core.return_value = True + mock_core_comp.get_type.return_value = CoreComponentType.ner + mock_core_comp.name = "my_ner_component" + mock_core_comp.full_name = "core:ner:my_ner_component" + + mock_addon_comp = MagicMock(spec=AddonComponent) + mock_addon_comp.is_core.return_value = False + mock_addon_comp.name = "my_addon_component" + mock_addon_comp.full_name = "addon:my_addon_component" + + self.mock_pipeline.iter_all_components.return_value = [mock_core_comp, mock_addon_comp] + + required_plugins = self.mock_cat.get_required_plugins() + + self.assertEqual(len(required_plugins), 1) + self.assertEqual(required_plugins[0]["name"], "MockPlugin") + self.assertIn(("ner", "my_ner_component"), required_plugins[0]["provides"]) + self.assertIn(("addon", "my_addon_component"), required_plugins[0]["provides"]) + self.assertEqual(required_plugins[0]["author"], "Mock Author") + self.assertEqual(required_plugins[0]["url"], "http://mock.com") + + @unittest.mock.patch('medcat.cat.CAT.describe_pipeline') + @unittest.mock.patch('medcat.cat.CAT.get_required_plugins') + def test_get_model_card_with_pipeline_and_plugins(self, mock_get_required_plugins, mock_describe_pipeline): + mock_describe_pipeline.return_value = {"core": {CoreComponentType.ner.name: {"name": "test_ner", "provider": "medcat"}}, "addons": []} + mock_get_required_plugins.return_value = [{"name": "TestPlugin", "provides": [("ner", "test_ner")], "author": "Test Author", "url": "http://test.com"}] + + model_card = self.mock_cat.get_model_card(as_dict=True) + + self.assertIn("Pipeline Description", model_card) + self.assertEqual(model_card["Pipeline Description"], {"core": {CoreComponentType.ner.name: {"name": "test_ner", "provider": "medcat"}}, "addons": []}) + self.assertIn("Required Plugins", model_card) + self.assertEqual(model_card["Required Plugins"], [{"name": "TestPlugin", "provides": [("ner", "test_ner")], "author": "Test Author", "url": "http://test.com"}]) + + @unittest.mock.patch('medcat.cat.CAT.describe_pipeline') + @unittest.mock.patch('medcat.cat.CAT.get_required_plugins') + def test_model_card_saved_and_loaded_from_disk(self, mock_get_required_plugins, mock_describe_pipeline): + # Setup mocks for content to be in the model card + mock_describe_pipeline.return_value = {"core": {CoreComponentType.ner.name: {"name": "test_ner_disk", "provider": "medcat"}}, "addons": []} + mock_get_required_plugins.return_value = [{"name": "TestPluginDisk", "provides": [("ner", "test_ner_disk")], "author": "Test Author Disk", "url": "http://test-disk.com"}] + + with tempfile.TemporaryDirectory() as temp_dir: + model_card_path = os.path.join(temp_dir, "model_card.json") + # Save the model pack + self.mock_cat.save_model_card(model_card_path) + + # Load the model card from disk + loaded_model_card = cat.CAT.load_model_card_off_disk(temp_dir, as_dict=True) + + self.assertIn("Pipeline Description", loaded_model_card) + self.assertEqual(loaded_model_card["Pipeline Description"], {"core": {CoreComponentType.ner.name: {"name": "test_ner_disk", "provider": "medcat"}}, "addons": []}) + self.assertIn("Required Plugins", loaded_model_card) + # NOTE: tuples get loaded as lists + self.assertEqual(loaded_model_card["Required Plugins"], [{"name": "TestPluginDisk", "provides": [["ner", "test_ner_disk"]], "author": "Test Author Disk", "url": "http://test-disk.com"}]) + + def test_describe_pipeline_with_module_path_fallback(self): + # Define a mock component class with a specific module path + class MockComponentWithModule(AbstractCoreComponent): + def is_core(self): return True + def get_type(self): return CoreComponentType.ner + name = "fallback_ner_component" + full_name = "core:ner:fallback_ner_component" + __module__ = "my_plugin_package.some_module" + + # Register a plugin with a matching module path, but no explicit registration + mock_plugin_info = PluginInfo( + name="MyPluginPackage", + version="1.0", + author="Module Author", + url="http://module-plugin.com", + module_paths=["my_plugin_package"], + registered_components=create_empty_reg_comps(), + ) + plugin_registry.register_plugin(mock_plugin_info) + + # Mock the pipeline to return an instance of our component + mock_comp_instance = MockComponentWithModule() + self.mock_pipeline.iter_all_components.return_value = [mock_comp_instance] + + pipeline_desc = self.mock_cat.describe_pipeline() + + self.assertIn(CoreComponentType.ner.name, pipeline_desc["core"]) + self.assertEqual(pipeline_desc["core"][CoreComponentType.ner.name]["name"], "fallback_ner_component") + self.assertEqual(pipeline_desc["core"][CoreComponentType.ner.name]["provider"], "MyPluginPackage") + + @unittest.mock.patch("medcat.cat.deserialise") + @unittest.mock.patch('importlib.util.find_spec') + def test_load_model_pack_with_missing_plugin_raises_error(self, mock_find_spec, mock_deserialise): + mock_find_spec.return_value = None # Simulate plugin not found + mock_deserialise.side_effect = ImportError + with tempfile.TemporaryDirectory() as temp_dir: + model_card_path = os.path.join(temp_dir, "model_card.json") + model_card_content = { + "Required Plugins": [{ + "name": "MissingPlugin", + "provides": [["core", "test_core_comp"]], + "author": "Missing Author", + "url": "http://missing.com" + }] + } + # overwrite model card + with open(model_card_path, "w") as f: + json.dump(model_card_content, f) + + with self.assertRaises(cat.MissingPluginError) as cm: + cat.CAT.load_model_pack(temp_dir) + + self.assertEqual(len(cm.exception.missing_plugins), 1) + self.assertEqual(cm.exception.missing_plugins[0]["name"], "MissingPlugin") + self.assertIn("MissingPlugin", str(cm.exception)) + + @unittest.mock.patch('importlib.util.find_spec') + def test_load_model_pack_with_available_plugin_succeeds(self, mock_find_spec): + mock_find_spec.return_value = MagicMock() # Simulate plugin found + with tempfile.TemporaryDirectory() as temp_dir: + model_card_path = os.path.join(temp_dir, "model_card.json") + model_card_content = { + "Required Plugins": [{ + "name": "AvailablePlugin", + "provides": [["core", "test_core_comp"]], + "author": "Available Author", + "url": "http://available.com" + }] + } + with open(model_card_path, "w") as f: + json.dump(model_card_content, f) + + # Mock deserialise to return a valid CAT object to avoid deeper loading issues + with unittest.mock.patch('medcat.cat.deserialise') as mock_deserialise: + mock_deserialise.return_value = self.mock_cat + loaded_cat = cat.CAT.load_model_pack(temp_dir) + self.assertIs(loaded_cat, self.mock_cat) + + @unittest.mock.patch('importlib.util.find_spec') + def test_load_model_pack_with_no_required_plugins_succeeds(self, mock_find_spec): + mock_find_spec.return_value = None # Should not be called if no required plugins + with tempfile.TemporaryDirectory() as temp_dir: + model_card_path = os.path.join(temp_dir, "model_card.json") + model_card_content = { + "Required Plugins": [] + } + with open(model_card_path, "w") as f: + json.dump(model_card_content, f) + + with unittest.mock.patch('medcat.cat.deserialise') as mock_deserialise: + mock_deserialise.return_value = self.mock_cat + loaded_cat = cat.CAT.load_model_pack(temp_dir) + self.assertIs(loaded_cat, self.mock_cat) +