Skip to content

Commit 3a53dff

Browse files
KrzakalaPaulPaulKrzakalarflamarycedricvincentcuaz
authored
Batch OT losses (Sinkhorn + Gromov) (#755)
* linear ot implemented * improve stopping criterion and assymetric case * Add recompute_const and simplify the pipeline for the symmetric = False * add tests * update the examples and rename to follow the "ot.solve" naming conventions * update realeases.md * idem * move set_grad_enabled to backend * set_grad_enabled for quadratric solver * update doc * remove useless importation in doc * Update references * update example * Remove classes in quadratic, move examples to backend, add potentials, remove context managers for grads. To do: improve doc and tests * updat tests * Massive improvement of the documentation for ot.batch * cover (almost) all ot.batch with tests * bug in the tests * update docstring * highlight that ot.batch is solving the entropic version * removing yet another error in the docstring * Add missing parameter recompute_const * Remove png, add all backends and gradient mode to tests * add the missing pytest * change .sum() into nx.sum * add missing backend * yet another missing nx * remove useless squeeze and add test for non-log bregman * remove last_step from quadratic tests * add missing tests and improve documentation * proper unsqueeze test * add unsqueeze to tensorflow * solve double backprop issue in test_gradients_torch --------- Co-authored-by: PaulKrzakala <paul.krzakala@gmail.com> Co-authored-by: Rémi Flamary <remi.flamary@gmail.com> Co-authored-by: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
1 parent 803d2ab commit 3a53dff

File tree

17 files changed

+1940
-22
lines changed

17 files changed

+1940
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,5 @@ debug
123123

124124
# pytest cahche
125125
.pytest_cache
126+
127+
docs/source/

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,3 +446,8 @@ Artificial Intelligence.
446446
[79] Liu, X., Bai, Y., Martín, R. D., Shi, K., Shahbazi, A., Landman, B. A., Chang, C., & Kolouri, S. (2025). [Linear Spherical Sliced Optimal Transport: A Fast Metric for Comparing Spherical Data](https://openreview.net/forum?id=fgUFZAxywx). International Conference on Learning Representations.
447447

448448
[80] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019.
449+
450+
[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).
451+
452+
453+
```

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
- Fix reg_div function compatibility with numpy in `ot.unbalanced.lbfgsb_unbalanced` via new function `ot.utils.fun_to_numpy` (PR #731)
2727
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)
2828
- Removed release information from quickstart guide (PR #744)
29+
- Implement batch parallel solvers in ot.batch (PR #745)
2930
- Update REAMDE with new API and reorganize examples (PR #754)
3031

3132
#### Closed issues

docs/source/all.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ API and modules
1515

1616

1717
backend
18+
batch
1819
bregman
1920
coot
2021
da

docs/source/user_guide.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,3 +1217,6 @@ References
12171217
couplings <http://proceedings.mlr.press/v89/forrow19a/forrow19a.pdf>`_. In
12181218
The 22nd International Conference on Artificial Intelligence and Statistics
12191219
(pp. 2454-2465). PMLR.
1220+
1221+
.. [41] Xu, H., Luo, D., & Carin, L. (2019). `Scalable Gromov-Wasserstein learning for graph partitioning and matching
1222+
<https://arxiv.org/abs/1906.03666>`_\ , Advances in neural information processing systems, 32.

examples/backends/plot_ot_batch.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
"""
2+
=================================================
3+
Solving Many Optimal Transport Problems in Parallel
4+
=================================================
5+
6+
In some situations, one may want to solve many OT problems with the same
7+
structure (same number of samples, same cost function, etc.) at the same time.
8+
9+
In that case using a for loop to solve the problems sequentially is inefficient.
10+
This example shows how to use the batch solvers implemented in POT to solve
11+
many problems in parallel on CPU or GPU (even more efficient on GPU).
12+
13+
"""
14+
15+
# Author: Paul Krzakala <paul.krzakala@gmail.com>
16+
# License: MIT License
17+
18+
# sphinx_gallery_thumbnail_number = 1
19+
20+
21+
#############################################################################
22+
#
23+
# Computing the Cost Matrices
24+
# ---------------------------------------------
25+
#
26+
# We want to create a batch of optimal transport problems with
27+
# :math:`n` samples in :math:`d` dimensions.
28+
#
29+
# To do this, we first need to compute the cost matrices for each problem.
30+
#
31+
# .. note::
32+
# A straightforward approach would be to use a Python loop and
33+
# :func:`ot.dist`.
34+
# However, this is inefficient when working with batches.
35+
#
36+
# Instead, you can directly use :func:`ot.batch.dist_batch`, which computes
37+
# all cost matrices in parallel.
38+
39+
import ot
40+
import numpy as np
41+
42+
n_problems = 4 # nb problems/batch size
43+
n_samples = 8 # nb samples
44+
dim = 2 # nb dimensions
45+
46+
np.random.seed(0)
47+
samples_source = np.random.randn(n_problems, n_samples, dim)
48+
samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim)
49+
50+
# Naive approach
51+
M_list = []
52+
for i in range(n_problems):
53+
M_list.append(
54+
ot.dist(samples_source[i], samples_target[i])
55+
) # List of cost matrices n_samples x n_samples
56+
# Batched approach
57+
M_batch = ot.batch.dist_batch(
58+
samples_source, samples_target
59+
) # Array of cost matrices n_problems x n_samples x n_samples
60+
61+
for i in range(n_problems):
62+
assert np.allclose(M_list[i], M_batch[i])
63+
64+
#############################################################################
65+
#
66+
# Solving the Problems
67+
# ---------------------------------------------
68+
#
69+
# Once the cost matrices are computed, we can solve the corresponding
70+
# optimal transport problems.
71+
#
72+
# .. note::
73+
# One option is to solve them sequentially with a Python loop using
74+
# :func:`ot.solve`.
75+
# This is simple but inefficient for large batches.
76+
#
77+
# Instead, you can use :func:`ot.batch.solve_batch`, which solves all
78+
# problems in parallel.
79+
80+
reg = 1.0
81+
max_iter = 100
82+
tol = 1e-3
83+
84+
# Naive approach
85+
results_values_list = []
86+
for i in range(n_problems):
87+
res = ot.solve(M_list[i], reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy")
88+
results_values_list.append(res.value_linear)
89+
90+
# Batched approach
91+
results_batch = ot.batch.solve_batch(
92+
M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy"
93+
)
94+
results_values_batch = results_batch.value_linear
95+
96+
assert np.allclose(np.array(results_values_list), results_values_batch, atol=tol * 10)
97+
98+
#############################################################################
99+
#
100+
# Comparing Computation Time
101+
# ---------------------------------------------
102+
#
103+
# We now compare the runtime of the two approaches on larger problems.
104+
#
105+
# .. note::
106+
# The speedup obtained with :mod:`ot.batch` can be even more
107+
# significant when computations are performed on a GPU.
108+
109+
110+
from time import perf_counter
111+
112+
n_problems = 128
113+
n_samples = 8
114+
dim = 2
115+
reg = 10.0
116+
max_iter = 1000
117+
tol = 1e-3
118+
119+
samples_source = np.random.randn(n_problems, n_samples, dim)
120+
samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim)
121+
122+
123+
def benchmark_naive(samples_source, samples_target):
124+
start = perf_counter()
125+
for i in range(n_problems):
126+
M = ot.dist(samples_source[i], samples_target[i])
127+
res = ot.solve(M, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy")
128+
end = perf_counter()
129+
return end - start
130+
131+
132+
def benchmark_batch(samples_source, samples_target):
133+
start = perf_counter()
134+
M_batch = ot.batch.dist_batch(samples_source, samples_target)
135+
res_batch = ot.batch.solve_batch(
136+
M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy"
137+
)
138+
end = perf_counter()
139+
return end - start
140+
141+
142+
time_naive = benchmark_naive(samples_source, samples_target)
143+
time_batch = benchmark_batch(samples_source, samples_target)
144+
145+
print(f"Naive approach time: {time_naive:.4f} seconds")
146+
print(f"Batched approach time: {time_batch:.4f} seconds")
147+
148+
#############################################################################
149+
#
150+
# Gromov-Wasserstein
151+
# ---------------------------------------------
152+
#
153+
# The :mod:`ot.batch` module also provides a batched Gromov-Wasserstein solver.
154+
#
155+
# .. note::
156+
# This solver is **not** equivalent to calling :func:`ot.solve_gromov`
157+
# repeatedly in a loop.
158+
#
159+
# Key differences:
160+
#
161+
# - :func:`ot.solve_gromov`
162+
# Uses the conditional gradient algorithm. Each inner iteration relies on
163+
# an exact EMD solver.
164+
#
165+
# - :func:`ot.batch.solve_gromov_batch`
166+
# Uses a proximal variant, where each inner iteration applies entropic
167+
# regularization.
168+
#
169+
# As a result:
170+
#
171+
# - :func:`ot.solve_gromov` is usually faster on CPU
172+
# - :func:`ot.batch.solve_gromov_batch` is slower on CPU, but provides
173+
# better objective values.
174+
#
175+
# .. tip::
176+
# If your data is on a GPU, :func:`ot.batch.solve_gromov_batch`
177+
# is significantly faster AND provides better objective values.
178+
179+
from ot import solve_gromov
180+
from ot.batch import solve_gromov_batch
181+
182+
183+
def benchmark_naive_gw(samples_source, samples_target):
184+
start = perf_counter()
185+
avg_value = 0
186+
for i in range(n_problems):
187+
C1 = ot.dist(samples_source[i], samples_source[i])
188+
C2 = ot.dist(samples_target[i], samples_target[i])
189+
res = solve_gromov(C1, C2, max_iter=1000, tol=tol)
190+
avg_value += res.value
191+
avg_value /= n_problems
192+
end = perf_counter()
193+
return end - start, avg_value
194+
195+
196+
def benchmark_batch_gw(samples_source, samples_target):
197+
start = perf_counter()
198+
C1_batch = ot.batch.dist_batch(samples_source, samples_source)
199+
C2_batch = ot.batch.dist_batch(samples_target, samples_target)
200+
res_batch = solve_gromov_batch(
201+
C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol
202+
)
203+
avg_value = np.mean(res_batch.value)
204+
end = perf_counter()
205+
return end - start, avg_value
206+
207+
208+
time_naive_gw, avg_value_naive_gw = benchmark_naive_gw(samples_source, samples_target)
209+
time_batch_gw, avg_value_batch_gw = benchmark_batch_gw(samples_source, samples_target)
210+
211+
print(f"{'Method':<20}{'Time (s)':<15}{'Avg Value':<15}")
212+
print(f"{'Naive GW':<20}{time_naive_gw:<15.4f}{avg_value_naive_gw:<15.4f}")
213+
print(f"{'Batched GW':<20}{time_batch_gw:<15.4f}{avg_value_batch_gw:<15.4f}")
214+
215+
#############################################################################
216+
#
217+
# In summary: no more for loops!
218+
# ---------------------------------------------
219+
220+
import matplotlib.pyplot as plt
221+
222+
fig, ax = plt.subplots(figsize=(4, 4))
223+
ax.text(0.5, 0.5, "For", fontsize=160, ha="center", va="center", zorder=0)
224+
ax.axis("off")
225+
ax.plot([0, 1], [0, 1], color="red", linewidth=10, zorder=1)
226+
ax.plot([0, 1], [1, 0], color="red", linewidth=10, zorder=1)
227+
plt.show()

examples/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ Differentiable OT with PyTorch
3333
../../examples/gaussian_gmm/plot_GMM_flow.py
3434
../../examples/gromov/plot_gnn_TFGW.py
3535

36-
3736
Gromov-Wasserstein (GW) and Fused GW
3837
------------------------------------
3938

ot/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from . import lowrank
3838
from . import gmm
3939

40-
4140
# OT functions
4241
from .lp import (
4342
emd,
@@ -73,6 +72,8 @@
7372
from .solvers import solve, solve_gromov, solve_sample
7473
from .lowrank import lowrank_sinkhorn
7574

75+
from .batch import solve_batch, solve_gromov_batch
76+
7677
# utils functions
7778
from .utils import dist, unif, tic, toc, toq
7879

@@ -136,4 +137,6 @@
136137
"sliced_wasserstein_sphere_unif",
137138
"lowrank_sinkhorn",
138139
"lowrank_gromov_wasserstein_samples",
140+
"solve_batch",
141+
"solve_gromov_batch",
139142
]

0 commit comments

Comments
 (0)