diff --git a/src/kmeans.jl b/src/kmeans.jl index 5588710c..616298a8 100644 --- a/src/kmeans.jl +++ b/src/kmeans.jl @@ -17,6 +17,7 @@ const _kmeans_default_init = :kmpp const _kmeans_default_maxiter = 100 const _kmeans_default_tol = 1.0e-6 const _kmeans_default_display = :none +const _kmeans_default_n_init = 10 function kmeans!{T<:AbstractFloat}(X::Matrix{T}, centers::Matrix{T}; weights=nothing, @@ -44,20 +45,34 @@ function kmeans{T<:AbstractFloat}(X::Matrix{T}, k::Int; weights=nothing, init=_kmeans_default_init, maxiter::Integer=_kmeans_default_maxiter, + n_init::Integer=_kmeans_default_n_init, tol::Real=_kmeans_default_tol, display::Symbol=_kmeans_default_display, distance::SemiMetric=SqEuclidean()) m, n = size(X) (2 <= k < n) || error("k must have 2 <= k < n.") - iseeds = initseeds(init, X, k) - centers = copyseeds(X, iseeds) - kmeans!(X, centers; - weights=weights, - maxiter=maxiter, - tol=tol, - display=display, - distance=distance) + n_init > 0 || throw(ArgumentError("n_init must be greater than 0")) + + lowestcost::Float64 = Inf + local bestresult::KmeansResult + + for i = 1:n_init + iseeds = initseeds(init, X, k) + centers = copyseeds(X, iseeds) + result = kmeans!(X, centers; + weights=weights, + maxiter=maxiter, + tol=tol, + display=display, + distance=distance) + + if result.totalcost < lowestcost + lowestcost = result.totalcost + bestresult = result + end + end + return bestresult end #### Core implementation diff --git a/test/kmeans.jl b/test/kmeans.jl index 33a61b16..db439adf 100644 --- a/test/kmeans.jl +++ b/test/kmeans.jl @@ -15,7 +15,7 @@ k = 10 x = rand(m, n) # non-weighted -r = kmeans(x, k; maxiter=50) +r = kmeans(x, k; maxiter=50, n_init=2) @test isa(r, KmeansResult{Float64}) @test size(r.centers) == (m, k) @test length(r.assignments) == n @@ -27,7 +27,7 @@ r = kmeans(x, k; maxiter=50) @test isapprox(sum(r.costs), r.totalcost) # non-weighted (float32) -r = kmeans(map(Float32, x), k; maxiter=50) +r = kmeans(@compat(map(Float32, x)), k; maxiter=50, n_init=2) @test isa(r, KmeansResult{Float32}) @test size(r.centers) == (m, k) @test length(r.assignments) == n @@ -40,7 +40,7 @@ r = kmeans(map(Float32, x), k; maxiter=50) # weighted w = rand(n) -r = kmeans(x, k; maxiter=50, weights=w) +r = kmeans(x, k; maxiter=50, weights=w, n_init=2) @test isa(r, KmeansResult{Float64}) @test size(r.centers) == (m, k) @test length(r.assignments) == n