Skip to content

Commit 28b549e

Browse files
author
Hicham Janati
committed
add test and example of UOT
1 parent 3c53834 commit 28b549e

File tree

2 files changed

+112
-0
lines changed

2 files changed

+112
-0
lines changed

examples/plot_UOT_1D.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
====================
4+
1D Unbalanced optimal transport
5+
====================
6+
7+
This example illustrates the computation of Unbalanced Optimal transport
8+
using a Kullback-Leibler relaxation.
9+
"""
10+
11+
# Author: Hicham Janati <hicham.janati@inria.fr>
12+
#
13+
# License: MIT License
14+
15+
import numpy as np
16+
import matplotlib.pylab as pl
17+
import ot
18+
import ot.plot
19+
from ot.datasets import make_1D_gauss as gauss
20+
21+
##############################################################################
22+
# Generate data
23+
# -------------
24+
25+
26+
#%% parameters
27+
28+
n = 100 # nb bins
29+
30+
# bin positions
31+
x = np.arange(n, dtype=np.float64)
32+
33+
# Gaussian distributions
34+
a = gauss(n, m=20, s=5) # m= mean, s= std
35+
b = gauss(n, m=60, s=10)
36+
37+
# make distributions unbalanced
38+
b *= 5.
39+
40+
# loss matrix
41+
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
42+
M /= M.max()
43+
44+
45+
##############################################################################
46+
# Plot distributions and loss matrix
47+
# ----------------------------------
48+
49+
#%% plot the distributions
50+
51+
pl.figure(1, figsize=(6.4, 3))
52+
pl.plot(x, a, 'b', label='Source distribution')
53+
pl.plot(x, b, 'r', label='Target distribution')
54+
pl.legend()
55+
56+
#%% plot distributions and loss matrix
57+
58+
pl.figure(2, figsize=(5, 5))
59+
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')
60+
61+
62+
##############################################################################
63+
# Solve Unbalanced Sinkhorn
64+
# --------------
65+
66+
67+
#%% Sinkhorn
68+
69+
lambd = 0.1
70+
alpha = 1.
71+
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, lambd, alpha, verbose=True)
72+
73+
pl.figure(4, figsize=(5, 5))
74+
ot.plot.plot1D_mat(a, b, Gs, 'UOT matrix Sinkhorn')
75+
76+
pl.show()

test/test_unbalanced.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""Tests for module Unbalanced OT with entropy regularization"""
2+
3+
# Author: Hicham Janati <hicham.janati@inria.fr>
4+
#
5+
# License: MIT License
6+
7+
import numpy as np
8+
import ot
9+
10+
11+
def test_unbalanced():
12+
# test generalized sinkhorn for unbalanced OT
13+
n = 100
14+
rng = np.random.RandomState(42)
15+
16+
x = rng.randn(n, 2)
17+
a = ot.utils.unif(n)
18+
b = ot.utils.unif(n) * 1.5
19+
20+
M = ot.dist(x, x)
21+
epsilon = 1.
22+
alpha = 1.
23+
K = np.exp(- M / epsilon)
24+
25+
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
26+
stopThr=1e-10, log=True)
27+
28+
# check fixed point equations
29+
fi = alpha / (alpha + epsilon)
30+
v_final = (b / K.T.dot(log["u"])) ** fi
31+
u_final = (a / K.dot(log["v"])) ** fi
32+
33+
np.testing.assert_allclose(
34+
u_final, log["u"], atol=1e-05)
35+
np.testing.assert_allclose(
36+
v_final, log["v"], atol=1e-05)

0 commit comments

Comments
 (0)