Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion backend/ks_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,90 @@
from difflib import SequenceMatcher



def rerank_results_using_metadata(results: List[dict]) -> List[dict]:
"""
Re-rank search results based on metadata signals:
- More recent datasets are preferred
- Higher citation count is preferred
- Trusted sources can be boosted
"""
def score_result(r: dict) -> float:
score = r.get("_score", 1.0) # original search score
meta = r.get("metadata", {})

# Example 1: Boost newer datasets
year = meta.get("publication_year") or meta.get("year")
if year:
try:
score += float(year) / 10000 # small boost for recent years
except:
pass

# Example 2: Boost by citations if available
citations = meta.get("citations") or meta.get("citation_count")
if citations:
try:
score += float(citations) / 100 # small boost
except:
pass

# Example 3: Boost trusted sources
trusted_sources = ["Allen Brain Atlas", "GENSAT", "EBRAINS"]
source_name = r.get("datasource_name") or meta.get("source") or ""
if any(ts.lower() in str(source_name).lower() for ts in trusted_sources):
score += 0.5 # boost for trusted source
print(f"Result: {r.get('title') or r.get('title_guess')} | Score: {score}")
return score

return sorted(results, key=score_result, reverse=True)










# Query Expansion Code
# --- Query Expansion for Neuroscience terms ---
QUERY_SYNONYMS = {
"mouse brain": ["Rattus norvegicus", "somatosensory cortex", "cortex", "hippocampus"],
"memory": ["hippocampus", "synaptic plasticity"],
"hippocampus": ["CA1", "CA3", "dentate gyrus"],
# add more phrases and synonyms as needed
}


def expand_query(query: str) -> str:
query_lower = query.lower()
expanded = [query_lower] # original query

# Keep track of added terms to avoid duplicates
added_terms = set(expanded)

# Phrase match
for phrase, synonyms in QUERY_SYNONYMS.items():
if phrase in query_lower:
for syn in synonyms:
if syn not in added_terms:
expanded.append(syn)
added_terms.add(syn)

# Word match
for word in query_lower.split():
if word in QUERY_SYNONYMS:
for syn in QUERY_SYNONYMS[word]:
if syn not in added_terms:
expanded.append(syn)
added_terms.add(syn)

return " ".join(expanded)



def tool(args_schema):
def decorator(func):
func.args_schema = args_schema
Expand Down Expand Up @@ -334,6 +418,8 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool
if enrich_details and normalized_results:
print(" -> Using parallel async enrichment...")
normalized_results = await enrich_with_dataset_details_async(normalized_results, top_k)
normalized_results = rerank_results_using_metadata(normalized_results)

return {"combined_results": normalized_results[:top_k]}
except Exception as e:
print(f" -> Error during async general search: {e}")
Expand Down Expand Up @@ -376,6 +462,7 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) ->
print(" -> Enriching results with detailed dataset information (parallel)...")
# Use sync enrichment for now - we'll make the whole function async later
normalized_results = enrich_with_dataset_details(normalized_results, top_k)
normalized_results = rerank_results_using_metadata(normalized_results)
return {"combined_results": normalized_results[:top_k]}
except requests.RequestException as e:
print(f" -> Error during general search: {e}")
Expand Down Expand Up @@ -439,7 +526,9 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs:
"metadata": src,
}
)
out = rerank_results_using_metadata(out)
return out

except requests.RequestException as e:
print(f" -> Error searching {data_source_id}: {e}")
return []
Expand All @@ -452,7 +541,7 @@ def smart_knowledge_search(
data_source: Optional[str] = None,
top_k: int = 10,
) -> dict:
q = query or "*"
q = expand_query(query) if query else "*"
if filters:
config_path = "datasources_config.json"
if os.path.exists(config_path):
Expand All @@ -463,3 +552,39 @@ def smart_knowledge_search(
results = _perform_search(target_id, q, dict(filters), all_configs)
return {"combined_results": results[:top_k]}
return general_search(q, top_k, enrich_details=True)


# Test code



# if __name__ == "__main__":
# test_queries = ["mouse brain", "memory", "hippocampus"]

# for q in test_queries:
# print(f"Searching for: {q}")
# results = smart_knowledge_search(q, top_k=3)
# for i, r in enumerate(results.get("combined_results", [])):
# print(f" {i+1}. {r.get('title') or r.get('title_guess')} - {r.get('primary_link')}")
# print("-" * 50)

def test_rerank():
mock_results = [
{"title_guess": "Dataset A", "_score": 1.0, "metadata": {"year": 2020, "citations": 5, "source": "GENSAT"}},
{"title_guess": "Dataset B", "_score": 1.0, "metadata": {"year": 2023, "citations": 2, "source": "OtherSource"}},
{"title_guess": "Dataset C", "_score": 1.0, "metadata": {"year": 2019, "citations": 10, "source": "EBRAINS"}},
]

print("Before rerank:")
for r in mock_results:
print(r["title_guess"], r["_score"])

ranked = rerank_results_using_metadata(mock_results)

print("\nAfter rerank:")
for r in ranked:
print(r["title_guess"], r["_score"])

if __name__ == "__main__":
test_rerank()