Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ext/MathOptInterfaceExt/MathOptInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,28 @@ mutable struct ForwardModeData{T}
param_perturbations::Dict{MOI.ConstraintIndex, T}
primal_sensitivities::Dict{MOI.VariableIndex, T}
dual_sensitivities::Dict{MOI.ConstraintIndex, T}
vector_dual_sensitivities::Dict{MOI.ConstraintIndex, Vector{T}}
objective_sensitivity::T
end
ForwardModeData{T}() where {T} = ForwardModeData{T}(
Dict{MOI.ConstraintIndex, T}(),
Dict{MOI.VariableIndex, T}(),
Dict{MOI.ConstraintIndex, T}(),
Dict{MOI.ConstraintIndex, Vector{T}}(),
zero(T),
)

mutable struct ReverseModeData{T}
primal_seeds::Dict{MOI.VariableIndex, T}
dual_seeds::Dict{MOI.ConstraintIndex, T}
vector_dual_seeds::Dict{MOI.ConstraintIndex, Vector{T}}
param_outputs::Dict{MOI.ConstraintIndex, T}
dobj::Union{Nothing, T}
end
ReverseModeData{T}() where {T} = ReverseModeData{T}(
Dict{MOI.VariableIndex, T}(),
Dict{MOI.ConstraintIndex, T}(),
Dict{MOI.ConstraintIndex, Vector{T}}(),
Dict{MOI.ConstraintIndex, T}(),
nothing,
)
Expand Down
32 changes: 30 additions & 2 deletions ext/MathOptInterfaceExt/forward_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function _forward_differentiate_impl!(model::Optimizer{OT, T}) where {OT, T}
dy = _get_dy_cache!(model, n_con)
dy .= (.-obj_sign) .* dy_cpu

_store_dual_sensitivities!(model.forward.dual_sensitivities, inner, dy)
_store_dual_sensitivities!(model.forward.dual_sensitivities, model.forward.vector_dual_sensitivities, inner, dy)
_store_bound_dual_sensitivities!(model, sens, result, inner)
model.forward.objective_sensitivity = result.dobj[]
return
Expand All @@ -68,10 +68,30 @@ function _constraint_row(inner, ci::MOI.ConstraintIndex{F, S}) where {F, S}
end
end

function _store_dual_sensitivities!(dual_sensitivities, inner, dy)
function _vno_rows(
inner,
ci::MOI.ConstraintIndex{MOI.VectorOfVariables, MOI.VectorNonlinearOracle{Float64}},
)
offset = length(inner.qp_data)
for i in 1:(ci.value - 1)
_, s = inner.vector_nonlinear_oracle_constraints[i]
offset += s.set.output_dimension
end
_, s = inner.vector_nonlinear_oracle_constraints[ci.value]
return offset .+ (1:s.set.output_dimension)
end

