Skip to content

Commit 121573c

Browse files
committed
MiniBatch algorithm draft
1 parent 9c221c9 commit 121573c

File tree

7 files changed

+179
-13
lines changed

7 files changed

+179
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParallelKMeans"
22
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af"
33
authors = ["Bernard Brenyah", "Andrey Oskin"]
4-
version = "0.2.0"
4+
version = "0.2.1"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

docs/src/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ ________________________________________________________________________________
213213
- 0.1.7 Added `Yinyang` and `Coreset` support in MLJ interface; added `weights` support in MLJ; added RNG seed support in MLJ interface and through all algorithms; added metric support.
214214
- 0.1.8 Minor cleanup
215215
- 0.1.9 Added travis support for Julia 1.5
216+
- 0.2.0 Updated MLJ Interface
217+
- 0.2.1 Mini-batch implementation
216218

217219
## Contributing
218220

src/ParallelKMeans.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ include("hamerly.jl")
1515
include("elkan.jl")
1616
include("yinyang.jl")
1717
include("coreset.jl")
18+
include("mini_batch.jl")
1819
include("mlj_interface.jl")
1920

2021
export kmeans
21-
export Lloyd, Hamerly, Elkan, Yinyang, 阴阳, Coreset
22+
export Lloyd, Hamerly, Elkan, Yinyang, 阴阳, Coreset, MiniBatch
2223

2324
end # module

