diff --git a/backend/ks_search_tool.py b/backend/ks_search_tool.py index 3004a02..b093aff 100644 --- a/backend/ks_search_tool.py +++ b/backend/ks_search_tool.py @@ -10,6 +10,90 @@ from difflib import SequenceMatcher + +def rerank_results_using_metadata(results: List[dict]) -> List[dict]: + """ + Re-rank search results based on metadata signals: + - More recent datasets are preferred + - Higher citation count is preferred + - Trusted sources can be boosted + """ + def score_result(r: dict) -> float: + score = r.get("_score", 1.0) # original search score + meta = r.get("metadata", {}) + + # Example 1: Boost newer datasets + year = meta.get("publication_year") or meta.get("year") + if year: + try: + score += float(year) / 10000 # small boost for recent years + except: + pass + + # Example 2: Boost by citations if available + citations = meta.get("citations") or meta.get("citation_count") + if citations: + try: + score += float(citations) / 100 # small boost + except: + pass + + # Example 3: Boost trusted sources + trusted_sources = ["Allen Brain Atlas", "GENSAT", "EBRAINS"] + source_name = r.get("datasource_name") or meta.get("source") or "" + if any(ts.lower() in str(source_name).lower() for ts in trusted_sources): + score += 0.5 # boost for trusted source + print(f"Result: {r.get('title') or r.get('title_guess')} | Score: {score}") + return score + + return sorted(results, key=score_result, reverse=True) + + + + + + + + + + +# Query Expansion Code +# --- Query Expansion for Neuroscience terms --- +QUERY_SYNONYMS = { + "mouse brain": ["Rattus norvegicus", "somatosensory cortex", "cortex", "hippocampus"], + "memory": ["hippocampus", "synaptic plasticity"], + "hippocampus": ["CA1", "CA3", "dentate gyrus"], + # add more phrases and synonyms as needed +} + + +def expand_query(query: str) -> str: + query_lower = query.lower() + expanded = [query_lower] # original query + + # Keep track of added terms to avoid duplicates + added_terms = set(expanded) + + # Phrase match + for phrase, synonyms in QUERY_SYNONYMS.items(): + if phrase in query_lower: + for syn in synonyms: + if syn not in added_terms: + expanded.append(syn) + added_terms.add(syn) + + # Word match + for word in query_lower.split(): + if word in QUERY_SYNONYMS: + for syn in QUERY_SYNONYMS[word]: + if syn not in added_terms: + expanded.append(syn) + added_terms.add(syn) + + return " ".join(expanded) + + + def tool(args_schema): def decorator(func): func.args_schema = args_schema @@ -334,6 +418,8 @@ async def general_search_async(query: str, top_k: int = 10, enrich_details: bool if enrich_details and normalized_results: print(" -> Using parallel async enrichment...") normalized_results = await enrich_with_dataset_details_async(normalized_results, top_k) + normalized_results = rerank_results_using_metadata(normalized_results) + return {"combined_results": normalized_results[:top_k]} except Exception as e: print(f" -> Error during async general search: {e}") @@ -376,6 +462,7 @@ def general_search(query: str, top_k: int = 10, enrich_details: bool = True) -> print(" -> Enriching results with detailed dataset information (parallel)...") # Use sync enrichment for now - we'll make the whole function async later normalized_results = enrich_with_dataset_details(normalized_results, top_k) + normalized_results = rerank_results_using_metadata(normalized_results) return {"combined_results": normalized_results[:top_k]} except requests.RequestException as e: print(f" -> Error during general search: {e}") @@ -439,7 +526,9 @@ def _perform_search(data_source_id: str, query: str, filters: dict, all_configs: "metadata": src, } ) + out = rerank_results_using_metadata(out) return out + except requests.RequestException as e: print(f" -> Error searching {data_source_id}: {e}") return [] @@ -452,7 +541,7 @@ def smart_knowledge_search( data_source: Optional[str] = None, top_k: int = 10, ) -> dict: - q = query or "*" + q = expand_query(query) if query else "*" if filters: config_path = "datasources_config.json" if os.path.exists(config_path): @@ -463,3 +552,39 @@ def smart_knowledge_search( results = _perform_search(target_id, q, dict(filters), all_configs) return {"combined_results": results[:top_k]} return general_search(q, top_k, enrich_details=True) + + +# Test code + + + +# if __name__ == "__main__": +# test_queries = ["mouse brain", "memory", "hippocampus"] + +# for q in test_queries: +# print(f"Searching for: {q}") +# results = smart_knowledge_search(q, top_k=3) +# for i, r in enumerate(results.get("combined_results", [])): +# print(f" {i+1}. {r.get('title') or r.get('title_guess')} - {r.get('primary_link')}") +# print("-" * 50) + +def test_rerank(): + mock_results = [ + {"title_guess": "Dataset A", "_score": 1.0, "metadata": {"year": 2020, "citations": 5, "source": "GENSAT"}}, + {"title_guess": "Dataset B", "_score": 1.0, "metadata": {"year": 2023, "citations": 2, "source": "OtherSource"}}, + {"title_guess": "Dataset C", "_score": 1.0, "metadata": {"year": 2019, "citations": 10, "source": "EBRAINS"}}, + ] + + print("Before rerank:") + for r in mock_results: + print(r["title_guess"], r["_score"]) + + ranked = rerank_results_using_metadata(mock_results) + + print("\nAfter rerank:") + for r in ranked: + print(r["title_guess"], r["_score"]) + +if __name__ == "__main__": + test_rerank() +