@@ -45,8 +45,8 @@ def test_sinkhorn_lpl1_transport_class():
4545 Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
4646 transp_Xs_new = clf .transform (Xs_new )
4747
48- # check that the oos method is not working
49- assert_equal (transp_Xs_new , Xs_new )
48+ # check that the oos method is working
49+ assert_equal (transp_Xs_new . shape , Xs_new . shape )
5050
5151 # test inverse transform
5252 transp_Xt = clf .inverse_transform (Xt = Xt )
@@ -55,8 +55,8 @@ def test_sinkhorn_lpl1_transport_class():
5555 Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
5656 transp_Xt_new = clf .inverse_transform (Xt = Xt_new )
5757
58- # check that the oos method is not working and returns the input data
59- assert_equal (transp_Xt_new , Xt_new )
58+ # check that the oos method is working
59+ assert_equal (transp_Xt_new . shape , Xt_new . shape )
6060
6161 # test fit_transform
6262 transp_Xs = clf .fit_transform (Xs = Xs , ys = ys , Xt = Xt )
@@ -108,8 +108,8 @@ def test_sinkhorn_l1l2_transport_class():
108108 Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
109109 transp_Xs_new = clf .transform (Xs_new )
110110
111- # check that the oos method is not working
112- assert_equal (transp_Xs_new , Xs_new )
111+ # check that the oos method is working
112+ assert_equal (transp_Xs_new . shape , Xs_new . shape )
113113
114114 # test inverse transform
115115 transp_Xt = clf .inverse_transform (Xt = Xt )
@@ -118,8 +118,8 @@ def test_sinkhorn_l1l2_transport_class():
118118 Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
119119 transp_Xt_new = clf .inverse_transform (Xt = Xt_new )
120120
121- # check that the oos method is not working and returns the input data
122- assert_equal (transp_Xt_new , Xt_new )
121+ # check that the oos method is working
122+ assert_equal (transp_Xt_new . shape , Xt_new . shape )
123123
124124 # test fit_transform
125125 transp_Xs = clf .fit_transform (Xs = Xs , ys = ys , Xt = Xt )
@@ -171,8 +171,8 @@ def test_sinkhorn_transport_class():
171171 Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
172172 transp_Xs_new = clf .transform (Xs_new )
173173
174- # check that the oos method is not working
175- assert_equal (transp_Xs_new , Xs_new )
174+ # check that the oos method is working
175+ assert_equal (transp_Xs_new . shape , Xs_new . shape )
176176
177177 # test inverse transform
178178 transp_Xt = clf .inverse_transform (Xt = Xt )
@@ -181,8 +181,8 @@ def test_sinkhorn_transport_class():
181181 Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
182182 transp_Xt_new = clf .inverse_transform (Xt = Xt_new )
183183
184- # check that the oos method is not working and returns the input data
185- assert_equal (transp_Xt_new , Xt_new )
184+ # check that the oos method is working
185+ assert_equal (transp_Xt_new . shape , Xt_new . shape )
186186
187187 # test fit_transform
188188 transp_Xs = clf .fit_transform (Xs = Xs , Xt = Xt )
@@ -234,8 +234,8 @@ def test_emd_transport_class():
234234 Xs_new , _ = get_data_classif ('3gauss' , ns + 1 )
235235 transp_Xs_new = clf .transform (Xs_new )
236236
237- # check that the oos method is not working
238- assert_equal (transp_Xs_new , Xs_new )
237+ # check that the oos method is working
238+ assert_equal (transp_Xs_new . shape , Xs_new . shape )
239239
240240 # test inverse transform
241241 transp_Xt = clf .inverse_transform (Xt = Xt )
@@ -244,8 +244,8 @@ def test_emd_transport_class():
244244 Xt_new , _ = get_data_classif ('3gauss2' , nt + 1 )
245245 transp_Xt_new = clf .inverse_transform (Xt = Xt_new )
246246
247- # check that the oos method is not working and returns the input data
248- assert_equal (transp_Xt_new , Xt_new )
247+ # check that the oos method is working
248+ assert_equal (transp_Xt_new . shape , Xt_new . shape )
249249
250250 # test fit_transform
251251 transp_Xs = clf .fit_transform (Xs = Xs , Xt = Xt )
0 commit comments