|
| 1 | +""" |
| 2 | + MiniBatch(b::Int) |
| 3 | +
|
| 4 | + Sculley et al. 2007 Mini batch k-means algorithm implementation. |
| 5 | +""" |
| 6 | +struct MiniBatch <: AbstractKMeansAlg |
| 7 | + b::Int # batch size |
| 8 | +end |
| 9 | + |
| 10 | + |
| 11 | +MiniBatch() = MiniBatch(100) |
| 12 | + |
| 13 | +function kmeans!(alg::MiniBatch, X, k; |
| 14 | + weights = nothing, metric = Euclidean(), n_threads = Threads.nthreads(), |
| 15 | + k_init = "k-means++", init = nothing, max_iters = 300, |
| 16 | + tol = 0, max_no_improvement = 10, verbose = false, rng = Random.GLOBAL_RNG) |
| 17 | + |
| 18 | + # Get the type and dimensions of design matrix, X |
| 19 | + T = eltype(X) |
| 20 | + nrow, ncol = size(X) |
| 21 | + |
| 22 | + # Initiate cluster centers - (Step 2) in paper |
| 23 | + centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init = k_init).centroids : deepcopy(init) |
| 24 | + |
| 25 | + # Initialize counter for the no. of data in each cluster - (Step 3) in paper |
| 26 | + N = zeros(T, k) |
| 27 | + |
| 28 | + # Initialize nearest centers |
| 29 | + labels = Vector{Int}(undef, alg.b) |
| 30 | + |
| 31 | + converged = false |
| 32 | + niters = 0 |
| 33 | + J_previous = zero(T) |
| 34 | + J = zero(T) |
| 35 | + |
| 36 | + # TODO: Main Steps. Batch update centroids until convergence |
| 37 | + while niters <= max_iters |
| 38 | + counter = 0 |
| 39 | + |
| 40 | + # b examples picked randomly from X (Step 5 in paper) |
| 41 | + batch_rand_idx = isnothing(weights) ? rand(rng, 1:ncol, alg.b) : wsample(rng, 1:ncol, weights, alg.b) |
| 42 | + batch_sample = X[:, batch_rand_idx] |
| 43 | + |
| 44 | + # Cache/label the batch samples nearest to the centers (Step 6 & 7) |
| 45 | + @inbounds for i in axes(batch_sample, 2) |
| 46 | + min_dist = distance(metric, batch_sample, centroids, i, 1) |
| 47 | + label = 1 |
| 48 | + |
| 49 | + for j in 2:size(centroids, 2) |
| 50 | + dist = distance(metric, batch_sample, centroids, i, j) |
| 51 | + label = dist < min_dist ? j : label |
| 52 | + min_dist = dist < min_dist ? dist : min_dist |
| 53 | + end |
| 54 | + |
| 55 | + labels[i] = label |
| 56 | + end |
| 57 | + |
| 58 | + # TODO: Batch gradient step |
| 59 | + for j in axes(batch_sample, 2) # iterate over examples (Step 9) |
| 60 | + |
| 61 | + # Get cached center/label for this x => labels[j] (Step 10) |
| 62 | + label = labels[j] |
| 63 | + # Update per-center counts |
| 64 | + N[label] += isnothing(weights) ? 1 : weights[j] # verify (Step 11) |
| 65 | + |
| 66 | + # Get per-center learning rate (Step 12) |
| 67 | + lr = 1 / N[label] |
| 68 | + |
| 69 | + # Take gradient step (Step 13) # TODO: Replace with an allocation-less loop. |
| 70 | + centroids[:, label] .= (1 - lr) .* centroids[:, label] .+ (lr .* batch_sample[:, j]) |
| 71 | + end |
| 72 | + |
| 73 | + # TODO: Calculate cost and check for convergence |
| 74 | + J = sum_of_squares(batch_sample, labels, centroids) # just a placeholder for now |
| 75 | + |
| 76 | + if verbose |
| 77 | + # Show progress and terminate if J stopped decreasing. |
| 78 | + println("Iteration $niters: Jclust = $J") |
| 79 | + end |
| 80 | + |
| 81 | + # TODO: Check for early stopping convergence |
| 82 | + if (niters > 1) & abs(J - J_previous) |
| 83 | + counter += 1 |
| 84 | + |
| 85 | + # Declare convergence if max_no_improvement criterion is met |
| 86 | + if counter >= max_no_improvement |
| 87 | + converged = true |
| 88 | + break |
| 89 | + end |
| 90 | + |
| 91 | + end |
| 92 | + |
| 93 | + J_previous = J |
| 94 | + niters += 1 |
| 95 | + end |
| 96 | + |
| 97 | + return centroids, niters, converged, labels, J # TODO: push learned artifacts to KmeansResult |
| 98 | + #return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged) |
| 99 | +end |
| 100 | + |
| 101 | +# TODO: Only being used to test generic implementation. Get rid off after! |
| 102 | +function sum_of_squares(x, labels, centre) |
| 103 | + s = 0.0 |
| 104 | + |
| 105 | + for i in axes(x, 2) |
| 106 | + for j in axes(x, 1) |
| 107 | + s += (x[j, i] - centre[j, labels[i]])^2 |
| 108 | + end |
| 109 | + end |
| 110 | + return s |
| 111 | +end |
0 commit comments