1. 概述
torch.distributed
是 PyTorch 提供的一个模块,用于支持分布式训练和计算。它允许在多个进程、多个 GPU 甚至多个节点(机器)之间协同工作,适用于数据并行、模型并行等场景。torch.distributed
提供了进程间通信的工具(如集合通信操作)和分布式训练的关键组件,配合工具如 torchrun
可以轻松启动分布式任务。
2. 核心功能
torch.distributed
的主要功能包括初始化进程组、进程间通信和分布式训练支持。以下是关键点:
2.1. 进程组初始化
- 作用:在分布式训练中,每个进程需要加入一个进程组以协调通信。
- 核心函数:
torch.distributed.init_process_group
- 初始化分布式环境,设置通信后端和进程排名。
- 参数:
backend
:通信后端,如'nccl'
(GPU)、'gloo'
(CPU/GPU)、'mpi'
(高性能计算)。init_method
:初始化方法,如'env://'
(环境变量,默认)、'tcp://'
(TCP 地址)、'file://'
(共享文件)。rank
:当前进程的全局排名(0 到 world_size-1)。world_size
:总进程数。
- 示例:
import torch.distributed as dist dist.init_process_group(backend='nccl', init_method='env://')
2.2. 进程间通信
- 集合通信:支持多进程间的数据交换,常见操作包括:
dist.broadcast(tensor, src)
:从源进程(src)广播张量到所有进程。dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
:对所有进程的张量执行归约(如求和、平均),结果同步到每个进程。dist.reduce(tensor, dst, op=dist.ReduceOp.SUM)
:将所有进程的张量归约到目标进程(dst)。dist.all_gather(tensor_list, tensor)
:收集所有进程的张量到每个进程的列表中。dist.scatter(tensor, scatter_list, src)
:从源进程分发张量到各进程。- 示例:
import torch tensor = torch.tensor([1.0, 2.0]) dist.all_reduce(tensor, op=dist.ReduceOp.SUM) print(tensor) # 所有进程的 tensor 都被求和
2.3. 分布式训练支持
- 分布式数据并行:
torch.nn.parallel.DistributedDataParallel
(DDP)- 封装模型,同步梯度,加速多 GPU 训练。
- 比
DataParallel
更高效,适合多节点和单节点多 GPU。 - 示例:
import torch.nn as nn model = YourModel().to(device) model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
- 数据采样:
torch.utils.data.distributed.DistributedSampler
- 确保每个进程获取数据集的不同子集,避免重复。
- 示例:
from torch.utils.data import DataLoader, DistributedSampler dataset = YourDataset() sampler = DistributedSampler(dataset) loader = DataLoader(dataset, batch_size=32, sampler=sampler)
3. 关键概念
3.1. 后端
- NCCL:NVIDIA 集体通信库,优化 GPU 训练,推荐用于多 GPU 场景。
- Gloo:跨平台后端,支持 CPU 和 GPU,适合调试或 CPU 训练。
- MPI:高性能计算后端,适用于集群环境。
3.2. 排名和规模
- Rank:每个进程的唯一标识(0 到 world_size-1)。
- World Size:总进程数,等于节点数 × 每节点进程数。
- Local Rank:进程在当前节点内的排名,用于绑定 GPU。
3.3. 初始化方法
env://
:默认方法,通过环境变量(如RANK
、WORLD_SIZE
、MASTER_ADDR
、MASTER_PORT
)配置,由torchrun
自动设置。tcp://
:指定主节点地址,如tcp://192.168.1.1:1234
。file://
:使用共享文件系统协调,如file:///tmp/sharedfile
。
4. 使用步骤
4.1. 启动分布式训练
- 使用
torchrun
启动脚本,自动设置环境变量:torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --rdzv_endpoint=localhost:1234 train.py
4.2. 编写训练脚本
以下是一个典型的分布式训练脚本示例:
1. import torch
2. import torch.distributed as dist
3. import torch.nn as nn
4. import os
5. from torch.utils.data import DataLoader, DistributedSampler
6.
7. def main():
8. # 4.2.1. 初始化进程组
9. dist.init_process_group(backend='nccl')
10. local_rank = int(os.environ['LOCAL_RANK'])
11. device = torch.device(f'cuda:{local_rank}')
12. torch.cuda.set_device(local_rank)
13.
14. # 4.2.2. 定义模型
15. model = YourModel().to(device)
16. model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
17.
18. # 4.2.3. 数据加载
19. dataset = YourDataset()
20. sampler = DistributedSampler(dataset)
21. loader = DataLoader(dataset, batch_size=32, sampler=sampler)
22.
23. # 4.2.4. 训练循环
24. for epoch in range(10):
25. sampler.set_epoch(epoch) # 确保数据洗牌
26. for data, target in loader:
27. data, target = data.to(device), target.to(device)
28. # 训练步骤:前向、损失、反向、优化
29. if dist.get_rank() == 0: # 仅主进程打印或保存
30. print(f"Epoch {epoch}, Loss: ...")
31.
32. # 4.2.5. 清理
33. dist.destroy_process_group()
34.
35. if __name__ == '__main__':
36. main()
5. 注意事项
5.1. 环境变量
torchrun
自动设置RANK
、LOCAL_RANK
、WORLD_SIZE
、MASTER_ADDR
、MASTER_PORT
。- 脚本中通过
os.environ
获取,如local_rank = int(os.environ['LOCAL_RANK'])
。
5.2. I/O 控制
- 仅在主进程(
dist.get_rank() == 0
)执行文件保存、日志打印等操作,避免冲突。
5.3. 性能优化
- 使用 NCCL 后端以获得最佳 GPU 性能。
- 确保批大小适中,过小可能导致通信开销过大。
6. 总结
torch.distributed
模块是 PyTorch 分布式训练的核心,提供进程组初始化、通信原语(如 all_reduce、broadcast)和 DDP 支持。配合 torchrun
,可以轻松在多 GPU 或多节点上运行训练任务。确保正确初始化进程组、使用 DistributedDataParallel
和 DistributedSampler
,并注意 I/O 管理。