-
Notifications
You must be signed in to change notification settings - Fork 519
/
test.py
84 lines (74 loc) · 3.57 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#coding=utf-8
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import mxnet as mx
import numpy as np
import cv2, random
import os
from io import BytesIO
from train import gen_rand, gen_sample
chars = ["京", "沪", "津", "渝", "冀", "晋", "蒙", "辽", "吉", "黑", "苏", "浙", "皖", "闽", "赣", "鲁", "豫", "鄂", "湘", "粤", "桂",
"琼", "川", "贵", "云", "藏", "陕", "甘", "青", "宁", "新", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A",
"B", "C", "D", "E", "F", "G", "H", "J", "K", "L", "M", "N", "P", "Q", "R", "S", "T", "U", "V", "W", "X",
"Y", "Z"
];
def getnet():
data = mx.symbol.Variable('data')
label = mx.symbol.Variable('softmax_label')
conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=32)
pool1 = mx.symbol.Pooling(data=conv1, pool_type="max", kernel=(2,2), stride=(1, 1))
relu1 = mx.symbol.Activation(data=pool1, act_type="relu")
conv2 = mx.symbol.Convolution(data=relu1, kernel=(5,5), num_filter=32)
pool2 = mx.symbol.Pooling(data=conv2, pool_type="avg", kernel=(2,2), stride=(1, 1))
relu2 = mx.symbol.Activation(data=pool2, act_type="relu")
# conv3 = mx.symbol.Convolution(data=relu2, kernel=(3,3), num_filter=32)
# pool3 = mx.symbol.Pooling(data=conv3, pool_type="avg", kernel=(2,2), stride=(1, 1))
# relu3 = mx.symbol.Activation(data=pool3, act_type="relu")
#
# conv4 = mx.symbol.Convolution(data=relu3, kernel=(3,3), num_filter=32)
# pool4 = mx.symbol.Pooling(data=conv4, pool_type="avg", kernel=(2,2), stride=(1, 1))
# relu4 = mx.symbol.Activation(data=pool4, act_type="relu")
flatten = mx.symbol.Flatten(data = relu2)
fc1 = mx.symbol.FullyConnected(data = flatten, num_hidden = 120)
fc21 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc22 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc23 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc24 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc25 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc26 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc27 = mx.symbol.FullyConnected(data = fc1, num_hidden = 65)
fc2 = mx.symbol.Concat(*[fc21, fc22, fc23, fc24,fc25,fc26,fc27], dim = 0)
return mx.symbol.SoftmaxOutput(data = fc2, name = "softmax")
def TestRecognizeOne(img):
img = cv2.resize(img,(120,30))
cv2.imshow("img",img);
print img.shape
img = np.swapaxes(img,0,2)
img = np.swapaxes(img,1,2)
print img.shape
batch_size = 1
_, arg_params, __ = mx.model.load_checkpoint("cnn-ocr", 1)
data_shape = [("data", (batch_size, 3, 30, 120))]
input_shapes = dict(data_shape)
sym = getnet()
executor = sym.simple_bind(ctx = mx.cpu(), **input_shapes)
for key in executor.arg_dict.keys():
if key in arg_params:
arg_params[key].copyto(executor.arg_dict[key])
executor.forward(is_train = True, data = mx.nd.array([img]))
probs = executor.outputs[0].asnumpy()
line = ''
for i in range(probs.shape[0]):
if i == 0:
result = np.argmax(probs[i][0:31])
if i == 1:
result = np.argmax(probs[i][41:65])+41
if i > 1:
result = np.argmax(probs[i][31:65])+31
line += chars[result]+" "
print 'predicted: ' + line
cv2.waitKey(0)
if __name__ == '__main__':
TestRecognizeOne(cv2.imread("./plate/01.jpg"))