@@ -40,12 +40,12 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
4040
4141 Parameters
4242 ----------
43- a : np. ndarray (ns,)
43+ a : ndarray, shape (ns,)
4444 samples weights in the source domain
45- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
45+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
4646 samples in the target domain, compute sinkhorn with multiple targets
4747 and fixed M if b is a matrix (return OT loss + dual variables in log)
48- M : np. ndarray (ns,nt)
48+ M : ndarray, shape (ns, nt)
4949 loss matrix
5050 reg : float
5151 Regularization term >0
@@ -64,7 +64,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
6464
6565 Returns
6666 -------
67- gamma : (ns x nt) ndarray
67+ gamma : ndarray, shape (ns, nt)
6868 Optimal transportation matrix for the given parameters
6969 log : dict
7070 log dictionary return only if log==True in parameters
@@ -155,12 +155,12 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
155155
156156 Parameters
157157 ----------
158- a : np. ndarray (ns,)
158+ a : ndarray, shape (ns,)
159159 samples weights in the source domain
160- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
160+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
161161 samples in the target domain, compute sinkhorn with multiple targets
162162 and fixed M if b is a matrix (return OT loss + dual variables in log)
163- M : np. ndarray (ns,nt)
163+ M : ndarray, shape (ns, nt)
164164 loss matrix
165165 reg : float
166166 Regularization term >0
@@ -176,7 +176,6 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
176176 log : bool, optional
177177 record log if True
178178
179-
180179 Returns
181180 -------
182181 W : (nt) ndarray or float
@@ -272,12 +271,12 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
272271
273272 Parameters
274273 ----------
275- a : np. ndarray (ns,)
274+ a : ndarray, shape (ns,)
276275 samples weights in the source domain
277- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
276+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
278277 samples in the target domain, compute sinkhorn with multiple targets
279278 and fixed M if b is a matrix (return OT loss + dual variables in log)
280- M : np. ndarray (ns,nt)
279+ M : ndarray, shape (ns, nt)
281280 loss matrix
282281 reg : float
283282 Regularization term >0
@@ -290,10 +289,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
290289 log : bool, optional
291290 record log if True
292291
293-
294292 Returns
295293 -------
296- gamma : (ns x nt) ndarray
294+ gamma : ndarray, shape (ns, nt)
297295 Optimal transportation matrix for the given parameters
298296 log : dict
299297 log dictionary return only if log==True in parameters
@@ -453,12 +451,12 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
453451
454452 Parameters
455453 ----------
456- a : np. ndarray (ns,)
454+ a : ndarray, shape (ns,)
457455 samples weights in the source domain
458- b : np. ndarray (nt,) or np. ndarray (nt,nbb)
456+ b : ndarray, shape (nt,) or ndarray, shape (nt, nbb)
459457 samples in the target domain, compute sinkhorn with multiple targets
460458 and fixed M if b is a matrix (return OT loss + dual variables in log)
461- M : np. ndarray (ns,nt)
459+ M : ndarray, shape (ns, nt)
462460 loss matrix
463461 reg : float
464462 Regularization term >0
@@ -469,10 +467,9 @@ def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=
469467 log : bool, optional
470468 record log if True
471469
472-
473470 Returns
474471 -------
475- gamma : (ns x nt) ndarray
472+ gamma : ndarray, shape (ns, nt)
476473 Optimal transportation matrix for the given parameters
477474 log : dict
478475 log dictionary return only if log==True in parameters
@@ -602,11 +599,11 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
602599
603600 Parameters
604601 ----------
605- a : np. ndarray (ns,)
602+ a : ndarray, shape (ns,)
606603 samples weights in the source domain
607- b : np. ndarray (nt,)
604+ b : ndarray, shape (nt,)
608605 samples in the target domain
609- M : np. ndarray (ns,nt)
606+ M : ndarray, shape (ns, nt)
610607 loss matrix
611608 reg : float
612609 Regularization term >0
@@ -623,10 +620,9 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
623620 log : bool, optional
624621 record log if True
625622
626-
627623 Returns
628624 -------
629- gamma : (ns x nt) ndarray
625+ gamma : ndarray, shape (ns, nt)
630626 Optimal transportation matrix for the given parameters
631627 log : dict
632628 log dictionary return only if log==True in parameters
@@ -823,19 +819,19 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
823819
824820 Parameters
825821 ----------
826- a : np. ndarray (ns,)
822+ a : ndarray, shape (ns,)
827823 samples weights in the source domain
828- b : np. ndarray (nt,)
824+ b : ndarray, shape (nt,)
829825 samples in the target domain
830- M : np. ndarray (ns,nt)
826+ M : ndarray, shape (ns, nt)
831827 loss matrix
832828 reg : float
833829 Regularization term >0
834830 tau : float
835831 thershold for max value in u or v for log scaling
836832 tau : float
837833 thershold for max value in u or v for log scaling
838- warmstart : tible of vectors
834+ warmstart : tuple of vectors
839835 if given then sarting values for alpha an beta log scalings
840836 numItermax : int, optional
841837 Max number of iterations
@@ -850,10 +846,9 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne
850846 log : bool, optional
851847 record log if True
852848
853-
854849 Returns
855850 -------
856- gamma : (ns x nt) ndarray
851+ gamma : ndarray, shape (ns, nt)
857852 Optimal transportation matrix for the given parameters
858853 log : dict
859854 log dictionary return only if log==True in parameters
@@ -1006,13 +1001,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
10061001
10071002 Parameters
10081003 ----------
1009- A : np. ndarray (d,n)
1004+ A : ndarray, shape (d,n)
10101005 n training distributions a_i of size d
1011- M : np. ndarray (d,d)
1006+ M : ndarray, shape (d,d)
10121007 loss matrix for OT
10131008 reg : float
10141009 Regularization term >0
1015- weights : np. ndarray (n,)
1010+ weights : ndarray, shape (n,)
10161011 Weights of each histogram a_i on the simplex (barycentric coodinates)
10171012 numItermax : int, optional
10181013 Max number of iterations
@@ -1102,11 +1097,11 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
11021097
11031098 Parameters
11041099 ----------
1105- A : np. ndarray (n,w, h)
1100+ A : ndarray, shape (n, w, h)
11061101 n distributions (2D images) of size w x h
11071102 reg : float
11081103 Regularization term >0
1109- weights : np. ndarray (n,)
1104+ weights : ndarray, shape (n,)
11101105 Weights of each image on the simplex (barycentric coodinates)
11111106 numItermax : int, optional
11121107 Max number of iterations
@@ -1119,15 +1114,13 @@ def convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, stopThr=1
11191114 log : bool, optional
11201115 record log if True
11211116
1122-
11231117 Returns
11241118 -------
1125- a : (w,h) ndarray
1119+ a : ndarray, shape (w, h)
11261120 2D Wasserstein barycenter
11271121 log : dict
11281122 log dictionary return only if log==True in parameters
11291123
1130-
11311124 References
11321125 ----------
11331126
@@ -1217,15 +1210,15 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
12171210
12181211 Parameters
12191212 ----------
1220- a : np. ndarray (d)
1213+ a : ndarray, shape (d)
12211214 observed distribution
1222- D : np. ndarray (d,n)
1215+ D : ndarray, shape (d, n)
12231216 dictionary matrix
1224- M : np. ndarray (d,d)
1217+ M : ndarray, shape (d, d)
12251218 loss matrix
1226- M0 : np. ndarray (n,n)
1219+ M0 : ndarray, shape (n, n)
12271220 loss matrix
1228- h0 : np. ndarray (n,)
1221+ h0 : ndarray, shape (n,)
12291222 prior on h
12301223 reg : float
12311224 Regularization term >0 (Wasserstein data fitting)
@@ -1245,7 +1238,7 @@ def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
12451238
12461239 Returns
12471240 -------
1248- a : (d,) ndarray
1241+ a : ndarray, shape (d,)
12491242 Wasserstein barycenter
12501243 log : dict
12511244 log dictionary return only if log==True in parameters
@@ -1325,15 +1318,15 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13251318
13261319 Parameters
13271320 ----------
1328- X_s : np. ndarray (ns, d)
1321+ X_s : ndarray, shape (ns, d)
13291322 samples in the source domain
1330- X_t : np. ndarray (nt, d)
1323+ X_t : ndarray, shape (nt, d)
13311324 samples in the target domain
13321325 reg : float
13331326 Regularization term >0
1334- a : np. ndarray (ns,)
1327+ a : ndarray, shape (ns,)
13351328 samples weights in the source domain
1336- b : np. ndarray (nt,)
1329+ b : ndarray, shape (nt,)
13371330 samples weights in the target domain
13381331 numItermax : int, optional
13391332 Max number of iterations
@@ -1347,7 +1340,7 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numI
13471340
13481341 Returns
13491342 -------
1350- gamma : (ns x nt) ndarray
1343+ gamma : ndarray, shape (ns, nt)
13511344 Regularized optimal transportation matrix for the given parameters
13521345 log : dict
13531346 log dictionary return only if log==True in parameters
@@ -1415,15 +1408,15 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14151408
14161409 Parameters
14171410 ----------
1418- X_s : np. ndarray (ns, d)
1411+ X_s : ndarray, shape (ns, d)
14191412 samples in the source domain
1420- X_t : np. ndarray (nt, d)
1413+ X_t : ndarray, shape (nt, d)
14211414 samples in the target domain
14221415 reg : float
14231416 Regularization term >0
1424- a : np. ndarray (ns,)
1417+ a : ndarray, shape (ns,)
14251418 samples weights in the source domain
1426- b : np. ndarray (nt,)
1419+ b : ndarray, shape (nt,)
14271420 samples weights in the target domain
14281421 numItermax : int, optional
14291422 Max number of iterations
@@ -1437,7 +1430,7 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
14371430
14381431 Returns
14391432 -------
1440- gamma : (ns x nt) ndarray
1433+ gamma : ndarray, shape (ns, nt)
14411434 Regularized optimal transportation matrix for the given parameters
14421435 log : dict
14431436 log dictionary return only if log==True in parameters
@@ -1523,15 +1516,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15231516
15241517 Parameters
15251518 ----------
1526- X_s : np. ndarray (ns, d)
1519+ X_s : ndarray, shape (ns, d)
15271520 samples in the source domain
1528- X_t : np. ndarray (nt, d)
1521+ X_t : ndarray, shape (nt, d)
15291522 samples in the target domain
15301523 reg : float
15311524 Regularization term >0
1532- a : np. ndarray (ns,)
1525+ a : ndarray, shape (ns,)
15331526 samples weights in the source domain
1534- b : np. ndarray (nt,)
1527+ b : ndarray, shape (nt,)
15351528 samples weights in the target domain
15361529 numItermax : int, optional
15371530 Max number of iterations
@@ -1542,17 +1535,15 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15421535 log : bool, optional
15431536 record log if True
15441537
1545-
15461538 Returns
15471539 -------
1548- gamma : (ns x nt) ndarray
1540+ gamma : ndarray, shape (ns, nt)
15491541 Regularized optimal transportation matrix for the given parameters
15501542 log : dict
15511543 log dictionary return only if log==True in parameters
15521544
15531545 Examples
15541546 --------
1555-
15561547 >>> n_s = 2
15571548 >>> n_t = 4
15581549 >>> reg = 0.1
@@ -1564,7 +1555,6 @@ def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeucli
15641555
15651556 References
15661557 ----------
1567-
15681558 .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative Models with Sinkhorn Divergences, Proceedings of the Twenty-First International Conference on Artficial Intelligence and Statistics, (AISTATS) 21, 2018
15691559 '''
15701560 if log :
0 commit comments