-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathsample.py
executable file
·76 lines (59 loc) · 2.47 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
__author__ = 'Tony Beltramelli www.tonybeltramelli.com - 19/08/2016'
import argparse
import codecs
from modules.Model import *
from modules.Vocabulary import *
from collections import deque
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, required=True)
parser.add_argument('--vocabulary_file', type=str, required=True)
parser.add_argument('--output_file', type=str, required=True)
parser.add_argument('--seed', type=str, default="Once upon a time, ")
parser.add_argument('--sample_length', type=int, default=1500)
parser.add_argument('--log_frequency', type=int, default=100)
args = parser.parse_args()
model_name = args.model_name
vocabulary_file = args.vocabulary_file
output_file = args.output_file
seed = args.seed.decode('utf-8')
sample_length = args.sample_length
log_frequency = args.log_frequency
model = Model(model_name)
model.restore()
classifier = model.get_classifier()
vocabulary = Vocabulary()
vocabulary.retrieve(vocabulary_file)
sample_file = codecs.open(output_file, 'w', 'utf_8')
stack = deque([])
for i in range(0, model.sequence_length - len(seed)):
stack.append(u' ')
for char in seed:
if char not in vocabulary.vocabulary:
print char,"is not in vocabulary file"
char = u' '
stack.append(char)
sample_file.write(char)
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver = tf.train.Saver(tf.global_variables())
ckpt = tf.train.get_checkpoint_state(model_name)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
for i in range(0, sample_length):
vector = []
for char in stack:
vector.append(vocabulary.binary_vocabulary[char])
vector = np.array([vector])
prediction = sess.run(classifier, feed_dict={model.x: vector})
predicted_char = vocabulary.char_lookup[np.argmax(prediction)]
stack.popleft()
stack.append(predicted_char)
sample_file.write(predicted_char)
if i % log_frequency == 0:
print "Progress: {}%".format((i * 100) / sample_length)
sample_file.close()
print "Sample saved in {}".format(output_file)
if __name__ == "__main__":
main()