77import multiprocessing
88
99import time
10- __time_tic_toc = time .time ()
10+ __time_tic_toc = time .time ()
11+
1112
1213def tic ():
1314 """ Python implementation of Matlab tic() function """
1415 global __time_tic_toc
15- __time_tic_toc = time .time ()
16+ __time_tic_toc = time .time ()
17+
1618
1719def toc (message = 'Elapsed time : {} s' ):
1820 """ Python implementation of Matlab toc() function """
19- t = time .time ()
20- print (message .format (t - __time_tic_toc ))
21- return t - __time_tic_toc
21+ t = time .time ()
22+ print (message .format (t - __time_tic_toc ))
23+ return t - __time_tic_toc
24+
2225
2326def toq ():
2427 """ Python implementation of Julia toc() function """
25- t = time .time ()
26- return t - __time_tic_toc
28+ t = time .time ()
29+ return t - __time_tic_toc
2730
2831
29- def kernel (x1 ,x2 ,method = 'gaussian' ,sigma = 1 ,** kwargs ):
32+ def kernel (x1 , x2 , method = 'gaussian' , sigma = 1 , ** kwargs ):
3033 """Compute kernel matrix"""
31- if method .lower () in ['gaussian' ,'gauss' ,'rbf' ]:
32- K = np .exp (- dist (x1 ,x2 )/ ( 2 * sigma ** 2 ))
34+ if method .lower () in ['gaussian' , 'gauss' , 'rbf' ]:
35+ K = np .exp (- dist (x1 , x2 ) / ( 2 * sigma ** 2 ))
3336 return K
3437
38+
3539def unif (n ):
3640 """ return a uniform histogram of length n (simplex)
3741
@@ -48,17 +52,19 @@ def unif(n):
4852
4953
5054 """
51- return np .ones ((n ,))/ n
55+ return np .ones ((n ,)) / n
5256
53- def clean_zeros (a ,b ,M ):
54- """ Remove all components with zeros weights in a and b
57+
58+ def clean_zeros (a , b , M ):
59+ """ Remove all components with zeros weights in a and b
5560 """
56- M2 = M [a > 0 ,:][:,b > 0 ].copy () # copy force c style matrix (froemd)
57- a2 = a [a > 0 ]
58- b2 = b [b > 0 ]
59- return a2 ,b2 ,M2
61+ M2 = M [a > 0 , :][:, b > 0 ].copy () # copy force c style matrix (froemd)
62+ a2 = a [a > 0 ]
63+ b2 = b [b > 0 ]
64+ return a2 , b2 , M2
65+
6066
61- def dist (x1 ,x2 = None ,metric = 'sqeuclidean' ):
67+ def dist (x1 , x2 = None , metric = 'sqeuclidean' ):
6268 """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
6369
6470 Parameters
@@ -84,12 +90,12 @@ def dist(x1,x2=None,metric='sqeuclidean'):
8490
8591 """
8692 if x2 is None :
87- x2 = x1
93+ x2 = x1
8894
89- return cdist (x1 ,x2 ,metric = metric )
95+ return cdist (x1 , x2 , metric = metric )
9096
9197
92- def dist0 (n ,method = 'lin_square' ):
98+ def dist0 (n , method = 'lin_square' ):
9399 """Compute standard cost matrices of size (n,n) for OT problems
94100
95101 Parameters
@@ -111,16 +117,17 @@ def dist0(n,method='lin_square'):
111117
112118
113119 """
114- res = 0
115- if method == 'lin_square' :
116- x = np .arange (n ,dtype = np .float64 ).reshape ((n ,1 ))
117- res = dist (x ,x )
120+ res = 0
121+ if method == 'lin_square' :
122+ x = np .arange (n , dtype = np .float64 ).reshape ((n , 1 ))
123+ res = dist (x , x )
118124 return res
119125
120126
121127def dots (* args ):
122128 """ dots function for multiple matrix multiply """
123- return reduce (np .dot ,args )
129+ return reduce (np .dot , args )
130+
124131
125132def fun (f , q_in , q_out ):
126133 """ Utility function for parmap with no serializing problems """
@@ -130,6 +137,7 @@ def fun(f, q_in, q_out):
130137 break
131138 q_out .put ((i , f (x )))
132139
140+
133141def parmap (f , X , nprocs = multiprocessing .cpu_count ()):
134142 """ paralell map for multiprocessing """
135143 q_in = multiprocessing .Queue (1 )
@@ -147,4 +155,4 @@ def parmap(f, X, nprocs=multiprocessing.cpu_count()):
147155
148156 [p .join () for p in proc ]
149157
150- return [x for i , x in sorted (res )]
158+ return [x for i , x in sorted (res )]
0 commit comments