-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsuccessor_model.py
More file actions
200 lines (190 loc) · 6.81 KB
/
successor_model.py
File metadata and controls
200 lines (190 loc) · 6.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import random
import pickle
from urllib.request import urlopen
class SuccessorModel(object):
SENTENCE_LOWER_BOUND = 6
SENTENCE_UPPER_BOUND = 15
stop_words = ['of', 'a', 'an', 'the', 'by', 'for', 'of', 'from']
end_punctuation = ['.', '!', '?']
BIGRAM_MULTIPLICATIVE_FACTOR = 5
STOP_WORDS_FACTOR = 50
def __init__(self, filename):
self.path = process(filename)
self.tokens = get_tokens(self.path)
self.unigram_table = build_unigram_table(self.tokens)
self.bigram_table = build_bigram_table(self.tokens)
########################
#####Public Methods#####
########################
def random_sent(self):
sentence = ''
while (len(sentence.split()) < 6) or (len(sentence.split()) > 15):
sentence = self.construct_sent(find_weighted_random(self.unigram_table['.']), self.unigram_table, self.bigram_table)
return sentence
def generate_k_sentences(self, k):
sentence = ''
for x in range(k):
sentence += self.random_sent() + '\n'
print(sentence[:-1])
def generate_sentence_length(self, k):
sentence = ''
while len(sentence.split())-1 != k:
sentence = self.random_sent()
return sentence
###############################
#####Sentence Construction#####
###############################
def create_options_with_bigram(self, bigram_table, prev_key, words_with_weights):
factor = SuccessorModel.BIGRAM_MULTIPLICATIVE_FACTOR
one_prev = prev_key[1]
if one_prev in SuccessorModel.stop_words:
factor = SuccessorModel.STOP_WORDS_FACTOR
bigram_mappings = list(bigram_table[prev_key])
for each_word in bigram_mappings:
each_word[1] *= factor
words_with_weights = deep_merge(words_with_weights, bigram_mappings)
return words_with_weights
def construct_sent(self, word, unigram_table, bigram_table, limit = None):
result, wordcounter = '', 0
two_prev, one_prev = '' , ''
prev_key = (two_prev, one_prev)
while not check_end(word):
result += word + ' '
words_with_weights = list(unigram_table[word])
two_prev, one_prev = one_prev, word
prev_key = (two_prev, one_prev)
if (one_prev != '' and two_prev != ''):
if prev_key in bigram_table:
words_with_weights = self.create_options_with_bigram(bigram_table, prev_key, words_with_weights)
if (limit is not None) and (wordcounter == limit):
end = check_contains_end(words_with_weights)
if (end != ''):
return result + end
else:
return ''
word = find_weighted_random(words_with_weights)
wordcounter += 1
return result + word
##################
#####Learning#####
##################
def build_bigram_table(tokens):
if len(tokens) < 3:
return {}
table = {}
two_prev, one_prev = tokens[0], tokens[1]
prev_key = (two_prev, one_prev)
for i in range(2, len(tokens)):
word = tokens[i]
# only make bigrams when you arne't mapping across sentences (when there's a period)
if not (check_end(prev_key[0]) or check_end(prev_key[1])):
if prev_key not in table:
table[prev_key] = [[word, 1]]
else:
duplicate = False
for successor in table[prev_key]:
if successor[0] == word:
successor[1] += 1
duplicate = True
if (not duplicate):
table[prev_key].append([word, 1])
two_prev, one_prev = one_prev, word
prev_key = (two_prev, one_prev)
return table
def build_unigram_table(tokens):
table = {}
prev = '.'
for word in tokens:
if prev not in table:
table[prev] = [[word, 1]]
else:
duplicate = False
for successor in table[prev]:
if successor[0] == word:
successor[1] += 1
duplicate = True
if (not duplicate):
table[prev].append([word, 1])
prev = word
return table
###################
#####Utilities#####
###################
def get_tokens(path):
if (os.path.exists(path)):
return open(path, encoding='utf-8').read().split()
else:
return []
def process(filename):
processed = ''
path = filename
if (os.path.exists(path)):
preprocessed_file = open(path, 'r', encoding='utf-8')
text = preprocessed_file.read()
for c in text:
if c in ['.','!', '?']:
processed += ' ' + c + ' '
elif c in ['"', '"', '-', '--', "''"]:
continue
else:
processed += c
new_path = filename+'.processed'
processed_file = open(new_path, 'w', encoding='utf-8')
processed_file.write(processed)
preprocessed_file.close()
processed_file.close()
return new_path
else:
return ''
def check_contains_end(l):
for punct in ['.', '!', '?']:
for subl in l:
if punct in subl[0]:
return punct
return ''
# Checks if a word contains a sentence ending character.
# Complexity: O(n) where n is char_length of the string.
def check_end(word):
for punct in ['.','!', '?']:
if punct in word:
return True
return False
# From a list of two-element lists with words and their respective weights,
# this function finds a weighted randomized choice form them.
# Complexity: O(n) where n is number of dictionaries in wordDict.
def find_weighted_random(wordDict):
elements = []
weights = []
weighted_ranges = []
for each_word in wordDict:
elements.append(each_word[0])
weights.append(each_word[1])
if (len(elements) == 0):
return None
if (len(elements) == 1):
return elements[0]
r = random.random()
total_weight, current_range = sum(weights), 0
for weight in weights:
current_range += weight / total_weight
weighted_ranges.append(current_range)
start_range = 0
for i in range(len(weighted_ranges)):
end_range = weighted_ranges[i]
if (r >= start_range and r <= end_range):
return elements[i]
start_range = end_range
# Deep merges 2 lists of 2-element lists to remove duplicates
# keys. Outputs a dictionary that has a sum of the keys (non-destructive)
# Complexity: O(n^2)
def deep_merge(l1, l2):
for sublist1 in l1:
duplicate = False
for sublist2 in l2:
if sublist1[0] == sublist2[0]:
sublist2[1] += sublist1[1]
duplicate = True
if (not duplicate):
l2.append(sublist1)
return l2