Skip to content

Commit 522ebb0

Browse files
author
Andrey Oskin
committed
refactoring with multiple dispatch
1 parent 0db1ef9 commit 522ebb0

File tree

3 files changed

+107
-63
lines changed

3 files changed

+107
-63
lines changed

src/ParallelKMeans.jl

Lines changed: 80 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,52 @@
11
module ParallelKMeans
22

3-
43
using StatsBase
54
import Base.Threads: @spawn, @threads
65

76
export kmeans
87

8+
abstract type CalculationMode end
9+
10+
struct SingleThread <: CalculationMode end
11+
12+
struct MultiThread <: CalculationMode
13+
n::Int
14+
end
15+
MultiThread() = MultiThread(Threads.nthreads())
16+
17+
"""
18+
pairwise!(target, x, y, mode)
19+
20+
Let X and Y respectively have m and n columns. Then the `pairwise!` function
21+
computes distances between each pair of columns in X and Y and store result
22+
in `target` array. Argument `mode` defines calculation mode, currently
23+
following modes supported
24+
- SingleThread()
25+
- MultiThread()
26+
"""
27+
pairwise!(target, x, y) = pairwise!(target, x, y, SingleThread())
28+
29+
function pairwise!(target, x, y, mode::SingleThread)
30+
ncol = size(x, 2)
31+
32+
@inbounds for k in axes(y, 1)
33+
for i in axes(x, 1)
34+
target[i, k] = (x[i, 1] - y[k, 1])^2
35+
end
36+
37+
for j in 2:ncol
38+
for i in axes(x, 1)
39+
target[i, k] += (x[i, j] - y[k, j])^2
40+
end
41+
end
42+
end
43+
target
44+
end
45+
946
"""
10-
TODO 1: Document function
47+
divider(n, k)
48+
49+
Utility function, splits 1:n sequence to k chunks of approximately same size.
1150
"""
1251
function divider(n, k)
1352
d = div(n, k)
@@ -16,14 +55,11 @@ function divider(n, k)
1655
end
1756

1857

19-
"""
20-
TODO 2: Document function
21-
"""
22-
function pl_pairwise!(target, x, y, nth = Threads.nthreads())
58+
function pairwise!(target, x, y, mode::MultiThread)
2359
ncol = size(x, 2)
2460
nrow = size(x, 1)
2561

26-
ranges = divider(nrow, nth)
62+
ranges = divider(nrow, mode.n)
2763
waiting_list = Task[]
2864

2965
for i in 1:length(ranges) - 1
@@ -41,7 +77,11 @@ end
4177

4278

4379
"""
44-
TODO 3: Document function
80+
inner_pairwise!(target, x, y, r)
81+
82+
Utility function for calculation of [pairwise!(target, x, y, mode)](@ref) function.
83+
UnitRange argument `r` select subarray of original design matrix `x` that is going
84+
to be processed.
4585
"""
4686
function inner_pairwise!(target, x, y, r)
4787
ncol = size(x, 2)
@@ -61,40 +101,21 @@ function inner_pairwise!(target, x, y, r)
61101
end
62102

63103

64-
"""
65-
TODO 4: Document function
66-
"""
67-
function pairwise!(target, x, y)
68-
ncol = size(x, 2)
69-
70-
@inbounds for k in axes(y, 1)
71-
for i in axes(x, 1)
72-
target[i, k] = (x[i, 1] - y[k, 1])^2
73-
end
74-
75-
for j in 2:ncol
76-
for i in axes(x, 1)
77-
target[i, k] += (x[i, j] - y[k, j])^2
78-
end
79-
end
80-
end
81-
target
82-
end
83-
84-
85104
"""
86105
smart_init(X, k; init="k-means++")
87106
88-
This function handles the random initialisation of the centroids from the
89-
design matrix (X) and desired groups (k) that a user supplies.
107+
This function handles the random initialisation of the centroids from the
108+
design matrix (X) and desired groups (k) that a user supplies.
90109
91-
`k-means++` algorithm is used by default with the normal random selection
92-
of centroids from X used if any other string is attempted.
110+
`k-means++` algorithm is used by default with the normal random selection
111+
of centroids from X used if any other string is attempted.
93112
94-
A tuple representing the centroids, number of rows, & columns respecitively
95-
is returned.
113+
A tuple representing the centroids, number of rows, & columns respecitively
114+
is returned.
96115
"""
97-
function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
116+
function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
117+
init::String="k-means++") where {T <: CalculationMode}
118+
98119
n_row, n_col = size(X)
99120

100121
if init == "k-means++"
@@ -111,7 +132,7 @@ function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
111132

112133
# flatten distances
113134
# distances = vec(pairwise(SqEuclidean(), X, first_centroid_matrix, dims = 1))
114-
pairwise!(distances, X, first_centroid_matrix)
135+
pairwise!(distances, X, first_centroid_matrix, mode)
115136

116137
for i = 2:k
117138
# choose the next centroid, the probability for each data point to be chosen
@@ -127,7 +148,7 @@ function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
127148
# compute distances from the centroids to all data points
128149
current_centroid_matrix = convert(Matrix, centroids[i, :]')
129150
# new_distances = vec(pairwise(SqEuclidean(), X, current_centroid_matrix, dims = 1))
130-
pairwise!(new_distances, X, first_centroid_matrix)
151+
pairwise!(new_distances, X, first_centroid_matrix, mode)
131152

132153
# and update the squared distance as the minimum distance to all centroid
133154
# distances = minimum([distances, new_distances])
@@ -140,7 +161,6 @@ function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
140161
# randomly select points from the design matrix as the initial centroids
141162
rand_indices = rand(1:n_row, k)
142163
centroids = X[rand_indices, :]
143-
144164
end
145165

146166
return centroids, n_row, n_col
@@ -150,10 +170,10 @@ end
150170
"""
151171
sum_of_squares(x, labels, centre, k)
152172
153-
This function computes the total sum of squares based on the assigned (labels)
154-
design matrix(x), centroids (centre), and the number of desired groups (k).
173+
This function computes the total sum of squares based on the assigned (labels)
174+
design matrix(x), centroids (centre), and the number of desired groups (k).
155175
156-
A Float type representing the computed metric is returned.
176+
A Float type representing the computed metric is returned.
157177
"""
158178
function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Array)
159179
s = 0.0
@@ -171,24 +191,24 @@ end
171191
"""
172192
Kmeans(design_matrix, k; k_init="k-means++", max_iters=300, tol=1e-4, verbose=true)
173193
174-
This main function employs the K-means algorithm to cluster all examples
175-
in the training data (design_matrix) into k groups using either the
176-
`k-means++` or random initialisation technique for selecting the initial
177-
centroids.
178-
179-
At the end of the number of iterations specified (max_iters), convergence is
180-
achieved if difference between the current and last cost objective is
181-
less than the tolerance level (tol). An error is thrown if convergence fails.
194+
This main function employs the K-means algorithm to cluster all examples
195+
in the training data (design_matrix) into k groups using either the
196+
`k-means++` or random initialisation technique for selecting the initial
197+
centroids.
182198
183-
Details of operations can be either printed or not by setting verbose accordingly.
199+
At the end of the number of iterations specified (max_iters), convergence is
200+
achieved if difference between the current and last cost objective is
201+
less than the tolerance level (tol). An error is thrown if convergence fails.
184202
185-
A tuple representing labels, centroids, and sum_squares respectively is returned.
203+
Details of operations can be either printed or not by setting verbose accordingly.
186204
205+
A tuple representing labels, centroids, and sum_squares respectively is returned.
187206
"""
188-
function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-means++",
189-
max_iters::Int = 300, tol = 1e-4, verbose::Bool = true)
207+
function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
208+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4,
209+
verbose::Bool = true) where {T <: CalculationMode}
190210

