StartseiteArtikel

FlashAttention-4 wird offiziell veröffentlicht: Das Algorithmus-Pipeline wird grundlegend geändert, mit Matrixmultiplikations-Geschwindigkeit.

机器之心2026-03-06 17:16
Es ist nicht mehr notwendig, zwischen "Flexibilität" und "Hochleistung" zu wählen.

Nach einem Jahr harter Arbeit ist FlashAttention-4 endlich offiziell online gegangen.

In letzter Zeit hat die wichtige untere Optimierungstechnologie FlashAttention im Bereich des Deep Learnings ein großer Versionsupdate erhalten.

Tri Dao, der Hauptautor von FlashAttention und Assistentprofessor an der Princeton University, sagte: Auf der Blackwell GPU ist die Ausführungsgeschwindigkeit des Attention-Mechanismus jetzt fast so schnell wie die der Matrixmultiplikation, auch wenn die Engpässe völlig unterschiedlich sind!

Derzeit ist die Geschwindigkeit der Tensor Core so schnell, dass der Engpass bei der Vorwärtsausbreitung des Attention-Mechanismus exponentiell zunimmt, während der Engpass bei der Rückwärtsausbreitung die Bandbreite des gemeinsamen Speichers ist.

Der neu gestaltete Algorithmus enthält einige Mechanismen, die darauf abzielen, diese Engpässe zu überwinden, einschließlich der Verwendung von Polynomen zur Exponentialsimulation, eines neuen Online-Softmax, der 90 % der Softmax-Neuskalierungen vermeiden kann, und der 2CTA MMA-Befehle, die es zwei Threadblöcken ermöglichen, Operanden zu teilen, um den SMEM-Datenverkehr zu reduzieren.

  • Publikationsadresse: https://github.com/Dao-AILab/flash-attention/blob/main/assets/fa4_paper.pdf
  • Code-Link: https://github.com/Dao-AILab/flash-attention

Im Folgenden werden wir uns das genauer ansehen.

Hardwaretrends: Asymmetrische Hardwareskalierung

Seit langem ist Attention als Kernschicht der allgegenwärtigen Transformer-Architektur ein Leistungseingriff für große Sprachmodelle und Anwendungen mit langen Kontexten.

Früher hat FlashAttention-3 Attention durch asynchrone Ausführung und Warp-Spezialisierung optimiert, aber es war hauptsächlich für die Hopper GPU (H100)-Architektur konzipiert.

Die KI-Branche hat sich jedoch schnell auf die Implementierung von Blackwell-Architektur-Systemen wie B200 und GB200 konzentriert. Moderne Beschleuniger wie die Blackwell GPU setzen einen Trend fort: Asymmetrische Hardwareskalierung.

Unter diesem Trend wächst die Durchsatzleistung der Tensor Core viel schneller als die anderer Hardware-Ressourcen, wie die Bandbreite des gemeinsamen Speichers, die speziellen Funktionsmodule (SFU) für übergeordnete Funktionsberechnungen wie Exponenten und die allgemeinen ganzzahligen und Gleitkomma-ALU...

Zum Beispiel hat sich die Durchsatzleistung der BF16-Tensor Core von Hopper H100 auf Blackwell B200 um das 2,25-Fache erhöht (von 1 auf 2,25 PFLOPs), aber die Anzahl der SFU und die Bandbreite des gemeinsamen Speichers sind im Wesentlichen gleich geblieben.

Diese Skalierungsasymmetrie hat einen tiefgreifenden Einfluss auf die Optimierung von komplexen Kernels wie Attention.

Genauer gesagt enthält der Kern von Attention zwei allgemeine Matrixmultiplikationen (GEMM):

Dazwischen liegt Softmax, aber in der Realität beinhaltet Attention auch eine Menge an Hilfsarbeiten, wie Datenverschiebung, Synchronisierung, Datenlayout-Transformation, elementweise Berechnungen, Planung, Maskenverarbeitung usw.

Die traditionelle Ansicht besagt, dass die Leistung von Attention vollständig von der Geschwindigkeit der GEMM bestimmt wird. Eine Analyse der "Geschwindigkeit und Zufuhr" für B200 zeigt jedoch, dass der Hauptengpass nicht in der Tensor Core liegt, sondern:

Die SFU-Einheiten für die Softmax-Exponentenberechnung bei der Vorwärtsausbreitung;

Der Datenverkehr im gemeinsamen Speicher bei der Rückwärtsausbreitung, der durch die Bandbreite des gemeinsamen Speichers begrenzt ist.

Deshalb hat das Team FlashAttention-4, ein kooperatives Design von Algorithmus und Kernel, vorgestellt. Das Hauptziel besteht darin, durch Maximierung der Überlappung zwischen Matrixmultiplikation und anderen Engpass-Ressourcen auf der B200 (BF16) eine Leistung von bis zu 1605 TFLOPs/s (71 % Auslastung) zu erreichen, was 1,3-mal schneller als cuDNN 9.13 und 2,7-mal schneller als Triton ist.

