Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions placeholder/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def main():
# * Initialize
name = config["name"]
if name.lower() == "none":
name = str(uuid.uuid4()).split["-"][0]
name = str(uuid.uuid4()).split("-")[0]
config["name"] = name

cwd = Path(os.getcwd())
Expand Down Expand Up @@ -71,11 +71,6 @@ def main():
task_level="graph",
return_type="raw"))

labels = ",".join(explain.V1)
header = "id,smiles,real,pred_label,pred,fragment,"
with open(out / "predictions.csv", "w") as f:
f.write(f"{header+labels}\n")

for data in tqdm(dataloader, ncols=120, desc="Explaining"):
data.to(device)

Expand Down
341 changes: 201 additions & 140 deletions placeholder/source/explain/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, explanation, pred, pred_lbl, loader, out_dir, k=10):
self.batch = explanation.batch.detach().cpu()
self.pred = pred
self.pred_lbl = pred_lbl
self.loader = loader
self.out = out_dir
self.k = k
self.cut = k // 2
Expand All @@ -51,6 +52,39 @@ def __init__(self, explanation, pred, pred_lbl, loader, out_dir, k=10):
self.labels = np.array(V2)

def retrieve_info(self, graphs):
if self.loader == "default":
header = "id,smiles,real,pred_label,pred,fragment"
with open(self.out / "predictions.csv", "w") as f:
f.write(f"{header}\n")

self._info_atom(graphs)

else:
labels = ",".join(V1)
header = "id,smiles,real,pred_label,pred"
with open(self.out / "predictions.csv", "w") as f:
f.write(f"{header+labels}\n")

self._info_fragment(graphs)

def _info_atom(self, graphs):
batch_num = self.batch.unique()
masks = [self.mask[self.batch == b] for b in batch_num]

for idx in range(len(masks)):
data = graphs[idx]
real_label = int(data.y.cpu().numpy()[0])
pred = self.pred[idx]
pred_label = self.pred_lbl[idx].numpy()[0]

text = (f"{data.idx},{data.smiles},{real_label},"
f"{pred_label},{pred:.3f}\n")

# * Export prediction
with open(self.out / "predictions.csv", "a") as f:
f.writelines(text)

def _info_fragment(self, graphs):
batch_num = self.batch.unique()
masks = [self.mask[self.batch == b] for b in batch_num]

Expand Down Expand Up @@ -81,106 +115,159 @@ def plot_explanations(self, graphs):

for idx, node_mask in enumerate(masks):
data = graphs[idx]
smiles = data.smiles
fragments = data.frag
name = data.idx
pred = self.pred[idx]
pred_label = self.pred_lbl[idx].numpy()[0]
mol = Chem.MolFromSmiles(smiles)

# * Feature-level bar plot
mask_feat = torch.sum(node_mask, dim=0).numpy()
feats = self._get_top(mask_feat)

fig, ax = plt.subplots(figsize=(10, 6))
all_lbl = np.append(feats[1]["labels"], feats[0]["labels"])
all_val = np.append(feats[1]["contrib"], feats[0]["contrib"])
colors = (["SteelBlue"] * len(feats[1]["labels"]) +
["DarkOrange"] * len(feats[0]["labels"]))

ax.barh(all_lbl, all_val, color=colors)
ax.set_title(f"Top {self.k} Features - {data.idx}")
ax.set_xlabel("Contribution")
ax.invert_yaxis()

plt.xlim(mask_feat.min() * 1.15, mask_feat.max() * 1.15)
for i, v in enumerate(all_val):
x_off = 0.02 if v > 0 else -0.3
ax.text(v + x_off, i, str(v), va="center", fontsize=8)

plt.tight_layout()
out_feat = self.out / f"{data.idx}_feat.svg"
fig.savefig(out_feat, format="svg")
plt.close(fig)

# * Molecule-level visualization
atom_map = dict(zip(data.atom_map[0], data.atom_map[1]))
mask_atom = torch.sum(node_mask, dim=1).numpy()
mask_atom = np.round(mask_atom, 3)

min_val = mask_atom.min()
max_val = mask_atom.max()
if min_val > 0:
max_val *= 1.3
cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0,
vmax=max_val),
cmap=mpl.cm.Blues)
elif max_val < 0:
min_val *= 1.3
cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val,
vmax=0),
cmap=mpl.cm.Oranges_r)

if self.loader == "default":
self._explain_atom(data, node_mask, pred, pred_label, name)

else:
min_val *= 1.3
max_val *= 1.3
pos_colors = plt.cm.Blues(np.linspace(0, 1, 128))
neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128))
combined = np.vstack((neg_colors, pos_colors))
color = LinearSegmentedColormap.from_list("OrBu", combined)
cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val,
vcenter=0,
vmax=max_val),
cmap=color)

