forked from dennybritz/chatbot-retrieval
-
Notifications
You must be signed in to change notification settings - Fork 0
/
udc_train.py
executable file
·64 lines (50 loc) · 2.01 KB
/
udc_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import time
import itertools
import tensorflow as tf
import udc_model
import udc_hparams
import udc_metrics
import udc_inputs
from models.dual_encoder import dual_encoder_model
tf.flags.DEFINE_string("input_dir", "./data", "Directory containing input data files 'train.tfrecords' and 'validation.tfrecords'")
tf.flags.DEFINE_string("model_dir", None, "Directory to store model checkpoints (defaults to ./runs)")
tf.flags.DEFINE_integer("loglevel", 20, "Tensorflow log level")
tf.flags.DEFINE_integer("num_epochs", None, "Number of training Epochs. Defaults to indefinite.")
tf.flags.DEFINE_integer("eval_every", 2000, "Evaluate after this many train steps")
FLAGS = tf.flags.FLAGS
TIMESTAMP = int(time.time())
if FLAGS.model_dir:
MODEL_DIR = FLAGS.model_dir
else:
MODEL_DIR = os.path.abspath(os.path.join("./runs", str(TIMESTAMP)))
TRAIN_FILE = os.path.abspath(os.path.join(FLAGS.input_dir, "train.tfrecords"))
VALIDATION_FILE = os.path.abspath(os.path.join(FLAGS.input_dir, "validation.tfrecords"))
tf.logging.set_verbosity(FLAGS.loglevel)
def main(unused_argv):
hparams = udc_hparams.create_hparams()
model_fn = udc_model.create_model_fn(
hparams,
model_impl=dual_encoder_model)
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=MODEL_DIR,
config=tf.contrib.learn.RunConfig())
input_fn_train = udc_inputs.create_input_fn(
mode=tf.contrib.learn.ModeKeys.TRAIN,
input_files=[TRAIN_FILE],
batch_size=hparams.batch_size,
num_epochs=FLAGS.num_epochs)
input_fn_eval = udc_inputs.create_input_fn(
mode=tf.contrib.learn.ModeKeys.EVAL,
input_files=[VALIDATION_FILE],
batch_size=hparams.eval_batch_size,
num_epochs=1)
eval_metrics = udc_metrics.create_evaluation_metrics()
eval_monitor = tf.contrib.learn.monitors.ValidationMonitor(
input_fn=input_fn_eval,
every_n_steps=FLAGS.eval_every,
metrics=eval_metrics)
estimator.fit(input_fn=input_fn_train, steps=None, monitors=[eval_monitor])
if __name__ == "__main__":
tf.app.run()