44Output: The decision tree maps a real number input to a real number output.
55"""
66
7- import numpy as np
87from collections import Counter
98
9+ import numpy as np
10+
1011
1112class DecisionTree :
1213 def __init__ (self , depth = 5 , min_leaf_size = 5 , task = "regression" , criterion = "gini" ):
@@ -54,7 +55,7 @@ def gini(self, y):
5455
5556 Lower Gini value indicates better purity (best split).
5657 """
57- classes , counts = np .unique (y , return_counts = True )
58+ _ , counts = np .unique (y , return_counts = True )
5859 prob = counts / counts .sum ()
5960 return 1 - np .sum (prob ** 2 )
6061
@@ -67,7 +68,7 @@ def entropy(self, y):
6768
6869 Lower entropy means higher purity.
6970 """
70- classes , counts = np .unique (y , return_counts = True )
71+ _ , counts = np .unique (y , return_counts = True )
7172 prob = counts / counts .sum ()
7273 return - np .sum (prob * np .log2 (prob + 1e-9 ))
7374
@@ -76,7 +77,8 @@ def information_gain(self, parent, left, right):
7677 Computes the information gain from splitting a dataset.
7778 Information gain represents the reduction in impurity
7879 after a dataset is split into left and right subsets.
79- Formula: IG = Impurity(parent) - [weighted impurity(left) + weighted impurity(right)]
80+ Formula: IG = Impurity(parent) - [
81+ weighted impurity(left) + weighted impurity(right)]
8082
8183 Higher information gain indicates a better split.
8284 """
@@ -155,10 +157,7 @@ def train(self, x, y):
155157 then the data set is not split and the average for the entire array is used as
156158 the predictor
157159 """
158- if self .task == "regression" :
159- best_score = float ("inf" )
160- else :
161- best_score = - float ("inf" )
160+ best_score = float ("inf" ) if self .task == "regression" else - float ("inf" )
162161
163162 for i in range (len (x )):
164163 if len (x [:i ]) < self .min_leaf_size :
@@ -209,11 +208,10 @@ def train(self, x, y):
209208 self .left .train (left_x , left_y )
210209 self .right .train (right_x , right_y )
211210
211+ elif self .task == "regression" :
212+ self .prediction = np .mean (y )
212213 else :
213- if self .task == "regression" :
214- self .prediction = np .mean (y )
215- else :
216- self .prediction = self .most_common_label (y )
214+ self .prediction = self .most_common_label (y )
217215
218216 def predict (self , x ):
219217 """
0 commit comments