-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdetect.py
155 lines (129 loc) · 6.06 KB
/
detect.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python
import argparse
import time
from sys import platform
from models import *
from utils.datasets import *
from utils.utils import *
def detect(cfg,
data,
weights,
images='data/samples', # input folder
output='output', # output folder
fourcc='mp4v', # video codec
img_size=416,
conf_thres=0.5,
nms_thres=0.5,
save_txt=False,
save_images=True):
# Initialize
device = torch_utils.select_device(force_cpu=ONNX_EXPORT)
torch.backends.cudnn.benchmark = False # set False for reproducible results
if os.path.exists(output):
shutil.rmtree(output) # delete output folder
os.makedirs(output) # make new output folder
# Initialize model
if ONNX_EXPORT:
s = (320, 192) # (320, 192) or (416, 256) or (608, 352) onnx model image size (height, width)
model = Darknet(cfg, s)
else:
model = Darknet(cfg, img_size)
# Load weights
if weights.endswith('.pt'): # pytorch format
model.load_state_dict(torch.load(weights, map_location=device)['model'])
else: # darknet format
_ = load_darknet_weights(model, weights)
# Fuse Conv2d + BatchNorm2d layers
# model.fuse()
# Eval mode
model.to(device).eval()
# Export mode
if ONNX_EXPORT:
img = torch.zeros((1, 3, s[0], s[1]))
torch.onnx.export(model, img, 'weights/export.onnx', verbose=True)
return
# Half precision
opt.half = opt.half and device.type != 'cpu' # half precision only supported on CUDA
if opt.half:
model.half()
# Set Dataloader
vid_path, vid_writer = None, None
if opt.webcam:
save_images = False
dataloader = LoadWebcam(img_size=img_size, half=opt.half)
else:
dataloader = LoadImages(images, img_size=img_size, half=opt.half)
# Get classes and colors
classes = load_classes(parse_data_cfg(data)['names'])
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(classes))]
# Run inference
t0 = time.time()
for i, (path, img, im0, vid_cap) in enumerate(dataloader):
t = time.time()
save_path = str(Path(output) / Path(path).name)
# Get detections
img = torch.from_numpy(img).unsqueeze(0).to(device)
pred, _ = model(img)
det = non_max_suppression(pred.float(), conf_thres, nms_thres)[0]
if det is not None and len(det) > 0:
# Rescale boxes from 416 to true image size
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
# Print results to screen
print('%gx%g ' % img.shape[2:], end='') # print image size
for c in det[:, -1].unique():
n = (det[:, -1] == c).sum()
print('%g %ss' % (n, classes[int(c)]), end=', ')
# Draw bounding boxes and labels of detections
for *xyxy, conf, cls_conf, cls in det:
if save_txt: # Write to file
with open(save_path + '.txt', 'a') as file:
file.write(('%g ' * 6 + '\n') % (*xyxy, cls, conf))
# Add bbox to the image
label = '%s %.2f' % (classes[int(cls)], conf)
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)])
print('Done. (%.3fs)' % (time.time() - t))
if opt.webcam: # Show live webcam
cv2.imshow(weights, im0)
if save_images: # Save image with detections
if dataloader.mode == 'images':
cv2.imwrite(save_path, im0)
else:
if vid_path != save_path: # new video
vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter):
vid_writer.release() # release previous video writer
fps = vid_cap.get(cv2.CAP_PROP_FPS)
width = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (width, height))
vid_writer.write(im0)
if save_images:
print('Results saved to %s' % os.getcwd() + os.sep + output)
if platform == 'darwin': # macos
os.system('open ' + output + ' ' + save_path)
print('Done. (%.3fs)' % (time.time() - t0))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', type=str, default='cfg/yolov3.cfg', help='cfg file path')
parser.add_argument('--data', type=str, default='data/ball.data', help='coco.data file path')
parser.add_argument('--weights', type=str, default='weights/ball.pt', help='path to weights file')
parser.add_argument('--images', type=str, default='data/samples', help='path to images')
parser.add_argument('--img-size', type=int, default=416, help='inference size (pixels)')
parser.add_argument('--conf-thres', type=float, default=0.3, help='object confidence threshold')
parser.add_argument('--nms-thres', type=float, default=0.5, help='iou threshold for non-maximum suppression')
parser.add_argument('--fourcc', type=str, default='mp4v', help='fourcc output video codec (verify ffmpeg support)')
parser.add_argument('--output', type=str, default='output', help='specifies the output path for images and videos')
parser.add_argument('--half', action='store_true', help='half precision FP16 inference')
parser.add_argument('--webcam', action='store_true', help='use webcam')
opt = parser.parse_args()
print(opt)
with torch.no_grad():
detect(opt.cfg,
opt.data,
opt.weights,
images=opt.images,
img_size=opt.img_size,
conf_thres=opt.conf_thres,
nms_thres=opt.nms_thres,
fourcc=opt.fourcc,
output=opt.output)