Skip to content

Commit 1372d73

Browse files
authored
improve topk/perm's worst case scenario (#79)
* improve topk/perm's worst case scenario * using binary search for larger k
1 parent 7a4f764 commit 1372d73

File tree

3 files changed

+205
-4
lines changed

3 files changed

+205
-4
lines changed

src/stat/hp_stat.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,46 @@ Base.@propagate_inbounds function hp_topk_vals(x::AbstractVector{T}, k::Int, lt_
8787
end
8888
topk_vals(res_out, k, lt_fun, by)
8989
end
90+
Base.@propagate_inbounds function hp_topk_vals(x::Union{Vector{T}, SubArray{T, N, Vector{T}, Tuple{I}, L}}, k::Int, lt_fun::F, by) where {T<:Union{Missing, FLOATS, INTEGERS}} where {F} where N where I <: UnitRange{Int} where L
91+
k < 1 && throw(ArgumentError("k must be greater than 1"))
92+
all(ismissing, x) && return Union{Missing,T}[missing]
93+
nt = Threads.nthreads()
94+
res = Vector{T}(undef, k * nt)
95+
res_out = allowmissing(res)
96+
fill!(res_out, missing)
97+
cz = div(length(x), nt)
98+
Threads.@threads for i in 1:nt
99+
lo = (i - 1) * cz + 1
100+
i == nt ? hi = length(x) : hi = i * cz
101+
th_res = view(res, (i-1)*k+1:i*k)
102+
th_x = view(x, lo:hi)
103+
th_res_out = view(res_out, (i-1)*k+1:i*k)
104+
idx, cnt = initiate_topk_res!(th_res, th_x, by)
105+
topk_sort!(th_res, 1, cnt, lt_fun)
106+
if k < 21
107+
for i in idx+1:length(th_x)
108+
if !ismissing(by(th_x[i]))
109+
insert_fixed_sorted!(th_res, th_x[i], lt_fun)
110+
cnt += 1
111+
end
112+
end
113+
else
114+
for i in idx+1:length(th_x)
115+
if !ismissing(by(th_x[i]))
116+
insert_fixed_sorted_binary!(th_res, th_x[i], lt_fun)
117+
cnt += 1
118+
end
119+
end
120+
end
121+
if cnt < k
122+
view(th_res_out, 1:cnt) .= view(th_res, 1:cnt)
123+
else
124+
th_res_out .= th_res
125+
end
126+
end
127+
topk_vals(res_out, k, lt_fun, by)
128+
end
129+
90130

91131
Base.@propagate_inbounds function hp_topk_perm(x::AbstractVector{T}, k::Int, lt_fun::F, by) where {T} where {F}
92132
k < 1 && throw(ArgumentError("k must be greater than 1"))
@@ -124,4 +164,51 @@ Base.@propagate_inbounds function hp_topk_perm(x::AbstractVector{T}, k::Int, lt_
124164
end
125165
end
126166
perm_out[topk_perm(res_out, k, lt_fun, by)]
167+
end
168+
169+
Base.@propagate_inbounds function hp_topk_perm(x::Union{Vector{T}, SubArray{T, N, Vector{T}, Tuple{I}, L}}, k::Int, lt_fun::F, by) where {T<:Union{Missing, FLOATS, INTEGERS}} where {F} where N where I <: UnitRange{Int} where L
170+
k < 1 && throw(ArgumentError("k must be greater than 1"))
171+
all(ismissing, x) && return Union{Missing,Int}[missing]
172+
nt = Threads.nthreads()
173+
res = Vector{T}(undef, k * nt)
174+
res_out = allowmissing(res)
175+
fill!(res_out, missing)
176+
perm = zeros(Int, k * nt)
177+
perm_out = allowmissing(perm)
178+
fill!(perm_out, missing)
179+
cz = div(length(x), nt)
180+
Threads.@threads for i in 1:nt
181+
lo = (i - 1) * cz + 1
182+
i == nt ? hi = length(x) : hi = i * cz
183+
th_res = view(res, (i-1)*k+1:i*k)
184+
th_perm = view(perm, (i-1)*k+1:i*k)
185+
th_x = view(x, lo:hi)
186+
th_res_out = view(res_out, (i-1)*k+1:i*k)
187+
th_perm_out = view(perm_out, (i-1)*k+1:i*k)
188+
idx, cnt = initiate_topk_res_perm!(th_perm, th_res, th_x, by, offset = lo - 1)
189+
topk_sort_permute!(th_res, th_perm, 1, cnt, lt_fun)
190+
if k < 16
191+
for i in idx+1:length(th_x)
192+
if !ismissing(by(th_x[i]))
193+
insert_fixed_sorted_perm!(th_perm, th_res, i + lo - 1, th_x[i], lt_fun)
194+
cnt += 1
195+
end
196+
end
197+
else
198+
for i in idx+1:length(th_x)
199+
if !ismissing(by(th_x[i]))
200+
insert_fixed_sorted_perm_binary!(th_perm, th_res, i + lo - 1, th_x[i], lt_fun)
201+
cnt += 1
202+
end
203+
end
204+
end
205+
if cnt < k
206+
view(th_res_out, 1:cnt) .= view(th_res, 1:cnt)
207+
view(th_perm_out, 1:cnt) .= view(th_perm, 1:cnt)
208+
else
209+
th_res_out .= th_res
210+
th_perm_out .= th_perm
211+
end
212+
end
213+
perm_out[topk_perm(res_out, k, lt_fun, by)]
127214
end

src/stat/non_hp_stat.jl

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,22 @@ function initiate_topk_res_perm!(perm, res, x, by; offset=0)
426426
end
427427
idx, cnt - 1
428428
end
429-
429+
# it is unsafe because x must be continuous in memory
430+
function unsafe_shift_insert!(x::AbstractVector{T}, i, item) where T<:Union{Missing, FLOATS, INTEGERS}
431+
n = length(x)
432+
ccall(:memmove, Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}, Csize_t), pointer(x, i + 1), pointer(x, i), (n - i) * Base.aligned_sizeof(T))
433+
x[i] = item
434+
x
435+
end
436+
Base.@propagate_inbounds function insert_fixed_sorted_binary!(x, item, lt_fun)
437+
if !lt_fun(item, x[end])
438+
return
439+
end
440+
idx = searchsortedlast(x, item, lt=lt_fun)
441+
unsafe_shift_insert!(x, idx + 1, item)
442+
nothing
443+
444+
end
430445
Base.@propagate_inbounds function insert_fixed_sorted!(x, item, lt_fun)
431446
if !lt_fun(item, x[end])
432447
return
@@ -445,6 +460,16 @@ Base.@propagate_inbounds function insert_fixed_sorted!(x, item, lt_fun)
445460
nothing
446461
end
447462
# TODO we do not need x, this is just easier to implement, later we may fix this
463+
Base.@propagate_inbounds function insert_fixed_sorted_perm_binary!(perm, x, idx, item, lt_fun)
464+
if !lt_fun(item, x[end])
465+
return
466+
end
467+
i = searchsortedlast(x, item, lt=lt_fun)
468+
unsafe_shift_insert!(x, i + 1, item)
469+
unsafe_shift_insert!(perm, i + 1, idx)
470+
nothing
471+
end
472+
448473
Base.@propagate_inbounds function insert_fixed_sorted_perm!(perm, x, idx, item, lt_fun)
449474
if !lt_fun(item, x[end])
450475
return
@@ -468,6 +493,7 @@ end
468493
Base.@propagate_inbounds function topk_vals(x::AbstractVector{T}, k::Int, lt_fun::F, by) where {T} where {F}
469494
k < 1 && throw(ArgumentError("k must be greater than 1"))
470495
all(ismissing, x) && return Union{Missing,T}[missing]
496+
# TODO should we use similar() here?
471497
res = Vector{T}(undef, k)
472498
idx, cnt = initiate_topk_res!(res, x, by)
473499
topk_sort!(res, 1, cnt, lt_fun)
@@ -484,6 +510,35 @@ Base.@propagate_inbounds function topk_vals(x::AbstractVector{T}, k::Int, lt_fun
484510
end
485511
end
486512

513+
#if k is greater than 20 (15 in topkperm) we switch to binary search - 21 and 16 are selected based on simulation study
514+
Base.@propagate_inbounds function topk_vals(x::Union{Vector{T}, SubArray{T, N, Vector{T}, Tuple{I}, L}}, k::Int, lt_fun::F, by) where {T<:Union{Missing, FLOATS, INTEGERS}} where {F} where N where I <: UnitRange{Int} where L
515+
k < 1 && throw(ArgumentError("k must be greater than 1"))
516+
all(ismissing, x) && return Union{Missing,T}[missing]
517+
res = Vector{T}(undef, k)
518+
idx, cnt = initiate_topk_res!(res, x, by)
519+
topk_sort!(res, 1, cnt, lt_fun)
520+
if k < 21
521+
for i in idx+1:length(x)
522+
if !ismissing(by(x[i]))
523+
insert_fixed_sorted!(res, x[i], lt_fun)
524+
cnt += 1
525+
end
526+
end
527+
else
528+
for i in idx+1:length(x)
529+
if !ismissing(by(x[i]))
530+
insert_fixed_sorted_binary!(res, x[i], lt_fun)
531+
cnt += 1
532+
end
533+
end
534+
end
535+
if cnt < k
536+
allowmissing(resize!(res, cnt))
537+
else
538+
allowmissing(res)
539+
end
540+
end
541+
487542
# ktop permutation
488543
#TODO should we return [missing] or Int[] when all elements are missings?
489544
Base.@propagate_inbounds function topk_perm(x::AbstractVector{T}, k::Int, lt_fun::F, by) where {T} where {F}
@@ -505,6 +560,34 @@ Base.@propagate_inbounds function topk_perm(x::AbstractVector{T}, k::Int, lt_fun
505560
allowmissing(perm)
506561
end
507562
end
563+
Base.@propagate_inbounds function topk_perm(x::Union{Vector{T}, SubArray{T, N, Vector{T}, Tuple{I}, L}}, k::Int, lt_fun::F, by) where {T<:Union{Missing, FLOATS, INTEGERS}} where {F} where N where I <: UnitRange{Int} where L
564+
k < 1 && throw(ArgumentError("k must be greater than 1"))
565+
all(ismissing, x) && return Union{Missing,Int}[missing]
566+
res = Vector{T}(undef, k)
567+
perm = zeros(Int, k)
568+
idx, cnt = initiate_topk_res_perm!(perm, res, x, by)
569+
topk_sort_permute!(res, perm, 1, cnt, lt_fun)
570+
if k < 16
571+
for i in idx+1:length(x)
572+
if !ismissing(by(x[i]))
573+
insert_fixed_sorted_perm!(perm, res, i, x[i], lt_fun)
574+
cnt += 1
575+
end
576+
end
577+
else
578+
for i in idx+1:length(x)
579+
if !ismissing(by(x[i]))
580+
insert_fixed_sorted_perm_binary!(perm, res, i, x[i], lt_fun)
581+
cnt += 1
582+
end
583+
end
584+
end
585+
if cnt < k
586+
allowmissing(resize!(perm, cnt))
587+
else
588+
allowmissing(perm)
589+
end
590+
end
508591

509592
"""
510593
topk(x, k; rev = false, lt = <, by = identity, threads = false)

test/stats.jl

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Random
1+
using Random,PooledArrays,CategoricalArrays
22
@testset "topk" begin
33
# general usage
44
for i in 1:100
@@ -24,7 +24,7 @@ using Random
2424
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j)
2525
end
2626
x = rand(Int8, 10000)
27-
for j in 1:15
27+
for j in 1:30
2828
@test partialsort(x, 1:j) == topk(x, j, rev=true) == topk(x, j, rev=true, threads=true)
2929
@test partialsort(x, 1:j, rev=true) == topk(x, j) == topk(x, j, threads=true)
3030
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true) == topkperm(x, j, rev=true, threads=true)
@@ -49,12 +49,27 @@ using Random
4949
@test partialsortperm(x, 1:min(11, j), rev=true) == topkperm(x, j)
5050
end
5151
x = [randstring() for _ in 1:101]
52-
for j in 1:15
52+
for j in 1:30
5353
@test partialsort(x, 1:j) == topk(x, j, rev=true) == topk(x, j, rev=true, threads=true)
5454
@test partialsort(x, 1:j, rev=true) == topk(x, j) == topk(x, j, threads=true)
5555
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true) == topkperm(x, j, rev=true, threads=true)
5656
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j) == topkperm(x, j, threads = true)
5757
end
58+
x = PooledArray(rand(1:100, 100))
59+
for j in 1:50
60+
@test partialsort(x, 1:j) == topk(x, j, rev=true) == topk(x, j, rev=true, threads=true)
61+
@test partialsort(x, 1:j, rev=true) == topk(x, j) == topk(x, j, threads=true)
62+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true) == topkperm(x, j, rev=true, threads=true)
63+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j) == topkperm(x, j, threads = true)
64+
end
65+
x = CategoricalArray(rand(100))
66+
for j in 1:50
67+
@test partialsort(x, 1:j) == topk(x, j, rev=true, lt = isless)
68+
@test partialsort(x, 1:j, rev=true) == topk(x, j, lt = isless)
69+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true, lt = isless)
70+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j, lt = isless)
71+
end
72+
5873
end
5974
x = [1, 10, missing, 100, -1000, 32, 54, 0, missing, missing, -1]
6075
@test topk(x, 2) == [100, 54] == topk(x, 2, threads = true)
@@ -93,4 +108,20 @@ using Random
93108
@test topkperm(x, 3, by=ff678) == [8, 1]
94109
@test topk(x, 3, by=ff678, rev=true) == [-1,-100]
95110
@test topkperm(x, 3, by=ff678, rev=true) == [1,8]
111+
112+
x=[missing for _ in 1:1000]
113+
@test isequal(topk(x, 10), topk(x,10,threads=true))
114+
@test isequal(topk(x, 10), [missing])
115+
@test isequal(topk(x, 100), topk(x,100,threads=true))
116+
@test isequal(topk(x, 100), [missing])
117+
@test isequal(topkperm(x, 100), topkperm(x,100,threads=true))
118+
@test isequal(topkperm(x, 100), [missing])
119+
@test isequal(topkperm(x, 10), topkperm(x,10,threads=true))
120+
@test isequal(topkperm(x, 10), [missing])
121+
@test isequal(topkperm(x, 10,rev=true), topkperm(x,10,threads=true,rev=true))
122+
@test isequal(topkperm(x, 10,rev=true), [missing])
123+
124+
x=CategoricalArray(rand(1000))
125+
# TODO categorical array is not thread safe - fortunately, it throws Errors - however, in future we may need to fix it
126+
@test_throws UndefRefError topk(x,10,lt=isless,threads=true)
96127
end

0 commit comments

Comments
 (0)