-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathbigram_model.py
108 lines (90 loc) · 3.23 KB
/
bigram_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
"""Probabilistic bigram language model."""
import collections
import numpy as np
START_SONG = 'XXSS' # Token for the start of a new song.
START_LINE = 'XXSL' # Token for the start of a new line.
END_LINE = 'XXEL' # Token for the end of a line.
END_SONG = 'XXES' # Token for the end of a song.
PREFIXES = ('(')
# NOTE: Check `...` before `.`.
SUFFIXES = (')', '?', '!', ',' , '...', '.', ':', '"')
class ProbabilisticBigramLanguageModel(object):
"""A probabilistic bigram language model.
Given a list of bigrams, the model builds a probability distribution
over the bigrams.
"""
def __init__(self):
self._model = collections.defaultdict(collections.Counter)
def fit(self, bigrams):
for curr_word, next_word in bigrams:
self._model[curr_word].update({next_word: 1})
def predict(self, curr_word):
words, counts = list(zip(*self._model[curr_word].most_common()))
counts = np.array(counts)
probs = counts / np.sum(counts)
return np.random.choice(words, p=probs)
def process_token(token):
"""Splits out prefixes and suffixes from the token."""
# Split prefixes.
prefixes = []
is_clean = False
while not is_clean:
is_clean = True
for prefix in PREFIXES:
if token.startswith(prefix):
is_clean = False
prefixes.append(token[:len(prefix)])
token = token[len(prefix):]
# Split suffixes.
suffixes = []
is_clean = False
for suffix in SUFFIXES:
is_clean = True
if token.endswith(suffix):
is_clean = False
suffixes.append(token[-len(suffix):])
token = token[:-len(suffix)]
return prefixes + [token] + suffixes
def create_bigrams(line):
"""Returns bigrams of the line with token prefixes and suffixes expanded."""
line_tokens = line.split(' ')
tokens = []
for token in line_tokens:
tokens.extend(process_token(token))
bigrams = []
for i in range(len(tokens) - 1):
curr_word, next_word = tokens[i], tokens[i + 1]
bigrams.append((curr_word, next_word))
return bigrams
if __name__ == '__main__':
with open('dataset.txt') as file_:
songs = file_.readlines()
lines = []
for song in songs:
song = song.strip()
title, author, lyrics = song.split('\t')
lyrics = '{} {} {}'.format(START_SONG, lyrics, END_SONG)
lines.extend(lyrics.split('\\'))
bigrams = []
for line in lines:
# TODO(eugenhotaj): There is probably a smarter way to handle this.
if not line.endswith(END_SONG):
line = '{} {} {} {}'.format(START_LINE, line, END_LINE, START_LINE)
bigrams.extend(create_bigrams(line))
model = ProbabilisticBigramLanguageModel()
model.fit(bigrams)
song = []
pred = START_SONG
while not pred == END_SONG:
pred = model.predict(pred)
if pred == END_LINE:
song.append("\n")
elif pred in (START_SONG, START_LINE, END_SONG):
pass
elif pred in PREFIXES:
song.append(pred)
elif pred in SUFFIXES:
song[-1] = song[-1][:-1] + pred + " "
else:
song.append(pred + " ")
print("".join(song))