function _store_dual_sensitivities!(dual_sensitivities, vector_dual_sensitivities, inner, dy)
for (F, S) in MOI.get(inner, MOI.ListOfConstraintTypesPresent())
F == MOI.VariableIndex && continue
S <: MOI.Parameter && continue
if F == MOI.VectorOfVariables && S == MOI.VectorNonlinearOracle{Float64}
for ci in MOI.get(inner, MOI.ListOfConstraintIndices{F, S}())
rows = _vno_rows(inner, ci)
vector_dual_sensitivities[ci] = dy[rows]
end
continue
end
for ci in MOI.get(inner, MOI.ListOfConstraintIndices{F, S}())
row = _constraint_row(inner, ci)
dual_sensitivities[ci] = dy[row]
Expand Down Expand Up @@ -127,6 +147,14 @@ function MOI.get(model::Optimizer, ::MadDiff.ForwardConstraintDual, ci::MOI.Cons
return model.forward.dual_sensitivities[ci]
end

function MOI.get(
model::Optimizer,
::MadDiff.ForwardConstraintDual,
ci::MOI.ConstraintIndex{MOI.VectorOfVariables, MOI.VectorNonlinearOracle{Float64}},
)
return model.forward.vector_dual_sensitivities[ci]
end

function MOI.get(model::Optimizer, ::MadDiff.ForwardObjectiveSensitivity)
return model.forward.objective_sensitivity
end
3 changes: 3 additions & 0 deletions ext/MathOptInterfaceExt/moi_wrapper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ function MOI.empty!(m::Optimizer)
empty!(m.forward.param_perturbations)
empty!(m.reverse.primal_seeds)
empty!(m.reverse.dual_seeds)
empty!(m.reverse.vector_dual_seeds)
m.reverse.dobj = nothing
return _invalidate_sensitivity!(m)
end
Expand All @@ -103,13 +104,15 @@ function MadDiff.empty_input_sensitivities!(model::Optimizer)
empty!(model.forward.param_perturbations)
empty!(model.reverse.primal_seeds)
empty!(model.reverse.dual_seeds)
empty!(model.reverse.vector_dual_seeds)
model.reverse.dobj = nothing
return _clear_outputs!(model)
end

function _clear_outputs!(m::Optimizer{OT, T}) where {OT, T}
empty!(m.forward.primal_sensitivities)
empty!(m.forward.dual_sensitivities)
empty!(m.forward.vector_dual_sensitivities)
m.forward.objective_sensitivity = zero(T)
empty!(m.reverse.param_outputs)
return m.diff_time = zero(T)
Expand Down
22 changes: 22 additions & 0 deletions ext/MathOptInterfaceExt/reverse_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ function MOI.set(
return _clear_outputs!(model) # keep KKT factorization
end

function MOI.set(
model::Optimizer,
::MadDiff.ReverseConstraintDual,
ci::MOI.ConstraintIndex{MOI.VectorOfVariables, MOI.VectorNonlinearOracle{Float64}},
value::AbstractVector,
)
model.reverse.vector_dual_seeds[ci] = value
return _clear_outputs!(model) # keep KKT factorization
end

function MOI.set(
model::Optimizer{OT, T},
::MadDiff.ReverseObjectiveSensitivity,
Expand Down Expand Up @@ -73,6 +83,15 @@ function _process_reverse_dual_input!(
dL_dy[row] = val
end

function _process_reverse_dual_input!(
ci::MOI.ConstraintIndex{MOI.VectorOfVariables, MOI.VectorNonlinearOracle{Float64}},
val::AbstractVector,
inner, dL_dy, dL_dzl, dL_dzu,
)
rows = _vno_rows(inner, ci)
dL_dy[rows] .= val
end

function _reverse_differentiate_impl!(model::Optimizer{OT, T}) where {OT, T}
inner = model.inner
solver = inner.solver
Expand All @@ -99,6 +118,9 @@ function _reverse_differentiate_impl!(model::Optimizer{OT, T}) where {OT, T}
for (ci, val) in model.reverse.dual_seeds
_process_reverse_dual_input!(ci, val, inner, dL_dy, dL_dzl, dL_dzu)
end
for (ci, val) in model.reverse.vector_dual_seeds
_process_reverse_dual_input!(ci, val, inner, dL_dy, dL_dzl, dL_dzu)
end

dL_dy .*= -solver.cb.obj_sign
dobj = model.reverse.dobj
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[sources]
DiffOpt = {rev = "mk/allow_obj_and_sol", url = "https://github.com/klamike/DiffOpt.jl.git"}
DiffOpt = {rev = "mk/vnodiff338", url = "https://github.com/klamike/DiffOpt.jl.git"}
ExaModels = {rev = "mk/param_ad", url = "https://github.com/klamike/ExaModels.jl"}
HybridKKT = {rev = "mk/latest", url = "https://github.com/klamike/HybridKKT.jl.git"}
MadDiff = {path = ".."}
Expand Down