forked from FLming/CRNN.tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
44 lines (34 loc) · 1.66 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from tensorflow import keras
from tensorflow.keras import layers
def vgg_style(input_tensor):
"""
The original feature extraction structure from CRNN paper.
Related paper: https://ieeexplore.ieee.org/abstract/document/7801919
"""
x = layers.Conv2D(64, 3, padding='same', activation='relu')(input_tensor)
x = layers.MaxPool2D(pool_size=2, padding='same')(x)
x = layers.Conv2D(128, 3, padding='same', activation='relu')(x)
x = layers.MaxPool2D(pool_size=2, padding='same')(x)
x = layers.Conv2D(256, 3, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(256, 3, padding='same', activation='relu')(x)
x = layers.MaxPool2D(pool_size=2, strides=(2, 1), padding='same')(x)
x = layers.Conv2D(512, 3, padding='same', use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(512, 3, padding='same', activation='relu')(x)
x = layers.MaxPool2D(pool_size=2, strides=(2, 1), padding='same')(x)
x = layers.Conv2D(512, 2, use_bias=False)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
return x
def build_model(num_classes, image_width=None, channels=1):
"""build CNN-RNN model"""
img_input = keras.Input(shape=(32, image_width, channels))
x = vgg_style(img_input)
x = layers.Reshape((-1, 512))(x)
x = layers.Bidirectional(layers.LSTM(units=256, return_sequences=True))(x)
x = layers.Bidirectional(layers.LSTM(units=256, return_sequences=True))(x)
x = layers.Dense(units=num_classes)(x)
return keras.Model(inputs=img_input, outputs=x, name='CRNN')