Flash Attention

IO感知的精确注意力计算

概述

Flash Attention是2022年提出的一种IO感知的注意力算法,通过优化内存访问模式, 在保持计算精度的同时,将注意力计算的显存占用从O(n²)降低到O(n),并显著加速训练。

核心创新:分块计算 + 重计算策略,避免存储完整的注意力矩阵。

问题分析

标准注意力的显存问题

标准注意力计算需要存储完整的n×n注意力矩阵:

# 标准注意力

Q, K, V: [batch, heads, seq_len, head_dim]

Attention = softmax(Q @ K^T / √d) @ V


# 中间结果:S = Q @ K^T

S: [batch, heads, seq_len, seq_len] # O(n²) 显存!

# 对于seq_len=8K, S需要约1GB显存/batch

IO瓶颈

GPU内存层次结构带来的IO开销:

  • HBM:高带宽内存,~1.5TB/s,容量大
  • SRAM:片上缓存,~20TB/s,容量小(~20MB)
  • 标准注意力频繁在HBM和SRAM之间读写

Flash Attention原理

分块计算 (Tiling)

将注意力计算分成小块,在SRAM中完成:

  • 将Q、K、V分成小块(如128×128)
  • 每个块的计算完全在SRAM中完成
  • 避免存储完整的n×n注意力矩阵
  • 只保留必要的归一化统计量

在线Softmax技巧

通过增量计算实现分块softmax:

# 标准softmax

softmax(x) = exp(x_i) / Σ exp(x_j)


# 在线softmax(增量计算)

m_new = max(m_old, m_block) # 最大值

d_new = d_old * exp(m_old - m_new) + Σ exp(x - m_new)

# 合并时调整之前的结果

重计算策略

反向传播时不存储注意力矩阵,而是重新计算:
- 前向:只存储Q、K、V和输出
- 反向:重新计算注意力分数(计算量增加,但IO减少)

性能提升

显存占用

O(n) → O(n)

从O(n²)降至O(n)

训练速度

2-4x

长序列提升更明显

序列长度

16K+

支持更长上下文

速度对比

序列长度标准注意力Flash Attention加速比
5121.0x1.2x1.2x
10241.0x1.5x1.5x
20481.0x2.0x2.0x
40961.0x2.8x2.8x
8192OOM-可用

代码示例

PyTorch集成

# PyTorch 2.0+ 内置Flash Attention
import torch
import torch.nn as nn

# 使用scaled_dot_product_attention
# PyTorch会自动选择Flash Attention
query = torch.randn(32, 8, 1024, 64, device='cuda')
key = torch.randn(32, 8, 1024, 64, device='cuda')
value = torch.randn(32, 8, 1024, 64, device='cuda')

# 启用Flash Attention
with torch.backends.cuda.sdp_kernel(enable_flash=True):
    output = torch.nn.functional.scaled_dot_product_attention(
        query, key, value
    )

# 或者在nn.MultiheadAttention中
attn = nn.MultiheadAttention(512, 8, device='cuda')
# PyTorch 2.0+ 自动使用Flash Attention

Flash Attention 2

# 使用flash-attn库
from flash_attn import flash_attn_func

# Flash Attention 2 更快
q = torch.randn(1, 8, 4096, 128, device='cuda', dtype=torch.float16)
k = torch.randn(1, 8, 4096, 128, device='cuda', dtype=torch.float16)
v = torch.randn(1, 8, 4096, 128, device='cuda', dtype=torch.float16)

output = flash_attn_func(q, k, v, softmax_scale=None, causal=True)

版本演进

版本改进速度
Flash Attention 1分块计算、重计算基准
Flash Attention 2优化并行、减少非矩阵乘法2x
Flash Attention 3H100优化、异步计算1.5-2x

应用场景

长上下文训练

支持16K-128K序列长度训练,扩展模型上下文窗口。

推理加速

减少KV Cache占用,支持更长上下文推理。

显存优化

同样GPU可以训练更大批量或更大模型。

精确计算

数值精确,无需近似,效果与标准注意力一致。

最佳实践

  • PyTorch 2.0+自动启用,无需额外配置
  • 使用FP16/BF16精度获得最佳性能
  • 长序列训练时效果更显著
  • 因果注意力(causal)使用causal=True参数
  • 结合其他优化(如梯度检查点)进一步节省显存

参考资料

  • FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
  • FlashAttention-2: Faster Attention with Better Parallelism (Dao, 2023)
  • FlashAttention-3: Fast and Accurate Attention (Dao et al., 2024)
----