Skip to content

Commit b4938ac

Browse files
committed
including more regularization options: handling nonsquare case and including lagrangian hessian regularization for games
1 parent 92d0a32 commit b4938ac

File tree

4 files changed

+125
-57
lines changed

4 files changed

+125
-57
lines changed

examples/utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ function build_parametric_game(;
5959
K_symbolic,
6060
z_symbolic,
6161
θ_symbolic,
62+
η_symbolic,
6263
lower_bounds,
6364
upper_bounds,
6465
dims,
@@ -72,7 +73,8 @@ function build_parametric_game(;
7273
z_symbolic,
7374
θ_symbolic,
7475
lower_bounds,
75-
upper_bounds,
76+
upper_bounds;
77+
η_symbolic,
7678
)
7779
MixedComplementarityProblems.ParametricGame(
7880
problems,

src/game.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function ParametricGame(;
3030
shared_equality = nothing,
3131
shared_inequality = nothing,
3232
)
33-
(; K_symbolic, z_symbolic, θ_symbolic, lower_bounds, upper_bounds, dims) =
33+
(; K_symbolic, z_symbolic, θ_symbolic, η_symbolic, lower_bounds, upper_bounds, dims) =
3434
game_to_mcp(;
3535
test_point,
3636
test_parameter,
@@ -39,7 +39,15 @@ function ParametricGame(;
3939
shared_inequality,
4040
)
4141

42-
mcp = PrimalDualMCP(K_symbolic, z_symbolic, θ_symbolic, lower_bounds, upper_bounds)
42+
mcp = PrimalDualMCP(
43+
K_symbolic,
44+
z_symbolic,
45+
θ_symbolic,
46+
lower_bounds,
47+
upper_bounds;
48+
η_symbolic,
49+
)
50+
4351
ParametricGame(problems, shared_equality, shared_inequality, dims, mcp)
4452
end
4553

@@ -62,7 +70,7 @@ function game_to_mcp(;
6270
shared_inequality,
6371
)
6472

65-
# Define primal and dual variables for the game, and game parameters..
73+
# Define primal and dual variables for the game, and game parameters.
6674
# Note that BlockArrays can handle blocks of zero size.
6775
backend = SymbolicTracingUtils.SymbolicsBackend()
6876
x =
@@ -80,6 +88,10 @@ function game_to_mcp(;
8088
SymbolicTracingUtils.make_variables(backend, , sum(dims.θ)) |>
8189
to_blockvector(dims.θ)
8290

91+
# Parameter for adding a scaled identity to the Hessian of each player's
92+
# Lagrangian wrt that player's variable.
93+
η = only(SymbolicTracingUtils.make_variables(backend, , 1))
94+
8395
# Build symbolic expressions for objectives and constraints.
8496
fs = map(problems, blocks(θ)) do p, θi
8597
p.objective(x, θi)
@@ -94,12 +106,12 @@ function game_to_mcp(;
94106
= isnothing(shared_equality) ? nothing : shared_equality(x, θ)
95107
= isnothing(shared_inequality) ? nothing : shared_inequality(x, θ)
96108

97-
# Build gradient of each player's Lagrangian.
109+
# Build gradient of each player's Lagrangian and include regularization.
98110
∇Ls = map(fs, gs, hs, blocks(x), blocks(λ), blocks(μ)) do f, g, h, xi, λi, μi
99111
L =
100112
f - (isnothing(g) ? 0 : sum(λi .* g)) - (isnothing(h) ? 0 : sum(μi .* h)) - (isnothing(g̃) ? 0 : sum(λ̃ .* g̃)) -
101113
(isnothing(h̃) ? 0 : sum(μ̃ .* h̃))
102-
SymbolicTracingUtils.gradient(L, xi)
114+
SymbolicTracingUtils.gradient(L, xi) + η * xi
103115
end
104116

105117
# Build MCP representation.
@@ -150,6 +162,7 @@ function game_to_mcp(;
150162
K_symbolic = collect(K),
151163
z_symbolic = collect(z),
152164
θ_symbolic = collect(θ),
165+
η_symbolic = η,
153166
lower_bounds,
154167
upper_bounds,
155168
dims,
@@ -183,13 +196,9 @@ function dimensions(
183196
end
184197

185198
"Solve a parametric game."
186-
function solve(
187-
game::ParametricGame,
188-
θ;
189-
solver_type = InteriorPoint(),
190-
kwargs...
191-
)
192-
(; x, y, s, kkt_error, status) = solve(solver_type, game.mcp, θ; kwargs...)
199+
function solve(game::ParametricGame, θ; solver_type = InteriorPoint(), kwargs...)
200+
(; x, y, s, kkt_error, status) =
201+
solve(solver_type, game.mcp, θ; regularize_linear_solve = :internal, kwargs...)
193202

194203
# Unpack primals per-player for ease of access later.
195204
end_dims = cumsum(game.dims.x)

src/mcp.jl

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ The primal-dual system arises when we introduce slack variable `s` and set
77
G(x, y; θ) = 0
88
H(x, y; θ) - s = 0
99
s ⦿ y - ϵ = 0
10-
for some ϵ > 0. Define the function `F(x, y, s; θ, ϵ)` to return the left
11-
hand side of this system of equations.
10+
for some ϵ > 0. Define the function `F(x, y, s; θ, ϵ, [η])` to return the left
11+
hand side of this system of equations. Here, `η` is an optional nonnegative
12+
regularization parameter defined by "internally-regularized" problems.
1213
"""
1314
struct PrimalDualMCP{T1,T2,T3}
14-
"A callable `F!(result, x, y, s; θ, ϵ)` which computes the KKT error in-place."
15+
"A callable `F!(result, x, y, s; θ, ϵ, [η])` to the KKT error in-place."
1516
F!::T1
16-
"A callable `∇F_z!(result, x, y, s; θ, ϵ)` to compute ∇F wrt z in-place."
17+
"A callable `∇F_z!(result, x, y, s; θ, ϵ, [η])` to compute ∇F wrt z in-place."
1718
∇F_z!::T2
18-
"A callable `∇F_θ!(result, x, y, s; θ, ϵ)` to compute ∇F wrt θ in-place."
19+
"A callable `∇F_θ!(result, x, y, s; θ, ϵ, [η])` to compute ∇F wrt θ in-place."
1920
∇F_θ!::T3
2021
"Dimension of unconstrained variable."
2122
unconstrained_dimension::Int
@@ -57,7 +58,8 @@ function PrimalDualMCP(
5758
H_symbolic::Vector{T},
5859
x_symbolic::Vector{T},
5960
y_symbolic::Vector{T},
60-
θ_symbolic::Vector{T};
61+
θ_symbolic::Vector{T},
62+
η_symbolic::Union{Nothing,T} = nothing;
6163
compute_sensitivities = false,
6264
backend_options = (;),
6365
) where {T<:Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num}}
@@ -79,7 +81,7 @@ function PrimalDualMCP(
7981
s_symbolic .* y_symbolic .- ϵ_symbolic
8082
]
8183

82-
F! = let
84+
F! = if isnothing(η_symbolic)
8385
_F! = SymbolicTracingUtils.build_function(
8486
F_symbolic,
8587
x_symbolic,
@@ -92,37 +94,28 @@ function PrimalDualMCP(
9294
)
9395

9496
(result, x, y, s; θ, ϵ) -> _F!(result, x, y, s, θ, ϵ)
95-
end
96-
97-
∇F_z! = let
98-
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F_symbolic, z_symbolic)
99-
_∇F! = SymbolicTracingUtils.build_function(
100-
∇F_symbolic,
97+
else
98+
_F! = SymbolicTracingUtils.build_function(
99+
F_symbolic,
101100
x_symbolic,
102101
y_symbolic,
103102
s_symbolic,
104103
θ_symbolic,
105-
ϵ_symbolic;
104+
ϵ_symbolic,
105+
η_symbolic;
106106
in_place = true,
107107
backend_options,
108108
)
109109

110-
rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
111-
constant_entries =
112-
SymbolicTracingUtils.get_constant_entries(∇F_symbolic, z_symbolic)
113-
SymbolicTracingUtils.SparseFunction(
114-
(result, x, y, s; θ, ϵ) -> _∇F!(result, x, y, s, θ, ϵ),
115-
rows,
116-
cols,
117-
size(∇F_symbolic),
118-
constant_entries,
119-
)
110+
(result, x, y, s; θ, ϵ, η = 0.0) -> _F!(result, x, y, s, θ, ϵ, η)
120111
end
121112

122-
∇F_θ! =
123-
!compute_sensitivities ? nothing :
124-
let
125-
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F_symbolic, θ_symbolic)
113+
function process_∇F(F, var)
114+
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F, var)
115+
rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
116+
constant_entries = SymbolicTracingUtils.get_constant_entries(∇F_symbolic, var)
117+
118+
if isnothing(η_symbolic)
126119
_∇F! = SymbolicTracingUtils.build_function(
127120
∇F_symbolic,
128121
x_symbolic,
@@ -134,17 +127,38 @@ function PrimalDualMCP(
134127
backend_options,
135128
)
136129

137-
rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
138-
constant_entries =
139-
SymbolicTracingUtils.get_constant_entries(∇F_symbolic, θ_symbolic)
140-
SymbolicTracingUtils.SparseFunction(
130+
return SymbolicTracingUtils.SparseFunction(
141131
(result, x, y, s; θ, ϵ) -> _∇F!(result, x, y, s, θ, ϵ),
142132
rows,
143133
cols,
144134
size(∇F_symbolic),
145135
constant_entries,
146136
)
137+
else
138+
_∇F! = SymbolicTracingUtils.build_function(
139+
∇F_symbolic,
140+
x_symbolic,
141+
y_symbolic,
142+
s_symbolic,
143+
θ_symbolic,
144+
ϵ_symbolic,
145+
η_symbolic;
146+
in_place = true,
147+
backend_options,
148+
)
149+
150+
return SymbolicTracingUtils.SparseFunction(
151+
(result, x, y, s; θ, ϵ, η = 0.0) -> _∇F!(result, x, y, s, θ, ϵ, η),
152+
rows,
153+
cols,
154+
size(∇F_symbolic),
155+
constant_entries,
156+
)
147157
end
158+
end
159+
160+
∇F_z! = process_∇F(F_symbolic, z_symbolic)
161+
∇F_θ! = !compute_sensitivities ? nothing : process_∇F(F_symbolic, θ_symbolic)
148162

149163
PrimalDualMCP(F!, ∇F_z!, ∇F_θ!, length(x_symbolic), length(y_symbolic))
150164
end
@@ -157,6 +171,7 @@ function PrimalDualMCP(
157171
lower_bounds::Vector,
158172
upper_bounds::Vector;
159173
parameter_dimension,
174+
internally_regularized = false,
160175
compute_sensitivities = false,
161176
backend = SymbolicTracingUtils.SymbolicsBackend(),
162177
backend_options = (;),
@@ -165,6 +180,21 @@ function PrimalDualMCP(
165180
θ_symbolic = SymbolicTracingUtils.make_variables(backend, , parameter_dimension)
166181
K_symbolic = K(z_symbolic; θ = θ_symbolic)
167182

183+
if internally_regularized
184+
η_symbolic = only(SymbolicTracingUtils.make_variables(backend, , 1))
185+
186+
return PrimalDualMCP(
187+
K_symbolic,
188+
z_symbolic,
189+
θ_symbolic,
190+
lower_bounds,
191+
upper_bounds;
192+
η_symbolic,
193+
compute_sensitivities,
194+
backend_options,
195+
)
196+
end
197+
168198
PrimalDualMCP(
169199
K_symbolic,
170200
z_symbolic,
@@ -185,6 +215,7 @@ function PrimalDualMCP(
185215
θ_symbolic::Vector{T},
186216
lower_bounds::Vector,
187217
upper_bounds::Vector;
218+
η_symbolic::Union{Nothing,T} = nothing,
188219
compute_sensitivities = false,
189220
backend_options = (;),
190221
) where {T<:Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num}}
@@ -203,7 +234,8 @@ function PrimalDualMCP(
203234
H_symbolic,
204235
x_symbolic,
205236
y_symbolic,
206-
θ_symbolic;
237+
θ_symbolic,
238+
η_symbolic;
207239
compute_sensitivities,
208240
backend_options,
209241
)

src/solver.jl

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ Keyword arguments:
2727
- `tol::Real = 1e-4`: the tolerance for the KKT error.
2828
- `max_inner_iters::Int = 20`: the maximum number of inner iterations.
2929
- `max_outer_iters::Int = 50`: the maximum number of outer iterations.
30-
- `tightening_rate::Real = 0.1`: the rate at which to tighten the tolerance.
31-
- `loosening_rate::Real = 0.5`: the rate at which to loosen the tolerance.
30+
- `tightening_rate::Real = 0.1`: rate for tightening tolerance and regularization.
31+
- `loosening_rate::Real = 0.5`: rate for loosening tolerance and regularization.
3232
- `min_stepsize::Real = 1e-2`: the minimum step size for the linesearch.
3333
- `verbose::Bool = false`: whether to print debug information.
3434
- `linear_solve_algorithm::LinearSolve.SciMLLinearSolveAlgorithm`: the linear solve algorithm to use. Any solver from `LinearSolve.jl` can be used.
35+
- `regularize_linear_solve::Symbol = :none`: scheme for regularizing the linear system matrix ∇F. Options are {:none, :identity, :internal}.
3536
"""
3637
function solve(
3738
::InteriorPoint,
@@ -49,6 +50,7 @@ function solve(
4950
min_stepsize = 1e-4,
5051
verbose = false,
5152
linear_solve_algorithm = UMFPACKFactorization(),
53+
regularize_linear_solve = :identity,
5254
)
5355
# Set up common memory.
5456
∇F = mcp.∇F_z!.result_buffer
@@ -61,11 +63,12 @@ function solve(
6163

6264
linsolve = init(LinearProblem(∇F, δz), linear_solve_algorithm)
6365

64-
# Main solver loop.
66+
# Initialize primal, dual, and slack variables.
6567
x = @something(x₀, zeros(mcp.unconstrained_dimension))
6668
y = @something(y₀, ones(mcp.constrained_dimension))
6769
s = @something(s₀, ones(mcp.constrained_dimension))
6870

71+
# Initialize IP relaxation parameter.
6972
if ϵ₀ === :auto
7073
is_warmstarted = !isnothing(x₀) && !isnothing(y₀) && !isnothing(s₀)
7174
if is_warmstarted
@@ -77,6 +80,10 @@ function solve(
7780
ϵ = ϵ₀
7881
end
7982

83+
# Initialize regularization parameter.
84+
η = tol
85+
86+
# Main solver loop.
8087
status = :solved
8188
total_iters = 0
8289
inner_iters = 1
@@ -88,14 +95,30 @@ function solve(
8895

8996
while kkt_error > ϵ && inner_iters < max_inner_iters
9097
total_iters += 1
91-
# Compute the Newton step.
92-
# TODO: Can add some adaptive regularization.
98+
99+
# Compute the (regularized) Newton step.
93100
# TODO: use a linear operator with a lazy gradient computation here.
94-
mcp.F!(F, x, y, s; θ, ϵ)
95-
mcp.∇F_z!(∇F, x, y, s; θ, ϵ)
96-
linsolve.A = ∇F + tol * I
101+
if regularize_linear_solve === :internal
102+
mcp.F!(F, x, y, s; θ, ϵ, η = 0.0)
103+
mcp.∇F_z!(∇F, x, y, s; θ, ϵ, η)
104+
else
105+
mcp.F!(F, x, y, s; θ, ϵ)
106+
mcp.∇F_z!(∇F, x, y, s; θ, ϵ)
107+
end
108+
109+
if regularize_linear_solve === :identity
110+
if size(∇F, 1) == size(∇F, 2)
111+
linsolve.A = ∇F + η * I
112+
else
113+
@warn "Cannot use identity regularization on a nonsquare problem."
114+
end
115+
else
116+
linsolve.A = ∇F
117+
end
118+
97119
linsolve.b = -F
98120
solution = solve!(linsolve)
121+
99122
if !SciMLBase.successful_retcode(solution) &&
100123
(solution.retcode !== SciMLBase.ReturnCode.Default)
101124
verbose &&
@@ -129,10 +152,12 @@ function solve(
129152
break
130153
end
131154

132-
ϵ *= if status === :solved
133-
1 - exp(-tightening_rate * inner_iters)
155+
if status === :solved
156+
ϵ *= 1 - exp(-tightening_rate * inner_iters)
157+
η *= 1 - exp(-tightening_rate * inner_iters)
134158
else
135-
1 + exp(-loosening_rate * inner_iters)
159+
ϵ *= 1 + exp(-loosening_rate * inner_iters)
160+
η *= 1 + exp(-loosening_rate * inner_iters)
136161
end
137162
ϵ = min(ϵ, one(ϵ))
138163
outer_iters += 1

0 commit comments

Comments
 (0)