forked from fuhx-ia/reid-yolo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
169 lines (145 loc) · 7.73 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# encoding: utf-8
import argparse
import os
import sys
import torch
from torch.backends import cudnn
sys.path.append('tools')
from config import cfg
from data import make_data_loader
from engine.trainer import do_train, do_train_with_center
from modeling import build_model
from layers import make_loss, make_loss_with_center
from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR
from loguru import logger
def train(cfg, args):
# prepare dataset
train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) # 加载数据集
# prepare model 模型初始化
model = build_model(args, num_classes)
if not args.IF_WITH_CENTER:
print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
optimizer = make_optimizer(cfg, model)
loss_func = make_loss(cfg, num_classes) # modified by gu
if args.pretrain_choice == 'imagenet':
start_epoch = 0
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
else:
print('Only support pretrain_choice for imagenet and self, but got {}'.format(args.pretrain_choice))
logger.info('ready train...')
do_train(
cfg,
model,
train_loader,
val_loader,
optimizer,
scheduler, # modify for using self trained model
loss_func,
num_query,
start_epoch, # add for using self trained model
args
)
elif args.IF_WITH_CENTER:
print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE)
loss_func, center_criterion = make_loss_with_center(cfg, num_classes, args) # modified by gu
optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion)
arguments = {}
if args.pretrain_choice == 'imagenet':
start_epoch = eval('weights/resnet50-19c8e357.pth')
print('Start epoch:', start_epoch)
path_to_optimizer = args.weights.replace('model', 'optimizer')
print('Path to the checkpoint of optimizer:', path_to_optimizer)
path_to_center_param = args.weights.replace('model', 'center_param')
print('Path to the checkpoint of center_param:', path_to_center_param)
path_to_optimizer_center = args.weights.replace('model', 'optimizer_center')
print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center)
model_dict = model.state_dict()
pretrained_dict = torch.load(args.weights)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if
k in model_dict.keys() == pretrained_dict.keys()}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
optimizer_dict = optimizer.state_dict()
pretrained_dict_optimizer = torch.load(path_to_optimizer)
pretrained_dict_optimizer = {k: v for k, v in pretrained_dict_optimizer.items() if
k in optimizer_dict.keys() == pretrained_dict_optimizer.keys()}
optimizer_dict.update(pretrained_dict_optimizer)
optimizer.load_state_dict(optimizer_dict)
center_dict = model.state_dict()
pretrained_dict_center = torch.load(args.weights)
pretrained_dict_center = {k: v for k, v in pretrained_dict_center.items() if
k in center_dict.keys() == pretrained_dict_center.keys()}
center_dict.update(pretrained_dict_center)
center_criterion.load_state_dict(center_dict)
optimizer_center.load_state_dict(torch.load(path_to_optimizer_center))
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch)
elif args.pretrain_choice == 'imagenet':
start_epoch = 0
scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD)
else:
print('Only support pretrain_choice for imagenet and self, but got {}'.format(args.pretrain_choice))
do_train_with_center(
cfg,
model,
center_criterion,
train_loader,
val_loader,
optimizer,
optimizer_center,
scheduler, # modify for using self trained model
loss_func,
num_query,
start_epoch, # add for using self trained model
args
)
else:
print(
"Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(args.IF_WITH_CENTER))
def main():
parser = argparse.ArgumentParser(description="Yolo v5 with ReID Baseline Training")
parser.add_argument(
"--config_file", type=str, default=r"./configs/softmax_triplet.yml", help="path to config file"
)
parser.add_argument('--LAST_STRIDE', type=int, default=1, help='last stride')
parser.add_argument('--weights', type=str, default=r'./weights/r50_ibn_2.pth')
parser.add_argument('--neck', type=str, default='bnneck', help='If train with BNNeck, options: bnneck or no')
parser.add_argument('--test_neck', type=str, default='after', help='Which feature of BNNeck to be used for test, '
'before or after BNNneck, options: before or '
'after')
parser.add_argument('--model_name', type=str, default='resnet50_ibn_a', help='Name of backbone')
# Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model
parser.add_argument('--pretrain_choice', type=str, default='imagenet')
parser.add_argument('--IF_WITH_CENTER', action='store_true', default=False, help="If train loss include center "
"loss, options: 'yes' or 'no'. "
"Loss with center loss has "
"different optimizer "
"configuration")
parser.add_argument("opts", help="Modify config options using the command-line", default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
if args.config_file != "":
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
logger.info("yolov5 reid_baseline")
logger.info("Using {} GPUS".format(num_gpus))
logger.info(args)
if args.config_file != "":
logger.info("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, 'r') as cf:
config_str = "\n" + cf.read()
logger.info(config_str)
logger.info("Running with config:\n{}".format(cfg))
if cfg.MODEL.DEVICE == "cuda":
os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu
cudnn.benchmark = True
train(cfg, args)
if __name__ == '__main__':
main()