-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.py
executable file
·71 lines (58 loc) · 2.09 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import torch
import multiprocessing
import torch.distributed as dist
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def init_distributed_mode(args):
cpu_cont = multiprocessing.cpu_count()
# if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
# args.rank = int(os.environ["RANK"])
# args.world_size = int(os.environ['WORLD_SIZE'])
# args.gpu = int(os.environ['LOCAL_RANK'])
# elif 'SLURM_PROCID' in os.environ:
# args.rank = int(os.environ['SLURM_PROCID'])
# args.gpu = args.rank % torch.cuda.device_count()
# else:
# print('Not using distributed mode')
# args.distributed = False
# return
args.distributed = True
args.rank = get_rank()
# args.world_size = get_world_size()
args.gpu = args.rank % torch.cuda.device_count()
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}, word {}): {}'.format(
args.rank, args.world_size, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
device = torch.device("cuda", args.gpu)
args.n_gpu = torch.cuda.device_count()
args.device = device
args.cpu_cont = cpu_cont
setup_for_distributed(args.rank == 0)