|
6 | 6 |
|
7 | 7 |
|
8 | 8 | def unif(n): |
9 | | - """ return a uniform histogram (simplex) """ |
| 9 | + """ return a uniform histogram of length n (simplex) """ |
10 | 10 | return np.ones((n,))/n |
11 | 11 |
|
12 | 12 |
|
13 | 13 | def dist(x1,x2=None,metric='sqeuclidean'): |
14 | | - """Compute distance between samples in x1 and x2""" |
| 14 | + """Compute distance between samples in x1 and x2 using function scipy.spatial.distance.cdist |
| 15 | + |
| 16 | + Parameters |
| 17 | + ---------- |
| 18 | +
|
| 19 | + x1 : np.array (n1,d) |
| 20 | + matrix with n1 samples of size d |
| 21 | + x2 : np.array (n2,d), optional |
| 22 | + matrix with n2 samples of size d (if None then x2=x1) |
| 23 | + metric : str, fun, optional |
| 24 | + name of the metric to be computed (full list in the doc of scipy), If a string, |
| 25 | + the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’, |
| 26 | + ‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’, |
| 27 | + ‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’, |
| 28 | + ‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’. |
| 29 | +
|
| 30 | + |
| 31 | + Returns |
| 32 | + ------- |
| 33 | + M : np.array (n1,n2) |
| 34 | + distance matrix computed with given metric |
| 35 | + |
| 36 | + """ |
15 | 37 | if x2 is None: |
16 | 38 | return cdist(x1,x1,metric=metric) |
17 | 39 | else: |
18 | 40 | return cdist(x1,x2,metric=metric) |
19 | 41 |
|
20 | | -def dist0(n,method='linear'): |
21 | | - """Compute stardard cos matrices for OT problems""" |
| 42 | +def dist0(n,method='lin_square'): |
| 43 | + """Compute standard cost matrices of size (n,n) for OT problems |
| 44 | + |
| 45 | + Parameters |
| 46 | + ---------- |
| 47 | +
|
| 48 | + n : int |
| 49 | + size of the cost matrix |
| 50 | + method : str, optional |
| 51 | + Type of loss matrix chosen from: |
| 52 | +
|
| 53 | + * 'lin_square' : linear sampling between 0 and n-1, quadratic loss |
| 54 | +
|
| 55 | + |
| 56 | + Returns |
| 57 | + ------- |
| 58 | + M : np.array (n1,n2) |
| 59 | + distance matrix computed with given metric |
| 60 | + |
| 61 | + |
| 62 | + """ |
22 | 63 | res=0 |
23 | | - if method=='linear': |
| 64 | + if method=='lin_square': |
24 | 65 | x=np.arange(n,dtype=np.float64).reshape((n,1)) |
25 | 66 | res=dist(x,x) |
26 | 67 | return res |
|
0 commit comments