diff --git a/CIGIN_V2/molecular_graph.py b/CIGIN_V2/molecular_graph.py index 8580397..9ab3df7 100644 --- a/CIGIN_V2/molecular_graph.py +++ b/CIGIN_V2/molecular_graph.py @@ -91,6 +91,6 @@ def get_graph_from_smile(molecule_smile): bond_features_ij = get_bond_features(bond_ij) edge_features.append(bond_features_ij) - G.ndata['x'] = np.array(node_features) - G.edata['w'] = np.array(edge_features) + G.ndata['x'] = torch.from_numpy(np.array(node_features)) + G.edata['w'] = torch.from_numpy(np.array(edge_features)) return G