Skip to content

Commit 0b2cb43

Browse files
authored
Merge pull request #6 from PyDataBlog/experimental
Experimental
2 parents d70db4b + 5239c9e commit 0b2cb43

File tree

3 files changed

+117
-62
lines changed

3 files changed

+117
-62
lines changed

src/ParallelKMeans.jl

Lines changed: 90 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,57 @@
11
module ParallelKMeans
22

3-
43
using StatsBase
54
import Base.Threads: @spawn, @threads
65

76
export kmeans
87

8+
abstract type CalculationMode end
9+
10+
# Single thread class to control the calculation type based on the CalculationMode
11+
struct SingleThread <: CalculationMode
12+
end
13+
14+
# Multi threaded implementation to control the calculation type based avaialble threads
15+
struct MultiThread <: CalculationMode
16+
n::Int
17+
end
18+
19+
# Get the number of avaialble threads for multithreading implementation
20+
MultiThread() = MultiThread(Threads.nthreads())
21+
922
"""
10-
TODO 1: Document function
23+
pairwise!(target, x, y, mode)
24+
25+
Let X and Y respectively have m and n columns. Then the `pairwise!` function
26+
computes distances between each pair of columns in X and Y and store result
27+
in `target` array. Argument `mode` defines calculation mode, currently
28+
following modes supported
29+
- SingleThread()
30+
- MultiThread()
31+
"""
32+
pairwise!(target, x, y) = pairwise!(target, x, y, SingleThread())
33+
34+
function pairwise!(target, x, y, mode::SingleThread)
35+
ncol = size(x, 2)
36+
37+
@inbounds for k in axes(y, 1)
38+
for i in axes(x, 1)
39+
target[i, k] = (x[i, 1] - y[k, 1])^2
40+
end
41+
42+
for j in 2:ncol
43+
for i in axes(x, 1)
44+
target[i, k] += (x[i, j] - y[k, j])^2
45+
end
46+
end
47+
end
48+
target
49+
end
50+
51+
"""
52+
divider(n, k)
53+
54+
Utility function, splits 1:n sequence to k chunks of approximately same size.
1155
"""
1256
function divider(n, k)
1357
d = div(n, k)
@@ -16,18 +60,19 @@ function divider(n, k)
1660
end
1761

1862

19-
"""
20-
TODO 2: Document function
21-
"""
22-
function pl_pairwise!(target, x, y, nth = Threads.nthreads())
63+
function pairwise!(target, x, y, mode::MultiThread)
2364
ncol = size(x, 2)
2465
nrow = size(x, 1)
25-
ranges = divider(nrow, nth)
66+
67+
ranges = divider(nrow, mode.n)
2668
waiting_list = Task[]
69+
2770
for i in 1:length(ranges) - 1
2871
push!(waiting_list, @spawn inner_pairwise!(target, x, y, ranges[i]))
2972
end
73+
3074
inner_pairwise!(target, x, y, ranges[end])
75+
3176
for i in 1:length(ranges) - 1
3277
wait(waiting_list[i])
3378
end
@@ -37,10 +82,15 @@ end
3782

3883

3984
"""
40-
TODO 3: Document function
85+
inner_pairwise!(target, x, y, r)
86+
87+
Utility function for calculation of [pairwise!(target, x, y, mode)](@ref) function.
88+
UnitRange argument `r` select subarray of original design matrix `x` that is going
89+
to be processed.
4190
"""
4291
function inner_pairwise!(target, x, y, r)
4392
ncol = size(x, 2)
93+
4494
@inbounds for k in axes(y, 1)
4595
for i in r
4696
target[i, k] = (x[i, 1] - y[k, 1])^2
@@ -56,39 +106,21 @@ function inner_pairwise!(target, x, y, r)
56106
end
57107

58108

