Skip to content

Commit 0776097

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8d10c36 commit 0776097

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

machine_learning/decision_tree.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, depth=5, min_leaf_size=5, task="regression", criterion="gini"
1818
self.prediction = None
1919
self.task = task
2020
self.criterion = criterion
21-
21+
2222
def mean_squared_error(self, labels, prediction):
2323
"""
2424
mean_squared_error:
@@ -51,20 +51,20 @@ def gini(self, y):
5151
would be incorrectly classified.
5252
Formula: Gini = 1 - sum(p_i^2)
5353
where p_i is the probability of class i.
54-
54+
5555
Lower Gini value indicates better purity (best split).
5656
"""
5757
classes, counts = np.unique(y, return_counts=True)
5858
prob = counts / counts.sum()
59-
return 1 - np.sum(prob ** 2)
59+
return 1 - np.sum(prob**2)
6060

6161
def entropy(self, y):
6262
"""
6363
Computes the entropy (impurity) of a set of labels.
6464
Entropy measures the randomness or disorder in the data.
6565
Formula: Entropy = - sum(p_i * log2(p_i))
6666
where p_i is the probability of class i.
67-
67+
6868
Lower entropy means higher purity.
6969
"""
7070
classes, counts = np.unique(y, return_counts=True)
@@ -77,7 +77,7 @@ def information_gain(self, parent, left, right):
7777
Information gain represents the reduction in impurity
7878
after a dataset is split into left and right subsets.
7979
Formula: IG = Impurity(parent) - [weighted impurity(left) + weighted impurity(right)]
80-
80+
8181
Higher information gain indicates a better split.
8282
"""
8383
if self.criterion == "gini":
@@ -90,9 +90,7 @@ def information_gain(self, parent, left, right):
9090
weight_l = len(left) / len(parent)
9191
weight_r = len(right) / len(parent)
9292

93-
return func(parent) - (
94-
weight_l * func(left) + weight_r * func(right)
95-
)
93+
return func(parent) - (weight_l * func(left) + weight_r * func(right))
9694

9795
def most_common_label(self, y):
9896
return Counter(y).most_common(1)[0][0]
@@ -150,7 +148,7 @@ def train(self, x, y):
150148
return
151149

152150
best_split = 0
153-
151+
154152
"""
155153
loop over all possible splits for the decision tree. find the best split.
156154
if no split exists that is less than 2 * error for the entire array
@@ -180,7 +178,7 @@ def train(self, x, y):
180178
best_score = score
181179
best_split = i
182180

183-
else:
181+
else:
184182
gain = self.information_gain(y, left_y, right_y)
185183

186184
if gain > best_score:
@@ -234,7 +232,7 @@ def predict(self, x):
234232

235233
raise ValueError("Decision tree not yet trained")
236234

237-
235+
238236
class TestDecisionTree:
239237
"""Decision Tres test class"""
240238

@@ -252,7 +250,7 @@ def helper_mean_squared_error_test(labels, prediction):
252250

253251
return float(squared_error_sum / labels.size)
254252

255-
253+
256254
def main():
257255
"""
258256
In this demonstration we're generating a sample data set from the sin function in
@@ -270,15 +268,17 @@ def main():
270268
x_cls = np.array([1, 2, 3, 4, 5, 6])
271269
y_cls = np.array([0, 0, 0, 1, 1, 1])
272270

273-
clf = DecisionTree(depth=3, min_leaf_size=1, task="classification", criterion="gini")
271+
clf = DecisionTree(
272+
depth=3, min_leaf_size=1, task="classification", criterion="gini"
273+
)
274274
clf.train(x_cls, y_cls)
275275

276-
print("Classification prediction (2):", clf.predict(2))
277-
print("Classification prediction (5):", clf.predict(5))
276+
print("Classification prediction (2):", clf.predict(2))
277+
print("Classification prediction (5):", clf.predict(5))
278278

279279

280280
if __name__ == "__main__":
281281
main()
282282
import doctest
283283

284-
doctest.testmod(name="mean_squared_error", verbose=True)
284+
doctest.testmod(name="mean_squared_error", verbose=True)

0 commit comments

Comments
 (0)