Skip to content

Commit bec2181

Browse files
clbonetrflamary
andauthored
[MRG] Linear Circular OT (#736)
* Linear Circular OT * update LCOT * update LCOT * Add doc * typo example * Add LSSOT * tests lcot and lssw * Tests LCOT * skip test tensorflow * catch warnings divide by 0 * Update examples/plot_compute_wasserstein_circle.py * fix doc * fix number citations * fix comment warning --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent 1f6d2df commit bec2181

File tree

10 files changed

+731
-83
lines changed

10 files changed

+731
-83
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,3 +440,7 @@ Artificial Intelligence.
440440
[76] Chapel, L., Tavenard, R. (2025). [One for all and all for one: Efficient computation of partial Wasserstein distances on the line](https://iclr.cc/virtual/2025/poster/28547). In International Conference on Learning Representations.
441441

442442
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024)
443+
444+
[78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). [LCOT: Linear Circular Optimal Transport](https://openreview.net/forum?id=49z97Y9lMq). International Conference on Learning Representations.
445+
446+
[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations.

RELEASES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
- Backend implementation of `ot.dist` for (PR #701)
2121
- Updated documentation Quickstart guide and User guide with new API (PR #726)
2222
- Fix jax version for auto-grad (PR #732)
23+
- Added `ot.solver_1d.linear_circular_ot` and `ot.sliced.linear_sliced_wasserstein_sphere` (PR #736)
2324
- Implement 1d solver for partial optimal transport (PR #741)
2425
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
2526
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)
@@ -36,6 +37,8 @@
3637
- Clean documentation for `ot.gromov.gromov_wasserstein` (PR #737)
3738
- Debug wheels building (PR #739)
3839
- Fix doc for projection sparse simplex (PR #734, PR #746)
40+
- Changed the default behavior of `ot.lp.solver_1d.wasserstein_circle` (Issue #738)
41+
- Avoid raising unnecessary warnings in `ot.lp.solver_1d.binary_search_circle` (Issue #738)
3942

4043
## 0.9.5
4144

examples/backends/plot_ssw_unif_torch.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
.. math::
1111
\min_{x} SSW_2(\nu, \frac{1}{n}\sum_{i=1}^n \delta_{x_i})
1212
13-
where :math:`\nu=\mathrm{Unif}(S^1)`.
13+
where :math:`\nu=\mathrm{Unif}(S^{d-1})`.
1414
1515
"""
1616

@@ -46,15 +46,18 @@
4646

4747

4848
def plot_sphere(ax):
49-
xlist = np.linspace(-1.0, 1.0, 50)
50-
ylist = np.linspace(-1.0, 1.0, 50)
51-
r = np.linspace(1.0, 1.0, 50)
52-
X, Y = np.meshgrid(xlist, ylist)
49+
# Create a sphere using spherical coordinates
50+
phi = np.linspace(0, 2 * np.pi, 100)
51+
theta = np.linspace(0, np.pi, 100)
52+
phi, theta = np.meshgrid(phi, theta)
5353

54-
Z = np.sqrt(np.maximum(r**2 - X**2 - Y**2, 0))
54+
# Compute the spherical coordinates
55+
X = np.sin(theta) * np.cos(phi)
56+
Y = np.sin(theta) * np.sin(phi)
57+
Z = np.cos(theta)
5558

59+
# Plot the wireframe
5660
ax.plot_wireframe(X, Y, Z, color="gray", alpha=0.3)
57-
ax.plot_wireframe(X, Y, -Z, color="gray", alpha=0.3) # Now plot the bottom half
5861

5962

6063
# plot the distributions

examples/sliced-wasserstein/plot_compute_wasserstein_circle.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,17 +102,23 @@ def pdf_von_Mises(theta, mu, kappa):
102102

103103
L_w2_circle = np.zeros((n_try, 200))
104104
L_w2 = np.zeros((n_try, 200))
105+
L_lcot = np.zeros((n_try, 200))
105106

106107
for i in range(n_try):
107108
w2_circle = ot.wasserstein_circle(xs2.T, xts2[i].T, p=2)
108109
w2 = ot.wasserstein_1d(xs2.T, xts2[i].T, p=2)
110+
w_lcot = ot.linear_circular_ot(xs2.T, xts2[i].T)
109111

110112
L_w2_circle[i] = w2_circle
111113
L_w2[i] = w2
114+
L_lcot[i] = w_lcot
112115

113116
m_w2_circle = np.mean(L_w2_circle, axis=0)
114117
std_w2_circle = np.std(L_w2_circle, axis=0)
115118

119+
m_w2_lcot = np.mean(L_lcot, axis=0)
120+
std_w2_lcot = np.std(L_lcot, axis=0)
121+
116122
m_w2 = np.mean(L_w2, axis=0)
117123
std_w2 = np.std(L_w2, axis=0)
118124

@@ -128,6 +134,13 @@ def pdf_von_Mises(theta, mu, kappa):
128134
pl.fill_between(
129135
mu_targets / (2 * np.pi), m_w2 - 2 * std_w2, m_w2 + 2 * std_w2, alpha=0.5
130136
)
137+
pl.plot(mu_targets / (2 * np.pi), m_w2_lcot, label="Linear COT")
138+
pl.fill_between(
139+
mu_targets / (2 * np.pi),
140+
m_w2_lcot - 2 * std_w2_lcot,
141+
m_w2_lcot + 2 * std_w2_lcot,
142+
alpha=0.5,
143+
)
131144
pl.vlines(
132145
x=[mu1 / (2 * np.pi)],
133146
ymin=0,
@@ -159,15 +172,23 @@ def pdf_von_Mises(theta, mu, kappa):
159172
xts[i, k] = xt / (2 * np.pi)
160173

161174
L_w2 = np.zeros((n_try, 100))
175+
L_lcot = np.zeros((n_try, 100))
162176
for i in range(n_try):
163177
L_w2[i] = ot.semidiscrete_wasserstein2_unif_circle(xts[i].T)
178+
L_lcot[i] = ot.linear_circular_ot(xts[i].T)
164179

165180
m_w2 = np.mean(L_w2, axis=0)
166181
std_w2 = np.std(L_w2, axis=0)
167182

183+
m_lcot = np.mean(L_lcot, axis=0)
184+
std_lcot = np.std(L_lcot, axis=0)
185+
168186
pl.figure(1)
169-
pl.plot(kappas, m_w2)
187+
pl.plot(kappas, m_w2, label="Wasserstein")
170188
pl.fill_between(kappas, m_w2 - std_w2, m_w2 + std_w2, alpha=0.5)
189+
pl.plot(kappas, m_lcot, label="LCOT")
190+
pl.fill_between(kappas, m_lcot - std_lcot, m_lcot + std_lcot, alpha=0.5)
191+
pl.legend()
171192
pl.title(r"Evolution of $W_2^2(vM(0,\kappa), Unif(S^1))$")
172193
pl.xlabel(r"$\kappa$")
173194
pl.show()

ot/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
binary_search_circle,
4949
wasserstein_circle,
5050
semidiscrete_wasserstein2_unif_circle,
51+
linear_circular_ot,
5152
)
5253
from .bregman import sinkhorn, sinkhorn2, barycenter
5354
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
@@ -57,6 +58,7 @@
5758
max_sliced_wasserstein_distance,
5859
sliced_wasserstein_sphere,
5960
sliced_wasserstein_sphere_unif,
61+
linear_sliced_wasserstein_sphere,
6062
)
6163
from .gromov import (
6264
gromov_wasserstein,
@@ -105,6 +107,7 @@
105107
"sinkhorn_unbalanced2",
106108
"sliced_wasserstein_distance",
107109
"sliced_wasserstein_sphere",
110+
"linear_sliced_wasserstein_sphere",
108111
"gromov_wasserstein",
109112
"gromov_wasserstein2",
110113
"gromov_barycenters",
@@ -129,6 +132,7 @@
129132
"binary_search_circle",
130133
"wasserstein_circle",
131134
"semidiscrete_wasserstein2_unif_circle",
135+
"linear_circular_ot",
132136
"sliced_wasserstein_sphere_unif",
133137
"lowrank_sinkhorn",
134138
"lowrank_gromov_wasserstein_samples",

ot/lp/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
binary_search_circle,
3030
wasserstein_circle,
3131
semidiscrete_wasserstein2_unif_circle,
32+
linear_circular_ot,
3233
)
3334

3435
__all__ = [
@@ -45,6 +46,7 @@
4546
"binary_search_circle",
4647
"wasserstein_circle",
4748
"semidiscrete_wasserstein2_unif_circle",
49+
"linear_circular_ot",
4850
"dmmot_monge_1dgrid_loss",
4951
"dmmot_monge_1dgrid_optimize",
5052
"check_number_threads",

0 commit comments

Comments
 (0)