diff --git a/tree.py b/tree.py index 0e624aa..1a38696 100644 --- a/tree.py +++ b/tree.py @@ -4,6 +4,7 @@ import operator import treePlotter from collections import Counter +from copy import deepcopy pre_pruning = True @@ -55,10 +56,10 @@ def cal_entropy(dataset): labelCounts = {} # 给所有可能分类创建字典 for featVec in dataset: - currentlabel = featVec[-1] + currentlabel = featVec[-1] if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 - labelCounts[currentlabel] += 1 + labelCounts[currentlabel] += 1 Ent = 0.0 for key in labelCounts: p = float(labelCounts[key]) / numEntries @@ -184,6 +185,7 @@ def ID3_createTree(dataset, labels, test_dataset): print(u"此时最优索引为:" + (bestFeatLabel)) ID3Tree = {bestFeatLabel: {}} + labels_for_post_pruning = deepcopy(labels) del (labels[bestFeat]) # 得到列表包括节点所有的属性值 featValues = [example[bestFeat] for example in dataset] @@ -224,8 +226,8 @@ def ID3_createTree(dataset, labels, test_dataset): if post_pruning: tree_output = classifytest(ID3Tree, - featLabels=['年龄段', '有工作', '有自己的房子', '信贷情况'], - testDataSet=test_dataset) + featLabels=labels_for_post_pruning, + testDataSet=test_dataset)# 这里传入的数据集的特征集合是变化后的,所以应该传入变化的特征集 ans = [] for vec in test_dataset: ans.append(vec[-1])