diff --git a/CHANGELOG.md b/CHANGELOG.md index e3d9cc3e9..b7c3a7986 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Removed `Py_INCREF`/`Py_DECREF` on `Model` in `catchEvent`/`dropEvent` that caused memory leak for imbalanced usage - Used getIndex() instead of ptr() for sorting nonlinear expression terms to avoid nondeterministic behavior ### Changed +- Speed up `SumExpr.__neg__`, `ProdExpr.__neg__` and `Constant.__neg__` via C-level API ### Removed - Removed outdated warning about Make build system incompatibility diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index b3a95de48..1521c874d 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -347,8 +347,8 @@ cdef class Expr: else: raise TypeError(f"Unsupported base type {type(other)} for exponentiation.") - def __neg__(self): - return Expr({v:-c for v,c in self.terms.items()}) + def __neg__(self) -> Expr: + return -1.0 * self def __sub__(self, other): return self + (-other) @@ -712,6 +712,23 @@ cdef class SumExpr(GenExpr): self.coefs = [] self.children = [] self._op = Operator.add + + def __neg__(self) -> SumExpr: + cdef int i = 0, n = len(self.coefs) + cdef list coefs = [0.0] * n + cdef double[:] dest_view = coefs + cdef double[:] src_view = self.coefs + + for i in range(n): + dest_view[i] = -src_view[i] + + cdef SumExpr res = SumExpr.__new__(SumExpr) + res.coefs = coefs + res.children = self.children.copy() + res.constant = -self.constant + res._op = Operator.add + return res + def __repr__(self): return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")" @@ -719,7 +736,7 @@ cdef class SumExpr(GenExpr): cdef double res = self.constant cdef int i = 0, n = len(self.children) cdef list children = self.children - cdef list coefs = self.coefs + cdef double[:] coefs = self.coefs for i in range(n): res += coefs[i] * (children[i])._evaluate(sol) return res @@ -735,6 +752,13 @@ cdef class ProdExpr(GenExpr): self.children = [] self._op = Operator.prod + def __neg__(self) -> ProdExpr: + cdef ProdExpr res = ProdExpr.__new__(ProdExpr) + res.constant = -self.constant + res.children = self.children.copy() + res._op = Operator.prod + return res + def __repr__(self): return self._op + "(" + str(self.constant) + "," + ",".join(map(lambda child : child.__repr__(), self.children)) + ")" @@ -804,11 +828,16 @@ cdef class UnaryExpr(GenExpr): # class for constant expressions cdef class Constant(GenExpr): + cdef public number + def __init__(self,number): self.number = number self._op = Operator.const + def __neg__(self) -> Constant: + return Constant(-self.number) + def __repr__(self): return str(self.number) diff --git a/src/pyscipopt/scip.pyi b/src/pyscipopt/scip.pyi index 2728cf513..de0dbdc85 100644 --- a/src/pyscipopt/scip.pyi +++ b/src/pyscipopt/scip.pyi @@ -343,7 +343,7 @@ class Expr: def __lt__(self, other: object) -> bool: ... def __mul__(self, other: Incomplete) -> Incomplete: ... def __ne__(self, other: object) -> bool: ... - def __neg__(self) -> Incomplete: ... + def __neg__(self) -> Expr: ... def __pow__(self, other: Incomplete, modulo: Incomplete = ...) -> Incomplete: ... def __radd__(self, other: Incomplete) -> Incomplete: ... def __rmul__(self, other: Incomplete) -> Incomplete: ... @@ -386,7 +386,7 @@ class GenExpr: def __lt__(self, other: object) -> bool: ... def __mul__(self, other: Incomplete) -> Incomplete: ... def __ne__(self, other: object) -> bool: ... - def __neg__(self) -> Incomplete: ... + def __neg__(self) -> GenExpr: ... def __pow__(self, other: Incomplete, modulo: Incomplete = ...) -> Incomplete: ... def __radd__(self, other: Incomplete) -> Incomplete: ... def __rmul__(self, other: Incomplete) -> Incomplete: ... diff --git a/tests/test_expr.py b/tests/test_expr.py index a4e739b76..06bd72631 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -2,8 +2,8 @@ import pytest -from pyscipopt import Model, sqrt, log, exp, sin, cos -from pyscipopt.scip import Expr, GenExpr, ExprCons, CONST +from pyscipopt import Model, cos, exp, log, sin, sqrt +from pyscipopt.scip import CONST, Constant, Expr, ExprCons, GenExpr, ProdExpr, SumExpr @pytest.fixture(scope="module") @@ -219,6 +219,36 @@ def test_getVal_with_GenExpr(): m.getVal(1 / z) +def test_neg(): + m = Model() + x = m.addVar(name="x") + + expr = (x + 1) ** 3 + neg_expr = -expr + assert isinstance(expr, Expr) + assert isinstance(neg_expr, Expr) + assert ( + str(neg_expr) + == "Expr({Term(x, x, x): -1.0, Term(x, x): -3.0, Term(x): -3.0, Term(): -1.0})" + ) + + base = sqrt(x) + expr = base * -1 + neg_expr = -expr + assert isinstance(expr, ProdExpr) + assert isinstance(neg_expr, ProdExpr) + assert str(neg_expr) == "prod(1.0,sqrt(sum(0.0,prod(1.0,x))))" + + expr = base + x - 1 + neg_expr = -expr + assert isinstance(expr, SumExpr) + assert isinstance(neg_expr, SumExpr) + assert str(neg_expr) == "sum(1.0,sqrt(sum(0.0,prod(1.0,x))),prod(1.0,x))" + assert list(neg_expr.coefs) == [-1, -1] + + assert str(-Constant(3.0)) == "-3.0" + + def test_mul(): m = Model() x = m.addVar(name="x")