-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdata_utils.py
81 lines (69 loc) · 2.41 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import copy
import tensorflow as tf
import pickle
def batch_generator(arr, n_seqs, n_steps):
arr = copy.copy(arr)
batch_size = n_seqs * n_steps
n_batches = int(len(arr) / batch_size)
arr = arr[:batch_size * n_batches]
arr = arr.reshape((n_seqs, -1))
while True:
np.random.shuffle(arr)
for n in range(0, arr.shape[1], n_steps):
x = arr[:, n:n + n_steps]
y = np.zeros_like(x)
y[:, :-1], y[:, -1] = x[:, 1:], x[:, 0]
yield x, y
class TextConverter(object):
def __init__(self, text=None, max_vocab=8000, filename=None):
if filename is not None:
with open(filename, 'rb') as f:
self.vocab = pickle.load(f)
else:
vocab = set(text)
vocab_count = {}
for word in vocab:
vocab_count[word] = 0
for word in text:
vocab_count[word] += 1
vocab_count_list = []
for word in vocab_count:
vocab_count_list.append((word, vocab_count[word]))
vocab_count_list.sort(key=lambda x: x[1], reverse=True)
if len(vocab_count_list) > max_vocab:
vocab_count_list = vocab_count_list[:max_vocab]
vocab = [x[0] for x in vocab_count_list]
self.vocab = vocab
self.word_to_int_table = {c: i for i, c in enumerate(self.vocab)}
self.int_to_word_table = dict(enumerate(self.vocab))
@property
def vocab_size(self):
return len(self.vocab) + 1
def word_to_int(self, word):
if word in self.word_to_int_table:
return self.word_to_int_table[word]
else:
return len(self.vocab)
def int_to_word(self, index):
if index == len(self.vocab):
return '*'
elif index < len(self.vocab):
return self.int_to_word_table[index]
else:
raise Exception('Unknown index!')
def text_to_arr(self, text):
arr = []
for word in text:
arr.append(self.word_to_int(word))
return np.array(arr)
def arr_to_text(self, arr):
words = []
for index in arr:
words.append(self.int_to_word(index))
return "".join(words)
def save_to_file(self, filename):
with open(filename, 'wb') as f:
pickle.dump(self.vocab, f)