@@ -7,15 +7,16 @@ The primal-dual system arises when we introduce slack variable `s` and set
77 G(x, y; θ) = 0
88 H(x, y; θ) - s = 0
99 s ⦿ y - ϵ = 0
10- for some ϵ > 0. Define the function `F(x, y, s; θ, ϵ)` to return the left
11- hand side of this system of equations.
10+ for some ϵ > 0. Define the function `F(x, y, s; θ, ϵ, [η])` to return the left
11+ hand side of this system of equations. Here, `η` is an optional nonnegative
12+ regularization parameter defined by "internally-regularized" problems.
1213"""
1314struct PrimalDualMCP{T1,T2,T3}
14- " A callable `F!(result, x, y, s; θ, ϵ)` which computes the KKT error in-place."
15+ " A callable `F!(result, x, y, s; θ, ϵ, [η] )` to the KKT error in-place."
1516 F!:: T1
16- " A callable `∇F_z!(result, x, y, s; θ, ϵ)` to compute ∇F wrt z in-place."
17+ " A callable `∇F_z!(result, x, y, s; θ, ϵ, [η] )` to compute ∇F wrt z in-place."
1718 ∇F_z!:: T2
18- " A callable `∇F_θ!(result, x, y, s; θ, ϵ)` to compute ∇F wrt θ in-place."
19+ " A callable `∇F_θ!(result, x, y, s; θ, ϵ, [η] )` to compute ∇F wrt θ in-place."
1920 ∇F_θ!:: T3
2021 " Dimension of unconstrained variable."
2122 unconstrained_dimension:: Int
@@ -57,7 +58,8 @@ function PrimalDualMCP(
5758 H_symbolic:: Vector{T} ,
5859 x_symbolic:: Vector{T} ,
5960 y_symbolic:: Vector{T} ,
60- θ_symbolic:: Vector{T} ;
61+ θ_symbolic:: Vector{T} ,
62+ η_symbolic:: Union{Nothing,T} = nothing ;
6163 compute_sensitivities = false ,
6264 backend_options = (;),
6365) where {T<: Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num} }
@@ -79,7 +81,7 @@ function PrimalDualMCP(
7981 s_symbolic .* y_symbolic .- ϵ_symbolic
8082 ]
8183
82- F! = let
84+ F! = if isnothing (η_symbolic)
8385 _F! = SymbolicTracingUtils. build_function (
8486 F_symbolic,
8587 x_symbolic,
@@ -92,37 +94,28 @@ function PrimalDualMCP(
9294 )
9395
9496 (result, x, y, s; θ, ϵ) -> _F! (result, x, y, s, θ, ϵ)
95- end
96-
97- ∇F_z! = let
98- ∇F_symbolic = SymbolicTracingUtils. sparse_jacobian (F_symbolic, z_symbolic)
99- _∇F! = SymbolicTracingUtils. build_function (
100- ∇F_symbolic,
97+ else
98+ _F! = SymbolicTracingUtils. build_function (
99+ F_symbolic,
101100 x_symbolic,
102101 y_symbolic,
103102 s_symbolic,
104103 θ_symbolic,
105- ϵ_symbolic;
104+ ϵ_symbolic,
105+ η_symbolic;
106106 in_place = true ,
107107 backend_options,
108108 )
109109
110- rows, cols, _ = SparseArrays. findnz (∇F_symbolic)
111- constant_entries =
112- SymbolicTracingUtils. get_constant_entries (∇F_symbolic, z_symbolic)
113- SymbolicTracingUtils. SparseFunction (
114- (result, x, y, s; θ, ϵ) -> _∇F! (result, x, y, s, θ, ϵ),
115- rows,
116- cols,
117- size (∇F_symbolic),
118- constant_entries,
119- )
110+ (result, x, y, s; θ, ϵ, η = 0.0 ) -> _F! (result, x, y, s, θ, ϵ, η)
120111 end
121112
122- ∇F_θ! =
123- ! compute_sensitivities ? nothing :
124- let
125- ∇F_symbolic = SymbolicTracingUtils. sparse_jacobian (F_symbolic, θ_symbolic)
113+ function process_∇F (F, var)
114+ ∇F_symbolic = SymbolicTracingUtils. sparse_jacobian (F, var)
115+ rows, cols, _ = SparseArrays. findnz (∇F_symbolic)
116+ constant_entries = SymbolicTracingUtils. get_constant_entries (∇F_symbolic, var)
117+
118+ if isnothing (η_symbolic)
126119 _∇F! = SymbolicTracingUtils. build_function (
127120 ∇F_symbolic,
128121 x_symbolic,
@@ -134,17 +127,38 @@ function PrimalDualMCP(
134127 backend_options,
135128 )
136129
137- rows, cols, _ = SparseArrays. findnz (∇F_symbolic)
138- constant_entries =
139- SymbolicTracingUtils. get_constant_entries (∇F_symbolic, θ_symbolic)
140- SymbolicTracingUtils. SparseFunction (
130+ return SymbolicTracingUtils. SparseFunction (
141131 (result, x, y, s; θ, ϵ) -> _∇F! (result, x, y, s, θ, ϵ),
142132 rows,
143133 cols,
144134 size (∇F_symbolic),
145135 constant_entries,
146136 )
137+ else
138+ _∇F! = SymbolicTracingUtils. build_function (
139+ ∇F_symbolic,
140+ x_symbolic,
141+ y_symbolic,
142+ s_symbolic,
143+ θ_symbolic,
144+ ϵ_symbolic,
145+ η_symbolic;
146+ in_place = true ,
147+ backend_options,
148+ )
149+
150+ return SymbolicTracingUtils. SparseFunction (
151+ (result, x, y, s; θ, ϵ, η = 0.0 ) -> _∇F! (result, x, y, s, θ, ϵ, η),
152+ rows,
153+ cols,
154+ size (∇F_symbolic),
155+ constant_entries,
156+ )
147157 end
158+ end
159+
160+ ∇F_z! = process_∇F (F_symbolic, z_symbolic)
161+ ∇F_θ! = ! compute_sensitivities ? nothing : process_∇F (F_symbolic, θ_symbolic)
148162
149163 PrimalDualMCP (F!, ∇F_z!, ∇F_θ!, length (x_symbolic), length (y_symbolic))
150164end
@@ -157,6 +171,7 @@ function PrimalDualMCP(
157171 lower_bounds:: Vector ,
158172 upper_bounds:: Vector ;
159173 parameter_dimension,
174+ internally_regularized = false ,
160175 compute_sensitivities = false ,
161176 backend = SymbolicTracingUtils. SymbolicsBackend (),
162177 backend_options = (;),
@@ -165,6 +180,21 @@ function PrimalDualMCP(
165180 θ_symbolic = SymbolicTracingUtils. make_variables (backend, :θ , parameter_dimension)
166181 K_symbolic = K (z_symbolic; θ = θ_symbolic)
167182
183+ if internally_regularized
184+ η_symbolic = only (SymbolicTracingUtils. make_variables (backend, :η , 1 ))
185+
186+ return PrimalDualMCP (
187+ K_symbolic,
188+ z_symbolic,
189+ θ_symbolic,
190+ lower_bounds,
191+ upper_bounds;
192+ η_symbolic,
193+ compute_sensitivities,
194+ backend_options,
195+ )
196+ end
197+
168198 PrimalDualMCP (
169199 K_symbolic,
170200 z_symbolic,
@@ -185,6 +215,7 @@ function PrimalDualMCP(
185215 θ_symbolic:: Vector{T} ,
186216 lower_bounds:: Vector ,
187217 upper_bounds:: Vector ;
218+ η_symbolic:: Union{Nothing,T} = nothing ,
188219 compute_sensitivities = false ,
189220 backend_options = (;),
190221) where {T<: Union{SymbolicTracingUtils.FD.Node,SymbolicTracingUtils.Symbolics.Num} }
@@ -203,7 +234,8 @@ function PrimalDualMCP(
203234 H_symbolic,
204235 x_symbolic,
205236 y_symbolic,
206- θ_symbolic;
237+ θ_symbolic,
238+ η_symbolic;
207239 compute_sensitivities,
208240 backend_options,
209241 )
0 commit comments