Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions numpy_ml/trees/dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,11 @@ def predict_class_probs(self, X):
def _grow(self, X, Y, cur_depth=0):
# if all labels are the same, return a leaf
if len(set(Y)) == 1:
if self.classifier:
prob = np.zeros(self.n_classes)
prob[Y[0]] = 1.0
return Leaf(prob) if self.classifier else Leaf(Y[0])
return Leaf(self._leaf_value(Y))

# if we have reached max_depth, return a leaf
if cur_depth >= self.max_depth:
v = np.mean(Y, axis=0)
if self.classifier:
v = np.bincount(Y, minlength=self.n_classes) / len(Y)
return Leaf(v)
return Leaf(self._leaf_value(Y))

cur_depth += 1
self.depth = max(self.depth, cur_depth)
Expand All @@ -142,14 +136,24 @@ def _grow(self, X, Y, cur_depth=0):

# greedily select the best split according to `criterion`
feat, thresh = self._segment(X, Y, feat_idxs)
if feat is None:
return Leaf(self._leaf_value(Y))

l = np.argwhere(X[:, feat] <= thresh).flatten()
r = np.argwhere(X[:, feat] > thresh).flatten()
if len(l) == 0 or len(r) == 0:
return Leaf(self._leaf_value(Y))

# grow the children that result from the split
left = self._grow(X[l, :], Y[l], cur_depth)
right = self._grow(X[r, :], Y[r], cur_depth)
return Node(left, right, (feat, thresh))

def _leaf_value(self, Y):
if self.classifier:
return np.bincount(Y, minlength=self.n_classes) / len(Y)
return np.mean(Y, axis=0)

def _segment(self, X, Y, feat_idxs):
"""
Find the optimal split rule (feature index and split threshold) for the
Expand All @@ -160,7 +164,10 @@ def _segment(self, X, Y, feat_idxs):
for i in feat_idxs:
vals = X[:, i]
levels = np.unique(vals)
thresholds = (levels[:-1] + levels[1:]) / 2 if len(levels) > 1 else levels
if len(levels) <= 1:
continue

thresholds = (levels[:-1] + levels[1:]) / 2
gains = np.array([self._impurity_gain(Y, t, vals) for t in thresholds])

if gains.max() > best_gain:
Expand Down