@@ -207,6 +207,25 @@ function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Arr
207207 return s
208208end
209209
210+ # TODO generalize centroids type
211+ function create_containers (k, d, mode:: SingleThread )
212+ new_centroids = Array {Float64, 2} (undef, d, k)
213+ centroids_cnt = Vector {Int} (undef, k)
214+
215+ return new_centroids, centroids_cnt
216+ end
217+
218+ function create_containers (k, d, mode:: MultiThread )
219+ new_centroids = Vector {Array{Float64, 2}} (undef, mode. n)
220+ centroids_cnt = Vector {Vector{Int}} (undef, mode. n)
221+
222+ for i in 1 : mode. n
223+ new_centroids[i] = Array {Float64, 2} (undef, d, k)
224+ centroids_cnt[i] = Vector {Int} (undef, k)
225+ end
226+
227+ return new_centroids, centroids_cnt
228+ end
210229
211230"""
212231 Kmeans(design_matrix, k; k_init="k-means++", max_iters=300, tol=1e-4, verbose=true)
@@ -228,15 +247,15 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread
228247 k_init:: String = " k-means++" , max_iters:: Int = 300 , tol = 1e-4 , verbose:: Bool = true , init = nothing ) where {T <: CalculationMode }
229248 nrow, ncol = size (design_matrix)
230249 centroids = init == nothing ? smart_init (design_matrix, k, mode, init= k_init). centroids : init
231- new_centroids = similar (centroids)
250+ new_centroids, centroids_cnt = create_containers (k, nrow, mode)
251+ # new_centroids = similar(centroids)
232252
233253 labels = Vector {Int} (undef, ncol)
234- centroids_cnt = Vector {Int} (undef, k)
254+ # centroids_cnt = Vector{Int}(undef, k)
235255
236256 J_previous = Inf64
237257 totalcost = Inf
238258
239- # nearest_neighbour = Array{Float64, 2}(undef, size(design_matrix, 1), size(centroids, 1))
240259 # Update centroids & labels with closest members until convergence
241260 for iter = 1 : max_iters
242261 J = update_centroids! (centroids, new_centroids, centroids_cnt, labels, design_matrix, mode)
@@ -294,16 +313,21 @@ function update_centroids!(centroids, new_centroids, centroids_cnt, labels,
294313 waiting_list = Vector {Task} (undef, mode. n - 1 )
295314
296315 for i in 1 : length (ranges) - 1
297- waiting_list[i] = @spawn chunk_update_centroids! (centroids, new_centroids, centroids_cnt, labels,
316+ waiting_list[i] = @spawn chunk_update_centroids! (centroids, new_centroids[i + 1 ] , centroids_cnt[i + 1 ] , labels,
298317 design_matrix, ranges[i], mode)
299318 end
300319
301- J = chunk_update_centroids! (centroids, new_centroids, centroids_cnt, labels,
320+ J = chunk_update_centroids! (centroids, new_centroids[ 1 ] , centroids_cnt[ 1 ] , labels,
302321 design_matrix, ranges[end ], mode)
303322
304323 J += sum (fetch .(waiting_list))
305324
306- centroids .= new_centroids ./ centroids_cnt'
325+ for i in 1 : length (ranges) - 1
326+ new_centroids[1 ] .+ = new_centroids[i + 1 ]
327+ centroids_cnt[1 ] .+ = centroids_cnt[i + 1 ]
328+ end
329+
330+ centroids .= new_centroids[1 ] ./ centroids_cnt[1 ]'
307331
308332 return J
309333end
0 commit comments