多头注意力

多头注意力允许模型在不同的表示子空间中同时关注不同的位置,是Transformer表达能力强度的关键来源。

共 3 篇文章·阅读时间:约35分钟

01多头注意力原理

多头注意力将Q、K、V投影到多个子空间,在每个子空间独立计算注意力,最后拼接输出。

为什么需要多头

单头注意力只能捕捉一种类型的依赖关系,而不同类型的依赖关系需要不同的注意力模式:

不同头学习不同模式

  • 语法头:捕捉主谓宾关系
  • 语义头:捕捉同义词、指代关系
  • 位置头:关注邻近词
  • 长距离头:捕捉长距离依赖

多头公式

计算过程

1. 线性投影到h个不同的子空间:

Qi = XWiQ, Ki = XWiK, Vi = XWiV

2. 并行计算每个头的注意力:

headi = Attention(Qi, Ki, Vi)

3. 拼接并线性变换:

MultiHead(Q,K,V) = Concat(head1,...,headh)WO

参数设置

  • h:注意力头数(通常为8)
  • dk:每个头的维度 = dmodel/h
  • dmodel:模型维度(通常为512或768)

保持 dk × h = dmodel,使得总计算量与单头注意力相当。

02并行计算

多头注意力可以高效地并行计算,充分利用GPU的矩阵运算能力。

并行计算原理

多头注意力可以通过批矩阵乘法高效实现:

矩阵形式

输入 X:(batch_size, seq_len, dmodel)

投影:(batch_size, seq_len, h, dk)

Q、K、V:reshape为 (batch_size, h, seq_len, dk)

注意力:单次矩阵乘法即可完成所有头的QKT

效率分析

操作时间复杂度说明
QKV投影O(n · dmodel2)3个线性变换
注意力计算O(n2 · dk)h个头并行
输出投影O(n · dmodel2)最终线性变换

关于O(n2)复杂度

序列长度n的平方复杂度是注意力机制的瓶颈。 长序列(如16K、100K tokens)的计算和内存需求成为挑战。 这催生了各种优化:Flash Attention、稀疏注意力、线性注意力等。

03注意力可视化

注意力权重可以直观地展示模型关注的位置,是理解和调试Transformer的重要工具。

可视化方法

热力图

最常用的方法,横纵轴分别是源位置和目标位置, 颜色深浅表示注意力权重大小。

弧线图

用弧线连接源位置和目标位置,弧线粗细表示注意力权重。 适合展示稀疏的注意力模式。

头部重要性分析

分析每个头对最终任务的贡献,识别和剪枝冗余头。

应用场景

可解释性研究

  • 理解模型如何处理特定任务
  • 发现语言现象的表示方式
  • 分析不同层的注意力模式变化

实际应用

  • 模型调试:发现异常注意力模式
  • 模型压缩:剪枝不重要或不稳定的头
  • 知识编辑:定位存储知识的注意力头
  • 人机交互:展示模型的决策依据

头数选择建议

  • 小模型:4-8头,dk较大
  • 大模型:16-32头,dk较小
  • 超长上下文:考虑稀疏注意力而非增加头数
----