多头注意力
多头注意力允许模型在不同的表示子空间中同时关注不同的位置,是Transformer表达能力强度的关键来源。
01多头注意力原理
多头注意力将Q、K、V投影到多个子空间,在每个子空间独立计算注意力,最后拼接输出。
为什么需要多头
单头注意力只能捕捉一种类型的依赖关系,而不同类型的依赖关系需要不同的注意力模式:
不同头学习不同模式
- 语法头:捕捉主谓宾关系
- 语义头:捕捉同义词、指代关系
- 位置头:关注邻近词
- 长距离头:捕捉长距离依赖
多头公式
计算过程
1. 线性投影到h个不同的子空间:
Qi = XWiQ, Ki = XWiK, Vi = XWiV
2. 并行计算每个头的注意力:
headi = Attention(Qi, Ki, Vi)
3. 拼接并线性变换:
MultiHead(Q,K,V) = Concat(head1,...,headh)WO
参数设置
- h:注意力头数(通常为8)
- dk:每个头的维度 = dmodel/h
- dmodel:模型维度(通常为512或768)
保持 dk × h = dmodel,使得总计算量与单头注意力相当。
02并行计算
多头注意力可以高效地并行计算,充分利用GPU的矩阵运算能力。
并行计算原理
多头注意力可以通过批矩阵乘法高效实现:
矩阵形式
输入 X:(batch_size, seq_len, dmodel)
投影:(batch_size, seq_len, h, dk)
Q、K、V:reshape为 (batch_size, h, seq_len, dk)
注意力:单次矩阵乘法即可完成所有头的QKT
效率分析
| 操作 | 时间复杂度 | 说明 |
|---|---|---|
| QKV投影 | O(n · dmodel2) | 3个线性变换 |
| 注意力计算 | O(n2 · dk) | h个头并行 |
| 输出投影 | O(n · dmodel2) | 最终线性变换 |
关于O(n2)复杂度
序列长度n的平方复杂度是注意力机制的瓶颈。 长序列(如16K、100K tokens)的计算和内存需求成为挑战。 这催生了各种优化:Flash Attention、稀疏注意力、线性注意力等。
03注意力可视化
注意力权重可以直观地展示模型关注的位置,是理解和调试Transformer的重要工具。
可视化方法
热力图
最常用的方法,横纵轴分别是源位置和目标位置, 颜色深浅表示注意力权重大小。
弧线图
用弧线连接源位置和目标位置,弧线粗细表示注意力权重。 适合展示稀疏的注意力模式。
头部重要性分析
分析每个头对最终任务的贡献,识别和剪枝冗余头。
应用场景
可解释性研究
- 理解模型如何处理特定任务
- 发现语言现象的表示方式
- 分析不同层的注意力模式变化
实际应用
- 模型调试:发现异常注意力模式
- 模型压缩:剪枝不重要或不稳定的头
- 知识编辑:定位存储知识的注意力头
- 人机交互:展示模型的决策依据
头数选择建议
- 小模型:4-8头,dk较大
- 大模型:16-32头,dk较小
- 超长上下文:考虑稀疏注意力而非增加头数