HomeArticle

FlashAttention-4 officially released: Major overhaul of algorithm pipeline, matrix multiplication-level speed

机器之心2026-03-06 17:16
There's no need to choose between "flexibility" and "high performance" anymore.

After a year of hard work, FlashAttention-4 has finally been officially launched.

Recently, FlashAttention, an important underlying optimization technology in the field of deep learning, has undergone a major version update.

Tri Dao, the core author of FlashAttention and an assistant professor at Princeton University, said, On the Blackwell GPU, even though the bottlenecks are quite different, the execution speed of the attention mechanism is now almost as fast as matrix multiplication!

Currently, the speed of Tensor Core is now so fast that the bottleneck of attention forward propagation has increased exponentially, while the bottleneck of attention backward propagation is the shared memory bandwidth.

The redesigned algorithm includes some mechanisms aimed at overcoming these bottlenecks, such as using polynomials for exponential simulation. The new online softmax can avoid 90% of softmax rescaling. The 2CTA MMA instruction allows two thread blocks to share operands to reduce smem traffic, etc.

  • Paper address: https://github.com/Dao-AILab/flash-attention/blob/main/assets/fa4_paper.pdf
  • Code link: https://github.com/Dao-AILab/flash-attention

Next, let's take a detailed look.

Hardware Trend: Asymmetric Hardware Scaling

For a long time, Attention, as the core layer in the ubiquitous Transformer architecture, has been the performance bottleneck for large language models and long-context applications.

Previously, FlashAttention-3 optimized Attention through asynchronous execution and warp specialization, but it was mainly targeted at the Hopper GPU (H100) architecture.

However, the AI industry has quickly shifted to deploying the Blackwell architecture systems, such as B200 and GB200. Modern accelerators like the Blackwell GPU continue a trend: asymmetric hardware scaling.

Under this trend, the throughput of Tensor Core grows much faster than other hardware resources, such as shared memory bandwidth, special function units (SFU) for transcendental function operations like exponential operations, and general-purpose integer and floating-point ALUs...

For example, from the Hopper H100 to the Blackwell B200, the throughput of the BF16 Tensor Core has increased by 2.25 times (from 1 to 2.25 PFLOPs), but the number of SFUs and the shared memory bandwidth have remained basically unchanged.

This scaling asymmetry has a profound impact on the optimization of complex kernels like Attention.

Specifically, the core of Attention contains two general matrix multiplications (GEMM):

There is a softmax in the middle, but in real practice, Attention also involves a large amount of auxiliary work, such as data transfer, synchronization, data layout conversion, element-wise operations, scheduling, mask processing, etc.

The traditional view is that the performance of Attention is completely determined by the speed of GEMM. However, a "speed and feed" analysis of the B200 shows that the main bottlenecks are not the Tensor Core, but:

The SFU units used for Softmax exponential operations in forward propagation;

The shared memory traffic in backward propagation, limited by the shared memory bandwidth.

To address this, the team has launched FlashAttention-4, a co-design of algorithm and kernel. The core goal is to maximize the overlap between matrix multiplication and other bottleneck resources. On the B200 (BF16), it can reach up to 1605 TFLOPs/s (71% utilization), 1.3 times faster than cuDNN 9.13 and 2.7 times faster than Triton.

The core ideas of the co-design are as follows:

  • New Pipeline: New software pipelines are designed for forward and backward propagation respectively. By leveraging the fully asynchronous MMA and larger tile sizes of Blackwell, the overlapping execution between Tensor Core computation, softmax computation, and memory operations is maximized.
  • Forward Propagation (FWD): Software simulation of the exponential function is implemented through polynomial approximation on the FMA unit to improve the throughput of exponential computation. At the same time, conditional softmax rescaling is introduced to skip unnecessary rescaling operations, thus alleviating the SFU bottleneck.
  • Backward Propagation (BWD): Tensor memory (TMEM) is used to store intermediate results to relieve the pressure on shared memory traffic. Meanwhile, combined with the new 2-CTA MMA mode of Blackwell, the shared memory access is further reduced, and the number of atomic reductions is halved. In addition, a deterministic execution mode is supported to achieve reproducible training.
  • Scheduling Optimization: A new tile scheduler is introduced to solve the load imbalance caused by causal masks and variable-length sequences.

