Skip to content

Commit 9523b1e

Browse files
committed
doc datasets.py
1 parent f33087d commit 9523b1e

File tree

3 files changed

+83
-25
lines changed

3 files changed

+83
-25
lines changed

ot/bregman.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,6 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
276276
277277
The optimization problem is solved suing the algorithm described in [4]
278278
279-
280-
distrib : distribution to unmix
281-
D : Dictionnary
282-
M : Metric matrix in the space of the distributions to unmix
283-
M0 : Metric matrix in the space of the 'abundance values' to solve for
284-
h0 : prior on solution (generally uniform distribution)
285-
reg,reg0 : transport regularizations
286-
alpha : how much should we trust the prior ? ([0,1])
287279
288280
Parameters
289281
----------
@@ -300,7 +292,9 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
300292
reg: float
301293
Regularization term >0 (Wasserstein data fitting)
302294
reg0: float
303-
Regularization term >0 (Wasserstein reg with h0)
295+
Regularization term >0 (Wasserstein reg with h0)
296+
alpha: float
297+
How much should we trust the prior ([0,1])
304298
numItermax: int, optional
305299
Max number of iterations
306300
stopThr: float, optional
@@ -318,7 +312,7 @@ def unmix(a,D,M,M0,h0,reg,reg0,alpha,numItermax = 1000, stopThr=1e-3,verbose=Fal
318312
log: dict
319313
log dictionary return only if log==True in parameters
320314
321-
References
315+
References
322316
----------
323317
324318
.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, Supervised planetary unmixing with optimal transport, Whorkshop on Hyperspectral Image and Signal Processing : Evolution in Remote Sensing (WHISPERS), 2016.

ot/datasets.py

Lines changed: 60 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,50 @@
88

99

1010
def get_1D_gauss(n,m,s):
11-
"return a 1D histogram for a gaussian distribution (n bins, mean m and std s) "
11+
"""return a 1D histogram for a gaussian distribution (n bins, mean m and std s)
12+
13+
Parameters
14+
----------
15+
16+
n : int
17+
number of bins in the histogram
18+
m : float
19+
mean value of the gaussian distribution
20+
s : float
21+
standard deviaton of the gaussian distribution
22+
23+
24+
Returns
25+
-------
26+
h : np.array (n,)
27+
1D histogram for a gaussian distribution
28+
29+
"""
1230
x=np.arange(n,dtype=np.float64)
1331
h=np.exp(-(x-m)**2/(2*s^2))
1432
return h/h.sum()
1533

1634

1735
def get_2D_samples_gauss(n,m,sigma):
18-
"return samples from 2D gaussian (n samples, mean m and cov sigma) "
36+
"""return n samples drawn from 2D gaussian N(m,sigma)
37+
38+
Parameters
39+
----------
40+
41+
n : int
42+
number of bins in the histogram
43+
m : np.array (2,)
44+
mean value of the gaussian distribution
45+
sigma : np.array (2,2)
46+
covariance matrix of the gaussian distribution
47+
48+
49+
Returns
50+
-------
51+
X : np.array (n,2)
52+
n samples drawn from N(m,sigma)
53+
54+
"""
1955
if np.isscalar(sigma):
2056
sigma=np.array([sigma,])
2157
if len(sigma)>1:
@@ -26,8 +62,26 @@ def get_2D_samples_gauss(n,m,sigma):
2662
return res
2763

2864
def get_data_classif(dataset,n,nz=.5,**kwargs):
29-
"""
30-
dataset generation
65+
""" dataset generation for classification problems
66+
67+
Parameters
68+
----------
69+
70+
dataset : str
71+
type of classification problem (see code)
72+
n : int
73+
number of training samples
74+
nz : float
75+
noise level (>0)
76+
77+
78+
Returns
79+
-------
80+
X : np.array (n,d)
81+
n observation of size d
82+
y : np.array (n,)
83+
labels of the samples
84+
3185
"""
3286
if dataset.lower()=='3gauss':
3387
y=np.floor((np.arange(n)*1.0/n*3))+1
@@ -50,15 +104,10 @@ def get_data_classif(dataset,n,nz=.5,**kwargs):
50104
x[y==3,0]=2. ; x[y==3,1]=0
51105

52106
x[y!=3,:]+=nz*np.random.randn(sum(y!=3),2)
53-
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
54-
# elif dataset.lower()=='sinreg':
55-
#
56-
# x=np.random.rand(n,1)
57-
# y=4*x+np.sin(2*np.pi*x)+nz*np.random.randn(n,1)
58-
107+
x[y==3,:]+=2*nz*np.random.randn(sum(y==3),2)
59108
else:
60109
x=0
61110
y=0
62111
print("unknown dataset")
63112

64-
return x,y
113+
return x,y.astype(int)

ot/utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# -*- coding: utf-8 -*-
12
"""
23
Various function that can be usefull
34
"""
@@ -6,7 +7,21 @@
67

78

89
def unif(n):
9-
""" return a uniform histogram of length n (simplex) """
10+
""" return a uniform histogram of length n (simplex)
11+
12+
Parameters
13+
----------
14+
15+
n : int
16+
number of bins in the histogram
17+
18+
Returns
19+
-------
20+
h : np.array (n,)
21+
histogram of length n such that h_i=1/n for all i
22+
23+
24+
"""
1025
return np.ones((n,))/n
1126

1227

@@ -22,9 +37,9 @@ def dist(x1,x2=None,metric='sqeuclidean'):
2237
matrix with n2 samples of size d (if None then x2=x1)
2338
metric : str, fun, optional
2439
name of the metric to be computed (full list in the doc of scipy), If a string,
25-
the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
40+
the distance function can be ‘braycurtis’, ‘canberra’, ‘chebyshev’, ‘cityblock’,
2641
‘correlation’, ‘cosine’, ‘dice’, ‘euclidean’, ‘hamming’, ‘jaccard’, ‘kulsinski’,
27-
‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
42+
‘mahalanobis’, ‘matching’, ‘minkowski’, ‘rogerstanimoto’, ‘russellrao’, ‘seuclidean’,
2843
‘sokalmichener’, ‘sokalsneath’, ‘sqeuclidean’, ‘wminkowski’, ‘yule’.
2944
3045
@@ -68,5 +83,5 @@ def dist0(n,method='lin_square'):
6883

6984

7085
def dots(*args):
71-
""" Stupid but nice dots function for multiple matrix multiply """
86+
""" dots function for multiple matrix multiply """
7287
return reduce(np.dot,args)

0 commit comments

Comments
 (0)