1313
1414from .bregman import sinkhorn
1515from .lp import emd
16- from .utils import unif , dist , kernel
16+ from .utils import unif , dist , kernel , cost_normalization
1717from .utils import check_params , deprecated , BaseEstimator
1818from .optim import cg
1919from .optim import gcg
@@ -673,7 +673,7 @@ def fit(self, xs, xt, ws=None, wt=None, norm=None, max_iter=100000):
673673 self .wt = wt
674674
675675 self .M = dist (xs , xt , metric = self .metric )
676- self .normalizeM ( norm )
676+ self .M = cost_normalization ( self . M , norm )
677677 self .G = emd (ws , wt , self .M , max_iter )
678678 self .computed = True
679679
@@ -741,26 +741,6 @@ def predict(self, x, direction=1):
741741 # aply the delta to the interpolation
742742 return xf [idx , :] + x - x0 [idx , :]
743743
744- def normalizeM (self , norm ):
745- """ Apply normalization to the loss matrix
746-
747-
748- Parameters
749- ----------
750- norm : str
751- type of normalization from 'median','max','log','loglog'
752-
753- """
754-
755- if norm == "median" :
756- self .M /= float (np .median (self .M ))
757- elif norm == "max" :
758- self .M /= float (np .max (self .M ))
759- elif norm == "log" :
760- self .M = np .log (1 + self .M )
761- elif norm == "loglog" :
762- self .M = np .log (1 + np .log (1 + self .M ))
763-
764744
765745@deprecated ("The class OTDA_sinkhorn is deprecated in 0.3.1 and will be"
766746 " removed in 0.5 \n Use class SinkhornTransport instead." )
@@ -787,7 +767,7 @@ def fit(self, xs, xt, reg=1, ws=None, wt=None, norm=None, **kwargs):
787767 self .wt = wt
788768
789769 self .M = dist (xs , xt , metric = self .metric )
790- self .normalizeM ( norm )
770+ self .M = cost_normalization ( self . M , norm )
791771 self .G = sinkhorn (ws , wt , self .M , reg , ** kwargs )
792772 self .computed = True
793773
@@ -816,7 +796,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
816796 self .wt = wt
817797
818798 self .M = dist (xs , xt , metric = self .metric )
819- self .normalizeM ( norm )
799+ self .M = cost_normalization ( self . M , norm )
820800 self .G = sinkhorn_lpl1_mm (ws , ys , wt , self .M , reg , eta , ** kwargs )
821801 self .computed = True
822802
@@ -845,7 +825,7 @@ def fit(self, xs, ys, xt, reg=1, eta=1, ws=None, wt=None, norm=None,
845825 self .wt = wt
846826
847827 self .M = dist (xs , xt , metric = self .metric )
848- self .normalizeM ( norm )
828+ self .M = cost_normalization ( self . M , norm )
849829 self .G = sinkhorn_l1l2_gl (ws , ys , wt , self .M , reg , eta , ** kwargs )
850830 self .computed = True
851831
@@ -1001,7 +981,7 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
1001981
1002982 # pairwise distance
1003983 self .cost_ = dist (Xs , Xt , metric = self .metric )
1004- self .normalizeCost_ ( self .norm )
984+ self .cost_ = cost_normalization ( self . cost_ , self .norm )
1005985
1006986 if (ys is not None ) and (yt is not None ):
1007987
@@ -1183,26 +1163,6 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11831163
11841164 return transp_Xt
11851165
1186- def normalizeCost_ (self , norm ):
1187- """ Apply normalization to the loss matrix
1188-
1189-
1190- Parameters
1191- ----------
1192- norm : str
1193- type of normalization from 'median','max','log','loglog'
1194-
1195- """
1196-
1197- if norm == "median" :
1198- self .cost_ /= float (np .median (self .cost_ ))
1199- elif norm == "max" :
1200- self .cost_ /= float (np .max (self .cost_ ))
1201- elif norm == "log" :
1202- self .cost_ = np .log (1 + self .cost_ )
1203- elif norm == "loglog" :
1204- self .cost_ = np .log (1 + np .log (1 + self .cost_ ))
1205-
12061166
12071167class SinkhornTransport (BaseTransport ):
12081168 """Domain Adapatation OT method based on Sinkhorn Algorithm
0 commit comments