【PyTorch】torch.distributed 模块:支持分布式训练和计算

1. 概述

torch.distributed 是 PyTorch 提供的一个模块,用于支持分布式训练和计算。它允许在多个进程、多个 GPU 甚至多个节点(机器)之间协同工作,适用于数据并行、模型并行等场景。torch.distributed 提供了进程间通信的工具(如集合通信操作)和分布式训练的关键组件,配合工具如 torchrun 可以轻松启动分布式任务。

2. 核心功能

torch.distributed 的主要功能包括初始化进程组、进程间通信和分布式训练支持。以下是关键点:

2.1. 进程组初始化
  • 作用:在分布式训练中,每个进程需要加入一个进程组以协调通信。
  • 核心函数torch.distributed.init_process_group
    • 初始化分布式环境,设置通信后端和进程排名。
    • 参数:
      • backend:通信后端,如 'nccl'(GPU)、'gloo'(CPU/GPU)、'mpi'(高性能计算)。
      • init_method:初始化方法,如 'env://'(环境变量,默认)、'tcp://'(TCP 地址)、'file://'(共享文件)。
      • rank:当前进程的全局排名(0 到 world_size-1)。
      • world_size:总进程数。
    • 示例:
      import torch.distributed as dist
      dist.init_process_group(backend='nccl', init_method='env://')
      
2.2. 进程间通信
  • 集合通信:支持多进程间的数据交换,常见操作包括:
    • dist.broadcast(tensor, src):从源进程(src)广播张量到所有进程。
    • dist.all_reduce(tensor, op=dist.ReduceOp.SUM):对所有进程的张量执行归约(如求和、平均),结果同步到每个进程。
    • dist.reduce(tensor, dst, op=dist.ReduceOp.SUM):将所有进程的张量归约到目标进程(dst)。
    • dist.all_gather(tensor_list, tensor):收集所有进程的张量到每个进程的列表中。
    • dist.scatter(tensor, scatter_list, src):从源进程分发张量到各进程。
    • 示例:
      import torch
      tensor = torch.tensor([1.0, 2.0])
      dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
      print(tensor)  # 所有进程的 tensor 都被求和
      
2.3. 分布式训练支持
  • 分布式数据并行torch.nn.parallel.DistributedDataParallel (DDP)
    • 封装模型,同步梯度,加速多 GPU 训练。
    • DataParallel 更高效,适合多节点和单节点多 GPU。
    • 示例:
      import torch.nn as nn
      model = YourModel().to(device)
      model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
      
  • 数据采样torch.utils.data.distributed.DistributedSampler
    • 确保每个进程获取数据集的不同子集,避免重复。
    • 示例:
      from torch.utils.data import DataLoader, DistributedSampler
      dataset = YourDataset()
      sampler = DistributedSampler(dataset)
      loader = DataLoader(dataset, batch_size=32, sampler=sampler)
      

3. 关键概念

3.1. 后端
  • NCCL:NVIDIA 集体通信库,优化 GPU 训练,推荐用于多 GPU 场景。
  • Gloo:跨平台后端,支持 CPU 和 GPU,适合调试或 CPU 训练。
  • MPI:高性能计算后端,适用于集群环境。
3.2. 排名和规模
  • Rank:每个进程的唯一标识(0 到 world_size-1)。
  • World Size:总进程数,等于节点数 × 每节点进程数。
  • Local Rank:进程在当前节点内的排名,用于绑定 GPU。
3.3. 初始化方法
  • env://:默认方法,通过环境变量(如 RANKWORLD_SIZEMASTER_ADDRMASTER_PORT)配置,由 torchrun 自动设置。
  • tcp://:指定主节点地址,如 tcp://192.168.1.1:1234
  • file://:使用共享文件系统协调,如 file:///tmp/sharedfile

4. 使用步骤

4.1. 启动分布式训练
  • 使用 torchrun 启动脚本,自动设置环境变量:
    torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --rdzv_endpoint=localhost:1234 train.py
    
4.2. 编写训练脚本

以下是一个典型的分布式训练脚本示例:

1.  import torch
2.  import torch.distributed as dist
3.  import torch.nn as nn
4.  import os
5.  from torch.utils.data import DataLoader, DistributedSampler
6.  
7.  def main():
8.      # 4.2.1. 初始化进程组
9.      dist.init_process_group(backend='nccl')
10.     local_rank = int(os.environ['LOCAL_RANK'])
11.     device = torch.device(f'cuda:{local_rank}')
12.     torch.cuda.set_device(local_rank)
13. 
14.     # 4.2.2. 定义模型
15.     model = YourModel().to(device)
16.     model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
17. 
18.     # 4.2.3. 数据加载
19.     dataset = YourDataset()
20.     sampler = DistributedSampler(dataset)
21.     loader = DataLoader(dataset, batch_size=32, sampler=sampler)
22. 
23.     # 4.2.4. 训练循环
24.     for epoch in range(10):
25.         sampler.set_epoch(epoch)  # 确保数据洗牌
26.         for data, target in loader:
27.             data, target = data.to(device), target.to(device)
28.             # 训练步骤:前向、损失、反向、优化
29.             if dist.get_rank() == 0:  # 仅主进程打印或保存
30.                 print(f"Epoch {epoch}, Loss: ...")
31. 
32.     # 4.2.5. 清理
33.     dist.destroy_process_group()
34. 
35. if __name__ == '__main__':
36.     main()

5. 注意事项

5.1. 环境变量
  • torchrun 自动设置 RANKLOCAL_RANKWORLD_SIZEMASTER_ADDRMASTER_PORT
  • 脚本中通过 os.environ 获取,如 local_rank = int(os.environ['LOCAL_RANK'])
5.2. I/O 控制
  • 仅在主进程(dist.get_rank() == 0)执行文件保存、日志打印等操作,避免冲突。
5.3. 性能优化
  • 使用 NCCL 后端以获得最佳 GPU 性能。
  • 确保批大小适中,过小可能导致通信开销过大。

6. 总结

torch.distributed 模块是 PyTorch 分布式训练的核心,提供进程组初始化、通信原语(如 all_reduce、broadcast)和 DDP 支持。配合 torchrun,可以轻松在多 GPU 或多节点上运行训练任务。确保正确初始化进程组、使用 DistributedDataParallelDistributedSampler,并注意 I/O 管理。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值