Skip to content
7 changes: 7 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@
This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.

#### New features
- Add `ot.utils.DataScaler` class for backend-aware joint normalization of input
distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and
support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808)
- Add `ot.utils.apply_scaler` helper that dispatches preprocessing to a scaler object,
a callable, or a no-op (PR #808)
- Add optional `scaler` parameter to `sliced_wasserstein_distance` and
`max_sliced_wasserstein_distance` (PR #808)
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
- Add Warmstart feature to the EMD solver for existing potentials (PR #793)
- Add Warmstart potentials feature to the EMD solver for lazy and sparse solver (PR #795)
Expand Down
114 changes: 114 additions & 0 deletions examples/sliced-wasserstein/plot_sliced_wasserstein_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""
============================================================
Sliced Wasserstein Distance with input scaling (DataScaler)
============================================================

.. note::
Example added in release: 0.9.7.

This example illustrates why input scaling matters when computing the Sliced
Wasserstein Distance (SWD) between distributions whose features have very
different magnitudes. Without scaling, the SWD is dominated by high-magnitude
features and may miss meaningful differences in low-magnitude features.

The :class:`ot.utils.DataScaler` class fits normalization statistics once on a
representative sample and applies the same fixed transformation on every call.
This is preferred over re-normalizing inside each SWD call because the
transformation stays consistent across mini-batches during optimization.

"""

# Author: Harguna Sood <harguna.sood@gmail.com>
#
# License: MIT License

import matplotlib.pylab as pl
import numpy as np

import ot

##############################################################################
# Generate two 2D distributions with mismatched feature scales
# ------------------------------------------------------------
#
# Feature 1 is on the scale of 1000 with random noise.
# Feature 2 is on the scale of 1 with a meaningful 5-sigma shift between
# source and target distributions.

# %% parameters and data generation

rng = np.random.RandomState(0)
n = 500

X_s = np.column_stack(
[
rng.normal(1000, 100, n), # feature 1: large scale, no real signal
rng.normal(0, 1, n), # feature 2: small scale, no shift
]
)
X_t = np.column_stack(
[
rng.normal(1000, 100, n), # feature 1: same distribution as source
rng.normal(5, 1, n), # feature 2: shifted by 5 std
]
)

##############################################################################
# SWD without scaling
# -------------------
#
# Because feature 1 has values ~1000x larger than feature 2, the random
# projections used in SWD are dominated by feature 1. The meaningful shift
# in feature 2 is buried.

# %% SWD without scaling

swd_raw = ot.sliced_wasserstein_distance(X_s, X_t, n_projections=200, seed=0)
print("SWD without scaling: {:.4f}".format(swd_raw))

##############################################################################
# SWD with DataScaler
# -------------------
#
# Fit a standard scaler jointly on both distributions, then pass it to SWD.
# The same fixed statistics are reused on every call, giving a stable loss
# across mini-batches.

# %% SWD with DataScaler

scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t])
swd_scaled = ot.sliced_wasserstein_distance(
X_s, X_t, n_projections=200, seed=0, scaler=scaler
)
print("SWD with DataScaler: {:.4f}".format(swd_scaled))

##############################################################################
# Visualize raw vs. scaled distributions
# ---------------------------------------

# %% plot distributions

X_s_n = scaler.transform(X_s)
X_t_n = scaler.transform(X_t)

pl.figure(1, figsize=(12, 5))

pl.subplot(1, 2, 1)
pl.scatter(X_s[:, 0], X_s[:, 1], alpha=0.5, label="$X_s$", s=10)
pl.scatter(X_t[:, 0], X_t[:, 1], alpha=0.5, label="$X_t$", s=10)
pl.title("Raw distributions\n(feature 2 signal hidden by feature 1 scale)")
pl.xlabel("Feature 1 (large scale)")
pl.ylabel("Feature 2 (small scale)")
pl.legend()

pl.subplot(1, 2, 2)
pl.scatter(X_s_n[:, 0], X_s_n[:, 1], alpha=0.5, label="$X_s$ normalized", s=10)
pl.scatter(X_t_n[:, 0], X_t_n[:, 1], alpha=0.5, label="$X_t$ normalized", s=10)
pl.title("Normalized distributions\n(feature 2 shift clearly visible)")
pl.xlabel("Feature 1 (normalized)")
pl.ylabel("Feature 2 (normalized)")
pl.legend()

pl.tight_layout()
pl.show()
36 changes: 35 additions & 1 deletion ot/sliced.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np
from .backend import get_backend, NumpyBackend
from .utils import list_to_array, get_coordinate_circle
from .utils import list_to_array, get_coordinate_circle, apply_scaler
from .lp import (
wasserstein_circle,
semidiscrete_wasserstein2_unif_circle,
Expand Down Expand Up @@ -76,6 +76,7 @@ def sliced_wasserstein_distance(
projections=None,
seed=None,
log=False,
scaler=None,
):
r"""
Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
Expand Down Expand Up @@ -109,6 +110,20 @@ def sliced_wasserstein_distance(
Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
scaler: None, object with .transform(), or callable, optional
Preprocessing applied to X_s and X_t before computing the distance.
Useful for normalizing inputs when features have very different scales.

- ``None`` : no preprocessing (default)
- Object with ``.transform()`` method : e.g. an :class:`ot.utils.DataScaler`
fitted on a representative sample. This is the recommended way to get
stable, consistent normalization across multiple calls (e.g. when
using SWD as a loss in mini-batch training).
- Callable : any function, lambda, or PyTorch transform applied
directly as ``scaler(X_s)`` and ``scaler(X_t)``.

See :class:`ot.utils.DataScaler` for a backend-aware scaler that supports
joint fitting on multiple distributions.

Returns
-------
Expand Down Expand Up @@ -136,6 +151,8 @@ def sliced_wasserstein_distance(

nx = get_backend(X_s, X_t, a, b, projections)

X_s, X_t = apply_scaler(X_s, X_t, scaler)

n = X_s.shape[0]
m = X_t.shape[0]

Expand Down Expand Up @@ -181,6 +198,7 @@ def max_sliced_wasserstein_distance(
projections=None,
seed=None,
log=False,
scaler=None,
):
r"""
Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
Expand Down Expand Up @@ -215,6 +233,20 @@ def max_sliced_wasserstein_distance(
Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
scaler : None, object with .transform(), or callable, optional
Preprocessing applied to X_s and X_t before computing the distance.
Useful for normalizing inputs when features have very different scales.

- ``None`` : no preprocessing (default)
- Object with ``.transform()`` method : e.g. an :class:`ot.utils.DataScaler`
fitted on a representative sample. This is the recommended way to get
stable, consistent normalization across multiple calls (e.g. when
using SWD as a loss in mini-batch training).
- Callable : any function, lambda, or PyTorch transform applied
directly as ``scaler(X_s)`` and ``scaler(X_t)``.

See :class:`ot.utils.DataScaler` for a backend-aware scaler that supports
joint fitting on multiple distributions.

Returns
-------
Expand Down Expand Up @@ -242,6 +274,8 @@ def max_sliced_wasserstein_distance(

nx = get_backend(X_s, X_t, a, b, projections)

X_s, X_t = apply_scaler(X_s, X_t, scaler)

n = X_s.shape[0]
m = X_t.shape[0]

Expand Down
Loading
Loading