Skip to content

Commit 20f6c02

Browse files
committed
Minibatch convergence done, metric support left
1 parent 121573c commit 20f6c02

File tree

2 files changed

+34
-18
lines changed

2 files changed

+34
-18
lines changed

src/mini_batch.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ MiniBatch() = MiniBatch(100)
1313
function kmeans!(alg::MiniBatch, X, k;
1414
weights = nothing, metric = Euclidean(), n_threads = Threads.nthreads(),
1515
k_init = "k-means++", init = nothing, max_iters = 300,
16-
tol = 0, max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG)
16+
tol = eltype(X)(1e-6), max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG)
1717

1818
# Get the type and dimensions of design matrix, X
1919
T = eltype(X)
@@ -27,15 +27,16 @@ function kmeans!(alg::MiniBatch, X, k;
2727

2828
# Initialize nearest centers
2929
labels = Vector{Int}(undef, alg.b)
30+
final_labels = Vector{Int}(undef, ncol)
3031

3132
converged = false
3233
niters = 0
34+
counter = 0
3335
J_previous = zero(T)
3436
J = zero(T)
3537

3638
# TODO: Main Steps. Batch update centroids until convergence
3739
while niters <= max_iters
38-
counter = 0
3940

4041
# b examples picked randomly from X (Step 5 in paper)
4142
batch_rand_idx = isnothing(weights) ? rand(rng, 1:ncol, alg.b) : wsample(rng, 1:ncol, weights, alg.b)
@@ -79,12 +80,27 @@ function kmeans!(alg::MiniBatch, X, k;
7980
end
8081

8182
# TODO: Check for early stopping convergence
82-
if (niters > 1) & abs(J - J_previous)
83+
if (niters > 1) & abs((J - J_previous) < (tol * J))
8384
counter += 1
8485

8586
# Declare convergence if max_no_improvement criterion is met
8687
if counter >= max_no_improvement
8788
converged = true
89+
# TODO: Compute label assignment for the complete dataset
90+
@inbounds for i in axes(X, 2)
91+
min_dist = distance(metric, X, centroids, i, 1)
92+
label = 1
93+
94+
for j in 2:size(centroids, 2)
95+
dist = distance(metric, X, centroids, i, j)
96+
label = dist < min_dist ? j : label
97+
min_dist = dist < min_dist ? dist : min_dist
98+
end
99+
100+
final_labels[i] = label
101+
end
102+
# TODO: Compute totalcost for the complete dataset
103+
J = sum_of_squares(X, final_labels, centroids) # just a placeholder for now
88104
break
89105
end
90106

@@ -94,7 +110,7 @@ function kmeans!(alg::MiniBatch, X, k;
94110
niters += 1
95111
end
96112

97-
return centroids, niters, converged, labels, J # TODO: push learned artifacts to KmeansResult
113+
return centroids, niters, converged, final_labels, J # TODO: push learned artifacts to KmeansResult
98114
#return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
99115
end
100116

test/test90_minibatch.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@ end
1313

1414

1515
@testset "MiniBatch convergence" begin
16-
X = [1 1 1 4 4 4 4 0 2 3 5 1; 2 4 0 2 0 4 5 1 2 2 5 -1.]
17-
1816
rng = StableRNG(2020)
19-
baseline = kmeans(Lloyd(), X, 2, rng = rng)
17+
X = rand(rng, 3, 100)
2018

21-
rng = StableRNG(2020)
22-
res = kmeans(MiniBatch(6), X, 2, rng = rng)
19+
baseline = [kmeans(Lloyd(), X, 2).totalcost for i in 1:1_000] |> mean |> round
20+
# TODO: Switch to kmeans after full implementation
21+
res = [ParallelKMeans.kmeans!(MiniBatch(50), X, 2)[end] for i in 1:1_000] |> mean |> round
2322

24-
@test baseline.totalcost res.totalcost
23+
@test baseline == res
2524
end
2625

2726

2827
@testset "MiniBatch metric support" begin
29-
X = [1 1 1 4 4 4 4 0 2 3 5 1; 2 4 0 2 0 4 5 1 2 2 5 -1.]
3028
rng = StableRNG(2020)
31-
rng_orig = deepcopy(rng)
32-
33-
baseline = kmeans(Lloyd(), X, 2, tol = 1e-16, metric=Cityblock(), rng = rng)
29+
X = rand(rng, 3, 100)
3430

35-
rng = deepcopy(rng_orig)
36-
res = kmeans(MiniBatch(6), X, 2; tol = 1e-16, metric=Cityblock(), rng = rng)
31+
baseline = [kmeans(Lloyd(), X, 2;
32+
tol=1e-6, metric=Cityblock(),
33+
max_iters=500).totalcost for i in 1:1000] |> mean |> floor
34+
# TODO: Switch to kmeans after full implementation
35+
res = [ParallelKMeans.kmeans!(MiniBatch(), X, 2;
36+
metric=Cityblock(), tol=1e-6,
37+
max_iters=500)[end] for i in 1:1000] |> mean |> floor
3738

38-
@test res.totalcost baseline.totalcost
39-
@test res.converged == baseline.converged
39+
@test baseline == res
4040
end
4141

4242

0 commit comments

Comments
 (0)