Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build/plugins/graph_processing/cg-gnn-cuda.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ RUN pip install scipy==1.15.1
RUN pip install dgl -f https://data.dgl.ai/wheels/cu118/repo.html
RUN pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html
ENV DGLBACKEND=pytorch
RUN pip install cg-gnn==0.3.2

# Make the files you need in this directory available everywhere in the container
ADD . /app
Expand Down
1 change: 0 additions & 1 deletion build/plugins/graph_processing/cg-gnn.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ RUN pip install torch --index-url https://download.pytorch.org/whl/cpu
RUN pip install dgl -f https://data.dgl.ai/wheels/repo.html
RUN pip install dglgo -f https://data.dgl.ai/wheels-test/repo.html
ENV DGLBACKEND=pytorch
RUN pip install cg-gnn==0.3.2

# Make the files you need in this directory available everywhere in the container
ADD . /app
Expand Down
5 changes: 5 additions & 0 deletions plugin/graph_processing/cg-gnn/cggnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Train and explain a graph neural network on a dataset of cell graphs."""

from cggnn.train import train, infer, infer_with_model
from cggnn.importance import calculate_importance, unify_importance_across, save_importances
from cggnn.separability import calculate_separability
102 changes: 102 additions & 0 deletions plugin/graph_processing/cg-gnn/cggnn/importance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
"""
Calculate importance scores per node in an ROI.

As used in:
"Quantifying Explainers of Graph Neural Networks in Computational Pathology",
Jaume et al, CVPR, 2021.
"""

from typing import List, Optional, Dict, Tuple, DefaultDict

from tqdm import tqdm
from numpy import average
from torch import FloatTensor
from torch.cuda import is_available
from dgl import DGLGraph
from pandas import Series

from cggnn.util import CellGraphModel, set_seeds
from cggnn.util.constants import IMPORTANCES, INDICES
from cggnn.util.interpretability import (BaseExplainer, GraphLRPExplainer, GraphGradCAMExplainer,
GraphGradCAMPPExplainer, GraphPruningExplainer)
from cggnn.train import infer_with_model

IS_CUDA = is_available()
DEVICE = 'cuda:0' if IS_CUDA else 'cpu'


def calculate_importance(cell_graphs: List[DGLGraph],
model: CellGraphModel,
explainer_model: str,
random_seed: Optional[int] = None
) -> List[DGLGraph]:
"""Calculate the importance for all cells in every graph."""
explainer: BaseExplainer
explainer_model = explainer_model.lower().strip()
if explainer_model in {'lrp', 'graphlrpexplainer'}:
explainer = GraphLRPExplainer(model=model)
elif explainer_model in {'cam', 'gradcam', 'graphgradcamexplainer'}:
explainer = GraphGradCAMExplainer(model=model)
elif explainer_model in {'pp', 'campp', 'gradcampp', 'graphgradcamppexplainer'}:
explainer = GraphGradCAMPPExplainer(model=model)
elif explainer_model in {'pruning', 'gnn', 'graphpruningexplainer'}:
explainer = GraphPruningExplainer(model=model)
else:
raise ValueError("explainer_model not recognized.")

if random_seed is not None:
set_seeds(random_seed)

# Set model to train so it'll let us do backpropogation.
# This shouldn't be necessary since we don't want the model to change at all while running the
# explainer. In fact, it isn't necessary when running the original histocartography code, but
# in this version of python and torch, it results in a can't-backprop-in-eval error in torch
# because calculating the weights requires backprop-ing to get the backward_hook.
# TODO: Fix this.
model = model.train()

# Calculate the importance scores for every graph
for graph in tqdm(cell_graphs):
importance_scores, _ = explainer.process(graph.to(DEVICE))
graph.ndata[IMPORTANCES] = FloatTensor(importance_scores)

return cell_graphs


def unify_importance_across(graphs_by_specimen: List[List[DGLGraph]],
model: CellGraphModel,
random_seed: Optional[int] = None
) -> Dict[int, float]:
"""Merge importance values for all cells in all ROIs in all specimens."""
if random_seed is not None:
set_seeds(random_seed)
hs_id_to_importance: Dict[int, float] = {}
for graphs in graphs_by_specimen:
for hs_id, importance in _unify_importance(graphs, model).items():
if hs_id in hs_id_to_importance:
raise RuntimeError(
'The same histological structure ID appears in multiple specimens.')
hs_id_to_importance[hs_id] = importance
return hs_id_to_importance


def _unify_importance(graphs: List[DGLGraph], model: CellGraphModel) -> Dict[int, float]:
"""Merge the importance values for each cell in a specimen."""
probs = infer_with_model(model, graphs, return_probability=True)
hs_id_to_importances: Dict[int, List[Tuple[float, float]]] = DefaultDict(list)
for i_graph, graph in enumerate(graphs):
for i in range(graph.num_nodes()):
hs_id_to_importances[graph.ndata[INDICES][i].item()].append(
(graph.ndata[IMPORTANCES][i], max(probs[i_graph, ])))
hs_id_to_importance: Dict[int, float] = {}
for hs_id, importance_confidences in hs_id_to_importances.items():
hs_id_to_importance[hs_id] = average([ic[0] for ic in importance_confidences],
weights=[ic[1] for ic in importance_confidences])
return hs_id_to_importance


def save_importances(hs_id_to_importance: Dict[int, float], out_directory: str) -> None:
"""Save importance scores per histological structure to CSV."""
s = Series(hs_id_to_importance).sort_index()
s.name = 'importance'
s.to_csv(out_directory)
80 changes: 80 additions & 0 deletions plugin/graph_processing/cg-gnn/cggnn/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Functions that run key pipelines of the cggnn model."""

