forked from PaddlePaddle/PaddleClas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_json_config.py
138 lines (128 loc) · 4.55 KB
/
generate_json_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import argparse
import json
import os
import yaml
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--yaml_path', type=str, default='../configs/inference_drink.yaml')
parser.add_argument(
'--img_dir',
type=str,
default=None,
help='The dir path for inference images')
parser.add_argument(
'--img_path',
type=str,
default=None,
help='The dir path for inference images')
parser.add_argument(
'--det_model_path',
type=str,
default='./det.nb',
help="The model path for mainbody detection")
parser.add_argument(
'--rec_model_path',
type=str,
default='./rec.nb',
help="The rec model path")
parser.add_argument(
'--rec_label_path',
type=str,
default='./label.txt',
help='The rec model label')
parser.add_argument(
'--arch',
type=str,
default='PicoDet',
help='The model structure for detection model')
parser.add_argument(
'--fpn-stride',
type=list,
default=[8, 16, 32, 64],
help="The fpn strid for detection model")
parser.add_argument(
'--keep_top_k',
type=int,
default=100,
help='The params for nms(postprocess for detection)')
parser.add_argument(
'--nms-name',
type=str,
default='MultiClassNMS',
help='The nms name for postprocess of detection model')
parser.add_argument(
'--nms_threshold',
type=float,
default=0.5,
help='The nms nms_threshold for detection postprocess')
parser.add_argument(
'--nms_top_k',
type=int,
default=1000,
help='The nms_top_k in postprocess of detection model')
parser.add_argument(
'--score_threshold',
type=float,
default=0.3,
help='The score_threshold for postprocess of detection')
args = parser.parse_args()
return args
def main():
args = parse_args()
config_yaml = yaml.safe_load(open(args.yaml_path))
config_json = {}
config_json["Global"] = {}
config_json["Global"][
"infer_imgs"] = args.img_path if args.img_path else config_yaml[
"Global"]["infer_imgs"]
if args.img_dir is not None:
config_json["Global"]["infer_imgs_dir"] = args.img_dir
config_json["Global"]["infer_imgs"] = None
else:
config_json["Global"][
"infer_imgs"] = args.img_path if args.img_path else config_yaml[
"Global"]["infer_imgs"]
config_json["Global"]["batch_size"] = config_yaml["Global"]["batch_size"]
config_json["Global"]["cpu_num_threads"] = min(
config_yaml["Global"]["cpu_num_threads"], 4)
config_json["Global"]["image_shape"] = config_yaml["Global"]["image_shape"]
config_json["Global"]["det_model_path"] = args.det_model_path
config_json["Global"]["rec_model_path"] = args.rec_model_path
config_json["Global"]["rec_label_path"] = args.rec_label_path
config_json["Global"]["label_list"] = config_yaml["Global"]["labe_list"]
config_json["Global"]["rec_nms_thresold"] = config_yaml["Global"][
"rec_nms_thresold"]
config_json["Global"]["max_det_results"] = config_yaml["Global"][
"max_det_results"]
config_json["Global"]["det_fpn_stride"] = args.fpn_stride
config_json["Global"]["det_arch"] = args.arch
config_json["Global"]["return_k"] = config_yaml["IndexProcess"]["return_k"]
# config_json["DetPreProcess"] = config_yaml["DetPreProcess"]
config_json["DetPreProcess"] = {}
config_json["DetPreProcess"]["transform_ops"] = []
for x in config_yaml["DetPreProcess"]["transform_ops"]:
k = list(x.keys())[0]
y = x[k]
y['type'] = k
config_json["DetPreProcess"]["transform_ops"].append(y)
config_json["DetPostProcess"] = {
"keep_top_k": args.keep_top_k,
"name": args.nms_name,
"nms_threshold": args.nms_threshold,
"nms_top_k": args.nms_top_k,
"score_threshold": args.score_threshold
}
# config_json["RecPreProcess"] = config_yaml["RecPreProcess"]
config_json["RecPreProcess"] = {}
config_json["RecPreProcess"]["transform_ops"] = []
for x in config_yaml["RecPreProcess"]["transform_ops"]:
k = list(x.keys())[0]
y = x[k]
if y is not None:
y["type"] = k
config_json["RecPreProcess"]["transform_ops"].append(y)
with open('shitu_config.json', 'w') as fd:
json.dump(config_json, fd, indent=4)
if __name__ == '__main__':
main()