diff --git a/src/pyvmcon/problem.py b/src/pyvmcon/problem.py index bb3f305..e42c771 100644 --- a/src/pyvmcon/problem.py +++ b/src/pyvmcon/problem.py @@ -2,19 +2,22 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from typing import NamedTuple, TypeVar +from dataclasses import dataclass, field +from typing import TypeVar, cast import numpy as np +from numpy.typing import NDArray -ScalarType = TypeVar("ScalarType", np.ndarray, np.number, float) +ScalarType = TypeVar("ScalarType", NDArray, np.number, float) """A scalar variable e.g. a single number (which could be a 0D numpy array)""" -VectorType = TypeVar("VectorType", bound=np.ndarray) +VectorType = NDArray """A numpy array with only 1 dimension""" -MatrixType = TypeVar("MatrixType", bound=np.ndarray) +MatrixType = NDArray """A numpy array with 2 dimensions""" -class Result(NamedTuple): +@dataclass +class Result: """The data from calling a problem.""" f: ScalarType @@ -80,6 +83,7 @@ def total_constraints(self) -> int: _VectorReturnFunctionAlias = Callable[[VectorType], VectorType] +@dataclass class Problem(AbstractProblem): """A simple implementation of an AbstractProblem. @@ -89,42 +93,36 @@ class Problem(AbstractProblem): feasible when they return a value >= 0. """ - def __init__( - self, - f: _ScalarReturnFunctionAlias, - df: _VectorReturnFunctionAlias, - equality_constraints: list[_ScalarReturnFunctionAlias] | None = None, - inequality_constraints: list[_ScalarReturnFunctionAlias] | None = None, - dequality_constraints: list[_VectorReturnFunctionAlias] | None = None, - dinequality_constraints: list[_VectorReturnFunctionAlias] | None = None, - ) -> None: - """Construct the problem.""" - super().__init__() - - self._f = f - self._df = df - self._equality_constraints = equality_constraints or [] - self._inequality_constraints = inequality_constraints or [] - self._dequality_constraints = dequality_constraints or [] - self._dinequality_constraints = dinequality_constraints or [] + f: _ScalarReturnFunctionAlias + df: _VectorReturnFunctionAlias + equality_constraints: list[_ScalarReturnFunctionAlias] = field(default_factory=list) + inequality_constraints: list[_ScalarReturnFunctionAlias] = field( + default_factory=list + ) + dequality_constraints: list[_VectorReturnFunctionAlias] = field( + default_factory=list + ) + dinequality_constraints: list[_VectorReturnFunctionAlias] = field( + default_factory=list + ) def __call__(self, x: VectorType) -> Result: """Evaluate the problem at input point x.""" return Result( - self._f(x), - self._df(x), - np.array([c(x) for c in self._equality_constraints]), - np.array([c(x) for c in self._dequality_constraints]), - np.array([c(x) for c in self._inequality_constraints]), - np.array([c(x) for c in self._dinequality_constraints]), + self.f(x), + self.df(x), + cast("VectorType", np.array([c(x) for c in self.equality_constraints])), + cast("MatrixType", np.array([c(x) for c in self.dequality_constraints])), + cast("VectorType", np.array([c(x) for c in self.inequality_constraints])), + cast("MatrixType", np.array([c(x) for c in self.dinequality_constraints])), ) @property def num_equality(self) -> int: """The number of equality constraints this problem has.""" - return len(self._equality_constraints) + return len(self.equality_constraints) @property def num_inequality(self) -> int: """The number of inequality constraints this problem has.""" - return len(self._inequality_constraints) + return len(self.inequality_constraints) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index ef69352..3dce4dd 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -2,18 +2,19 @@ import logging from collections.abc import Callable -from typing import Any +from typing import Any, cast import cvxpy as cp import numpy as np +from pyvmcon.problem import AbstractProblem, MatrixType, Result, ScalarType, VectorType + from .exceptions import ( LineSearchConvergenceException, QSPSolverException, VMCONConvergenceException, _QspSolveException, ) -from .problem import AbstractProblem, MatrixType, Result, ScalarType, VectorType logger = logging.getLogger(__name__) @@ -439,8 +440,10 @@ def perform_linesearch( mu_inequality = _calculate_mu_j(mu_inequality, lamda_inequality) def phi(result: Result) -> ScalarType: - sum_equality = (mu_equality * np.abs(result.eq)).sum() - sum_inequality = (mu_inequality * np.abs(np.minimum(0, result.ie))).sum() + sum_equality: ScalarType = (mu_equality * np.abs(result.eq)).sum() + sum_inequality: ScalarType = ( + mu_inequality * np.abs(np.minimum(0, result.ie)) + ).sum() return result.f + sum_equality + sum_inequality @@ -476,7 +479,7 @@ def phi(result: Result) -> ScalarType: lamda_inequality=lamda_inequality, ) - return alpha, mu_equality, mu_inequality, new_result + return cast("ScalarType", alpha), mu_equality, mu_inequality, new_result def _derivative_lagrangian(