from typing import Dict, List, DefaultDict, Tuple, Optional
from os.path import join

from pandas import DataFrame
from dgl import DGLGraph # type: ignore

from cggnn import train, calculate_importance, unify_importance_across, save_importances, \
calculate_separability
from cggnn.util import GraphData, load_cell_graphs, save_cell_graphs, instantiate_model, \
load_label_to_result, CellGraphModel


def train_and_evaluate(cg_directory: str,
in_ram: bool = False,
batch_size: int = 1,
epochs: int = 10,
learning_rate: float = 1e-3,
k_folds: int = 0,
explainer: Optional[str] = None,
merge_rois: bool = False,
random_seed: Optional[int] = None,
) -> Tuple[CellGraphModel, List[GraphData], Optional[Dict[int, float]]]:
"""Train a CG-GNN on pre-split sets of cell graphs and explain it if requested."""
graphs_data = load_cell_graphs(cg_directory)[0]
model = train(graphs_data,
cg_directory,
in_ram=in_ram,
epochs=epochs,
learning_rate=learning_rate,
batch_size=batch_size,
k_folds=k_folds,
random_seed=random_seed)
hs_id_to_importance: Optional[Dict[int, float]] = None
if explainer is not None:
cell_graphs = calculate_importance([d.graph for d in graphs_data],
model,
explainer,
random_seed=random_seed)
graphs_data = [d._replace(graph=graph) for d, graph in zip(graphs_data, cell_graphs)]
save_cell_graphs(graphs_data, cg_directory)
if merge_rois:
cell_graphs_by_specimen: Dict[str, List[DGLGraph]] = DefaultDict(list)
for cg in graphs_data:
cell_graphs_by_specimen[cg.specimen].append(cg.graph)
hs_id_to_importance = unify_importance_across(
list(cell_graphs_by_specimen.values()),
model,
random_seed=random_seed)
save_importances(hs_id_to_importance, join(cg_directory, 'importances.csv'))
return model, graphs_data, hs_id_to_importance


