diff --git a/lime/lime_tabular.py b/lime/lime_tabular.py index 880f3d391..706eaed83 100644 --- a/lime/lime_tabular.py +++ b/lime/lime_tabular.py @@ -359,6 +359,8 @@ def explain_instance(self, ).ravel() yss = predict_fn(inverse) + if not isinstance(yss, np.ndarray): #pytorch output + yss = yss.detach().numpy() # for classification, the model needs to provide a list of tuples - classes # along with prediction probabilities @@ -505,7 +507,6 @@ def __data_inverse(self, else: num_cols = data_row.shape[0] data = np.zeros((num_samples, num_cols)) - categorical_features = range(num_cols) if self.discretizer is None: instance_sample = data_row scale = self.scaler.scale_