From 42829d1afca5a566075402b1c595bdc0e5a9714f Mon Sep 17 00:00:00 2001 From: Rafael Date: Tue, 18 Nov 2025 17:57:33 +0200 Subject: [PATCH] chg: add atom-based explanation --- placeholder/explain.py | 7 +- placeholder/source/explain/__init__.py | 341 +++++++++++++++---------- 2 files changed, 202 insertions(+), 146 deletions(-) diff --git a/placeholder/explain.py b/placeholder/explain.py index c428c78..5d4f77b 100644 --- a/placeholder/explain.py +++ b/placeholder/explain.py @@ -36,7 +36,7 @@ def main(): # * Initialize name = config["name"] if name.lower() == "none": - name = str(uuid.uuid4()).split["-"][0] + name = str(uuid.uuid4()).split("-")[0] config["name"] = name cwd = Path(os.getcwd()) @@ -71,11 +71,6 @@ def main(): task_level="graph", return_type="raw")) - labels = ",".join(explain.V1) - header = "id,smiles,real,pred_label,pred,fragment," - with open(out / "predictions.csv", "w") as f: - f.write(f"{header+labels}\n") - for data in tqdm(dataloader, ncols=120, desc="Explaining"): data.to(device) diff --git a/placeholder/source/explain/__init__.py b/placeholder/source/explain/__init__.py index 402917b..0c0e6d8 100644 --- a/placeholder/source/explain/__init__.py +++ b/placeholder/source/explain/__init__.py @@ -39,6 +39,7 @@ def __init__(self, explanation, pred, pred_lbl, loader, out_dir, k=10): self.batch = explanation.batch.detach().cpu() self.pred = pred self.pred_lbl = pred_lbl + self.loader = loader self.out = out_dir self.k = k self.cut = k // 2 @@ -51,6 +52,39 @@ def __init__(self, explanation, pred, pred_lbl, loader, out_dir, k=10): self.labels = np.array(V2) def retrieve_info(self, graphs): + if self.loader == "default": + header = "id,smiles,real,pred_label,pred,fragment" + with open(self.out / "predictions.csv", "w") as f: + f.write(f"{header}\n") + + self._info_atom(graphs) + + else: + labels = ",".join(V1) + header = "id,smiles,real,pred_label,pred" + with open(self.out / "predictions.csv", "w") as f: + f.write(f"{header+labels}\n") + + self._info_fragment(graphs) + + def _info_atom(self, graphs): + batch_num = self.batch.unique() + masks = [self.mask[self.batch == b] for b in batch_num] + + for idx in range(len(masks)): + data = graphs[idx] + real_label = int(data.y.cpu().numpy()[0]) + pred = self.pred[idx] + pred_label = self.pred_lbl[idx].numpy()[0] + + text = (f"{data.idx},{data.smiles},{real_label}," + f"{pred_label},{pred:.3f}\n") + + # * Export prediction + with open(self.out / "predictions.csv", "a") as f: + f.writelines(text) + + def _info_fragment(self, graphs): batch_num = self.batch.unique() masks = [self.mask[self.batch == b] for b in batch_num] @@ -81,106 +115,159 @@ def plot_explanations(self, graphs): for idx, node_mask in enumerate(masks): data = graphs[idx] - smiles = data.smiles - fragments = data.frag + name = data.idx pred = self.pred[idx] pred_label = self.pred_lbl[idx].numpy()[0] - mol = Chem.MolFromSmiles(smiles) - - # * Feature-level bar plot - mask_feat = torch.sum(node_mask, dim=0).numpy() - feats = self._get_top(mask_feat) - - fig, ax = plt.subplots(figsize=(10, 6)) - all_lbl = np.append(feats[1]["labels"], feats[0]["labels"]) - all_val = np.append(feats[1]["contrib"], feats[0]["contrib"]) - colors = (["SteelBlue"] * len(feats[1]["labels"]) + - ["DarkOrange"] * len(feats[0]["labels"])) - - ax.barh(all_lbl, all_val, color=colors) - ax.set_title(f"Top {self.k} Features - {data.idx}") - ax.set_xlabel("Contribution") - ax.invert_yaxis() - - plt.xlim(mask_feat.min() * 1.15, mask_feat.max() * 1.15) - for i, v in enumerate(all_val): - x_off = 0.02 if v > 0 else -0.3 - ax.text(v + x_off, i, str(v), va="center", fontsize=8) - - plt.tight_layout() - out_feat = self.out / f"{data.idx}_feat.svg" - fig.savefig(out_feat, format="svg") - plt.close(fig) - - # * Molecule-level visualization - atom_map = dict(zip(data.atom_map[0], data.atom_map[1])) - mask_atom = torch.sum(node_mask, dim=1).numpy() - mask_atom = np.round(mask_atom, 3) - - min_val = mask_atom.min() - max_val = mask_atom.max() - if min_val > 0: - max_val *= 1.3 - cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0, - vmax=max_val), - cmap=mpl.cm.Blues) - elif max_val < 0: - min_val *= 1.3 - cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val, - vmax=0), - cmap=mpl.cm.Oranges_r) + + if self.loader == "default": + self._explain_atom(data, node_mask, pred, pred_label, name) + else: - min_val *= 1.3 - max_val *= 1.3 - pos_colors = plt.cm.Blues(np.linspace(0, 1, 128)) - neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128)) - combined = np.vstack((neg_colors, pos_colors)) - color = LinearSegmentedColormap.from_list("OrBu", combined) - cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val, - vcenter=0, - vmax=max_val), - cmap=color) - - highlight_node = {} - for atom in mol.GetAtoms(): - idx = atom.GetIdx() - frag_val = mask_atom[atom_map[idx]] - rgb = cmap.to_rgba(frag_val)[:-1] - highlight_node[idx] = [rgb] - atom.SetProp("atomNote", str(frag_val)) - - legend = (f"Graph ID: {data.idx}\n{smiles}\n" - f"Prediction: {pred:.3f} ({pred_label})" - f"\tTrue: {float(data.y)}") - - drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) - opts = drawer.drawOptions() - opts.fillHighlights = True - opts.annotationFontScale = 0.5 - opts.legendFontSize = 25 - drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node, - {}, {}, {}) - drawer.FinishDrawing() - - with open(self.out / f"{data.idx}_mol.svg", "w") as f: - f.write(drawer.GetDrawingText()) - - # * Fragment-level visualization - mask_frag = node_mask.numpy().tolist() - frag_imgs = [] - - for i, frag in enumerate(fragments): - top_val = self._get_top(mask_frag[i]) - label = np.append(top_val[1]["labels"], top_val[0]["labels"]) - cntrb = np.append(top_val[1]["contrib"], top_val[0]["contrib"]) - - contrib = np.sum(mask_frag[i]).round(3) - entries = [f"{lbl}: {val}" for lbl, val in zip(label, cntrb)] - frag_imgs.append(self._subplot(frag, entries, contrib)) - - # Create image - fig = self._create_frag_image(frag_imgs, 1600, 300) - fig.save(self.out / f"{data.idx}_frag.svg") + # * Feature-level bar plot + self._bar_plot(node_mask, name) + + # * Fragment-level visualization + fragments = data.frag + self._frag_visualization(node_mask, fragments, name) + + # * Molecule-level visualization + self._explain_frag(data, node_mask, pred, pred_label, name) + + def _explain_atom(self, data, node_mask, pred, pred_label, name): + smiles = data.smiles + mol = Chem.MolFromSmiles(smiles) + + mask_atom = torch.sum(node_mask, dim=1).numpy() + mask_atom = np.round(mask_atom, 3) + + min_val = mask_atom.min() + max_val = mask_atom.max() + if min_val > 0: + max_val *= 1.3 + cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0, + vmax=max_val), + cmap=mpl.cm.Blues) + elif max_val < 0: + min_val *= 1.3 + cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val, + vmax=0), + cmap=mpl.cm.Oranges_r) + else: + min_val *= 1.3 + max_val *= 1.3 + pos_colors = plt.cm.Blues(np.linspace(0, 1, 128)) + neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128)) + combined = np.vstack((neg_colors, pos_colors)) + color = LinearSegmentedColormap.from_list("OrBu", combined) + cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val, + vcenter=0, + vmax=max_val), + cmap=color) + + highlight_node = {} + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + rgb = cmap.to_rgba(mask_atom[idx])[:-1] + highlight_node[idx] = [rgb] + atom.SetProp("atomNote", str(mask_atom[idx])) + + legend = (f"Graph ID: {name}\n{smiles}\n" + f"Prediction: {pred:.3f} ({pred_label})" + f"\tTrue: {float(data.y)}") + + drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) + opts = drawer.drawOptions() + opts.fillHighlights = True + opts.annotationFontScale = 0.5 + opts.legendFontSize = 25 + drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node, + {}, {}, {}) + drawer.FinishDrawing() + + with open(self.out / f"{data.idx}_mol.svg", "w") as f: + f.write(drawer.GetDrawingText()) + + def _explain_frag(self, data, node_mask, pred, pred_label, name): + smiles = data.smiles + mol = Chem.MolFromSmiles(smiles) + + atom_map = dict(zip(data.atom_map[0], data.atom_map[1])) + mask_atom = torch.sum(node_mask, dim=1).numpy() + mask_atom = np.round(mask_atom, 3) + + min_val = mask_atom.min() + max_val = mask_atom.max() + if min_val > 0: + max_val *= 1.3 + cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=0, + vmax=max_val), + cmap=mpl.cm.Blues) + elif max_val < 0: + min_val *= 1.3 + cmap = mpl.cm.ScalarMappable(norm=Normalize(vmin=min_val, + vmax=0), + cmap=mpl.cm.Oranges_r) + else: + min_val *= 1.3 + max_val *= 1.3 + pos_colors = plt.cm.Blues(np.linspace(0, 1, 128)) + neg_colors = plt.cm.Oranges_r(np.linspace(0, 1, 128)) + combined = np.vstack((neg_colors, pos_colors)) + color = LinearSegmentedColormap.from_list("OrBu", combined) + cmap = mpl.cm.ScalarMappable(norm=TwoSlopeNorm(vmin=min_val, + vcenter=0, + vmax=max_val), + cmap=color) + + highlight_node = {} + for atom in mol.GetAtoms(): + idx = atom.GetIdx() + frag_val = mask_atom[atom_map[idx]] + rgb = cmap.to_rgba(frag_val)[:-1] + highlight_node[idx] = [rgb] + atom.SetProp("atomNote", str(frag_val)) + + legend = (f"Graph ID: {name}\n{smiles}\n" + f"Prediction: {pred:.3f} ({pred_label})" + f"\tTrue: {float(data.y)}") + + drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) + opts = drawer.drawOptions() + opts.fillHighlights = True + opts.annotationFontScale = 0.5 + opts.legendFontSize = 25 + drawer.DrawMoleculeWithHighlights(mol, legend, highlight_node, + {}, {}, {}) + drawer.FinishDrawing() + + with open(self.out / f"{name}_mol.svg", "w") as f: + f.write(drawer.GetDrawingText()) + + def _bar_plot(self, node_mask, name): + # * Feature-level bar plot + mask_feat = torch.sum(node_mask, dim=0).numpy() + feats = self._get_top(mask_feat) + + fig, ax = plt.subplots(figsize=(10, 6)) + all_lbl = np.append(feats[1]["labels"], feats[0]["labels"]) + all_val = np.append(feats[1]["contrib"], feats[0]["contrib"]) + colors = (["SteelBlue"] * len(feats[1]["labels"]) + + ["DarkOrange"] * len(feats[0]["labels"])) + + ax.barh(all_lbl, all_val, color=colors) + ax.set_title(f"Top {self.k} Features - {name}") + ax.set_xlabel("Contribution") + ax.invert_yaxis() + + plt.xlim(mask_feat.min() * 1.15, mask_feat.max() * 1.15) + for i, v in enumerate(all_val): + x_off = 0.02 if v > 0 else -0.3 + ax.text(v + x_off, i, str(v), va="center", fontsize=8) + + plt.tight_layout() + out_feat = self.out / f"{name}_feat.svg" + fig.savefig(out_feat, format="svg") + plt.close(fig) def _get_top(self, mask, fragments=None): mask = np.round(mask, 3) @@ -217,6 +304,23 @@ def _get_top(self, mask, fragments=None): cuts = {0: neg, 1: pos} return cuts + def _frag_visualization(self, node_mask, fragments, name): + mask_frag = node_mask.numpy().tolist() + frag_imgs = [] + + for i, frag in enumerate(fragments): + top_val = self._get_top(mask_frag[i]) + label = np.append(top_val[1]["labels"], top_val[0]["labels"]) + cntrb = np.append(top_val[1]["contrib"], top_val[0]["contrib"]) + + contrib = np.sum(mask_frag[i]).round(3) + entries = [f"{lbl}: {val}" for lbl, val in zip(label, cntrb)] + frag_imgs.append(self._subplot(frag, entries, contrib)) + + # Create image + fig = self._create_frag_image(frag_imgs, 1600, 300) + fig.save(self.out / f"{name}_frag.svg") + def _subplot(self, frag, entries, contrib, size=(500, 250)): mol = Chem.MolFromSmiles(frag) @@ -267,46 +371,3 @@ def _create_frag_image(self, images, width=1600, height=300): fig = sg.SVGFigure(str(width), str(h)) fig.append([background] + images) return fig - - -def plot_counters(data, out, prefix="", top_n=35): - fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(18, 12)) - - col = 0 - keys = ["0_0", "0_1", "1_0", "1_1"] - title = {"0_0": "Predicted: 0, Real: 0", - "0_1": "Predicted: 0, Real: 1", - "1_1": "Predicted: 1, Real: 1", - "1_0": "Predicted: 1, Real: 0"} - - for key in keys: - inner_dict = data[key] - - # Get class 0 and class 1 dict - class_0 = inner_dict[0] - class_0 = dict(sorted(class_0.items(), key=lambda item: item[1], - reverse=True)) - class_0_lbl = list(class_0.keys())[: top_n] - class_0_num = list(class_0.values())[: top_n] - - class_1 = inner_dict[1] - class_1 = dict(sorted(class_1.items(), key=lambda item: item[1], - reverse=True)) - class_1_lbl = list(class_1.keys())[: top_n] - class_1_num = list(class_1.values())[: top_n] - - # Plot - axes[0, col].bar(class_1_lbl, class_1_num, color="RoyalBlue") - axes[0, col].set_title(f"{title[key]} - Class 1") - axes[0, col].tick_params(axis="x", rotation=90) - - axes[1, col].bar(class_0_lbl, class_0_num, color="Crimson") - axes[1, col].set_title(f"{title[key]} - Class 0") - axes[1, col].tick_params(axis="x", rotation=90) - - col += 1 - - plt.tight_layout() - out_feat = out / f"all_{prefix}.svg" - fig.savefig(out_feat, format="svg") - plt.close(fig)