Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions config/genefunction/gfun_vis.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
output_dir: experiments/
output_dir: /pasteur/appa/homes/claudy/2026-BPN-paper-data/gene-function-prediction/


dataset:
class: biomedical
path: data/PC_KEGG_0928
path: /pasteur/appa/homes/claudy/BioPathNet/data/PC_KEGG_0602
include_factgraph: yes

task:
Expand Down
4 changes: 2 additions & 2 deletions config/mock/mockdata_inductive.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
output_dir: experiments/
output_dir: /Users/claudy/work/projects/BioPathNet/experiments/

dataset:
class: BiomedicalInductive
path: data/mock_inductive
path: /Users/claudy/work/projects/BioPathNet/data/mock_inductive

task:
class: KnowledgeGraphCompletionBiomedInductive
Expand Down
6 changes: 3 additions & 3 deletions config/mock/mockdata_inductive_pred.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
output_dir: experiments/
output_dir: /Users/claudy/work/projects/BioPathNet/experiments/

dataset:
class: BiomedicalInductive
path: data/mock_inductive
path: /Users/claudy/work/projects/BioPathNet/data/mock_inductive

task:
class: KnowledgeGraphCompletionBiomedInductive
Expand Down Expand Up @@ -37,4 +37,4 @@ train:
num_epoch: 5

metric: mrr
checkpoint: {{checkpoint}}
checkpoint: {{checkpoint}}
52 changes: 52 additions & 0 deletions config/mock/mockdata_inductive_vis.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
output_dir: /Users/claudy/work/projects/BioPathNet/experiments/

dataset:
class: BiomedicalInductive
path: /Users/claudy/work/projects/BioPathNet/data/mock_inductive

task:
class: KnowledgeGraphCompletionBiomedInductive
model:
class: NBFNet
input_dim: 32
hidden_dims: [32, 32, 32, 32, 32, 32]
message_func: distmult
aggregate_func: pna
short_cut: yes
layer_norm: yes
dependent: yes
symmetric: no
criterion: bce
num_negative: 32
strict_negative: yes
adversarial_temperature: 0.5
sample_weight: no
heterogeneous_negative: yes
heterogeneous_evaluation: yes
full_batch_eval: no
<<<<<<< HEAD
=======
remove_pos: no
>>>>>>> 776380a (feat(config files): adds config files for inductive models on mock data and missing visualisation for gene function prediction.)

optimizer:
class: Adam
lr: 5.0e-3

engine:
gpus: {{ gpus }}
batch_size: 4

<<<<<<< HEAD
=======

>>>>>>> 776380a (feat(config files): adds config files for inductive models on mock data and missing visualisation for gene function prediction.)
train:
num_epoch: 5

metric: mrr
<<<<<<< HEAD
=======

>>>>>>> 776380a (feat(config files): adds config files for inductive models on mock data and missing visualisation for gene function prediction.)
checkpoint: {{checkpoint}}
8 changes: 4 additions & 4 deletions script/eval_and_predict_inductive.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ def load_dataset_and_solver(cfg, test_file):

if __name__ == "__main__":
args, vars = util.parse_args()
logger = util.get_root_logger()
cfg = util.load_config(args.config, context=vars)
working_dir = util.create_working_directory(cfg)
print(working_dir)
logger.warning(f"working directory = {working_dir}")
# get entity names
if 'entity_files' in cfg.dataset:
vfile = cfg.dataset.entity_files[1]
Expand All @@ -180,8 +181,6 @@ def load_dataset_and_solver(cfg, test_file):

torch.manual_seed(args.seed + comm.get_rank())

logger = util.get_root_logger()
logger.warning("Working directory: %s" % working_dir)
if comm.get_rank() == 0:
logger.warning("Config file: %s" % args.config)
#logger.warning(pprint.pformat(cfg))
Expand All @@ -192,6 +191,7 @@ def load_dataset_and_solver(cfg, test_file):

# Prediction phase
cfg_pred, solver, _dataset, entity_vocab, relation_vocab = load_dataset_and_solver(cfg, 'test_pred.txt')
logger.warning(f"******* solver = {solver}")
test(cfg_pred, solver)

