1212
1313def coordinate_grad_semi_dual (b , M , reg , beta , i ):
1414 '''
15- Compute the coordinate gradient update for regularized discrete
16- distributions for (i, :)
15+ Compute the coordinate gradient update for regularized discrete distributions for (i, :)
1716
1817 The function computes the gradient of the semi dual problem:
1918
2019 .. math::
21- \W_\v arepsilon(a, b) = \max_\v \sum_i (\sum_j v_j * b_j
22- - \r eg log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i
20+ \max_v \sum_i (\sum_j v_j * b_j - reg * log(\sum_j exp((v_j - M_{i,j})/reg) * b_j)) * a_i
21+
22+ Where :
2323
24- where :
2524 - M is the (ns,nt) metric cost matrix
2625 - v is a dual variable in R^J
2726 - reg is the regularization term
@@ -34,15 +33,15 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
3433 Parameters
3534 ----------
3635
37- b : np.ndarray(nt,),
36+ b : np.ndarray(nt,)
3837 target measure
39- M : np.ndarray(ns, nt),
38+ M : np.ndarray(ns, nt)
4039 cost matrix
41- reg : float nu,
40+ reg : float nu
4241 Regularization term > 0
43- v : np.ndarray(nt,),
44- optimization vector
45- i : number int,
42+ v : np.ndarray(nt,)
43+ dual variable
44+ i : number int
4645 picked number i
4746
4847 Returns
@@ -93,14 +92,19 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
9392
9493 .. math::
9594 \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
95+
9696 s.t. \gamma 1 = a
97- \gamma^T 1= b
97+
98+ \gamma^T 1 = b
99+
98100 \gamma \geq 0
99- where :
101+
102+ Where :
103+
100104 - M is the (ns,nt) metric cost matrix
101- - :math:`\Omega` is the entropic regularization term
102- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
105+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
103106 - a and b are source and target weights (sum to 1)
107+
104108 The algorithm used for solving the problem is the SAG algorithm
105109 as proposed in [18]_ [alg.1]
106110
@@ -173,33 +177,37 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
173177
174178def averaged_sgd_entropic_transport (a , b , M , reg , numItermax = 300000 , lr = None ):
175179 '''
176- Compute the ASGD algorithm to solve the regularized semi contibous measures
177- optimal transport max problem
180+ Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem
178181
179182 The function solves the following optimization problem:
180183
181184 .. math::
182185 \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
186+
183187 s.t. \gamma 1 = a
188+
184189 \gamma^T 1= b
190+
185191 \gamma \geq 0
186- where :
192+
193+ Where :
194+
187195 - M is the (ns,nt) metric cost matrix
188- - :math:`\Omega` is the entropic regularization term
189- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
196+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
190197 - a and b are source and target weights (sum to 1)
198+
191199 The algorithm used for solving the problem is the ASGD algorithm
192200 as proposed in [18]_ [alg.2]
193201
194202
195203 Parameters
196204 ----------
197205
198- b : np.ndarray(nt,),
206+ b : np.ndarray(nt,)
199207 target measure
200- M : np.ndarray(ns, nt),
208+ M : np.ndarray(ns, nt)
201209 cost matrix
202- reg : float number,
210+ reg : float number
203211 Regularization term > 0
204212 numItermax : int number
205213 number of iteration
@@ -211,7 +219,7 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
211219 -------
212220
213221 ave_v : np.ndarray(nt,)
214- optimization vector
222+ dual variable
215223
216224 Examples
217225 --------
@@ -265,7 +273,8 @@ def c_transform_entropic(b, M, reg, beta):
265273 .. math::
266274 u = v^{c,reg} = -reg \sum_j exp((v - M)/reg) b_j
267275
268- where :
276+ Where :
277+
269278 - M is the (ns,nt) metric cost matrix
270279 - u, v are dual variables in R^IxR^J
271280 - reg is the regularization term
@@ -290,6 +299,7 @@ def c_transform_entropic(b, M, reg, beta):
290299 -------
291300
292301 u : np.ndarray(ns,)
302+ dual variable
293303
294304 Examples
295305 --------
@@ -341,10 +351,11 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
341351 s.t. \gamma 1 = a
342352 \gamma^T 1= b
343353 \gamma \geq 0
344- where :
354+
355+ Where :
356+
345357 - M is the (ns,nt) metric cost matrix
346- - :math:`\Omega` is the entropic regularization term
347- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
358+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
348359 - a and b are source and target weights (sum to 1)
349360 The algorithm used for solving the problem is the SAG or ASGD algorithms
350361 as proposed in [18]_
@@ -353,15 +364,15 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
353364 Parameters
354365 ----------
355366
356- a : np.ndarray(ns,),
367+ a : np.ndarray(ns,)
357368 source measure
358- b : np.ndarray(nt,),
369+ b : np.ndarray(nt,)
359370 target measure
360- M : np.ndarray(ns, nt),
371+ M : np.ndarray(ns, nt)
361372 cost matrix
362- reg : float number,
373+ reg : float number
363374 Regularization term > 0
364- methode : str,
375+ methode : str
365376 used method (SAG or ASGD)
366377 numItermax : int number
367378 number of iteration
@@ -438,40 +449,40 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
438449def batch_grad_dual (a , b , M , reg , alpha , beta , batch_size , batch_alpha ,
439450 batch_beta ):
440451 '''
441- Computes the partial gradient of F_\W_varepsilon
452+ Computes the partial gradient of the dual optimal transport problem.
453+
454+ For each (i,j) in a batch of coordinates, the partial gradients are :
442455
443- Compute the partial gradient of the dual problem:
456+ .. math::
457+ \partial_{u_i} F = u_i * b_s/l_{v} - \sum_{j \in B_v} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
444458
445- ..math:
446- \f orall i in batch_alpha,
447- grad_alpha_i = alpha_i * batch_size/len(beta) -
448- sum_{j in batch_beta} exp((alpha_i + beta_j - M_{i,j})/reg)
449- * a_i * b_j
459+ \partial_{v_j} F = v_j * b_s/l_{u} - \sum_{i \in B_u} exp((u_i + v_j - M_{i,j})/reg) * a_i * b_j
460+
461+ Where :
450462
451- \f orall j in batch_alpha,
452- grad_beta_j = beta_j * batch_size/len(alpha) -
453- sum_{i in batch_alpha} exp((alpha_i + beta_j - M_{i,j})/reg)
454- * a_i * b_j
455- where :
456463 - M is the (ns,nt) metric cost matrix
457- - alpha, beta are dual variables in R^ixR^J
464+ - u, v are dual variables in R^ixR^J
458465 - reg is the regularization term
459- - batch_alpha and batch_beta are lists of index
466+ - :math:`B_u` and :math:`B_v` are lists of index
467+ - :math:`b_s` is the size of the batchs :math:`B_u` and :math:`B_v`
468+ - :math:`l_u` and :math:`l_v` are the lenghts of :math:`B_u` and :math:`B_v`
460469 - a and b are source and target weights (sum to 1)
461470
462471
463472 The algorithm used for solving the dual problem is the SGD algorithm
464473 as proposed in [19]_ [alg.1]
465474
475+
466476 Parameters
467477 ----------
468- a : np.ndarray(ns,),
478+
479+ a : np.ndarray(ns,)
469480 source measure
470- b : np.ndarray(nt,),
481+ b : np.ndarray(nt,)
471482 target measure
472- M : np.ndarray(ns, nt),
483+ M : np.ndarray(ns, nt)
473484 cost matrix
474- reg : float number,
485+ reg : float number
475486 Regularization term > 0
476487 alpha : np.ndarray(ns,)
477488 dual variable
@@ -542,24 +553,29 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
542553
543554 .. math::
544555 \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
556+
545557 s.t. \gamma 1 = a
558+
546559 \gamma^T 1= b
560+
547561 \gamma \geq 0
548- where :
562+
563+ Where :
564+
549565 - M is the (ns,nt) metric cost matrix
550- - :math:`\Omega` is the entropic regularization term
551- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
566+ - :math:`\Omega` is the entropic regularization term with :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
552567 - a and b are source and target weights (sum to 1)
553568
554569 Parameters
555570 ----------
556- a : np.ndarray(ns,),
571+
572+ a : np.ndarray(ns,)
557573 source measure
558- b : np.ndarray(nt,),
574+ b : np.ndarray(nt,)
559575 target measure
560- M : np.ndarray(ns, nt),
576+ M : np.ndarray(ns, nt)
561577 cost matrix
562- reg : float number,
578+ reg : float number
563579 Regularization term > 0
564580 batch_size : int number
565581 size of the batch
@@ -633,25 +649,29 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
633649
634650 .. math::
635651 \gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)
652+
636653 s.t. \gamma 1 = a
654+
637655 \gamma^T 1= b
656+
638657 \gamma \geq 0
639- where :
658+
659+ Where :
660+
640661 - M is the (ns,nt) metric cost matrix
641- - :math:`\Omega` is the entropic regularization term
642- :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
662+ - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
643663 - a and b are source and target weights (sum to 1)
644664
645665 Parameters
646666 ----------
647667
648- a : np.ndarray(ns,),
668+ a : np.ndarray(ns,)
649669 source measure
650- b : np.ndarray(nt,),
670+ b : np.ndarray(nt,)
651671 target measure
652- M : np.ndarray(ns, nt),
672+ M : np.ndarray(ns, nt)
653673 cost matrix
654- reg : float number,
674+ reg : float number
655675 Regularization term > 0
656676 batch_size : int number
657677 size of the batch
0 commit comments