Skip to content

Commit 5587d6b

Browse files
committed
TODO cleanups and requests
1 parent e557d89 commit 5587d6b

File tree

2 files changed

+88
-16
lines changed

2 files changed

+88
-16
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ ________________________________________________________________________________
5050
_________________________________________________________________________________________________________
5151

5252
### Pending Features
53-
53+
- [X] Implementation of Triangle inequality based on [Elkan C. (2003) "Using the Triangle Inequality to Accelerate
54+
-Mean"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf)
5455
- [ ] Support for DataFrame inputs.
5556
- [ ] Refactoring and finalizaiton of API desgin.
5657
- [ ] GPU support.
5758
- [ ] Even faster Kmeans implementation based on current literature.
5859
- [ ] Optimization of code base.
59-
- [X] Implementation of Triangle inequality based on [Elkan C. (2003) "Using the Triangle Inequality to Accelerate
60-
-Mean"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf)
60+
6161
_________________________________________________________________________________________________________
6262

6363
### How To Use

src/ParallelKMeans.jl

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ export kmeans
88

99
# All Abstract types defined
1010
"""
11-
TODO: Docs
11+
AbstractKMeansAlg
12+
Abstract base type inherited by all sub-KMeans algorithms.
1213
"""
1314
abstract type AbstractKMeansAlg end
1415

1516
"""
16-
TODO: Docs
17+
CalculationMode
18+
Abstract base type inherited by various threading implementations.
1719
"""
1820
abstract type CalculationMode end
1921

@@ -28,6 +30,7 @@ abstract type ClusteringResult end
2830
# Here we mimic `Clustering` output structure
2931
"""
3032
KmeansResult{C,D<:Real,WC<:Real} <: ClusteringResult
33+
3134
The output of [`kmeans`](@ref) and [`kmeans!`](@ref).
3235
# Type parameters
3336
* `C<:AbstractMatrix{<:AbstractFloat}`: type of the `centers` matrix
@@ -57,21 +60,28 @@ struct Lloyd <: AbstractKMeansAlg end
5760

5861

5962
"""
60-
TODO: Docs
63+
LightElkan()
64+
TODO: Description of LightElkan algorithm here
6165
"""
6266
struct LightElkan <: AbstractKMeansAlg end
6367

64-
# Single thread class to control the calculation type based on the CalculationMode
68+
"""
69+
SingleThread()
70+
Single thread class to control the calculation type based on the CalculationMode
71+
"""
6572
struct SingleThread <: CalculationMode
6673
end
6774

68-
# Multi threaded implementation to control the calculation type based avaialble threads
75+
"""
76+
MultiThread()
77+
Multi threaded implementation to control the calculation type based avaialble threads
78+
"""
6979
struct MultiThread <: CalculationMode
7080
n::Int
7181
end
7282

7383
# Get the number of avaialble threads for multithreading implementation
74-
MultiThread() = MultiThread(Threads.nthreads())
84+
MultiThread() = MultiThread(Threads.nthreads()) # Uses all avaialble cores by default
7585

7686

7787

@@ -81,9 +91,12 @@ MultiThread() = MultiThread(Threads.nthreads())
8191
Let X is a matrix `m x n` and Y is a vector of the length `m`. Then the `colwise!` function
8292
computes distance between each column in X and Y and store result
8393
in `target` array. Argument `mode` defines calculation mode, currently
84-
following modes supported
94+
following modes supported:
95+
8596
- SingleThread()
8697
- MultiThread()
98+
99+
This dispatch handles the colwise calculation for single threads.
87100
"""
88101
colwise!(target, x, y) = colwise!(target, x, y, SingleThread())
89102

@@ -97,6 +110,7 @@ function colwise!(target, x, y, mode::SingleThread)
97110
end
98111
end
99112

113+
100114
"""
101115
spliiter(n, k)
102116
@@ -108,6 +122,19 @@ function splitter(n, k)
108122
end
109123

110124

125+
"""
126+
colwise!(target, x, y, mode)
127+
128+
Let X is a matrix `m x n` and Y is a vector of the length `m`. Then the `colwise!` function
129+
computes distance between each column in X and Y and store result
130+
in `target` array. Argument `mode` defines calculation mode, currently
131+
following modes supported:
132+
133+
- SingleThread()
134+
- MultiThread()
135+
136+
This dispatch handles the colwise calculation for multi-threads.
137+
"""
111138
function colwise!(target, x, y, mode::MultiThread)
112139
ncol = size(x, 2)
113140

@@ -145,6 +172,7 @@ function chunk_colwise!(target, x, y, r)
145172
end
146173
end
147174

175+
148176
"""
149177
smart_init(X, k; init="k-means++")
150178
@@ -231,14 +259,21 @@ function sum_of_squares(x::Array{Float64,2}, labels::Array{Int64,1}, centre::Arr
231259
return s
232260
end
233261

234-
# TODO generalize centroids type
262+
263+
"""
264+
# TODO generalize centroids type & Docs
265+
"""
235266
function create_containers(k, d, mode::SingleThread)
236267
new_centroids = Array{Float64, 2}(undef, d, k)
237268
centroids_cnt = Vector{Int}(undef, k)
238269

239270
return new_centroids, centroids_cnt
240271
end
241272

273+
274+
"""
275+
# TODO: Docs
276+
"""
242277
function create_containers(k, d, mode::MultiThread)
243278
new_centroids = Vector{Array{Float64, 2}}(undef, mode.n)
244279
centroids_cnt = Vector{Vector{Int}}(undef, mode.n)
@@ -251,8 +286,9 @@ function create_containers(k, d, mode::MultiThread)
251286
return new_centroids, centroids_cnt
252287
end
253288

289+
254290
"""
255-
Kmeans(design_matrix, k; k_init="k-means++", max_iters=300, tol=1e-4, verbose=true)
291+
Kmeans(design_matrix, k; k_init="k-means++", max_iters=300, tol=1e-6, verbose=true)
256292
257293
This main function employs the K-means algorithm to cluster all examples
258294
in the training data (design_matrix) into k groups using either the
@@ -268,7 +304,7 @@ Details of operations can be either printed or not by setting verbose accordingl
268304
A tuple representing labels, centroids, and sum_squares respectively is returned.
269305
"""
270306
function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
271-
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
307+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-6, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
272308
nrow, ncol = size(design_matrix)
273309
centroids = init == nothing ? smart_init(design_matrix, k, mode, init=k_init).centroids : init
274310
new_centroids, centroids_cnt = create_containers(k, nrow, mode)
@@ -314,14 +350,22 @@ function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread
314350
return KmeansResult(centroids, labels, Float64[], Int[], Float64[], totalcost, niters, converged)
315351
end
316352

353+
354+
"""
355+
# TODO: Docs
356+
"""
317357
kmeans(alg::Lloyd, design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
318-
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4,
358+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-6,
319359
verbose::Bool = true, init = nothing) where {T <: CalculationMode} =
320360
kmeans(design_matrix, k, mode; k_init = k_init, max_iters = max_iters, tol = tol,
321361
verbose = verbose, init = init)
322362

363+
364+
"""
365+
# TODO: Docs
366+
"""
323367
function kmeans(alg::LightElkan, design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
324-
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
368+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-6, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
325369
nrow, ncol = size(design_matrix)
326370
centroids = init == nothing ? smart_init(design_matrix, k, mode, init=k_init).centroids : deepcopy(init)
327371
new_centroids, centroids_cnt = create_containers(k, nrow, mode)
@@ -371,6 +415,10 @@ function kmeans(alg::LightElkan, design_matrix::Array{Float64, 2}, k::Int, mode:
371415
return KmeansResult(centroids, labels, Float64[], Int[], Float64[], totalcost, niters, converged)
372416
end
373417

418+
419+
"""
420+
# TODO: Docs
421+
"""
374422
function update_centroids!(centroids, new_centroids, centroids_cnt, labels,
375423
design_matrix, mode::SingleThread)
376424

@@ -383,6 +431,10 @@ function update_centroids!(centroids, new_centroids, centroids_cnt, labels,
383431
return J
384432
end
385433

434+
435+
"""
436+
# TODO: Docs
437+
"""
386438
function update_centroids!(centroids, new_centroids, centroids_cnt, labels,
387439
design_matrix, mode::MultiThread)
388440
mode.n == 1 && return update_centroids!(centroids, new_centroids[1], centroids_cnt[1], labels,
@@ -414,7 +466,11 @@ function update_centroids!(centroids, new_centroids, centroids_cnt, labels,
414466
return J
415467
end
416468

417-
# Lots of copy paste. It should be cleaned after api settles down.
469+
470+
"""
471+
# TODO: Docs
472+
# Lots of copy paste. It should be cleaned after api settles down.
473+
"""
418474
function update_centroids_dist!(centroids_dist, centroids, mode = SingleThread())
419475
k = size(centroids_dist, 1) # number of clusters
420476
@inbounds for j in axes(centroids_dist, 2)
@@ -441,6 +497,10 @@ function update_centroids_dist!(centroids_dist, centroids, mode = SingleThread()
441497
centroids_dist
442498
end
443499

500+
501+
"""
502+
# TODO: Docs
503+
"""
444504
function update_centroids!(alg::LightElkan, centroids, centroids_dist, new_centroids, centroids_cnt, labels,
445505
design_matrix, mode::SingleThread)
446506

@@ -454,6 +514,10 @@ function update_centroids!(alg::LightElkan, centroids, centroids_dist, new_centr
454514
return J
455515
end
456516

517+
518+
"""
519+
# TODO: Docs
520+
"""
457521
function update_centroids!(alg::LightElkan, centroids, centroids_dist, new_centroids, centroids_cnt, labels,
458522
design_matrix, mode::MultiThread)
459523
mode.n == 1 && return update_centroids!(alg, centroids, centroids_dist, new_centroids[1], centroids_cnt[1], labels,
@@ -486,6 +550,10 @@ function update_centroids!(alg::LightElkan, centroids, centroids_dist, new_centr
486550
return J
487551
end
488552

553+
554+
"""
555+
# TODO: Docs
556+
"""
489557
function chunk_update_centroids!(centroids, new_centroids, centroids_cnt, labels,
490558
design_matrix, r)
491559

@@ -514,6 +582,10 @@ function chunk_update_centroids!(centroids, new_centroids, centroids_cnt, labels
514582
return J
515583
end
516584

585+
586+
"""
587+
# TODO: Docs
588+
"""
517589
function chunk_update_centroids!(alg::LightElkan, centroids, centroids_dist, new_centroids, centroids_cnt, labels,
518590
design_matrix, r)
519591

0 commit comments

Comments
 (0)