forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nets.py
106 lines (85 loc) · 3.54 KB
/
nets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright 2019 DeepMind Technologies Limited and Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Network utilities."""
import functools
import re
import numpy as np
import sonnet as snt
import tensorflow.compat.v1 as tf
import tensorflow_gan as tfgan
def _sn_custom_getter():
def name_filter(name):
match = re.match(r'.*w(_.*)?$', name)
return match is not None
return tfgan.features.spectral_normalization_custom_getter(
name_filter=name_filter)
class SNGenNet(snt.AbstractModule):
"""As in the SN paper."""
def __init__(self, name='conv_gen'):
super(SNGenNet, self).__init__(name=name)
def _build(self, inputs, is_training):
batch_size = inputs.get_shape().as_list()[0]
first_shape = [4, 4, 512]
norm_ctor = snt.BatchNormV2
norm_ctor_config = {'scale': True}
up_tensor = snt.Linear(np.prod(first_shape))(inputs)
first_tensor = tf.reshape(up_tensor, shape=[batch_size] + first_shape)
net = snt.nets.ConvNet2DTranspose(
output_channels=[256, 128, 64, 3],
output_shapes=[(8, 8), (16, 16), (32, 32), (32, 32)],
kernel_shapes=[(4, 4), (4, 4), (4, 4), (3, 3)],
strides=[2, 2, 2, 1],
normalization_ctor=norm_ctor,
normalization_kwargs=norm_ctor_config,
normalize_final=False,
paddings=[snt.SAME], activate_final=False, activation=tf.nn.relu)
output = net(first_tensor, is_training=is_training)
return tf.nn.tanh(output)
class SNMetricNet(snt.AbstractModule):
"""Spectral normalization discriminator (metric) architecture."""
def __init__(self, num_outputs=2, name='sn_metric'):
super(SNMetricNet, self).__init__(name=name)
self._num_outputs = num_outputs
def _build(self, inputs):
with tf.variable_scope('', custom_getter=_sn_custom_getter()):
net = snt.nets.ConvNet2D(
output_channels=[64, 64, 128, 128, 256, 256, 512],
kernel_shapes=[
(3, 3), (4, 4), (3, 3), (4, 4), (3, 3), (4, 4), (3, 3)],
strides=[1, 2, 1, 2, 1, 2, 1],
paddings=[snt.SAME], activate_final=True,
activation=functools.partial(tf.nn.leaky_relu, alpha=0.1))
linear = snt.Linear(self._num_outputs)
output = linear(snt.BatchFlatten()(net(inputs)))
return output
class MLPGeneratorNet(snt.AbstractModule):
"""MNIST generator net."""
def __init__(self, name='mlp_generator'):
super(MLPGeneratorNet, self).__init__(name=name)
def _build(self, inputs, is_training=True):
del is_training
net = snt.nets.MLP([500, 500, 784], activation=tf.nn.leaky_relu)
out = net(inputs)
out = tf.nn.tanh(out)
return snt.BatchReshape([28, 28, 1])(out)
class MLPMetricNet(snt.AbstractModule):
"""Same as in Grover and Ermon, ICLR workshop 2017."""
def __init__(self, num_outputs=2, name='mlp_metric'):
super(MLPMetricNet, self).__init__(name=name)
self._layer_size = [500, 500, num_outputs]
def _build(self, inputs):
net = snt.nets.MLP(self._layer_size,
activation=tf.nn.leaky_relu)
output = net(snt.BatchFlatten()(inputs))
return output