[WIP] Add a stablized function entropic_partial_wasserstein_logscale#724
[WIP] Add a stablized function entropic_partial_wasserstein_logscale#724wzm2256 wants to merge 5 commits into
Conversation
ot.partial. This function solves the same problem as entropic\_partial\_wasserstein but is computed in logscale, so it is more robust. 2. Test exampless are provided in compare_logscale_POT.py. Test data is in data\entropic_partial_OT_cost.txt
|
Hello @wzm2256 and thanks for the PR we are indeed missing a stabilized partial entropic ot solver. When we add a stabilized solver, we usually add the function and make the original function as a wrapper when we add the log implementation (see how it is done for function ot.sinkhorn ) so that with an additional parameter Finally your example is nice we usually try to add a visualization and we really try not to add raw data to the git. Could you design a simulated example instead of a new dataset? Thanks again for your work, we will do a proper code review as soon as possible but wold prefer to do that after the comments above are taken into account. |
|
Sure! I will work on that. |
|
Hello @wzm2256 , quick reminder that we are waiting for a few changes here. |
|
I tested the current implementation, for some reason its extremely slow at the first iteration itself. |
|
@wzm2256 @rflamary @cedricvincentcuaz — hello! I'm a downstream user of Why a separate PR instead of pushing here: I'm happy to do whichever you prefer:
Either way, the goal is to get the function in so issue #723 closes. Let me know which path you want. |
… (rescue of #724) (#811) Re-applies the function from PR #724 by wzm2256 on top of current master (the original PR is stuck at CONFLICTING since 2025-09; this takes the additive parts and skips the obsolete merges through the March-2025 single-file layout). Subject is [WIP] because the original PR is also [WIP] and maintainer review is still required. Changes vs master: ot/partial/partial_solvers.py + entropic_partial_wasserstein_logscale (function body verbatim from PR #724 modulo: (a) duplicate sphinx label removed to avoid build failure, (b) print warning -> warnings.warn(stacklevel=2) for convention). ot/partial/__init__.py + entropic_partial_wasserstein_logscale export test/test_partial.py + test_entropic_partial_wasserstein_logscale_matches_old_at_large_reg (machine-precision agreement at reg in {10.0, 1.0}: atol=1e-10) + test_entropic_partial_wasserstein_logscale_no_nan_at_small_reg (parametrised over reg in {0.1, 0.05, 0.01, 5e-3, 1e-3, 5e-4}) + test_entropic_partial_wasserstein_logscale_approaches_exact_at_small_reg (plan-cost gap vs exact partial OT at reg=1e-3) + test_entropic_partial_wasserstein_logscale_log_dict + test_entropic_partial_wasserstein_logscale_input_validation examples/unbalanced-partial/plot_entropic_partial_wasserstein_logscale.py + Sphinx-Gallery example reproducing issue #723 + the fix (MPLBACKEND=Agg-safe; narrative softened to acknowledge BLAS/platform-dependent underflow boundary). docs/source/user_guide.rst + one-paragraph mention next to entropic_partial_wasserstein. RELEASES.md + entry under 0.9.7.dev0 (phrased as "mitigated via new log-domain variant", not "fixed", since the standard solver itself is unchanged). Verified locally on master at 41a4d57: pytest test/ -> 1939 passed, 97 skipped, 6 xfailed (no regressions). pytest test/test_partial.py -> 19 passed (8 originals + 11 new parametrised cases for the logscale function). Example script runs end-to-end with MPLBACKEND=Agg. The new function agrees with the standard solver at reg >= 1.0 to ~1e-18 absolute (atol=1e-10 in tests is conservative) and stays finite at reg down to 5e-4 on a 50x50 cost-scale-~50 problem (the exact failure mode of issue #723); std solver returns NaN at reg ~ 0.05-0.01 on the same problem. References Issue #723. Maintainer review needed before merge — author attribution to wzm2256 retained via Co-authored-by trailer. Co-authored-by: wzm2256 <wzm2256@qq.com>
|
Closing this because #811 was merged |
add a new function called entropic_partial_wasserstein_logscale to ot.partial. This function solves the same problem as entropic_partial_wasserstein but is computed in logscale, so it is more robust.
Test exampless are provided in compare_logscale_POT.py. Test data is in data\entropic_partial_OT_cost.txt
Types of changes
I implement a new function
entropic_partial_wasserstein_logscalethat solves exactly the same problem as the one inentropic_partial_wassersteinin log scale. The new function is a line-to-line translation of the old one, and the input/output format is exactly the same.I do not remove the old function because the new function can be slower due to the use of the logsumexp trick. So when there is no Nan error, the old function is favored.
Motivation and context / Related issue
#723
How has this been tested (if it applies)
I test the new function
entropic_partial_wasserstein_logscaleagainst the old oneentropic_partial_wassersteinin the example file 'compare_logscale_POT.py` for both numpy and pytorch.PR checklist
I could not build the document in my laptop due to some errors:
so I am not completely sure whether the document is fine, although I only added a few sentences to the docs.
Also, I do not know how to use pytest to test my code. If this is necessary, I may need some help here.