1414from .bregman import sinkhorn
1515from .lp import emd
1616from .utils import unif , dist , kernel
17- from .utils import deprecated , BaseEstimator
17+ from .utils import check_params , deprecated , BaseEstimator
1818from .optim import cg
1919from .optim import gcg
2020
@@ -954,6 +954,26 @@ def distribution_estimation_uniform(X):
954954
955955
956956class BaseTransport (BaseEstimator ):
957+ """Base class for OTDA objects
958+
959+ Notes
960+ -----
961+ All estimators should specify all the parameters that can be set
962+ at the class level in their ``__init__`` as explicit keyword
963+ arguments (no ``*args`` or ``**kwargs``).
964+
965+ fit method should:
966+ - estimate a cost matrix and store it in a `cost_` attribute
967+ - estimate a coupling matrix and store it in a `coupling_`
968+ attribute
969+ - estimate distributions from source and target data and store them in
970+ mu_s and mu_t attributes
971+ - store Xs and Xt in attributes to be used later on in transform and
972+ inverse_transform methods
973+
974+ transform method should always get as input a Xs parameter
975+ inverse_transform method should always get as input a Xt parameter
976+ """
957977
958978 def fit (self , Xs = None , ys = None , Xt = None , yt = None ):
959979 """Build a coupling matrix from source and target sets of samples
@@ -976,7 +996,9 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
976996 Returns self.
977997 """
978998
979- if Xs is not None and Xt is not None :
999+ # check the necessary inputs parameters are here
1000+ if check_params (Xs = Xs , Xt = Xt ):
1001+
9801002 # pairwise distance
9811003 self .cost_ = dist (Xs , Xt , metric = self .metric )
9821004
@@ -1003,14 +1025,10 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
10031025 self .mu_t = self .distribution_estimation (Xt )
10041026
10051027 # store arrays of samples
1006- self .Xs = Xs
1007- self .Xt = Xt
1028+ self .xs_ = Xs
1029+ self .xt_ = Xt
10081030
1009- return self
1010- else :
1011- print ("POT-Warning" )
1012- print ("Please provide both Xs and Xt arguments when calling" )
1013- print ("fit method" )
1031+ return self
10141032
10151033 def fit_transform (self , Xs = None , ys = None , Xt = None , yt = None ):
10161034 """Build a coupling matrix from source and target sets of samples
@@ -1058,16 +1076,19 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
10581076 The transport source samples.
10591077 """
10601078
1061- if Xs is not None :
1062- if np .array_equal (self .Xs , Xs ):
1079+ # check the necessary inputs parameters are here
1080+ if check_params (Xs = Xs ):
1081+
1082+ if np .array_equal (self .xs_ , Xs ):
1083+
10631084 # perform standard barycentric mapping
10641085 transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
10651086
10661087 # set nans to 0
10671088 transp [~ np .isfinite (transp )] = 0
10681089
10691090 # compute transported samples
1070- transp_Xs = np .dot (transp , self .Xt )
1091+ transp_Xs = np .dot (transp , self .xt_ )
10711092 else :
10721093 # perform out of sample mapping
10731094 indices = np .arange (Xs .shape [0 ])
@@ -1079,26 +1100,23 @@ def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128):
10791100 for bi in batch_ind :
10801101
10811102 # get the nearest neighbor in the source domain
1082- D0 = dist (Xs [bi ], self .Xs )
1103+ D0 = dist (Xs [bi ], self .xs_ )
10831104 idx = np .argmin (D0 , axis = 1 )
10841105
10851106 # transport the source samples
10861107 transp = self .coupling_ / np .sum (
10871108 self .coupling_ , 1 )[:, None ]
10881109 transp [~ np .isfinite (transp )] = 0
1089- transp_Xs_ = np .dot (transp , self .Xt )
1110+ transp_Xs_ = np .dot (transp , self .xt_ )
10901111
10911112 # define the transported points
1092- transp_Xs_ = transp_Xs_ [idx , :] + Xs [bi ] - self .Xs [idx , :]
1113+ transp_Xs_ = transp_Xs_ [idx , :] + Xs [bi ] - self .xs_ [idx , :]
10931114
10941115 transp_Xs .append (transp_Xs_ )
10951116
10961117 transp_Xs = np .concatenate (transp_Xs , axis = 0 )
10971118
10981119 return transp_Xs
1099- else :
1100- print ("POT-Warning" )
1101- print ("Please provide Xs argument when calling transform method" )
11021120
11031121 def inverse_transform (self , Xs = None , ys = None , Xt = None , yt = None ,
11041122 batch_size = 128 ):
@@ -1123,16 +1141,19 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11231141 The transported target samples.
11241142 """
11251143
1126- if Xt is not None :
1127- if np .array_equal (self .Xt , Xt ):
1144+ # check the necessary inputs parameters are here
1145+ if check_params (Xt = Xt ):
1146+
1147+ if np .array_equal (self .xt_ , Xt ):
1148+
11281149 # perform standard barycentric mapping
11291150 transp_ = self .coupling_ .T / np .sum (self .coupling_ , 0 )[:, None ]
11301151
11311152 # set nans to 0
11321153 transp_ [~ np .isfinite (transp_ )] = 0
11331154
11341155 # compute transported samples
1135- transp_Xt = np .dot (transp_ , self .Xs )
1156+ transp_Xt = np .dot (transp_ , self .xs_ )
11361157 else :
11371158 # perform out of sample mapping
11381159 indices = np .arange (Xt .shape [0 ])
@@ -1143,26 +1164,23 @@ def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None,
11431164 transp_Xt = []
11441165 for bi in batch_ind :
11451166
1146- D0 = dist (Xt [bi ], self .Xt )
1167+ D0 = dist (Xt [bi ], self .xt_ )
11471168 idx = np .argmin (D0 , axis = 1 )
11481169
11491170 # transport the target samples
11501171 transp_ = self .coupling_ .T / np .sum (
11511172 self .coupling_ , 0 )[:, None ]
11521173 transp_ [~ np .isfinite (transp_ )] = 0
1153- transp_Xt_ = np .dot (transp_ , self .Xs )
1174+ transp_Xt_ = np .dot (transp_ , self .xs_ )
11541175
11551176 # define the transported points
1156- transp_Xt_ = transp_Xt_ [idx , :] + Xt [bi ] - self .Xt [idx , :]
1177+ transp_Xt_ = transp_Xt_ [idx , :] + Xt [bi ] - self .xt_ [idx , :]
11571178
11581179 transp_Xt .append (transp_Xt_ )
11591180
11601181 transp_Xt = np .concatenate (transp_Xt , axis = 0 )
11611182
11621183 return transp_Xt
1163- else :
1164- print ("POT-Warning" )
1165- print ("Please provide Xt argument when calling inverse_transform" )
11661184
11671185
11681186class SinkhornTransport (BaseTransport ):
@@ -1428,7 +1446,8 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14281446 Returns self.
14291447 """
14301448
1431- if Xs is not None and Xt is not None and ys is not None :
1449+ # check the necessary inputs parameters are here
1450+ if check_params (Xs = Xs , Xt = Xt , ys = ys ):
14321451
14331452 super (SinkhornLpl1Transport , self ).fit (Xs , ys , Xt , yt )
14341453
@@ -1438,10 +1457,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
14381457 numInnerItermax = self .max_inner_iter , stopInnerThr = self .tol ,
14391458 verbose = self .verbose )
14401459
1441- return self
1442- else :
1443- print ("POT-Warning" )
1444- print ("Please provide both Xs, Xt, ys arguments to fit method" )
1460+ return self
14451461
14461462
14471463class SinkhornL1l2Transport (BaseTransport ):
@@ -1537,7 +1553,8 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15371553 Returns self.
15381554 """
15391555
1540- if Xs is not None and Xt is not None and ys is not None :
1556+ # check the necessary inputs parameters are here
1557+ if check_params (Xs = Xs , Xt = Xt , ys = ys ):
15411558
15421559 super (SinkhornL1l2Transport , self ).fit (Xs , ys , Xt , yt )
15431560
@@ -1554,10 +1571,7 @@ def fit(self, Xs, ys=None, Xt=None, yt=None):
15541571 self .coupling_ = returned_
15551572 self .log_ = dict ()
15561573
1557- return self
1558- else :
1559- print ("POT-Warning" )
1560- print ("Please, provide both Xs, Xt and ys argument to fit method" )
1574+ return self
15611575
15621576
15631577class MappingTransport (BaseEstimator ):
@@ -1652,29 +1666,35 @@ def fit(self, Xs=None, ys=None, Xt=None, yt=None):
16521666 Returns self
16531667 """
16541668
1655- self .Xs = Xs
1656- self .Xt = Xt
1657-
1658- if self .kernel == "linear" :
1659- returned_ = joint_OT_mapping_linear (
1660- Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1661- verbose = self .verbose , verbose2 = self .verbose2 ,
1662- numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1663- stopThr = self .tol , stopInnerThr = self .inner_tol , log = self .log )
1669+ # check the necessary inputs parameters are here
1670+ if check_params (Xs = Xs , Xt = Xt ):
1671+
1672+ self .xs_ = Xs
1673+ self .xt_ = Xt
1674+
1675+ if self .kernel == "linear" :
1676+ returned_ = joint_OT_mapping_linear (
1677+ Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1678+ verbose = self .verbose , verbose2 = self .verbose2 ,
1679+ numItermax = self .max_iter ,
1680+ numInnerItermax = self .max_inner_iter , stopThr = self .tol ,
1681+ stopInnerThr = self .inner_tol , log = self .log )
1682+
1683+ elif self .kernel == "gaussian" :
1684+ returned_ = joint_OT_mapping_kernel (
1685+ Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1686+ sigma = self .sigma , verbose = self .verbose ,
1687+ verbose2 = self .verbose , numItermax = self .max_iter ,
1688+ numInnerItermax = self .max_inner_iter ,
1689+ stopInnerThr = self .inner_tol , stopThr = self .tol ,
1690+ log = self .log )
16641691
1665- elif self .kernel == "gaussian" :
1666- returned_ = joint_OT_mapping_kernel (
1667- Xs , Xt , mu = self .mu , eta = self .eta , bias = self .bias ,
1668- sigma = self .sigma , verbose = self .verbose , verbose2 = self .verbose ,
1669- numItermax = self .max_iter , numInnerItermax = self .max_inner_iter ,
1670- stopInnerThr = self .inner_tol , stopThr = self .tol , log = self .log )
1671-
1672- # deal with the value of log
1673- if self .log :
1674- self .coupling_ , self .mapping_ , self .log_ = returned_
1675- else :
1676- self .coupling_ , self .mapping_ = returned_
1677- self .log_ = dict ()
1692+ # deal with the value of log
1693+ if self .log :
1694+ self .coupling_ , self .mapping_ , self .log_ = returned_
1695+ else :
1696+ self .coupling_ , self .mapping_ = returned_
1697+ self .log_ = dict ()
16781698
16791699 return self
16801700
@@ -1692,22 +1712,26 @@ def transform(self, Xs):
16921712 The transport source samples.
16931713 """
16941714
1695- if np .array_equal (self .Xs , Xs ):
1696- # perform standard barycentric mapping
1697- transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
1715+ # check the necessary inputs parameters are here
1716+ if check_params (Xs = Xs ):
16981717
1699- # set nans to 0
1700- transp [~ np .isfinite (transp )] = 0
1718+ if np .array_equal (self .xs_ , Xs ):
1719+ # perform standard barycentric mapping
1720+ transp = self .coupling_ / np .sum (self .coupling_ , 1 )[:, None ]
17011721
1702- # compute transported samples
1703- transp_Xs = np .dot (transp , self .Xt )
1704- else :
1705- if self .kernel == "gaussian" :
1706- K = kernel (Xs , self .Xs , method = self .kernel , sigma = self .sigma )
1707- elif self .kernel == "linear" :
1708- K = Xs
1709- if self .bias :
1710- K = np .hstack ((K , np .ones ((Xs .shape [0 ], 1 ))))
1711- transp_Xs = K .dot (self .mapping_ )
1722+ # set nans to 0
1723+ transp [~ np .isfinite (transp )] = 0
17121724
1713- return transp_Xs
1725+ # compute transported samples
1726+ transp_Xs = np .dot (transp , self .xt_ )
1727+ else :
1728+ if self .kernel == "gaussian" :
1729+ K = kernel (Xs , self .xs_ , method = self .kernel ,
1730+ sigma = self .sigma )
1731+ elif self .kernel == "linear" :
1732+ K = Xs
1733+ if self .bias :
1734+ K = np .hstack ((K , np .ones ((Xs .shape [0 ], 1 ))))
1735+ transp_Xs = K .dot (self .mapping_ )
1736+
1737+ return transp_Xs
0 commit comments