-
Notifications
You must be signed in to change notification settings - Fork 119
/
options.py
95 lines (73 loc) · 3.58 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
def parse_common_args(parser):
parser.add_argument('--model_type', type=str, default='base_model', help='used in model_entry.py')
parser.add_argument('--data_type', type=str, default='base_dataset', help='used in data_entry.py')
parser.add_argument('--save_prefix', type=str, default='pref', help='some comment for model or test result dir')
parser.add_argument('--load_model_path', type=str, default='checkpoints/base_model_pref/0.pth',
help='model path for pretrain or test')
parser.add_argument('--load_not_strict', action='store_true', help='allow to load only common state dicts')
parser.add_argument('--val_list', type=str, default='/data/dataset1/list/base/val.txt',
help='val list in train, test list path in test')
parser.add_argument('--gpus', nargs='+', type=int)
parser.add_argument('--seed', type=int, default=1234)
return parser
def parse_train_args(parser):
parser = parse_common_args(parser)
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum for sgd, alpha parameter for adam')
parser.add_argument('--beta', default=0.999, type=float, metavar='M',
help='beta parameters for adam')
parser.add_argument('--weight-decay', '--wd', default=0, type=float,
metavar='W', help='weight decay')
parser.add_argument('--model_dir', type=str, default='', help='leave blank, auto generated')
parser.add_argument('--train_list', type=str, default='/data/dataset1/list/base/train.txt')
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--epochs', type=int, default=100)
return parser
def parse_test_args(parser):
parser = parse_common_args(parser)
parser.add_argument('--save_viz', action='store_true', help='save viz result in eval or not')
parser.add_argument('--result_dir', type=str, default='', help='leave blank, auto generated')
return parser
def get_train_args():
parser = argparse.ArgumentParser()
parser = parse_train_args(parser)
args = parser.parse_args()
return args
def get_test_args():
parser = argparse.ArgumentParser()
parser = parse_test_args(parser)
args = parser.parse_args()
return args
def get_train_model_dir(args):
model_dir = os.path.join('checkpoints', args.model_type + '_' + args.save_prefix)
if not os.path.exists(model_dir):
os.system('mkdir -p ' + model_dir)
args.model_dir = model_dir
def get_test_result_dir(args):
ext = os.path.basename(args.load_model_path).split('.')[-1]
model_dir = args.load_model_path.replace(ext, '')
val_info = os.path.basename(os.path.dirname(args.val_list)) + '_' + os.path.basename(args.val_list.replace('.txt', ''))
result_dir = os.path.join(model_dir, val_info + '_' + args.save_prefix)
if not os.path.exists(result_dir):
os.system('mkdir -p ' + result_dir)
args.result_dir = result_dir
def save_args(args, save_dir):
args_path = os.path.join(save_dir, 'args.txt')
with open(args_path, 'w') as fd:
fd.write(str(args).replace(', ', ',\n'))
def prepare_train_args():
args = get_train_args()
get_train_model_dir(args)
save_args(args, args.model_dir)
return args
def prepare_test_args():
args = get_test_args()
get_test_result_dir(args)
save_args(args, args.result_dir)
return args
if __name__ == '__main__':
train_args = get_train_args()
test_args = get_test_args()