-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsoftmax_examples.jl
More file actions
96 lines (65 loc) · 2.18 KB
/
softmax_examples.jl
File metadata and controls
96 lines (65 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
using Pkg
Pkg.activate(".")
using CUDA
using Flux
using BenchmarkTools
using Pluto
Pluto.run(workspace_use_distributed=false)
function softmax_basic(x)
exps = exp.(x)
return exps./sum(exps,dims=1)
end
#inner_max(x::X) where X = CUDA.max(x)
function softmax_kernel_sub_max(input::CuDeviceArray{T},acc::CuDeviceArray{T}) where {T}
# Thread index
i = (blockIdx().x - 1)* blockDim().x + threadIdx().x
# Only in valid threads
if i <= size(input,2)
v_max = -1f8
@view input[:,i]
for i in @view input[:,i]
if i > v_max
v_max=i
end
end
# Actual calculation
(@view acc[:,i]) .= exp.((@view input[:,i]) .-v_max)
sumsi = sum(view(acc,:,i))
(@view acc[:,i]) ./= sumsi
end
return nothing
end
v1 = randn(19962,10000)
@btime begin
threads_per_blockm = 1024 # 1024
blocks_per_gridm = cld(size(v1, 2), threads_per_blockm)
sumsm = CUDA.zeros(size(v1,2))
accsm = CUDA.zeros(size(v1))
@cuda threads=threads_per_blockm blocks=blocks_per_gridm softmax_kernel_sub_max(cu(v1), accsm)
end
#! CuDynamicSharedArray
#! Figure out how to get the online max.
function softmax_kernel(input::CuDeviceArray{T}, sumsi::CuDeviceArray{T}, acc::CuDeviceArray{T}) where {T}
# Thread index
i = (blockIdx().x - 1)* blockDim().x + threadIdx().x
j = (blockIdx().y - 1)* blockDim().y + threadIdx().y
#sumsi = CuDynamicSharedArray(T, size(input,2))
if i <= size(input,1) && j<=size(input,2)
# Exponentiation and denominator addition
acc[i,j] = exp(input[i,j])
#! I may still need logsumexp?
#CUDA.atomic_add!(pointer(sumsi,j), acc[i,j])
end
return nothing
end
#! Shared Memory is very limited.
v1 = randn(Float32, (19962,16000));
@btime begin
threads_per_block = (1024, 1) # 1024
blocks_per_grid = (cld(size(v1, 1), threads_per_block[1]), cld(size(v1, 2),threads_per_block[2]))
sums = CUDA.zeros(size(v1,2))
accs = CUDA.zeros(size(v1))
@cuda threads=threads_per_block blocks=blocks_per_grid softmax_kernel(cu(v1), sums, accs)
# Grid-synced division, kind of cheating.
#res_sm = accs./sums[:,:]'
end