Flash Attention
IO感知的精确注意力计算
概述
Flash Attention是2022年提出的一种IO感知的注意力算法,通过优化内存访问模式, 在保持计算精度的同时,将注意力计算的显存占用从O(n²)降低到O(n),并显著加速训练。
核心创新:分块计算 + 重计算策略,避免存储完整的注意力矩阵。
问题分析
标准注意力的显存问题
标准注意力计算需要存储完整的n×n注意力矩阵:
# 标准注意力
Q, K, V: [batch, heads, seq_len, head_dim]
Attention = softmax(Q @ K^T / √d) @ V
# 中间结果:S = Q @ K^T
S: [batch, heads, seq_len, seq_len] # O(n²) 显存!
# 对于seq_len=8K, S需要约1GB显存/batch
IO瓶颈
GPU内存层次结构带来的IO开销:
- HBM:高带宽内存,~1.5TB/s,容量大
- SRAM:片上缓存,~20TB/s,容量小(~20MB)
- 标准注意力频繁在HBM和SRAM之间读写
Flash Attention原理
分块计算 (Tiling)
将注意力计算分成小块,在SRAM中完成:
- 将Q、K、V分成小块(如128×128)
- 每个块的计算完全在SRAM中完成
- 避免存储完整的n×n注意力矩阵
- 只保留必要的归一化统计量
在线Softmax技巧
通过增量计算实现分块softmax:
# 标准softmax
softmax(x) = exp(x_i) / Σ exp(x_j)
# 在线softmax(增量计算)
m_new = max(m_old, m_block) # 最大值
d_new = d_old * exp(m_old - m_new) + Σ exp(x - m_new)
# 合并时调整之前的结果
重计算策略
反向传播时不存储注意力矩阵,而是重新计算:
- 前向:只存储Q、K、V和输出
- 反向:重新计算注意力分数(计算量增加,但IO减少)
性能提升
显存占用
O(n) → O(n)
从O(n²)降至O(n)
训练速度
2-4x
长序列提升更明显
序列长度
16K+
支持更长上下文
速度对比
| 序列长度 | 标准注意力 | Flash Attention | 加速比 |
|---|---|---|---|
| 512 | 1.0x | 1.2x | 1.2x |
| 1024 | 1.0x | 1.5x | 1.5x |
| 2048 | 1.0x | 2.0x | 2.0x |
| 4096 | 1.0x | 2.8x | 2.8x |
| 8192 | OOM | - | 可用 |
代码示例
PyTorch集成
# PyTorch 2.0+ 内置Flash Attention
import torch
import torch.nn as nn
# 使用scaled_dot_product_attention
# PyTorch会自动选择Flash Attention
query = torch.randn(32, 8, 1024, 64, device='cuda')
key = torch.randn(32, 8, 1024, 64, device='cuda')
value = torch.randn(32, 8, 1024, 64, device='cuda')
# 启用Flash Attention
with torch.backends.cuda.sdp_kernel(enable_flash=True):
output = torch.nn.functional.scaled_dot_product_attention(
query, key, value
)
# 或者在nn.MultiheadAttention中
attn = nn.MultiheadAttention(512, 8, device='cuda')
# PyTorch 2.0+ 自动使用Flash AttentionFlash Attention 2
# 使用flash-attn库 from flash_attn import flash_attn_func # Flash Attention 2 更快 q = torch.randn(1, 8, 4096, 128, device='cuda', dtype=torch.float16) k = torch.randn(1, 8, 4096, 128, device='cuda', dtype=torch.float16) v = torch.randn(1, 8, 4096, 128, device='cuda', dtype=torch.float16) output = flash_attn_func(q, k, v, softmax_scale=None, causal=True)
版本演进
| 版本 | 改进 | 速度 |
|---|---|---|
| Flash Attention 1 | 分块计算、重计算 | 基准 |
| Flash Attention 2 | 优化并行、减少非矩阵乘法 | 2x |
| Flash Attention 3 | H100优化、异步计算 | 1.5-2x |
应用场景
长上下文训练
支持16K-128K序列长度训练,扩展模型上下文窗口。
推理加速
减少KV Cache占用,支持更长上下文推理。
显存优化
同样GPU可以训练更大批量或更大模型。
精确计算
数值精确,无需近似,效果与标准注意力一致。
最佳实践
- PyTorch 2.0+自动启用,无需额外配置
- 使用FP16/BF16精度获得最佳性能
- 长序列训练时效果更显著
- 因果注意力(causal)使用causal=True参数
- 结合其他优化(如梯度检查点)进一步节省显存
参考资料
- FlashAttention: Fast and Memory-Efficient Exact Attention (Dao et al., 2022)
- FlashAttention-2: Faster Attention with Better Parallelism (Dao, 2023)
- FlashAttention-3: Fast and Accurate Attention (Dao et al., 2024)