数据并行
分布式训练的核心技术
基本原理
数据并行(Data Parallelism)是最简单的分布式训练方法。 将训练数据分成多份,每个GPU持有完整的模型副本,独立计算梯度后同步更新。
# 数据并行流程
1. 将batch分割到多个GPU
2. 每个GPU独立前向+反向计算
3. AllReduce同步梯度
4. 各GPU同步更新权重
DDP (Distributed Data Parallel)
PyTorch DDP
PyTorch原生数据并行实现,使用环形AllReduce高效同步梯度。
# PyTorch DDP示例
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
model = DistributedDataParallel(
model, device_ids=[local_rank]
)
特点
- • 每个GPU持有完整模型副本
- • 通信量与模型参数量成正比
- • 显存效率低(冗余存储)
- • 适合中小规模模型
FSDP (Fully Sharded Data Parallel)
完全分片数据并行
来自DeepSpeed ZeRO的思想,将模型参数、梯度、优化器状态分片到各GPU。
ZeRO-1
分片优化器状态
ZeRO-2
+分片梯度
ZeRO-3
+分片参数
# PyTorch FSDP示例
from torch.distributed.fsdp import FullyShardedDataParallel
model = FullyShardedDataParallel(
model,
mixed_precision=mixed_precision_policy
)
显存对比
| 方法 | 模型存储 | 梯度存储 | 优化器状态 |
|---|---|---|---|
| DDP | 完整副本 | 完整副本 | 完整副本 |
| FSDP ZeRO-1 | 完整副本 | 完整副本 | 分片 |
| FSDP ZeRO-2 | 完整副本 | 分片 | 分片 |
| FSDP ZeRO-3 | 分片 | 分片 | 分片 |
----