1919# Get the number of avaialble threads for multithreading implementation
2020MultiThread () = MultiThread (Threads. nthreads ())
2121
22+ # TODO here we mimic `Clustering` data structure, should thing how to integrate these
23+ # two packages more closely.
24+
25+ """
26+ ClusteringResult
27+ Base type for the output of clustering algorithm.
28+ """
29+ abstract type ClusteringResult end
30+
31+ # C is the type of centers, an (abstract) matrix of size (d x k)
32+ # D is the type of pairwise distance computation from points to cluster centers
33+ # WC is the type of cluster weights, either Int (in the case where points are
34+ # unweighted) or eltype(weights) (in the case where points are weighted).
35+ """
36+ KmeansResult{C,D<:Real,WC<:Real} <: ClusteringResult
37+ The output of [`kmeans`](@ref) and [`kmeans!`](@ref).
38+ # Type parameters
39+ * `C<:AbstractMatrix{<:AbstractFloat}`: type of the `centers` matrix
40+ * `D<:Real`: type of the assignment cost
41+ * `WC<:Real`: type of the cluster weight
42+ """
43+ struct KmeansResult{C<: AbstractMatrix{<:AbstractFloat} ,D<: Real ,WC<: Real } <: ClusteringResult
44+ centers:: C # cluster centers (d x k)
45+ assignments:: Vector{Int} # assignments (n)
46+ costs:: Vector{D} # cost of the assignments (n)
47+ counts:: Vector{Int} # number of points assigned to each cluster (k)
48+ wcounts:: Vector{WC} # cluster weights (k)
49+ totalcost:: D # total cost (i.e. objective)
50+ iterations:: Int # number of elapsed iterations
51+ converged:: Bool # whether the procedure converged
52+ end
53+
2254"""
23- pairwise !(target, x, y, mode)
55+ colwise !(target, x, y, mode)
2456
25- Let X and Y respectively have m and n columns . Then the `pairwise !` function
26- computes distances between each pair of columns in X and Y and store result
57+ Let X is a matrix `m x n` and Y is a vector of the length `m` . Then the `colwise !` function
58+ computes distance between each column in X and Y and store result
2759in `target` array. Argument `mode` defines calculation mode, currently
2860following modes supported
2961- SingleThread()
3062- MultiThread()
3163"""
32- pairwise ! (target, x, y) = pairwise ! (target, x, y, SingleThread ())
64+ colwise ! (target, x, y) = colwise ! (target, x, y, SingleThread ())
3365
34- function pairwise! (target, x, y, mode:: SingleThread )
35- ncol = size (x, 2 )
36-
37- @inbounds for k in axes (y, 1 )
66+ function colwise! (target, x, y, mode:: SingleThread )
67+ @inbounds for j in axes (x, 2 )
68+ res = 0.0
3869 for i in axes (x, 1 )
39- target[i, k] = (x[i, 1 ] - y[k, 1 ])^ 2
40- end
41-
42- for j in 2 : ncol
43- for i in axes (x, 1 )
44- target[i, k] += (x[i, j] - y[k, j])^ 2
45- end
70+ res += (x[i, j] - y[i])^ 2
4671 end
72+ target[j] = res
4773 end
48- target
4974end
5075
5176"""
52- divider (n, k)
77+ spliiter (n, k)
5378
5479Utility function, splits 1:n sequence to k chunks of approximately same size.
5580"""
56- function divider (n, k)
57- d = div (n, k)
58- xz = vcat (collect ((0 : k- 1 ) * d), n)
59- return [t[1 ]: t[2 ] for t in zip (xz[1 : end - 1 ] .+ 1 , xz[2 : end ])]
81+ function splitter (n, k)
82+ xz = Int .(ceil .(range (0 , n, length = k+ 1 )))
83+ return [xz[i]+ 1 : xz[i+ 1 ] for i in 1 : k]
6084end
6185
6286
63- function pairwise ! (target, x, y, mode:: MultiThread )
87+ function colwise ! (target, x, y, mode:: MultiThread )
6488 ncol = size (x, 2 )
65- nrow = size (x, 1 )
6689
67- ranges = divider (nrow , mode. n)
90+ ranges = splitter (ncol , mode. n)
6891 waiting_list = Task[]
6992
7093 for i in 1 : length (ranges) - 1
71- push! (waiting_list, @spawn inner_pairwise ! (target, x, y, ranges[i]))
94+ push! (waiting_list, @spawn chunk_colwise ! (target, x, y, ranges[i]))
7295 end
7396
74- inner_pairwise ! (target, x, y, ranges[end ])
97+ chunk_colwise ! (target, x, y, ranges[end ])
7598
7699 for i in 1 : length (ranges) - 1
77100 wait (waiting_list[i])
82105
83106
84107"""
85- inner_pairwise !(target, x, y, r)
108+ chunk_colwise !(target, x, y, r)
86109
87- Utility function for calculation of [pairwise !(target, x, y, mode)](@ref ) function.
110+ Utility function for calculation of the colwise !(target, x, y, mode) function.
88111UnitRange argument `r` select subarray of original design matrix `x` that is going
89112to be processed.
90113"""
91- function inner_pairwise! (target, x, y, r)
92- ncol = size (x, 2 )
93-
94- @inbounds for k in axes (y, 1 )
95- for i in r
96- target[i, k] = (x[i, 1 ] - y[k, 1 ])^ 2
97- end
98-
99- for j in 2 : ncol
100- for i in r
101- target[i, k] += (x[i, j] - y[k, j])^ 2
102- end
114+ function chunk_colwise! (target, x, y, r)
115+ @inbounds for j in r
116+ res = 0.0
117+ for i in axes (x, 1 )
118+ res += (x[i, j] - y[i])^ 2
103119 end
120+ target[j] = res
104121 end
105- target
106122end
107123
108-
109124"""
110125 smart_init(X, k; init="k-means++")
111126
@@ -126,37 +141,33 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
126141 if init == " k-means++"
127142
128143 # randonmly select the first centroid from the data (X)
129- centroids = zeros (k, n_col )
144+ centroids = zeros (n_row, k )
130145 rand_indices = Vector {Int} (undef, k)
131- rand_idx = rand (1 : n_row )
146+ rand_idx = rand (1 : n_col )
132147 rand_indices[1 ] = rand_idx
133- centroids[1 , :] .= X[rand_idx, :]
134- centroids[k, :] .= 0.0
135- distances = Array {Float64} (undef, n_row, 1 )
136- new_distances = Array {Float64} (undef, n_row, 1 )
148+ centroids[:, 1 ] .= X[:, rand_idx]
149+ distances = Vector {Float64} (undef, n_col)
150+ new_distances = Vector {Float64} (undef, n_col)
137151
138152 # TODO : Add `colwise` function (or use it from `Distances` package)
139153 # compute distances from the first centroid chosen to all the other data points
140- first_centroid_matrix = convert (Matrix, centroids[1 , :]' )
141154
142155 # flatten distances
143- pairwise ! (distances, X, first_centroid_matrix , mode)
156+ colwise ! (distances, X, centroids[:, 1 ] , mode)
144157 distances[rand_idx] = 0.0
145158
146159 for i = 2 : k
147160 # choose the next centroid, the probability for each data point to be chosen
148161 # is directly proportional to its squared distance from the nearest centroid
149- r_idx = wsample (1 : n_row , vec (distances))
162+ r_idx = wsample (1 : n_col , vec (distances))
150163 rand_indices[i] = r_idx
151- centroids[i, : ] .= X[r_idx, : ]
164+ centroids[:, i ] .= X[:, r_idx ]
152165
153166 # no need for final distance update
154167 i == k && break
155168
156169 # compute distances from the centroids to all data points
157- current_centroid_matrix = convert (Matrix, centroids[i, :]' )
158- # new_distances = vec(pairwise(SqEuclidean(), X, current_centroid_matrix, dims = 1))
159- pairwise! (new_distances, X, first_centroid_matrix, mode)
170+ colwise! (new_distances, X, centroids[:, i], mode)
160171
161172 # and update the squared distance as the minimum distance to all centroid
162173 for i in 1 : n_row
@@ -167,8 +178,9 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
167178
168179 else
169180 # randomly select points from the design matrix as the initial centroids
170- rand_indices = rand (1 : n_row, k)
171- centroids = X[rand_indices, :]
181+ # TODO change rand to sample
182+ rand_indices = rand (1 : n_col, k)
183+ centroids = X[:, rand_indices]
172184 end
173185
174186 return (centroids = centroids, indices = rand_indices)
@@ -188,7 +200,7 @@ function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Arr
188200
189201 @inbounds for j in axes (x, 2 )
190202 for i in axes (x, 1 )
191- s += (x[i, j] - centre[labels[i], j ])^ 2
203+ s += (x[i, j] - centre[i, labels[j] ])^ 2
192204 end
193205 end
194206
@@ -214,72 +226,116 @@ A tuple representing labels, centroids, and sum_squares respectively is returned
214226"""
215227function kmeans (design_matrix:: Array{Float64, 2} , k:: Int , mode:: T = SingleThread ();
216228 k_init:: String = " k-means++" , max_iters:: Int = 300 , tol = 1e-4 , verbose:: Bool = true , init = nothing ) where {T <: CalculationMode }
217-
218- n_row, n_col = size (design_matrix)
229+ nrow, ncol = size (design_matrix)
219230 centroids = init == nothing ? smart_init (design_matrix, k, mode, init= k_init). centroids : init
231+ new_centroids = similar (centroids)
220232
221- labels = Vector {Int} (undef, n_row)
222- distances = Vector {Float64} (undef, n_row)
223- centroids_cnt = Vector {Int} (undef, size (centroids, 1 ))
233+ labels = Vector {Int} (undef, ncol)
234+ centroids_cnt = Vector {Int} (undef, k)
224235
225236 J_previous = Inf64
237+ totalcost = Inf
226238
227- nearest_neighbour = Array {Float64, 2} (undef, size (design_matrix, 1 ), size (centroids, 1 ))
239+ # nearest_neighbour = Array{Float64, 2}(undef, size(design_matrix, 1), size(centroids, 1))
228240 # Update centroids & labels with closest members until convergence
229241 for iter = 1 : max_iters
230- pairwise! (nearest_neighbour, design_matrix, centroids, mode)
231-
232- @inbounds for i in axes (nearest_neighbour, 1 )
233- labels[i] = 1
234- distances[i] = nearest_neighbour[i, 1 ]
235- for j in 2 : size (nearest_neighbour, 2 )
236- if distances[i] > nearest_neighbour[i, j]
237- labels[i] = j
238- distances[i] = nearest_neighbour[i, j]
239- end
240- end
241- end
242-
243- centroids .= 0.0
244- centroids_cnt .= 0
245- @inbounds for i in axes (design_matrix, 1 )
246- centroids[labels[i], 1 ] += design_matrix[i, 1 ]
247- centroids_cnt[labels[i]] += 1
248- end
249- @inbounds for j in 2 : n_col
250- for i in axes (design_matrix, 1 )
251- centroids[labels[i], j] += design_matrix[i, j]
252- end
253- end
254- centroids ./= centroids_cnt
255-
256- # Cost objective
257- J = mean (distances)
242+ J = update_centroids! (centroids, new_centroids, centroids_cnt, labels, design_matrix, mode)
243+ J /= ncol
258244
259245 if verbose
260246 # Show progress and terminate if J stopped decreasing.
261- println (" Iteration $iter : Jclust = $J . " )
247+ println (" Iteration $iter : Jclust = $J " )
262248 end
263249
264250 # Final Step: Check for convergence
265251 if (iter > 1 ) & (abs (J - J_previous) < (tol * J))
266252
267- sum_squares = sum_of_squares (design_matrix, labels, centroids)
253+ totalcost = sum_of_squares (design_matrix, labels, centroids)
268254
269255 # Terminate algorithm with the assumption that K-means has converged
270256 if verbose
271257 println (" Successfully terminated with convergence." )
272258 end
273259
274- return (labels= labels, centroids= centroids, sum_squares= sum_squares)
260+ # TODO empty vectors should be calculated
261+ # TODO Float64 type definitions is too restrictive, should be relaxed
262+ # especially during GPU related development
263+ return KmeansResult (centroids, labels, Float64[], Int[], Float64[], totalcost, iter, true )
275264
276265 elseif (iter == max_iters) & (abs (J - J_previous) > (tol * J))
277- throw (error (" Failed to converge Check data and/or implementation or increase max_iter." ))
278-
266+ return KmeansResult (centroids, labels, Float64[], Int[], Float64[], totalcost, iter + 1 , false )
279267 end
280268
281269 J_previous = J
282270 end
283271end
284272
273+ function update_centroids! (centroids, new_centroids, centroids_cnt, labels,
274+ design_matrix, mode:: SingleThread )
275+
276+ r = axes (design_matrix, 2 )
277+ J = chunk_update_centroids! (centroids, new_centroids, centroids_cnt, labels,
278+ design_matrix, r, mode)
279+
280+ centroids .= new_centroids ./ centroids_cnt'
281+
282+ return J
283+ end
284+
285+ function update_centroids! (centroids, new_centroids, centroids_cnt, labels,
286+ design_matrix, mode:: MultiThread )
287+ mode. n == 1 && return update_centroids! (centroids, new_centroids, centroids_cnt, labels,
288+ design_matrix, SingleThread ())
289+
290+ ncol = size (design_matrix, 2 )
291+
292+ ranges = splitter (ncol, mode. n)
293+
294+ waiting_list = Vector {Task} (undef, mode. n - 1 )
295+
296+ for i in 1 : length (ranges) - 1
297+ waiting_list[i] = @spawn chunk_update_centroids! (centroids, new_centroids, centroids_cnt, labels,
298+ design_matrix, ranges[i], mode)
299+ end
300+
301+ J = chunk_update_centroids! (centroids, new_centroids, centroids_cnt, labels,
302+ design_matrix, ranges[end ], mode)
303+
304+ J += sum (fetch .(waiting_list))
305+
306+ centroids .= new_centroids ./ centroids_cnt'
307+
308+ return J
309+ end
310+
311+
312+ function chunk_update_centroids! (centroids, new_centroids, centroids_cnt, labels,
313+ design_matrix, r, mode:: T = SingleThread ()) where {T <: CalculationMode }
314+
315+ new_centroids .= 0.0
316+ centroids_cnt .= 0
317+ J = 0.0
318+ @inbounds for i in r
319+ min_distance = Inf
320+ label = 1
321+ for k in axes (centroids, 2 )
322+ distance = 0.0
323+ for j in axes (design_matrix, 1 )
324+ distance += (design_matrix[j, i] - centroids[j, k])^ 2
325+ end
326+ label = min_distance > distance ? k : label
327+ min_distance = min_distance > distance ? distance : min_distance
328+ end
329+ labels[i] = label
330+ centroids_cnt[label] += 1
331+ for j in axes (design_matrix, 1 )
332+ new_centroids[j, label] += design_matrix[j, i]
333+ end
334+ J += min_distance
335+ end
336+ # centroids .= new_centroids ./ centroids_cnt'
337+
338+ return J
339+ end
340+
285341end # module
0 commit comments