From 389e27f1c6782336e5da3fd8819f17503c2ab44a Mon Sep 17 00:00:00 2001 From: HTQ17double <100686379+HTQ17double@users.noreply.github.com> Date: Wed, 13 Sep 2023 19:38:26 +0800 Subject: [PATCH 1/2] Add files via upload --- tree.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tree.py b/tree.py index 0e624aa..2335eb9 100644 --- a/tree.py +++ b/tree.py @@ -4,6 +4,7 @@ import operator import treePlotter from collections import Counter +from copy import deepcopy pre_pruning = True @@ -29,7 +30,7 @@ def read_dataset(filename): for line in all_lines[0:]: line = line.strip().split(',') # 以逗号为分割符拆分列表 dataset.append(line) - return dataset, labels + return dataset, labels#这里的label是特征 def read_testset(testfile): @@ -55,10 +56,10 @@ def cal_entropy(dataset): labelCounts = {} # 给所有可能分类创建字典 for featVec in dataset: - currentlabel = featVec[-1] + currentlabel = featVec[-1] #类别标签 if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 - labelCounts[currentlabel] += 1 + labelCounts[currentlabel] += 1# 各个类的总个数 Ent = 0.0 for key in labelCounts: p = float(labelCounts[key]) / numEntries @@ -184,6 +185,7 @@ def ID3_createTree(dataset, labels, test_dataset): print(u"此时最优索引为:" + (bestFeatLabel)) ID3Tree = {bestFeatLabel: {}} + labels_for_post_pruning = deepcopy(labels) del (labels[bestFeat]) # 得到列表包括节点所有的属性值 featValues = [example[bestFeat] for example in dataset] @@ -224,7 +226,7 @@ def ID3_createTree(dataset, labels, test_dataset): if post_pruning: tree_output = classifytest(ID3Tree, - featLabels=['年龄段', '有工作', '有自己的房子', '信贷情况'], + featLabels=labels_for_post_pruning, testDataSet=test_dataset) ans = [] for vec in test_dataset: From c8d5319d8a83998dd0b3d0883380a3640417edce Mon Sep 17 00:00:00 2001 From: HTQ17double <100686379+HTQ17double@users.noreply.github.com> Date: Wed, 13 Sep 2023 19:42:06 +0800 Subject: [PATCH 2/2] Add files via upload --- tree.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tree.py b/tree.py index 2335eb9..1a38696 100644 --- a/tree.py +++ b/tree.py @@ -30,7 +30,7 @@ def read_dataset(filename): for line in all_lines[0:]: line = line.strip().split(',') # 以逗号为分割符拆分列表 dataset.append(line) - return dataset, labels#这里的label是特征 + return dataset, labels def read_testset(testfile): @@ -56,10 +56,10 @@ def cal_entropy(dataset): labelCounts = {} # 给所有可能分类创建字典 for featVec in dataset: - currentlabel = featVec[-1] #类别标签 + currentlabel = featVec[-1] if currentlabel not in labelCounts.keys(): labelCounts[currentlabel] = 0 - labelCounts[currentlabel] += 1# 各个类的总个数 + labelCounts[currentlabel] += 1 Ent = 0.0 for key in labelCounts: p = float(labelCounts[key]) / numEntries @@ -227,7 +227,7 @@ def ID3_createTree(dataset, labels, test_dataset): if post_pruning: tree_output = classifytest(ID3Tree, featLabels=labels_for_post_pruning, - testDataSet=test_dataset) + testDataSet=test_dataset)# 这里传入的数据集的特征集合是变化后的,所以应该传入变化的特征集 ans = [] for vec in test_dataset: ans.append(vec[-1])