diff --git a/src/easyscience/base_classes/__init__.py b/src/easyscience/base_classes/__init__.py index 9f3ba080..9a8dc2d3 100644 --- a/src/easyscience/base_classes/__init__.py +++ b/src/easyscience/base_classes/__init__.py @@ -1,6 +1,7 @@ from .based_base import BasedBase from .collection_base import CollectionBase from .model_base import ModelBase +from .model_collection import ModelCollection from .new_base import NewBase from .obj_base import ObjBase @@ -9,5 +10,6 @@ CollectionBase, ObjBase, ModelBase, + ModelCollection, NewBase, ] diff --git a/src/easyscience/base_classes/collection_base.py b/src/easyscience/base_classes/collection_base.py index 3cc0586a..e08ae26c 100644 --- a/src/easyscience/base_classes/collection_base.py +++ b/src/easyscience/base_classes/collection_base.py @@ -18,6 +18,7 @@ from ..variable.descriptor_base import DescriptorBase from .based_base import BasedBase +from .new_base import NewBase if TYPE_CHECKING: from ..fitting.calculators import InterfaceFactoryTemplate @@ -64,7 +65,7 @@ def __init__( _kwargs[key] = item kwargs = _kwargs for item in list(kwargs.values()) + _args: - if not issubclass(type(item), (DescriptorBase, BasedBase)): + if not issubclass(type(item), (DescriptorBase, BasedBase, NewBase)): raise AttributeError('A collection can only be formed from easyscience objects.') args = _args _kwargs = {} diff --git a/src/easyscience/base_classes/model_collection.py b/src/easyscience/base_classes/model_collection.py new file mode 100644 index 00000000..6f26085f --- /dev/null +++ b/src/easyscience/base_classes/model_collection.py @@ -0,0 +1,282 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """Add an item to the collection and set up graph edges. + + Note: Duplicate items (same object reference) are silently ignored. + """ + if not isinstance(item, NewBase): + raise TypeError(f'Items must be NewBase objects, got {type(item)}') + if item in self._data: + return # Skip duplicates to avoid multiple graph edges + self._data.append(item) + self._global_object.map.add_edge(self, item) + self._global_object.map.reset_type(item, 'created_internal') + if self._interface is not None and hasattr(item, 'interface'): + setattr(item, 'interface', self._interface) + + def _remove_item(self, item: NewBase) -> None: + """Remove an item from the collection and clean up graph edges.""" + self._global_object.map.prune_vertex_from_edge(self, item) + + @property + def interface(self) -> InterfaceType: + """Get the current interface of the collection.""" + return self._interface + + @interface.setter + def interface(self, new_interface: InterfaceType) -> None: + """Set the interface and propagate to all items. + + :param new_interface: The interface to set (must be InterfaceFactoryTemplate, CalculatorFactoryBase, or None) + :raises TypeError: If the interface is not a valid type + """ + # Import here to avoid circular imports + from ..fitting.calculators import CalculatorFactoryBase + from ..fitting.calculators import InterfaceFactoryTemplate + + if new_interface is not None and not isinstance(new_interface, (InterfaceFactoryTemplate, CalculatorFactoryBase)): + raise TypeError( + f'interface must be InterfaceFactoryTemplate, CalculatorFactoryBase, or None, ' + f'got {type(new_interface).__name__}' + ) + + self._interface = new_interface + for item in self._data: + if hasattr(item, 'interface'): + setattr(item, 'interface', new_interface) + + # MutableSequence abstract methods + + # Use @overload to provide precise type hints for different __getitem__ argument types + @overload + def __getitem__(self, idx: int) -> T: ... + @overload + def __getitem__(self, idx: slice) -> 'ModelCollection[T]': ... + @overload + def __getitem__(self, idx: str) -> T: ... + + def __getitem__(self, idx: int | slice | str) -> T | 'ModelCollection[T]': + """ + Get an item by index, slice, or name. + + :param idx: Index, slice, or name of the item + :return: The item or a new collection for slices + """ + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + return self.__class__(*[self._data[i] for i in range(start, stop, step)]) + if isinstance(idx, str): + # Search by name + for item in self._data: + if hasattr(item, 'name') and getattr(item, 'name') == idx: + return item # type: ignore[return-value] + if hasattr(item, 'unique_name') and item.unique_name == idx: + return item # type: ignore[return-value] + raise KeyError(f'No item with name "{idx}" found') + return self._data[idx] # type: ignore[return-value] + + @overload + def __setitem__(self, idx: int, value: T) -> None: ... + @overload + def __setitem__(self, idx: slice, value: Iterable[T]) -> None: ... + + def __setitem__(self, idx: int | slice, value: T | Iterable[T]) -> None: + """ + Set an item at an index. + + :param idx: Index to set + :param value: New value + """ + if isinstance(idx, slice): + # Handle slice assignment + values = list(value) # type: ignore[arg-type] + # Remove old items + start, stop, step = idx.indices(len(self)) + for i in range(start, stop, step): + self._remove_item(self._data[i]) + # Set new items + self._data[idx] = values # type: ignore[assignment] + for v in values: + self._global_object.map.add_edge(self, v) + self._global_object.map.reset_type(v, 'created_internal') + if self._interface is not None and hasattr(v, 'interface'): + setattr(v, 'interface', self._interface) + else: + if not isinstance(value, NewBase): + raise TypeError(f'Items must be NewBase objects, got {type(value)}') + + old_item = self._data[idx] + self._remove_item(old_item) + + self._data[idx] = value # type: ignore[assignment] + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + if self._interface is not None and hasattr(value, 'interface'): + setattr(value, 'interface', self._interface) + + @overload + def __delitem__(self, idx: int) -> None: ... + @overload + def __delitem__(self, idx: slice) -> None: ... + @overload + def __delitem__(self, idx: str) -> None: ... + + def __delitem__(self, idx: int | slice | str) -> None: + """ + Delete an item by index, slice, or name. + + :param idx: Index, slice, or name of item to delete + """ + if isinstance(idx, slice): + start, stop, step = idx.indices(len(self)) + indices = list(range(start, stop, step)) + # Remove in reverse order to maintain indices + for i in reversed(indices): + item = self._data[i] + self._remove_item(item) + del self._data[i] + elif isinstance(idx, str): + for i, item in enumerate(self._data): + if hasattr(item, 'name') and getattr(item, 'name') == idx: + idx = i + break + if hasattr(item, 'unique_name') and item.unique_name == idx: + idx = i + break + else: + raise KeyError(f'No item with name "{idx}" found') + + item = self._data[idx] + self._remove_item(item) + del self._data[idx] + else: + item = self._data[idx] + self._remove_item(item) + del self._data[idx] + + def __len__(self) -> int: + """Return the number of items in the collection.""" + return len(self._data) + + def insert(self, index: int, value: T) -> None: + """ + Insert an item at an index. + + :param index: Index to insert at + :param value: Item to insert + """ + if not isinstance(value, NewBase): + raise TypeError(f'Items must be NewBase objects, got {type(value)}') + + self._data.insert(index, value) # type: ignore[arg-type] + self._global_object.map.add_edge(self, value) + self._global_object.map.reset_type(value, 'created_internal') + if self._interface is not None and hasattr(value, 'interface'): + setattr(value, 'interface', self._interface) + + # Additional utility methods + + @property + def data(self) -> tuple: + """Return the data as a tuple.""" + return tuple(self._data) + + def sort(self, mapping: Callable[[T], Any], reverse: bool = False) -> None: + """ + Sort the collection according to the given mapping. + + :param mapping: Mapping function to sort by + :param reverse: Whether to reverse the sort + """ + self._data.sort(key=mapping, reverse=reverse) # type: ignore[arg-type] + + def __repr__(self) -> str: + return f'{self.__class__.__name__} of length {len(self)}' + + def __iter__(self) -> Any: + return iter(self._data) + + # Serialization support + + def _convert_to_dict(self, in_dict: dict, encoder: Any, skip: Optional[List[str]] = None, **kwargs: Any) -> dict: + """Convert the collection to a dictionary for serialization.""" + if skip is None: + skip = [] + d: dict = {} + if hasattr(self, '_modify_dict'): + d = self._modify_dict(skip=skip, **kwargs) # type: ignore[attr-defined] + in_dict['data'] = [encoder._convert_to_dict(item, skip=skip, **kwargs) for item in self._data] + return {**in_dict, **d} + + def get_all_variables(self) -> List[Any]: + """Get all variables from all items in the collection.""" + variables: List[Any] = [] + for item in self._data: + if hasattr(item, 'get_all_variables'): + variables.extend(item.get_all_variables()) # type: ignore[attr-defined] + return variables diff --git a/src/easyscience/fitting/calculators/__init__.py b/src/easyscience/fitting/calculators/__init__.py index a3ca5d43..e2c1ea89 100644 --- a/src/easyscience/fitting/calculators/__init__.py +++ b/src/easyscience/fitting/calculators/__init__.py @@ -2,6 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project None: + """ + Initialize the calculator with a model and instrumental parameters. + + Parameters + ---------- + model : NewBase + The physical model to calculate from. This is typically a sample + or structure definition containing fittable parameters. + instrumental_parameters : NewBase, optional + Instrumental parameters that affect the calculation, such as + resolution, wavelength, or detector settings. + unique_name : str, optional + Unique identifier for this calculator instance. + display_name : str, optional + Human-readable name for display purposes. + **kwargs : Any + Additional calculator-specific options. + """ + if model is None: + raise ValueError('Model cannot be None') + + # Initialize NewBase with naming + super().__init__(unique_name=unique_name, display_name=display_name) + + self._model = model + self._instrumental_parameters = instrumental_parameters + self._additional_kwargs = kwargs + # Register this calculator and model in the global object map + if hasattr(model, 'unique_name'): + self._global_object.map.add_edge(self, model) + + @property + def model(self) -> NewBase: + """ + Get the current physical model. + + Returns + ------- + NewBase + The physical model used for calculations. + """ + return self._model + + @model.setter + def model(self, new_model: NewBase) -> None: + """ + Set a new physical model. + + Parameters + ---------- + new_model : NewBase + The new physical model to use for calculations. + + Raises + ------ + ValueError + If the new model is None. + """ + if new_model is None: + raise ValueError('Model cannot be None') + self._model = new_model + + @property + def instrumental_parameters(self) -> Optional[NewBase]: + """ + Get the current instrumental parameters. + + Returns + ------- + NewBase or None + The instrumental parameters, or None if not set. + """ + return self._instrumental_parameters + + @instrumental_parameters.setter + def instrumental_parameters(self, new_parameters: Optional[NewBase]) -> None: + """ + Set new instrumental parameters. + + Parameters + ---------- + new_parameters : NewBase or None + The new instrumental parameters to use for calculations. + Truly optional, since instrumental parameters may not always be needed. + """ + self._instrumental_parameters = new_parameters + + def update_model(self, new_model: NewBase) -> None: + """ + Update the physical model used for calculations. + + This is an alternative to the `model` property setter that can be + overridden by subclasses to perform additional setup when the model changes. + + Parameters + ---------- + new_model : NewBase + The new physical model to use. + + Raises + ------ + ValueError + If the new model is None. + """ + self.model = new_model + + def update_instrumental_parameters(self, new_parameters: Optional[NewBase]) -> None: + """ + Update the instrumental parameters used for calculations. + + This is an alternative to the `instrumental_parameters` property setter + that can be overridden by subclasses to perform additional setup when + instrumental parameters change. + + Parameters + ---------- + new_parameters : NewBase or None + The new instrumental parameters to use. + """ + self.instrumental_parameters = new_parameters + + @property + def additional_kwargs(self) -> dict: + """ + Get additional keyword arguments passed during initialization. + + Returns a copy to prevent external modification of internal state. + + Returns + ------- + dict + Copy of the dictionary of additional kwargs passed to __init__. + """ + return dict(self._additional_kwargs) + + @abstractmethod + def calculate(self, x: np.ndarray) -> np.ndarray: + """ + Calculate theoretical values at the given points. + + This is the main calculation method that must be implemented by all + concrete calculator classes. It uses the current model and instrumental + parameters to compute theoretical predictions. + + Parameters + ---------- + x : np.ndarray + The independent variable values (e.g., Q values, angles, energies) + at which to calculate the theoretical response. + + Returns + ------- + np.ndarray + The calculated theoretical values corresponding to the input x values. + + Notes + ----- + This method is called during fitting and should be thread-safe if + parallel fitting is to be supported. + """ + ... + + def __repr__(self) -> str: + """Return a string representation of the calculator.""" + model_name = getattr(self._model, 'name', type(self._model).__name__) + instr_info = '' + if self._instrumental_parameters is not None: + instr_name = getattr( + self._instrumental_parameters, + 'name', + type(self._instrumental_parameters).__name__, # default to class name if no 'name' attribute + ) + instr_info = f', instrumental_parameters={instr_name}' + return f'{self.__class__.__name__}(model={model_name}{instr_info})' diff --git a/src/easyscience/fitting/calculators/calculator_factory.py b/src/easyscience/fitting/calculators/calculator_factory.py new file mode 100644 index 00000000..69397d4d --- /dev/null +++ b/src/easyscience/fitting/calculators/calculator_factory.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project List[str]: + """ + Return a list of available calculator names. + + Returns + ------- + List[str] + Names of all calculators that can be created by this factory. + """ + ... + + @abstractmethod + def create( + self, + calculator_name: str, + model: NewBase, + instrumental_parameters: Optional[NewBase] = None, + **kwargs: Any, + ) -> CalculatorBase: + """ + Create a calculator instance. + + Parameters + ---------- + calculator_name : str + The name of the calculator to create. Must be one of the names + returned by `available_calculators`. + model : NewBase + The physical model (e.g., sample) to pass to the calculator. + instrumental_parameters : NewBase, optional + Instrumental parameters to pass to the calculator. + **kwargs : Any + Additional arguments to pass to the calculator constructor. + + Returns + ------- + CalculatorBase + A new calculator instance configured with the given model and + instrumental parameters. + + Raises + ------ + ValueError + If the requested calculator_name is not available. + """ + ... + + def __repr__(self) -> str: + """Return a string representation of the factory.""" + return f'{self.__class__.__name__}(available={self.available_calculators})' + + +class SimpleCalculatorFactory(CalculatorFactoryBase): + """ + A simple implementation of a calculator factory using a dictionary registry. + + This class provides a convenient base for creating calculator factories + where calculators are registered in a dictionary. Subclasses only need + to populate the `_calculators` class attribute. + + Parameters + ---------- + calculators : Dict[str, Type[CalculatorBase]], optional + A dictionary mapping calculator names to calculator classes. + If not provided, uses the class-level `_calculators` attribute. + + Attributes + ---------- + _calculators : Dict[str, Type[CalculatorBase]] + Class-level dictionary of registered calculators. Subclasses should + override this with their available calculators. + + Examples + -------- + Using class-level registration:: + + class MyFactory(SimpleCalculatorFactory): + _calculators = { + 'fast': FastCalculator, + 'accurate': AccurateCalculator, + } + + factory = MyFactory() + calc = factory.create('fast', model, instrument) + + Using instance-level registration:: + + factory = SimpleCalculatorFactory({ + 'custom': CustomCalculator, + }) + calc = factory.create('custom', model, instrument) + """ + + _calculators: Dict[str, Type[CalculatorBase]] = {} + + def __init__( + self, + calculators: Optional[Dict[str, Type[CalculatorBase]]] = None, + ) -> None: + """ + Initialize the factory with optional calculator registry. + + Parameters + ---------- + calculators : Dict[str, Type[CalculatorBase]], optional + A dictionary mapping calculator names to calculator classes. + If provided, overrides the class-level `_calculators` attribute. + """ + # Create instance-level copy to prevent bleeding between instances + if calculators is not None: + self._calculators = dict(calculators) + else: + # Create a copy of the class-level registry for this instance + self._calculators = dict(self.__class__._calculators) + + @property + def available_calculators(self) -> List[str]: + """ + Return a list of available calculator names. + + Returns + ------- + List[str] + Names of all registered calculators. + """ + return list(self._calculators.keys()) + + def create( + self, + calculator_name: str, + model: NewBase, + instrumental_parameters: Optional[NewBase] = None, + **kwargs: Any, + ) -> CalculatorBase: + """ + Create a calculator instance from the registered calculators. + + Parameters + ---------- + calculator_name : str + The name of the calculator to create. + model : NewBase + The physical model to pass to the calculator. + instrumental_parameters : NewBase, optional + Instrumental parameters to pass to the calculator. + **kwargs : Any + Additional arguments to pass to the calculator constructor. + + Returns + ------- + CalculatorBase + A new calculator instance. + + Raises + ------ + ValueError + If the calculator_name is not in the registry or is not a string. + TypeError + If model is None or instrumental_parameters has wrong type. + """ + if not isinstance(calculator_name, str): + raise ValueError(f'calculator_name must be a string, got {type(calculator_name).__name__}') + + if calculator_name not in self._calculators: + available = ', '.join(self.available_calculators) if self.available_calculators else 'none' + raise ValueError(f"Unknown calculator '{calculator_name}'. Available calculators: {available}") + + if model is None: + raise TypeError('Model cannot be None') + + calculator_class = self._calculators[calculator_name] + try: + return calculator_class(model, instrumental_parameters, **kwargs) + except Exception as e: + raise type(e)(f"Failed to create calculator '{calculator_name}': {e}") from e + + def register(self, name: str, calculator_class: Type[CalculatorBase]) -> None: + """ + Register a new calculator class with the factory. + + Parameters + ---------- + name : str + The name to register the calculator under. + calculator_class : Type[CalculatorBase] + The calculator class to register. + + Raises + ------ + TypeError + If calculator_class is not a subclass of CalculatorBase. + ValueError + If name is empty or not a string. + + Warnings + -------- + If overwriting an existing calculator, a warning is issued. + """ + # Import here to avoid circular imports at module level + import warnings + + from .calculator_base import CalculatorBase + + if not isinstance(name, str) or not name: + raise ValueError('Calculator name must be a non-empty string') + + if not (isinstance(calculator_class, type) and issubclass(calculator_class, CalculatorBase)): + raise TypeError(f'calculator_class must be a subclass of CalculatorBase, got {type(calculator_class).__name__}') + + if name in self._calculators: + warnings.warn(f"Overwriting existing calculator '{name}' in {self.__class__.__name__}", UserWarning, stacklevel=2) + + self._calculators[name] = calculator_class + + def unregister(self, name: str) -> None: + """ + Remove a calculator from the registry. + + Parameters + ---------- + name : str + The name of the calculator to remove. + + Raises + ------ + KeyError + If the calculator name is not in the registry. + """ + if name not in self._calculators: + raise KeyError(f"Calculator '{name}' is not registered") + del self._calculators[name] diff --git a/src/easyscience/fitting/calculators/interface_factory.py b/src/easyscience/fitting/calculators/interface_factory.py index ca4713fd..35956ad7 100644 --- a/src/easyscience/fitting/calculators/interface_factory.py +++ b/src/easyscience/fitting/calculators/interface_factory.py @@ -3,6 +3,7 @@ # SPDX-FileCopyrightText: 2025 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project bool: class Map: def __init__(self): # A dictionary of object names and their corresponding objects - self._store = weakref.WeakValueDictionary() + self.__store = weakref.WeakValueDictionary() # A dict with object names as keys and a list of their object types as values, with weak references self.__type_dict = {} @@ -82,7 +82,7 @@ def vertices(self) -> List[str]: """ while True: try: - return list(self._store) + return list(self.__store) except RuntimeError: # Dictionary changed size during iteration, retry continue @@ -112,8 +112,8 @@ def _nested_get(self, obj_type: str) -> List[str]: return [key for key, item in self.__type_dict.items() if obj_type in item.type] def get_item_by_key(self, item_id: str) -> object: - if item_id in self._store: - return self._store[item_id] + if item_id in self.__store: + return self.__store[item_id] raise ValueError('Item not in map.') def is_known(self, vertex: object) -> bool: @@ -121,7 +121,7 @@ def is_known(self, vertex: object) -> bool: All objects should have a 'unique_name' attribute. """ - return vertex.unique_name in self._store + return vertex.unique_name in self.__store def find_type(self, vertex: object) -> List[str]: if self.is_known(vertex): @@ -137,15 +137,15 @@ def change_type(self, obj, new_type: str): def add_vertex(self, obj: object, obj_type: str = None): name = obj.unique_name - if name in self._store: + if name in self.__store: raise ValueError(f'Object name {name} already exists in the graph.') # Clean up stale entry in __type_dict if the weak reference was collected # but the finalizer hasn't run yet if name in self.__type_dict: del self.__type_dict[name] - self._store[name] = obj + self.__store[name] = obj self.__type_dict[name] = _EntryList() # Add objects type to the list of types - self.__type_dict[name].finalizer = weakref.finalize(self._store[name], self.prune, name) + self.__type_dict[name].finalizer = weakref.finalize(self.__store[name], self.prune, name) self.__type_dict[name].type = obj_type def add_edge(self, start_obj: object, end_obj: object): @@ -185,8 +185,8 @@ def prune_vertex_from_edge(self, parent_obj, child_obj): def prune(self, key: str): if key in self.__type_dict: del self.__type_dict[key] - if key in self._store: - del self._store[key] + if key in self.__store: + del self.__store[key] def find_isolated_vertices(self) -> list: """returns a list of isolated vertices.""" @@ -279,9 +279,9 @@ def is_connected(self, vertices_encountered=None, start_vertex=None) -> bool: def _clear(self): """Reset the map to an empty state. Only to be used for testing""" - self._store.clear() + self.__store.clear() self.__type_dict.clear() gc.collect() def __repr__(self) -> str: - return f'Map object of {len(self._store)} vertices.' + return f'Map object of {len(self.__store)} vertices.' diff --git a/src/easyscience/variable/parameter.py b/src/easyscience/variable/parameter.py index 55787ad4..0df57d1f 100644 --- a/src/easyscience/variable/parameter.py +++ b/src/easyscience/variable/parameter.py @@ -1029,7 +1029,8 @@ def resolve_pending_dependencies(self) -> None: def _find_parameter_by_serializer_id(self, serializer_id: str) -> Optional['DescriptorNumber']: """Find a parameter by its serializer_id from all parameters in the global map.""" - for obj in self._global_object.map._store.values(): + for key in self._global_object.map.vertices(): + obj = self._global_object.map.get_item_by_key(key) if isinstance(obj, DescriptorNumber) and hasattr(obj, '_DescriptorNumber__serializer_id'): if obj._DescriptorNumber__serializer_id == serializer_id: return obj diff --git a/tests/unit_tests/base_classes/test_model_collection.py b/tests/unit_tests/base_classes/test_model_collection.py new file mode 100644 index 00000000..b62ddb63 --- /dev/null +++ b/tests/unit_tests/base_classes/test_model_collection.py @@ -0,0 +1,671 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project str: + return self._name + + @property + def value(self) -> Parameter: + return self._value + + @value.setter + def value(self, new_value: float) -> None: + self._value.value = new_value + + +class DerivedModelCollection(ModelCollection): + """A derived class for testing inheritance.""" + pass + + +class_constructors = [ModelCollection, DerivedModelCollection] + + +@pytest.fixture +def clear_global(): + """Clear the global object map before each test.""" + global_object.map._clear() + yield + global_object.map._clear() + + +@pytest.fixture +def sample_items(): + """Create sample items for testing.""" + return [ + MockModelItem(name='item1', value=1.0), + MockModelItem(name='item2', value=2.0), + MockModelItem(name='item3', value=3.0), + ] + + +# ============================================================================= +# Constructor Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_empty(cls, clear_global): + """Test creating an empty collection.""" + coll = cls() + assert len(coll) == 0 + assert coll.interface is None + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_items(cls, clear_global, sample_items): + """Test creating a collection with initial items.""" + coll = cls(*sample_items) + assert len(coll) == 3 + for i, item in enumerate(coll): + assert item.name == sample_items[i].name + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_unique_name(cls, clear_global): + """Test creating a collection with a custom unique_name.""" + coll = cls(unique_name='custom_unique') + assert coll.unique_name == 'custom_unique' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_display_name(cls, clear_global): + """Test creating a collection with a custom display_name.""" + coll = cls(display_name='My Display Name') + assert coll.display_name == 'My Display Name' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_with_list_arg(cls, clear_global, sample_items): + """Test creating a collection with a list of items (should flatten).""" + coll = cls(sample_items) + assert len(coll) == 3 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_init_type_error(cls, clear_global): + """Test that adding non-NewBase items raises TypeError.""" + with pytest.raises(TypeError): + cls('not_a_newbase_object') + + +# ============================================================================= +# Interface Property Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_default(cls, clear_global): + """Test that interface defaults to None.""" + coll = cls() + assert coll.interface is None + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_propagation(cls, clear_global, sample_items): + """Test that setting interface propagates to items.""" + # Add interface attribute to items for this test + for item in sample_items: + item.interface = None + + coll = cls(*sample_items) + + class MockInterface(CalculatorFactoryBase): + """Mock interface for testing.""" + @property + def available_calculators(self): + return [] + + def create(self, calculator_name, *args, **kwargs): + pass + + mock_interface = MockInterface() + coll.interface = mock_interface + + assert coll.interface is mock_interface + for item in coll: + assert item.interface is mock_interface + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_type_error(cls, clear_global): + """Test that setting an invalid interface type raises TypeError.""" + coll = cls() + + with pytest.raises(TypeError, match='interface must be'): + coll.interface = 'not_an_interface' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_type_error_with_object(cls, clear_global): + """Test that setting a plain object as interface raises TypeError.""" + coll = cls() + + class NotAnInterface: + pass + + with pytest.raises(TypeError, match='interface must be'): + coll.interface = NotAnInterface() + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_interface_accepts_none(cls, clear_global, sample_items): + """Test that setting interface to None is allowed.""" + coll = cls(*sample_items) + coll.interface = None + assert coll.interface is None + + +# ============================================================================= +# __getitem__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_int(cls, clear_global, sample_items): + """Test getting items by integer index.""" + coll = cls(*sample_items) + assert coll[0].name == 'item1' + assert coll[1].name == 'item2' + assert coll[2].name == 'item3' + assert coll[-1].name == 'item3' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_int_out_of_range(cls, clear_global, sample_items): + """Test that out of range index raises IndexError.""" + coll = cls(*sample_items) + with pytest.raises(IndexError): + _ = coll[100] + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_slice(cls, clear_global, sample_items): + """Test getting items by slice.""" + coll = cls(*sample_items) + sliced = coll[0:2] + assert isinstance(sliced, cls) + assert len(sliced) == 2 + assert sliced[0].name == 'item1' + assert sliced[1].name == 'item2' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_str_by_name(cls, clear_global, sample_items): + """Test getting items by name string.""" + coll = cls(*sample_items) + item = coll['item2'] + assert item.name == 'item2' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_str_by_unique_name(cls, clear_global, sample_items): + """Test getting items by unique_name string.""" + coll = cls(*sample_items) + unique_name = sample_items[1].unique_name + item = coll[unique_name] + assert item.unique_name == unique_name + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_getitem_str_not_found(cls, clear_global, sample_items): + """Test that getting non-existent name raises KeyError.""" + coll = cls(*sample_items) + with pytest.raises(KeyError): + _ = coll['nonexistent'] + + +# ============================================================================= +# __setitem__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_setitem_int(cls, clear_global, sample_items): + """Test setting items by integer index.""" + coll = cls(*sample_items) + new_item = MockModelItem(name='new_item', value=99.0) + old_item = coll[1] + + coll[1] = new_item + + assert len(coll) == 3 + assert coll[1].name == 'new_item' + assert coll[1].value.value == 99.0 + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert new_item.unique_name in edges + assert old_item.unique_name not in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_setitem_type_error(cls, clear_global, sample_items): + """Test that setting non-NewBase item raises TypeError.""" + coll = cls(*sample_items) + with pytest.raises(TypeError): + coll[0] = 'not_a_newbase_object' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_setitem_slice(cls, clear_global, sample_items): + """Test setting items by slice.""" + coll = cls(*sample_items) + new_items = [ + MockModelItem(name='new1', value=10.0), + MockModelItem(name='new2', value=20.0), + ] + + coll[0:2] = new_items + + assert len(coll) == 3 + assert coll[0].name == 'new1' + assert coll[1].name == 'new2' + assert coll[2].name == 'item3' + + +# ============================================================================= +# __delitem__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_int(cls, clear_global, sample_items): + """Test deleting items by integer index.""" + coll = cls(*sample_items) + deleted_item = coll[1] + + del coll[1] + + assert len(coll) == 2 + assert coll[0].name == 'item1' + assert coll[1].name == 'item3' + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert deleted_item.unique_name not in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_slice(cls, clear_global, sample_items): + """Test deleting items by slice.""" + coll = cls(*sample_items) + + del coll[0:2] + + assert len(coll) == 1 + assert coll[0].name == 'item3' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_str_by_name(cls, clear_global, sample_items): + """Test deleting items by name string.""" + coll = cls(*sample_items) + + del coll['item2'] + + assert len(coll) == 2 + assert 'item2' not in [item.name for item in coll] + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_delitem_str_not_found(cls, clear_global, sample_items): + """Test that deleting non-existent name raises KeyError.""" + coll = cls(*sample_items) + with pytest.raises(KeyError): + del coll['nonexistent'] + + +# ============================================================================= +# __len__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +@pytest.mark.parametrize('count', [0, 1, 3, 5]) +def test_ModelCollection_len(cls, clear_global, count): + """Test __len__ returns correct count.""" + items = [MockModelItem(name=f'item{i}', value=float(i)) for i in range(count)] + coll = cls(*items) + assert len(coll) == count + + +# ============================================================================= +# insert Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_insert(cls, clear_global, sample_items): + """Test inserting items at an index.""" + coll = cls(*sample_items) + new_item = MockModelItem(name='inserted', value=99.0) + + coll.insert(1, new_item) + + assert len(coll) == 4 + assert coll[0].name == 'item1' + assert coll[1].name == 'inserted' + assert coll[2].name == 'item2' + assert coll[3].name == 'item3' + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert new_item.unique_name in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_insert_type_error(cls, clear_global, sample_items): + """Test that inserting non-NewBase item raises TypeError.""" + coll = cls(*sample_items) + with pytest.raises(TypeError): + coll.insert(0, 'not_a_newbase_object') + + +# ============================================================================= +# append Tests (inherited from MutableSequence) +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_append(cls, clear_global, sample_items): + """Test appending items.""" + coll = cls(*sample_items) + new_item = MockModelItem(name='appended', value=99.0) + + coll.append(new_item) + + assert len(coll) == 4 + assert coll[-1].name == 'appended' + + # Check graph edges + edges = global_object.map.get_edges(coll) + assert new_item.unique_name in edges + + +# ============================================================================= +# data Property Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_data_property(cls, clear_global, sample_items): + """Test that data property returns tuple of items.""" + coll = cls(*sample_items) + data = coll.data + assert isinstance(data, tuple) + assert len(data) == 3 + for i, item in enumerate(data): + assert item.name == sample_items[i].name + + +# ============================================================================= +# sort Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_sort(cls, clear_global): + """Test sorting the collection.""" + items = [ + MockModelItem(name='c', value=3.0), + MockModelItem(name='a', value=1.0), + MockModelItem(name='b', value=2.0), + ] + coll = cls(*items) + + coll.sort(lambda x: x.value.value) + + assert coll[0].name == 'a' + assert coll[1].name == 'b' + assert coll[2].name == 'c' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_sort_reverse(cls, clear_global): + """Test sorting the collection in reverse.""" + items = [ + MockModelItem(name='a', value=1.0), + MockModelItem(name='c', value=3.0), + MockModelItem(name='b', value=2.0), + ] + coll = cls(*items) + + coll.sort(lambda x: x.value.value, reverse=True) + + assert coll[0].name == 'c' + assert coll[1].name == 'b' + assert coll[2].name == 'a' + + +# ============================================================================= +# __repr__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_repr(cls, clear_global, sample_items): + """Test string representation.""" + coll = cls(*sample_items) + repr_str = repr(coll) + assert cls.__name__ in repr_str + assert '3' in repr_str + + +# ============================================================================= +# __iter__ Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_iter(cls, clear_global, sample_items): + """Test iteration over collection.""" + coll = cls(*sample_items) + + names = [item.name for item in coll] + assert names == ['item1', 'item2', 'item3'] + + +# ============================================================================= +# get_all_variables Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_get_all_variables(cls, clear_global, sample_items): + """Test getting all variables from items.""" + coll = cls(*sample_items) + variables = coll.get_all_variables() + + # Each MockModelItem has one Parameter (value) + assert len(variables) == 3 + for var in variables: + assert isinstance(var, Parameter) + + +# ============================================================================= +# get_all_parameters Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_get_all_parameters(cls, clear_global, sample_items): + """Test getting all parameters from items.""" + coll = cls(*sample_items) + parameters = coll.get_all_parameters() + + assert len(parameters) == 3 + for param in parameters: + assert isinstance(param, Parameter) + + +# ============================================================================= +# get_fit_parameters Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_get_fit_parameters(cls, clear_global, sample_items): + """Test getting fit parameters from items.""" + # Fix one parameter so we can test filtering + sample_items[0].value.fixed = True + + coll = cls(*sample_items) + fit_params = coll.get_fit_parameters() + + # All 3 parameters should be returned (get_fit_parameters on items) + # since MockModelItem.get_fit_parameters returns free params + assert len(fit_params) == 2 + + +# ============================================================================= +# Graph Edge Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_graph_edges(cls, clear_global, sample_items): + """Test that graph edges are correctly maintained.""" + coll = cls(*sample_items) + + edges = global_object.map.get_edges(coll) + assert len(edges) == 3 + + for item in sample_items: + assert item.unique_name in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_graph_edges_after_append(cls, clear_global, sample_items): + """Test graph edges are updated after append.""" + coll = cls(*sample_items) + new_item = MockModelItem(name='new', value=99.0) + + coll.append(new_item) + + edges = global_object.map.get_edges(coll) + assert len(edges) == 4 + assert new_item.unique_name in edges + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_graph_edges_after_delete(cls, clear_global, sample_items): + """Test graph edges are updated after delete.""" + coll = cls(*sample_items) + deleted_item = sample_items[1] + + del coll[1] + + edges = global_object.map.get_edges(coll) + assert len(edges) == 2 + assert deleted_item.unique_name not in edges + + +# ============================================================================= +# MutableSequence Interface Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_extend(cls, clear_global, sample_items): + """Test extend method (inherited from MutableSequence).""" + coll = cls(sample_items[0]) + coll.extend([sample_items[1], sample_items[2]]) + + assert len(coll) == 3 + assert coll[1].name == 'item2' + assert coll[2].name == 'item3' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_pop(cls, clear_global, sample_items): + """Test pop method (inherited from MutableSequence).""" + coll = cls(*sample_items) + + popped = coll.pop() + assert popped.name == 'item3' + assert len(coll) == 2 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_pop_index(cls, clear_global, sample_items): + """Test pop method with index (inherited from MutableSequence).""" + coll = cls(*sample_items) + + popped = coll.pop(0) + assert popped.name == 'item1' + assert len(coll) == 2 + assert coll[0].name == 'item2' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_remove(cls, clear_global, sample_items): + """Test remove method (inherited from MutableSequence).""" + coll = cls(*sample_items) + item_to_remove = sample_items[1] + + coll.remove(item_to_remove) + + assert len(coll) == 2 + assert item_to_remove not in coll + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_clear(cls, clear_global, sample_items): + """Test clear method (inherited from MutableSequence).""" + coll = cls(*sample_items) + + coll.clear() + + assert len(coll) == 0 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_reverse(cls, clear_global, sample_items): + """Test reverse method (inherited from MutableSequence).""" + coll = cls(*sample_items) + + coll.reverse() + + assert coll[0].name == 'item3' + assert coll[1].name == 'item2' + assert coll[2].name == 'item1' + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_count(cls, clear_global, sample_items): + """Test count method (inherited from MutableSequence).""" + coll = cls(*sample_items) + + count = coll.count(sample_items[0]) + assert count == 1 + + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_index(cls, clear_global, sample_items): + """Test index method (inherited from MutableSequence).""" + coll = cls(*sample_items) + + idx = coll.index(sample_items[1]) + assert idx == 1 + + +# ============================================================================= +# Contains Tests +# ============================================================================= + +@pytest.mark.parametrize('cls', class_constructors) +def test_ModelCollection_contains(cls, clear_global, sample_items): + """Test __contains__ (in operator).""" + coll = cls(*sample_items) + + assert sample_items[0] in coll + assert sample_items[1] in coll + + new_item = MockModelItem(name='not_in_collection', value=999.0) + assert new_item not in coll diff --git a/tests/unit_tests/fitting/calculators/test_calculator_base.py b/tests/unit_tests/fitting/calculators/test_calculator_base.py new file mode 100644 index 00000000..38b3c601 --- /dev/null +++ b/tests/unit_tests/fitting/calculators/test_calculator_base.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project np.ndarray: + # Simple identity function for testing + return x * 2.0 + + return ConcreteCalculator + + @pytest.fixture + def calculator(self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters): + """Create a calculator instance for testing.""" + return concrete_calculator_class( + mock_model, mock_instrumental_parameters, unique_name="test_calc", display_name="TestCalc" + ) + + # Initialization tests + def test_init_with_model_only(self, clear, concrete_calculator_class, mock_model): + """Test initialization with only a model.""" + calc = concrete_calculator_class(mock_model, unique_name="test_1", display_name="Test1") + assert calc.model is mock_model + assert calc.instrumental_parameters is None + + def test_init_with_model_and_instrumental_parameters( + self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters + ): + """Test initialization with model and instrumental parameters.""" + calc = concrete_calculator_class( + mock_model, mock_instrumental_parameters, unique_name="test_2", display_name="Test2" + ) + assert calc.model is mock_model + assert calc.instrumental_parameters is mock_instrumental_parameters + + def test_init_with_kwargs(self, clear, concrete_calculator_class, mock_model): + """Test initialization with additional kwargs.""" + calc = concrete_calculator_class( + mock_model, unique_name="test_3", display_name="Test3", custom_option="value" + ) + assert calc.additional_kwargs == {"custom_option": "value"} + + def test_init_with_none_model_raises_error(self, clear, concrete_calculator_class): + """Test that initialization with None model raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + concrete_calculator_class(None, unique_name="test_4", display_name="Test4") + + # Model property tests + def test_model_getter(self, calculator, mock_model): + """Test model getter property.""" + assert calculator.model is mock_model + + def test_model_setter(self, calculator): + """Test model setter property.""" + new_model = MagicMock() + new_model.name = "NewModel" + calculator.model = new_model + assert calculator.model is new_model + + def test_model_setter_with_none_raises_error(self, calculator): + """Test that setting model to None raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + calculator.model = None + + # Instrumental parameters property tests + def test_instrumental_parameters_getter(self, calculator, mock_instrumental_parameters): + """Test instrumental_parameters getter property.""" + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_instrumental_parameters_setter(self, calculator): + """Test instrumental_parameters setter property.""" + new_params = MagicMock() + new_params.name = "NewInstrument" + calculator.instrumental_parameters = new_params + assert calculator.instrumental_parameters is new_params + + def test_instrumental_parameters_setter_with_none(self, calculator): + """Test that instrumental_parameters can be set to None.""" + calculator.instrumental_parameters = None + assert calculator.instrumental_parameters is None + + # Update methods tests + def test_update_model(self, calculator): + """Test update_model method.""" + new_model = MagicMock() + new_model.name = "UpdatedModel" + calculator.update_model(new_model) + assert calculator.model is new_model + + def test_update_model_with_none_raises_error(self, calculator): + """Test that update_model with None raises ValueError.""" + with pytest.raises(ValueError, match="Model cannot be None"): + calculator.update_model(None) + + def test_update_instrumental_parameters(self, calculator): + """Test update_instrumental_parameters method.""" + new_params = MagicMock() + new_params.name = "UpdatedInstrument" + calculator.update_instrumental_parameters(new_params) + assert calculator.instrumental_parameters is new_params + + def test_update_instrumental_parameters_with_none(self, calculator): + """Test that update_instrumental_parameters accepts None.""" + calculator.update_instrumental_parameters(None) + assert calculator.instrumental_parameters is None + + # Calculate method tests + def test_calculate_returns_array(self, calculator): + """Test that calculate returns an array.""" + x = np.array([1.0, 2.0, 3.0]) + result = calculator.calculate(x) + assert isinstance(result, np.ndarray) + np.testing.assert_array_equal(result, np.array([2.0, 4.0, 6.0])) + + def test_calculate_with_empty_array(self, calculator): + """Test calculate with empty array.""" + x = np.array([]) + result = calculator.calculate(x) + assert len(result) == 0 + + # Abstract method enforcement tests + def test_cannot_instantiate_abstract_class(self, mock_model): + """Test that CalculatorBase cannot be instantiated directly.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + CalculatorBase(mock_model) + + def test_subclass_must_implement_calculate(self, mock_model): + """Test that subclasses must implement calculate method.""" + + class IncompleteCalculator(CalculatorBase): + pass # Does not implement calculate + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteCalculator(mock_model) + + # Representation tests + def test_repr_with_model_only(self, clear, concrete_calculator_class, mock_model): + """Test __repr__ with only model.""" + calc = concrete_calculator_class(mock_model, unique_name="test_5", display_name="Test5") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "model=MockModel" in repr_str + assert "instrumental_parameters" not in repr_str + + def test_repr_with_model_and_instrumental_parameters( + self, clear, concrete_calculator_class, mock_model, mock_instrumental_parameters + ): + """Test __repr__ with model and instrumental parameters.""" + calc = concrete_calculator_class(mock_model, mock_instrumental_parameters, unique_name="test_6", display_name="Test6") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "model=MockModel" in repr_str + assert "instrumental_parameters=MockInstrument" in repr_str + + def test_repr_with_model_without_name_attribute(self, clear, concrete_calculator_class): + """Test __repr__ when model has no name attribute.""" + model = MagicMock(spec=[]) # No name attribute + calc = concrete_calculator_class(model, unique_name="test_7", display_name="Test7") + repr_str = repr(calc) + assert "ConcreteCalculator" in repr_str + assert "MagicMock" in repr_str + + # Name attribute tests + def test_calculator_name_attribute(self, calculator): + """Test that calculator has name attribute.""" + assert calculator.name == "test_calculator" + + def test_default_name_is_base(self): + """Test that default name is 'base'.""" + assert CalculatorBase.name == "base" + + # Additional kwargs property tests + def test_additional_kwargs_with_init(self, clear, concrete_calculator_class, mock_model): + """Test additional_kwargs property with kwargs in init.""" + calc = concrete_calculator_class( + mock_model, + unique_name="test_8", + display_name="Test8", + custom_option="value", + numeric_param=42 + ) + assert calc.additional_kwargs == {"custom_option": "value", "numeric_param": 42} + + def test_additional_kwargs_empty_by_default(self, clear, concrete_calculator_class, mock_model): + """Test that additional_kwargs is empty dict when no kwargs provided.""" + calc = concrete_calculator_class(mock_model, unique_name="test_9", display_name="Test9") + assert calc.additional_kwargs == {} + + +class TestCalculatorBaseWithRealModel: + """Integration-style tests using actual EasyScience objects.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def real_parameter(self, clear): + """Create a real Parameter object.""" + from easyscience.variable import Parameter + return Parameter("test_param", value=5.0, unit="m") + + @pytest.fixture + def concrete_calculator_class(self): + """Create a concrete implementation that uses model parameters.""" + + class ParameterAwareCalculator(CalculatorBase): + name = "param_aware" + + def calculate(self, x: np.ndarray) -> np.ndarray: + # Access parameter from model if available + if hasattr(self._model, 'get_parameters'): + params = self._model.get_parameters() + if params: + scale = params[0].value + return x * scale + return x + + return ParameterAwareCalculator + + def test_calculator_can_access_model_parameters( + self, clear, concrete_calculator_class, real_parameter + ): + """Test that calculator can access parameters from model.""" + # Create a mock model that returns our real parameter + model = MagicMock() + model.name = "TestModel" + model.get_parameters.return_value = [real_parameter] + + calc = concrete_calculator_class(model, unique_name="test_10", display_name="Test10") + x = np.array([1.0, 2.0, 3.0]) + result = calc.calculate(x) + + # Should multiply by parameter value (5.0) + np.testing.assert_array_equal(result, np.array([5.0, 10.0, 15.0])) diff --git a/tests/unit_tests/fitting/calculators/test_calculator_factory.py b/tests/unit_tests/fitting/calculators/test_calculator_factory.py new file mode 100644 index 00000000..40a9fa61 --- /dev/null +++ b/tests/unit_tests/fitting/calculators/test_calculator_factory.py @@ -0,0 +1,644 @@ +# SPDX-FileCopyrightText: 2025 EasyScience contributors +# SPDX-License-Identifier: BSD-3-Clause +# © 2021-2025 Contributors to the EasyScience project np.ndarray: + return x * 2.0 + + return TestCalculator + + @pytest.fixture + def concrete_factory_class(self, concrete_calculator_class): + """Create a concrete factory implementation.""" + + class TestFactory(CalculatorFactoryBase): + def __init__(self): + self._calc_class = concrete_calculator_class + + @property + def available_calculators(self) -> List[str]: + return ["test"] + + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + if calculator_name != "test": + raise ValueError(f"Unknown calculator: {calculator_name}") + return self._calc_class(model, instrumental_parameters, **kwargs) + + return TestFactory + + # Abstract class enforcement tests + def test_cannot_instantiate_abstract_factory(self): + """Test that CalculatorFactoryBase cannot be instantiated directly.""" + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + CalculatorFactoryBase() + + def test_subclass_must_implement_available_calculators(self, concrete_calculator_class, mock_model): + """Test that subclasses must implement available_calculators property.""" + + class IncompleteFactory(CalculatorFactoryBase): + def create(self, calculator_name, model, instrumental_parameters=None, **kwargs): + return concrete_calculator_class(model, instrumental_parameters, **kwargs) + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteFactory() + + def test_subclass_must_implement_create(self): + """Test that subclasses must implement create method.""" + + class IncompleteFactory(CalculatorFactoryBase): + @property + def available_calculators(self): + return ["test"] + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + IncompleteFactory() + + # Concrete factory tests + def test_factory_available_calculators(self, concrete_factory_class): + """Test available_calculators property.""" + factory = concrete_factory_class() + assert factory.available_calculators == ["test"] + + def test_factory_create_calculator(self, concrete_factory_class, mock_model, mock_instrumental_parameters): + """Test creating a calculator via factory.""" + factory = concrete_factory_class() + calculator = factory.create("test", mock_model, mock_instrumental_parameters) + assert isinstance(calculator, CalculatorBase) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_factory_create_with_model_only(self, concrete_factory_class, mock_model): + """Test creating calculator with only model.""" + factory = concrete_factory_class() + calculator = factory.create("test", mock_model) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is None + + def test_factory_create_unknown_calculator_raises_error(self, concrete_factory_class, mock_model): + """Test that creating unknown calculator raises ValueError.""" + factory = concrete_factory_class() + with pytest.raises(ValueError, match="Unknown calculator"): + factory.create("unknown", mock_model) + + # Repr tests + def test_factory_repr(self, concrete_factory_class): + """Test factory __repr__.""" + factory = concrete_factory_class() + repr_str = repr(factory) + assert "TestFactory" in repr_str + assert "test" in repr_str + + +class TestSimpleCalculatorFactory: + """Tests for SimpleCalculatorFactory class.""" + + @pytest.fixture + def clear(self): + """Clear global map to avoid test contamination.""" + global_object.map._clear() + yield + global_object.map._clear() + + @pytest.fixture + def mock_model(self): + """Create a mock model object.""" + model = MagicMock() + model.name = "MockModel" + model.unique_name = "MockModel" + model.display_name = "MockModel" + return model + + @pytest.fixture + def mock_instrumental_parameters(self): + """Create mock instrumental parameters.""" + params = MagicMock() + params.name = "MockInstrument" + params.unique_name = "MockInstrument" + params.display_name = "MockInstrument" + return params + + @pytest.fixture + def calculator_class_a(self): + """Create first concrete calculator class.""" + + class CalculatorA(CalculatorBase): + name = "calc_a" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2.0 + + return CalculatorA + + @pytest.fixture + def calculator_class_b(self): + """Create second concrete calculator class.""" + + class CalculatorB(CalculatorBase): + name = "calc_b" + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 3.0 + + return CalculatorB + + # Initialization tests + def test_init_empty(self): + """Test initialization with no calculators.""" + factory = SimpleCalculatorFactory() + assert factory.available_calculators == [] + + def test_init_with_calculators_dict(self, calculator_class_a, calculator_class_b): + """Test initialization with calculators dictionary.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + assert set(factory.available_calculators) == {"a", "b"} + + def test_class_level_calculators(self, calculator_class_a): + """Test using class-level _calculators attribute.""" + + class MyFactory(SimpleCalculatorFactory): + pass + + # Set class-level calculators + MyFactory._calculators = {"my_calc": calculator_class_a} + factory = MyFactory() + assert "my_calc" in factory.available_calculators + + # Available calculators tests + def test_available_calculators_returns_list(self, calculator_class_a): + """Test that available_calculators returns a list.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + result = factory.available_calculators + assert isinstance(result, list) + assert "a" in result + + # Create tests + def test_create_calculator(self, calculator_class_a, mock_model, mock_instrumental_parameters): + """Test creating a calculator.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model, mock_instrumental_parameters) + assert isinstance(calculator, CalculatorBase) + assert calculator.model is mock_model + assert calculator.instrumental_parameters is mock_instrumental_parameters + + def test_create_with_kwargs(self, calculator_class_a, mock_model): + """Test creating calculator with additional kwargs.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model, custom_option="value") + assert calculator.additional_kwargs == {"custom_option": "value"} + + def test_create_unknown_calculator_raises_error(self, calculator_class_a, mock_model): + """Test that creating unknown calculator raises ValueError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(ValueError, match="Unknown calculator 'unknown'"): + factory.create("unknown", mock_model) + + def test_create_error_message_includes_available(self, calculator_class_a, calculator_class_b, mock_model): + """Test that error message includes available calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + with pytest.raises(ValueError) as exc_info: + factory.create("unknown", mock_model) + assert "a" in str(exc_info.value) or "b" in str(exc_info.value) + + # Register tests + def test_register_calculator(self, calculator_class_a, calculator_class_b): + """Test registering a new calculator.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + factory.register("b", calculator_class_b) + assert "b" in factory.available_calculators + + def test_register_overwrites_existing(self, calculator_class_a, calculator_class_b): + """Test that registering with existing name overwrites.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + factory.register("a", calculator_class_b) + # Now "a" should create CalculatorB + calc = factory.create("a", MagicMock()) + assert calc.name == "calc_b" + + def test_register_invalid_class_raises_error(self, calculator_class_a): + """Test that registering non-CalculatorBase raises TypeError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + + class NotACalculator: + pass + + with pytest.raises(TypeError, match="must be a subclass of CalculatorBase"): + factory.register("bad", NotACalculator) + + def test_register_non_class_raises_error(self, calculator_class_a): + """Test that registering a non-class raises TypeError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(TypeError, match="must be a subclass of CalculatorBase"): + factory.register("bad", "not a class") + + # Unregister tests + def test_unregister_calculator(self, calculator_class_a, calculator_class_b): + """Test unregistering a calculator.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + factory.unregister("a") + assert "a" not in factory.available_calculators + assert "b" in factory.available_calculators + + def test_unregister_unknown_raises_error(self, calculator_class_a): + """Test that unregistering unknown calculator raises KeyError.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + with pytest.raises(KeyError, match="Calculator 'unknown' is not registered"): + factory.unregister("unknown") + + # Repr tests + def test_repr_empty_factory(self): + """Test __repr__ with empty factory.""" + factory = SimpleCalculatorFactory() + repr_str = repr(factory) + assert "SimpleCalculatorFactory" in repr_str + assert "available=[]" in repr_str + + def test_repr_with_calculators(self, calculator_class_a, calculator_class_b): + """Test __repr__ with calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + repr_str = repr(factory) + assert "SimpleCalculatorFactory" in repr_str + assert "a" in repr_str or "b" in repr_str + + # Integration tests + def test_created_calculator_works(self, calculator_class_a, mock_model): + """Test that created calculator actually works.""" + factory = SimpleCalculatorFactory({"a": calculator_class_a}) + calculator = factory.create("a", mock_model) + x = np.array([1.0, 2.0, 3.0]) + result = calculator.calculate(x) + np.testing.assert_array_equal(result, np.array([2.0, 4.0, 6.0])) + + def test_create_multiple_calculators_independently( + self, calculator_class_a, calculator_class_b, mock_model + ): + """Test creating multiple independent calculators.""" + factory = SimpleCalculatorFactory({ + "a": calculator_class_a, + "b": calculator_class_b, + }) + + model_a = MagicMock(name="ModelA") + model_b = MagicMock(name="ModelB") + + calc_a = factory.create("a", model_a) + calc_b = factory.create("b", model_b) + + # They should be independent + assert calc_a.model is model_a + assert calc_b.model is model_b + assert calc_a is not calc_b + + # And calculate differently + x = np.array([1.0, 2.0]) + np.testing.assert_array_equal(calc_a.calculate(x), np.array([2.0, 4.0])) + np.testing.assert_array_equal(calc_b.calculate(x), np.array([3.0, 6.0])) + + +class TestFactoryStatelessness: + """Tests to verify that the factory is truly stateless.""" + + @pytest.fixture + def calculator_class(self): + """Create a calculator class with counter for instances.""" + + class CountingCalculator(CalculatorBase): + name = "counting" + instance_count = 0 + + def __init__(self, model, instrumental_parameters=None, **kwargs): + super().__init__(model, instrumental_parameters, **kwargs) + CountingCalculator.instance_count += 1 + self.instance_id = CountingCalculator.instance_count + + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + # Reset counter before each test + CountingCalculator.instance_count = 0 + return CountingCalculator + + def test_factory_does_not_store_calculator_instances(self, calculator_class): + """Test that factory doesn't store references to created calculators.""" + factory = SimpleCalculatorFactory({"calc": calculator_class}) + mock_model = MagicMock() + + calc1 = factory.create("calc", mock_model) + calc2 = factory.create("calc", mock_model) + + # Each create should produce a new instance + assert calc1 is not calc2 + assert calc1.instance_id == 1 + assert calc2.instance_id == 2 + + def test_factory_has_no_current_calculator_attribute(self, calculator_class): + """Test that factory has no 'current' calculator state.""" + factory = SimpleCalculatorFactory({"calc": calculator_class}) + + # Should not have any attributes tracking current state + assert not hasattr(factory, "_current_calculator") + assert not hasattr(factory, "current_calculator") + assert not hasattr(factory, "_current") + + def test_multiple_factories_are_independent(self, calculator_class): + """Test that multiple factory instances are independent.""" + factory1 = SimpleCalculatorFactory({"calc": calculator_class}) + factory2 = SimpleCalculatorFactory({"calc": calculator_class}) + + mock_model = MagicMock() + + calc1 = factory1.create("calc", mock_model) + calc2 = factory2.create("calc", mock_model) + + # Each factory creates independent calculators + assert calc1 is not calc2 + + +class TestFactoryIsolation: + """Tests to ensure calculator registries don't bleed between factory instances or subclasses.""" + + @pytest.fixture + def calculator_class_x(self): + """First test calculator class.""" + class CalculatorX(CalculatorBase): + name = "x" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return CalculatorX + + @pytest.fixture + def calculator_class_y(self): + """Second test calculator class.""" + class CalculatorY(CalculatorBase): + name = "y" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2 + return CalculatorY + + @pytest.fixture + def calculator_class_z(self): + """Third test calculator class.""" + class CalculatorZ(CalculatorBase): + name = "z" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 3 + return CalculatorZ + + def test_instance_registration_does_not_affect_other_instances( + self, calculator_class_x, calculator_class_y, calculator_class_z + ): + """Test that registering to one instance doesn't affect others.""" + factory1 = SimpleCalculatorFactory({"x": calculator_class_x}) + factory2 = SimpleCalculatorFactory({"y": calculator_class_y}) + + # Register z to factory1 only + factory1.register("z", calculator_class_z) + + # factory1 should have both x and z + assert "x" in factory1.available_calculators + assert "z" in factory1.available_calculators + assert "y" not in factory1.available_calculators + + # factory2 should only have y + assert "y" in factory2.available_calculators + assert "x" not in factory2.available_calculators + assert "z" not in factory2.available_calculators + + def test_subclass_registration_does_not_affect_parent_or_siblings( + self, calculator_class_x, calculator_class_y + ): + """Test that subclass registries are independent.""" + + class FactoryA(SimpleCalculatorFactory): + _calculators = {"x": calculator_class_x} + + class FactoryB(SimpleCalculatorFactory): + _calculators = {"y": calculator_class_y} + + factory_a = FactoryA() + factory_b = FactoryB() + + # Each should have their own calculators + assert "x" in factory_a.available_calculators + assert "y" not in factory_a.available_calculators + + assert "y" in factory_b.available_calculators + assert "x" not in factory_b.available_calculators + + def test_class_level_registry_not_modified_by_instance_register( + self, calculator_class_x, calculator_class_y + ): + """Test that instance.register() doesn't modify class-level registry.""" + + class MyFactory(SimpleCalculatorFactory): + _calculators = {"x": calculator_class_x} + + # Create instance and register to it + factory = MyFactory() + factory.register("y", calculator_class_y) + + # Instance should have both + assert "x" in factory.available_calculators + assert "y" in factory.available_calculators + + # Create new instance - should NOT have y + factory2 = MyFactory() + assert "x" in factory2.available_calculators + assert "y" not in factory2.available_calculators + + def test_unregister_from_one_instance_does_not_affect_others( + self, calculator_class_x + ): + """Test that unregistering from one instance doesn't affect others.""" + factory1 = SimpleCalculatorFactory({"x": calculator_class_x}) + factory2 = SimpleCalculatorFactory({"x": calculator_class_x}) + + # Unregister from factory1 + factory1.unregister("x") + + # factory1 should not have x + assert "x" not in factory1.available_calculators + + # factory2 should still have x + assert "x" in factory2.available_calculators + + +class TestFactoryErrorHandling: + """Tests for improved error handling and validation.""" + + @pytest.fixture + def calculator_class(self): + """Simple test calculator.""" + class TestCalc(CalculatorBase): + name = "test" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return TestCalc + + def test_register_with_empty_name_raises_error(self, calculator_class): + """Test that empty calculator name raises ValueError.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="non-empty string"): + factory.register("", calculator_class) + + def test_register_with_non_string_name_raises_error(self, calculator_class): + """Test that non-string calculator name raises ValueError.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="non-empty string"): + factory.register(123, calculator_class) + + def test_register_overwrites_with_warning(self, calculator_class): + """Test that overwriting existing calculator issues warning.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + + class NewCalc(CalculatorBase): + name = "new" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x * 2 + + with pytest.warns(UserWarning, match="Overwriting existing calculator 'test'"): + factory.register("test", NewCalc) + + def test_create_with_non_string_name_raises_error(self, calculator_class): + """Test that create with non-string name raises ValueError.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + with pytest.raises(ValueError, match="must be a string"): + factory.create(123, MagicMock()) + + def test_create_with_none_model_raises_error(self, calculator_class): + """Test that create with None model raises TypeError.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + with pytest.raises(TypeError, match="Model cannot be None"): + factory.create("test", None) + + def test_create_unknown_calculator_shows_available_in_error(self, calculator_class): + """Test that error message includes available calculators.""" + factory = SimpleCalculatorFactory({"calc1": calculator_class}) + with pytest.raises(ValueError, match="calc1") as exc_info: + factory.create("unknown", MagicMock()) + assert "Available calculators" in str(exc_info.value) + + def test_create_empty_factory_error_shows_none_available(self): + """Test error message when factory has no calculators.""" + factory = SimpleCalculatorFactory() + with pytest.raises(ValueError, match="none") as exc_info: + factory.create("anything", MagicMock()) + assert "Available calculators: none" in str(exc_info.value) + + def test_create_wraps_calculator_init_errors(self, calculator_class): + """Test that calculator initialization errors are wrapped.""" + + class BrokenCalc(CalculatorBase): + name = "broken" + def __init__(self, model, instrumental_parameters=None, **kwargs): + raise RuntimeError("Something went wrong") + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + + factory = SimpleCalculatorFactory({"broken": BrokenCalc}) + with pytest.raises(RuntimeError, match="Failed to create calculator 'broken'"): + factory.create("broken", MagicMock()) + + +class TestCalculatorKwargsProperty: + """Tests for the additional_kwargs property on CalculatorBase.""" + + @pytest.fixture + def calculator_class(self): + """Simple calculator class for testing.""" + class TestCalc(CalculatorBase): + name = "test" + def calculate(self, x: np.ndarray) -> np.ndarray: + return x + return TestCalc + + def test_additional_kwargs_accessible(self, calculator_class): + """Test that additional_kwargs property is accessible.""" + calc = calculator_class( + MagicMock(), + custom_param="value", + another_option=42 + ) + kwargs = calc.additional_kwargs + assert isinstance(kwargs, dict) + assert kwargs["custom_param"] == "value" + assert kwargs["another_option"] == 42 + + def test_additional_kwargs_empty_when_none_provided(self, calculator_class): + """Test that additional_kwargs is empty dict when no kwargs provided.""" + calc = calculator_class(MagicMock()) + assert calc.additional_kwargs == {} + + def test_additional_kwargs_via_factory(self, calculator_class): + """Test that kwargs passed through factory are accessible.""" + factory = SimpleCalculatorFactory({"test": calculator_class}) + calc = factory.create( + "test", + MagicMock(), + option1="value1", + option2=123 + ) + assert calc.additional_kwargs["option1"] == "value1" + assert calc.additional_kwargs["option2"] == 123 diff --git a/tests/unit_tests/fitting/calculators/test_interface_factory.py b/tests/unit_tests/fitting/calculators/test_interface_factory.py index ce7b8542..fc102664 100644 --- a/tests/unit_tests/fitting/calculators/test_interface_factory.py +++ b/tests/unit_tests/fitting/calculators/test_interface_factory.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause # © 2021-2025 Contributors to the EasyScience project