diff --git a/src/easyscience/fitting/calculators/__init__.py b/src/easyscience/fitting/calculators/__init__.py index a3ca5d43..e2c1ea89 100644 --- a/src/easyscience/fitting/calculators/__init__.py +++ b/src/easyscience/fitting/calculators/__init__.py @@ -2,6 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """ + Initialize the calculator with a model and instrumental parameters. + + Parameters + ---------- + model : ModelBase + The physical model to calculate from. This is typically a sample + or structure definition containing fittable parameters. + instrumental_parameters : ModelBase, optional + Instrumental parameters that affect the calculation, such as + resolution, wavelength, or detector settings. + unique_name : str, optional + Unique identifier for this calculator instance. + display_name : str, optional + Human-readable name for display purposes. + **kwargs : Any + Additional calculator-specific options. + """ + if not isinstance(model, ModelBase): + raise ValueError('Model must be an instance of ModelBase') + + # Initialize ModelBase with naming + super().__init__(unique_name=unique_name, display_name=display_name) + + self._model = model + self._instrumental_parameters = instrumental_parameters + self._additional_kwargs = kwargs + + @property + def model(self) -> ModelBase: + """ + Get the current physical model. + + Returns + ------- + ModelBase + The physical model used for calculations. + """ + return self._model + + @model.setter + def model(self, new_model: ModelBase) -> None: + """ + Set a new physical model. + + Parameters + ---------- + new_model : ModelBase + The new physical model to use for calculations. + + Raises + ------ + ValueError + If the new model is None. + """ + if new_model is None: + raise ValueError('Model cannot be None') + self._model = new_model + + @property + def instrumental_parameters(self) -> Optional[ModelBase]: + """ + Get the current instrumental parameters. + + Returns + ------- + ModelBase or None + The instrumental parameters, or None if not set. + """ + return self._instrumental_parameters + + @instrumental_parameters.setter + def instrumental_parameters(self, new_parameters: Optional[ModelBase]) -> None: + """ + Set new instrumental parameters. + + Parameters + ---------- + new_parameters : ModelBase or None + The new instrumental parameters to use for calculations. + Truly optional, since instrumental parameters may not always be needed. + """ + self._instrumental_parameters = new_parameters + + def update_model(self, new_model: ModelBase) -> None: + """ + Update the physical model used for calculations. + + This is an alternative to the `model` property setter that can be + overridden by subclasses to perform additional setup when the model changes. + + Parameters + ---------- + new_model : ModelBase + The new physical model to use. + + Raises + ------ + ValueError + If the new model is None. + """ + self.model = new_model + + def update_instrumental_parameters(self, new_parameters: Optional[ModelBase]) -> None: + """ + Update the instrumental parameters used for calculations. + + This is an alternative to the `instrumental_parameters` property setter + that can be overridden by subclasses to perform additional setup when + instrumental parameters change. + + Parameters + ---------- + new_parameters : ModelBase or None + The new instrumental parameters to use. + """ + self.instrumental_parameters = new_parameters + + @property + def additional_kwargs(self) -> dict: + """ + Get additional keyword arguments passed during initialization. + + Returns a copy to prevent external modification of internal state. + + Returns + ------- + dict + Copy of the dictionary of additional kwargs passed to __init__. + """ + return dict(self._additional_kwargs) + + @abstractmethod + def calculate(self, x: np.ndarray) -> np.ndarray: + """ + Calculate theoretical values at the given points. + + This is the main calculation method that must be implemented by all + concrete calculator classes. It uses the current model and instrumental + parameters to compute theoretical predictions. + + Parameters + ---------- + x : np.ndarray + The independent variable values (e.g., Q values, angles, energies) + at which to calculate the theoretical response. + + Returns + ------- + np.ndarray + The calculated theoretical values corresponding to the input x values. + + Notes + ----- + This method is called during fitting and should be thread-safe if + parallel fitting is to be supported. + """ + ... + + def __repr__(self) -> str: + """Return a string representation of the calculator.""" + model_name = getattr(self._model, 'name', type(self._model).__name__) + instr_info = '' + if self._instrumental_parameters is not None: + instr_name = getattr( + self._instrumental_parameters, + 'name', + type(self._instrumental_parameters).__name__, # default to class name if no 'name' attribute + ) + instr_info = f', instrumental_parameters={instr_name}' + return f'{self.__class__.__name__}(model={model_name}{instr_info})' diff --git a/src/easyscience/fitting/calculators/calculator_factory.py b/src/easyscience/fitting/calculators/calculator_factory.py new file mode 100644 index 00000000..accc05a7 --- /dev/null +++ b/src/easyscience/fitting/calculators/calculator_factory.py @@ -0,0 +1,353 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """Initialize the factory with an empty calculator registry.""" + self._available_calculators: Dict[str, Type[CalculatorBase]] = {} + + def _try_register_calculator( + self, + name: str, + module_path: str, + class_name: str, + ) -> bool: + """ + Attempt to import and register a calculator class. + + This method tries to import a calculator class from the given module path. + If the import succeeds, the calculator is added to the available calculators. + If the import fails (e.g., because a dependency is not installed), the + calculator is silently skipped. + + Parameters + ---------- + name : str + The name to register the calculator under. + module_path : str + The full module path to import from (e.g., 'easyreflectometry.calculators.refl1d'). + class_name : str + The name of the calculator class within the module. + + Returns + ------- + bool + True if the calculator was successfully registered, False otherwise. + + Examples + -------- + :: + + # In a subclass __init__: + self._try_register_calculator( + 'backend_a', + 'mypackage.calculators.backend_a', + 'BackendACalculator' + ) + """ + try: + import importlib + + module = importlib.import_module(module_path) + calculator_class = getattr(module, class_name) + self._available_calculators[name] = calculator_class + return True + except (ImportError, AttributeError): + # Package not installed or class not found - skip silently + return False + except Exception: + # Any other error during import - skip silently + return False + + @property + def available_calculators(self) -> List[str]: + """ + Return a list of available calculator names. + + Returns + ------- + List[str] + Names of all calculators that can be created by this factory. + Only includes calculators whose dependencies are installed. + """ + return list(self._available_calculators.keys()) + + @abstractmethod + def create( + self, + calculator_name: str, + model: NewBase, + instrumental_parameters: Optional[NewBase] = None, + **kwargs: Any, + ) -> CalculatorBase: + """ + Create a calculator instance. + + Parameters + ---------- + calculator_name : str + The name of the calculator to create. Must be one of the names + returned by `available_calculators`. + model : NewBase + The physical model (e.g., sample) to pass to the calculator. + instrumental_parameters : NewBase, optional + Instrumental parameters to pass to the calculator. + **kwargs : Any + Additional arguments to pass to the calculator constructor. + + Returns + ------- + CalculatorBase + A new calculator instance configured with the given model and + instrumental parameters. + + Raises + ------ + ValueError + If the requested calculator_name is not available. + """ + ... + + def __repr__(self) -> str: + """Return a string representation of the factory.""" + return f'{self.__class__.__name__}(available={self.available_calculators})' + + +class SimpleCalculatorFactory(CalculatorFactoryBase): + """ + A simple implementation of a calculator factory using a dictionary registry. + + This class provides a convenient base for creating calculator factories + where calculators are registered either via `_try_register_calculator` for + dynamic discovery or directly via the `register` method. + + Parameters + ---------- + calculators : Dict[str, Type[CalculatorBase]], optional + A dictionary mapping calculator names to calculator classes. + If provided, these are added to the registry immediately. + + Examples + -------- + Using dynamic registration in a subclass:: + + class MyFactory(SimpleCalculatorFactory): + def __init__(self): + super().__init__() + self._try_register_calculator('fast', 'mypackage.fast', 'FastCalculator') + self._try_register_calculator('accurate', 'mypackage.accurate', 'AccurateCalculator') + + factory = MyFactory() + calc = factory.create('fast', model, instrument) # Only if 'fast' is installed + + Using instance-level registration:: + + factory = SimpleCalculatorFactory({ + 'custom': CustomCalculator, + }) + calc = factory.create('custom', model, instrument) + """ + + def __init__( + self, + calculators: Optional[Dict[str, Type[CalculatorBase]]] = None, + ) -> None: + """ + Initialize the factory with optional calculator registry. + + Parameters + ---------- + calculators : Dict[str, Type[CalculatorBase]], optional + A dictionary mapping calculator names to calculator classes. + If provided, these calculators are added to the registry. + """ + super().__init__() + if calculators is not None: + self._available_calculators.update(calculators) + + def create( + self, + calculator_name: str, + model: NewBase, + instrumental_parameters: Optional[NewBase] = None, + **kwargs: Any, + ) -> CalculatorBase: + """ + Create a calculator instance from the registered calculators. + + Parameters + ---------- + calculator_name : str + The name of the calculator to create. + model : NewBase + The physical model to pass to the calculator. + instrumental_parameters : NewBase, optional + Instrumental parameters to pass to the calculator. + **kwargs : Any + Additional arguments to pass to the calculator constructor. + + Returns + ------- + CalculatorBase + A new calculator instance. + + Raises + ------ + ValueError + If the calculator_name is not in the registry or is not a string. + TypeError + If model is None or instrumental_parameters has wrong type. + """ + if not isinstance(calculator_name, str): + raise ValueError(f'calculator_name must be a string, got {type(calculator_name).__name__}') + + if calculator_name not in self._available_calculators: + available = ', '.join(self.available_calculators) if self.available_calculators else 'none' + raise ValueError(f"Unknown calculator '{calculator_name}'. Available calculators: {available}") + + if model is None: + raise TypeError('Model cannot be None') + + calculator_class = self._available_calculators[calculator_name] + try: + return calculator_class(model, instrumental_parameters, **kwargs) + except Exception as e: + raise type(e)(f"Failed to create calculator '{calculator_name}': {e}") from e + + def register(self, name: str, calculator_class: Type[CalculatorBase]) -> None: + """ + Register a new calculator class with the factory. + + Parameters + ---------- + name : str + The name to register the calculator under. + calculator_class : Type[CalculatorBase] + The calculator class to register. + + Raises + ------ + TypeError + If calculator_class is not a subclass of CalculatorBase. + ValueError + If name is empty or not a string. + + Warnings + -------- + If overwriting an existing calculator, a warning is issued. + """ + # Import here to avoid circular imports at module level + import warnings + + from .calculator_base import CalculatorBase + + if not isinstance(name, str) or not name: + raise ValueError('Calculator name must be a non-empty string') + + if not (isinstance(calculator_class, type) and issubclass(calculator_class, CalculatorBase)): + raise TypeError(f'calculator_class must be a subclass of CalculatorBase, got {type(calculator_class).__name__}') + + if name in self._available_calculators: + warnings.warn(f"Overwriting existing calculator '{name}' in {self.__class__.__name__}", UserWarning, stacklevel=2) + + self._available_calculators[name] = calculator_class + + def unregister(self, name: str) -> None: + """ + Remove a calculator from the registry. + + Parameters + ---------- + name : str + The name of the calculator to remove. + + Raises + ------ + KeyError + If the calculator name is not in the registry. + """ + if name not in self._available_calculators: + raise KeyError(f"Calculator '{name}' is not registered") + del self._available_calculators[name] diff --git a/src/easyscience/fitting/calculators/interface_factory.py b/src/easyscience/fitting/calculators/interface_factory.py index ca4713fd..35956ad7 100644 --- a/src/easyscience/fitting/calculators/interface_factory.py +++ b/src/easyscience/fitting/calculators/interface_factory.py @@ -3,6 +3,7 @@ # SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project bool: class Map: def __init__(self): # A dictionary of object names and their corresponding objects - self._store = weakref.WeakValueDictionary() + self.__store = weakref.WeakValueDictionary() # A dict with object names as keys and a list of their object types as values, with weak references self.__type_dict = {} @@ -82,7 +82,7 @@ def vertices(self) -> List[str]: """ while True: try: - return list(self._store) + return list(self.__store) except RuntimeError: # Dictionary changed size during iteration, retry continue @@ -112,8 +112,8 @@ def _nested_get(self, obj_type: str) -> List[str]: return [key for key, item in self.__type_dict.items() if obj_type in item.type] def get_item_by_key(self, item_id: str) -> object: - if item_id in self._store: - return self._store[item_id] + if item_id in self.__store: + return self.__store[item_id] raise ValueError('Item not in map.') def is_known(self, vertex: object) -> bool: @@ -121,7 +121,7 @@ def is_known(self, vertex: object) -> bool: All objects should have a 'unique_name' attribute. """ - return vertex.unique_name in self._store + return vertex.unique_name in self.__store def find_type(self, vertex: object) -> List[str]: if self.is_known(vertex): @@ -137,15 +137,15 @@ def change_type(self, obj, new_type: str): def add_vertex(self, obj: object, obj_type: str = None): name = obj.unique_name - if name in self._store: + if name in self.__store: raise ValueError(f'Object name {name} already exists in the graph.') # Clean up stale entry in __type_dict if the weak reference was collected # but the finalizer hasn't run yet if name in self.__type_dict: del self.__type_dict[name] - self._store[name] = obj + self.__store[name] = obj self.__type_dict[name] = _EntryList() # Add objects type to the list of types - self.__type_dict[name].finalizer = weakref.finalize(self._store[name], self.prune, name) + self.__type_dict[name].finalizer = weakref.finalize(self.__store[name], self.prune, name) self.__type_dict[name].type = obj_type def add_edge(self, start_obj: object, end_obj: object): @@ -185,8 +185,8 @@ def prune_vertex_from_edge(self, parent_obj, child_obj): def prune(self, key: str): if key in self.__type_dict: del self.__type_dict[key] - if key in self._store: - del self._store[key] + if key in self.__store: + del self.__store[key] def find_isolated_vertices(self) -> list: """returns a list of isolated vertices.""" @@ -279,9 +279,9 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool: def _clear(self): """Reset the map to an empty state. Only to be used for testing""" - self._store.clear() + self.__store.clear() self.__type_dict.clear() gc.collect() def __repr__(self) -> str: - return f'Map object of {len(self._store)} vertices.' + return f'Map object of {len(self.__store)} vertices.' diff --git a/src/easyscience/variable/parameter.py b/src/easyscience/variable/parameter.py index 55787ad4..0df57d1f 100644 --- a/src/easyscience/variable/parameter.py +++ b/src/easyscience/variable/parameter.py @@ -1029,7 +1029,8 @@ def resolve_pending_dependencies(self) -> None: def _find_parameter_by_serializer_id(self, serializer_id: str) -> Optional['DescriptorNumber']: """Find a parameter by its serializer_id from all parameters in the global map.""" - for obj in self._global_object.map._store.values(): + for key in self._global_object.map.vertices(): + obj = self._global_object.map.get_item_by_key(key) if isinstance(obj, DescriptorNumber) and hasattr(obj, '_DescriptorNumber__serializer_id'): if obj._DescriptorNumber__serializer_id == serializer_id: return obj diff --git a/tests/unit_tests/fitting/calculators/test_calculator_base.py b/tests/unit_tests/fitting/calculators/test_calculator_base.py new file mode 100644 index 00000000..c98f0b69 --- /dev/null +++ b/tests/unit_tests/fitting/calculators/test_calculator_base.py @@ -0,0 +1,291 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project np.ndarray: + # Simple identity function for testing + return x * 2.0 + + return ConcreteCalculator + + @pytest.fixture + def calculator(self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters): + """Create a calculator instance for testing.""" + return concrete_calculator_class( + mock_model, mock_instrumental_parameters, unique_name="test_calc", display_name="TestCalc" + ) + + # Initialization tests + def test_init_with_model_only(self, clear, concrete_calculator_class, mock_model): + """Test initialization with only a model.""" + calc = concrete_calculator_class(mock_model, unique_name="test_1", display_name="Test1") + assert calc.model is mock_model + assert calc.instrumental_parameters is None + + def test_init_with_model_and_instrumental_parameters( + self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters + ): + """Test initialization with model and instrumental parameters.""" + calc = concrete_calculator_class( + mock_model, mock_instrumental_parameters, unique_name="test_2", display_name="Test2" + ) + assert calc.model is mock_model + assert calc.instrumental_parameters is mock_instrumental_parameters + + def test_init_with_kwargs(self, clear, concrete_calculator_class, mock_model): + """Test initialization with additional kwargs.""" + calc = concrete_calculator_class( + mock_model, unique_name="test_3", display_name="Test3", custom_option="value" + ) + assert calc.additional_kwargs == {"custom_option": "value"} + + def test_init_with_none_model_raises_error(self, clear, concrete_calculator_class): + """Test that initialization with None model raises ValueError.""" + with pytest.raises(ValueError, match="Model must be an instance of ModelBase"): + concrete_calculator_class(None, unique_name="test_4", display_name="Test4") + + # Model property tests + def test_model_getter(self, calculator, mock_model): + """Test model getter property.""" + assert calculator.model is mock_model + + def test_model_setter(self, calculator): + """Test model setter property.""" + new_model = create_mock_model("NewModel") + calculator.model = new_model + assert calculator.model is new_model + + def test_model_setter_with_none_raises_error(self, calculator): + """Test that setting model to None raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + calculator.model = None + + # Instrumental parameters property tests + def test_instrumental_parameters_getter(self, calculator, mock_instrumental_parameters): + """Test instrumental_parameters getter property.""" + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_instrumental_parameters_setter(self, calculator): + """Test instrumental_parameters setter property.""" + new_params = create_mock_model("NewInstrument") + calculator.instrumental_parameters = new_params + assert calculator.instrumental_parameters is new_params + + def test_instrumental_parameters_setter_with_none(self, calculator): + """Test that instrumental_parameters can be set to None.""" + calculator.instrumental_parameters = None + assert calculator.instrumental_parameters is None + + # Update methods tests + def test_update_model(self, calculator): + """Test update_model method.""" + new_model = create_mock_model("UpdatedModel") + calculator.update_model(new_model) + assert calculator.model is new_model + + def test_update_model_with_none_raises_error(self, calculator): + """Test that update_model with None raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + calculator.update_model(None) + + def test_update_instrumental_parameters(self, calculator): + """Test update_instrumental_parameters method.""" + new_params = create_mock_model("UpdatedInstrument") + calculator.update_instrumental_parameters(new_params) + assert calculator.instrumental_parameters is new_params + + def test_update_instrumental_parameters_with_none(self, calculator): + """Test that update_instrumental_parameters accepts None.""" + calculator.update_instrumental_parameters(None) + assert calculator.instrumental_parameters is None + + # Calculate method tests + def test_calculate_returns_array(self, calculator): + """Test that calculate returns an array.""" + x = np.array([1.0, 2.0, 3.0]) + result = calculator.calculate(x) + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, np.array([2.0, 4.0, 6.0])) + + def test_calculate_with_empty_array(self, calculator): + """Test calculate with empty array.""" + x = np.array([]) + result = calculator.calculate(x) + assert len(result) == 0 + + # Abstract method enforcement tests + def test_cannot_instantiate_abstract_class(self, clear): + """Test that CalculatorBase cannot be instantiated directly.""" + mock_model = create_mock_model() + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + CalculatorBase(mock_model) + + def test_subclass_must_implement_calculate(self, clear): + """Test that subclasses must implement calculate method.""" + mock_model = create_mock_model() + + class IncompleteCalculator(CalculatorBase): + pass # Does not implement calculate + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteCalculator(mock_model) + + # Representation tests + def test_repr_with_model_only(self, clear, concrete_calculator_class, mock_model): + """Test __repr__ with only model.""" + calc = concrete_calculator_class(mock_model, unique_name="test_5", display_name="Test5") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "model=MockModel" in repr_str + assert "instrumental_parameters" not in repr_str + + def test_repr_with_model_and_instrumental_parameters( + self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters + ): + """Test __repr__ with model and instrumental parameters.""" + calc = concrete_calculator_class(mock_model, mock_instrumental_parameters, unique_name="test_6", display_name="Test6") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "model=" in repr_str + assert "instrumental_parameters=" in repr_str + + def test_repr_with_model_without_name_attribute(self, clear, concrete_calculator_class): + """Test __repr__ when model has no explicit name attribute (uses class name).""" + model = create_mock_model() # ModelBase without explicit name + calc = concrete_calculator_class(model, unique_name="test_7", display_name="Test7") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + # ModelBase subclass name will appear + assert "MockModel" in repr_str or "model=" in repr_str + + # Name attribute tests + def test_calculator_name_attribute(self, calculator): + """Test that calculator has name attribute.""" + assert calculator.name == "test_calculator" + + def test_default_name_is_base(self): + """Test that default name is 'base'.""" + assert CalculatorBase.name == "base" + + # Additional kwargs property tests + def test_additional_kwargs_with_init(self, clear, concrete_calculator_class, mock_model): + """Test additional_kwargs property with kwargs in init.""" + calc = concrete_calculator_class( + mock_model, + unique_name="test_8", + display_name="Test8", + custom_option="value", + numeric_param=42 + ) + assert calc.additional_kwargs == {"custom_option": "value", "numeric_param": 42} + + def test_additional_kwargs_empty_by_default(self, clear, concrete_calculator_class, mock_model): + """Test that additional_kwargs is empty dict when no kwargs provided.""" + calc = concrete_calculator_class(mock_model, unique_name="test_9", display_name="Test9") + assert calc.additional_kwargs == {} + + +class TestCalculatorBaseWithRealModel: + """Integration-style tests using actual EasyScience objects.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def real_parameter(self, clear): + """Create a real Parameter object.""" + from easyscience.variable import Parameter + return Parameter("test_param", value=5.0, unit="m") + + @pytest.fixture + def concrete_calculator_class(self): + """Create a concrete implementation that uses model parameters.""" + + class ParameterAwareCalculator(CalculatorBase): + name = "param_aware" + + def calculate(self, x: np.ndarray) -> np.ndarray: + # Access parameter from model if available + if hasattr(self._model, 'get_parameters'): + params = self._model.get_parameters() + if params: + scale = params[0].value + return x * scale + return x + + return ParameterAwareCalculator + + def test_calculator_can_access_model_parameters( + self, clear, concrete_calculator_class, real_parameter + ): + """Test that calculator can access parameters from model.""" + # Create a model that returns our real parameter + class TestModel(ModelBase): + def __init__(self, param): + super().__init__(display_name="TestModel") + self._param = param + + def get_parameters(self): + return [self._param] + + model = TestModel(real_parameter) + + calc = concrete_calculator_class(model, unique_name="test_10", display_name="Test10") + x = np.array([1.0, 2.0, 3.0]) + result = calc.calculate(x) + + # Should multiply by parameter value (5.0) + np.testing.assert_array_equal(result, np.array([5.0, 10.0, 15.0])) diff --git a/tests/unit_tests/fitting/calculators/test_calculator_factory.py b/tests/unit_tests/fitting/calculators/test_calculator_factory.py new file mode 100644 index 00000000..561699ac --- /dev/null +++ b/tests/unit_tests/fitting/calculators/test_calculator_factory.py @@ -0,0 +1,771 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project np.ndarray: + return x * 2.0 + + return TestCalculator + + @pytest.fixture + def concrete_factory_class(self, concrete_calculator_class): + """Create a concrete factory implementation.""" + + calc_class = concrete_calculator_class + + class TestFactory(CalculatorFactoryBase): + def __init__(self): + super().__init__() + self._available_calculators["test"] = calc_class + + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + if calculator_name not in self._available_calculators: + raise ValueError(f"Unknown calculator: {calculator_name}") + return self._available_calculators[calculator_name](model, instrumental_parameters, **kwargs) + + return TestFactory + + # Abstract class enforcement tests + def test_cannot_instantiate_abstract_factory(self): + """Test that CalculatorFactoryBase cannot be instantiated directly.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + CalculatorFactoryBase() + + def test_subclass_must_implement_create(self): + """Test that subclasses must implement create method.""" + + class IncompleteFactory(CalculatorFactoryBase): + pass + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteFactory() + + # Concrete factory tests + def test_factory_available_calculators(self, concrete_factory_class): + """Test available_calculators property.""" + factory = concrete_factory_class() + assert factory.available_calculators == ["test"] + + def test_factory_create_calculator(self, concrete_factory_class, mock_model, mock_instrumental_parameters): + """Test creating a calculator via factory.""" + factory = concrete_factory_class() + calculator = factory.create("test", mock_model, mock_instrumental_parameters) + assert isinstance(calculator, CalculatorBase) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_factory_create_with_model_only(self, concrete_factory_class, mock_model): + """Test creating calculator with only model.""" + factory = concrete_factory_class() + calculator = factory.create("test", mock_model) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is None + + def test_factory_create_unknown_calculator_raises_error(self, concrete_factory_class, mock_model): + """Test that creating unknown calculator raises ValueError.""" + factory = concrete_factory_class() + with pytest.raises(ValueError, match="Unknown calculator"): + factory.create("unknown", mock_model) + + # Repr tests + def test_factory_repr(self, concrete_factory_class): + """Test factory __repr__.""" + factory = concrete_factory_class() + repr_str = repr(factory) + assert "TestFactory" in repr_str + assert "test" in repr_str + + +class TestSimpleCalculatorFactory: + """Tests for SimpleCalculatorFactory class.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def mock_model(self, clear): + """Create a mock model object.""" + + class MockModel(ModelBase): + pass + + return MockModel() + + @pytest.fixture + def mock_instrumental_parameters(self, clear): + """Create mock instrumental parameters.""" + + class MockInstrument(ModelBase): + pass + + return MockInstrument() + + @pytest.fixture + def calculator_class_a(self): + """Create first concrete calculator class.""" + + class CalculatorA(CalculatorBase): + name = "calc_a" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2.0 + + return CalculatorA + + @pytest.fixture + def calculator_class_b(self): + """Create second concrete calculator class.""" + + class CalculatorB(CalculatorBase): + name = "calc_b" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 3.0 + + return CalculatorB + + # Initialization tests + def test_init_empty(self): + """Test initialization with no calculators.""" + factory = SimpleCalculatorFactory() + assert factory.available_calculators == [] + + def test_init_with_calculators_dict(self, calculator_class_a, calculator_class_b): + """Test initialization with calculators dictionary.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + assert set(factory.available_calculators) == {"a", "b"} + + # Available calculators tests + def test_available_calculators_returns_list(self, calculator_class_a): + """Test that available_calculators returns a list.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + result = factory.available_calculators + assert isinstance(result, list) + assert "a" in result + + # Create tests + def test_create_calculator(self, calculator_class_a, mock_model, mock_instrumental_parameters): + """Test creating a calculator.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model, mock_instrumental_parameters) + assert isinstance(calculator, CalculatorBase) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_create_with_kwargs(self, calculator_class_a, mock_model): + """Test creating calculator with additional kwargs.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model, custom_option="value") + assert calculator.additional_kwargs == {"custom_option": "value"} + + def test_create_unknown_calculator_raises_error(self, calculator_class_a, mock_model): + """Test that creating unknown calculator raises ValueError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(ValueError, match="Unknown calculator 'unknown'"): + factory.create("unknown", mock_model) + + def test_create_error_message_includes_available(self, calculator_class_a, calculator_class_b, mock_model): + """Test that error message includes available calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + with pytest.raises(ValueError) as exc_info: + factory.create("unknown", mock_model) + assert "a" in str(exc_info.value) or "b" in str(exc_info.value) + + # Register tests + def test_register_calculator(self, calculator_class_a, calculator_class_b): + """Test registering a new calculator.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + factory.register("b", calculator_class_b) + assert "b" in factory.available_calculators + + def test_register_overwrites_existing(self, calculator_class_a, calculator_class_b, clear): + """Test that registering with existing name overwrites.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + factory.register("a", calculator_class_b) + # Now "a" should create CalculatorB + calc = factory.create("a", create_mock_model()) + assert calc.name == "calc_b" + + def test_register_invalid_class_raises_error(self, calculator_class_a): + """Test that registering non-CalculatorBase raises TypeError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + + class NotACalculator: + pass + + with pytest.raises(TypeError, match="must be a subclass of CalculatorBase"): + factory.register("bad", NotACalculator) + + def test_register_non_class_raises_error(self, calculator_class_a): + """Test that registering a non-class raises TypeError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(TypeError, match="must be a subclass of CalculatorBase"): + factory.register("bad", "not a class") + + # Unregister tests + def test_unregister_calculator(self, calculator_class_a, calculator_class_b): + """Test unregistering a calculator.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + factory.unregister("a") + assert "a" not in factory.available_calculators + assert "b" in factory.available_calculators + + def test_unregister_unknown_raises_error(self, calculator_class_a): + """Test that unregistering unknown calculator raises KeyError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(KeyError, match="Calculator 'unknown' is not registered"): + factory.unregister("unknown") + + # Repr tests + def test_repr_empty_factory(self): + """Test __repr__ with empty factory.""" + factory = SimpleCalculatorFactory() + repr_str = repr(factory) + assert "SimpleCalculatorFactory" in repr_str + assert "available=[]" in repr_str + + def test_repr_with_calculators(self, calculator_class_a, calculator_class_b): + """Test __repr__ with calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + repr_str = repr(factory) + assert "SimpleCalculatorFactory" in repr_str + assert "a" in repr_str or "b" in repr_str + + # Integration tests + def test_created_calculator_works(self, calculator_class_a, mock_model): + """Test that created calculator actually works.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model) + x = np.array([1.0, 2.0, 3.0]) + result = calculator.calculate(x) + np.testing.assert_array_equal(result, np.array([2.0, 4.0, 6.0])) + + def test_create_multiple_calculators_independently( + self, calculator_class_a, calculator_class_b, clear + ): + """Test creating multiple independent calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + + model_a = create_mock_model() + model_b = create_mock_model() + + calc_a = factory.create("a", model_a) + calc_b = factory.create("b", model_b) + + # They should be independent + assert calc_a.model is model_a + assert calc_b.model is model_b + assert calc_a is not calc_b + + # And calculate differently + x = np.array([1.0, 2.0]) + np.testing.assert_array_equal(calc_a.calculate(x), np.array([2.0, 4.0])) + np.testing.assert_array_equal(calc_b.calculate(x), np.array([3.0, 6.0])) + + +class TestFactoryStatelessness: + """Tests to verify that the factory is truly stateless.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def calculator_class(self): + """Create a calculator class with counter for instances.""" + + class CountingCalculator(CalculatorBase): + name = "counting" + instance_count = 0 + + def __init__(self, model, instrumental_parameters=None, **kwargs): + super().__init__(model, instrumental_parameters, **kwargs) + CountingCalculator.instance_count += 1 + self.instance_id = CountingCalculator.instance_count + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + # Reset counter before each test + CountingCalculator.instance_count = 0 + return CountingCalculator + + def test_factory_does_not_store_calculator_instances(self, calculator_class, clear): + """Test that factory doesn't store references to created calculators.""" + factory = SimpleCalculatorFactory({"calc": calculator_class}) + mock_model = create_mock_model() + + calc1 = factory.create("calc", mock_model) + calc2 = factory.create("calc", mock_model) + + # Each create should produce a new instance + assert calc1 is not calc2 + assert calc1.instance_id == 1 + assert calc2.instance_id == 2 + + def test_factory_has_no_current_calculator_attribute(self, calculator_class): + """Test that factory has no 'current' calculator state.""" + factory = SimpleCalculatorFactory({"calc": calculator_class}) + + # Should not have any attributes tracking current state + assert not hasattr(factory, "_current_calculator") + assert not hasattr(factory, "current_calculator") + assert not hasattr(factory, "_current") + + def test_multiple_factories_are_independent(self, calculator_class, clear): + """Test that multiple factory instances are independent.""" + factory1 = SimpleCalculatorFactory({"calc": calculator_class}) + factory2 = SimpleCalculatorFactory({"calc": calculator_class}) + + mock_model = create_mock_model() + + calc1 = factory1.create("calc", mock_model) + calc2 = factory2.create("calc", mock_model) + + # Each factory creates independent calculators + assert calc1 is not calc2 + + +class TestFactoryIsolation: + """Tests to ensure calculator registries don't bleed between factory instances or subclasses.""" + + @pytest.fixture + def calculator_class_x(self): + """First test calculator class.""" + class CalculatorX(CalculatorBase): + name = "x" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return CalculatorX + + @pytest.fixture + def calculator_class_y(self): + """Second test calculator class.""" + class CalculatorY(CalculatorBase): + name = "y" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2 + return CalculatorY + + @pytest.fixture + def calculator_class_z(self): + """Third test calculator class.""" + class CalculatorZ(CalculatorBase): + name = "z" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 3 + return CalculatorZ + + def test_instance_registration_does_not_affect_other_instances( + self, calculator_class_x, calculator_class_y, calculator_class_z + ): + """Test that registering to one instance doesn't affect others.""" + factory1 = SimpleCalculatorFactory({"x": calculator_class_x}) + factory2 = SimpleCalculatorFactory({"y": calculator_class_y}) + + # Register z to factory1 only + factory1.register("z", calculator_class_z) + + # factory1 should have both x and z + assert "x" in factory1.available_calculators + assert "z" in factory1.available_calculators + assert "y" not in factory1.available_calculators + + # factory2 should only have y + assert "y" in factory2.available_calculators + assert "x" not in factory2.available_calculators + assert "z" not in factory2.available_calculators + + def test_subclass_registration_does_not_affect_parent_or_siblings( + self, calculator_class_x, calculator_class_y + ): + """Test that subclass registries are independent.""" + + calc_x = calculator_class_x + calc_y = calculator_class_y + + class FactoryA(SimpleCalculatorFactory): + def __init__(self): + super().__init__() + self._available_calculators["x"] = calc_x + + class FactoryB(SimpleCalculatorFactory): + def __init__(self): + super().__init__() + self._available_calculators["y"] = calc_y + + factory_a = FactoryA() + factory_b = FactoryB() + + # Each should have their own calculators + assert "x" in factory_a.available_calculators + assert "y" not in factory_a.available_calculators + + assert "y" in factory_b.available_calculators + assert "x" not in factory_b.available_calculators + + def test_class_level_registry_not_modified_by_instance_register( + self, calculator_class_x, calculator_class_y + ): + """Test that instance.register() doesn't modify other instances.""" + + calc_x = calculator_class_x + + class MyFactory(SimpleCalculatorFactory): + def __init__(self): + super().__init__() + self._available_calculators["x"] = calc_x + + # Create instance and register to it + factory = MyFactory() + factory.register("y", calculator_class_y) + + # Instance should have both + assert "x" in factory.available_calculators + assert "y" in factory.available_calculators + + # Create new instance - should NOT have y + factory2 = MyFactory() + assert "x" in factory2.available_calculators + assert "y" not in factory2.available_calculators + + def test_unregister_from_one_instance_does_not_affect_others( + self, calculator_class_x + ): + """Test that unregistering from one instance doesn't affect others.""" + factory1 = SimpleCalculatorFactory({"x": calculator_class_x}) + factory2 = SimpleCalculatorFactory({"x": calculator_class_x}) + + # Unregister from factory1 + factory1.unregister("x") + + # factory1 should not have x + assert "x" not in factory1.available_calculators + + # factory2 should still have x + assert "x" in factory2.available_calculators + + +class TestFactoryErrorHandling: + """Tests for improved error handling and validation.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def calculator_class(self): + """Simple test calculator.""" + class TestCalc(CalculatorBase): + name = "test" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return TestCalc + + def test_register_with_empty_name_raises_error(self, calculator_class): + """Test that empty calculator name raises ValueError.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="non-empty string"): + factory.register("", calculator_class) + + def test_register_with_non_string_name_raises_error(self, calculator_class): + """Test that non-string calculator name raises ValueError.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="non-empty string"): + factory.register(123, calculator_class) + + def test_register_overwrites_with_warning(self, calculator_class): + """Test that overwriting existing calculator issues warning.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + + class NewCalc(CalculatorBase): + name = "new" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + with pytest.warns(UserWarning, match="Overwriting existing calculator 'test'"): + factory.register("test", NewCalc) + + def test_create_with_non_string_name_raises_error(self, calculator_class, clear): + """Test that create with non-string name raises ValueError.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + with pytest.raises(ValueError, match="must be a string"): + factory.create(123, create_mock_model()) + + def test_create_with_none_model_raises_error(self, calculator_class): + """Test that create with None model raises TypeError.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + with pytest.raises(TypeError, match="Model cannot be None"): + factory.create("test", None) + + def test_create_unknown_calculator_shows_available_in_error(self, calculator_class, clear): + """Test that error message includes available calculators.""" + factory = SimpleCalculatorFactory({"calc1": calculator_class}) + with pytest.raises(ValueError, match="calc1") as exc_info: + factory.create("unknown", create_mock_model()) + assert "Available calculators" in str(exc_info.value) + + def test_create_empty_factory_error_shows_none_available(self, clear): + """Test error message when factory has no calculators.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="none") as exc_info: + factory.create("anything", create_mock_model()) + assert "Available calculators: none" in str(exc_info.value) + + def test_create_wraps_calculator_init_errors(self, calculator_class, clear): + """Test that calculator initialization errors are wrapped.""" + + class BrokenCalc(CalculatorBase): + name = "broken" + def __init__(self, model, instrumental_parameters=None, **kwargs): + raise RuntimeError("Something went wrong") + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + factory = SimpleCalculatorFactory({"broken": BrokenCalc}) + with pytest.raises(RuntimeError, match="Failed to create calculator 'broken'"): + factory.create("broken", create_mock_model()) + + +class TestCalculatorKwargsProperty: + """Tests for the additional_kwargs property on CalculatorBase.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def calculator_class(self): + """Simple calculator class for testing.""" + class TestCalc(CalculatorBase): + name = "test" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return TestCalc + + def test_additional_kwargs_accessible(self, calculator_class, clear): + """Test that additional_kwargs property is accessible.""" + calc = calculator_class( + create_mock_model(), + custom_param="value", + another_option=42 + ) + kwargs = calc.additional_kwargs + assert isinstance(kwargs, dict) + assert kwargs["custom_param"] == "value" + assert kwargs["another_option"] == 42 + + def test_additional_kwargs_empty_when_none_provided(self, calculator_class, clear): + """Test that additional_kwargs is empty dict when no kwargs provided.""" + calc = calculator_class(create_mock_model()) + assert calc.additional_kwargs == {} + + def test_additional_kwargs_via_factory(self, calculator_class, clear): + """Test that kwargs passed through factory are accessible.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + calc = factory.create( + "test", + create_mock_model(), + option1="value1", + option2=123 + ) + assert calc.additional_kwargs["option1"] == "value1" + assert calc.additional_kwargs["option2"] == 123 + + +class TestTryRegisterCalculator: + """Tests for the _try_register_calculator method.""" + + @pytest.fixture + def calculator_class(self): + """Simple calculator class for testing.""" + + class TestCalc(CalculatorBase): + name = "test" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + return TestCalc + + @pytest.fixture + def concrete_factory(self, calculator_class): + """Create a concrete factory for testing.""" + calc_class = calculator_class + + class TestFactory(CalculatorFactoryBase): + def __init__(self): + super().__init__() + + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + if calculator_name not in self._available_calculators: + raise ValueError(f"Unknown calculator: {calculator_name}") + return self._available_calculators[calculator_name](model, instrumental_parameters, **kwargs) + + return TestFactory + + def test_try_register_existing_package_succeeds(self, concrete_factory): + """Test that registering from an existing package works.""" + factory = concrete_factory() + # json is always available in Python + result = factory._try_register_calculator("json_encoder", "json", "JSONEncoder") + assert result is True + assert "json_encoder" in factory.available_calculators + + def test_try_register_nonexistent_package_returns_false(self, concrete_factory): + """Test that registering from non-existent package returns False.""" + factory = concrete_factory() + result = factory._try_register_calculator( + "nonexistent", "this_package_does_not_exist_12345", "SomeClass" + ) + assert result is False + assert "nonexistent" not in factory.available_calculators + + def test_try_register_nonexistent_class_returns_false(self, concrete_factory): + """Test that registering non-existent class returns False.""" + factory = concrete_factory() + result = factory._try_register_calculator( + "bad_class", "json", "ThisClassDoesNotExist12345" + ) + assert result is False + assert "bad_class" not in factory.available_calculators + + def test_try_register_multiple_calculators(self, concrete_factory): + """Test registering multiple calculators with mixed success.""" + factory = concrete_factory() + + # This should succeed + result1 = factory._try_register_calculator("encoder", "json", "JSONEncoder") + # This should fail + result2 = factory._try_register_calculator("fake", "nonexistent_pkg", "FakeClass") + # This should succeed + result3 = factory._try_register_calculator("decoder", "json", "JSONDecoder") + + assert result1 is True + assert result2 is False + assert result3 is True + + assert "encoder" in factory.available_calculators + assert "fake" not in factory.available_calculators + assert "decoder" in factory.available_calculators + assert len(factory.available_calculators) == 2 + + def test_try_register_does_not_affect_other_instances(self, concrete_factory): + """Test that _try_register on one instance doesn't affect others.""" + factory1 = concrete_factory() + factory2 = concrete_factory() + + factory1._try_register_calculator("encoder", "json", "JSONEncoder") + + assert "encoder" in factory1.available_calculators + assert "encoder" not in factory2.available_calculators + + def test_try_register_in_subclass_init(self, calculator_class): + """Test using _try_register_calculator in subclass __init__.""" + calc_class = calculator_class + + class DynamicFactory(CalculatorFactoryBase): + def __init__(self): + super().__init__() + # Register one that exists + self._try_register_calculator("encoder", "json", "JSONEncoder") + # Register one that doesn't exist - should be silently skipped + self._try_register_calculator("fake", "no_such_package", "NoClass") + + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + if calculator_name not in self._available_calculators: + raise ValueError(f"Unknown calculator: {calculator_name}") + return self._available_calculators[calculator_name](model, instrumental_parameters, **kwargs) + + factory = DynamicFactory() + assert "encoder" in factory.available_calculators + assert "fake" not in factory.available_calculators diff --git a/tests/unit_tests/fitting/calculators/test_interface_factory.py b/tests/unit_tests/fitting/calculators/test_interface_factory.py index ce7b8542..fc102664 100644 --- a/tests/unit_tests/fitting/calculators/test_interface_factory.py +++ b/tests/unit_tests/fitting/calculators/test_interface_factory.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project