New Hardware Features of Blackwell

Tensor Memory (TMEM): On the B200, each of the 148 SMs (Streaming Multiprocessors) is equipped with 256 KB of TMEM, which is directly connected to the Tensor Core and used for storing intermediate results for warp synchronization.

Fully Asynchronous Fifth-Generation Tensor Core: The instruction tcgen05.mma supports asynchronous execution and stores the accumulated results in TMEM. For BF16 and FP16, the maximum UMMA tile that a single CTA can use is 128×256×16, about twice the size of the largest WGMMA atomic block in the Hopper architecture. UMMA is initiated by a single thread, thus reducing register pressure and making it easier to use larger tiles and deeper pipelines without the register overflow problem of the Hopper warpgroup MMA.

In addition, this also makes warp specialization more feasible: some warps are responsible for moving tiles, and others are responsible for initiating MMA, thus achieving overlapping execution of matrix multiply-accumulate operations, softmax computation, and memory access. tcgen05.mma can also directly read operand A from TMEM.

2-CTA MMA: Blackwell supports a pair of CTAs in the same cluster to jointly execute a UMMA operation and span the TMEM of two CTAs. A thread in the leader CTA initiates the MMA, but both CTAs must remain active during execution. By splitting the M and N dimensions between this pair of CTAs, the tile size of MMA can be extended to 256×256×16, thus reducing redundant data transfer and the resource occupancy of each CTA. In a kernel, the CTA group size (1 or 2) must be consistent between TMEM operations and Tensor Core computations.

Programming Language and Framework: CuTe-DSL

FlashAttention-4 (FA4) is fully implemented using CuTe-DSL, a Python kernel DSL provided by CUTLASS.

The kernel code is written in Python, and then the DSL lowers it to PTX, which is then compiled into GPU machine code by the CUDA toolchain.

This programming model is consistent with CuTe / CUTLASS at the abstract level and provides a PTX-level escape hatch (low-level control interface). Compared with using C++ templates, this method can reduce the compilation time by about 20–30 times.

In response, Tri Dao even posted on X saying he was "weirdly excited". This means that installation / "compilation" now only takes a few seconds instead of minutes / hours.

Attention Performance Benchmark

The team presented the performance results of FlashAttention-4 on the B200 (BF16) and compared it with the implementations of FlashAttention-2, Triton, Gluon, and cuDNN.

The results show:

  • Forward Pass: FlashAttention-4 is 1.1–1.3 times faster than cuDNN 9.13 and 2.1–2.7 times faster than the Triton implementation.
  • Backward Pass: In the scenario of long sequence lengths, FlashAttention-4 always outperforms other benchmark models.

Once FlashAttention-4 was released, it also sparked a lot of discussions.

The official Pytorch announced that FlexAttention now supports the FlashAttention-4 backend.

Pytorch said that for a long time, FlexAttention has allowed researchers to quickly prototype various custom Attention variants. Currently, more than 1000 code repositories have adopted it, and dozens of papers have cited it.

However, users often encountered performance bottlenecks until the emergence of FlashAttention-4.

Now, they have added the FlashAttention-4 backend to FlexAttention on Hopper and Blackwell GPUs. PyTorch can now automatically generate score/mask modification code for CuTeDSL and instantiate FlashAttention-4 for custom Attention variants through JIT compilation.

The results show that under compute-limited workloads, it can still achieve a 1.2 to 3.2 times performance improvement compared to Triton. Researchers no longer have to choose between "flexibility" and "high performance".

A netizen believes that "FlashAttention-4 is a milestone." On the Blackwell architecture, Attention can now reach a speed close to that of matrix multiplication (matmul), which means that the computational bottleneck will completely shift to memory and communication. The Attention performance of about 1600 TFLOPs is amazing — it is 2–3 times higher than that of FlashAttention-3. "This will directly benefit all cutting-edge large models." Because faster Attention means a longer effective context window, lower inference costs, and stronger large-scale inference capabilities...

Reference links:

https://x.com/tri_dao/status/2029569881151263082

https://tridao.me/blog/2026/flash4/

This article is from the WeChat official account "MachineHeart", and is published by 36Kr with authorization.