Skip to content

Commit 4a67e2c

Browse files
committed
exploit multithreading in gathering observation in special cases
1 parent e92e409 commit 4a67e2c

File tree

1 file changed

+55
-11
lines changed

1 file changed

+55
-11
lines changed

src/other/utils.jl

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -554,24 +554,67 @@ isequal_row(cols1::Tuple{Vararg{AbstractVector}}, r1::Int,
554554

555555

556556
_grabrefs(x) = DataAPI.refpool(x) == nothing ? x : DataAPI.refarray(x)
557-
function _gather_groups_hugeds_multicols(ds, cols, ::Val{T}; threads = true) where T
557+
function _gather_groups_hugeds_multicols(ds, cols, ::Val{T}; threads::Bool = true) where T
558558
colidx = index(ds)[cols]
559559
rhashes = byrow(ds, hash, cols, threads = threads)
560560
colsvals = ntuple(i->_grabrefs(_columns(ds)[colidx[i]]), length(colidx))
561-
create_dict_hugeds_multicols(colsvals, rhashes, Val(T))
561+
if threads
562+
rngs = _gather_groups_hugeds_splitter(rhashes, Val(T))
563+
groups = Vector{T}(undef, length(rhashes))
564+
ngroups_all = _gather_groups_hugeds_collector(groups, rngs, rhashes, colsvals, Val(T))
565+
ngroups = _gather_groups_hugeds_cleanup!(groups, ngroups_all, rngs)
566+
else
567+
groups = Vector{T}(undef, length(rhashes))
568+
rng = 1:length(rhashes)
569+
ngroups = create_dict_hugeds_multicols!(groups, rng, colsvals, rhashes, Val(T))
570+
end
571+
groups, T[], ngroups
572+
end
573+
574+
# TODO what happen if the values are not randomly grouped based on cols
575+
function _gather_groups_hugeds_splitter(rhashes, ::Val{T}) where T
576+
nt = 997 # TODO this should be an argument, however, we must be careful that this value doesn't degrade actual dictionary creation in Subsequent steps
577+
sz = div(length(rhashes), nt)
578+
rngs = [sizehint!(T[], sz) for _ in 1:nt]
579+
for i in eachindex(rhashes)
580+
push!(rngs[(rhashes[i] % nt)+1], i)
581+
end
582+
rngs
583+
end
584+
585+
function _gather_groups_hugeds_collector(groups, rngs, rhashes, colsvals, ::Val{T}) where T
586+
ngroups = Vector{Int}(undef, length(rngs))
587+
Threads.@threads for i in 1:length(rngs)
588+
_tmp = view(groups, rngs[i])
589+
ngroups[i] = create_dict_hugeds_multicols!(_tmp, rngs[i], colsvals, rhashes, Val(T))
590+
end
591+
ngroups
562592
end
563593

564-
function create_dict_hugeds_multicols(colvals, rhashes, ::Val{T}) where T
565-
sz = max(1 + ((5 * length(rhashes)) >> 2), 16)
594+
function _gather_groups_hugeds_cleanup!(groups, ngroups, rngs)
595+
our_cumsum!(ngroups)
596+
Threads.@threads for i in 2:length(rngs)
597+
for j in rngs[i]
598+
groups[j] += ngroups[i-1]
599+
end
600+
end
601+
return ngroups[end]
602+
end
603+
604+
# groups is a list of integeres for which the dict is going to be created
605+
# get index and set index should sometimes be adjusted based on rng
606+
# make sure groups is a vector{T}
607+
function create_dict_hugeds_multicols!(groups, rng, colvals, rhashes, ::Val{T}) where T
608+
isempty(rng) && return 0
609+
sz = max(1 + ((5 * length(groups)) >> 2), 16)
566610
sz = 1 << (8 * sizeof(sz) - leading_zeros(sz - 1))
567-
@assert 4 * sz >= 5 * length(rhashes)
611+
@assert 4 * sz >= 5 * length(groups)
568612
szm1 = sz-1
569613
gslots = zeros(T, sz)
570-
groups = Vector{T}(undef, length(rhashes))
571614
ngroups = 0
572-
@inbounds for i in eachindex(rhashes)
615+
@inbounds for i in eachindex(rng)
573616
# find the slot and group index for a row
574-
slotix = rhashes[i] & szm1 + 1
617+
slotix = rhashes[rng[i]] & szm1 + 1
575618
gix = -1
576619
probe = 0
577620
while true
@@ -580,8 +623,8 @@ function create_dict_hugeds_multicols(colvals, rhashes, ::Val{T}) where T
580623
gslots[slotix] = i
581624
gix = ngroups += 1
582625
break
583-
elseif rhashes[i] == rhashes[g_row] # occupied slot, check if miss or hit
584-
if isequal_row(colvals, i, Int(g_row)) # hit
626+
elseif rhashes[rng[i]] == rhashes[rng[g_row]] # occupied slot, check if miss or hit
627+
if isequal_row(colvals, Int(rng[i]), Int(rng[g_row])) # hit
585628
gix = groups[g_row]
586629
break
587630
end
@@ -590,9 +633,10 @@ function create_dict_hugeds_multicols(colvals, rhashes, ::Val{T}) where T
590633
probe += 1
591634
@assert probe < sz
592635
end
636+
# groups[i] has done its work we can modify it
593637
groups[i] = gix
594638
end
595-
return groups, gslots, ngroups
639+
return ngroups
596640
end
597641

598642

0 commit comments

Comments
 (0)