-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchatbot.py
More file actions
74 lines (61 loc) · 2.69 KB
/
chatbot.py
File metadata and controls
74 lines (61 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# chatbot.py
import os
import time
import textwrap
import torch
from langdetect import detect
from contextlib import contextmanager
from sentence_transformers import CrossEncoder
from llama_index.core import StorageContext, load_index_from_storage
from config import Config
from load_model import load_model_and_tokenizer, setup_llm, setup_embed_model
from clean_response import clean_response
HF_TOKEN = os.getenv("HF_TOKEN") or (lambda: (_ for _ in ()).throw(ValueError("❌ Chưa có HF_TOKEN!")))()
@contextmanager
def timer():
start = time.time()
yield
print(f"⏱️ Thời gian phản hồi: {time.time() - start:.2f} giây")
def setup_index(storage_dir):
storage_context = StorageContext.from_defaults(persist_dir=storage_dir)
return load_index_from_storage(storage_context)
# Thiết lập
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"⚙️ Đang sử dụng thiết bị: {device.upper()} - {torch.cuda.get_device_name(0) if device == 'cuda' else 'CPU'}")
try:
setup_embed_model()
model, tokenizer = load_model_and_tokenizer(HF_TOKEN)
llm = setup_llm(model, tokenizer)
index = setup_index(Config.STORAGE_DIR)
query_engine = index.as_query_engine(llm=llm, similarity_top_k=3)
reranker = CrossEncoder("BAAI/bge-reranker-base")
except Exception as e:
print(f"❌ Lỗi khởi tạo: {e}")
exit(1)
print("🤖 Chatbot IDC đã sẵn sàng. Gõ 'exit' để thoát.\n")
# Vòng lặp chính
while True:
user_input = input("👤 Bạn: ")
if user_input.lower() in ["exit", "quit"]:
print("👋 Tạm biệt!")
break
lang = detect(user_input)
print(f"🌐 Ngôn ngữ phát hiện: {lang}")
print("🤔 Chatbot đang suy nghĩ...")
with timer():
results = query_engine.retrieve(user_input)
pairs = [(user_input, r.node.text) for r in results]
scores = reranker.predict(pairs)
top_node = results[scores.argmax()].node.text
system_prompt = f"""
Bạn là trợ lý AI trả lời câu hỏi dựa trên tài liệu nội bộ. Trả lời chính xác theo dữ liệu.
Nếu không chắc chắn, hãy nói rõ là chưa tìm thấy trong dữ liệu.
Câu hỏi: {user_input}
"""
response = llm.complete(top_node + "\n" + system_prompt)
cleaned_text = clean_response(response)
wrapped_text = textwrap.fill(cleaned_text, width=100)
print("📄 Đoạn văn mô hình đang dùng để trả lời:")
print("-", top_node[:300])
prefix = "(Tiếng Việt)" if lang == "vi" else "(English)"
print("💬", prefix, wrapped_text)