@@ -19,20 +19,24 @@ function kmeans!(alg::Hamerly, containers, design_matrix, k;
1919 nrow, ncol = size (design_matrix)
2020 centroids = init == nothing ? smart_init (design_matrix, k, n_threads, init= k_init). centroids : deepcopy (init)
2121
22- initialize ! (alg, containers, centroids, design_matrix, n_threads )
22+ @parallelize n_threads ncol chunk_initialize ! (alg, containers, centroids, design_matrix)
2323
2424 converged = false
2525 niters = 1
2626 J_previous = 0.0
27+ p = containers. p
2728
2829 # Update centroids & labels with closest members until convergence
29-
3030 while niters <= max_iters
3131 update_containers! (containers, alg, centroids, n_threads)
32- update_centroids! (centroids, containers, alg, design_matrix, n_threads)
32+ @parallelize n_threads ncol chunk_update_centroids! (centroids, containers, alg, design_matrix)
33+ collect_containers (alg, containers, n_threads)
34+
3335 J = sum (containers. ub)
3436 move_centers! (centroids, containers, alg)
35- update_bounds! (containers, n_threads)
37+
38+ r1, r2, pr1, pr2 = double_argmax (p)
39+ @parallelize n_threads ncol chunk_update_bounds! (containers, r1, r2, pr1, pr2)
3640
3741 if verbose
3842 # Show progress and terminate if J stopped decreasing.
@@ -49,7 +53,8 @@ function kmeans!(alg::Hamerly, containers, design_matrix, k;
4953 niters += 1
5054 end
5155
52- totalcost = sum_of_squares (design_matrix, containers. labels, centroids)
56+ @parallelize n_threads ncol sum_of_squares (containers, design_matrix, containers. labels, centroids)
57+ totalcost = sum (containers. sum_of_squares)
5358
5459 # Terminate algorithm with the assumption that K-means has converged
5560 if verbose & converged
@@ -101,6 +106,9 @@ function create_containers(alg::Hamerly, k, nrow, ncol, n_threads)
101106 # distance from the center to the closest other center
102107 s = Vector {Float64} (undef, k)
103108
109+ # total_sum_calculation
110+ sum_of_squares = Vector {Float64} (undef, n_threads)
111+
104112 return (
105113 centroids_new = centroids_new,
106114 centroids_cnt = centroids_cnt,
@@ -109,31 +117,15 @@ function create_containers(alg::Hamerly, k, nrow, ncol, n_threads)
109117 lb = lb,
110118 p = p,
111119 s = s,
120+ sum_of_squares = sum_of_squares
112121 )
113122end
114123
115- function initialize! (alg:: Hamerly , containers, centroids, design_matrix, n_threads)
116- ncol = size (design_matrix, 2 )
117-
118- if n_threads == 1
119- r = axes (design_matrix, 2 )
120- chunk_initialize! (alg, containers, centroids, design_matrix, r, 1 )
121- else
122- ranges = splitter (ncol, n_threads)
123-
124- waiting_list = Vector {Task} (undef, n_threads - 1 )
125-
126- for i in 1 : n_threads - 1
127- waiting_list[i] = @spawn chunk_initialize! (alg, containers, centroids,
128- design_matrix, ranges[i], i + 1 )
129- end
130-
131- chunk_initialize! (alg, containers, centroids, design_matrix, ranges[end ], 1 )
132-
133- wait .(waiting_list)
134- end
135- end
124+ """
125+ chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r, idx)
136126
127+ Initial calulation of all bounds and points labeling.
128+ """
137129function chunk_initialize! (alg:: Hamerly , containers, centroids, design_matrix, r, idx)
138130 centroids_cnt = containers. centroids_cnt[idx]
139131 centroids_new = containers. centroids_new[idx]
@@ -147,6 +139,11 @@ function chunk_initialize!(alg::Hamerly, containers, centroids, design_matrix, r
147139 end
148140end
149141
142+ """
143+ update_containers!(containers, ::Hamerly, centroids, n_threads)
144+
145+ Calculates minimum distances from centers to each other.
146+ """
150147function update_containers! (containers, :: Hamerly , centroids, n_threads)
151148 s = containers. s
152149 s .= Inf
@@ -160,39 +157,14 @@ function update_containers!(containers, ::Hamerly, centroids, n_threads)
160157 end
161158end
162159
163- function update_centroids! (centroids, containers, alg:: Hamerly , design_matrix, n_threads)
164-
165- if n_threads == 1
166- r = axes (design_matrix, 2 )
167- chunk_update_centroids! (centroids, containers, alg, design_matrix, r, 1 )
168- else
169- ncol = size (design_matrix, 2 )
170- ranges = splitter (ncol, n_threads)
171-
172- waiting_list = Vector {Task} (undef, n_threads - 1 )
173-
174- for i in 1 : length (ranges) - 1
175- waiting_list[i] = @spawn chunk_update_centroids! (centroids, containers,
176- alg, design_matrix, ranges[i], i)
177- end
178-
179- chunk_update_centroids! (centroids, containers, alg, design_matrix, ranges[end ], n_threads)
180-
181- wait .(waiting_list)
182-
183- end
184-
185- collect_containers (alg, containers, n_threads)
186- end
160+ """
161+ chunk_update_centroids!(centroids, containers, alg::Hamerly, design_matrix, r, idx)
187162
188- function chunk_update_centroids! (
189- centroids,
190- containers,
191- alg:: Hamerly ,
192- design_matrix,
193- r,
194- idx,
195- )
163+ Detailed description of this function can be found in the original paper. It iterates through
164+ all points and tries to skip some calculation using known upper and lower bounds of distances
165+ from point to centers. If it fails to skip than it fall back to generic `point_all_centers!` function.
166+ """
167+ function chunk_update_centroids! (centroids, containers, alg:: Hamerly , design_matrix, r, idx)
196168
197169 # unpack containers for easier manipulations
198170 centroids_new = containers. centroids_new[idx]
@@ -227,6 +199,11 @@ function chunk_update_centroids!(
227199 end
228200end
229201
202+ """
203+ point_all_centers!(containers, centroids, design_matrix, i)
204+
205+ Calculates new labels and upper and lower bounds for all points.
206+ """
230207function point_all_centers! (containers, centroids, design_matrix, i)
231208 ub = containers. ub
232209 lb = containers. lb
@@ -253,6 +230,12 @@ function point_all_centers!(containers, centroids, design_matrix, i)
253230 return label
254231end
255232
233+ """
234+ move_centers!(centroids, containers, ::Hamerly)
235+
236+ Calculates new positions of centers and distance they have moved. Results are stored
237+ in `centroids` and `p` respectively.
238+ """
256239function move_centers! (centroids, containers, :: Hamerly )
257240 centroids_new = containers. centroids_new[end ]
258241 p = containers. p
@@ -267,35 +250,28 @@ function move_centers!(centroids, containers, ::Hamerly)
267250 end
268251end
269252
270- function update_bounds! (containers, n_threads)
271- p = containers . p
253+ """
254+ chunk_update_bounds!(containers, r1, r2, pr1, pr2, r, idx)
272255
273- r1, r2 = double_argmax (p)
274- pr1 = p[r1]
275- pr2 = p[r2]
256+ Updates upper and lower bounds of point distance to the centers, with regard to the centers movement.
257+ Since bounds are squred distance, `sqrt` is used to make corresponding estimation, unlike
258+ the original paper, where usual metric is used.
276259
277- if n_threads == 1
278- r = axes (containers. ub, 1 )
279- chunk_update_bounds! (containers, r, r1, r2, pr1, pr2)
280- else
281- ncol = length (containers. ub)
282- ranges = splitter (ncol, n_threads)
260+ Using notation from original paper, `u` is upper bound and `a` is `labels`, so
283261
284- waiting_list = Vector {Task} (undef, n_threads - 1 )
262+ `u[i] -> u[i] + p[a[i]]`
285263
286- for i in 1 : n_threads - 1
287- waiting_list[i] = @spawn chunk_update_bounds! (containers, ranges[i], r1, r2, pr1, pr2)
288- end
264+ then squared distance is
289265
290- chunk_update_bounds! (containers, ranges[ end ], r1, r2, pr1, pr2)
266+ `u[i]^2 -> (u[i] + p[a[i]])^2 = u[i]^2 + 2 p[a[i]] u[i] + p[a[i]]^2`
291267
292- for i in 1 : n_threads - 1
293- wait (waiting_list[i])
294- end
295- end
296- end
268+ Taking into account that in our noations `p^2 -> p`, `u^2 -> ub` we obtain
269+
270+ `ub[i] -> ub[i] + 2 sqrt(p[a[i]] ub[i]) + p[a[i]]`
297271
298- function chunk_update_bounds! (containers, r, r1, r2, pr1, pr2)
272+ The same applies to the lower bounds.
273+ """
274+ function chunk_update_bounds! (containers, r1, r2, pr1, pr2, r, idx)
299275 p = containers. p
300276 ub = containers. ub
301277 lb = containers. lb
@@ -312,6 +288,11 @@ function chunk_update_bounds!(containers, r, r1, r2, pr1, pr2)
312288 end
313289end
314290
291+ """
292+ double_argmax(p)
293+
294+ Finds maximum and next after maximum arguments.
295+ """
315296function double_argmax (p)
316297 r1, r2 = 1 , 1
317298 d1 = p[1 ]
@@ -328,19 +309,5 @@ function double_argmax(p)
328309 end
329310 end
330311
331- r1, r2
332- end
333-
334- """
335- distance(X1, X2, i1, i2)
336-
337- Allocation less calculation of square eucledean distance between vectors X1[:, i1] and X2[:, i2]
338- """
339- function distance (X1, X2, i1, i2)
340- d = 0.0
341- @inbounds for i in axes (X1, 1 )
342- d += (X1[i, i1] - X2[i, i2])^ 2
343- end
344-
345- return d
312+ r1, r2, d1, d2
346313end
0 commit comments