-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathvae.py
122 lines (111 loc) · 5.27 KB
/
vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Input, Lambda, Flatten, Reshape, Conv2D
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.optimizers import Adamax
from normalizing_flows.flows import Flow
from normalizing_flows.layers import GatedConv2D, GatedConv2DTranspose, FlowLayer
class GatedConvVAE(tf.Module):
"""
Gated, convolutional variational autoencoder with support for normalizing flows.
"""
def __init__(self, img_wt, img_ht, flow: Flow = None, hidden_units=32, z_size=64,
encoder_strides=[2,2], decoder_strides=[2,2],
callbacks=[], metrics=[], output_activation='sigmoid', loss='binary_crossentropy',
beta_update_fn=None):
super(GatedConvVAE, self).__init__()
if beta_update_fn is None:
beta_update_fn = lambda i, beta: 1.0E-2*i
self.flow = flow
self.hidden_units = hidden_units
self.z_size = z_size
self.num_downsamples = len(encoder_strides)
self.num_upsamples = len(decoder_strides)
self.encoder_strides = encoder_strides
self.decoder_strides = decoder_strides
self.output_activation = output_activation
self.encoder = self._create_encoder(img_wt, img_ht)
self.decoder, self.flow_layer = self._create_decoder(img_wt, img_ht)
beta_update = LambdaCallback(on_epoch_begin=lambda i,_: beta_update_fn(i, self.flow_layer.beta))
self.callbacks = [beta_update]+callbacks
decoder_output = self.decoder(self.encoder(self.encoder.inputs))
self.model = Model(inputs=self.encoder.inputs, outputs=decoder_output[0])
self.model.compile(loss=loss, optimizer=Adamax(learning_rate=1.0E-4, clipnorm=1.), metrics=metrics)
def fit(self, *args, **kwargs):
"""
Passthrough to tf.keras.Model::fit
"""
return self.model.fit(*args, **kwargs)
def predict(self, *args, **kwargs):
"""
Passthrough to tf.keras.Model::predict
"""
return self.model.predict(*args, **kwargs)
def evaluate(self, *args, **kwargs):
"""
Passthrough to tf.keras.Model::evaluate
"""
return self.model.evaluate(*args, **kwargs)
def sample(self, x, n=1):
"""
Sample from the conditional distribution Z ~ P(z|x)
Returns a tuple (x', zs) where x' is the reconstructed input and
zs = [z_0, z_1, ... , z_k] where k is the number of flows.
"""
input_shape = tf.shape(x)
# add sample dim
x = tf.expand_dims(x, axis=1)
# broadcast according to number of samples
x = tf.broadcast_to(x, (input_shape[0], n, *input_shape[1:]))
# fold sample dim back into batch axis
x = tf.reshape(x, (input_shape[0]*n, *input_shape[1:]))
# encode/decode inputs and retrieve samples
if self.flow is not None:
z_mu, z_log_sigma, params = self.encoder.predict(x)
outputs = self.decoder.predict([z_mu, z_log_sigma, params])
else:
z_mu, z_log_sigma = self.encoder.predict(x)
outputs = self.decoder.predict([z_mu, z_log_sigma])
# return x', (z_0, ..., z_k)
return outputs[0], outputs[1:]
def _conv_downsample(self, f, strides, x):
g = GatedConv2D(f, 3, activation='linear')
g_downsample = GatedConv2D(f, 3, strides=strides)
return g_downsample(g(x))
def _conv_upsample(self, f, strides, x):
g = GatedConv2DTranspose(f, 3, activation='linear')
g_upsample = GatedConv2DTranspose(f, 3, strides=strides)
return g_upsample(g(x))
def _create_encoder(self, wt, ht):
input_0 = Input((wt, ht, 1))
h = input_0
for i in range(self.num_downsamples):
h = self._conv_downsample(self.hidden_units*(i+1), self.encoder_strides[i], h)
z_mu = Dense(self.z_size, activation='linear')(Flatten()(h))
z_log_var = Dense(self.z_size, activation='linear')(Flatten()(h))
outputs = [z_mu, z_log_var]
if self.flow is not None:
z_shape = tf.TensorShape((None,self.z_size))
params = Dense(self.flow.param_count(z_shape), activation='linear')(Flatten()(h))
outputs += [params]
return Model(inputs=input_0, outputs=outputs)
def _create_decoder(self, wt, ht):
z_mu = Input(shape=(self.z_size,))
z_log_var = Input(shape=(self.z_size,))
inputs = [z_mu, z_log_var]
if self.flow is not None:
z_shape = tf.TensorShape((None, self.z_size))
params = Input(shape=(self.flow.param_count(z_shape),))
inputs += [params]
self.flow.initialize(z_shape)
flow_layer = FlowLayer(self.flow, min_beta=1.0E-3)
zs, ldj, kld = flow_layer(inputs)
z_k = zs[-1]
s = np.prod(self.encoder_strides)
h_k = Dense(wt*ht // s**2, activation='linear')(z_k)
h_k = Reshape((wt // s, ht // s, 1))(h_k)
for i in range(self.num_upsamples):
h_k = self._conv_upsample(self.hidden_units*(i+1), self.decoder_strides[i], h_k)
output_0 = Conv2D(1, 1, activation=self.output_activation, padding='same')(h_k)
return Model(inputs=inputs, outputs=[output_0] + zs), flow_layer