Recently I tweeted about realistic Speed-of-Light (SOL) of 5090 and RTX PRO 6000 for some dtypes, and mobicham asked me about FP8 MMA with FP16 accumulation. I of last year would turn to Triton for this - it’s trivial to change the accumulation dtype of tl.dot(). However, I roughly know how to write a fast matmul kernel now, so why not do it myself! In addition, I have been tinkering around with torch.cuda._compile_kernel(), which compiles CUDA kernels super fast via NVRTC. This seems ideal...| gau-nernst's blog
In this post, I will walkthrough how I learned to implement Flash Attention for 5090 in CUDA C++. The main objective is to learn writing attention in CUDA C++, since many features are not available in Triton, such as MXFP8 / NVFP4 MMA for sm120. I also feel this is a natural next step after learning about matmul kernels. Lastly, there are many excellent blogposts on writing fast matmul kernels, but there is none for attention. So I want to take this chance to write up something nicely.| gau-nernst's blog