59-
"""
60-
TODO 4: Document function
61-
"""
62-
function pairwise!(target, x, y)
63-
ncol = size(x, 2)
64-
@inbounds for k in axes(y, 1)
65-
for i in axes(x, 1)
66-
target[i, k] = (x[i, 1] - y[k, 1])^2
67-
end
68-
69-
for j in 2:ncol
70-
for i in axes(x, 1)
71-
target[i, k] += (x[i, j] - y[k, j])^2
72-
end
73-
end
74-
end
75-
target
76-
end
77-
78-
79109
"""
80110
smart_init(X, k; init="k-means++")
81111
82-
This function handles the random initialisation of the centroids from the
83-
design matrix (X) and desired groups (k) that a user supplies.
112+
This function handles the random initialisation of the centroids from the
113+
design matrix (X) and desired groups (k) that a user supplies.
84114
85-
`k-means++` algorithm is used by default with the normal random selection
86-
of centroids from X used if any other string is attempted.
115+
`k-means++` algorithm is used by default with the normal random selection
116+
of centroids from X used if any other string is attempted.
87117
88-
A tuple representing the centroids, number of rows, & columns respecitively
89-
is returned.
118+
A tuple representing the centroids, number of rows, & columns respecitively
119+
is returned.
90120
"""
91-
function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
121+
function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
122+
init::String="k-means++") where {T <: CalculationMode}
123+
92124
n_row, n_col = size(X)
93125

94126
if init == "k-means++"
@@ -105,7 +137,7 @@ function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
105137

106138
# flatten distances
107139
# distances = vec(pairwise(SqEuclidean(), X, first_centroid_matrix, dims = 1))
108-
pairwise!(distances, X, first_centroid_matrix)
140+
pairwise!(distances, X, first_centroid_matrix, mode)
109141

110142
for i = 2:k
111143
# choose the next centroid, the probability for each data point to be chosen
@@ -121,7 +153,7 @@ function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
121153
# compute distances from the centroids to all data points
122154
current_centroid_matrix = convert(Matrix, centroids[i, :]')
123155
# new_distances = vec(pairwise(SqEuclidean(), X, current_centroid_matrix, dims = 1))
124-
pairwise!(new_distances, X, first_centroid_matrix)
156+
pairwise!(new_distances, X, first_centroid_matrix, mode)
125157

126158
# and update the squared distance as the minimum distance to all centroid
127159
# distances = minimum([distances, new_distances])
@@ -134,7 +166,6 @@ function smart_init(X::Array{Float64, 2}, k::Int; init::String="k-means++")
134166
# randomly select points from the design matrix as the initial centroids
135167
rand_indices = rand(1:n_row, k)
136168
centroids = X[rand_indices, :]
137-
138169
end
139170

140171
return centroids, n_row, n_col
@@ -144,10 +175,10 @@ end
144175
"""
145176
sum_of_squares(x, labels, centre, k)
146177
147-
This function computes the total sum of squares based on the assigned (labels)
148-
design matrix(x), centroids (centre), and the number of desired groups (k).
178+
This function computes the total sum of squares based on the assigned (labels)
179+
design matrix(x), centroids (centre), and the number of desired groups (k).
149180
150-
A Float type representing the computed metric is returned.
181+
A Float type representing the computed metric is returned.
151182
"""
152183
function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Array)
153184
s = 0.0
@@ -165,24 +196,24 @@ end
165196
"""
166197
Kmeans(design_matrix, k; k_init="k-means++", max_iters=300, tol=1e-4, verbose=true)
167198
168-
This main function employs the K-means algorithm to cluster all examples
169-
in the training data (design_matrix) into k groups using either the
170-
`k-means++` or random initialisation technique for selecting the initial
171-
centroids.
172-
173-
At the end of the number of iterations specified (max_iters), convergence is
174-
achieved if difference between the current and last cost objective is
175-
less than the tolerance level (tol). An error is thrown if convergence fails.
199+
This main function employs the K-means algorithm to cluster all examples
200+
in the training data (design_matrix) into k groups using either the
201+
`k-means++` or random initialisation technique for selecting the initial
202+
centroids.
176203
177-
Details of operations can be either printed or not by setting verbose accordingly.
204+
At the end of the number of iterations specified (max_iters), convergence is
205+
achieved if difference between the current and last cost objective is
206+
less than the tolerance level (tol). An error is thrown if convergence fails.
178207
179-
A tuple representing labels, centroids, and sum_squares respectively is returned.
208+
Details of operations can be either printed or not by setting verbose accordingly.
180209
210+
A tuple representing labels, centroids, and sum_squares respectively is returned.
181211
"""
182-
function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-means++",
183-
max_iters::Int = 300, tol = 1e-4, verbose::Bool = true)
212+
function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
213+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4,
214+
verbose::Bool = true) where {T <: CalculationMode}
184215

