Skip to content

Commit 09a222b

Browse files
committed
Refactor expression to node conversion
Introduces _to_nodes methods for Expr, PolynomialExpr, and UnaryExpr to convert expressions into node lists for SCIP construction. Refactors Model's constraint creation to use the new node format, simplifying and clarifying the mapping from expression trees to SCIP nonlinear constraints.
1 parent 69737c0 commit 09a222b

File tree

2 files changed

+104
-98
lines changed

2 files changed

+104
-98
lines changed

src/pyscipopt/expr.pxi

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,23 @@ cdef class Expr:
210210
def _normalize(self) -> Expr:
211211
return self
212212

213+
def _to_nodes(self, start: int = 0) -> list[tuple]:
214+
"""Convert expression to list of nodes for SCIP expression construction"""
215+
nodes, indices = [], []
216+
for i in self:
217+
nodes.extend(i._to_nodes(start + len(nodes)))
218+
indices.append(start + len(nodes) - 1)
219+
220+
if type(self) is PowExpr:
221+
nodes.append((ConstExpr, self.expo))
222+
indices.append(start + len(nodes) - 1)
223+
elif type(self) is ProdExpr and self.coef != 1:
224+
nodes.append((ConstExpr, self.coef))
225+
indices.append(start + len(nodes) - 1)
226+
227+
nodes.append((type(self), indices))
228+
return nodes
229+
213230

214231
cdef class SumExpr(Expr):
215232
"""Expression like `expression1 + expression2 + constant`."""
@@ -301,6 +318,24 @@ class PolynomialExpr(SumExpr):
301318
{k: v for k, v in self.children.items() if v != 0.0}
302319
)
303320

321+
def _to_nodes(self, start: int = 0) -> list[tuple]:
322+
"""Convert expression to list of nodes for SCIP expression construction"""
323+
nodes = []
324+
for child, coef in self.children.items():
325+
if coef != 0:
326+
if child == CONST:
327+
nodes.append((ConstExpr, coef))
328+
else:
329+
ind = start + len(nodes)
330+
nodes.extend([(Term, i) for i in child.vars])
331+
if coef != 1:
332+
nodes.append((ConstExpr, coef))
333+
if len(child) > 1:
334+
nodes.append((ProdExpr, list(range(ind, len(nodes)))))
335+
if len(nodes) > 1:
336+
nodes.append((SumExpr, list(range(start, start + len(nodes)))))
337+
return nodes
338+
304339

305340
class ConstExpr(PolynomialExpr):
306341
"""Expression representing for `constant`."""
@@ -411,6 +446,15 @@ cdef class UnaryExpr(FuncExpr):
411446
def __repr__(self):
412447
return f"{type(self).__name__}({tuple(self)[0]})"
413448

449+
def _to_nodes(self, start: int = 0) -> list[tuple]:
450+
"""Convert expression to list of nodes for SCIP expression construction"""
451+
nodes = []
452+
for i in self:
453+
nodes.extend(i._to_nodes(start + len(nodes)))
454+
455+
nodes.append((type(self), start + len(nodes) - 1))
456+
return nodes
457+
414458
cdef float _evaluate(self, SCIP* scip, SCIP_SOL* sol):
415459
return self.op(_evaluate(self.children, scip, sol))
416460

src/pyscipopt/scip.pxi