highlight_node = {}
for atom in mol.GetAtoms():
idx = atom.GetIdx()
frag_val = mask_atom[atom_map[idx]]
rgb = cmap.to_rgba(frag_val)[:-1]
highlight_node[idx] = [rgb]
atom.SetProp("atomNote", str(frag_val))

legend = (f"Graph ID: {data.idx}\n{smiles}\n"
f"Prediction: {pred:.3f} ({pred_label})"
f"\tTrue: {float(data.y)}")

drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800)
opts = drawer.drawOptions()
opts.fillHighlights = True
opts.annotationFontScale = 0.5
opts.legendFontSize = 25
drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node,
{}, {}, {})
drawer.FinishDrawing()

with open(self.out / f"{data.idx}_mol.svg", "w") as f:
f.write(drawer.GetDrawingText())

# * Fragment-level visualization
mask_frag = node_mask.numpy().tolist()
frag_imgs = []

for i, frag in enumerate(fragments):
top_val = self._get_top(mask_frag[i])
label = np.append(top_val[1]["labels"], top_val[0]["labels"])
cntrb = np.append(top_val[1]["contrib"], top_val[0]["contrib"])

contrib = np.sum(mask_frag[i]).round(3)
entries = [f"{lbl}: {val}" for lbl, val in zip(label, cntrb)]
frag_imgs.append(self._subplot(frag, entries, contrib))

# Create image
fig = self._create_frag_image(frag_imgs, 1600, 300)
fig.save(self.out / f"{data.idx}_frag.svg")
# * Feature-level bar plot
self._bar_plot(node_mask, name)

# * Fragment-level visualization
fragments = data.frag
self._frag_visualization(node_mask, fragments, name)

# * Molecule-level visualization
self._explain_frag(data, node_mask, pred, pred_label, name)

def _explain_atom(self, data, node_mask, pred, pred_label, name):
smiles = data.smiles
mol = Chem.MolFromSmiles(smiles)

mask_atom = torch.sum(node_mask, dim=1).numpy()
mask_atom = np.round(mask_atom, 3)

min_val = mask_atom.min()
max_val = mask_atom.max()
if min_val > 0:
max_val *= 1.3
cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0,
vmax=max_val),
cmap=mpl.cm.Blues)
elif max_val < 0:
min_val *= 1.3
cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val,
vmax=0),
cmap=mpl.cm.Oranges_r)
else:
min_val *= 1.3
max_val *= 1.3
pos_colors = plt.cm.Blues(np.linspace(0, 1, 128))
neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128))
combined = np.vstack((neg_colors, pos_colors))
color = LinearSegmentedColormap.from_list("OrBu", combined)
cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val,
vcenter=0,
vmax=max_val),
cmap=color)

highlight_node = {}
for atom in mol.GetAtoms():
idx = atom.GetIdx()
rgb = cmap.to_rgba(mask_atom[idx])[:-1]
highlight_node[idx] = [rgb]
atom.SetProp("atomNote", str(mask_atom[idx]))

legend = (f"Graph ID: {name}\n{smiles}\n"
f"Prediction: {pred:.3f} ({pred_label})"
f"\tTrue: {float(data.y)}")

drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800)
opts = drawer.drawOptions()
opts.fillHighlights = True
opts.annotationFontScale = 0.5
opts.legendFontSize = 25
drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node,
{}, {}, {})
drawer.FinishDrawing()

with open(self.out / f"{data.idx}_mol.svg", "w") as f:
f.write(drawer.GetDrawingText())

def _explain_frag(self, data, node_mask, pred, pred_label, name):
smiles = data.smiles
mol = Chem.MolFromSmiles(smiles)

atom_map = dict(zip(data.atom_map[0], data.atom_map[1]))
mask_atom = torch.sum(node_mask, dim=1).numpy()
mask_atom = np.round(mask_atom, 3)

min_val = mask_atom.min()
max_val = mask_atom.max()
if min_val > 0:
max_val *= 1.3
cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0,
vmax=max_val),
cmap=mpl.cm.Blues)
elif max_val < 0:
min_val *= 1.3
cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val,
vmax=0),
cmap=mpl.cm.Oranges_r)
else:
min_val *= 1.3
max_val *= 1.3
pos_colors = plt.cm.Blues(np.linspace(0, 1, 128))
neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128))
combined = np.vstack((neg_colors, pos_colors))
color = LinearSegmentedColormap.from_list("OrBu", combined)
cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val,
vcenter=0,
vmax=max_val),
cmap=color)

