Skip to content

Commit 6c79536

Browse files
authored
Merge pull request #44 from CLeARoboticsLab/feature/better-regularization
More flexible regularization of Newton step
2 parents 92d0a32 + 2873159 commit 6c79536

8 files changed

Lines changed: 142 additions & 62 deletions

File tree

benchmark/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ ParametricMCPs = "9b992ff8-05bb-4ea1-b9d2-5ef72d82f7ad"
99
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
12+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1213
TrajectoryGamesBase = "ac1ac542-73eb-4349-ae1b-660ab3609574"
1314
TrajectoryGamesExamples = "ff3fa34c-8d8f-519c-b5bc-31760c52507a"
1415

benchmark/SolverBenchmarks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Distributions: Distributions
1010
using LazySets: LazySets
1111
using PATHSolver: PATHSolver
1212
using ProgressMeter: @showprogress
13+
using Symbolics: Symbolics
1314

1415
abstract type BenchmarkType end
1516
struct QuadraticProgramBenchmark <: BenchmarkType end

benchmark/path.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ function benchmark(
3636
problem.z_symbolic,
3737
problem.θ_symbolic,
3838
problem.lower_bounds,
39-
problem.upper_bounds,
39+
problem.upper_bounds;
40+
η_symbolic = hasproperty(problem, :η_symbolic) ? problem.η_symbolic : nothing,
4041
)
4142
end
4243

