Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 0 additions & 13 deletions .github/workflows/wip.yml

This file was deleted.

18 changes: 10 additions & 8 deletions frame/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def main():

# * Get checkpoint and prepare Explainer
model = models.select_model(model_name, tune)
model.load_state_dict(torch.load(path_checkpoint,
map_location=device,
weights_only=True))
model.load_state_dict(torch.load(path_checkpoint))
model.eval()

if task == "classification":
Expand All @@ -79,7 +77,7 @@ def main():
return_type="raw"))

for data in tqdm(dataloader, ncols=120, desc="Explaining"):
data = data.to(device)
data.to(device)

# * Make predictions
model_out = model(x=data.x.float(),
Expand All @@ -89,19 +87,23 @@ def main():

# * Read prediction values
if task == "classification":
detach = torch.sigmoid(model_out).cpu().detach()
pred = list(torch.ravel(detach).cpu().detach().numpy())
logit = model_out.cpu().detach()
logit_list = list(torch.ravel(logit).numpy())
detach = torch.sigmoid(logit)
pred = list(torch.ravel(detach).numpy())
pred_lbl = (detach >= 0.5).int()
else:
detach = model_out.cpu().detach()
pred = list(torch.ravel(detach).cpu().detach().numpy())
pred = list(torch.ravel(detach).numpy())
logit_list = pred
pred_lbl = [None] * detach.shape[0]

# * Explain
explanation = explainer(data.x.float(), data.edge_index,
edge_attr=data.edge_attr.float(),
batch=data.batch)

mol_exp = explain.MolExplain(explanation, pred, pred_lbl, loader, out)
mol_exp = explain.MolExplain(explanation, logit_list, pred, pred_lbl,
loader, out)
mol_exp.retrieve_info(data)
mol_exp.plot_explanations(data)
46 changes: 34 additions & 12 deletions frame/source/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@


class MolExplain:
def __init__(self, explanation, pred, pred_lbl, loader, out_dir, k=10):
def __init__(self, explanation, logit, pred, pred_lbl, loader,
out_dir, k=10):
self.mask = explanation.node_mask.detach().cpu()
self.batch = explanation.batch.detach().cpu()
self.pred = pred
self.pred_lbl = pred_lbl
self.logit = logit
self.loader = loader
self.out = out_dir
self.k = k
Expand Down Expand Up @@ -103,17 +105,19 @@ def _info_fragment(self, graphs):
def plot_explanations(self, graphs):
batch_num = self.batch.unique()
masks = [self.mask[self.batch == b] for b in batch_num]
pred_label = ""

for idx, node_mask in enumerate(masks):
data = graphs[idx]
name = data.idx
pred = self.pred[idx]
if self.pred_lbl[idx] is not None:
pred_label = self.pred_lbl[idx].numpy()[0]
logit = self.logit[idx]
pred_label = self.pred_lbl[idx]
if pred_label is not None:
pred_label = pred_label.numpy()[0]

if self.loader == "default":
self._explain_atom(data, node_mask, pred, pred_label, name)
self._explain_atom(data, node_mask, pred, logit,
pred_label, name)

else:
# * Feature-level bar plot
Expand All @@ -124,13 +128,22 @@ def plot_explanations(self, graphs):
# self._frag_visualization(node_mask, fragments, name)

# * Molecule-level visualization
self._explain_frag(data, node_mask, pred, pred_label, name)
self._explain_frag(data, node_mask, pred, logit,
pred_label, name)

def _explain_atom(self, data, node_mask, pred, pred_label, name):
def _rescale_mask(sel, mask_atom, logit):
"""Rescale per-atom attributions so they sum to the raw logit."""
current_sum = mask_atom.sum()
if current_sum == 0:
return mask_atom
return mask_atom * (logit / current_sum)

def _explain_atom(self, data, node_mask, pred, logit, pred_label, name):
smiles = data.smiles
mol = Chem.MolFromSmiles(smiles)

mask_atom = torch.sum(node_mask, dim=1).numpy()
mask_atom = self._rescale_mask(mask_atom, logit)
mask_atom = np.round(mask_atom, 3)

min_val = mask_atom.min()
Expand Down Expand Up @@ -165,8 +178,12 @@ def _explain_atom(self, data, node_mask, pred, pred_label, name):
atom.SetProp("atomNote", str(mask_atom[idx]))

legend = (f"Graph ID: {name}\n{smiles}\n"
f"Prediction: {pred:.3f} ({pred_label})"
f"\tTrue: {float(data.y)}")
f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t"
f"Class: {pred_label}\tTrue: {int(data.y)}")

if pred_label is None: # Regression
legend = (f"Graph ID: {name}\n{smiles}\n"
f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}")

drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800)
opts = drawer.drawOptions()
Expand All @@ -180,12 +197,13 @@ def _explain_atom(self, data, node_mask, pred, pred_label, name):
with open(self.out / f"{data.idx}.svg", "w") as f:
f.write(drawer.GetDrawingText())

def _explain_frag(self, data, node_mask, pred, pred_label, name):
def _explain_frag(self, data, node_mask, pred, logit, pred_label, name):
smiles = data.smiles
mol = Chem.MolFromSmiles(smiles)

atom_map = dict(zip(data.atom_map[0], data.atom_map[1]))
mask_atom = torch.sum(node_mask, dim=1).numpy()
mask_atom = self._rescale_mask(mask_atom, logit)
mask_atom = np.round(mask_atom, 3)

min_val = mask_atom.min()
Expand Down Expand Up @@ -221,8 +239,12 @@ def _explain_frag(self, data, node_mask, pred, pred_label, name):
atom.SetProp("atomNote", str(frag_val))

legend = (f"Graph ID: {name}\n{smiles}\n"
f"Prediction: {pred:.3f} ({pred_label})"
f"\tTrue: {float(data.y)}")
f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t"
f"Class: {pred_label}\tTrue: {int(data.y)}")

if pred_label is None: # Regression
legend = (f"Graph ID: {name}\n{smiles}\n"
f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}")

drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800)
opts = drawer.drawOptions()
Expand Down
Loading