forked from nakulbh/Meve-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathphase2_verification.py
More file actions
69 lines (55 loc) · 2.72 KB
/
phase2_verification.py
File metadata and controls
69 lines (55 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# phase_2_verification.py
from meve_data import ContextChunk, Query, MeVeConfig
from typing import List
from sentence_transformers import CrossEncoder
import torch
# Load cross-encoder model for relevance scoring
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
def get_relevance_score(query_text: str, chunk_content: str) -> float:
"""Get actual relevance score using cross-encoder model."""
try:
# Cross-encoders take query-document pairs and output relevance scores
score = cross_encoder.predict([(query_text, chunk_content)])
# Handle different return types
if isinstance(score, (list, tuple)):
score = float(score[0])
elif isinstance(score, torch.Tensor):
score = score.item()
else:
score = float(score)
# Apply sigmoid to normalize to [0,1] range if needed
normalized_score = torch.sigmoid(torch.tensor(score)).item()
return normalized_score
except Exception as e:
print(f"Error in cross-encoder scoring: {e}")
# Fallback to simple similarity check
return simulate_cross_encoder_fallback(query_text, chunk_content)
def simulate_cross_encoder_fallback(query_text: str, chunk_content: str) -> float:
"""Fallback simulation for cross-encoder scoring."""
# Simple keyword matching as fallback
query_words = set(query_text.lower().split())
content_words = set(chunk_content.lower().split())
overlap = len(query_words.intersection(content_words))
return min(overlap / len(query_words) if query_words else 0.0, 1.0)
def execute_phase_2(query: Query, initial_candidates: List[ContextChunk], config: MeVeConfig) -> List[ContextChunk]:
"""
Phase 2: Relevance Verification (Cross-Encoder).
Filters candidates based on the relevance threshold (tau)[cite: 75, 90].
"""
print(f"--- Phase 2: Relevance Verification (Tau={config.tau_relevance}) ---")
if not initial_candidates:
print("No initial candidates to verify.")
return []
verified_chunks: List[ContextChunk] = []
# Process each candidate through the cross-encoder
for chunk in initial_candidates:
score = get_relevance_score(query.text, chunk.content)
chunk.relevance_score = score
print(f"Chunk {chunk.doc_id}: relevance_score={score:.3f}")
if score >= config.tau_relevance:
verified_chunks.append(chunk)
print(f" → VERIFIED (score >= {config.tau_relevance})")
else:
print(f" → FILTERED OUT (score < {config.tau_relevance})")
print(f"Verified {len(verified_chunks)} out of {len(initial_candidates)} chunks (C_ver).")
return verified_chunks