-
Notifications
You must be signed in to change notification settings - Fork 553
Description
Sharing the current design for spmd on GPU for pytorch/xla. Feel free to suggest and comment.
🚀 Objective
This design is intended to describe what is needed to make GSPMD work in PyTorch/XLA on the GPU. By doing so, we can enable large scale PyTorch training via GSPMD and leverage the compiler-based sharding framework/tools.
Goals
- Functionality: Users should be able to train their models using SPMD on PyTorch/XLA GPU, such as the model test/spmd/test_train_spmd_imagenet.py.
- Usability: We should make it easy for users to try the SPMD. The user experience should be similar to the existing one, without requiring the user to do too much extra work such as using extra environment variables.
- Performance: We want the performance of SPMD on PyTorch/XLA GPU to be competitive with that of SPMD on JAX on GPU. Llama 2 is the model we will benchmark on.
Non-goals
- SPMD on CUDA tensors.
- SPMD composability with other technologies such as dynamo, FSDPv2.
Design
Usability/User experience
It is a trend that today’s machine learning models contain a significant number of parameters (~Billions) and data. It is likely that users will use SPMD on multiple GPU machines. This is true for TPU. So we will design the user interface around that assumption.
With that in mind, we can make the SPMD user experience similar to that of multi-host training that we have so far. In multi-host training on GPU, users run a torchrun script such as
PJRT_DEVICE=CUDA \
torchrun \
--nnodes=2 \
--node_rank=0 \
--nproc_per_node=4 \
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:12355" test_pytorch_xla_script.py
Similarly, we propose that for SPMD training, users can still use torchrun and start the training such as:
PJRT_DEVICE=CUDA \
torchrun \
--nnodes=${NUM_GPU_MACHINES} \
--node_rank=${RANK_OF_CURRENT_MACHINE} \
--nproc_per_node=1 \ # each machine create one process due to SPMD requirement
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:12355" \
test_pytorch_xla_spmd_script.py
The only notable difference is how we set the flag --nproc_per_node
. In the general multinode training, --nproc_per_node
indicates how many processes I want to use on the current node and it can be as large as the total number of GPU devices on the node. However, for SPMD, it has to be the constant number of 1 because SPMD requires one process per node. All the others torchrun
flags remain the same.
Then we enable SPMD mode via xr.use_spmd(
). In summary, --nproc_per_node=1
and xr.use_spmd()
enable the users to run PyTorch/XLA SPMD on GPU.
The benefits of this proposal is that we make the user experience consistent with PyTorch’s multinode training which predominantly uses torchrun
. Also, we can leverage the GKE tooling that we developed for PyTorch/XLA multinode training and provide a workload manager (for example SLURM) like user experience, another user experience recommended by PyTorch. Also, by using xr.use_spmd(
), we are on par with go/pytorch-spmd-usability in that we avoid using another environment variable XLA_USE_SPMD
hence simplifies the user experience.
GPU client
After the user starts the SPMD training via either torchrun or GKE, the aforementioned torchrun command will be run. A few environment variables will be set on each host:
LOCAL_RANK: The local rank of the process will be equal to the rank of the current GPU machine.
RANK: The global rank of the process will be equal to LOCAL_RANK hence equal to the number of participant GPU machines. This is because SPMD creates 1 process per host.
LOCAL_WORLD_SIZE: The local world size (e.g. number of processes running locally) will be 1 since SPMD creates 1 process per host.
WORLD_SIZE: total number of processes across the hosts will be equal to the number of GPU machines.
Then we need to first make sure the single process on each GPU machine can access all GPU devices. In contrast, currently each process can only access one GPU device. To accomplish this, we need to construct the StreamExecutorGpuClient with the correct GpuClientOptions. Most notably, GpuClientOptions.allowed_devices needs to be empty so that the StreamExecutorGpuClient can automatically detect all GPU devices attached to the current node.
Process group
One of the concerns about using torchrun is how we deal with process groups. Process group is a group of processes to achieve one task and it enables communication among the processes. But the XLA device process group is not supported for SPMD because SPMD is a single replica and the compiler should handle the communication/coordination instead of users manually communicating/coordinating processes via collective ops. The problem is if a XLA process group is created (by torchrun or something else and we do dist.init_process_group, the code would crash under SPMD mode.
It turns out it is not an issue. For one thing, the process group on XLA is only created when import torch_xla.distributed.xla_backend module but SPMD script does not import the module. On the other hand, because a process group is not needed for SPMD, there is no reason to do dist.init_process_group in the SPMD training/inference script.
PyTorch/XLA technologies
The design should work well with dynamo/non-dynamo, inference, and FSPD v2 because the design does not change the SPMD interface in any way. In other words, those aforementioned technologies should work in a hardware agnostic way hence will be discussed in detail in other documents.
Future work
Eventually, we want to have the one-GPU-per-process model, even for SPMD, for performance’s sake. The one-GPU-per-process model is the optimal configuration because each process can then be bound to a NUMA domain and a single NIC.
To enable it, we may need an overhaul of the existing SPMD design so it will come after the current work.
Performance
We will choose Llama 2 to run benchmarking since it is the one we benchmark in PyTorch/XLA SPMD on TPU. As a first step, we will compare the performance with JAX SPMD on GPU since the only variance is the ML framework (PyTorch/XLA vs JAX) and all the rest (XLA GPU compiler, hardware) will be the same. Later, we will compare the performance wil PyTorch on inductor, PyTorch eager, PyTorch/XLA SPMD on TPU, PyTorch/XLA FSDP.