@@ -52,19 +52,23 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
5252 Examples
5353 --------
5454 >>> import ot
55+ >>> np.random.seed(0)
5556 >>> n_source = 7
5657 >>> n_target = 4
57- >>> reg = 1
58- >>> numItermax = 300000
5958 >>> a = ot.utils.unif(n_source)
6059 >>> b = ot.utils.unif(n_target)
61- >>> rng = np.random.RandomState(0)
62- >>> X_source = rng.randn(n_source, 2)
63- >>> Y_target = rng.randn(n_target, 2)
60+ >>> X_source = np.random.randn(n_source, 2)
61+ >>> Y_target = np.random.randn(n_target, 2)
6462 >>> M = ot.dist(X_source, Y_target)
65- >>> method = "ASGD"
66- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
67- >>> print(asgd_pi)
63+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
64+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
65+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
66+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
67+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
68+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
69+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
70+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
71+
6872
6973 References
7074 ----------
@@ -133,19 +137,22 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
133137 Examples
134138 --------
135139 >>> import ot
140+ >>> np.random.seed(0)
136141 >>> n_source = 7
137142 >>> n_target = 4
138- >>> reg = 1
139- >>> numItermax = 300000
140143 >>> a = ot.utils.unif(n_source)
141144 >>> b = ot.utils.unif(n_target)
142- >>> rng = np.random.RandomState(0)
143- >>> X_source = rng.randn(n_source, 2)
144- >>> Y_target = rng.randn(n_target, 2)
145+ >>> X_source = np.random.randn(n_source, 2)
146+ >>> Y_target = np.random.randn(n_target, 2)
145147 >>> M = ot.dist(X_source, Y_target)
146- >>> method = "ASGD"
147- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
148- >>> print(asgd_pi)
148+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
149+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
150+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
151+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
152+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
153+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
154+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
155+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
149156
150157 References
151158 ----------
@@ -222,19 +229,22 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
222229 Examples
223230 --------
224231 >>> import ot
232+ >>> np.random.seed(0)
225233 >>> n_source = 7
226234 >>> n_target = 4
227- >>> reg = 1
228- >>> numItermax = 300000
229235 >>> a = ot.utils.unif(n_source)
230236 >>> b = ot.utils.unif(n_target)
231- >>> rng = np.random.RandomState(0)
232- >>> X_source = rng.randn(n_source, 2)
233- >>> Y_target = rng.randn(n_target, 2)
237+ >>> X_source = np.random.randn(n_source, 2)
238+ >>> Y_target = np.random.randn(n_target, 2)
234239 >>> M = ot.dist(X_source, Y_target)
235- >>> method = "ASGD"
236- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
237- >>> print(asgd_pi)
240+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
241+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
242+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
243+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
244+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
245+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
246+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
247+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
238248
239249 References
240250 ----------
@@ -301,19 +311,22 @@ def c_transform_entropic(b, M, reg, beta):
301311 Examples
302312 --------
303313 >>> import ot
314+ >>> np.random.seed(0)
304315 >>> n_source = 7
305316 >>> n_target = 4
306- >>> reg = 1
307- >>> numItermax = 300000
308317 >>> a = ot.utils.unif(n_source)
309318 >>> b = ot.utils.unif(n_target)
310- >>> rng = np.random.RandomState(0)
311- >>> X_source = rng.randn(n_source, 2)
312- >>> Y_target = rng.randn(n_target, 2)
319+ >>> X_source = np.random.randn(n_source, 2)
320+ >>> Y_target = np.random.randn(n_target, 2)
313321 >>> M = ot.dist(X_source, Y_target)
314- >>> method = "ASGD"
315- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
316- >>> print(asgd_pi)
322+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
323+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
324+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
325+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
326+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
327+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
328+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
329+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
317330
318331 References
319332 ----------
@@ -395,19 +408,22 @@ def solve_semi_dual_entropic(a, b, M, reg, method, numItermax=10000, lr=None,
395408 Examples
396409 --------
397410 >>> import ot
411+ >>> np.random.seed(0)
398412 >>> n_source = 7
399413 >>> n_target = 4
400- >>> reg = 1
401- >>> numItermax = 300000
402414 >>> a = ot.utils.unif(n_source)
403415 >>> b = ot.utils.unif(n_target)
404- >>> rng = np.random.RandomState(0)
405- >>> X_source = rng.randn(n_source, 2)
406- >>> Y_target = rng.randn(n_target, 2)
416+ >>> X_source = np.random.randn(n_source, 2)
417+ >>> Y_target = np.random.randn(n_target, 2)
407418 >>> M = ot.dist(X_source, Y_target)
408- >>> method = "ASGD"
409- >>> asgd_pi = ot.stochastic.solve_semi_dual_entropic(a, b, M, reg, method, numItermax)
410- >>> print(asgd_pi)
419+ >>> ot.stochastic.solve_semi_dual_entropic(a, b, M, reg=1, method="ASGD", numItermax=300000)
420+ array([[2.53942342e-02, 9.98640673e-02, 1.75945647e-02, 4.27664307e-06],
421+ [1.21556999e-01, 1.26350515e-02, 1.30491795e-03, 7.36017394e-03],
422+ [3.54070702e-03, 7.63581358e-02, 6.29581672e-02, 1.32812798e-07],
423+ [2.60578198e-02, 3.35916645e-02, 8.28023223e-02, 4.05336238e-04],
424+ [9.86808864e-03, 7.59774324e-04, 1.08702729e-02, 1.21359007e-01],
425+ [2.17218856e-02, 9.12931802e-04, 1.87962526e-03, 1.18342700e-01],
426+ [4.14237512e-02, 2.67487857e-02, 7.23016955e-02, 2.38291052e-03]])
411427
412428 References
413429 ----------
@@ -502,22 +518,28 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
502518 Examples
503519 --------
504520 >>> import ot
521+ >>> np.random.seed(0)
505522 >>> n_source = 7
506523 >>> n_target = 4
507- >>> reg = 1
508- >>> numItermax = 20000
509- >>> lr = 0.1
510- >>> batch_size = 3
511- >>> log = True
512524 >>> a = ot.utils.unif(n_source)
513525 >>> b = ot.utils.unif(n_target)
514- >>> rng = np.random.RandomState(0)
515- >>> X_source = rng.randn(n_source, 2)
516- >>> Y_target = rng.randn(n_target, 2)
526+ >>> X_source = np.random.randn(n_source, 2)
527+ >>> Y_target = np.random.randn(n_target, 2)
517528 >>> M = ot.dist(X_source, Y_target)
518- >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
519- >>> print(log['alpha'], log['beta'])
520- >>> print(sgd_dual_pi)
529+ >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg=1, batch_size=3, numItermax=30000, lr=0.1, log=True)
530+ >>> log['alpha']
531+ array([0.71759102, 1.57057384, 0.85576566, 0.1208211 , 0.59190466,
532+ 1.197148 , 0.17805133])
533+ >>> log['beta']
534+ array([0.49741367, 0.57478564, 1.40075528, 2.75890102])
535+ >>> sgd_dual_pi
536+ array([[2.09730063e-02, 8.38169324e-02, 7.50365455e-03, 8.72731415e-09],
537+ [5.58432437e-03, 5.89881299e-04, 3.09558411e-05, 8.35469849e-07],
538+ [3.26489515e-03, 7.15536035e-02, 2.99778211e-02, 3.02601593e-10],
539+ [4.05390622e-02, 5.31085068e-02, 6.65191787e-02, 1.55812785e-06],
540+ [7.82299812e-02, 6.12099102e-03, 4.44989098e-02, 2.37719187e-03],
541+ [5.06266486e-02, 2.16230494e-03, 2.26215141e-03, 6.81514609e-04],
542+ [6.06713990e-02, 3.98139808e-02, 5.46829338e-02, 8.62371424e-06]])
521543
522544 References
523545 ----------
@@ -526,7 +548,6 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
526548 International Conference on Learning Representation (2018),
527549 arXiv preprint arxiv:1711.02283.
528550 '''
529-
530551 G = - (np .exp ((alpha [batch_alpha , None ] + beta [None , batch_beta ] -
531552 M [batch_alpha , :][:, batch_beta ]) / reg ) *
532553 a [batch_alpha , None ] * b [None , batch_beta ])
@@ -605,8 +626,19 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
605626 >>> Y_target = rng.randn(n_target, 2)
606627 >>> M = ot.dist(X_source, Y_target)
607628 >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
608- >>> print(log['alpha'], log['beta'])
609- >>> print(sgd_dual_pi)
629+ >>> log['alpha']
630+ array([0.64171798, 1.27932201, 0.78132257, 0.15638935, 0.54888354,
631+ 1.03663469, 0.20595781])
632+ >>> log['beta']
633+ array([0.51207194, 0.58033189, 1.28922676, 2.26859736])
634+ >>> sgd_dual_pi
635+ array([[1.97276541e-02, 7.81248547e-02, 6.22136048e-03, 4.95442423e-09],
636+ [4.23494310e-03, 4.43286263e-04, 2.06927079e-05, 3.82389139e-07],
637+ [3.07542414e-03, 6.67897769e-02, 2.48904999e-02, 1.72030247e-10],
638+ [4.26271990e-02, 5.53375455e-02, 6.16535024e-02, 9.88812650e-07],
639+ [7.60423265e-02, 5.89585256e-03, 3.81267087e-02, 1.39458256e-03],
640+ [4.37557504e-02, 1.85189176e-03, 1.72335760e-03, 3.55491279e-04],
641+ [6.33096109e-02, 4.11683954e-02, 5.02962051e-02, 5.43097516e-06]])
610642
611643 References
612644 ----------
@@ -701,8 +733,19 @@ def solve_dual_entropic(a, b, M, reg, batch_size, numItermax=10000, lr=1,
701733 >>> Y_target = rng.randn(n_target, 2)
702734 >>> M = ot.dist(X_source, Y_target)
703735 >>> sgd_dual_pi, log = ot.stochastic.solve_dual_entropic(a, b, M, reg, batch_size, numItermax, lr, log)
704- >>> print(log['alpha'], log['beta'])
705- >>> print(sgd_dual_pi)
736+ >>> log['alpha']
737+ array([0.64057733, 1.2683513 , 0.75610161, 0.16024284, 0.54926534,
738+ 1.0514201 , 0.19958936])
739+ >>> log['beta']
740+ array([0.51372571, 0.58843489, 1.27993921, 2.24344807])
741+ >>> sgd_dual_pi
742+ array([[1.97377795e-02, 7.86706853e-02, 6.15682001e-03, 4.82586997e-09],
743+ [4.19566963e-03, 4.42016865e-04, 2.02777272e-05, 3.68823708e-07],
744+ [3.00379244e-03, 6.56562018e-02, 2.40462171e-02, 1.63579656e-10],
745+ [4.28626062e-02, 5.60031599e-02, 6.13193826e-02, 9.67977735e-07],
746+ [7.61972739e-02, 5.94609051e-03, 3.77886693e-02, 1.36046648e-03],
747+ [4.44810042e-02, 1.89476742e-03, 1.73285847e-03, 3.51826036e-04],
748+ [6.30118293e-02, 4.12398660e-02, 4.95148998e-02, 5.26247246e-06]])
706749
707750 References
708751 ----------
0 commit comments