185-
centroids, n_row, n_col = smart_init(design_matrix, k, init=k_init)
216+
centroids, n_row, n_col = smart_init(design_matrix, k, mode, init=k_init)
186217

187218
labels = Vector{Int}(undef, n_row)
188219
distances = Vector{Float64}(undef, n_row)
@@ -193,7 +224,7 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-me
193224
nearest_neighbour = Array{Float64, 2}(undef, size(design_matrix, 1), size(centroids, 1))
194225
# Update centroids & labels with closest members until convergence
195226
for iter = 1:max_iters
196-
pairwise!(nearest_neighbour, design_matrix, centroids)
227+
pairwise!(nearest_neighbour, design_matrix, centroids, mode)
197228

198229
@inbounds for i in axes(nearest_neighbour, 1)
199230
labels[i] = 1
@@ -224,11 +255,11 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-me
224255

225256
if verbose
226257
# Show progress and terminate if J stopped decreasing.
227-
println("Iteration ", iter, ": Jclust = ", J, ".")
258+
println("Iteration $iter: Jclust = $J.")
228259
end
229260

230261
# Final Step: Check for convergence
231-
if iter > 1 && abs(J - J_previous) < (tol * J)
262+
if (iter > 1) & (abs(J - J_previous) < (tol * J))
232263

233264
sum_squares = sum_of_squares(design_matrix, labels, centroids)
234265

@@ -237,16 +268,15 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int; k_init::String = "k-me
237268
println("Successfully terminated with convergence.")
238269
end
239270

240-
return labels, centroids, sum_squares
271+
return (labels=labels, centroids=centroids, sum_squares=sum_squares)
241272

242-
elseif iter == max_iters && abs(J - J_previous) > (tol * J)
273+
elseif (iter == max_iters) & (abs(J - J_previous) > (tol * J))
243274
throw(error("Failed to converge Check data and/or implementation or increase max_iter."))
244275

245276
end
246277

247278
J_previous = J
248279
end
249-
250280
end
251281

252282
end # module

test/test01_distance.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module TestDistance
2-
using ParallelKMeans: pairwise!, pl_pairwise!
2+
using ParallelKMeans: pairwise!, SingleThread, MultiThread
33
using Test
44

55
@testset "naive singlethread pairwise" begin
@@ -11,4 +11,14 @@ using Test
1111
@test all(r .≈ [0.0, 13.0, 25.0])
1212
end
1313

14+
@testset "multithread pairwise" begin
15+
X = [1.0 2.0; 3.0 5.0; 4.0 6.0]
16+
y = [1.0 2.0; ]
17+
r = Array{Float64, 2}(undef, 3, 1)
18+
19+
pairwise!(r, X, y, MultiThread())
20+
@test all(r .≈ [0.0, 13.0, 25.0])
21+
end
22+
23+
1424
end # module

test/test02_kmeans.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
module TestKMeans
2+
23
using ParallelKMeans
4+
using ParallelKMeans: MultiThread
35
using Test
46
using Random
57

6-
@testset "linear separation" begin
8+
@testset "singlethread linear separation" begin
79
Random.seed!(2020)
810

911
X = rand(100, 3)
@@ -15,4 +17,17 @@ using Random
1517
@test sum_squares 15.314823028363763
1618
end
1719

20+
21+
@testset "multithread linear separation" begin
22+
Random.seed!(2020)
23+
24+
X = rand(100, 3)
25+
labels, centroids, sum_squares = kmeans(X, 3, MultiThread(); tol = 1e-10, verbose = false)
26+
27+
# for future reference: Clustering shows here 14.964882850452984
28+
# guess they use better initialisation. For now we will use own
29+
# value
30+
@test sum_squares 15.314823028363763
31+
end
32+
1833
end # module

0 commit comments

Comments
 (0)