-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsemantic_server.py
More file actions
90 lines (73 loc) · 2.23 KB
/
semantic_server.py
File metadata and controls
90 lines (73 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from pathway.xpacks.llm.document_store import DocumentStore
from pathway.xpacks.llm.servers import DocumentStoreServer
from pathway.stdlib.indexing import BruteForceKnnFactory
from pathway.udfs import DiskCache
from pathway.xpacks.llm import embedders
import pathway as pw
from dotenv import load_dotenv
import config
from langchain_core.documents import Document
load_dotenv()
# Initialize Embedder and KNN Index
embedder = embedders.OpenAIEmbedder(cache_strategy=DiskCache())
knn_index = BruteForceKnnFactory(
reserved_space=1000,
embedder=embedder,
metric=pw.engine.BruteForceKnnMetricKind.COS,
dimensions=1536,
)
# Define schema to match JSON structure
class InputSchema(pw.Schema):
record_id: str
query: str
answer: str
type: str # Metadata tag for differentiation
t1 = pw.io.fs.read(
path="data_cache/",
format="json",
schema=InputSchema,
)
t1 = t1.select(
data=pw.this.query + "########" + pw.this.answer + "########" + pw.this.record_id,
_metadata={"is_cache": "True"},
**t1,
)
t2 = pw.io.fs.read(
path="data_convo/",
format="json",
schema=InputSchema,
)
t2 = t2.select(
data=pw.this.query + "########" + pw.this.answer + "########" + pw.this.record_id,
_metadata={"is_cache": "False"},
**t2,
)
pw.universes.promise_are_pairwise_disjoint(t1, t2)
t3 = t1.concat(t2)
class ParseUtf8(pw.UDF):
def __wrapped__(self, contents: bytes) -> list[tuple[str, dict]]:
parts = contents.split("########")
question = parts[0]
answer = parts[1]
record_id = parts[2]
docs: list[tuple[str, dict]] = [
(question, {"answer": answer, "record_id": record_id})
]
return docs
def __call__(self, contents: pw.ColumnExpression, **kwargs) -> pw.ColumnExpression:
return super().__call__(contents, **kwargs)
parser = ParseUtf8()
# Initialize the DocumentStore
vector_store = DocumentStore(
t3,
retriever_factory=knn_index,
parser=parser,
splitter=None,
)
# Run the server
server = DocumentStoreServer(
host=config.CACHE_STORE_HOST,
port=config.CACHE_STORE_PORT,
document_store=vector_store,
)
server.run()