-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
117 lines (106 loc) · 4.26 KB
/
main.py
File metadata and controls
117 lines (106 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import streamlit as st
import os
import logging
from src.data_loader import DataLoader
from src.embedder import Embedder
from src.retriever import Retriever
from src.generator import Generator
import yaml
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('rag_app_cpu.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_config(config_path: str = 'config.yaml') -> dict:
"""Load configuration from YAML file."""
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
logger.info("Configuration loaded successfully")
return config
except Exception as e:
logger.error(f"Error loading config: {e}")
raise
def main():
"""Streamlit application for RAG."""
st.title("RAG Application for PDF Question Answering")
st.write("Upload a PDF file and ask questions about its content.")
# Load configuration
try:
config = load_config()
except Exception as e:
st.error(f"Failed to load configuration: {e}")
return
# Initialize components
data_loader = DataLoader(
chunk_size=config['chunk_size'],
chunk_overlap=config['chunk_overlap']
)
embedder = Embedder(
embedding_model=config['embedding_model'],
vector_db_dir=config['vector_db_dir']
)
generator = Generator(
generation_model=config['generation_model']
)
# PDF upload
uploaded_file = st.file_uploader("Upload a PDF file", type="pdf")
if uploaded_file is not None:
# Save uploaded file
pdf_path = os.path.join("data", uploaded_file.name)
os.makedirs("data", exist_ok=True)
with open(pdf_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"Uploaded {uploaded_file.name}")
# Check for existing vector database
pdf_hash = data_loader.get_pdf_hash(pdf_path)
vector_db = embedder.load_existing_db(pdf_hash)
# Load and process PDF if no valid database exists
if vector_db is None:
with st.spinner("Processing PDF and creating embeddings..."):
try:
documents = data_loader.load_pdf(pdf_path)
chunks = data_loader.split_documents(documents)
if not chunks:
st.error("No document chunks available. Check the PDF file.")
logger.error("No document chunks to process")
return
vector_db = embedder.store_embeddings(chunks, pdf_hash)
st.success("PDF processed and embeddings stored.")
except Exception as e:
st.error(f"Error processing PDF: {e}")
logger.error(f"Error processing PDF: {e}")
return
# Initialize retriever
retriever = Retriever(vector_db, config['max_retrieved_docs'])
# Query input
query = st.text_input("Enter your question:", placeholder="e.g., What was the revenue in 2023?")
if st.button("Submit"):
if query:
with st.spinner("Generating response..."):
try:
retrieved_docs, sources = retriever.retrieve_documents(query)
if not retrieved_docs:
st.warning("No relevant documents found.")
logger.warning("No relevant documents found for query")
else:
response = generator.generate_response(query, retrieved_docs)
st.write("**Response:**")
st.write(response)
st.write("**Retrieved from:**")
for source in sources:
st.write(f"- {source}")
except Exception as e:
st.error(f"Error generating response: {e}")
logger.error(f"Error generating response: {e}")
else:
st.warning("Please enter a question.")
else:
st.info("Please upload a PDF file to start.")
if __name__ == "__main__":
main()