数据并行

分布式训练的核心技术

基本原理

数据并行(Data Parallelism)是最简单的分布式训练方法。 将训练数据分成多份,每个GPU持有完整的模型副本,独立计算梯度后同步更新。

# 数据并行流程
1. 将batch分割到多个GPU
2. 每个GPU独立前向+反向计算
3. AllReduce同步梯度
4. 各GPU同步更新权重

DDP (Distributed Data Parallel)

PyTorch DDP

PyTorch原生数据并行实现,使用环形AllReduce高效同步梯度。

# PyTorch DDP示例
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
model = DistributedDataParallel(
model, device_ids=[local_rank]
)

特点

  • • 每个GPU持有完整模型副本
  • • 通信量与模型参数量成正比
  • • 显存效率低(冗余存储)
  • • 适合中小规模模型

FSDP (Fully Sharded Data Parallel)

完全分片数据并行

来自DeepSpeed ZeRO的思想,将模型参数、梯度、优化器状态分片到各GPU。

ZeRO-1
分片优化器状态
ZeRO-2
+分片梯度
ZeRO-3
+分片参数
# PyTorch FSDP示例
from torch.distributed.fsdp import FullyShardedDataParallel
model = FullyShardedDataParallel(
model,
mixed_precision=mixed_precision_policy
)

显存对比

方法模型存储梯度存储优化器状态
DDP完整副本完整副本完整副本
FSDP ZeRO-1完整副本完整副本分片
FSDP ZeRO-2完整副本分片分片
FSDP ZeRO-3分片分片分片
----