highlight_node = {}
for atom in mol.GetAtoms():
idx = atom.GetIdx()
frag_val = mask_atom[atom_map[idx]]
rgb = cmap.to_rgba(frag_val)[:-1]
highlight_node[idx] = [rgb]
atom.SetProp("atomNote", str(frag_val))

legend = (f"Graph ID: {name}\n{smiles}\n"
f"Prediction: {pred:.3f} ({pred_label})"
f"\tTrue: {float(data.y)}")

drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800)
opts = drawer.drawOptions()
opts.fillHighlights = True
opts.annotationFontScale = 0.5
opts.legendFontSize = 25
drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node,
{}, {}, {})
drawer.FinishDrawing()

with open(self.out / f"{name}_mol.svg", "w") as f:
f.write(drawer.GetDrawingText())

def _bar_plot(self, node_mask, name):
# * Feature-level bar plot
mask_feat = torch.sum(node_mask, dim=0).numpy()
feats = self._get_top(mask_feat)

fig, ax = plt.subplots(figsize=(10, 6))
all_lbl = np.append(feats[1]["labels"], feats[0]["labels"])
all_val = np.append(feats[1]["contrib"], feats[0]["contrib"])
colors = (["SteelBlue"] * len(feats[1]["labels"]) +
["DarkOrange"] * len(feats[0]["labels"]))

ax.barh(all_lbl, all_val, color=colors)
ax.set_title(f"Top {self.k} Features - {name}")
ax.set_xlabel("Contribution")
ax.invert_yaxis()

plt.xlim(mask_feat.min() * 1.15, mask_feat.max() * 1.15)
for i, v in enumerate(all_val):
x_off = 0.02 if v > 0 else -0.3
ax.text(v + x_off, i, str(v), va="center", fontsize=8)

plt.tight_layout()
out_feat = self.out / f"{name}_feat.svg"
fig.savefig(out_feat, format="svg")
plt.close(fig)

def _get_top(self, mask, fragments=None):
mask = np.round(mask, 3)
Expand Down Expand Up @@ -217,6 +304,23 @@ def _get_top(self, mask, fragments=None):
cuts = {0: neg, 1: pos}
return cuts

def _frag_visualization(self, node_mask, fragments, name):
mask_frag = node_mask.numpy().tolist()
frag_imgs = []

for i, frag in enumerate(fragments):
top_val = self._get_top(mask_frag[i])
label = np.append(top_val[1]["labels"], top_val[0]["labels"])
cntrb = np.append(top_val[1]["contrib"], top_val[0]["contrib"])

contrib = np.sum(mask_frag[i]).round(3)
entries = [f"{lbl}: {val}" for lbl, val in zip(label, cntrb)]
frag_imgs.append(self._subplot(frag, entries, contrib))

# Create image
fig = self._create_frag_image(frag_imgs, 1600, 300)
fig.save(self.out / f"{name}_frag.svg")

def _subplot(self, frag, entries, contrib, size=(500, 250)):
mol = Chem.MolFromSmiles(frag)

Expand Down Expand Up @@ -267,46 +371,3 @@ def _create_frag_image(self, images, width=1600, height=300):
fig = sg.SVGFigure(str(width), str(h))
fig.append([background] + images)
return fig


def plot_counters(data, out, prefix="", top_n=35):
fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(18, 12))

col = 0
keys = ["0_0", "0_1", "1_0", "1_1"]
title = {"0_0": "Predicted: 0, Real: 0",
"0_1": "Predicted: 0, Real: 1",
"1_1": "Predicted: 1, Real: 1",
"1_0": "Predicted: 1, Real: 0"}

for key in keys:
inner_dict = data[key]

# Get class 0 and class 1 dict
class_0 = inner_dict[0]
class_0 = dict(sorted(class_0.items(), key=lambda item: item[1],
reverse=True))
class_0_lbl = list(class_0.keys())[: top_n]
class_0_num = list(class_0.values())[: top_n]

class_1 = inner_dict[1]
class_1 = dict(sorted(class_1.items(), key=lambda item: item[1],
reverse=True))
class_1_lbl = list(class_1.keys())[: top_n]
class_1_num = list(class_1.values())[: top_n]

# Plot
axes[0, col].bar(class_1_lbl, class_1_num, color="RoyalBlue")
axes[0, col].set_title(f"{title[key]} - Class 1")
axes[0, col].tick_params(axis="x", rotation=90)

axes[1, col].bar(class_0_lbl, class_0_num, color="Crimson")
axes[1, col].set_title(f"{title[key]} - Class 0")
axes[1, col].tick_params(axis="x", rotation=90)

col += 1

plt.tight_layout()
out_feat = out / f"all_{prefix}.svg"
fig.savefig(out_feat, format="svg")
plt.close(fig)