Skip to content

Commit f33087d

Browse files
committed
doc utils.py
1 parent c418ef4 commit f33087d

File tree

1 file changed

+46
-5
lines changed

1 file changed

+46
-5
lines changed

ot/utils.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,62 @@
66

77

88
def unif(n):
9-
""" return a uniform histogram (simplex) """
9+
""" return a uniform histogram of length n (simplex) """
1010
return np.ones((n,))/n
1111

1212

1313
def dist(x1,x2=None,metric='sqeuclidean'):
14-
"""Compute distance between samples in x1 and x2"""
14+
"""Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist
15+
16+
Parameters
17+
----------
18+
19+
x1 : np.array (n1,d)
20+
matrix with n1 samples of size d
21+
x2 : np.array (n2,d), optional
22+
matrix with n2 samples of size d (if None then x2=x1)
23+
metric : str, fun, optional
24+
name of the metric to be computed (full list in the doc of scipy), If a string,
25+
the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
26+
‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
27+
‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
28+
‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
29+
30+
31+
Returns
32+
-------
33+
M : np.array (n1,n2)
34+
distance matrix computed with given metric
35+
36+
"""
1537
if x2 is None:
1638
return cdist(x1,x1,metric=metric)
1739
else:
1840
return cdist(x1,x2,metric=metric)
1941

20-
def dist0(n,method='linear'):
21-
"""Compute stardard cos matrices for OT problems"""
42+
def dist0(n,method='lin_square'):
43+
"""Compute standard cost matrices of size (n,n) for OT problems
44+
45+
Parameters
46+
----------
47+
48+
n : int
49+
size of the cost matrix
50+
method : str, optional
51+
Type of loss matrix chosen from:
52+
53+
* 'lin_square' : linear sampling between 0 and n-1, quadratic loss
54+
55+
56+
Returns
57+
-------
58+
M : np.array (n1,n2)
59+
distance matrix computed with given metric
60+
61+
62+
"""
2263
res=0
23-
if method=='linear':
64+
if method=='lin_square':
2465
x=np.arange(n,dtype=np.float64).reshape((n,1))
2566
res=dist(x,x)
2667
return res

0 commit comments

Comments
 (0)