Die Kernidee des kooperativen Designs lautet wie folgt:

  • Neue Pipeline: Es wurden neue Software-Pipelines für die Vorwärts- und Rückwärtsausbreitung entwickelt, die die voll asynchrone MMA und die größere Blockgröße (Tile) von Blackwell nutzen, um die Überlappung von Tensor Core-Berechnungen, Softmax-Berechnungen und Speicheroperationen zu maximieren;
  • Vorwärtsausbreitung (FWD): Es wird eine Software-Simulation der Exponentialfunktion durch Polynomnäherung auf der FMA-Einheit durchgeführt, um die Durchsatzleistung der Exponentenberechnung zu erhöhen. Gleichzeitig wird eine bedingte Softmax-Neuskalierung eingeführt, um unnötige Neuskalierungen zu überspringen und so den SFU-Engpass zu lindern;
  • Rückwärtsausbreitung (BWD): Der Tensor-Speicher (TMEM) wird zur Speicherung von Zwischenergebnissen genutzt, um den Datenverkehr im gemeinsamen Speicher zu reduzieren. Gleichzeitig wird die neue 2-CTA MMA-Modus von Blackwell verwendet, um den Zugriff auf den gemeinsamen Speicher weiter zu verringern und die Anzahl der atomaren Reduktionen um die Hälfte zu verringern. Darüber hinaus wird ein deterministisches Ausführungsmodus unterstützt, um reproduzierbares Training zu ermöglichen;
  • Planungsoptimierung: Ein neuer Tile-Planer wird eingeführt, um die Lastungleichgewichte aufgrund von kausalen Masken und Sequenzen variabler Länge zu beheben.

Neue Hardwaremerkmale von Blackwell

Tensor-Speicher (TMEM): Auf der B200 ist jeder der 148 SM (Stream Multiprocessor) mit 256 KB TMEM ausgestattet, der direkt mit der Tensor Core verbunden ist und zur Speicherung von Zwischenergebnissen für die Warp-Synchronisierung verwendet wird.

Vollständig asynchrone fünfte Generation Tensor Core: Der Befehl tcgen05.mma unterstützt asynchrone Ausführung und speichert die Akkumulationsergebnisse im TMEM. Für BF16 und FP16 ist die maximale UMMA-Tile, die ein einzelner CTA verwenden kann, 128×256×16, was etwa das Doppelte der größten WGMMA-Atomblöcke in der Hopper-Architektur ist. Die UMMA wird von einem einzelnen Thread gestartet, wodurch der Registerdruck verringert wird und es einfacher ist, größere Tiles und tiefere Pipelines zu verwenden, ohne dass es zu Registerüberläufen wie bei der Hopper Warpgroup MMA kommt.

Darüber hinaus macht dies die Warp-Spezialisierung noch praktikabler: Einige Warps sind für die Verschiebung von Tiles verantwortlich, andere für das Starten der MMA, wodurch die Matrixmultiplikation und -addition mit der Softmax-Berechnung und dem Speicherzugriff überlappt werden können. tcgen05.mma kann auch direkt die Operanden A aus dem TMEM lesen.

2-CTA MMA: Blackwell unterstützt die gemeinsame Ausführung einer UMMA-Operation durch ein Paar von CTA in demselben Cluster und über das TMEM von zwei CTA hinweg. Ein Thread im Leader-CTA startet die MMA, aber während der Ausführung müssen beide CTA aktiv bleiben. Durch die Teilung der M- und N-Dimensionen zwischen diesem CTA-Paar kann die Tile-Größe der MMA auf 256×256×16 erweitert werden, wodurch redundante Datenübertragungen reduziert und der Ressourcenverbrauch jedes CTA verringert werden.

Programmiersprache und Framework: CuTe-DSL

FlashAttention-4 (FA4) ist vollständig in CuTe-DSL implementiert, einer Python-Kernel-DSL, die von CUTLASS bereitgestellt wird.

Der Kernel-Code wird in Python geschrieben, und dann wird die DSL in PTX heruntergestuft und von der CUDA-Toolkette in GPU-Maschinencode kompiliert.

Dieses Programmiermodell stimmt auf Abstraktionsebene mit CuTe / CUTLASS überein und bietet gleichzeitig eine PTX-Ebene Escape Hatch (eine Schnittstelle für die untere Steuerung). Im Vergleich zur Verwendung von C++-Templates kann auf diese Weise die Kompilierungszeit um etwa das 20 - 30-Fache verkürzt werden.

Tri Dao schrieb sogar in einem Post auf X, dass er sich "seltsam aufgeregt" fühle. Das bedeutet, dass die Installation / "Kompilierung" jetzt nur einige Sekunden dauert, anstatt Minuten / Stunden.

Attention-Leistungsbenchmark

Das Team hat die Leistungsergebnisse von FlashAttention-4 auf der B200 (BF16) gezeigt und sie mit den Implementierungen von FlashAttention-2 sowie Triton, Gluon und cuDNN verglichen.

Die Ergebnisse zeigen:

  • Vorwärtsausbreitung (Forward Pass): FlashAttention-4 ist 1,1 - 1,3-mal schneller als cuDNN 9.13 und 2,1 - 2,7-mal schneller als die Triton-Implementierung.
  • Rückwärtsausbreitung (Backward Pass): In Szenarien mit langer Sequenzlänge schneidet FlashAttention-4 immer besser ab als die anderen Benchmark-Modelle.

Sobald FlashAttention-4 veröffentlicht wurde, hat es auch zu vielen Diskussionen geführt.

Das Pytorch-Team hat offiziell angekündigt, dass FlexAttention jetzt den FlashAttention-4-Backend unterstützt.

Pytorch hat erklärt, dass FlexAttention es Forschern seit langem ermöglicht, verschiedene benutzerdefinierte Attention-Varianten schnell zu prototypen. Derzeit werden es von mehr als 1000 Code-Repositories verwendet, und es wurden mehrere Dutzend Papers darüber veröffentlicht.

Die Benutzer hatten jedoch oft mit Leistungseingriffen zu kämpfen, bis FlashAttention-4 auf den Markt kam.

Jetzt haben sie für FlexAttention auf Hopper- und Blackwell-GPUs den FlashAttention-4-Backend hinzugefügt. PyTorch kann jetzt automatisch den Score/Mask-Modifikationscode von CuTeDSL generieren und ihn über JIT-Compilierung für benutzerdefinierte Attention-Varianten in FlashAttention-4 instanziieren.