Skip to content

Commit e2a9a7a

Browse files
committed
Use drug pivot table to compute interactions betwen drugs in JSON
1 parent 9f35568 commit e2a9a7a

3 files changed

Lines changed: 96 additions & 5 deletions

File tree

app/pipeline.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
#!/usr/bin/env python3
22

33
import csv
4-
from typing import List, Optional
4+
import json
55
from loguru import logger
6+
from typing import List, Optional
67

78
from app.clients import query_llama
89
from app.prompts import DETERMINE_TASK_PROMPT, GENERAL_PROMPT, \
910
DENY_PROMPT, TASKS, GRAPH_NEEDED, FIND_SUBSTANCES_PROMPT, GRAPH_PROMPT
1011
from app.gpraph import run_subgraph_builder
12+
from app.substance_mapper import create_json_for_llm
1113

1214
entities_file = "data/entity_name_mapping.json"
1315
substances_file = "data/drugbank/drugbank_vocabulary.csv"
@@ -119,9 +121,10 @@ def process_pipeline(query: str, history: List[str]=[], graph: Optional[object]=
119121
logger.info(f"Found substances by LLM: {substances}. Try to find in the DrugBank vocabulary and bulding a graph")
120122
response['graph'] = run_subgraph_builder(substances)
121123
logger.info(f"Subgraph built with {len(response['graph'].vs)} vertices and {len(response['graph'].es)} edges.")
122-
123-
if response['graph']:
124-
prompt = f"{GENERAL_PROMPT}\n\n{TASKS[discovered_class]}\n{GRAPH_PROMPT}\n{response['graph']}\nTask: {query}"
124+
supplemental_json = create_json_for_llm(substances)
125+
# logger.debug(json.dumps(supplemental_json, indent=4))
126+
if supplemental_json and len(supplemental_json) > 2:
127+
prompt = f"{GENERAL_PROMPT}\n\n{TASKS[discovered_class]}\n{GRAPH_PROMPT}\n{supplemental_json}\nTask: {query}"
125128

126129
else:
127130
prompt = f"{GENERAL_PROMPT}\n\n{TASKS[discovered_class]}\nQuery: {query}"

app/prompts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@
4646
FIND_SUBSTANCES_PROMPT = """Find any substances in the query below.
4747
Return a list of original words from text separated by ',' without spaces after ','. Do not separate one substance with ',' if it takes more than one word."""
4848

49-
GRAPH_PROMPT = "Additional information below includes a subgraph displaying relations between those substances in iGraph format. Use it to form the response as a ground truth"
49+
GRAPH_PROMPT = "Additional information below includes a subgraph displaying relations between those substances in JSON format. Use it to form the response as a ground truth. Do not mention the work JSON in the output"

app/substance_mapper.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import json
2+
from loguru import logger
3+
import time
4+
5+
import pandas as pd
6+
import numpy as np
7+
8+
9+
columns_pathway_function = [
10+
'gene_pathways_activated_by_drug',
11+
'gene_pathways_inhibited_by_drug',
12+
'molecular_function_activated_by_drug',
13+
'molecular_function_inhibited_by_drug'
14+
]
15+
16+
def load_json_file(filepath):
17+
"""
18+
Load JSON data from a file into a Python dict.
19+
"""
20+
with open(filepath, 'r', encoding='utf-8') as f:
21+
data = json.load(f) # parses file into dict/list
22+
return data
23+
24+
def process_mapping(ent_mapper):
25+
ent_mapper_new = {x:ent_mapper[x].split(':')[-1] for x in ent_mapper.keys()}
26+
ent_mapper_new['drug_disease_minus'] = 'disease_associated_with_drug'
27+
ent_mapper_new['drug_disease_plus'] = 'disease_cured_by_drug'
28+
ent_mapper_new['drug_gene_minus'] = 'genes_inhibited_or_suppressed_by_drug'
29+
ent_mapper_new['drug_gene_plus'] = 'genes_enhanced_or_activated_by_drug'
30+
ent_mapper_new['drug_side_effect_plus'] = 'side_effects_assosiated_with_drug'
31+
ent_mapper_new['gene_pathway_plus'] = 'gene_pathways_activated_by_drug'
32+
ent_mapper_new['gene_pathway_minus'] = 'gene_pathways_inhibited_by_drug'
33+
ent_mapper_new['gene_function_plus'] = 'molecular_function_activated_by_drug'
34+
ent_mapper_new['gene_function_minus'] = 'molecular_function_inhibited_by_drug'
35+
return ent_mapper_new
36+
37+
def map_cell(cell, mapper):
38+
"""
39+
Map a cell value (which can be NaN, list, numpy array, or scalar) using mapper dict.
40+
- If list/array: map each element, keep original if not in mapper
41+
- If None/NaN: leave as is
42+
- Else (scalar): map if in mapper, else leave
43+
"""
44+
# 1) handle lists and numpy arrays first
45+
if isinstance(cell, (list, np.ndarray)):
46+
mapped = [mapper.get(item, item) for item in cell]
47+
logger.debug(f"List/array mapped: {cell} -> {mapped}")
48+
return mapped
49+
50+
# 2) handle missing scalars
51+
if cell is None or pd.isna(cell):
52+
logger.debug("Missing value encountered, leaving unchanged")
53+
return cell
54+
55+
# 3) scalar mapping
56+
new_val = mapper.get(cell, cell)
57+
if new_val != cell:
58+
logger.debug(f"Scalar mapped: {cell} -> {new_val}")
59+
return new_val
60+
61+
62+
def map_dataframe(df, entity_mapper):
63+
# 1) map all cells
64+
df_mapped = df.map(lambda x: map_cell(x, entity_mapper))
65+
logger.info("Finished mapping cell values")
66+
67+
# 2) map row index
68+
logger.info("Mapping DataFrame index")
69+
new_index = [entity_mapper.get(idx, idx).lower() for idx in df_mapped.index]
70+
df_mapped.index = new_index
71+
72+
logger.info("Mapping DataFrame columns")
73+
df_mapped.rename(columns=entity_mapper, inplace=True)
74+
return df_mapped
75+
76+
start = time.time()
77+
drug_pivot = pd.read_json("data/drug_pivot_full.json", orient="table").set_index("compound")
78+
ent_mapper_new = process_mapping(load_json_file('data/entity_name_mapping.json'))
79+
drug_pivot_mapped = map_dataframe(drug_pivot, ent_mapper_new)
80+
logger.info(f"Loaded substance mapping graph in {time.time() - start}s...")
81+
82+
def create_json_for_llm(compounds: list, drug_pivot=drug_pivot_mapped, mapper=ent_mapper_new) -> dict:
83+
try:
84+
drug_pivot_comp = drug_pivot.loc[compounds].dropna(axis=1, how='all').drop(columns=columns_pathway_function)
85+
return drug_pivot_comp.to_dict()
86+
except:
87+
return {}
88+

0 commit comments

Comments
 (0)