-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild_embeddings.py
More file actions
69 lines (61 loc) · 2.67 KB
/
build_embeddings.py
File metadata and controls
69 lines (61 loc) · 2.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
build_embeddings.py — builds FAISS index from enriched corpus.
No parallelism, sequential GPU encoding with progress bar.
"""
import json
import logging
import sys
from pathlib import Path
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
log = logging.getLogger(__name__)
CORPUS = Path("data/corpus/corpus/protocols_corpus_enriched.jsonl")
FAISS_OUT = Path("data/faiss_index.bin")
EMB_OUT = Path("data/e5_embeddings_cache.npy")
MODEL_NAME = "intfloat/multilingual-e5-large"
BATCH = 16
MAX_DIAG = 500
def main():
# ── Load corpus ──────────────────────────────────────────────────────
log.info("Reading corpus: %s", CORPUS)
passages = []
with open(CORPUS, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
title = obj.get("title", "").strip()
codes = ", ".join(obj.get("icd_codes", []))
text = obj.get("text", "")[:MAX_DIAG]
passage = f"passage: {title}. МКБ-10: {codes}. {text}"
passages.append(passage)
log.info("Loaded %d passages", len(passages))
# ── Load model (FP16 on GPU) ─────────────────────────────────────────
device = "cuda" if torch.cuda.is_available() else "cpu"
log.info("Loading model %s on %s in FP16...", MODEL_NAME, device)
model = SentenceTransformer(MODEL_NAME, device=device)
model.half()
# ── Encode sequentially — NO ThreadPoolExecutor ─────────────────────
log.info("Encoding %d passages (batch=%d)...", len(passages), BATCH)
embeddings = model.encode(
passages,
batch_size=BATCH,
normalize_embeddings=True,
convert_to_numpy=True,
show_progress_bar=True,
).astype(np.float32)
log.info("Embeddings shape: %s", embeddings.shape)
# ── Build + save FAISS index ─────────────────────────────────────────
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
faiss.write_index(index, str(FAISS_OUT))
np.save(str(EMB_OUT), embeddings)
log.info("FAISS index saved to %s (%d vectors)", FAISS_OUT, index.ntotal)
log.info("Embeddings cache saved to %s", EMB_OUT)
if __name__ == "__main__":
main()