-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathsentence.py
119 lines (102 loc) · 4.15 KB
/
sentence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import argparse
import numpy as np
import features
from main import load_subreddit
def make_ngram_freq(X_train, model):
"""
Returns a list of (ngram, score) sorted by score, where
- ngram is a string representing the ngram
- score is the sum of all rows of X_train for that ngram
"""
# sum column-wise of X_train to get features -> count for that feature
X_freq = X_train.sum(0)
X_freq = np.array(X_freq)[0].tolist()
# get original n_grams
ngrams = model.vectorizer.get_feature_names()
# sort by highest frequency
ngram_freq = zip(ngrams, X_freq)
ngram_freq = sorted(ngram_freq, key=lambda (_,v) : v, reverse=True)
return ngram_freq
def find_next_word(first, ngram_freq, random=False):
"""
assumes ngram_freq is already reverse sorted by score
given the first word, finds the highest scored ngram whose first
word is `first`, and returns the next word of that ngram.
"""
# get only ngrams whose first word matches first
filtered_ngrams = filter(lambda (ngram, v): ngram.split()[0] == first,
ngram_freq)
if len(filtered_ngrams) == 0:
return ""
# if we use randomness, we use the ngram's score
# as part of a multinomial distribution so that we
# are more likely to sample from ngrams with higher score
if random:
probs = [score for (word,score) in filtered_ngrams]
sum_score = float(sum(probs))
probs = [score / sum_score for score in probs]
result = np.random.multinomial(n=1, pvals=probs, size=1)
index = result[0].tolist().index(1)
best_ngram = filtered_ngrams[index][0]
return best_ngram.split()[1]
# return second word of ngram
best_ngram = filtered_ngrams[0][0]
next = best_ngram.split()[1]
# to avoid infinite recursion, make sure the next word
# is different from the current word
if next != first or len(filtered_ngrams) == 1:
return next
best_ngram = filtered_ngrams[1][0]
return best_ngram.split()[1]
def build_sentence(start, nwords, ngram_freq, random=False):
"""
Builds a sentence of length `nwords` with a starting word `start`
using the most likely next word based on ngram scores in ngram_freq.
Returns the sentence as a list of words
"""
sentence = [start]
for _ in range(nwords):
current = sentence[-1]
next = find_next_word(current, ngram_freq, random=random)
if len(next) == 0:
return sentence
sentence.append(next)
return sentence
if __name__ == '__main__':
parser = argparse.ArgumentParser("sentence builder")
parser.add_argument("subreddit", help="path to subreddit file", type=str)
parser.add_argument("ngram", help="N to use", type=int)
parser.add_argument("--start", help="starting word", type=str,
dest="start", default="")
args = vars(parser.parse_args())
ngram = args["ngram"]
model = features.NGramModel(ngram)
print "N used for ngram: %d" % ngram
data_file = args["subreddit"]
df = load_subreddit(data_file)
print "loaded: %s" % data_file
# Make the training set
print "making training data..."
X_train, Y_train = model.make_training_xy(df)
# Build most likely sentence with a root word
ngram_freq = make_ngram_freq(X_train, model)
nwords = 20
def print_trial(start, nwords, random):
print ">>> start: %s, length: %d, random: %s" % (start, nwords, str(random))
sentence = build_sentence(start, nwords, ngram_freq, random=random)
print " ".join(sentence)
if args["start"] == "":
print_trial("obama", nwords, False)
print_trial("obama", nwords, True)
print_trial("liberal", nwords, False)
print_trial("liberal", nwords, True)
print_trial("liberals", nwords, False)
print_trial("liberals", nwords, True)
print_trial("republican", nwords, False)
print_trial("republican", nwords, True)
print_trial("republicans", nwords, False)
print_trial("republicans", nwords, True)
else:
start = args["start"]
print_trial(start, nwords, False)
print_trial(start, nwords, True)