-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
153 lines (121 loc) · 4.09 KB
/
app.py
File metadata and controls
153 lines (121 loc) · 4.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import json
import requests
import weaviate
from dotenv import dotenv_values
from weaviate.classes.query import MetadataQuery
from utils import get_embedding
config = dotenv_values(".env")
client = weaviate.connect_to_local()
collection = client.collections.get(name=config["COLLECTION_NAME"])
messages = [
{
"role": "system",
"content": "I am a helpful assistant.",
}
]
initial_query = True
def keyword_search(query, K=10):
response = collection.query.bm25(
query=query,
limit=K,
return_metadata=MetadataQuery(score=True),
)
return [obj.properties["content"] for obj in response.objects]
def vector_search(query, K=10):
embedding = get_embedding(query)
response = collection.query.near_vector(
near_vector=embedding,
limit=K,
return_metadata=MetadataQuery(distance=True),
)
return [obj.properties["content"] for obj in response.objects]
def hybrid_search(query, K=10):
embedding = get_embedding(query)
response = collection.query.hybrid(
query=query,
vector=embedding,
limit=K,
return_metadata=MetadataQuery(score=True),
)
return [obj.properties["content"] for obj in response.objects]
def getRelevantChunks(query, strategy=config["RETRIEVAL"]):
if strategy == "keyword":
return keyword_search(query)
if strategy == "vector":
return vector_search(query)
if strategy == "hybrid":
return hybrid_search(query)
raise ValueError(
f'Invalid argument: {strategy}. Expected strategy one of "keyword", "vector" or "hybrid".'
)
def createQueryPrompt(query):
relevantChunks = getRelevantChunks(query)
queryPrompt = (
f"Use the context provided below to answer the following question: {query}\n\n"
)
for idx, chunk in enumerate(relevantChunks):
item = f"{idx+1}. {chunk} \n"
queryPrompt += item
queryPrompt += "\nYou have all the context you need provided above."
return queryPrompt
def prompt_llm(messages):
params = {
"model": config["CHAT_MODEL"],
"messages": messages,
"options": {
"num_predict": 512,
},
"stream": True,
}
# Making POST request to LLM
response = requests.post(config["LLM_URL"], json=params, stream=True)
# Check if request was successful (status code 200)
if response.status_code == 200:
results = []
# Iterate over the streamed response
for line in response.iter_lines():
if line:
# Parse JSON response
# print(line)
result = json.loads(line)
results.append(result)
# Yield each result as it is received
yield result
# After streaming, yield the complete results list
yield results
else:
# Request was not successful
print("Error:", response.status_code)
yield None
while True:
userInput = input("\n\nUser : ")
if userInput.lower() == "exit":
break
if userInput.lower() == "new":
initial_query = True
messages = messages[:1]
continue
prompt = userInput
if initial_query:
prompt = createQueryPrompt(userInput)
message = {"role": "user", "content": prompt}
messages.append(message)
results = prompt_llm(messages)
assistant_message = ""
print("\nAssistant: ", end="")
for result in results:
if result:
if isinstance(result, list):
# Final results list, not doing anything with it as we already processed all results
messages.append({"role": "assistant", "content": assistant_message})
else:
# Individual streamed result
if not result["done"]:
predicted_token = result["message"]["content"]
assistant_message += predicted_token
print(predicted_token, end="", flush=True)
else:
print("Error communicating with the language model.")
break
initial_query = False
client.close()