src/kmeans.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Allocationless calculation of square eucledean distance between vectors X1[:, i1
115115
@inline function distance(metric::Euclidean, X1, X2, i1, i2)
116116
# here goes my definition
117117
d = zero(eltype(X1))
118-
# TODO: break of the loop if d is larger than threshold (known minimum disatnce)
118+
# TODO: break of the loop if d is larger than threshold (known minimum distance)
119119
@inbounds @simd for i in axes(X1, 1)
120120
d += (X1[i, i1] - X2[i, i2])^2
121121
end

src/mini_batch.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""
2+
MiniBatch(b::Int)
3+
4+
Sculley et al. 2007 Mini batch k-means algorithm implementation.
5+
"""
6+
struct MiniBatch <: AbstractKMeansAlg
7+
b::Int # batch size
8+
end
9+
10+
11+
MiniBatch() = MiniBatch(100)
12+
13+
function kmeans!(alg::MiniBatch, X, k;
14+
weights = nothing, metric = Euclidean(), n_threads = Threads.nthreads(),
15+
k_init = "k-means++", init = nothing, max_iters = 300,
16+
tol = 0, max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG)
17+
18+
# Get the type and dimensions of design matrix, X
19+
T = eltype(X)
20+
nrow, ncol = size(X)
21+
22+
# Initiate cluster centers - (Step 2) in paper
23+
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init = k_init).centroids : deepcopy(init)
24+
25+
# Initialize counter for the no. of data in each cluster - (Step 3) in paper
26+
N = zeros(T, k)
27+
28+
# Initialize nearest centers
29+
labels = Vector{Int}(undef, alg.b)
30+
31+
converged = false
32+
niters = 0
33+
J_previous = zero(T)
34+
J = zero(T)
35+
36+
# TODO: Main Steps. Batch update centroids until convergence
37+
while niters <= max_iters
38+
counter = 0
39+
40+
# b examples picked randomly from X (Step 5 in paper)
41+
batch_rand_idx = isnothing(weights) ? rand(rng, 1:ncol, alg.b) : wsample(rng, 1:ncol, weights, alg.b)
42+
batch_sample = X[:, batch_rand_idx]
43+
44+
# Cache/label the batch samples nearest to the centers (Step 6 & 7)
45+
@inbounds for i in axes(batch_sample, 2)
46+
min_dist = distance(metric, batch_sample, centroids, i, 1)
47+
label = 1
48+
49+
for j in 2:size(centroids, 2)
50+
dist = distance(metric, batch_sample, centroids, i, j)
51+
label = dist < min_dist ? j : label
52+
min_dist = dist < min_dist ? dist : min_dist
53+
end
54+
55+
labels[i] = label
56+
end
57+
58+
# TODO: Batch gradient step
59+
for j in axes(batch_sample, 2) # iterate over examples (Step 9)
60+
61+
# Get cached center/label for this x => labels[j] (Step 10)
62+
label = labels[j]
63+
# Update per-center counts
64+
N[label] += isnothing(weights) ? 1 : weights[j] # verify (Step 11)
65+
66+
# Get per-center learning rate (Step 12)
67+
lr = 1 / N[label]
68+
69+
# Take gradient step (Step 13) # TODO: Replace with an allocation-less loop.
70+
centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* batch_sample[:, j])
71+
end
72+
73+
# TODO: Calculate cost and check for convergence
74+
J = sum_of_squares(batch_sample, labels, centroids) # just a placeholder for now
75+
76+
if verbose
77+
# Show progress and terminate if J stopped decreasing.
78+
println("Iteration $niters: Jclust = $J")
79+
end
80+
81+
# TODO: Check for early stopping convergence
82+
if (niters > 1) & abs(J - J_previous)
83+
counter += 1
84+
85+
# Declare convergence if max_no_improvement criterion is met
86+
if counter >= max_no_improvement
87+
converged = true
88+
break
89+
end
90+
91+
end
92+
93+
J_previous = J
94+
niters += 1
95+
end
96+
97+
return centroids, niters, converged, labels, J # TODO: push learned artifacts to KmeansResult
98+
#return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
99+
end
100+
101+
# TODO: Only being used to test generic implementation. Get rid off after!
102+
function sum_of_squares(x, labels, centre)
103+
s = 0.0
104+
105+
for i in axes(x, 2)
106+
for j in axes(x, 1)
107+
s += (x[j, i] - centre[j, labels[i]])^2
108+
end
109+
end
110+
return s
111+
end

src/mlj_interface.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ const MLJDICT = Dict(:Lloyd => Lloyd(),
1515
####
1616
#### MODEL DEFINITION
1717
####
18-
18+
"""
19+
ParallelKMeans model constructed by the user.
20+
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
21+
"""
1922
mutable struct KMeans <: MMI.Unsupervised
2023
algo::Union{Symbol, AbstractKMeansAlg}
2124
k_init::String
@@ -80,7 +83,7 @@ end
8083
#### FIT FUNCTION
8184
####
8285
"""
83-
Fit the specified ParaKMeans model constructed by the user.
86+
Fit the specified ParallelKMeans model constructed by the user.
8487
8588
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
8689
"""
@@ -187,21 +190,21 @@ end
187190
#### METADATA
188191
####
189192

190-
# TODO 4: metadata for the package and for each of the model interfaces
193+
# Metadata for the package and for each of the model interfaces
191194
MMI.metadata_pkg.(KMeans,
192-
name = "ParallelKMeans",
193-
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af",
194-
url = "https://github.com/PyDataBlog/ParallelKMeans.jl",
195-
julia = true,
196-
license = "MIT",
197-
is_wrapper = false)
195+
name = "ParallelKMeans",
196+
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af",
197+
url = "https://github.com/PyDataBlog/ParallelKMeans.jl",
198+
julia = true,
199+
license = "MIT",
200+
is_wrapper = false)
198201

199202

200203
# Metadata for ParaKMeans model interface
201204
MMI.metadata_model(KMeans,
202205
input = MMI.Table(MMI.Continuous),
203206
output = MMI.Table(MMI.Continuous),
204-
target = AbstractArray{<:MMI.Multiclass},
207+
target = AbstractArray{<:MMI.Multiclass},
205208
weights = false,
206209
descr = ParallelKMeans_Desc,
207210
path = "ParallelKMeans.KMeans")

test/test90_minibatch.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
module TestMiniBatch
2+
3+
using ParallelKMeans
4+
using Test
5+
using StableRNGs
6+
using StatsBase
7+
using Distances
8+
9+
10+
@testset "MiniBatch default batch size" begin
11+
@test MiniBatch() == MiniBatch(100)
12+
end
13+
14+
15+
@testset "MiniBatch convergence" begin
16+
X = [1 1 1 4 4 4 4 0 2 3 5 1; 2 4 0 2 0 4 5 1 2 2 5 -1.]
17+
18+
rng = StableRNG(2020)
19+
baseline = kmeans(Lloyd(), X, 2, rng = rng)
20+
21+
rng = StableRNG(2020)
22+
res = kmeans(MiniBatch(6), X, 2, rng = rng)
23+
24+
@test baseline.totalcost res.totalcost
25+
end
26+
27+
28+
@testset "MiniBatch metric support" begin
29+
X = [1 1 1 4 4 4 4 0 2 3 5 1; 2 4 0 2 0 4 5 1 2 2 5 -1.]
30+
rng = StableRNG(2020)
31+
rng_orig = deepcopy(rng)
32+
33+
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock(), rng = rng)
34+
35+
rng = deepcopy(rng_orig)
36+
res = kmeans(MiniBatch(6), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
37+
38+
@test res.totalcost baseline.totalcost
39+
@test res.converged == baseline.converged
40+
end
41+
42+
43+
44+
45+
46+
47+
48+
49+
end # module

0 commit comments

Comments
 (0)