diff --git a/config/genefunction/gfun_vis.yaml b/config/genefunction/gfun_vis.yaml index 90dc844..a80634a 100644 --- a/config/genefunction/gfun_vis.yaml +++ b/config/genefunction/gfun_vis.yaml @@ -1,8 +1,9 @@ -output_dir: experiments/ +output_dir: /pasteur/appa/homes/claudy/2026-BPN-paper-data/gene-function-prediction/ + dataset: class: biomedical - path: data/PC_KEGG_0928 + path: /pasteur/appa/homes/claudy/BioPathNet/data/PC_KEGG_0602 include_factgraph: yes task: diff --git a/config/mock/mockdata_inductive.yaml b/config/mock/mockdata_inductive.yaml index d2ee30e..b3b0b0e 100644 --- a/config/mock/mockdata_inductive.yaml +++ b/config/mock/mockdata_inductive.yaml @@ -1,8 +1,8 @@ -output_dir: experiments/ +output_dir: /Users/claudy/work/projects/BioPathNet/experiments/ dataset: class: BiomedicalInductive - path: data/mock_inductive + path: /Users/claudy/work/projects/BioPathNet/data/mock_inductive task: class: KnowledgeGraphCompletionBiomedInductive diff --git a/config/mock/mockdata_inductive_pred.yaml b/config/mock/mockdata_inductive_pred.yaml index 383598e..d614fe9 100644 --- a/config/mock/mockdata_inductive_pred.yaml +++ b/config/mock/mockdata_inductive_pred.yaml @@ -1,8 +1,8 @@ -output_dir: experiments/ +output_dir: /Users/claudy/work/projects/BioPathNet/experiments/ dataset: class: BiomedicalInductive - path: data/mock_inductive + path: /Users/claudy/work/projects/BioPathNet/data/mock_inductive task: class: KnowledgeGraphCompletionBiomedInductive @@ -37,4 +37,4 @@ train: num_epoch: 5 metric: mrr -checkpoint: {{checkpoint}} \ No newline at end of file +checkpoint: {{checkpoint}} diff --git a/config/mock/mockdata_inductive_vis.yaml b/config/mock/mockdata_inductive_vis.yaml new file mode 100644 index 0000000..e9c5d87 --- /dev/null +++ b/config/mock/mockdata_inductive_vis.yaml @@ -0,0 +1,52 @@ +output_dir: /Users/claudy/work/projects/BioPathNet/experiments/ + +dataset: + class: BiomedicalInductive + path: /Users/claudy/work/projects/BioPathNet/data/mock_inductive + +task: + class: KnowledgeGraphCompletionBiomedInductive + model: + class: NBFNet + input_dim: 32 + hidden_dims: [32, 32, 32, 32, 32, 32] + message_func: distmult + aggregate_func: pna + short_cut: yes + layer_norm: yes + dependent: yes + symmetric: no + criterion: bce + num_negative: 32 + strict_negative: yes + adversarial_temperature: 0.5 + sample_weight: no + heterogeneous_negative: yes + heterogeneous_evaluation: yes + full_batch_eval: no +<<<<<<< HEAD +======= + remove_pos: no +>>>>>>> 776380a (feat(config files): adds config files for inductive models on mock data and missing visualisation for gene function prediction.) + +optimizer: + class: Adam + lr: 5.0e-3 + +engine: + gpus: {{ gpus }} + batch_size: 4 + +<<<<<<< HEAD +======= + +>>>>>>> 776380a (feat(config files): adds config files for inductive models on mock data and missing visualisation for gene function prediction.) +train: + num_epoch: 5 + +metric: mrr +<<<<<<< HEAD +======= + +>>>>>>> 776380a (feat(config files): adds config files for inductive models on mock data and missing visualisation for gene function prediction.) +checkpoint: {{checkpoint}} diff --git a/script/eval_and_predict_inductive.py b/script/eval_and_predict_inductive.py index 0513949..487f37a 100644 --- a/script/eval_and_predict_inductive.py +++ b/script/eval_and_predict_inductive.py @@ -167,9 +167,10 @@ def load_dataset_and_solver(cfg, test_file): if __name__ == "__main__": args, vars = util.parse_args() + logger = util.get_root_logger() cfg = util.load_config(args.config, context=vars) working_dir = util.create_working_directory(cfg) - print(working_dir) + logger.warning(f"working directory = {working_dir}") # get entity names if 'entity_files' in cfg.dataset: vfile = cfg.dataset.entity_files[1] @@ -180,8 +181,6 @@ def load_dataset_and_solver(cfg, test_file): torch.manual_seed(args.seed + comm.get_rank()) - logger = util.get_root_logger() - logger.warning("Working directory: %s" % working_dir) if comm.get_rank() == 0: logger.warning("Config file: %s" % args.config) #logger.warning(pprint.pformat(cfg)) @@ -192,6 +191,7 @@ def load_dataset_and_solver(cfg, test_file): # Prediction phase cfg_pred, solver, _dataset, entity_vocab, relation_vocab = load_dataset_and_solver(cfg, 'test_pred.txt') + logger.warning(f"******* solver = {solver}") test(cfg_pred, solver) logger.warning(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") @@ -206,4 +206,4 @@ def load_dataset_and_solver(cfg, test_file): logger.warning(f"Saving to file {os.path.join(working_dir, 'predictions.csv')}") df.to_csv(os.path.join(working_dir, "predictions.csv"), index=False, sep="\t") logger.warning("Done") - logger.warning("------------------------------") \ No newline at end of file + logger.warning("------------------------------") diff --git a/script/visualize_inductive.py b/script/visualize_inductive.py new file mode 100644 index 0000000..db382d6 --- /dev/null +++ b/script/visualize_inductive.py @@ -0,0 +1,157 @@ +import os +import sys +import pprint + +import torch + +from torchdrug import core +from torchdrug.utils import comm + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from biopathnet import dataset, layer, model, task, util + + + + +def solver_load(checkpoint, load_optimizer=True): + + if comm.get_rank() == 0: + logger.warning("Load checkpoint from %s" % checkpoint) + checkpoint = os.path.expanduser(checkpoint) + state = torch.load(checkpoint, map_location=solver.device) + # some issues with loading back the graphs if present + # remove + state["model"].pop("fact_graph", 0) + state["model"].pop("fact_graph_supervision", 0) + state["model"].pop("graph", 0) + state["model"].pop("train_graph", 0) + state["model"].pop("valid_graph", 0) + state["model"].pop("test_graph", 0) + state["model"].pop("full_valid_graph", 0) + state["model"].pop("full_test_graph", 0) + # load without + solver.model.load_state_dict(state["model"], strict=False) + + + if load_optimizer: + solver.optimizer.load_state_dict(state["optimizer"]) + for state in solver.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(solver.device) + + comm.synchronize() + +def build_solver(cfg): + cfg.task.model.num_relation = _dataset.num_relation + _task = core.Configurable.load_config_dict(cfg.task) + cfg.optimizer.params = _task.parameters() + optimizer = core.Configurable.load_config_dict(cfg.optimizer) + if "scheduler" in cfg: + cfg.scheduler.optimizer = optimizer + scheduler = core.Configurable.load_config_dict(cfg.scheduler) + else: + scheduler = None + return core.Engine(_task, train_set, valid_set, test_set, optimizer, scheduler, **cfg.engine) + + + +def load_vocab(dataset): + entity_mapping = {} + logger.debug(f"###### vocab_file = {vocab_file}") + + entity_mapping = {} + with open(vocab_file, "r") as fin: + for line in fin: + k, v = line.strip().split("\t") + entity_mapping[k] = v + + dataset_entity_vocab = list(set(dataset.train_entity_vocab + dataset.test_entity_vocab)) + entity_vocab = [entity_mapping[t] for t in dataset_entity_vocab] + relation_vocab = ["%s (%d)" % (t[t.rfind("/") + 1:].replace("_", " "), i) + for i, t in enumerate(dataset.relation_vocab)] + + # with open(vocab_file, "r") as fin: + # for line in fin: + # k, v = line.strip().split("\t") + # entity_mapping[k] = v + # logger.debug(f"###### entity_mapping = {entity_mapping}") + # entity_vocab = [entity_mapping[t] for t in dataset.entity_vocab] + # logger.debug(f"###### entity_vocab = {entity_vocab}") + # relation_vocab = ["%s (%d)" % (t[t.rfind("/") + 1:].replace("_", " "), i) + # for i, t in enumerate(dataset.relation_vocab)] + # logger.debug(f"###### relation_vocab = {relation_vocab}") + + return entity_vocab, relation_vocab + +def visualize_path(solver, triplet, entity_vocab, relation_vocab): + num_relation = len(relation_vocab) + h, t, r = triplet.tolist() + triplet = torch.as_tensor([[h, t, r]], device=solver.device) + inverse = torch.as_tensor([[t, h, r + num_relation]], device=solver.device) + solver.model.split = "test" + solver.model.eval() + pred, (mask, target) = solver.model.predict_and_target(triplet) + pos_pred = pred.gather(-1, target.unsqueeze(-1)) + rankings = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1 + rankings = rankings.squeeze(0) + + logger.warning("") + samples = (triplet, inverse) + for sample, ranking in zip(samples, rankings): + h, t, r = sample.squeeze(0).tolist() + h_name = entity_vocab[h] + t_name = entity_vocab[t] + r_name = relation_vocab[r % num_relation] + if r >= num_relation: + r_name += "^(-1)" + logger.warning(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>") + logger.warning("rank(%s | %s, %s) = %g" % (t_name, h_name, r_name, ranking)) + + paths, weights = solver.model.visualize(sample) + for path, weight in zip(paths, weights): + triplets = [] + for h, t, r in path: + # try: + h_name = entity_vocab[h] + t_name = entity_vocab[t] + r_name = relation_vocab[r % num_relation] + # except IndexError as e: + # logger.error(f"entity_vocab = {entity_vocab}") + # logger.error(f"entity_vocab.size = {len(entity_vocab)}") + # logger.error(f"h = {h}") + # logger.error(f"t = {t}") + # logger.error(f"r = {r}") + # break + if r >= num_relation: + r_name += "^(-1)" + triplets.append("<%s, %s, %s>" % (h_name, r_name, t_name)) + logger.warning("weight: %g\n\t%s" % (weight, " ->\n\t".join(triplets))) + + +if __name__ == "__main__": + args, vars = util.parse_args() + cfg = util.load_config(args.config, context=vars) + working_dir = util.create_working_directory(cfg) + print(working_dir) + vocab_file = os.path.join(os.path.dirname(__file__), cfg.dataset.path, "entity_names.txt") + vocab_file = os.path.abspath(vocab_file) + torch.manual_seed(args.seed + comm.get_rank()) + + logger = util.get_root_logger() + logger.warning("Working directory: %s" % working_dir) + logger.warning("Config file: %s" % args.config) + logger.warning(pprint.pformat(cfg)) + + _dataset = core.Configurable.load_config_dict(cfg.dataset) + train_set, valid_set, test_set = _dataset.split() + solver = build_solver(cfg) + logger.warning(f"****** solver = {solver}") + + if "checkpoint" in cfg: + solver_load(cfg.checkpoint) + + entity_vocab, relation_vocab = load_vocab(_dataset) + + for i in range(len(solver.test_set)): + visualize_path(solver, solver.test_set[i], entity_vocab, relation_vocab)