Skip to content

Commit 83288d1

Browse files
committed
reducing allocations in calculating var in fast path for gather by
1 parent d707ae6 commit 83288d1

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

src/sort/gatherby.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -229,9 +229,10 @@ function _fill_mapreduce_col!(x, f, op, y, loc)
229229
end
230230
end
231231

232-
function _fill_mapreduce_col!(x, f::Vector, op, y, loc)
232+
# only for calculating var - mval is a vector of means
233+
function _fill_mapreduce_col!(x, mval::Vector, op, y, loc)
233234
@inbounds for i in 1:length(y)
234-
x[loc[i]] = op(x[loc[i]], f[loc[i]](y[i]))
235+
x[loc[i]] = op(x[loc[i]], _abs2mean(y[i], mval[loc[i]]))
235236
end
236237
end
237238

@@ -247,11 +248,12 @@ function _fill_mapreduce_col_threaded!(x, f, op, y, loc, nt)
247248
end
248249
end
249250

250-
function _fill_mapreduce_col_threaded!(x, f::Vector, op, y, loc, nt)
251+
# only for calculating var - mval is a vector of means
252+
function _fill_mapreduce_col_threaded!(x, mval::Vector, op, y, loc, nt)
251253
@sync for thid in 0:nt-1
252254
Threads.@spawn for i in 1:length(y)
253255
@inbounds if loc[i] % nt == thid
254-
x[loc[i]] = op(x[loc[i]], f[loc[i]](y[i]))
256+
x[loc[i]] = op(x[loc[i]], _abs2mean(y[i], mval[loc[i]]))
255257
end
256258
end
257259
end
@@ -345,14 +347,15 @@ function _fill_gatherby_var_barrier!(res, countnan, meanval, ss, nval, cal_std,
345347
end
346348

347349
# TODO directly calculating var should be a better approach
350+
_abs2mean(x, meanval) = abs2(x - meanval)
348351
function _gatherby_var(gds, col; dof = true, cal_std = false, threads = true)
349352
if threads
350353
nt = Threads.nthreads()
351354
nt2 = max(div(nt,2),1)
352355
t1 = Threads.@spawn _gatherby_cntnan(gds, col, nt = nt2)
353356
t2 = Threads.@spawn _gatherby_mean(gds, col, nt = nt2)
354357
meanval = fetch(t2)
355-
t3 = Threads.@spawn gatherby_mapreduce(gds, [x->abs2(x - meanval[i]) for i in 1:length(meanval)], _stat_add_sum, col, nt2, missing, Val(Float64))
358+
t3 = Threads.@spawn gatherby_mapreduce(gds, meanval, _stat_add_sum, col, nt2, missing, Val(Float64))
356359
t4 = Threads.@spawn _gatherby_n(gds, col, nt = nt2)
357360
countnan = fetch(t1)
358361
ss = fetch(t3)
@@ -361,7 +364,7 @@ function _gatherby_var(gds, col; dof = true, cal_std = false, threads = true)
361364
t1 = _gatherby_cntnan(gds, col, threads = threads)
362365
t2 = _gatherby_mean(gds, col, threads = threads)
363366
meanval = t2
364-
t3 = gatherby_mapreduce(gds, [x->abs2(x - meanval[i]) for i in 1:length(meanval)], _stat_add_sum, col, Threads.nthreads(), missing, Val(Float64), threads = threads)
367+
t3 = gatherby_mapreduce(gds, meanval, _stat_add_sum, col, Threads.nthreads(), missing, Val(Float64), threads = threads)
365368
t4 = _gatherby_n(gds, col, threads = threads)
366369
countnan = t1
367370
ss = t3

0 commit comments

Comments
 (0)