From 155b1e96410e8ac343e95077b2cac7736d484267 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Wed, 18 Jun 2025 15:23:02 -0400 Subject: [PATCH 1/3] initial --- src/L2ODLL.jl | 3 +++ src/projection.jl | 32 +++++++++++++++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/L2ODLL.jl b/src/L2ODLL.jl index c3d4c32..9a3174e 100644 --- a/src/L2ODLL.jl +++ b/src/L2ODLL.jl @@ -176,6 +176,9 @@ end function y_shape(cache::DLLCache) return length.(get_y_dual(cache.dual_model, cache.decomposition)) end +function y_shape(dual_model::JuMP.Model, decomposition::AbstractDecomposition) + return length.(get_y_dual(dual_model, decomposition)) +end """ flatten_y(y::AbstractVector) diff --git a/src/projection.jl b/src/projection.jl index c088bd9..2694e9f 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -31,4 +31,34 @@ function get_y_sets(dual_model, decomposition) isnothing(set) ? nothing : MOI.get(dual_model, MOI.ConstraintSet(), set) for set in get_y_constraint(dual_model, decomposition) ] -end \ No newline at end of file +end + +function make_jump_proj_fn(decomposition::AbstractDecomposition, dual_model::JuMP.Model, optimizer) + sets = get_y_sets(dual_model, decomposition) + shapes = y_shape(dual_model, decomposition) + + proj_model = JuMP.Model(optimizer) + + idxs = [(i, ji) for (i,j) in enumerate(shapes) for ji in 1:j] + JuMP.@variable(proj_model, y[idxs]) + + for (i, set) in enumerate(sets) + isnothing(set) && continue + y_vars = filter(ij->first(ij)==i, idxs) + if length(y_vars) == 1 + JuMP.@constraint(proj_model, y[only(y_vars)] ∈ set) + else + JuMP.@constraint(proj_model, y[y_vars] ∈ set) + end + end + + return (y_prediction) -> begin + JuMP.set_objective_function(proj_model, sum((y .- reduce(vcat, y_prediction)).^2)) + # TODO: ensure reduce(vcat) and idxs are same order + JuMP.set_objective_sense(proj_model, MOI.MIN_SENSE) + JuMP.optimize!(proj_model) + return value.(y) + end +end + + \ No newline at end of file From ff35a14a40f6f72e5ef28981427654e672f71c17 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Thu, 19 Jun 2025 15:33:43 -0400 Subject: [PATCH 2/3] up --- src/projection.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 2694e9f..0809f6e 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -33,7 +33,7 @@ function get_y_sets(dual_model, decomposition) ] end -function make_jump_proj_fn(decomposition::AbstractDecomposition, dual_model::JuMP.Model, optimizer) +function make_jump_proj_fn(decomposition::AbstractDecomposition, dual_model::JuMP.Model, optimizer; silent=true) sets = get_y_sets(dual_model, decomposition) shapes = y_shape(dual_model, decomposition) @@ -52,12 +52,21 @@ function make_jump_proj_fn(decomposition::AbstractDecomposition, dual_model::JuM end end + silent && JuMP.set_silent(proj_model) + proj_model.ext[:🔒] = ReentrantLock() + # TODO: define frule/rrule using Moreau return (y_prediction) -> begin - JuMP.set_objective_function(proj_model, sum((y .- reduce(vcat, y_prediction)).^2)) - # TODO: ensure reduce(vcat) and idxs are same order - JuMP.set_objective_sense(proj_model, MOI.MIN_SENSE) - JuMP.optimize!(proj_model) - return value.(y) + lock(proj_model.ext[:🔒]) + try + JuMP.set_objective_function(proj_model, sum((y .- reduce(vcat, y_prediction)).^2)) + JuMP.set_objective_sense(proj_model, MOI.MIN_SENSE) + JuMP.optimize!(proj_model) + JuMP.assert_is_solved_and_feasible(proj_model) + + value.(y) + finally + unlock(proj_model.ext[:🔒]) + end end end From 2321b842ea699000fdc94c3203d0dfca12d4d945 Mon Sep 17 00:00:00 2001 From: "Klamkin, Michael" Date: Mon, 23 Jun 2025 20:33:17 -0400 Subject: [PATCH 3/3] use flatten_y --- src/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index 0809f6e..15ad8ab 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -58,7 +58,7 @@ function make_jump_proj_fn(decomposition::AbstractDecomposition, dual_model::JuM return (y_prediction) -> begin lock(proj_model.ext[:🔒]) try - JuMP.set_objective_function(proj_model, sum((y .- reduce(vcat, y_prediction)).^2)) + JuMP.set_objective_function(proj_model, sum((y .- flatten_y(y_prediction)).^2)) JuMP.set_objective_sense(proj_model, MOI.MIN_SENSE) JuMP.optimize!(proj_model) JuMP.assert_is_solved_and_feasible(proj_model)