forked from optsuite/optlib
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNesterovSmooth.lean
More file actions
244 lines (233 loc) · 13.1 KB
/
NesterovSmooth.lean
File metadata and controls
244 lines (233 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
/-
Copyright (c) 2023 Chenyi Li, Ziyu Wang, Zaiwen Wen. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Chenyi Li, Ziyu Wang, Zaiwen Wen
-/
import Convex.Function.Lsmooth
/-!
# NesterovSmooth
## Main results
This file mainly concentrates on the Nesterov algorithm for smooth convex optimization problems.
We prove the O(1 / k ^ 2) rate for this algorithm.
-/
variable {E : Type*} [NormedAddCommGroup E] [InnerProductSpace ℝ E] [CompleteSpace E]
section
open Set
class Nesterov (f : E → ℝ) (f' : E → E) (γ : ℕ+ → ℝ) (initial_point : E) :=
(x : ℕ → E) (y : ℕ+ → E) (v : ℕ → E) (l : NNReal)
(diff : ∀ x₁, HasGradientAt f (f' x₁) x₁)
(update1 : ∀ (k : ℕ+), y k = (1 - γ k) • x (k - 1) + γ k • v (k - 1))
(update2 : ∀ (k : ℕ+), x k = y k - (1 / l.1) • (f' (y k)))
(update3 : ∀ (k : ℕ+), v k = x (k - 1) + (1 / (γ k)) • (x k - x (k - 1)))
(hl : l > 0) (smooth : LipschitzWith l f')
(initial1 : γ 1 = (1 : ℝ)) (initial2 : v 0 = initial_point)
variable {f : E → ℝ} {f' : E → E} {xm x₀: E} {γ : ℕ+ → ℝ } {alg : Nesterov f f' γ x₀}
lemma one_iter (hfun : ConvexOn ℝ Set.univ f) (hg : ∀ (k : ℕ+), γ k = 2 / (k + 1)) :
∀ (k : ℕ+), f (alg.x k) - f xm - (1 - γ k) * (f (alg.x (k - 1)) - f xm) ≤
alg.l * (γ k) ^ 2 / 2 * (‖alg.v (k - 1) - xm‖ ^ 2 - ‖alg.v k - xm‖ ^ 2) := by
have h2 : ∀ (k : ℕ+), ∀ x' : E , f (alg.x k) - f x' ≤ alg.l *
inner (alg.x k - alg.y k) (x' - alg.x k) + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
intro k x'
rw [sub_le_iff_le_add', ← add_assoc]
have : (f' (alg.y k)) = alg.l.1 • (alg.y k - alg.x k) := by
have update2 : ∀ (k : ℕ+), alg.x k = alg.y k - (1 / alg.l.1) • (f' (alg.y k)) := alg.update2
specialize update2 k
have : alg.l > 0 := alg.hl
rw [eq_sub_iff_add_eq', ← eq_sub_iff_add_eq] at update2
rw [← update2, smul_smul]
field_simp
have t1 : f (alg.y k) + inner (f' (alg.y k)) (x' - alg.y k) ≤ f x' := by
exact Convex_first_order_condition' (alg.diff (alg.y k)) hfun (by trivial) x' (by trivial)
calc
_ ≤ f (alg.y k) + inner (f' (alg.y k)) (alg.x k - alg.y k) +
alg.l.1 / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
exact lipschitz_continuos_upper_bound' alg.diff alg.smooth (alg.y k) (alg.x k)
_ = f (alg.y k) + inner (f' (alg.y k)) (x' - alg.y k + (alg.x k - x')) +
alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
rw [add_comm (x' - alg.y k), add_sub (alg.x k - x'), sub_add, sub_self, sub_zero]; simp
_ = f (alg.y k) + inner (f' (alg.y k)) (x' - alg.y k) + inner (f' (alg.y k)) (alg.x k - x')
+ alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by rw [inner_add_right, ← add_assoc]
_ ≤ f x' + inner (f' (alg.y k)) (alg.x k - x') + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
rw [add_le_add_iff_right, add_le_add_iff_right]; exact t1
_ = f x' + inner (alg.l.1 • (alg.y k - alg.x k)) (alg.x k - x') +
alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by rw [this]
_ = f x' + alg.l * inner (alg.x k - alg.y k) (x' - alg.x k) +
alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
rw [real_inner_smul_left, ← inner_neg_neg, neg_sub, neg_sub]; simp
have h3 : ∀ (k : ℕ+), f (alg.x k) - f xm - (1 - γ k) * (f (alg.x (k - 1)) - f xm) ≤
alg.l * (inner (alg.x k - alg.y k) ((1 - γ k) • (alg.x (k - 1)) + ((γ k) • xm) -
alg.x k)) + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
intro k
have : f (alg.x k) - f xm - (1 - γ k) * (f (alg.x (k - 1)) - f xm) = γ k *
(f (alg.x k) - f xm) + (1 - γ k) * (f (alg.x k) - f (alg.x (k - 1))) := by ring_nf
rw [this]
have lzero: 0 < k + (1 : ℝ) := Nat.cast_add_one_pos k
have hz : γ k ≥ (0 : ℝ) := by
rw [hg k]
apply div_nonneg (by norm_num); linarith
have hl : γ k ≤ (1 : ℝ) := by
rw [hg k, div_le_iff lzero, one_mul, ← sub_le_iff_le_add]
ring_nf; simp
exact Nat.succ_le_of_lt k.2
have : γ k • (xm - alg.x k) + (1 - γ k) • (alg.x (k - 1) - alg.x k)
= (1 - γ k) • (alg.x (k - 1)) + ((γ k) • xm)- alg.x k := by
rw [smul_sub, smul_sub, add_sub, ← add_sub_right_comm, sub_sub, ← add_smul]
ring_nf; rw [one_smul, add_comm]
calc
_ ≤ γ k * (alg.l * (inner (alg.x k - alg.y k) (xm - alg.x k)) + alg.l / 2 *
‖alg.x k - alg.y k‖ ^ 2) + (1 - γ k) * (alg.l * (inner (alg.x k - alg.y k)
(alg.x (k - 1) - alg.x k)) + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2) := by
apply add_le_add
· exact mul_le_mul_of_nonneg_left (h2 k xm) hz
· exact mul_le_mul_of_nonneg_left (h2 k (alg.x (k - 1))) (by linarith)
_ = alg.l * (γ k * (inner (alg.x k - alg.y k) (xm - alg.x k))) + alg.l * ((1 - γ k) *
(inner (alg.x k - alg.y k) (alg.x (k - 1) - alg.x k))) +
alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by ring_nf
_ = alg.l * inner (alg.x k - alg.y k) (γ k • (xm - alg.x k)) + alg.l *
(inner (alg.x k - alg.y k) ((1 - γ k) •
(alg.x (k - 1) - alg.x k))) + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
rw [← inner_smul_right _ _ (γ k), ← inner_smul_right _ _ (1 - γ k)]
_ = alg.l * inner (alg.x k - alg.y k) (γ k • (xm - alg.x k) + (1 - γ k) •
(alg.x (k - 1) - alg.x k)) + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
rw [← mul_add, ← inner_add_right (alg.x k - alg.y k)]
_ = alg.l * inner (alg.x k - alg.y k) ((1 - γ k) • (alg.x (k - 1)) +
((γ k) • xm)- alg.x k) + alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by rw [this]
intro k
have lzero: 0 < k + (1 : ℝ) := Nat.cast_add_one_pos k
have hz : γ k ≥ (0 : ℝ) := by
rw [hg k]
apply div_nonneg (by norm_num); linarith
have hzs : γ k > (0 : ℝ) := by
rw [hg k]
apply div_pos (by norm_num); linarith
have this1 : alg.l * (γ k) ^ 2 / 2 * (‖alg.v (k-1) - xm‖ ^ 2 - ‖alg.v k - xm‖ ^ 2) =
alg.l / 2 * (‖alg.y k - (1 - γ k) • (alg.x (k - 1)) - γ k • xm‖ ^ 2 -
‖alg.x k - (1 - γ k) • alg.x (k - 1) - γ k • xm‖ ^ 2) := by
calc
_ = alg.l / 2 * ((γ k) ^ 2 * ‖alg.v (k-1) - xm‖ ^ 2 -
(γ k) ^ 2 * ‖alg.v k - xm‖ ^ 2) := by ring_nf
_ = alg.l / 2 * ((‖γ k‖ * ‖(alg.v (k-1) - xm)‖) ^ 2 -
(‖γ k‖ * ‖(alg.v k - xm)‖) ^ 2) := by
simp; rw [abs_of_nonneg hz]; ring_nf; left; simp
_ = alg.l / 2 * (‖γ k • (alg.v (k-1) - xm)‖ ^ 2 - ‖γ k • (alg.v k - xm)‖ ^ 2) := by
rw [norm_smul, norm_smul]
_ = alg.l / 2 * (‖alg.y k - (1 - γ k) • (alg.x (k - 1)) - γ k • xm‖ ^ 2 -
‖γ k • (alg.x (k - 1) + (1 / (γ k)) • (alg.x k - alg.x (k - 1)))
- γ k • xm‖ ^ 2) := by
have update1 : ∀ (k : ℕ+), alg.y k = (1 - γ k) •
alg.x (k - 1) + γ k • alg.v (k - 1) := alg.update1
specialize update1 k
rw [← neg_add_eq_iff_eq_add, neg_add_eq_sub] at update1
rw [smul_sub, smul_sub, update1, alg.update3 k]
_ = alg.l / 2 * (‖alg.y k - (1 - γ k) • (alg.x (k - 1)) - γ k • xm‖ ^ 2 -
‖alg.x k - (1 - γ k) • alg.x (k - 1) - γ k • xm‖ ^ 2) := by
rw [smul_add, smul_smul]; simp
left; rw [mul_inv_cancel (by linarith), one_smul, sub_smul, one_smul, add_comm, sub_add]
have this2 : alg.l / 2 * (‖alg.y k - (1 - γ k) • (alg.x (k - 1)) - γ k • xm‖ ^ 2 -
‖alg.x k - (1 - γ k) • alg.x (k - 1) - γ k • xm‖ ^ 2) = alg.l *
(inner (alg.x k - alg.y k) ((1 - γ k) • (alg.x (k - 1)) + ((γ k) • xm)- alg.x k))
+ alg.l / 2 * ‖alg.x k - alg.y k‖ ^ 2 := by
rw [sub_sub, sub_sub, norm_sub_sq_real, norm_sub_sq_real, norm_sub_sq_real]
calc
_ = alg.l / 2 * (‖alg.y k‖ ^ 2 - ‖alg.x k‖ ^ 2) + alg.l / 2 * 2 * (inner (alg.x k)
((1 - γ k) • alg.x (↑k - 1) + γ k • xm) - inner (alg.y k)
((1 - γ k) • alg.x (↑k - 1) + γ k • xm)) := by ring_nf
_ = alg.l / 2 * (‖alg.y k‖ ^ 2 - ‖alg.x k‖ ^ 2) + alg.l * inner (alg.x k - alg.y k)
((1 - γ k) • alg.x (↑k - 1) + γ k • xm) := by rw [← inner_sub_left]; ring_nf
_ = alg.l / 2 * (‖alg.y k‖ ^ 2 - ‖alg.x k‖ ^ 2) + alg.l * inner (alg.x k - alg.y k)
((1 - γ k) • (alg.x (k - 1)) + ((γ k) • xm) - alg.x k + alg.x k) := by
rw [sub_add, sub_self, sub_zero]
_ = alg.l / 2 * (‖alg.y k‖ ^ 2 - ‖alg.x k‖ ^ 2) + alg.l * inner (alg.x k - alg.y k)
(alg.x k) + alg.l * (inner (alg.x k - alg.y k) ((1 - γ k) • (alg.x (k - 1))
+ ((γ k) • xm) - alg.x k)) := by
rw [inner_add_right, mul_add]; ring_nf
_ = alg.l / 2 * (‖alg.y k‖ ^ 2 - ‖alg.x k‖ ^ 2) + alg.l * ‖alg.x k‖ ^ 2 -
alg.l * inner (alg.x k) (alg.y k) + alg.l * (inner (alg.x k - alg.y k) ((1 - γ k)
• (alg.x (k - 1)) + ((γ k) • xm) - alg.x k)) := by
rw [inner_sub_left, mul_sub, mul_sub, real_inner_self_eq_norm_sq]
rw [real_inner_comm, add_sub];
_ = alg.l * (inner (alg.x k - alg.y k) ((1 - γ k) • (alg.x (k - 1)) + ((γ k) • xm)
- alg.x k)) + alg.l / 2 * (‖alg.x k‖ ^ 2 - 2 *
inner (alg.x k) (alg.y k) + ‖alg.y k‖ ^ 2) := by ring_nf
rw [this1, this2]
exact h3 k
theorem nesterov_algorithm_smooth (hfun: ConvexOn ℝ Set.univ f)
(hg : ∀ (k : ℕ+), γ k = 2 / (k + 1)) (min : IsMinOn f Set.univ xm)
(con : ∀ k : ℕ+ , (1 - γ k) / (γ k) ^ 2 ≤ 1 / (γ (k - 1)) ^ 2):
∀ k : ℕ+, f (alg.x k) - f xm ≤ 2 * alg.l / ((k + 1) ^ 2) * ‖x₀ - xm‖ ^ 2 := by
have h4 : ∀ (k : ℕ+), f (alg.x k) - f xm - (1 - γ k) * (f (alg.x (k - 1)) - f xm) ≤
alg.l.1 * (γ k) ^ 2 / 2 * (‖alg.v (k-1) - xm‖ ^ 2 - ‖alg.v k -xm‖ ^ 2) := by
exact one_iter hfun hg
have h5 : ∀ (k : ℕ+), 1 / (γ k) ^ 2 * (f (alg.x k) - f xm) + alg.l.1 / 2 * ‖alg.v k - xm‖ ^ 2
≤ 1 / (γ (k - 1)) ^ 2 * (f (alg.x (k - 1)) - f xm) +
alg.l.1 / 2 * ‖alg.v (k - 1) - xm‖ ^ 2 := by
intro k
specialize h4 k
specialize con k
have : (γ k) ^ 2 > 0 := by
rw [hg k]
simp only [Real.rpow_two, div_pow, gt_iff_lt]
apply div_pos (by linarith)
apply sq_pos_of_ne_zero
exact Nat.cast_add_one_ne_zero ↑k
rw [← div_le_div_right this, sub_div, mul_div_right_comm (1 - γ k)] at h4
rw [← one_mul (f (alg.x k) - f xm), mul_div_right_comm 1] at h4
rw [mul_div_right_comm (alg.l).1, mul_assoc, mul_comm (γ k ^ 2)] at h4
rw [← mul_assoc, mul_div_assoc] at h4
rw [div_self (by linarith), mul_one, mul_sub (alg.l.1 / 2)] at h4
rw [tsub_le_iff_left, add_sub, le_sub_iff_add_le] at h4
apply le_trans h4
simp only [Real.rpow_two, ge_iff_le, add_le_add_iff_right, gt_iff_lt, sub_pos, sub_neg]
have : f xm ≤ f (alg.x (k - 1)):= min (by trivial)
apply mul_le_mul_of_nonneg_right _ (by linarith)
exact con
have h6 : ∀ (k : ℕ+), 1 / (γ k) ^ 2 * (f (alg.x k) - f xm) + alg.l / 2 * ‖alg.v k - xm‖ ^ 2
≤ 1 / (γ 1) ^ 2 * (f (alg.x 1) - f xm) + alg.l / 2 * ‖alg.v 1 - xm‖ ^ 2 := by
intro k
induction' k using PNat.caseStrongInductionOn with j IH
· simp
· specialize IH j (le_refl _)
specialize h5 (j + 1)
have y1: ↑(j + 1) - (1 : ℕ) = j := by simp
have y2: j + 1 - 1 = j := by exact Iff.mp PNat.natPred_inj rfl
apply le_trans h5 _
rw [y1, y2]
exact IH
have h7 : 1 / (γ 1) ^ 2 * (f (alg.x 1) - f xm) + alg.l / 2 * ‖alg.v 1 - xm‖ ^ 2
≤ (1 - γ 1) / ((γ 1) ^ 2 ) * (f (alg.x 0) - f xm)
+ alg.l / 2 * ‖alg.v 0 - xm‖ ^ 2 := by
specialize h4 1
rw [alg.initial1, sub_self, zero_mul, sub_zero] at h4
rw [alg.initial1, sub_self, zero_div, zero_mul, zero_add]
simp
simp only [PNat.one_coe, Real.rpow_two, one_pow, mul_one, le_refl, tsub_eq_zero_of_le] at h4
rw [← le_sub_iff_add_le, ← mul_sub]
exact h4
have h8 : ∀ (k : ℕ+), 1 / (γ k) ^ 2 * (f (alg.x k) - f xm) + alg.l / 2
* ‖alg.v k - xm‖ ^ 2 ≤ alg.l / 2 * ‖x₀ - xm‖ ^ 2 := by
rw [alg.initial1] at h6
rw [alg.initial1, sub_self, zero_div, zero_mul, zero_add, alg.initial2] at h7
intro k
apply le_trans (h6 k) h7
intro k
specialize h8 k
have h9 : 1 / (γ k) ^ 2 * (f (alg.x k) - f xm) ≤ alg.l / 2 * ‖x₀ - xm‖ ^ 2 := by
apply le_of_add_le_of_nonneg_left h8 _
have : alg.l > 0 := alg.hl
apply mul_nonneg _ _
· positivity
· simp only [Real.rpow_two, sq_nonneg]
have h10 : alg.l / (2 : ℝ) * ‖x₀ - xm‖ ^ 2 / ((1 :ℝ) / (2 / (k + 1)) ^ 2)
= 2 * alg.l / ((k + 1) ^ 2) * ‖x₀ - xm‖ ^ 2 := by
simp [Nat.cast_add_one_ne_zero ↑k]
rw [← div_mul]
ring_nf
rw [hg k] at h9
rw [← le_div_iff'] at h9
· rw [h10] at h9
exact h9
· simp only [Real.rpow_two, div_pow, one_div, inv_div]
apply div_pos
· apply sq_pos_of_ne_zero
exact Nat.cast_add_one_ne_zero ↑k
· simp only [gt_iff_lt, zero_lt_two, pow_pos]