Skip to content

Commit 3c7ca03

Browse files
committed
Update utils.jl
1 parent 4a67e2c commit 3c7ca03

File tree

1 file changed

+33
-17
lines changed

1 file changed

+33
-17
lines changed

src/other/utils.jl

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ function _gather_groups(ds, cols, ::Val{T}; mapformats = false, stable = true, t
430430
_max_level = nrow(ds)
431431

432432

433-
if nrow(ds) > 2^23 && !stable && 5<length(colidx)<16 # the result is stable anyway
433+
if nrow(ds) > 2^23 && !stable && 5<length(colidx)<16
434434
if !mapformats || all(==(identity), getformat.(Ref(ds), colidx))
435435
return _gather_groups_hugeds_multicols(ds, cols, Val(T); threads = threads)
436436
end
@@ -559,10 +559,10 @@ function _gather_groups_hugeds_multicols(ds, cols, ::Val{T}; threads::Bool = tru
559559
rhashes = byrow(ds, hash, cols, threads = threads)
560560
colsvals = ntuple(i->_grabrefs(_columns(ds)[colidx[i]]), length(colidx))
561561
if threads
562-
rngs = _gather_groups_hugeds_splitter(rhashes, Val(T))
562+
rngs, sz = _gather_groups_hugeds_splitter(rhashes, Val(T))
563563
groups = Vector{T}(undef, length(rhashes))
564-
ngroups_all = _gather_groups_hugeds_collector(groups, rngs, rhashes, colsvals, Val(T))
565-
ngroups = _gather_groups_hugeds_cleanup!(groups, ngroups_all, rngs)
564+
ngroups_all = _gather_groups_hugeds_collector(groups, rngs, sz, rhashes, colsvals, Val(T))
565+
ngroups = _gather_groups_hugeds_cleanup!(groups, ngroups_all, rngs, sz)
566566
else
567567
groups = Vector{T}(undef, length(rhashes))
568568
rng = 1:length(rhashes)
@@ -574,28 +574,44 @@ end
574574
# TODO what happen if the values are not randomly grouped based on cols
575575
function _gather_groups_hugeds_splitter(rhashes, ::Val{T}) where T
576576
nt = 997 # TODO this should be an argument, however, we must be careful that this value doesn't degrade actual dictionary creation in Subsequent steps
577-
sz = div(length(rhashes), nt)
578-
rngs = [sizehint!(T[], sz) for _ in 1:nt]
577+
sz = zeros(T, nt)
578+
# It is safe to record _ids - memory will be released and it does not add extra memory to the total amount (we later need to allocate groups)
579+
_id = Vector{Int16}(undef, length(rhashes))
579580
for i in eachindex(rhashes)
580-
push!(rngs[(rhashes[i] % nt)+1], i)
581+
_id[i] = (rhashes[i] % nt)+1
582+
sz[_id[i]] += 1
583+
end
584+
rngs = Vector{T}(undef, length(rhashes))
585+
prepend!(sz, T(0))
586+
our_cumsum!(sz)
587+
sz_cp = copy(sz)
588+
589+
for i in eachindex(rhashes)
590+
idx=_id[i]
591+
sz_cp[idx] += 1
592+
rngs[sz_cp[idx]] = i
581593
end
582-
rngs
594+
rngs, sz
583595
end
584596

585-
function _gather_groups_hugeds_collector(groups, rngs, rhashes, colsvals, ::Val{T}) where T
586-
ngroups = Vector{Int}(undef, length(rngs))
587-
Threads.@threads for i in 1:length(rngs)
588-
_tmp = view(groups, rngs[i])
589-
ngroups[i] = create_dict_hugeds_multicols!(_tmp, rngs[i], colsvals, rhashes, Val(T))
597+
function _gather_groups_hugeds_collector(groups, rngs, sz, rhashes, colsvals, ::Val{T}) where T
598+
ngroups = Vector{Int}(undef, length(sz)-1)
599+
Threads.@threads for i in 2:length(sz)
600+
hi = sz[i]
601+
lo = sz[i-1]+1
602+
_tmp = view(groups, view(rngs, lo:hi))
603+
ngroups[i-1] = create_dict_hugeds_multicols!(_tmp, view(rngs, lo:hi), colsvals, rhashes, Val(T))
590604
end
591605
ngroups
592606
end
593607

594-
function _gather_groups_hugeds_cleanup!(groups, ngroups, rngs)
608+
function _gather_groups_hugeds_cleanup!(groups, ngroups, rngs, sz)
595609
our_cumsum!(ngroups)
596-
Threads.@threads for i in 2:length(rngs)
597-
for j in rngs[i]
598-
groups[j] += ngroups[i-1]
610+
Threads.@threads for i in 3:length(sz)
611+
hi=sz[i]
612+
lo=sz[i-1]+1
613+
for j in lo:hi
614+
groups[rngs[j]] += ngroups[i-2]
599615
end
600616
end
601617
return ngroups[end]

0 commit comments

Comments
 (0)