-
Notifications
You must be signed in to change notification settings - Fork 2
/
follow_vae.py
155 lines (123 loc) · 4.96 KB
/
follow_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
# Dependency imports
from absl import flags
import numpy as np
from six.moves import urllib
import tensorflow as tf
from model import AudioMPS
flags.DEFINE_integer(
"viz_steps", default=2, help="Frequency at which to save visualizations.")
flags.DEFINE_integer(
"max_steps", default=5001, help="Number of training steps to run.")
flags.DEFINE_integer(
"bond_d", default=10, help="Bond dimension.")
flags.DEFINE_float(
"dt", default=0.001, help="Time discretization.")
flags.DEFINE_bool(
"discr",
default=False,
help="If false, we are using a pure state.")
# flags.DEFINE_float(
# "learning_rate", default=1e-3, help="Initial learning rate.")
flags.DEFINE_integer(
"batch_size",
default=32,
help="Batch size.")
flags.DEFINE_string(
"model_dir",
default=os.path.join(os.getenv("TEST_TMPDIR", "/tmp"), "vae/"),
help="Directory to put the model's fit.")
flags.DEFINE_string(
"data_dir",
default="/Users/mencia/PhD_local/audioMPS/data/pitch_30.tfrecords",
help="Directory where data is stored (if using real data).")
FLAGS = flags.FLAGS
# self.bond_d = bond_d
# self.delta_t = delta_t
# self.batch_size = batch_size
# self.R = tf.get_variable("R", shape=[bond_d, bond_d], dtype=tf.float32, initializer=None)
# self.H = tf.get_variable("H", shape=[bond_d, bond_d], dtype=tf.float32, initializer=None)
# self.H = self._symmetrize(self.H)
# self.loss = self._build_loss_psi(data_iterator)
self.bond_d = FLAGS.bond_d
def build_loss_psi(data):
batch_zeros = tf.zeros_like(data[:, 0])
psi_0 = tf.one_hot(tf.cast(batch_zeros, dtype=tf.int32), self.bond_d, dtype=tf.complex64)
loss = batch_zeros
data = tf.transpose(data, [1, 0]) # foldl goes along the first dimension
_, loss = tf.foldl(_psi_and_loss_update, data,
initializer=(psi_0, loss), name="loss_fold")
return tf.reduce_mean(loss)
def _psi_and_loss_update(psi_and_loss, signal):
psi, loss = psi_and_loss
loss += _inc_loss_psi(psi, signal)
return psi, loss
def _inc_loss_psi(psi, signal):
return (signal - _expectation_psi(psi)) ** 2 / 2
def _expectation_psi(psi):
R = tf.get_variable(name="R", shape=[self.bond_d, self.bond_d], dtype=tf.float32, initializer=None)
R_c = tf.cast(R, dtype=tf.complex64)
exp = tf.einsum('ab,bc,ac->a', tf.conj(psi), R_c, psi)
return 2 * tf.real(exp)
# def audiomps(bond_d, dt, batch_size, data, discr):
# our_model = AudioMPS(bond_d, dt, batch_size, data_iterator=data, mixed=discr)
# return our_model
def model_fn(features, labels, mode, params, config):
"""Builds the model function for use in an estimator.
Arguments:
features: The input features for the estimator.
labels: The labels, unused here.
mode: Signifies whether it is train or test or predict.
params: Some hyperparameters as a dictionary.
config: The RunConfig, unused here.
Returns:
EstimatorSpec: A tf.estimator.EstimatorSpec instance.
"""
del labels, config
# PARAMS ARE THE FLAGS DEFINED ABOVE
# FEATURES CORRECTLY USED ??????????????????????????????
data = features
# loss = audiomps(params["bond_d"], params["dt"], params["batch_size"], data, params["discr"]).loss
loss = build_loss_psi(data, params["bond_d"])
step = tf.get_variable("global_step", [], tf.int64, tf.zeros_initializer(), trainable=False)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=step)
return tf.estimator.EstimatorSpec(
mode=mode,
loss=loss,
train_op=train_op,
eval_metric_ops={"loss": tf.metrics.mean(loss)})
def static_nsynth_dataset(directory):
"""Returns binary static NSynth tf.data.Dataset."""
dataset = tf.data.TFRecordDataset(directory)
def _parser(example_proto):
features = {"audio": tf.FixedLenFeature([2 ** 16], dtype=tf.float32)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["audio"] # Do I put ["audio"] ??????
return dataset.map(_parser)
def build_input_fns(data_dir, batch_size):
"""Builds an Iterator switching between train and heldout data."""
# Build an iterator over training batches.
training_dataset = static_nsynth_dataset(data_dir)
training_dataset = training_dataset.shuffle(buffer_size=24).repeat().batch(batch_size)
train_input_fn = lambda: training_dataset.make_one_shot_iterator().get_next()
return train_input_fn
def main(argv):
del argv # unused
params = FLAGS.flag_values_dict()
tf.gfile.MakeDirs(FLAGS.model_dir)
train_input_fn = build_input_fns(FLAGS.data_dir, FLAGS.batch_size)
estimator = tf.estimator.Estimator(
model_fn,
params=params,
config=tf.estimator.RunConfig(
model_dir=FLAGS.model_dir,
save_checkpoints_steps=FLAGS.viz_steps,
),
)
for _ in range(FLAGS.max_steps // FLAGS.viz_steps):
estimator.train(train_input_fn, steps=FLAGS.viz_steps)
tf.app.run()