diff --git a/numpy_ml/trees/dt.py b/numpy_ml/trees/dt.py index 3bd033c..9a60e67 100644 --- a/numpy_ml/trees/dt.py +++ b/numpy_ml/trees/dt.py @@ -122,17 +122,11 @@ def predict_class_probs(self, X): def _grow(self, X, Y, cur_depth=0): # if all labels are the same, return a leaf if len(set(Y)) == 1: - if self.classifier: - prob = np.zeros(self.n_classes) - prob[Y[0]] = 1.0 - return Leaf(prob) if self.classifier else Leaf(Y[0]) + return Leaf(self._leaf_value(Y)) # if we have reached max_depth, return a leaf if cur_depth >= self.max_depth: - v = np.mean(Y, axis=0) - if self.classifier: - v = np.bincount(Y, minlength=self.n_classes) / len(Y) - return Leaf(v) + return Leaf(self._leaf_value(Y)) cur_depth += 1 self.depth = max(self.depth, cur_depth) @@ -142,14 +136,24 @@ def _grow(self, X, Y, cur_depth=0): # greedily select the best split according to `criterion` feat, thresh = self._segment(X, Y, feat_idxs) + if feat is None: + return Leaf(self._leaf_value(Y)) + l = np.argwhere(X[:, feat] <= thresh).flatten() r = np.argwhere(X[:, feat] > thresh).flatten() + if len(l) == 0 or len(r) == 0: + return Leaf(self._leaf_value(Y)) # grow the children that result from the split left = self._grow(X[l, :], Y[l], cur_depth) right = self._grow(X[r, :], Y[r], cur_depth) return Node(left, right, (feat, thresh)) + def _leaf_value(self, Y): + if self.classifier: + return np.bincount(Y, minlength=self.n_classes) / len(Y) + return np.mean(Y, axis=0) + def _segment(self, X, Y, feat_idxs): """ Find the optimal split rule (feature index and split threshold) for the @@ -160,7 +164,10 @@ def _segment(self, X, Y, feat_idxs): for i in feat_idxs: vals = X[:, i] levels = np.unique(vals) - thresholds = (levels[:-1] + levels[1:]) / 2 if len(levels) > 1 else levels + if len(levels) <= 1: + continue + + thresholds = (levels[:-1] + levels[1:]) / 2 gains = np.array([self._impurity_gain(Y, t, vals) for t in thresholds]) if gains.max() > best_gain: