Distributed PyTorch Training¶
Abstract
This tutorial aims to introduce dispatching a distributed training job on Isambard-AI.
Prerequisites
We welcome people from all domain backgrounds that have experience training AI models with Data Parallelism on PyTorch.
This tutorial is meant to bridge the gap between training on a single node to distributing jobs across multiple nodes. High Performance Computing (Slurm) knowledge is not required.
A working conda installation is required to setup the environment, the instructions can be followed in the conda guide.
Learning Objectives
The learning objectives of this tutorial are as follows:
- Be able to launch a distributed PyTorch job.
- Understand how Slurm, MPI, and NCCL interact.
- Understand how MPI and NCCL are dispatched and the modules required for them.
- Combine the above to make use of the high speed network (Slingshot).
Tutorial Contents¶
Tutorial¶
1. Introduction and Setup¶
When training large AI models, some are simply too large to fit on a single GPU's memory. Large language models can exceed the memory capacity of even the most powerful single GPU. Distributed training allows you to split these large models across multiple GPUs and nodes, making it possible to train models that would otherwise be impossible on a single device - for example, when facing CUDA "out of memory" (OOM) errors.
Distributed training jobs differ according to the infrastructure of the underlying system. This tutorial will introduce how to launch multi-node jobs on Isambard-AI to make use of the Slingshot 11 High Speed Network (hsn
) using MPI and NCCL (Nvidia Collective Communications Library).
Distributed pytorch is based around the torch.distributed
python module. It is very similar to other parallel computing methods such as MPI.
PyTorch will launch MPI processes behind the scenes. However, torch
has to assign a GPU to each process and connect them using NCCL.
PyTorch torch.distributed.launch
vs. torchrun
PyTorch used to provide the command line tool torch.distributed.launch
which has moved to torchrun
.
Please see the PyTorch distributed documentation for transitioning to the torchrun
command.
Setting up the environment¶
We will start by setting up a conda environment with our dependencies.
name: pytorch_env
channels:
- conda-forge
- nodefaults
dependencies:
- python=3.10
- pytorch=2.5.1
- torchvision
- transformers
Click here to download the file: pytorch_conda_env.yml
We can create and activate the environment by executing the following commands. We then test GPU capability by printing cuda availability.
$ conda env create -f pytorch_conda_env.yml
$ conda activate pytorch_env
$ python3 -c "import torch; print(torch.cuda.is_available())"
True
Installing ML Applications
More information on ML applications and frameworks is available in the ML Applications Documentation
2. Understanding the backend¶
Launching jobs with mpirun
¶
MPI is used to orchestrate the processes at a high level and decide the ranks between them. The Nvidia Collective Communications Library (nccl
) provides collectives (e.g. All Reduce) for efficient GPU-to-GPU communication. For in-depth information on how PyTorch uses MPI and NCCL for collectives please see the PyTorch distributed documentation.
We will use the BriCS supplied OpenMPI and NCCL since it includes performance improvements for the network.
$ module load brics/ompi brics/nccl
We will now use mpirun
to dispatch 4 processes on the login node, where -np
decides the number of processes to launch:
$ mpirun -np 4 env | grep LOCAL_RANK
OMPI_COMM_WORLD_LOCAL_RANK=0
OMPI_COMM_WORLD_LOCAL_RANK=1
OMPI_COMM_WORLD_LOCAL_RANK=2
OMPI_COMM_WORLD_LOCAL_RANK=3
What environment variables does MPI set on the different processes?
Try running mpirun -np 4 env | grep MPI
and have a look at which environment variables are set for each process.
Launching jobs with srun
¶
To run on multiple nodes we must use the process manager (Slurm) srun
command which is similar to MPI's mpirun
command. It manages the distribution of processes across multiple nodes, leveraging the process manager's capabilities. By using the below command, you can execute a job across two nodes, with each node running the specified command.
$ srun -N 2 hostname
nid001031
nid001033
What environment variables does Slurm set on the different processes?
Try running srun -N 2 env | grep SLURM
and have a look at which environment variables are set for each process.
MPI and Slurm Documentation
For more background the MPI and PMI documentation is available here: MPI Guide.
Our Slurm Documentation will explain more about how to use srun
and sbatch
commands.
PyTorch distributed¶
PyTorch's torch.distributed
module simplifies launching distributed training jobs.
To demonstrate a simple distributed setup, we can use a Python script named launch.py
that initializes a distributed process group and prints the global and local ranks of each process.
Here is the content of launch.py:
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def init_process(rank, world_size, fn, backend='gloo'):
dist.init_process_group(backend, rank=rank, world_size=world_size)
fn(rank, world_size)
def fn(rank, world_size):
"""Distributed function to be implemented."""
print(f"Hello from rank {rank} out of {world_size} processes on {os.uname().nodename}")
dist.destroy_process_group()
if __name__ == "__main__":
world_size = int(os.environ['WORLD_SIZE'])
mp.spawn(init_process,
args=(world_size, fn, 'gloo'),
nprocs=world_size,
join=True)
Click here to download the file: launch.py
This script illustrates how PyTorch multiprocessing and distributed training work:
init_process()
: Initializes the distributed environment using the specified backend. Executes the provided functionfn()
for each process.fn()
: Prints a message from each process, indicating its rank and the total number of processes. Cleans up the process group after execution."__main__"
: Retrieves the total number of processes from theWORLD_SIZE
environment variable. Usesmp.spawn
to launch multiple processes, each running theinit_process()
function.
You can run this script with the following command:
(pytorch_env) $ MASTER_ADDR='localhost' MASTER_PORT=29600 WORLD_SIZE=2 python3 launch.py
Hello from rank 0 out of 2 processes on nid001040
Hello from rank 1 out of 2 processes on nid001040
This command will launch 2 processes on the login node, and each process will print its rank and the total number of processes.
MASTER_ADDR
: The address of the master node (in this case,localhost
). This is the hostname of the machine where the master process is running. On Isambard-AI this will resolve to the node name and we will show you below how to set this.MASTER_PORT
: The port on the master node to which the worker nodes will connect (in this case, 29600). This helps in establishing communication between the master and worker nodes.WORLD_SIZE
: The total number of processes participating in the job (in this case, 2). This includes both the master and worker processes.
We set these environment variables manually in this example, however, PyTorch will pick up some environment variables from MPI/Slurm.
The gloo
backend
The gloo
backend is used in this example for simplicity. The gloo
backend is suitable for CPU-based distributed training and is included in the standard PyTorch installation.
By using the gloo
backend, you can easily test and understand the basics of distributed training without the complexity of setting up MPI.
3. Our first distributed job¶
To demonstrate how the network impacts performance, we will run a training benchmark of the BERT model. We will be using PyTorch's Distributed Data Parallel functionality which uses torch.multiprocessing
under the hood.
import torch
import transformers
from transformers import BertTokenizer, BertForSequenceClassification
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import time
from datetime import timedelta
import os
transformers.utils.logging.set_verbosity_error()
BATCH_SIZE = 2048
TRAINING_STEPS = 100
def init_process(backend='gloo'): # change to 'nccl' to use NCCL backend
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))
dist.init_process_group(backend=backend, timeout=timedelta(seconds=60*60))
rank = dist.get_rank()
local_rank = int(os.environ['LOCAL_RANK'])
world_size = dist.get_world_size()
if rank == 0:
num_cuda_devices = torch.cuda.device_count()
print(f"Process {rank} initialized on local rank {local_rank} with backend {backend} and world_size {world_size}.")
print(f"Number of locally available CUDA devices: {num_cuda_devices}")
def benchmark():
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
local_rank = int(os.environ['LOCAL_RANK'])
world_size = dist.get_world_size()
# Calculate per-GPU batch size
per_gpu_batch_size = BATCH_SIZE // world_size
model = BertForSequenceClassification.from_pretrained(model_name).cuda(local_rank)
model = DDP(model, device_ids=[local_rank])
optimizer = torch.optim.Adam(model.parameters())
# Create different data for each GPU
start_idx = local_rank * per_gpu_batch_size
end_idx = start_idx + per_gpu_batch_size
texts = [f"This is sample sentence {i} for benchmarking BERT." for i in range(start_idx, end_idx)]
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(local_rank)
labels = torch.ones(per_gpu_batch_size, dtype=torch.long).to(local_rank)
dist.barrier()
start_time = time.time()
for _ in range(TRAINING_STEPS):
optimizer.zero_grad()
outputs = model(**inputs, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
dist.barrier()
end_time = time.time()
if dist.get_rank() == 0:
print(f"Time taken for {TRAINING_STEPS} forward and backward pass(es) with BATCH_SIZE={BATCH_SIZE}: {end_time - start_time} seconds")
dist.destroy_process_group()
if __name__ == "__main__":
init_process()
benchmark()
Click here to download the file: train.py
Just like above we set up an init_process()
and a fn()
function (here named benchmark()
). Executing both of them in the main section:
init_process()
: Initializes the distributed environment using the specified backend (gloo
ornccl
). Sets the CUDA device for the current process based on theLOCAL_RANK
environment variablebenchmark()
:- Loads a pre-trained BERT model and tokenizer.
- Wraps the model with
DistributedDataParallel
(DDP) for distributed training. - Prepares a batch of sample text and dummy labels.
- Measures the time taken for one forward and backward pass through the model.
- Prints the benchmark result from the process with rank
0
.
Now that we have a benchmark script, the following batch script explains how to launch the benchmark on Isambard-AI, including #SBATCH
directives and srun
arguments. We use scontrol
to set the MASTER_ADDR
automatically. Note that the arguments have to be set in both the sbatch
directives and the srun
command.
#!/bin/bash
#SBATCH --job-name=Torch_Distributed
#SBATCH --nodes=2
#SBATCH --gpus=8
#SBATCH --ntasks-per-node=4
module load brics/ompi brics/nccl
source $HOME/miniforge3/etc/profile.d/conda.sh
conda activate pytorch_env
export MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1) # e.g. nid001038
export MASTER_PORT=29600
# Print the job has started
echo "Job Started"
# Run the function with srun
srun -N 2 \
--gpus=8 \
--ntasks-per-node=4 \
--mpi=pmix \
bash -c 'export WORLD_SIZE=$SLURM_NTASKS; export RANK=$PMIX_RANK; export LOCAL_RANK=$SLURM_LOCALID; python benchmark.py'
Click here to download the file: sbatch_pytorch.sh
To run our job we can simply execute:
$ sbatch sbatch_pytorch.sh
Process 0 initialized on local rank 0 with backend gloo and world_size 8.
Number of locally available CUDA devices: 4
Time taken for 100 forward and backward pass(es) with BATCH_SIZE=2048: 55.21897602081299 seconds
Make a note of the performance.
Using the nccl
backend¶
We can now change the line dist.init_process_group(backend='gloo')
to use the 'nccl'
backend.
$ sbatch sbatch_pytorch.sh
Process 0 initialized on local rank 0 with backend nccl and world_size 8.
Number of locally available CUDA devices: 4
Time taken for 100 forward and backward pass(es) with BATCH_SIZE=2048: [HIDDEN] seconds
nccl
vs. gloo
How does the performance change between the different backends? What if the batch size is increased?
You can now try scaling down to 1 GPU or scaling up beyond two nodes. The high speed network provides both low-latency and high-bandwidth suitable for whatever your desired model size and batch size.
4. Conclusion¶
In this tutorial, you've learned how to launch distributed PyTorch training jobs on Isambard-AI's high-speed network. We covered how Slurm orchestrates processes across nodes, how MPI provides the initial process rankings, and how NCCL enables efficient GPU-to-GPU communication. You've seen how switching from the gloo
to nccl
backend can significantly improve training performance by leveraging the Slingshot network.
MPI is only used to set up the process ranks, the backend is then taken over by NCCL to enable GPU collectives. NCCL will bypass the CPU, and facilitate the GPUs to communicate directly to each other without the operating system interfering. This process is known as RDMA (Remote Direct Memory Access), and dramatically decreases the latency and increases the bandwidth, ensuring your job scales as the model size and batch size increase.
DistributedDataParallel
(DDP) works by combining PyTorch's multiprocessing (torch.multiprocessing
) and distributed (torch.distributed
) modules. When you wrap a model with DDP, it creates replicas of your model across different processes, each assigned to a specific GPU. During training, gradients are automatically synchronized across all processes using NCCL's efficient GPU-to-GPU communication primitives. This synchronization happens in the backward pass, ensuring all model replicas maintain identical parameters.
Checklist for HSN (High Speed Network) Usage
Before running your distributed job, ensure:
- Loaded required modules:
module load brics/ompi brics/nccl
. - Using NCCL backend:
backend=nccl
ininit_process_group()
. MASTER_ADDR
is set correctly usingexport MASTER_ADDR=$(scontrol show hostname $SLURM_NODELIST | head -n 1)
.- Each process has been assigned to correct GPU using
LOCAL_RANK
.
Debugging NCCL
Set NCCL_DEBUG=INFO in your environment to see detailed information about NCCL initialization and communication:
bash export NCCL_DEBUG=INFO
This will help verify you're using the HSN and diagnose any communication issues.
PMI Types¶
Note that the environment variables depend on the PMI and MPI version you use according to the table below.
Environment Variable | --mpi=cray_shasta |
--mpi=pmi2 |
--mpi=pmix |
---|---|---|---|
World Size | SLURM_NTASKS | SLURM_NTASKS | SLURM_NTASKS |
Number of Nodes | SLURM_NNODES | SLURM_NNODES | SLURM_NNODES |
Global Rank | PMI_RANK | PMI_RANK | PMIX_RANK |
Node Rank | SLURM_NODEID | SLURM_NODEID | SLURM_NODEID |
Local Rank | SLURM_LOCALID | SLURM_LOCALID | SLURM_LOCALID |
Note: Containers¶
All the above is applicable for containers given you follow the multi-node container documentation