-
Notifications
You must be signed in to change notification settings - Fork 12
Pytorch DDP configure
1. master node: the main gpu responsible for synchronizations, making copies, loading models, writing logs;
2. process group: if you want to train/test the model over K gpus, then the K process forms a group, which is supported by a backend (pytorch managed that for you, according to the documentation, nccl is the most recommended backend);
3. rank: within the process group, each process is identified by its rank, from 0 to K-1;
4. world size: the number of processes in the group i.e. gpu number K.
5. multi-processing: all children processes together with the parent process run the same code. In PyTorch, torch.multiprocessing provides convenient ways to create parallel processes. As the official documentation says, the spawn
function below addresses these concerns and takes care of error propagation, out of order termination, and will actively terminate processes upon detecting an error in one of them.
- init the process group in training function
def train(gpu, args):
rank = args.nr * args.gpus + gpu
dist.init_process_group(
backend="nccl",
init_method='env://', # default use environment variable
world_size=world_size,
rank=rank)
- split the dataloader to each process in the group, which can be easily achieved by torch.utils.data.DistributedSampler or any customized sampler;
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=args.world_size,
rank=rank
)
- wrap the model with DDP
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
- spawn processes in main
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nodes', default=1,
type=int, metavar='N')
parser.add_argument('-g', '--gpus', default=1, type=int,
help='number of gpus per node')
parser.add_argument('-nr', '--nr', default=0, type=int,
help='ranking within the nodes')
parser.add_argument('--epochs', default=2, type=int,
metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
#########################################################
args.world_size = args.gpus * args.nodes #
os.environ['MASTER_ADDR'] = '10.57.23.164' #
os.environ['MASTER_PORT'] = '8888' #
mp.spawn(train, nprocs=args.gpus, args=(args,)) #
#########################################################
python mnist-distributed.py -n 1 -g 8 -nr 0
on each node launch separately
python mnist-distributed.py -n 2 -g 8 -nr 0
python mnist-distributed.py -n 2 -g 8 -nr 1
Reference
- code, https://theaisummer.com/distributed-training-pytorch/
- https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
- https://medium.com/codex/a-comprehensive-tutorial-to-pytorch-distributeddataparallel-1f4b42bb1b51
- backend choice
- https://ai.googleblog.com/2022/05/alpa-automated-model-parallel-deep.html