FP16混合精度训练
Mixed Precision Training with FP16
概述
混合精度训练使用FP16(半精度)进行前向和反向传播计算, 同时保留FP32(单精度)的权重副本用于参数更新。 这种方法可以显著加速训练并减少显存占用。
核心优势:显存减少~50%,训练加速2-3x,精度损失极小。
精度对比
| 精度 | 位数 | 范围 | 精度 | 显存 |
|---|---|---|---|---|
| FP32 | 32位 | ±3.4×10³⁸ | ~7位小数 | 4字节/数 |
| FP16 | 16位 | ±65,504 | ~3位小数 | 2字节/数 |
| BF16 | 16位 | ±3.4×10³⁸ | ~2位小数 | 2字节/数 |
混合精度原理
FP16的问题
- 溢出风险:梯度可能超出FP16表示范围
- 下溢问题:小梯度可能变为零
- 精度损失:权重更新精度不足
解决方案
- FP32权重副本:用于精确更新
- 损失缩放:放大梯度防止下溢
- FP32累加:优化器内部用FP32计算
训练流程
1. 维护FP32权重副本 (master weights)
2. 前向传播:FP32权重 → FP16 → 计算 → FP16激活
3. 损失缩放:loss × scale
4. 反向传播:FP16梯度
5. 梯度缩放:gradient / scale
6. 参数更新:FP32权重更新,复制回FP16
损失缩放
损失缩放是解决梯度下溢的关键技术:
# 前向传播后
scaled_loss = loss * scale_factor
# 反向传播后
unscaled_gradient = gradient / scale_factor
# scale_factor通常为2^16 = 65536
动态损失缩放
自动调整scale_factor:
- 检测到溢出(NaN/Inf)时减小scale
- 连续N步无溢出时增大scale
- 跳过有溢出的更新步
代码示例
PyTorch自动混合精度
import torch
from torch.cuda.amp import autocast, GradScaler
# 创建GradScaler(动态损失缩放)
scaler = GradScaler()
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for batch in dataloader:
inputs, labels = batch
optimizer.zero_grad()
# 自动混合精度上下文
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
# 缩放损失并反向传播
scaler.scale(loss).backward()
# 梯度解缩并裁剪(可选)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 优化器步进
scaler.step(optimizer)
scaler.update()Hugging Face Trainer
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir='./output',
fp16=True, # 启用FP16
fp16_opt_level='O1', # 优化级别
# 其他参数
per_device_train_batch_size=8,
learning_rate=2e-5,
)
# Trainer会自动处理混合精度
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
trainer.train()BF16 vs FP16
| 特性 | FP16 | BF16 |
|---|---|---|
| 数值范围 | 窄(易溢出) | 宽(同FP32) |
| 精度 | 较高 | 较低 |
| 需要损失缩放 | 是 | 否 |
| 硬件支持 | 广泛 | Ampere+ |
| 推荐场景 | 通用 | 大模型训练 |
BF16在A100/H100等新硬件上首选,无需损失缩放,更稳定。
性能提升
显存节省
~50%
激活值和梯度占用减半
训练加速
2-3x
Tensor Core加速
批量提升
2x
可用更大批量
最佳实践
- Ampere架构(A100/H100)优先使用BF16
- FP16配合动态损失缩放使用
- 梯度裁剪在unscale后进行
- 监控训练中的NaN/Inf
- 敏感操作(如Softmax)保持高精度
- 验证模型精度与FP32训练一致
参考资料
- Mixed Precision Training (Micikevicius et al., 2018)
- A Manual for Implementing Mixed Precision Training (NVIDIA)
- PyTorch AMP Documentation
----