Skip to content

Commit ba96983

Browse files
committed
public get_y, get_y_dual, y_shape
1 parent 345d9ce commit ba96983

5 files changed

Lines changed: 49 additions & 11 deletions

File tree

src/L2ODLL.jl

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,56 @@ function make_vector_data(cache::DLLCache; M=SparseArrays.SparseMatrixCSC{Float6
132132
completion_data = convert(VectorStandardFormData{M,V,T}, model_to_data(completion_model))
133133
return completion_data, y_sets, (p_ref, y_ref, ref_map)
134134
end
135-
function get_y(cache::DLLCache)
136-
return get_y(cache.dual_model, cache.decomposition)
135+
136+
"""
137+
get_y(model::JuMP.Model)
138+
139+
Get the primal constraints corresponding to the `y` variables in the decomposition.
140+
"""
141+
function get_y(model::JuMP.Model)
142+
return get_cache(model).decomposition.y_ref
143+
end
144+
145+
"""
146+
get_y_dual(model::JuMP.Model)
147+
148+
Get the dual variables corresponding to the `y` variables in the decomposition.
149+
These are VariableRefs belonging to the dual model, not the passed-in `model`.
150+
"""
151+
function get_y_dual(model::JuMP.Model)
152+
return get_y_dual(get_cache(model))
153+
end
154+
function get_y_dual(cache::DLLCache)
155+
return get_y_dual(cache.dual_model, cache.decomposition)
137156
end
138157

158+
"""
159+
y_shape(model::JuMP.Model)
160+
161+
Get the shape of the `y` variables in the decomposition.
162+
This is a Vector{Int} where each entry is the number of dual variables for that constraint.
163+
"""
164+
function y_shape(model::JuMP.Model)
165+
return y_shape(get_cache(model))
166+
end
139167
function y_shape(cache::DLLCache)
140-
return length.(get_y(cache.dual_model, cache.decomposition))
168+
return length.(get_y_dual(cache.dual_model, cache.decomposition))
141169
end
142170

171+
"""
172+
flatten_y(y::AbstractVector)
173+
174+
Flatten a vector of `y` variables into a single vector, i.e. Vector{Vector{Float64}} -> Vector{Float64}.
175+
"""
143176
function flatten_y(y::AbstractVector)
144177
return reduce(vcat, y)
145178
end
146179

180+
"""
181+
unflatten_y(y::AbstractVector, y_shape::AbstractVector{Int})
182+
183+
Unflatten a vector of flattened `y` variables into a vector of vectors, i.e. Vector{Float64} -> Vector{Vector{Float64}}.
184+
"""
147185
function unflatten_y(y::AbstractVector, y_shape::AbstractVector{Int})
148186
return [y[start_idx:start_idx + shape - 1] for (start_idx, shape) in enumerate(y_shape)]
149187
end

src/layers/bounded.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function bounded_builder(decomposition::BoundDecomposition, proj_fn, dual_model::JuMP.Model; completion=:exact, μ=1.0)
3030
p_vars = get_p(dual_model, decomposition)
31-
y_vars = get_y(dual_model, decomposition)
31+
y_vars = get_y_dual(dual_model, decomposition)
3232
zl_vars = only.(get_zl(dual_model, decomposition))
3333
zu_vars = only.(get_zu(dual_model, decomposition))
3434
types = filter(

src/layers/convex_qp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ end
3333

3434
function convex_qp_builder(decomposition::ConvexQP, proj_fn, dual_model::JuMP.Model)
3535
p_vars = get_p(dual_model, decomposition)
36-
y_vars = get_y(dual_model, decomposition)
36+
y_vars = get_y_dual(dual_model, decomposition)
3737
x_vars = get_x(decomposition)
3838
if !all(x -> has_quadslack(dual_model, x), x_vars)
3939
@warn "Some primal variables do not have a quadratic objective term, " *

src/layers/generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ function _make_completion_model(decomposition::AbstractDecomposition, dual_model
7878

7979
# mark y and p as parameters (optimizing over z only)
8080
p_ref = getindex.(ref_map, get_p(dual_model, decomposition))
81-
y_ref = getindex.(ref_map, get_y(dual_model, decomposition))
81+
y_ref = getindex.(ref_map, get_y_dual(dual_model, decomposition))
8282
y_ref_flat = reduce(vcat, y_ref)
8383
JuMP.@constraint(completion_model, y_ref_flat .∈ MOI.Parameter.(zeros(length(y_ref_flat))))
8484
JuMP.@constraint(completion_model, p_ref .∈ MOI.Parameter.(zeros(length(p_ref))))
@@ -93,7 +93,7 @@ function make_vector_data(decomposition::AbstractDecomposition, dual_model::JuMP
9393
return data, y_sets, (p_ref, y_ref, ref_map)
9494
end
9595

96-
function get_y(dual_model, decomposition::AbstractDecomposition)
96+
function get_y_dual(dual_model, decomposition::AbstractDecomposition)
9797
return Dualization._get_dual_variables.(dual_model, decomposition.y_ref)
9898
end
9999

test/runtests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ SOLVER = () -> ParametricOptInterface.Optimizer(HiGHS.Optimizer());
4848

4949
L2ODLL.decompose!(m);
5050

51-
cqp_y_pred = randn_like(L2ODLL.get_y(L2ODLL.get_cache(m)));
51+
cqp_y_pred = randn_like(L2ODLL.get_y_dual(m));
5252

5353
dobj = L2ODLL.dual_objective(m, cqp_y_pred, param_value)
5454
dobj_wrt_y = L2ODLL.dual_objective_gradient(m, cqp_y_pred, param_value)
@@ -65,7 +65,7 @@ SOLVER = () -> ParametricOptInterface.Optimizer(HiGHS.Optimizer());
6565
JuMP.set_optimizer(m, Clarabel.Optimizer);
6666
JuMP.set_silent(m);
6767
JuMP.optimize!(m)
68-
cqp_y_true = JuMP.dual.(L2ODLL.get_cache(m).decomposition.y_ref)
68+
cqp_y_true = JuMP.dual.(L2ODLL.get_y(m))
6969
dobj1 = L2ODLL.dual_objective(m, cqp_y_true, param_value)
7070
@test isapprox(dobj1, JuMP.objective_value(m), atol=1e-6)
7171
end
@@ -91,7 +91,7 @@ SOLVER = () -> ParametricOptInterface.Optimizer(HiGHS.Optimizer());
9191

9292
L2ODLL.decompose!(m);
9393

94-
blp_y_pred = randn_like(L2ODLL.get_y(L2ODLL.get_cache(m)));
94+
blp_y_pred = randn_like(L2ODLL.get_y_dual(m));
9595

9696
dobj = L2ODLL.dual_objective(m, blp_y_pred, param_value)
9797
dobj_wrt_y = L2ODLL.dual_objective_gradient(m, blp_y_pred, param_value)
@@ -113,7 +113,7 @@ SOLVER = () -> ParametricOptInterface.Optimizer(HiGHS.Optimizer());
113113
JuMP.set_optimizer(m, Clarabel.Optimizer);
114114
JuMP.set_silent(m);
115115
JuMP.optimize!(m)
116-
blp_y_true = JuMP.dual.(L2ODLL.get_cache(m).decomposition.y_ref)
116+
blp_y_true = JuMP.dual.(L2ODLL.get_y(m))
117117

118118
dobj1 = L2ODLL.dual_objective(m, blp_y_true, param_value)
119119
@test isapprox(dobj1, JuMP.objective_value(m), atol=1e-6)

0 commit comments

Comments
 (0)