191-
centroids, n_row, n_col = smart_init(design_matrix, k, init=k_init)
211+
centroids, n_row, n_col = smart_init(design_matrix, k, mode, init=k_init)
192212

193213
labels = Vector{Int}(undef, n_row)
194214
distances = Vector{Float64}(undef, n_row)
@@ -199,7 +219,7 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-me
199219
nearest_neighbour = Array{Float64, 2}(undef, size(design_matrix, 1), size(centroids, 1))
200220
# Update centroids & labels with closest members until convergence
201221
for iter = 1:max_iters
202-
pairwise!(nearest_neighbour, design_matrix, centroids)
222+
pairwise!(nearest_neighbour, design_matrix, centroids, mode)
203223

204224
@inbounds for i in axes(nearest_neighbour, 1)
205225
labels[i] = 1
@@ -230,11 +250,11 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-me
230250

231251
if verbose
232252
# Show progress and terminate if J stopped decreasing.
233-
println("Iteration ", iter, ": Jclust = ", J, ".")
253+
println("Iteration $iter: Jclust = $J.")
234254
end
235255

236256
# Final Step: Check for convergence
237-
if iter > 1 && abs(J - J_previous) < (tol * J)
257+
if (iter > 1) & (abs(J - J_previous) < (tol * J))
238258

239259
sum_squares = sum_of_squares(design_matrix, labels, centroids)
240260

@@ -243,16 +263,15 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-me
243263
println("Successfully terminated with convergence.")
244264
end
245265

246-
return labels, centroids, sum_squares
266+
return (labels=labels, centroids=centroids, sum_squares=sum_squares)
247267

248-
elseif iter == max_iters && abs(J - J_previous) > (tol * J)
268+
elseif (iter == max_iters) & (abs(J - J_previous) > (tol * J))
249269
throw(error("Failed to converge Check data and/or implementation or increase max_iter."))
250270

251271
end
252272

253273
J_previous = J
254274
end
255-
256275
end
257276

258277
end # module

test/test01_distance.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module TestDistance
2-
using ParallelKMeans: pairwise!, pl_pairwise!
2+
using ParallelKMeans: pairwise!, SingleThread, MultiThread
33
using Test
44

55
@testset "naive singlethread pairwise" begin
@@ -11,4 +11,14 @@ using Test
1111
@test all(r .≈ [0.0, 13.0, 25.0])
1212
end
1313

14+
@testset "multithread pairwise" begin
15+
X = [1.0 2.0; 3.0 5.0; 4.0 6.0]
16+
y = [1.0 2.0; ]
17+
r = Array{Float64, 2}(undef, 3, 1)
18+
19+
pairwise!(r, X, y, MultiThread())
20+
@test all(r .≈ [0.0, 13.0, 25.0])
21+
end
22+
23+
1424
end # module

test/test02_kmeans.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module TestKMeans
2+
23
using ParallelKMeans
4+
using ParallelKMeans: MultiThread
35
using Test
46
using Random
57

6-
@testset "linear separation" begin
8+
@testset "singlethread linear separation" begin
79
Random.seed!(2020)
810

911
X = rand(100, 3)
@@ -15,4 +17,17 @@ using Random
1517
@test sum_squares 15.314823028363763
1618
end
1719

20+
21+
@testset "multithread linear separation" begin
22+
Random.seed!(2020)
23+
24+
X = rand(100, 3)
25+
labels, centroids, sum_squares = kmeans(X, 3, MultiThread(); tol = 1e-10, verbose = false)
26+
27+
# for future reference: Clustering shows here 14.964882850452984
28+
# guess they use better initialisation. For now we will use own
29+
# value
30+
@test sum_squares 15.314823028363763
31+
end
32+
1833
end # module

0 commit comments

Comments
 (0)