-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretriever.py
More file actions
executable file
·152 lines (122 loc) · 4.09 KB
/
retriever.py
File metadata and controls
executable file
·152 lines (122 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/python3
# needs improving to remove forced type conversions
import sys
import re
import json
import math
from nltk.stem.snowball import SnowballStemmer
import string
import csv
# global declarations for doclist, postings, vocabulary
docids = []
postings = {}
cache = []
doclength = {}
vocab = []
def main():
# code for testing offline
if len(sys.argv) < 2:
print('usage: ./retriever.py term [term ...]')
sys.exit(1)
query = sys.argv[1:]
query_terms = []
answer = []
read_index_files()
print('Query: ', query) #shows terms that were input
for i in query: #stores terms in a list
stripped = "".join(l for l in i if l not in string.punctuation)
query_terms.append(stripped)
#Stem the input to find stemmmed words in vocab
stemmer = SnowballStemmer("english")
query_terms = [stemmer.stem(plural) for plural in query_terms]
answer = retrieve_bool(query_terms)
#prints results
print()
i = 0
with open('output.csv', 'w', newline='\n') as file:#store output to file for data collection
for docid in answer:
# only print top 10 results (un-nest for all results)
if i < 10:
i += 1
print(i, cache[docid][0]) #Title
print(docids[int(docid)]) #URL
print(cache[docid][1]) #Snippit
print()#space
#stores data to CSV for collection(got top line from stackoverflow)
spamwriter = csv.writer(file, delimiter=' ',
quotechar='|', quoting=csv.QUOTE_MINIMAL)
spamwriter.writerow(docids[int(docid)])
def read_index_files():
## reads existing data from index files: docids, vocab, postings
# uses JSON to preserve list/dictionary data structures
# declare refs to global variables
global docids
global postings
global cache
global doclength
global vocab
# open the files
in_d = open('docids.txt', 'r')
in_v = open('vocab.txt', 'r')
in_c = open('cache.txt', 'r')
in_l = open('doclength.txt', 'r')
in_p = open('postings.txt', 'r')
# load the data
docids = json.load(in_d)
vocab = json.load(in_v)
cache = json.load(in_c)
doclength = json.load(in_l)
postings = json.load(in_p)
# close the files
in_d.close()
in_v.close()
in_c.close()
in_l.close()
in_p.close()
return
def retrieve_bool(query_terms):
global docids
global doclength
global cache
global vocab
global postings
answer = []
merge_list = []
idf = {}
scores = {}
query_vector = []
query_set = set(query_terms)
for term in query_set:
weight = 0 #set weight to 0 each term
try:
termid = str(vocab.index(term.lower()))#check if term is in vocab
except: # the term is not in the vocab
print('Not found: ', term, ' is not in vocabulary')
continue
#gets the words weight
for i in postings.get(termid):
weight = weight + i[2] #get weight location from postings
if weight > 0:
#if it has a weight, give weight
weight = weight + i[2]
#get idf weight
idf[termid] = (1 + math.log(len(postings.get(termid)))) / (len(doclength))
i = -1
## now calculate tf*idf and score for each doc and the query
for termid in sorted(idf, key=idf.get, reverse=True):
i += 1
#get vector
query_vector.append(idf[termid] / len(query_set) + weight)
#give post thr score for results
for post in postings.get(termid):
if post[0] in scores:
scores[post[0]] += (idf.get(termid) * post[1]) / doclength[str(post[0])] * query_vector[i]
else:
scores[post[0]] = (idf.get(termid) * post[1]) / doclength[str(post[0])] * query_vector[i]
# rank the list
for docid in sorted(scores, key=scores.get, reverse=True):
answer.append(docid)
return answer
# Standard boilerplate to call the main() function
if __name__ == '__main__':
main()