-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathidcee.py
More file actions
115 lines (94 loc) · 3.93 KB
/
idcee.py
File metadata and controls
115 lines (94 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import time
import json
import glob
from pathlib import Path
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from load_model import load_llm
# ===== CẤU HÌNH ========
TOP_K = 3
INDEX_DIR = Path("data/index")
EMBEDDING_PATH = "data/encoder/bge-m3"
# ===== LOAD MODEL =====
def load_embedding_model():
print("📥 Đang tải mô hình embedding bge-m3...")
return SentenceTransformer(EMBEDDING_PATH)
def load_all_indexes():
all_indexes, all_mappings = [], []
for faiss_path in glob.glob(str(INDEX_DIR / "**/index.faiss"), recursive=True):
try:
index = faiss.read_index(faiss_path)
mapping_path = Path(faiss_path).parent / "mapping.json"
if not mapping_path.exists():
continue
with open(mapping_path, encoding="utf-8") as f:
mapping = json.load(f)
all_indexes.append(index)
all_mappings.append(mapping)
print(f"✅ Loaded index: {faiss_path}")
except Exception as e:
print(f"❌ Lỗi khi load {faiss_path}: {e}")
return all_indexes, all_mappings
# ===== SEARCH & RETRIEVE =====
def search_similar_chunks(query, model, indexes, mappings, top_k=TOP_K):
query_emb = model.encode(
f"Represent this sentence for searching relevant passages: {query}",
convert_to_numpy=True
)
results = []
for index, texts in zip(indexes, mappings):
D, I = index.search(np.array([query_emb]), top_k)
results += [(score, texts[idx]) for score, idx in zip(D[0], I[0]) if 0 <= idx < len(texts)]
return [text for _, text in sorted(results, key=lambda x: x[0])[:top_k]]
# ===== CONTEXT =====
def limit_context(chunks, max_chars=800):
context = ""
for c in chunks:
if len(context) + len(c) > max_chars:
break
context += c + "\n\n"
return context.strip()
# ===== MAIN LOOP =====
def main():
llm = load_llm()
embed_model = load_embedding_model()
indexes, mappings = load_all_indexes()
print("🤖 IDCee sẵn sàng! Gõ 'exit' để thoát.\n")
while True:
user_input = input("🧑 You: ").strip()
if user_input.lower() in ["exit", "quit"]:
break
chunks = search_similar_chunks(user_input, embed_model, indexes, mappings)
context = limit_context(chunks)
prompt = f"""Bạn là trợ lý AI IDCee. Trả lời ngắn gọn, bằng tiếng Việt và chỉ sử dụng thông tin từ phần Thông tin nội bộ.
👉 Nếu không có thông tin phù hợp, trả lời duy nhất câu sau:
"Tôi không tìm thấy thông tin trong tài liệu nội bộ để trả lời câu hỏi này."
❗ Không được bịa, suy đoán hoặc thêm nội dung ngoài dữ liệu cung cấp. Chỉ sử dụng nội dung chứa từ khóa: "{user_input}"
### Câu hỏi:
{user_input}
### Thông tin nội bộ:
{context}
### Trả lời:
"""
start = time.time()
try:
result = llm(prompt, max_tokens=192)
if isinstance(result, str):
reply = result.strip()
elif hasattr(result, "__iter__") and not isinstance(result, dict):
reply = "".join(chunk for chunk in result).strip()
elif isinstance(result, dict) and "choices" in result:
reply = result["choices"][0]["text"].strip()
else:
reply = "(Không thể xử lý phản hồi)"
elapsed = time.time() - start
if not reply:
print("🤖 IDCee: (Không thể tạo câu trả lời)")
else:
print(f"🤖 IDCee: {reply}")
print(f"⏱️ Thời gian phản hồi: {elapsed:.2f} giây\n")
except Exception as e:
print(f"❌ Lỗi infer: {e}")
if __name__ == "__main__":
main()