def find_separability(cg_path: str,
model_checkpoint_path: str,
label_to_result_path: Optional[str] = None,
prune_misclassified: bool = False,
output_directory: Optional[str] = None,
random_seed: Optional[int] = None,
) -> Tuple[DataFrame,
DataFrame,
Dict[Tuple[int, int] | Tuple[str, str], DataFrame]]:
"""Calculate separability scores for a cell graph dataset."""
graphs_data, feature_names = load_cell_graphs(cg_path)
df_concept, df_aggregated, dfs_k_dist = calculate_separability(
graphs_data,
instantiate_model(graphs_data, model_checkpoint_path=model_checkpoint_path),
feature_names,
label_to_result=load_label_to_result(label_to_result_path)
if label_to_result_path else None,
prune_misclassified=prune_misclassified,
out_directory=output_directory,
random_seed=random_seed)
print(df_concept)
print(df_aggregated)
for cg_pair, df_k in dfs_k_dist.items():
print(cg_pair)
print(df_k)
return df_concept, df_aggregated, dfs_k_dist
66 changes: 66 additions & 0 deletions plugin/graph_processing/cg-gnn/cggnn/scripts/separability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Explain a cell graph (CG) prediction using a pretrained CG-GNN and a graph explainer."""

from argparse import ArgumentParser

from cggnn.run import find_separability


def parse_arguments():
"""Process command line arguments."""
parser = ArgumentParser(
description='Explain a cell graph prediction using a model and a graph explainer.',
)
parser.add_argument(
'--cg_path',
type=str,
help='Directory with the cell graphs, metadata, and feature names.',
required=True
)
parser.add_argument(
'--feature_names_path',
type=str,
help='Path to the list of feature names.',
required=True
)
parser.add_argument(
'--model_checkpoint_path',
type=str,
help='Path to the model checkpoint.',
required=True
)
parser.add_argument(
'--label_to_result_path',
type=str,
help='Where to find the data mapping label ints to their string results.',
required=False
)
parser.add_argument(
'--prune_misclassified',
help='Remove entries for misclassified cell graphs when calculating separability scores.',
action='store_true'
)
parser.add_argument(
'--output_directory',
type=str,
help='Where to save the output reporting.',
default=None,
required=False
)
parser.add_argument(
'--random_seed',
type=int,
help='Random seed to use for reproducibility.',
default=None,
required=False
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()
find_separability(args.cg_path,
args.model_checkpoint_path,
args.label_to_result_path,
args.prune_misclassified,
args.output_directory,
args.random_seed)
89 changes: 89 additions & 0 deletions plugin/graph_processing/cg-gnn/cggnn/scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""Train a CG-GNN on pre-split sets of cell graphs."""

from argparse import ArgumentParser

from cggnn.run import train_and_evaluate


def parse_arguments():
"""Parse command line arguments."""
parser = ArgumentParser(
description='Train a GNN on cell graphs.',
)
parser.add_argument(
'--cg_directory',
type=str,
help='Directory with the cell graphs, metadata, and feature names. '
'Model results and any other output will be saved to this directory.',
required=True
)
parser.add_argument(
'--in_ram',
help='If the data should be stored in RAM.',
action='store_true',
)
parser.add_argument(
'-b',
'--batch_size',
type=int,
help='Batch size to use during training.',
default=1,
required=False
)
parser.add_argument(
'--epochs',
type=int,
help='Number of training epochs to do.',
default=10,
required=False
)
parser.add_argument(
'-l',
'--learning_rate',
type=float,
help='Learning rate to use during training.',
default=1e-3,
required=False
)
parser.add_argument(
'-k',
'--k_folds',
type=int,
help='Folds to use in k-fold cross validation. 0 means don\'t use k-fold cross validation '
'unless no validation dataset is provided, in which case k defaults to 3.',
required=False,
default=0
)
parser.add_argument(
'--explainer',
type=str,
help='Which explainer type to use. If provided, importance scores will be calculated.',
default=None,
required=False
)
parser.add_argument(
'--merge_rois',
help='Save a CSV of importance scores merged across ROIs from a single specimen.',
action='store_true'
)
parser.add_argument(
'--random_seed',
type=int,
help='Random seed to use for reproducibility.',
default=None,
required=False
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_arguments()
train_and_evaluate(args.cg_directory,
args.in_ram,
args.batch_size,
args.epochs,
args.learning_rate,
args.k_folds,
args.explainer,
args.merge_rois,
args.random_seed)
Loading