-
Notifications
You must be signed in to change notification settings - Fork 120
/
train_model.py
179 lines (129 loc) · 6.16 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
'''
This module describes the contrastive predictive coding model from DeepMind:
Oord, Aaron van den, Yazhe Li, and Oriol Vinyals.
"Representation Learning with Contrastive Predictive Coding."
arXiv preprint arXiv:1807.03748 (2018).
'''
from data_utils import SortedNumberGenerator
from os.path import join, basename, dirname, exists
import keras
from keras import backend as K
def network_encoder(x, code_size):
''' Define the network mapping images to embeddings '''
x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(units=256, activation='linear')(x)
x = keras.layers.BatchNormalization()(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Dense(units=code_size, activation='linear', name='encoder_embedding')(x)
return x
def network_autoregressive(x):
''' Define the network that integrates information along the sequence '''
# x = keras.layers.GRU(units=256, return_sequences=True)(x)
# x = keras.layers.BatchNormalization()(x)
x = keras.layers.GRU(units=256, return_sequences=False, name='ar_context')(x)
return x
def network_prediction(context, code_size, predict_terms):
''' Define the network mapping context to multiple embeddings '''
outputs = []
for i in range(predict_terms):
outputs.append(keras.layers.Dense(units=code_size, activation="linear", name='z_t_{i}'.format(i=i))(context))
if len(outputs) == 1:
output = keras.layers.Lambda(lambda x: K.expand_dims(x, axis=1))(outputs[0])
else:
output = keras.layers.Lambda(lambda x: K.stack(x, axis=1))(outputs)
return output
class CPCLayer(keras.layers.Layer):
''' Computes dot product between true and predicted embedding vectors '''
def __init__(self, **kwargs):
super(CPCLayer, self).__init__(**kwargs)
def call(self, inputs):
# Compute dot product among vectors
preds, y_encoded = inputs
dot_product = K.mean(y_encoded * preds, axis=-1)
dot_product = K.mean(dot_product, axis=-1, keepdims=True) # average along the temporal dimension
# Keras loss functions take probabilities
dot_product_probs = K.sigmoid(dot_product)
return dot_product_probs
def compute_output_shape(self, input_shape):
return (input_shape[0][0], 1)
def network_cpc(image_shape, terms, predict_terms, code_size, learning_rate):
''' Define the CPC network combining encoder and autoregressive model '''
# Set learning phase (https://stackoverflow.com/questions/42969779/keras-error-you-must-feed-a-value-for-placeholder-tensor-bidirectional-1-keras)
K.set_learning_phase(1)
# Define encoder model
encoder_input = keras.layers.Input(image_shape)
encoder_output = network_encoder(encoder_input, code_size)
encoder_model = keras.models.Model(encoder_input, encoder_output, name='encoder')
encoder_model.summary()
# Define rest of model
x_input = keras.layers.Input((terms, image_shape[0], image_shape[1], image_shape[2]))
x_encoded = keras.layers.TimeDistributed(encoder_model)(x_input)
context = network_autoregressive(x_encoded)
preds = network_prediction(context, code_size, predict_terms)
y_input = keras.layers.Input((predict_terms, image_shape[0], image_shape[1], image_shape[2]))
y_encoded = keras.layers.TimeDistributed(encoder_model)(y_input)
# Loss
dot_product_probs = CPCLayer()([preds, y_encoded])
# Model
cpc_model = keras.models.Model(inputs=[x_input, y_input], outputs=dot_product_probs)
# Compile model
cpc_model.compile(
optimizer=keras.optimizers.Adam(lr=learning_rate),
loss='binary_crossentropy',
metrics=['binary_accuracy']
)
cpc_model.summary()
return cpc_model
def train_model(epochs, batch_size, output_dir, code_size, lr=1e-4, terms=4, predict_terms=4, image_size=28, color=False):
# Prepare data
train_data = SortedNumberGenerator(batch_size=batch_size, subset='train', terms=terms,
positive_samples=batch_size // 2, predict_terms=predict_terms,
image_size=image_size, color=color, rescale=True)
validation_data = SortedNumberGenerator(batch_size=batch_size, subset='valid', terms=terms,
positive_samples=batch_size // 2, predict_terms=predict_terms,
image_size=image_size, color=color, rescale=True)
# Prepares the model
model = network_cpc(image_shape=(image_size, image_size, 3), terms=terms, predict_terms=predict_terms,
code_size=code_size, learning_rate=lr)
# Callbacks
callbacks = [keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=1/3, patience=2, min_lr=1e-4)]
# Trains the model
model.fit_generator(
generator=train_data,
steps_per_epoch=len(train_data),
validation_data=validation_data,
validation_steps=len(validation_data),
epochs=epochs,
verbose=1,
callbacks=callbacks
)
# Saves the model
# Remember to add custom_objects={'CPCLayer': CPCLayer} to load_model when loading from disk
model.save(join(output_dir, 'cpc.h5'))
# Saves the encoder alone
encoder = model.layers[1].layer
encoder.save(join(output_dir, 'encoder.h5'))
if __name__ == "__main__":
train_model(
epochs=10,
batch_size=32,
output_dir='models/64x64',
code_size=128,
lr=1e-3,
terms=4,
predict_terms=4,
image_size=64,
color=True
)