@@ -46,19 +47,28 @@ function benchmark(
4647
elseif hasproperty(problem, :K)
4748
# Generated a callable problem.
4849
ParametricMCPs.ParametricMCP(
49-
problem.K,
50+
(z, θ) -> problem.K(z; θ),
5051
problem.lower_bounds,
5152
problem.upper_bounds,
5253
parameter_dimension,
5354
)
5455
else
5556
# Generated a symbolic problem.
57+
K_symbolic =
58+
!hasproperty(problem, :η_symbolic) ? problem.K_symbolic :
59+
Vector{Symbolics.Num}(
60+
Symbolics.substitute(
61+
problem.K_symbolic,
62+
Dict([problem.η_symbolic => 0.0]),
63+
),
64+
)
65+
5666
ParametricMCPs.ParametricMCP(
57-
problem.K_symbolic,
67+
K_symbolic,
5868
problem.z_symbolic,
5969
problem.θ_symbolic,
6070
problem.lower_bounds,
61-
problem.upper_bounds
71+
problem.upper_bounds;
6272
)
6373
end
6474

benchmark/quadratic_program_benchmark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ function generate_test_problem(
3131
A * x - b
3232
end
3333

34-
K(z, θ) =
34+
K(z; θ) =
3535
let
3636
x = z[1:num_primals]
3737
y = z[(num_primals + 1):end]

examples/utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ function build_parametric_game(;
5959
K_symbolic,
6060
z_symbolic,
6161
θ_symbolic,
62+
η_symbolic,
6263
lower_bounds,
6364
upper_bounds,
6465
dims,
@@ -72,7 +73,8 @@ function build_parametric_game(;
7273
z_symbolic,
7374
θ_symbolic,
7475
lower_bounds,
75-
upper_bounds,
76+
upper_bounds;
77+
η_symbolic,
7678
)
7779
MixedComplementarityProblems.ParametricGame(
7880
problems,

src/game.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function ParametricGame(;
3030
shared_equality = nothing,
3131
shared_inequality = nothing,
3232
)
33-
(; K_symbolic, z_symbolic, θ_symbolic, lower_bounds, upper_bounds, dims) =
33+
(; K_symbolic, z_symbolic, θ_symbolic, η_symbolic, lower_bounds, upper_bounds, dims) =
3434
game_to_mcp(;
3535
test_point,
3636
test_parameter,
@@ -39,7 +39,15 @@ function ParametricGame(;
3939
shared_inequality,
4040
)
4141

42-
mcp = PrimalDualMCP(K_symbolic, z_symbolic, θ_symbolic, lower_bounds, upper_bounds)
42+
mcp = PrimalDualMCP(
43+
K_symbolic,
44+
z_symbolic,
45+
θ_symbolic,
46+
lower_bounds,
47+
upper_bounds;
48+
η_symbolic,
49+
)
50+
4351
ParametricGame(problems, shared_equality, shared_inequality, dims, mcp)
4452
end
4553

@@ -62,7 +70,7 @@ function game_to_mcp(;
6270
shared_inequality,
6371
)
6472

65-
# Define primal and dual variables for the game, and game parameters..
73+
# Define primal and dual variables for the game, and game parameters.
6674
# Note that BlockArrays can handle blocks of zero size.
6775
backend = SymbolicTracingUtils.SymbolicsBackend()
6876
x =
@@ -80,6 +88,10 @@ function game_to_mcp(;
8088
SymbolicTracingUtils.make_variables(backend, , sum(dims.θ)) |>
8189
to_blockvector(dims.θ)
8290

91+
# Parameter for adding a scaled identity to the Hessian of each player's
92+
# Lagrangian wrt that player's variable.
93+
η = only(SymbolicTracingUtils.make_variables(backend, , 1))
94+
8395
# Build symbolic expressions for objectives and constraints.
8496
fs = map(problems, blocks(θ)) do p, θi
8597
p.objective(x, θi)
@@ -94,12 +106,12 @@ function game_to_mcp(;
94106
= isnothing(shared_equality) ? nothing : shared_equality(x, θ)
95107
= isnothing(shared_inequality) ? nothing : shared_inequality(x, θ)
96108

97-
# Build gradient of each player's Lagrangian.
109+
# Build gradient of each player's Lagrangian and include regularization.
98110
∇Ls = map(fs, gs, hs, blocks(x), blocks(λ), blocks(μ)) do f, g, h, xi, λi, μi
99111
L =
100112
f - (isnothing(g) ? 0 : sum(λi .* g)) - (isnothing(h) ? 0 : sum(μi .* h)) - (isnothing(g̃) ? 0 : sum(λ̃ .* g̃)) -
101113
(isnothing(h̃) ? 0 : sum(μ̃ .* h̃))
102-
SymbolicTracingUtils.gradient(L, xi)
114+
SymbolicTracingUtils.gradient(L, xi) + η * xi
103115
end
104116

105117
# Build MCP representation.
@@ -150,6 +162,7 @@ function game_to_mcp(;
150162
K_symbolic = collect(K),
151163
z_symbolic = collect(z),
152164
θ_symbolic = collect(θ),
165+
η_symbolic = η,
153166
lower_bounds,
154167
upper_bounds,
155168
dims,
@@ -183,13 +196,9 @@ function dimensions(
183196
end
184197

185198
"Solve a parametric game."
186-
function solve(
187-
game::ParametricGame,
188-
θ;
189-
solver_type = InteriorPoint(),
190-
kwargs...
191-
)
192-
(; x, y, s, kkt_error, status) = solve(solver_type, game.mcp, θ; kwargs...)
199+
function solve(game::ParametricGame, θ; solver_type = InteriorPoint(), kwargs...)
200+
(; x, y, s, kkt_error, status) =
201+
solve(solver_type, game.mcp, θ; regularize_linear_solve = :internal, kwargs...)
193202

194203
# Unpack primals per-player for ease of access later.
195204
end_dims = cumsum(game.dims.x)

src/mcp.jl

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ The primal-dual system arises when we introduce slack variable `s` and set
77
G(x, y; θ) = 0
88
H(x, y; θ) - s = 0
99
s ⦿ y - ϵ = 0
10-
for some ϵ > 0. Define the function `F(x, y, s; θ, ϵ)` to return the left
11-
hand side of this system of equations.
10+
for some ϵ > 0. Define the function `F(x, y, s; θ, ϵ, [η])` to return the left
11+
hand side of this system of equations. Here, `η` is an optional nonnegative
12+
regularization parameter defined by "internally-regularized" problems.
1213
"""
1314
struct PrimalDualMCP{T1,T2,T3}
14-
"A callable `F!(result, x, y, s; θ, ϵ)` which computes the KKT error in-place."
15+
"A callable `F!(result, x, y, s; θ, ϵ, [η])` to compute the KKT error in-place."
1516
F!::T1
16-
"A callable `∇F_z!(result, x, y, s; θ, ϵ)` to compute ∇F wrt z in-place."
17+
"A callable `∇F_z!(result, x, y, s; θ, ϵ, [η])` to compute ∇F wrt z in-place."
1718
∇F_z!::T2
18-
"A callable `∇F_θ!(result, x, y, s; θ, ϵ)` to compute ∇F wrt θ in-place."
19+
"A callable `∇F_θ!(result, x, y, s; θ, ϵ, [η])` to compute ∇F wrt θ in-place."
1920
∇F_θ!::T3
2021
"Dimension of unconstrained variable."
2122
unconstrained_dimension::Int
@@ -57,7 +58,8 @@ function PrimalDualMCP(
5758
H_symbolic::Vector{T},
5859
x_symbolic::Vector{T},
5960
y_symbolic::Vector{T},
60-
θ_symbolic::Vector{T};
61+
θ_symbolic::Vector{T},
62+
η_symbolic::Union{Nothing,T} = nothing;
6163
compute_sensitivities = false,
6264
backend_options = (;),
6365
) where {T<:Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num}}
@@ -79,7 +81,7 @@ function PrimalDualMCP(
7981
s_symbolic .* y_symbolic .- ϵ_symbolic
8082
]
8183

82-
F! = let
84+
F! = if isnothing(η_symbolic)
8385
_F! = SymbolicTracingUtils.build_function(
8486
F_symbolic,
8587
x_symbolic,
@@ -92,37 +94,28 @@ function PrimalDualMCP(
9294
)
9395

9496
(result, x, y, s; θ, ϵ) -> _F!(result, x, y, s, θ, ϵ)
95-
end
96-
97-
∇F_z! = let
98-
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F_symbolic, z_symbolic)
99-
_∇F! = SymbolicTracingUtils.build_function(
100-
∇F_symbolic,
97+
else
98+
_F! = SymbolicTracingUtils.build_function(
99+
F_symbolic,
101100
x_symbolic,
102101
y_symbolic,
103102
s_symbolic,
104103
θ_symbolic,
105-
ϵ_symbolic;
104+
ϵ_symbolic,
105+
η_symbolic;
106106
in_place = true,
107107
backend_options,
108108
)
109109

110-
rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
111-
constant_entries =
112-
SymbolicTracingUtils.get_constant_entries(∇F_symbolic, z_symbolic)
113-
SymbolicTracingUtils.SparseFunction(
114-
(result, x, y, s; θ, ϵ) -> _∇F!(result, x, y, s, θ, ϵ),
115-
rows,
116-
cols,
117-
size(∇F_symbolic),
118-
constant_entries,
119-
)
110+
(result, x, y, s; θ, ϵ, η = 0.0) -> _F!(result, x, y, s, θ, ϵ, η)
120111
end
121112

122-
∇F_θ! =
123-
!compute_sensitivities ? nothing :
124-
let
125-
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F_symbolic, θ_symbolic)
113+
function process_∇F(F, var)
114+
∇F_symbolic = SymbolicTracingUtils.sparse_jacobian(F, var)
115+
rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
116+
constant_entries = SymbolicTracingUtils.get_constant_entries(∇F_symbolic, var)
117+
118+
if isnothing(η_symbolic)
126119
_∇F! = SymbolicTracingUtils.build_function(
127120
∇F_symbolic,
128121
x_symbolic,
@@ -134,17 +127,38 @@ function PrimalDualMCP(
134127
backend_options,
135128
)
136129

137-
rows, cols, _ = SparseArrays.findnz(∇F_symbolic)
138-
constant_entries =
139-
SymbolicTracingUtils.get_constant_entries(∇F_symbolic, θ_symbolic)
140-
SymbolicTracingUtils.SparseFunction(
130+
return SymbolicTracingUtils.SparseFunction(
141131
(result, x, y, s; θ, ϵ) -> _∇F!(result, x, y, s, θ, ϵ),
142132
rows,
143133
cols,
144134
size(∇F_symbolic),
145135
constant_entries,
146136
)
137+
else
138+
_∇F! = SymbolicTracingUtils.build_function(
139+
∇F_symbolic,
140+
x_symbolic,
141+
y_symbolic,
142+
s_symbolic,
143+
θ_symbolic,
144+
ϵ_symbolic,
145+
η_symbolic;
146+
in_place = true,
147+
backend_options,
148+
)
149+
150+
return SymbolicTracingUtils.SparseFunction(
151+
(result, x, y, s; θ, ϵ, η = 0.0) -> _∇F!(result, x, y, s, θ, ϵ, η),
152+
rows,
153+
cols,
154+
size(∇F_symbolic),
155+
constant_entries,
156+
)
147157
end
158+
end
159+
160+
∇F_z! = process_∇F(F_symbolic, z_symbolic)
161+
∇F_θ! = !compute_sensitivities ? nothing : process_∇F(F_symbolic, θ_symbolic)
148162

149163
PrimalDualMCP(F!, ∇F_z!, ∇F_θ!, length(x_symbolic), length(y_symbolic))
150164
end
@@ -157,6 +171,7 @@ function PrimalDualMCP(
157171
lower_bounds::Vector,
158172
upper_bounds::Vector;
159173
parameter_dimension,
174+
internally_regularized = false,
160175
compute_sensitivities = false,
161176
backend = SymbolicTracingUtils.SymbolicsBackend(),
162177
backend_options = (;),
@@ -165,6 +180,21 @@ function PrimalDualMCP(
165180
θ_symbolic = SymbolicTracingUtils.make_variables(backend, , parameter_dimension)
166181
K_symbolic = K(z_symbolic; θ = θ_symbolic)
167182

183+
if internally_regularized
184+
η_symbolic = only(SymbolicTracingUtils.make_variables(backend, , 1))
185+
186+
return PrimalDualMCP(
187+
K_symbolic,
188+
z_symbolic,
189+
θ_symbolic,
190+
lower_bounds,
191+
upper_bounds;
192+
η_symbolic,
193+
compute_sensitivities,
194+
backend_options,
195+
)
196+
end
197+
168198
PrimalDualMCP(
169199
K_symbolic,
170200
z_symbolic,
@@ -185,6 +215,7 @@ function PrimalDualMCP(
185215
θ_symbolic::Vector{T},
186216
lower_bounds::Vector,
187217
upper_bounds::Vector;
218+
η_symbolic::Union{Nothing,T} = nothing,
188219
compute_sensitivities = false,
189220
backend_options = (;),
190221
) where {T<:Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num}}
@@ -203,7 +234,8 @@ function PrimalDualMCP(
203234
H_symbolic,
204235
x_symbolic,
205236
y_symbolic,
206-
θ_symbolic;
237+
θ_symbolic,
238+
η_symbolic;
207239
compute_sensitivities,
208240
backend_options,
209241
)

0 commit comments

Comments
 (0)