@@ -18,7 +18,7 @@ def __init__(self, depth=5, min_leaf_size=5, task="regression", criterion="gini"
1818 self .prediction = None
1919 self .task = task
2020 self .criterion = criterion
21-
21+
2222 def mean_squared_error (self , labels , prediction ):
2323 """
2424 mean_squared_error:
@@ -51,20 +51,20 @@ def gini(self, y):
5151 would be incorrectly classified.
5252 Formula: Gini = 1 - sum(p_i^2)
5353 where p_i is the probability of class i.
54-
54+
5555 Lower Gini value indicates better purity (best split).
5656 """
5757 classes , counts = np .unique (y , return_counts = True )
5858 prob = counts / counts .sum ()
59- return 1 - np .sum (prob ** 2 )
59+ return 1 - np .sum (prob ** 2 )
6060
6161 def entropy (self , y ):
6262 """
6363 Computes the entropy (impurity) of a set of labels.
6464 Entropy measures the randomness or disorder in the data.
6565 Formula: Entropy = - sum(p_i * log2(p_i))
6666 where p_i is the probability of class i.
67-
67+
6868 Lower entropy means higher purity.
6969 """
7070 classes , counts = np .unique (y , return_counts = True )
@@ -77,7 +77,7 @@ def information_gain(self, parent, left, right):
7777 Information gain represents the reduction in impurity
7878 after a dataset is split into left and right subsets.
7979 Formula: IG = Impurity(parent) - [weighted impurity(left) + weighted impurity(right)]
80-
80+
8181 Higher information gain indicates a better split.
8282 """
8383 if self .criterion == "gini" :
@@ -90,9 +90,7 @@ def information_gain(self, parent, left, right):
9090 weight_l = len (left ) / len (parent )
9191 weight_r = len (right ) / len (parent )
9292
93- return func (parent ) - (
94- weight_l * func (left ) + weight_r * func (right )
95- )
93+ return func (parent ) - (weight_l * func (left ) + weight_r * func (right ))
9694
9795 def most_common_label (self , y ):
9896 return Counter (y ).most_common (1 )[0 ][0 ]
@@ -150,7 +148,7 @@ def train(self, x, y):
150148 return
151149
152150 best_split = 0
153-
151+
154152 """
155153 loop over all possible splits for the decision tree. find the best split.
156154 if no split exists that is less than 2 * error for the entire array
@@ -180,7 +178,7 @@ def train(self, x, y):
180178 best_score = score
181179 best_split = i
182180
183- else :
181+ else :
184182 gain = self .information_gain (y , left_y , right_y )
185183
186184 if gain > best_score :
@@ -234,7 +232,7 @@ def predict(self, x):
234232
235233 raise ValueError ("Decision tree not yet trained" )
236234
237-
235+
238236class TestDecisionTree :
239237 """Decision Tres test class"""
240238
@@ -252,7 +250,7 @@ def helper_mean_squared_error_test(labels, prediction):
252250
253251 return float (squared_error_sum / labels .size )
254252
255-
253+
256254def main ():
257255 """
258256 In this demonstration we're generating a sample data set from the sin function in
@@ -270,15 +268,17 @@ def main():
270268 x_cls = np .array ([1 , 2 , 3 , 4 , 5 , 6 ])
271269 y_cls = np .array ([0 , 0 , 0 , 1 , 1 , 1 ])
272270
273- clf = DecisionTree (depth = 3 , min_leaf_size = 1 , task = "classification" , criterion = "gini" )
271+ clf = DecisionTree (
272+ depth = 3 , min_leaf_size = 1 , task = "classification" , criterion = "gini"
273+ )
274274 clf .train (x_cls , y_cls )
275275
276- print ("Classification prediction (2):" , clf .predict (2 ))
277- print ("Classification prediction (5):" , clf .predict (5 ))
276+ print ("Classification prediction (2):" , clf .predict (2 ))
277+ print ("Classification prediction (5):" , clf .predict (5 ))
278278
279279
280280if __name__ == "__main__" :
281281 main ()
282282 import doctest
283283
284- doctest .testmod (name = "mean_squared_error" , verbose = True )
284+ doctest .testmod (name = "mean_squared_error" , verbose = True )
0 commit comments