logger.warning(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
Expand All @@ -206,4 +206,4 @@ def load_dataset_and_solver(cfg, test_file):
logger.warning(f"Saving to file {os.path.join(working_dir, 'predictions.csv')}")
df.to_csv(os.path.join(working_dir, "predictions.csv"), index=False, sep="\t")
logger.warning("Done")
logger.warning("------------------------------")
logger.warning("------------------------------")
157 changes: 157 additions & 0 deletions script/visualize_inductive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
import sys
import pprint

import torch

from torchdrug import core
from torchdrug.utils import comm

sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from biopathnet import dataset, layer, model, task, util




def solver_load(checkpoint, load_optimizer=True):

if comm.get_rank() == 0:
logger.warning("Load checkpoint from %s" % checkpoint)
checkpoint = os.path.expanduser(checkpoint)
state = torch.load(checkpoint, map_location=solver.device)
# some issues with loading back the graphs if present
# remove
state["model"].pop("fact_graph", 0)
state["model"].pop("fact_graph_supervision", 0)
state["model"].pop("graph", 0)
state["model"].pop("train_graph", 0)
state["model"].pop("valid_graph", 0)
state["model"].pop("test_graph", 0)
state["model"].pop("full_valid_graph", 0)
state["model"].pop("full_test_graph", 0)
# load without
solver.model.load_state_dict(state["model"], strict=False)


if load_optimizer:
solver.optimizer.load_state_dict(state["optimizer"])
for state in solver.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(solver.device)

comm.synchronize()

def build_solver(cfg):
cfg.task.model.num_relation = _dataset.num_relation
_task = core.Configurable.load_config_dict(cfg.task)
cfg.optimizer.params = _task.parameters()
optimizer = core.Configurable.load_config_dict(cfg.optimizer)
if "scheduler" in cfg:
cfg.scheduler.optimizer = optimizer
scheduler = core.Configurable.load_config_dict(cfg.scheduler)
else:
scheduler = None
return core.Engine(_task, train_set, valid_set, test_set, optimizer, scheduler, **cfg.engine)



def load_vocab(dataset):
entity_mapping = {}
logger.debug(f"###### vocab_file = {vocab_file}")

entity_mapping = {}
with open(vocab_file, "r") as fin:
for line in fin:
k, v = line.strip().split("\t")
entity_mapping[k] = v

dataset_entity_vocab = list(set(dataset.train_entity_vocab + dataset.test_entity_vocab))
entity_vocab = [entity_mapping[t] for t in dataset_entity_vocab]
relation_vocab = ["%s (%d)" % (t[t.rfind("/") + 1:].replace("_", " "), i)
for i, t in enumerate(dataset.relation_vocab)]

# with open(vocab_file, "r") as fin:
# for line in fin:
# k, v = line.strip().split("\t")
# entity_mapping[k] = v
# logger.debug(f"###### entity_mapping = {entity_mapping}")
# entity_vocab = [entity_mapping[t] for t in dataset.entity_vocab]
# logger.debug(f"###### entity_vocab = {entity_vocab}")
# relation_vocab = ["%s (%d)" % (t[t.rfind("/") + 1:].replace("_", " "), i)
# for i, t in enumerate(dataset.relation_vocab)]
# logger.debug(f"###### relation_vocab = {relation_vocab}")

return entity_vocab, relation_vocab

def visualize_path(solver, triplet, entity_vocab, relation_vocab):
num_relation = len(relation_vocab)
h, t, r = triplet.tolist()
triplet = torch.as_tensor([[h, t, r]], device=solver.device)
inverse = torch.as_tensor([[t, h, r + num_relation]], device=solver.device)
solver.model.split = "test"
solver.model.eval()
pred, (mask, target) = solver.model.predict_and_target(triplet)
pos_pred = pred.gather(-1, target.unsqueeze(-1))
rankings = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
rankings = rankings.squeeze(0)

logger.warning("")
samples = (triplet, inverse)
for sample, ranking in zip(samples, rankings):
h, t, r = sample.squeeze(0).tolist()
h_name = entity_vocab[h]
t_name = entity_vocab[t]
r_name = relation_vocab[r % num_relation]
if r >= num_relation:
r_name += "^(-1)"
logger.warning(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
logger.warning("rank(%s | %s, %s) = %g" % (t_name, h_name, r_name, ranking))

paths, weights = solver.model.visualize(sample)
for path, weight in zip(paths, weights):
triplets = []
for h, t, r in path:
# try:
h_name = entity_vocab[h]
t_name = entity_vocab[t]
r_name = relation_vocab[r % num_relation]
# except IndexError as e:
# logger.error(f"entity_vocab = {entity_vocab}")
# logger.error(f"entity_vocab.size = {len(entity_vocab)}")
# logger.error(f"h = {h}")
# logger.error(f"t = {t}")
# logger.error(f"r = {r}")
# break
if r >= num_relation:
r_name += "^(-1)"
triplets.append("<%s, %s, %s>" % (h_name, r_name, t_name))
logger.warning("weight: %g\n\t%s" % (weight, " ->\n\t".join(triplets)))


if __name__ == "__main__":
args, vars = util.parse_args()
cfg = util.load_config(args.config, context=vars)
working_dir = util.create_working_directory(cfg)
print(working_dir)
vocab_file = os.path.join(os.path.dirname(__file__), cfg.dataset.path, "entity_names.txt")
vocab_file = os.path.abspath(vocab_file)
torch.manual_seed(args.seed + comm.get_rank())

logger = util.get_root_logger()
logger.warning("Working directory: %s" % working_dir)
logger.warning("Config file: %s" % args.config)
logger.warning(pprint.pformat(cfg))

_dataset = core.Configurable.load_config_dict(cfg.dataset)
train_set, valid_set, test_set = _dataset.split()
solver = build_solver(cfg)
logger.warning(f"****** solver = {solver}")

if "checkpoint" in cfg:
solver_load(cfg.checkpoint)

entity_vocab, relation_vocab = load_vocab(_dataset)

for i in range(len(solver.test_set)):
visualize_path(solver, solver.test_set[i], entity_vocab, relation_vocab)