-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspellchecker.py
More file actions
108 lines (85 loc) · 3.81 KB
/
spellchecker.py
File metadata and controls
108 lines (85 loc) · 3.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import sys
from copy import copy
from time import time
from Engine.Classifier.QueryClassifier import QueryClassifier
from Engine.ErrorModel import ErrorModel
from Engine.Generators.GrammarGenerator import GrammarGenerator
from Engine.Generators.JoinGenerator import JoinGenerator
from Engine.Generators.LayoutGenerator import LayoutGenerator
from Engine.Generators.SplitGenerator import SplitGenerator
from Engine.TextFormatter import TextFormatter
from Engine.utils.utils import load_obj, print_error
def correctLayout(layoutGenerator, words, correction, probs):
keybordChangedWords = layoutGenerator.generate_correction(words)
all_generation.append(keybordChangedWords)
queryKeybord = textFormatter.format_text(keybordChangedWords)
if qc.is_correct(queryKeybord, keybordChangedWords):
correction.append(queryKeybord)
probs.append(lm.get_prob(keybordChangedWords))
def correctGrammar(grammarGenerator, words, correction, probs):
grammas = grammarGenerator.generate_correction(words)
for gramma in grammas:
all_generation.append(gramma)
queryGramma = textFormatter.format_text(gramma)
if qc.is_correct(queryGramma, gramma):
correction.append(queryGramma)
probs.append(lm.get_prob(gramma))
def correctJoin(joinGenerator, words, correction, probs):
joins = joinGenerator.generate_joins(words)
all_generation.extend(joins)
for join in joins:
queryJoin = u" ".join(join)
if qc.is_correct(queryJoin, join):
correction.append(queryJoin)
probs.append(lm.get_prob(join))
def correctSplit(splitGenerator, words, correction, probs):
splits = splitGenerator.generate_splits(words)
all_generation.extend(splits)
for split in splits:
querySplit = u" ".join(split)
if qc.is_correct(querySplit, split):
correction.append(querySplit)
probs.append(lm.get_prob(split))
def correct(layoutGenerator, grammarGenerator, joinGenerator, splitGenerator, words, correction, probs):
correctLayout(layoutGenerator, words, correction, probs)
correctGrammar(grammarGenerator, words, correction, probs)
correctJoin(joinGenerator, words, correction, probs)
correctSplit(splitGenerator, words, correction, probs)
if __name__ == "__main__":
MAX_ITER = 2
lm = load_obj("LanguageModel")
em = load_obj("ErrorModel")
#qc_input = QueryClassifier(load_obj("classifier_input"), lm)
qc = QueryClassifier(load_obj("classifier"), lm)
layoutGenerator = LayoutGenerator()
splitGenerator = SplitGenerator(lm)
joinGenerator = JoinGenerator()
grammarGenerator = GrammarGenerator(em, lm)
for s in sys.stdin:
t = time()
textFormatter = TextFormatter(s)
words = textFormatter.get_query_list()
query = textFormatter.text
if qc.is_correct(query, words):
print query.encode("utf-8")
else:
iteration = MAX_ITER
found = False
while iteration > 0 and not found:
iteration -= 1
correction = []
all_generation = []
probs = []
correct(layoutGenerator, grammarGenerator, joinGenerator, splitGenerator, words, correction, probs)
if len(correction) != 0:
print correction[probs.index(max(probs))].encode("utf-8")
found = True
else:
gen_prob = []
for g in all_generation:
gen_prob.append(lm.get_prob(g))
words = all_generation[gen_prob.index(max(gen_prob))]
words = list(words)
if not found:
print textFormatter.format_text(words).encode("utf-8")
#print "{} for {} iter".format(t-time(), MAX_ITER-iteration)