@@ -127,7 +127,9 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
127127
128128 # randonmly select the first centroid from the data (X)
129129 centroids = zeros (k, n_col)
130+ rand_indices = Vector {Int} (undef, k)
130131 rand_idx = rand (1 : n_row)
132+ rand_indices[1 ] = rand_idx
131133 centroids[1 , :] .= X[rand_idx, :]
132134 distances = Array {Float64} (undef, n_row, 1 )
133135 new_distances = Array {Float64} (undef, n_row, 1 )
@@ -143,6 +145,7 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
143145 # choose the next centroid, the probability for each data point to be chosen
144146 # is directly proportional to its squared distance from the nearest centroid
145147 r_idx = sample (1 : n_row, ProbabilityWeights (vec (distances)))
148+ rand_indices[i] = r_idx
146149 centroids[i, :] .= X[r_idx, :]
147150
148151 # Ignore setting the last centroid to help the separation of centroids
@@ -168,7 +171,7 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
168171 centroids = X[rand_indices, :]
169172 end
170173
171- return centroids, n_row, n_col
174+ return ( centroids = centroids, indices = rand_indices)
172175end
173176
174177
@@ -210,10 +213,10 @@ Details of operations can be either printed or not by setting verbose accordingl
210213A tuple representing labels, centroids, and sum_squares respectively is returned.
211214"""
212215function kmeans (design_matrix:: Array{Float64, 2} , k:: Int , mode:: T = SingleThread ();
213- k_init:: String = " k-means++" , max_iters:: Int = 300 , tol = 1e-4 ,
214- verbose:: Bool = true ) where {T <: CalculationMode }
216+ k_init:: String = " k-means++" , max_iters:: Int = 300 , tol = 1e-4 , verbose:: Bool = true , init = nothing ) where {T <: CalculationMode }
215217
216- centroids, n_row, n_col = smart_init (design_matrix, k, mode, init= k_init)
218+ n_row, n_col = size (design_matrix)
219+ centroids = init == nothing ? smart_init (design_matrix, k, mode, init= k_init). centroids : init
217220
218221 labels = Vector {Int} (undef, n_row)
219222 distances = Vector {Float64} (undef, n_row)
0 commit comments