Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
77 changes: 69 additions & 8 deletions server/clustering/cluster_scoring.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,87 @@
from typing import Dict, List

def compute_entity_similarity(a: Dict, b: Dict) -> float:
"""
Compute entity similarity between two articles with primary/secondary importance weighting.

Args:
a: Entity dict with primary_subject, secondary_subject, primary_orgs, secondary_orgs, primary_event, secondary_event
b: Entity dict with same structure

Returns:
Similarity score (0.0 to 1.0+)
"""
score = 0.0

if a["subject"] and a["subject"] == b["subject"]:
# Primary subject match (highest weight)
if a.get("primary_subject") and a["primary_subject"] == b.get("primary_subject"):
score += 1.0
# Secondary subject match with lower weight
elif a.get("secondary_subject") and a["secondary_subject"] == b.get("secondary_subject"):
score += 0.3
# Cross-match (primary vs secondary)
elif (a.get("primary_subject") and a["primary_subject"] == b.get("secondary_subject")) or \
(a.get("secondary_subject") and a["secondary_subject"] == b.get("primary_subject")):
score += 0.2

if a["event"] and a["event"] == b["event"]:
# Primary event match (high weight)
if a.get("primary_event") and a["primary_event"] == b.get("primary_event"):
score += 0.5
# Secondary event match (lower weight)
elif a.get("secondary_event") and a["secondary_event"] == b.get("secondary_event"):
score += 0.2
# Cross-match
elif (a.get("primary_event") and a["primary_event"] == b.get("secondary_event")) or \
(a.get("secondary_event") and a["secondary_event"] == b.get("primary_event")):
score += 0.15

orgs_a = set(a.get("orgs", []))
orgs_b = set(b.get("orgs", []))
# Organization matching with primary/secondary distinction
primary_orgs_a = set(a.get("primary_orgs", []))
primary_orgs_b = set(b.get("primary_orgs", []))
secondary_orgs_a = set(a.get("secondary_orgs", []))
secondary_orgs_b = set(b.get("secondary_orgs", []))

if orgs_a and orgs_b:
score += 0.2 * len(orgs_a & orgs_b)
# Primary org matches (higher weight)
if primary_orgs_a and primary_orgs_b:
score += 0.3 * len(primary_orgs_a & primary_orgs_b)

# Secondary org matches (lower weight)
if secondary_orgs_a and secondary_orgs_b:
score += 0.1 * len(secondary_orgs_a & secondary_orgs_b)

# Cross-org matches (primary <-> secondary)
if primary_orgs_a and secondary_orgs_b:
score += 0.1 * len(primary_orgs_a & secondary_orgs_b)
if secondary_orgs_a and primary_orgs_b:
score += 0.1 * len(secondary_orgs_a & primary_orgs_b)

return score

def compute_final_score(
semantic_score: float,
entity_score: float,
w_sem: float = 0.6,
cross_score: float = 0.5,
w_sem: float = 0.3,
w_ent: float = 0.4,
w_cross: float = 0.3,
) -> float:
return w_sem * semantic_score + w_ent * entity_score
"""
Compute final clustering score combining multiple signals.

Args:
semantic_score: Embedding-based similarity (0.0-1.0)
entity_score: Entity matching score (0.0-1.0+)
cross_score: Cross-encoder score (0.0-1.0)
w_sem: Weight for semantic similarity
w_ent: Weight for entity similarity
w_cross: Weight for cross-encoder score

Returns:
Final combined score
"""
# Normalize entity score to [0, 1] range
normalized_entity = min(entity_score / 2.0, 1.0)

return (w_sem * semantic_score +
w_ent * normalized_entity +
w_cross * cross_score)
124 changes: 124 additions & 0 deletions server/cross_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
from typing import Dict, List, Any, Tuple
from sentence_transformers import CrossEncoder
import numpy as np

logger = logging.getLogger(__name__)


class CrossEncoderManager:
"""Manages cross-encoder for computing semantic relevance scores between articles."""

def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
"""
Initialize the cross encoder.

Args:
model_name: HuggingFace model identifier for cross-encoder
Default: ms-marco-MiniLM-L-6-v2 (efficient and accurate for relevance)
"""
self.model_name = model_name
try:
self.model = CrossEncoder(model_name)
logger.info(f"Cross-encoder loaded: {model_name}")
except Exception as e:
logger.error(f"Failed to load cross-encoder: {e}")
self.model = None

def compute_relevance_score(
self,
query_article: Dict[str, Any],
candidate_article: Dict[str, Any]
) -> float:
"""
Compute semantic relevance score between two articles.

Args:
query_article: Source article dict with title, description, full_content
candidate_article: Target article dict for comparison

Returns:
Relevance score between 0 and 1
"""
if self.model is None:
logger.warning("Cross-encoder model not loaded, returning 0.5")
return 0.5

try:
query_text = self._build_article_text(query_article)
candidate_text = self._build_article_text(candidate_article)
scores = self.model.predict([
[query_text, candidate_text]
])
relevance_score = self._sigmoid(scores[0])

return float(relevance_score)

except Exception as e:
logger.error(f"Error computing relevance score: {e}")
return 0.5

def compute_batch_relevance_scores(
self,
query_article: Dict[str, Any],
candidate_articles: List[Dict[str, Any]]
) -> List[float]:
"""
Compute relevance scores between one query article and multiple candidates.

Args:
query_article: Source article
candidate_articles: List of candidate articles

Returns:
List of relevance scores
"""
if self.model is None or not candidate_articles:
return [0.5] * len(candidate_articles)

try:
query_text = self._build_article_text(query_article)

pairs = [
[query_text, self._build_article_text(candidate)]
for candidate in candidate_articles
]
scores = self.model.predict(pairs)
normalized_scores = [float(self._sigmoid(score)) for score in scores]

return normalized_scores

except Exception as e:
logger.error(f"Error computing batch relevance scores: {e}")
return [0.5] * len(candidate_articles)

def _build_article_text(self, article: Dict[str, Any]) -> str:
"""
Build a text representation of an article for cross-encoder.

Args:
article: Article dictionary

Returns:
Combined text of title and description
"""
title = article.get("title", "").strip()
description = article.get("description", "").strip()

if title and description:
return f"{title} {description}"
elif title:
return title
elif description:
return description
else:
return ""

@staticmethod
def _sigmoid(x: float) -> float:
"""Apply sigmoid function to normalize cross-encoder output."""
import math
try:
return 1.0 / (1.0 + math.exp(-x))
except OverflowError:
return 0.0 if x < 0 else 1.0
Loading
Loading