Skip to content

Commit afd927c

Browse files
authored
Merge pull request #12 from PyDataBlog/master
Get the experimental branch up to date with the current stable version
2 parents 5239c9e + ac87d1b commit afd927c

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ julia:
88
- nightly
99
after_success:
1010
- julia -e 'using Pkg; Pkg.add("Coverage"); using Coverage; Coveralls.submit(process_folder())'
11+
coveralls: true
1112
jobs:
1213
allow_failures:
1314
- julia: nightly

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://PyDataBlog.github.io/ParallelKMeans.jl/stable)
44
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://PyDataBlog.github.io/ParallelKMeans.jl/dev)
55
[![Build Status](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl.svg?branch=master)](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl)
6-
[![Coveralls](https://coveralls.io/repos/github/PyDataBlog/ParallelKMeans.jl/badge.svg?branch=master)](https://coveralls.io/github/PyDataBlog/ParallelKMeans.jl?branch=master)
6+
[![Coverage Status](https://coveralls.io/repos/github/PyDataBlog/ParallelKMeans.jl/badge.svg?branch=master)](https://coveralls.io/github/PyDataBlog/ParallelKMeans.jl?branch=master)

src/ParallelKMeans.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,9 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
127127

128128
# randonmly select the first centroid from the data (X)
129129
centroids = zeros(k, n_col)
130+
rand_indices = Vector{Int}(undef, k)
130131
rand_idx = rand(1:n_row)
132+
rand_indices[1] = rand_idx
131133
centroids[1, :] .= X[rand_idx, :]
132134
distances = Array{Float64}(undef, n_row, 1)
133135
new_distances = Array{Float64}(undef, n_row, 1)
@@ -143,6 +145,7 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
143145
# choose the next centroid, the probability for each data point to be chosen
144146
# is directly proportional to its squared distance from the nearest centroid
145147
r_idx = sample(1:n_row, ProbabilityWeights(vec(distances)))
148+
rand_indices[i] = r_idx
146149
centroids[i, :] .= X[r_idx, :]
147150

148151
# Ignore setting the last centroid to help the separation of centroids
@@ -168,7 +171,7 @@ function smart_init(X::Array{Float64, 2}, k::Int, mode::T = SingleThread();
168171
centroids = X[rand_indices, :]
169172
end
170173

171-
return centroids, n_row, n_col
174+
return (centroids = centroids, indices = rand_indices)
172175
end
173176

174177

@@ -210,10 +213,10 @@ Details of operations can be either printed or not by setting verbose accordingl
210213
A tuple representing labels, centroids, and sum_squares respectively is returned.
211214
"""
212215
function kmeans(design_matrix::Array{Float64, 2}, k::Int, mode::T = SingleThread();
213-
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4,
214-
verbose::Bool = true) where {T <: CalculationMode}
216+
k_init::String = "k-means++", max_iters::Int = 300, tol = 1e-4, verbose::Bool = true, init = nothing) where {T <: CalculationMode}
215217

216-
centroids, n_row, n_col = smart_init(design_matrix, k, mode, init=k_init)
218+
n_row, n_col = size(design_matrix)
219+
centroids = init == nothing ? smart_init(design_matrix, k, mode, init=k_init).centroids : init
217220

218221
labels = Vector{Int}(undef, n_row)
219222
distances = Vector{Float64}(undef, n_row)

0 commit comments

Comments
 (0)