@@ -13,7 +13,7 @@ MiniBatch() = MiniBatch(100)
1313function kmeans! (alg:: MiniBatch , X, k;
1414 weights = nothing , metric = Euclidean (), n_threads = Threads. nthreads (),
1515 k_init = " k-means++" , init = nothing , max_iters = 300 ,
16- tol = 0 , max_no_improvement = 10 , verbose = false , rng = Random. GLOBAL_RNG)
16+ tol = eltype (X)( 1e-6 ) , max_no_improvement = 10 , verbose = false , rng = Random. GLOBAL_RNG)
1717
1818 # Get the type and dimensions of design matrix, X
1919 T = eltype (X)
@@ -27,15 +27,16 @@ function kmeans!(alg::MiniBatch, X, k;
2727
2828 # Initialize nearest centers
2929 labels = Vector {Int} (undef, alg. b)
30+ final_labels = Vector {Int} (undef, ncol)
3031
3132 converged = false
3233 niters = 0
34+ counter = 0
3335 J_previous = zero (T)
3436 J = zero (T)
3537
3638 # TODO : Main Steps. Batch update centroids until convergence
3739 while niters <= max_iters
38- counter = 0
3940
4041 # b examples picked randomly from X (Step 5 in paper)
4142 batch_rand_idx = isnothing (weights) ? rand (rng, 1 : ncol, alg. b) : wsample (rng, 1 : ncol, weights, alg. b)
@@ -79,12 +80,27 @@ function kmeans!(alg::MiniBatch, X, k;
7980 end
8081
8182 # TODO : Check for early stopping convergence
82- if (niters > 1 ) & abs (J - J_previous)
83+ if (niters > 1 ) & abs (( J - J_previous) < (tol * J) )
8384 counter += 1
8485
8586 # Declare convergence if max_no_improvement criterion is met
8687 if counter >= max_no_improvement
8788 converged = true
89+ # TODO : Compute label assignment for the complete dataset
90+ @inbounds for i in axes (X, 2 )
91+ min_dist = distance (metric, X, centroids, i, 1 )
92+ label = 1
93+
94+ for j in 2 : size (centroids, 2 )
95+ dist = distance (metric, X, centroids, i, j)
96+ label = dist < min_dist ? j : label
97+ min_dist = dist < min_dist ? dist : min_dist
98+ end
99+
100+ final_labels[i] = label
101+ end
102+ # TODO : Compute totalcost for the complete dataset
103+ J = sum_of_squares (X, final_labels, centroids) # just a placeholder for now
88104 break
89105 end
90106
@@ -94,7 +110,7 @@ function kmeans!(alg::MiniBatch, X, k;
94110 niters += 1
95111 end
96112
97- return centroids, niters, converged, labels , J # TODO : push learned artifacts to KmeansResult
113+ return centroids, niters, converged, final_labels , J # TODO : push learned artifacts to KmeansResult
98114 # return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
99115end
100116
0 commit comments