-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmain.py
184 lines (157 loc) · 6.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
A main training script.
"""
# Copyright (c) Facebook, Inc. and its affiliates.
import warnings
warnings.filterwarnings('ignore') # never print matching warnings
import logging
import os
from collections import OrderedDict
import torch
import uniperceiver.utils.comm as comm
from uniperceiver.config import get_cfg, CfgNode
from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments
#!TODO re-implement hooks
from uniperceiver.engine import hooks
from uniperceiver.modeling import add_config
from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile
try:
import deepspeed
DEEPSPEED_INSTALLED = True
except:
DEEPSPEED_INSTALLED = False
import copy
def add_data_prefix(cfg):
# TODO: more flexible method
data_dir = os.getenv("DATA_PATH", None)
mapping_list = [
[cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]],
[cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]],
[cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]],
[cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]],
[cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]],
[cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]],
[cfg.MODEL, 'WEIGHTS', ['MODEL',]],
]
whitelist = ["BERT", "CLIP", "CLIP_CAPTION"]
if data_dir:
for node, attr ,_ in mapping_list:
if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist:
setattr(node, attr, os.path.join(data_dir, node[attr]))
for task in cfg.TASKS:
for _, item, key_list in mapping_list:
config_tmp = task
for key in key_list:
if key in config_tmp:
config_tmp = config_tmp[key]
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
config_tmp[item] = os.path.join(data_dir, config_tmp[item])
mapping_list = [
['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]],
]
if cfg.SHARED_TARGETS is None:
cfg.SHARED_TARGETS = []
for share_targets in cfg.SHARED_TARGETS:
for _, item, key_list in mapping_list:
config_tmp = share_targets
for key in key_list:
config_tmp = config_tmp[key]
if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith(
'/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith(
'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
config_tmp[item] = os.path.join(data_dir, config_tmp[item])
def add_default_setting_for_multitask_config(cfg):
# merge some default config in (CfgNode) uniperceiver/config/defaults.py to each task config (dict)
tasks_config_temp = cfg.TASKS
num_tasks = len(tasks_config_temp)
cfg.pop('TASKS', None)
cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)]
for i, task_config in enumerate(tasks_config_temp):
cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config))
cfg.TASKS[i] = cfg.TASKS[i].to_dict_object()
pass
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
tmp_cfg = cfg.load_from_file_tmp(args.config_file)
add_config(cfg, tmp_cfg)
cfg.merge_from_file(args.config_file)
add_data_prefix(cfg)
cfg.merge_from_list(args.opts)
#
add_default_setting_for_multitask_config(cfg)
cfg.freeze()
default_setup(cfg, args)
return cfg
def main(args):
cfg = setup(args)
"""
If you'd like to do anything fancier than the standard training logic,
consider writing your own training loop (see plain_train_net.py) or
subclassing the trainer.
"""
trainer = build_engine(cfg)
trainer.resume_or_load(resume=args.resume)
trainer.cast_layers()
if args.eval_only:
print('---------------------------')
print('eval model only')
print('---------------------------\n')
res = None
if trainer.val_data_loader is not None:
if trainer.model_ema is not None and args.eval_ema:
if comm.is_main_process():
print('using ema model for evaluation')
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
else:
if args.eval_ema and comm.is_main_process():
print('no ema model exists! using master model for evaluation')
res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
if comm.is_main_process():
print(res)
if trainer.test_data_loader is not None:
if trainer.model_ema is not None and args.eval_ema:
if comm.is_main_process():
print('using ema model for evaluation')
res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
else:
if args.eval_ema and comm.is_main_process():
print('no ema model exists! using master model for evaluation')
res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
if comm.is_main_process():
print(res)
return res
return trainer.train()
def get_args_parser():
parser = default_argument_parser()
if DEEPSPEED_INSTALLED:
parser = deepspeed.add_config_arguments(parser)
parser = add_moe_arguments(parser)
parser.add_argument('--init_method', default='slurm', type=str)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args_parser()
print("Command Line Args:", args)
if args.init_method == 'slurm':
# slurm init
check_dist_portfile()
init_distributed_mode(args)
main(args)
elif args.init_method == 'pytorch':
main(args)
else:
# follow 'd2' use default `mp.spawn` to init dist training
print('using \'mp.spawn\' for dist init! ')
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)