-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrievers.py
More file actions
63 lines (51 loc) · 2.22 KB
/
retrievers.py
File metadata and controls
63 lines (51 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import google.generativeai as genai
from dotenv import load_dotenv, find_dotenv
# Updated imports for Gemini integration
from langchain_google_genai import GoogleGenerativeAIEmbeddings # Fixed import
from langchain_community.document_loaders import PyPDFLoader # If still using community for loaders
from langchain_community.vectorstores import Chromadb as Chroma # If still using community for vectorstores
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
# Load env variables
load_dotenv(find_dotenv())
gemini_api_key = os.getenv("GOOGLE_API_KEY")
genai.configure(api_key=gemini_api_key)
# Initialize Gemini LLM (using Google AI SDK directly if needed)
llm = genai.GenerativeModel("gemini-1.5-flash")
# Initialize embeddings using the updated import
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=gemini_api_key)
# Load PDF document
loader = PyPDFLoader("./data/react-paper.pdf")
docs = loader.load()
# Split document
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Store embeddings in Chroma vector database
persist_directory = './data/db/chroma/'
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory)
vectorstore.persist()
# Load vector store and make retriever
vector_store = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 2})
# Retrieve relevant documents
docs = retriever.get_relevant_documents("Tell me more about ReAct prompting")
print(docs[0].page_content)
# Create QA Chain with Gemini
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True,
return_source_documents=True
)
# Helper function to format LLM response
def process_llm_response(llm_response):
print(llm_response['result'])
print('\n\nSources:')
for source in llm_response["source_documents"]:
print(source.metadata['source'])
# Query the LLM
query = "tell me more about ReAct prompting"
llm_response = qa_chain(query)
process_llm_response(llm_response=llm_response)