Skip to content

Commit 971b456

Browse files
author
Andrey Oskin
committed
add init
1 parent 0b2cb43 commit 971b456

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/ParallelKMeans.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
127127

128128
# randonmly select the first centroid from the data (X)
129129
centroids = zeros(k, n_col)
130+
rand_indices = Vector{Int}(undef, k)
130131
rand_idx = rand(1:n_row)
132+
rand_indices[1] = rand_idx
131133
centroids[1, :] .= X[rand_idx, :]
132134
distances = Array{Float64}(undef, n_row, 1)
133135
new_distances = Array{Float64}(undef, n_row, 1)
@@ -143,6 +145,7 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
143145
# choose the next centroid, the probability for each data point to be chosen
144146
# is directly proportional to its squared distance from the nearest centroid
145147
r_idx = sample(1:n_row, ProbabilityWeights(vec(distances)))
148+
rand_indices[i] = r_idx
146149
centroids[i, :] .= X[r_idx, :]
147150

148151
# Ignore setting the last centroid to help the separation of centroids
@@ -168,7 +171,7 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
168171
centroids = X[rand_indices, :]
169172
end
170173

171-
return centroids, n_row, n_col
174+
return (centroids = centroids, indices = rand_indices)
172175
end
173176

174177

@@ -210,10 +213,10 @@ Details of operations can be either printed or not by setting verbose accordingl
210213
A tuple representing labels, centroids, and sum_squares respectively is returned.
211214
"""
212215
function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
213-
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4,
214-
verbose::Bool = true) where {T <: CalculationMode}
216+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
215217

216-
centroids, n_row, n_col = smart_init(design_matrix, k, mode, init=k_init)
218+
n_row, n_col = size(design_matrix)
219+
centroids = init == nothing ? smart_init(design_matrix, k, mode, init=k_init).centroids : init
217220

218221
labels = Vector{Int}(undef, n_row)
219222
distances = Vector{Float64}(undef, n_row)

0 commit comments

Comments
 (0)