Skip to content

Conversation

@jonahsamost
Copy link

This kernel mainly attains a speed up from doing an online softmax in both the forward and backward passes. It also avoids writing intermediate values to be used in the backward pass in favor of recomputation.

Tests were done on a 4090.

Forward Pass

Config Original Optimized Speedup
N=32, T=64, A=4 0.041ms 0.024ms 1.67x
N=32, T=64, A=6 0.045ms 0.025ms 1.80x
N=64, T=32, A=4 0.040ms 0.024ms 1.69x
N=128, T=64, A=6 0.045ms 0.024ms 1.87x
N=128, T=64, A=18 0.074ms 0.025ms 2.91x
N=256, T=32, A=15 0.067ms 0.025ms 2.63x
N=512, T=64, A=4 0.042ms 0.025ms 1.69x
N=512, T=64, A=18 0.076ms 0.026ms 2.94x
N=1024, T=32, A=18 0.076ms 0.026ms 2.88x
N=1024, T=64, A=18 0.125ms 0.027ms 4.56x
N=2048, T=32, A=18 0.125ms 0.027ms 4.67x
N=2048, T=64, A=6 0.109ms 0.028ms 3.92x
N=512, T=64, A=64 0.188ms 0.039ms 4.76x
N=256, T=64, A=128 0.344ms 0.053ms 6.48x

Backward Pass

Config Original Optimized Speedup
N=32, T=64, A=4 0.038ms 0.023ms 1.62x
N=32, T=64, A=6 0.043ms 0.023ms 1.87x
N=64, T=32, A=4 0.038ms 0.023ms 1.63x
N=128, T=64, A=6 0.043ms 0.023ms 1.83x
N=128, T=64, A=18 0.075ms 0.027ms 2.83x
N=256, T=32, A=15 0.066ms 0.026ms 2.53x
N=512, T=64, A=4 0.037ms 0.023ms 1.63x
N=512, T=64, A=18 0.080ms 0.037ms 2.16x
N=1024, T=32, A=18 0.081ms 0.036ms 2.28x
N=1024, T=64, A=18 0.142ms 0.048ms 2.98x
N=2048, T=32, A=18 0.143ms 0.048ms 2.99x
N=2048, T=64, A=6 0.101ms 0.029ms 3.44x
N=512, T=64, A=64 0.246ms 0.091ms 2.69x
N=256, T=64, A=128 0.409ms 0.110ms 3.72x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant