diff --git a/.github/workflows/wip.yml b/.github/workflows/wip.yml deleted file mode 100755 index e989566..0000000 --- a/.github/workflows/wip.yml +++ /dev/null @@ -1,13 +0,0 @@ -name: WIP - -on: - pull_request: - types: [opened, synchronize, reopened, edited] - -jobs: - wip: - runs-on: ubuntu-latest - steps: - - uses: wip/action@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/frame/explain.py b/frame/explain.py index 465c0ed..f379072 100644 --- a/frame/explain.py +++ b/frame/explain.py @@ -60,9 +60,7 @@ def main(): # * Get checkpoint and prepare Explainer model = models.select_model(model_name, tune) - model.load_state_dict(torch.load(path_checkpoint, - map_location=device, - weights_only=True)) + model.load_state_dict(torch.load(path_checkpoint)) model.eval() if task == "classification": @@ -79,7 +77,7 @@ def main(): return_type="raw")) for data in tqdm(dataloader, ncols=120, desc="Explaining"): - data = data.to(device) + data.to(device) # * Make predictions model_out = model(x=data.x.float(), @@ -89,12 +87,15 @@ def main(): # * Read prediction values if task == "classification": - detach = torch.sigmoid(model_out).cpu().detach() - pred = list(torch.ravel(detach).cpu().detach().numpy()) + logit = model_out.cpu().detach() + logit_list = list(torch.ravel(logit).numpy()) + detach = torch.sigmoid(logit) + pred = list(torch.ravel(detach).numpy()) pred_lbl = (detach >= 0.5).int() else: detach = model_out.cpu().detach() - pred = list(torch.ravel(detach).cpu().detach().numpy()) + pred = list(torch.ravel(detach).numpy()) + logit_list = pred pred_lbl = [None] * detach.shape[0] # * Explain @@ -102,6 +103,7 @@ def main(): edge_attr=data.edge_attr.float(), batch=data.batch) - mol_exp = explain.MolExplain(explanation, pred, pred_lbl, loader, out) + mol_exp = explain.MolExplain(explanation, logit_list, pred, pred_lbl, + loader, out) mol_exp.retrieve_info(data) mol_exp.plot_explanations(data) diff --git a/frame/source/explain/__init__.py b/frame/source/explain/__init__.py index b06bebe..b555819 100644 --- a/frame/source/explain/__init__.py +++ b/frame/source/explain/__init__.py @@ -24,11 +24,13 @@ class MolExplain: - def __init__(self, explanation, pred, pred_lbl, loader, out_dir, k=10): + def __init__(self, explanation, logit, pred, pred_lbl, loader, + out_dir, k=10): self.mask = explanation.node_mask.detach().cpu() self.batch = explanation.batch.detach().cpu() self.pred = pred self.pred_lbl = pred_lbl + self.logit = logit self.loader = loader self.out = out_dir self.k = k @@ -103,17 +105,19 @@ def _info_fragment(self, graphs): def plot_explanations(self, graphs): batch_num = self.batch.unique() masks = [self.mask[self.batch == b] for b in batch_num] - pred_label = "" for idx, node_mask in enumerate(masks): data = graphs[idx] name = data.idx pred = self.pred[idx] - if self.pred_lbl[idx] is not None: - pred_label = self.pred_lbl[idx].numpy()[0] + logit = self.logit[idx] + pred_label = self.pred_lbl[idx] + if pred_label is not None: + pred_label = pred_label.numpy()[0] if self.loader == "default": - self._explain_atom(data, node_mask, pred, pred_label, name) + self._explain_atom(data, node_mask, pred, logit, + pred_label, name) else: # * Feature-level bar plot @@ -124,13 +128,22 @@ def plot_explanations(self, graphs): # self._frag_visualization(node_mask, fragments, name) # * Molecule-level visualization - self._explain_frag(data, node_mask, pred, pred_label, name) + self._explain_frag(data, node_mask, pred, logit, + pred_label, name) - def _explain_atom(self, data, node_mask, pred, pred_label, name): + def _rescale_mask(sel, mask_atom, logit): + """Rescale per-atom attributions so they sum to the raw logit.""" + current_sum = mask_atom.sum() + if current_sum == 0: + return mask_atom + return mask_atom * (logit / current_sum) + + def _explain_atom(self, data, node_mask, pred, logit, pred_label, name): smiles = data.smiles mol = Chem.MolFromSmiles(smiles) mask_atom = torch.sum(node_mask, dim=1).numpy() + mask_atom = self._rescale_mask(mask_atom, logit) mask_atom = np.round(mask_atom, 3) min_val = mask_atom.min() @@ -165,8 +178,12 @@ def _explain_atom(self, data, node_mask, pred, pred_label, name): atom.SetProp("atomNote", str(mask_atom[idx])) legend = (f"Graph ID: {name}\n{smiles}\n" - f"Prediction: {pred:.3f} ({pred_label})" - f"\tTrue: {float(data.y)}") + f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t" + f"Class: {pred_label}\tTrue: {int(data.y)}") + + if pred_label is None: # Regression + legend = (f"Graph ID: {name}\n{smiles}\n" + f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}") drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) opts = drawer.drawOptions() @@ -180,12 +197,13 @@ def _explain_atom(self, data, node_mask, pred, pred_label, name): with open(self.out / f"{data.idx}.svg", "w") as f: f.write(drawer.GetDrawingText()) - def _explain_frag(self, data, node_mask, pred, pred_label, name): + def _explain_frag(self, data, node_mask, pred, logit, pred_label, name): smiles = data.smiles mol = Chem.MolFromSmiles(smiles) atom_map = dict(zip(data.atom_map[0], data.atom_map[1])) mask_atom = torch.sum(node_mask, dim=1).numpy() + mask_atom = self._rescale_mask(mask_atom, logit) mask_atom = np.round(mask_atom, 3) min_val = mask_atom.min() @@ -221,8 +239,12 @@ def _explain_frag(self, data, node_mask, pred, pred_label, name): atom.SetProp("atomNote", str(frag_val)) legend = (f"Graph ID: {name}\n{smiles}\n" - f"Prediction: {pred:.3f} ({pred_label})" - f"\tTrue: {float(data.y)}") + f"Prediction: {pred:.3f}\tLogits: {logit:.3f}\t|\t" + f"Class: {pred_label}\tTrue: {int(data.y)}") + + if pred_label is None: # Regression + legend = (f"Graph ID: {name}\n{smiles}\n" + f"Prediction: {pred:.3f}\tTrue: {float(data.y):.3f}") drawer = rdMolDraw2D.MolDraw2DSVG(1200, 800) opts = drawer.drawOptions()