44
55This module contains class implementations of various optimisation algoritms.
66
7- :Author: Samuel Farrens <samuel.farrens@cea.fr>, Zaccharie Ramzi <zaccharie.ramzi@cea.fr>
7+ :Author: Samuel Farrens <samuel.farrens@cea.fr>,
8+ Zaccharie Ramzi <zaccharie.ramzi@cea.fr>
89
910NOTES
1011-----
@@ -260,58 +261,43 @@ class FISTA(object):
260261 None , # no restarting
261262 ]
262263
263- def __init__ (
264- self ,
265- restart_strategy = None ,
266- min_beta = None ,
267- s_greedy = None ,
268- xi_restart = None ,
269- a_cd = None ,
270- p_lazy = 1 ,
271- q_lazy = 1 ,
272- r_lazy = 4 ,
273- ):
264+ def __init__ (self , restart_strategy = None , min_beta = None , s_greedy = None ,
265+ xi_restart = None , a_cd = None , p_lazy = 1 , q_lazy = 1 , r_lazy = 4 ):
266+
274267 if isinstance (a_cd , type (None )):
275268 self .mode = 'regular'
276269 self .p_lazy = p_lazy
277270 self .q_lazy = q_lazy
278271 self .r_lazy = r_lazy
272+
279273 elif a_cd > 2 :
280274 self .mode = 'CD'
281275 self .a_cd = a_cd
282276 self ._n = 0
277+
283278 else :
284- raise ValueError (
285- "a_cd must either be None (for regular mode) or a number > 2" ,
286- )
279+ raise ValueError ('a_cd must either be None (for regular mode) or '
280+ ' a number > 2' )
281+
287282 if restart_strategy in self .__class__ .__restarting_strategies__ :
288- self ._check_restart_params (
289- restart_strategy ,
290- min_beta ,
291- s_greedy ,
292- xi_restart ,
293- )
283+ self ._check_restart_params (restart_strategy , min_beta , s_greedy ,
284+ xi_restart )
294285 self .restart_strategy = restart_strategy
295286 self .min_beta = min_beta
296287 self .s_greedy = s_greedy
297288 self .xi_restart = xi_restart
289+
298290 else :
299- raise ValueError (
300- "Restarting strategy must be one of %s." %
301- ", " .join (self .__class__ .__restarting_strategies__ )
302- )
291+ raise ValueError ('Restarting strategy must be one of {}.' .format (
292+ ', ' .join (
293+ self .__class__ .__restarting_strategies__ )))
303294 self ._t_now = 1.0
304295 self ._t_prev = 1.0
305296 self ._delta_0 = None
306297 self ._safeguard = False
307298
308- def _check_restart_params (
309- self ,
310- restart_strategy ,
311- min_beta ,
312- s_greedy ,
313- xi_restart ,
314- ):
299+ def _check_restart_params (self , restart_strategy , min_beta , s_greedy ,
300+ xi_restart ):
315301 r""" Check restarting parameters
316302
317303 This method checks that the restarting parameters are set and satisfy
@@ -346,23 +332,24 @@ def _check_restart_params(
346332 When a parameter that should be set isn't or doesn't verify the
347333 correct assumptions.
348334 """
335+
349336 if restart_strategy is None :
350337 return True
338+
351339 if self .mode != 'regular' :
352- raise ValueError (
353- "Restarting strategies can only be used with regular mode."
354- )
355- greedy_params_check = (
356- min_beta is None or s_greedy is None or s_greedy <= 1
357- )
340+ raise ValueError ('Restarting strategies can only be used with '
341+ ' regular mode.' )
342+
343+ greedy_params_check = (min_beta is None or s_greedy is None or
344+ s_greedy <= 1 )
345+
358346 if restart_strategy == 'greedy' and greedy_params_check :
359- raise ValueError (
360- "You need a min_beta and an s_greedy > 1 for greedy restart."
361- )
347+ raise ValueError ('You need a min_beta and an s_greedy > 1 for '
348+ ' greedy restart.' )
349+
362350 if xi_restart is None or xi_restart >= 1 :
363- raise ValueError (
364- "You need a xi_restart < 1 for restart."
365- )
351+ raise ValueError ('You need a xi_restart < 1 for restart.' )
352+
366353 return True
367354
368355 def is_restart (self , z_old , x_new , x_old ):
@@ -393,18 +380,22 @@ def is_restart(self, z_old, x_new, x_old):
393380 """
394381 if self .restart_strategy is None :
395382 return False
383+
396384 criterion = np .vdot (z_old - x_new , x_new - x_old ) >= 0
385+
397386 if criterion :
398387 if 'adaptive' in self .restart_strategy :
399388 self .r_lazy *= self .xi_restart
400389 if self .restart_strategy in ['adaptive-ii' , 'adaptive-2' ]:
401390 self ._t_now = 1
391+
402392 if self .restart_strategy == 'greedy' :
403393 cur_delta = np .linalg .norm (x_new - x_old )
404394 if self ._delta_0 is None :
405395 self ._delta_0 = self .s_greedy * cur_delta
406396 else :
407397 self ._safeguard = cur_delta >= self ._delta_0
398+
408399 return criterion
409400
410401 def update_beta (self , beta ):
@@ -422,9 +413,11 @@ def update_beta(self, beta):
422413 -------
423414 float: the new value for the beta parameter
424415 """
416+
425417 if self ._safeguard :
426418 beta *= self .xi_restart
427419 beta = max (beta , self .min_beta )
420+
428421 return beta
429422
430423 def update_lambda (self , * args , ** kwargs ):
@@ -441,12 +434,17 @@ def update_lambda(self, *args, **kwargs):
441434 Implements steps 3 and 4 from algoritm 10.7 in [B2011]_
442435
443436 """
437+
444438 if self .restart_strategy == 'greedy' :
445439 return 2
440+
446441 # Steps 3 and 4 from alg.10.7.
447442 self ._t_prev = self ._t_now
443+
448444 if self .mode == 'regular' :
449- self ._t_now = (self .p_lazy + np .sqrt (self .r_lazy * self ._t_prev ** 2 + self .q_lazy )) * 0.5
445+ self ._t_now = (self .p_lazy + np .sqrt (self .r_lazy *
446+ self ._t_prev ** 2 + self .q_lazy )) * 0.5
447+
450448 elif self .mode == 'CD' :
451449 self ._t_now = (self ._n + self .a_cd - 1 ) / self .a_cd
452450 self ._n += 1
@@ -538,7 +536,7 @@ def __init__(self, x, grad, prox, cost='auto', beta_param=1.0,
538536 else :
539537 self ._check_param_update (lambda_update )
540538 self ._lambda_update = lambda_update
541- self ._is_restart = lambda * args , ** kwargs :False
539+ self ._is_restart = lambda * args , ** kwargs : False
542540
543541 # Automatically run the algorithm
544542 if auto_iterate :
@@ -688,8 +686,8 @@ def __init__(self, x, grad, prox_list, cost='auto', gamma_param=1.0,
688686 self ._x_old = np .copy (x )
689687
690688 # Set the algorithm operators
691- (self ._check_operator (operator ) for operator in [grad , cost ]
692- + prox_list )
689+ (self ._check_operator (operator ) for operator in [grad , cost ] +
690+ prox_list )
693691 self ._grad = grad
694692 self ._prox_list = np .array (prox_list )
695693 self ._linear = linear
@@ -910,7 +908,7 @@ class Condat(SetUp):
910908 """
911909
912910 def __init__ (self , x , y , grad , prox , prox_dual , linear = None , cost = 'auto' ,
913- reweight = None , rho = 0.5 , sigma = 1.0 , tau = 1.0 , rho_update = None ,
911+ reweight = None , rho = 0.5 , sigma = 1.0 , tau = 1.0 , rho_update = None ,
914912 sigma_update = None , tau_update = None , auto_iterate = True ,
915913 max_iter = 150 , n_rewightings = 1 , metric_call_period = 5 ,
916914 metrics = {}):
@@ -1070,6 +1068,7 @@ def retrieve_outputs(self):
10701068 metrics [obs .name ] = obs .retrieve_metrics ()
10711069 self .metrics = metrics
10721070
1071+
10731072class POGM (SetUp ):
10741073 r"""Proximal Optimised Gradient Method
10751074
@@ -1103,28 +1102,13 @@ class POGM(SetUp):
11031102 Option to automatically begin iterations upon initialisation (default
11041103 is 'True')
11051104 """
1106- def __init__ (
1107- self ,
1108- u ,
1109- x ,
1110- y ,
1111- z ,
1112- grad ,
1113- prox ,
1114- cost = 'auto' ,
1115- linear = None ,
1116- beta_param = 1.0 ,
1117- sigma_bar = 1.0 ,
1118- auto_iterate = True ,
1119- metric_call_period = 5 ,
1120- metrics = {},
1121- ):
1105+ def __init__ (self , u , x , y , z , grad , prox , cost = 'auto' , linear = None ,
1106+ beta_param = 1.0 , sigma_bar = 1.0 , auto_iterate = True ,
1107+ metric_call_period = 5 , metrics = {}):
1108+
11221109 # Set default algorithm properties
1123- super (POGM , self ).__init__ (
1124- metric_call_period = metric_call_period ,
1125- metrics = metrics ,
1126- linear = linear ,
1127- )
1110+ super (POGM , self ).__init__ (metric_call_period = metric_call_period ,
1111+ metrics = metrics , linear = linear )
11281112
11291113 # set the initial variable values
11301114 (self ._check_input_data (data ) for data in (u , x , y , z ))
@@ -1145,7 +1129,7 @@ def __init__(
11451129
11461130 # Set the algorithm parameters
11471131 (self ._check_param (param ) for param in (beta_param , sigma_bar ))
1148- if not (0 <= sigma_bar <= 1 ):
1132+ if not (0 <= sigma_bar <= 1 ):
11491133 raise ValueError ('The sigma bar parameter needs to be in [0, 1]' )
11501134 self ._beta = beta_param
11511135 self ._sigma_bar = sigma_bar
@@ -1169,7 +1153,7 @@ def _update(self):
11691153 """
11701154 # Step 4 from alg. 3
11711155 self ._grad .get_grad (self ._x_old )
1172- self ._u_new = self ._x_old - self ._beta * self ._grad .grad
1156+ self ._u_new = self ._x_old - self ._beta * self ._grad .grad
11731157
11741158 # Step 5 from alg. 3
11751159 self ._t_new = 0.5 * (1 + np .sqrt (1 + 4 * self ._t_old ** 2 ))
@@ -1218,7 +1202,6 @@ def _update(self):
12181202 self .converge = self .any_convergence_flag () or \
12191203 self ._cost_func .get_cost (self ._x_new )
12201204
1221-
12221205 def iterate (self , max_iter = 150 ):
12231206 r"""Iterate
12241207
0 commit comments