FlashAttention-3
https://github.com/togethercomputer/flash-attention-3/raw/main/assets/flash3_fp16_fwd.png
(1) overlap overall computation and data movement via warp-specialization and
(2) interleave block-wise matmul and softmax operations, and
(3) incoherent processing that leverages hardware support for FP8 low-precision.
It’s 1.5-2.0x faster than FlashAttention-2
With FP16, up to 740 TFLOPS, i.e., 75% utilization of H100 theoretical max FLOPS.
With FP8, FlashAttention-3 reaches close to 1.2 PFLOPS, with 2.6x smaller error than baseline FP8 attention.
More efficient GPU Utilization
The new technique uses up to 75% of an H100 GPU’s maximum capabilities, up from just 35% before. This results in significantly (1.5-2x) faster than previous versions for training and running of large language models (LLMs).
Better performance with lower precision
FlashAttention-3 can work with lower precision numbers (FP8) while maintaining accuracy. This allows for even faster processing and potentially lower memory usage, which could lead to cost savings and improved efficiency for customers running large-scale AI operations.
Ability to use longer context in LLMs
By speeding up the attention mechanism, FlashAttention-3 enables AI models to work with much longer pieces of text more efficiently. This could allow for applications that can understand and generate longer, more complex content without slowing down.
FlashAttentionの最新版(β)