@@ -264,6 +264,112 @@ def test_emd_transport_class():
264264 assert n_unsup != n_semisup , "semisupervised mode not working"
265265
266266
267+ def test_mapping_transport_class ():
268+ """test_mapping_transport
269+ """
270+
271+ ns = 150
272+ nt = 200
273+
274+ Xs , ys = get_data_classif ('3gauss' , ns )
275+ Xt , yt = get_data_classif ('3gauss2' , nt )
276+ Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
277+
278+ ##########################################################################
279+ # kernel == linear mapping tests
280+ ##########################################################################
281+
282+ # check computation and dimensions if bias == False
283+ clf = ot .da .MappingTransport (kernel = "linear" , bias = False )
284+ clf .fit (Xs = Xs , Xt = Xt )
285+
286+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
287+ assert_equal (clf .mapping_ .shape , ((Xs .shape [1 ], Xt .shape [1 ])))
288+
289+ # test margin constraints
290+ mu_s = unif (ns )
291+ mu_t = unif (nt )
292+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
293+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
294+
295+ # test transform
296+ transp_Xs = clf .transform (Xs = Xs )
297+ assert_equal (transp_Xs .shape , Xs .shape )
298+
299+ transp_Xs_new = clf .transform (Xs_new )
300+
301+ # check that the oos method is working
302+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
303+
304+ # check computation and dimensions if bias == True
305+ clf = ot .da .MappingTransport (kernel = "linear" , bias = True )
306+ clf .fit (Xs = Xs , Xt = Xt )
307+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
308+ assert_equal (clf .mapping_ .shape , ((Xs .shape [1 ] + 1 , Xt .shape [1 ])))
309+
310+ # test margin constraints
311+ mu_s = unif (ns )
312+ mu_t = unif (nt )
313+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
314+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
315+
316+ # test transform
317+ transp_Xs = clf .transform (Xs = Xs )
318+ assert_equal (transp_Xs .shape , Xs .shape )
319+
320+ transp_Xs_new = clf .transform (Xs_new )
321+
322+ # check that the oos method is working
323+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
324+
325+ ##########################################################################
326+ # kernel == gaussian mapping tests
327+ ##########################################################################
328+
329+ # check computation and dimensions if bias == False
330+ clf = ot .da .MappingTransport (kernel = "gaussian" , bias = False )
331+ clf .fit (Xs = Xs , Xt = Xt )
332+
333+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
334+ assert_equal (clf .mapping_ .shape , ((Xs .shape [0 ], Xt .shape [1 ])))
335+
336+ # test margin constraints
337+ mu_s = unif (ns )
338+ mu_t = unif (nt )
339+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
340+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
341+
342+ # test transform
343+ transp_Xs = clf .transform (Xs = Xs )
344+ assert_equal (transp_Xs .shape , Xs .shape )
345+
346+ transp_Xs_new = clf .transform (Xs_new )
347+
348+ # check that the oos method is working
349+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
350+
351+ # check computation and dimensions if bias == True
352+ clf = ot .da .MappingTransport (kernel = "gaussian" , bias = True )
353+ clf .fit (Xs = Xs , Xt = Xt )
354+ assert_equal (clf .coupling_ .shape , ((Xs .shape [0 ], Xt .shape [0 ])))
355+ assert_equal (clf .mapping_ .shape , ((Xs .shape [0 ] + 1 , Xt .shape [1 ])))
356+
357+ # test margin constraints
358+ mu_s = unif (ns )
359+ mu_t = unif (nt )
360+ assert_allclose (np .sum (clf .coupling_ , axis = 0 ), mu_t , rtol = 1e-3 , atol = 1e-3 )
361+ assert_allclose (np .sum (clf .coupling_ , axis = 1 ), mu_s , rtol = 1e-3 , atol = 1e-3 )
362+
363+ # test transform
364+ transp_Xs = clf .transform (Xs = Xs )
365+ assert_equal (transp_Xs .shape , Xs .shape )
366+
367+ transp_Xs_new = clf .transform (Xs_new )
368+
369+ # check that the oos method is working
370+ assert_equal (transp_Xs_new .shape , Xs_new .shape )
371+
372+
267373def test_otda ():
268374
269375 n_samples = 150 # nb samples
@@ -326,9 +432,10 @@ def test_otda():
326432 da_emd .predict (xs ) # interpolation of source samples
327433
328434
329- # if __name__ == "__main__":
435+ if __name__ == "__main__" :
330436
331- # test_sinkhorn_transport_class()
332- # test_emd_transport_class()
333- # test_sinkhorn_l1l2_transport_class()
334- # test_sinkhorn_lpl1_transport_class()
437+ # test_sinkhorn_transport_class()
438+ # test_emd_transport_class()
439+ # test_sinkhorn_l1l2_transport_class()
440+ # test_sinkhorn_lpl1_transport_class()
441+ test_mapping_transport_class ()
0 commit comments