Skip to content

Commit f505869

Browse files
author
Andrey Oskin
committed
Fixed errors in multithreaded version
1 parent c82d790 commit f505869

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@
1010
.benchmarkci/
1111
.idea/*
1212
.vscode/*
13-
.test/experiments.jl
13+
test/experiments.jl

benchmark/bench01_distance.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,19 @@ using Random
77
suite = BenchmarkGroup()
88

99
Random.seed!(2020)
10-
X = rand(100_000, 3)
11-
centroids = rand(2, 3)
12-
d = rand(100_000, 2)
13-
suite["100kx3"] = @benchmarkable ParallelKMeans.pairwise!($d, $X, $centroids)
10+
X = rand(3, 100_000)
11+
centroids = rand(3, 2)
12+
d = Vector{Float64}(undef, 100_000)
13+
suite["100kx3"] = @benchmarkable ParallelKMeans.colwise!($d, $X, $centroids)
1414

15-
X = rand(100_000, 10)
16-
centroids = rand(2, 10)
17-
d = rand(100_000, 2)
18-
suite["100kx10"] = @benchmarkable ParallelKMeans.pairwise!($d, $X, $centroids)
15+
X = rand(10, 100_000)
16+
centroids = rand(10, 2)
17+
d = Vector{Float64}(undef, 100_000)
18+
suite["100kx10"] = @benchmarkable ParallelKMeans.colwise!($d, $X, $centroids)
1919

2020
# for reference
2121
metric = SqEuclidean()
22-
suite["100kx10_distances"] = @benchmarkable Distances.pairwise!($d, $metric, $X, $centroids, dims = 1)
22+
suite["100kx10_distances"] = @benchmarkable Distances.colwise!($d, $metric, $X, $centroids, dims = 2)
2323

2424
end # module
2525

src/ParallelKMeans.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,25 @@ function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Arr
207207
return s
208208
end
209209

210+
# TODO generalize centroids type
211+
function create_containers(k, d, mode::SingleThread)
212+
new_centroids = Array{Float64, 2}(undef, d, k)
213+
centroids_cnt = Vector{Int}(undef, k)
214+
215+
return new_centroids, centroids_cnt
216+
end
217+
218+
function create_containers(k, d, mode::MultiThread)
219+
new_centroids = Vector{Array{Float64, 2}}(undef, mode.n)
220+
centroids_cnt = Vector{Vector{Int}}(undef, mode.n)
221+
222+
for i in 1:mode.n
223+
new_centroids[i] = Array{Float64, 2}(undef, d, k)
224+
centroids_cnt[i] = Vector{Int}(undef, k)
225+
end
226+
227+
return new_centroids, centroids_cnt
228+
end
210229

211230
"""
212231
Kmeans(design_matrix, k; k_init="k-means++", max_iters=300, tol=1e-4, verbose=true)
@@ -228,15 +247,15 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread
228247
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
229248
nrow, ncol = size(design_matrix)
230249
centroids = init == nothing ? smart_init(design_matrix, k, mode, init=k_init).centroids : init
231-
new_centroids = similar(centroids)
250+
new_centroids, centroids_cnt = create_containers(k, nrow, mode)
251+
# new_centroids = similar(centroids)
232252

233253
labels = Vector{Int}(undef, ncol)
234-
centroids_cnt = Vector{Int}(undef, k)
254+
# centroids_cnt = Vector{Int}(undef, k)
235255

236256
J_previous = Inf64
237257
totalcost = Inf
238258

239-
# nearest_neighbour = Array{Float64, 2}(undef, size(design_matrix, 1), size(centroids, 1))
240259
# Update centroids & labels with closest members until convergence
241260
for iter = 1:max_iters
242261
J = update_centroids!(centroids, new_centroids, centroids_cnt, labels, design_matrix, mode)
@@ -294,16 +313,21 @@ function update_centroids!(centroids, new_centroids, centroids_cnt, labels,
294313
waiting_list = Vector{Task}(undef, mode.n - 1)
295314

296315
for i in 1:length(ranges) - 1
297-
waiting_list[i] = @spawn chunk_update_centroids!(centroids, new_centroids, centroids_cnt, labels,
316+
waiting_list[i] = @spawn chunk_update_centroids!(centroids, new_centroids[i + 1], centroids_cnt[i + 1], labels,
298317
design_matrix, ranges[i], mode)
299318
end
300319

301-
J = chunk_update_centroids!(centroids, new_centroids, centroids_cnt, labels,
320+
J = chunk_update_centroids!(centroids, new_centroids[1], centroids_cnt[1], labels,
302321
design_matrix, ranges[end], mode)
303322

304323
J += sum(fetch.(waiting_list))
305324

306-
centroids .= new_centroids ./ centroids_cnt'
325+
for i in 1:length(ranges) - 1
326+
new_centroids[1] .+= new_centroids[i + 1]
327+
centroids_cnt[1] .+= centroids_cnt[i + 1]
328+
end
329+
330+
centroids .= new_centroids[1] ./ centroids_cnt[1]'
307331

308332
return J
309333
end

0 commit comments

Comments
 (0)