Lines changed: 60 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -5557,122 +5557,84 @@ cdef class Model:
55575557
Constraint
55585558
55595559
"""
5560-
cdef SCIP_EXPR** childrenexpr
5561-
cdef SCIP_EXPR** scipexprs
5560+
cdef SCIP_EXPR** children_expr
5561+
cdef SCIP_EXPR** scip_exprs
55625562
cdef SCIP_CONS* scip_cons
55635563
cdef _VarArray wrapper
55645564
cdef int nchildren
55655565
cdef int c
55665566
cdef int i
55675567

5568-
# get arrays from python's expression tree
5569-
nodes = expr_to_nodes(cons.expr)
5570-
5571-
# in nodes we have a list of tuples: each tuple is of the form
5572-
# (operator, [indices]) where indices are the indices of the tuples
5573-
# that are the children of this operator. This is sorted,
5574-
# so we are going to do is:
5575-
# loop over the nodes and create the expression of each
5576-
# Note1: when the operator is Operator.const, [indices] stores the value
5577-
# Note2: we need to compute the number of variable operators to find out
5578-
# how many variables are there.
5579-
nvars = 0
5580-
for node in nodes:
5581-
if node[0] == Operator.varidx:
5582-
nvars += 1
5583-
5584-
scipexprs = <SCIP_EXPR**> malloc(len(nodes) * sizeof(SCIP_EXPR*))
5585-
for i,node in enumerate(nodes):
5586-
opidx = node[0]
5587-
if opidx == Operator.varidx:
5588-
assert len(node[1]) == 1
5589-
pyvar = node[1][0] # for vars we store the actual var!
5590-
wrapper = _VarArray(pyvar)
5591-
PY_SCIP_CALL( SCIPcreateExprVar(self._scip, &scipexprs[i], wrapper.ptr[0], NULL, NULL) )
5592-
continue
5593-
if opidx == Operator.const:
5594-
assert len(node[1]) == 1
5595-
value = node[1][0]
5596-
PY_SCIP_CALL( SCIPcreateExprValue(self._scip, &scipexprs[i], <SCIP_Real>value, NULL, NULL) )
5597-
continue
5598-
if opidx == Operator.add:
5599-
nchildren = len(node[1])
5600-
childrenexpr = <SCIP_EXPR**> malloc(nchildren * sizeof(SCIP_EXPR*))
5568+
nodes = cons.expr._to_nodes()
5569+
scip_exprs = <SCIP_EXPR**> malloc(len(nodes) * sizeof(SCIP_EXPR*))
5570+
for i, (e_type, value) in enumerate(nodes):
5571+
if e_type is Term:
5572+
wrapper = _VarArray(value)
5573+
PY_SCIP_CALL(SCIPcreateExprVar(self._scip, &scip_exprs[i], wrapper.ptr[0], NULL, NULL))
5574+
elif e_type is ConstExpr:
5575+
PY_SCIP_CALL(SCIPcreateExprValue(self._scip, &scip_exprs[i], <SCIP_Real>value, NULL, NULL))
5576+
5577+
elif e_type is SumExpr:
5578+
nchildren = len(value)
5579+
children_expr = <SCIP_EXPR**> malloc(nchildren * sizeof(SCIP_EXPR*))
56015580
coefs = <SCIP_Real*> malloc(nchildren * sizeof(SCIP_Real))
5602-
for c, pos in enumerate(node[1]):
5603-
childrenexpr[c] = scipexprs[pos]
5581+
for c, pos in enumerate(value):
5582+
children_expr[c] = scip_exprs[pos]
56045583
coefs[c] = 1
5605-
PY_SCIP_CALL( SCIPcreateExprSum(self._scip, &scipexprs[i], nchildren, childrenexpr, coefs, 0, NULL, NULL))
5584+
5585+
PY_SCIP_CALL(SCIPcreateExprSum(self._scip, &scip_exprs[i], nchildren, children_expr, coefs, 0, NULL, NULL))
56065586
free(coefs)
5607-
free(childrenexpr)
5608-
continue
5609-
if opidx == Operator.prod:
5610-
nchildren = len(node[1])
5611-
childrenexpr = <SCIP_EXPR**> malloc(nchildren * sizeof(SCIP_EXPR*))
5612-
for c, pos in enumerate(node[1]):
5613-
childrenexpr[c] = scipexprs[pos]
5614-
PY_SCIP_CALL( SCIPcreateExprProduct(self._scip, &scipexprs[i], nchildren, childrenexpr, 1, NULL, NULL) )
5615-
free(childrenexpr)
5616-
continue
5617-
if opidx == Operator.power:
5618-
# the second child is the exponent which is a const
5619-
valuenode = nodes[node[1][1]]
5620-
assert valuenode[0] == Operator.const
5621-
exponent = valuenode[1][0]
5622-
PY_SCIP_CALL( SCIPcreateExprPow(self._scip, &scipexprs[i], scipexprs[node[1][0]], <SCIP_Real>exponent, NULL, NULL ))
5623-
continue
5624-
if opidx == Operator.exp:
5625-
assert len(node[1]) == 1
5626-
PY_SCIP_CALL( SCIPcreateExprExp(self._scip, &scipexprs[i], scipexprs[node[1][0]], NULL, NULL ))
5627-
continue
5628-
if opidx == Operator.log:
5629-
assert len(node[1]) == 1
5630-
PY_SCIP_CALL( SCIPcreateExprLog(self._scip, &scipexprs[i], scipexprs[node[1][0]], NULL, NULL ))
5631-
continue
5632-
if opidx == Operator.sqrt:
5633-
assert len(node[1]) == 1
5634-
PY_SCIP_CALL( SCIPcreateExprPow(self._scip, &scipexprs[i], scipexprs[node[1][0]], <SCIP_Real>0.5, NULL, NULL) )
5635-
continue
5636-
if opidx == Operator.sin:
5637-
assert len(node[1]) == 1
5638-
PY_SCIP_CALL( SCIPcreateExprSin(self._scip, &scipexprs[i], scipexprs[node[1][0]], NULL, NULL) )
5639-
continue
5640-
if opidx == Operator.cos:
5641-
assert len(node[1]) == 1
5642-
PY_SCIP_CALL( SCIPcreateExprCos(self._scip, &scipexprs[i], scipexprs[node[1][0]], NULL, NULL) )
5643-
continue
5644-
if opidx == Operator.fabs:
5645-
assert len(node[1]) == 1
5646-
PY_SCIP_CALL( SCIPcreateExprAbs(self._scip, &scipexprs[i], scipexprs[node[1][0]], NULL, NULL ))
5647-
continue
5648-
# default:
5649-
raise NotImplementedError
5587+
free(children_expr)
5588+
5589+
elif e_type is ProdExpr:
5590+
nchildren = len(value)
5591+
children_expr = <SCIP_EXPR**> malloc(nchildren * sizeof(SCIP_EXPR*))
5592+
for c, pos in enumerate(value):
5593+
children_expr[c] = scip_exprs[pos]
5594+
5595+
PY_SCIP_CALL(SCIPcreateExprProduct(self._scip, &scip_exprs[i], nchildren, children_expr, 1, NULL, NULL))
5596+
free(children_expr)
5597+
5598+
elif e_type is PowExpr:
5599+
PY_SCIP_CALL(SCIPcreateExprPow(self._scip, &scip_exprs[i], scip_exprs[value[0]], <SCIP_Real>nodes[value[1]][1], NULL, NULL))
5600+
elif e_type is ExpExpr:
5601+
PY_SCIP_CALL(SCIPcreateExprExp(self._scip, &scip_exprs[i], scip_exprs[value], NULL, NULL))
5602+
elif e_type is LogExpr:
5603+
PY_SCIP_CALL(SCIPcreateExprLog(self._scip, &scip_exprs[i], scip_exprs[value], NULL, NULL))
5604+
elif e_type is SqrtExpr:
5605+
PY_SCIP_CALL(SCIPcreateExprPow(self._scip, &scip_exprs[i], scip_exprs[value], <SCIP_Real>0.5, NULL, NULL))
5606+
elif e_type is SinExpr:
5607+
PY_SCIP_CALL(SCIPcreateExprSin(self._scip, &scip_exprs[i], scip_exprs[value], NULL, NULL))
5608+
elif e_type is CosExpr:
5609+
PY_SCIP_CALL(SCIPcreateExprCos(self._scip, &scip_exprs[i], scip_exprs[value], NULL, NULL))
5610+
elif e_type is AbsExpr:
5611+
PY_SCIP_CALL(SCIPcreateExprAbs(self._scip, &scip_exprs[i], scip_exprs[value], NULL, NULL))
5612+
else:
5613+
raise NotImplementedError(f"{e_type} not implemented yet")
56505614

56515615
# create nonlinear constraint for the expression root
56525616
PY_SCIP_CALL(SCIPcreateConsNonlinear(
56535617
self._scip,
56545618
&scip_cons,
5655-
str_conversion(kwargs['name']),
5656-
scipexprs[len(nodes) - 1],
5657-
kwargs['lhs'],
5658-
kwargs['rhs'],
5659-
kwargs['initial'],
5660-
kwargs['separate'],
5661-
kwargs['enforce'],
5662-
kwargs['check'],
5663-
kwargs['propagate'],
5664-
kwargs['local'],
5665-
kwargs['modifiable'],
5666-
kwargs['dynamic'],
5667-
kwargs['removable']),
5619+
str_conversion(kwargs["name"]),
5620+
scip_exprs[len(nodes) - 1],
5621+
kwargs["lhs"],
5622+
kwargs["rhs"],
5623+
kwargs["initial"],
5624+
kwargs["separate"],
5625+
kwargs["enforce"],
5626+
kwargs["check"],
5627+
kwargs["propagate"],
5628+
kwargs["local"],
5629+
kwargs["modifiable"],
5630+
kwargs["dynamic"],
5631+
kwargs["removable"]),
56685632
)
5669-
56705633
PyCons = Constraint.create(scip_cons)
56715634
for i in range(len(nodes)):
5672-
PY_SCIP_CALL( SCIPreleaseExpr(self._scip, &scipexprs[i]) )
5635+
PY_SCIP_CALL(SCIPreleaseExpr(self._scip, &scip_exprs[i]))
56735636

5674-
# free more memory
5675-
free(scipexprs)
5637+
free(scip_exprs)
56765638
return PyCons
56775639

56785640
def createConsFromExpr(self, cons, name='', initial=True, separate=True,

0 commit comments

Comments
 (0)