Commit b9ed7b1
[MRG] Make partial_wasserstein, partial_wasserstein2 and entropic_partial_wasserstein work with backend (#449)
* add test of partial_wasserstein with torch tensors
* WIP: differentiable ot.partial.partial_wasserstein
* change test of torch partial
* make partial_wasserstein2 work with torch
* test backward through ot.partial.partial_wasserstein2
* add test of entropic_partial_wasserstein with torch tensors
* make entropic_partial_wasserstein work with torch tensors
* add test of backward through entropic_partial_wasserstein
* rm unused import
* test partial_wasserstein with all backends
* tests of partial fcts: check if torch is available
* partial: check if marginals are empty arrays
* add tests when marginals are empty arrays and/or m=None
* add PR to RELEASES.md
---------
Co-authored-by: Antoine Collas <22830806+antoinecollas@users.noreply.github.com>
Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>1 parent c48cd76 commit b9ed7b1
3 files changed
+148
-62
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
17 | | - | |
| 17 | + | |
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
120 | 120 | | |
121 | 121 | | |
122 | 122 | | |
123 | | - | |
| 123 | + | |
124 | 124 | | |
125 | 125 | | |
126 | 126 | | |
| |||
270 | 270 | | |
271 | 271 | | |
272 | 272 | | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
273 | 279 | | |
274 | 280 | | |
275 | 281 | | |
276 | 282 | | |
277 | 283 | | |
278 | | - | |
| 284 | + | |
279 | 285 | | |
280 | 286 | | |
281 | 287 | | |
282 | | - | |
283 | | - | |
284 | | - | |
285 | | - | |
286 | | - | |
287 | | - | |
288 | | - | |
289 | | - | |
290 | | - | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
291 | 298 | | |
292 | 299 | | |
293 | 300 | | |
294 | 301 | | |
295 | | - | |
| 302 | + | |
296 | 303 | | |
297 | 304 | | |
298 | 305 | | |
299 | 306 | | |
300 | | - | |
301 | | - | |
302 | | - | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
303 | 310 | | |
304 | 311 | | |
305 | 312 | | |
| |||
389 | 396 | | |
390 | 397 | | |
391 | 398 | | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
392 | 403 | | |
393 | 404 | | |
394 | 405 | | |
395 | 406 | | |
396 | 407 | | |
397 | | - | |
| 408 | + | |
398 | 409 | | |
399 | | - | |
| 410 | + | |
400 | 411 | | |
401 | 412 | | |
402 | 413 | | |
| |||
838 | 849 | | |
839 | 850 | | |
840 | 851 | | |
841 | | - | |
842 | | - | |
843 | | - | |
| 852 | + | |
| 853 | + | |
| 854 | + | |
844 | 855 | | |
845 | 856 | | |
846 | | - | |
847 | | - | |
| 857 | + | |
| 858 | + | |
848 | 859 | | |
849 | 860 | | |
850 | | - | |
| 861 | + | |
851 | 862 | | |
852 | | - | |
| 863 | + | |
853 | 864 | | |
854 | 865 | | |
855 | | - | |
| 866 | + | |
856 | 867 | | |
857 | 868 | | |
858 | 869 | | |
859 | | - | |
| 870 | + | |
860 | 871 | | |
861 | 872 | | |
862 | 873 | | |
863 | 874 | | |
864 | 875 | | |
865 | | - | |
866 | | - | |
867 | | - | |
868 | | - | |
869 | | - | |
| 876 | + | |
| 877 | + | |
| 878 | + | |
| 879 | + | |
| 880 | + | |
| 881 | + | |
| 882 | + | |
| 883 | + | |
| 884 | + | |
870 | 885 | | |
871 | 886 | | |
872 | | - | |
873 | | - | |
874 | | - | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
875 | 890 | | |
876 | 891 | | |
877 | 892 | | |
878 | 893 | | |
879 | | - | |
| 894 | + | |
880 | 895 | | |
881 | 896 | | |
882 | 897 | | |
883 | | - | |
| 898 | + | |
884 | 899 | | |
885 | 900 | | |
886 | 901 | | |
887 | | - | |
| 902 | + | |
888 | 903 | | |
889 | 904 | | |
890 | | - | |
| 905 | + | |
891 | 906 | | |
892 | 907 | | |
893 | 908 | | |
894 | | - | |
| 909 | + | |
895 | 910 | | |
896 | 911 | | |
897 | 912 | | |
| |||
901 | 916 | | |
902 | 917 | | |
903 | 918 | | |
904 | | - | |
| 919 | + | |
905 | 920 | | |
906 | 921 | | |
907 | 922 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
82 | 83 | | |
83 | 84 | | |
84 | 85 | | |
85 | | - | |
| 86 | + | |
86 | 87 | | |
87 | 88 | | |
88 | 89 | | |
| |||
102 | 103 | | |
103 | 104 | | |
104 | 105 | | |
| 106 | + | |
| 107 | + | |
105 | 108 | | |
106 | | - | |
107 | | - | |
| 109 | + | |
108 | 110 | | |
109 | 111 | | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
118 | 116 | | |
119 | 117 | | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
| 118 | + | |
| 119 | + | |
124 | 120 | | |
125 | 121 | | |
126 | 122 | | |
| |||
130 | 126 | | |
131 | 127 | | |
132 | 128 | | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
| 146 | + | |
| 147 | + | |
| 148 | + | |
| 149 | + | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
139 | 210 | | |
140 | 211 | | |
141 | 212 | | |
| |||
0 commit comments