附上实现的ID3算法python代码~~~
参考机器学习实战写的
#-*- coding: UTF-8 -*-
from math import log
import operator
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import copy
#创建测试数据
def createDataSet():
dataSet=[['young', 0, 0, 0, 'no'], #数据集,no代表不给贷款,yes代表给贷款
['young', 0, 0, 1, 'no'],
['young', 1, 0, 1, 'yes'],
['young', 1, 1, 0, 'yes'],
['young', 0, 0, 0, 'no'],
['middle', 0, 0, 0, 'no'],
['middle', 0, 0, 1, 'no'],
['middle', 1, 1, 1, 'yes'],
['middle', 0, 1, 2, 'yes'],
['middle', 0, 1, 2, 'yes'],
['old', 0, 1, 2, 'yes'],
['old', 0, 1, 1, 'yes'],
['old', 1, 0, 1, 'yes'],
['old', 1, 0, 2, 'yes'],
['old', 0, 0, 0, 'no']]
labels=['年龄','有工作','有房子','贷款情况']#贷款情况,0,1,2代表一般,好,非常好
return dataSet,labels
#计算信息熵
def calShannonEnt(dataSet):
labelCounts={}
for item in dataSet:
label=item[-1]
if(label not in labelCounts.keys()):
labelCounts[label]=1
else:
labelCounts[label]+=1
length=len(dataSet)
shannonEnt=0.0
for i in labelCounts:
p=labelCounts[i]/length
shannonEnt-=p*log(p,2)
return shannonEnt
###按照给定的特征划分数据集
def splitDataSet(dataSet,index,value):#index为特征的索引,value为要选出的特征值:
returnData=[]
for item in dataSet:
if(item[index]==value):
item2=item[:index]
item2.extend(item[index+1:])
returnData.append(item2)
return returnData
###选择最优特征
def chooseBestFeatureToSplit(dataSet):
featureNum=len(dataSet[0])-1
baseEnt=calShannonEnt(dataSet)
maxGain=0.0
bestFeature=-1
for i in range(featureNum):
#先统计i列特征有几种取值
featureValues=[]
currentEnt=0.0
for item in dataSet:
featureValues.append(item[i])
featureValues=set(featureValues)
#对每种取值进行数据划分并计算熵
for value in featureValues:
splitData=splitDataSet(dataSet,i,value)
p=len(splitData)/len(dataSet)
ent=calShannonEnt(splitData)
currentEnt+=p*ent
currentGain=baseEnt-currentEnt
print("第%d个特征的增益为%.3f" % (i, currentGain))
if(maxGain<currentGain):
maxGain=currentGain
bestFeature=i
return bestFeature
###统计classList中出现此处最多的元素
def majorityCnt(classList):
classCount={}
for item in classList:
if item not in classCount.keys():
classCount[item]=1
else:
classCountp[item]+=1
sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
print(sortedClassCount)
return sortedClassCount[0][0]
#创建决策树
"""
函数说明:创建决策树
Parameters:
dataSet - 训练数据集
labels - 分类属性标签
featLabels - 存储选择的最优特征标签
在构建决策树的代码,可以看到,有个featLabels参数。
它是用来干什么的?它就是用来记录各个分类结点的,在用决策树做预测的时候,我们按顺序输入需要的分类结点的属性值即可。
Returns:
myTree - 决策树
"""
def createTree(dataSet,labels,featLabels):
classList=[example[-1] for example in dataSet]
if(classList.count(classList[0])==len(classList)):
return classList[0]
if(len(dataSet[0])==1):
return majorityCnt(classList)
bestFeat=chooseBestFeatureToSplit(dataSet)
bestFeatLabel=labels[bestFeat]
featLabels.append(bestFeatLabel)
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
#得到训练集中所有最优特征的属性值
featValues=[example[bestFeat] for example in dataSet]
featValues=set(featValues)
for value in featValues:
myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),labels,featLabels)
return myTree
"""
函数说明:获取决策树叶子结点的数目
Parameters:
myTree - 决策树
Returns:
numLeafs - 决策树的叶子结点的数目
"""
def getNumLeafs(myTree):
numLeafs=0
firstStr=next(iter(myTree))
secondDict=myTree[firstStr]
for key in secondDict.keys():
if(type(secondDict[key]).__name__=='dict'):
numLeafs+=getNumLeafs(secondDict[key])
else:
numLeafs+=1
return numLeafs
"""
函数说明:获取决策树的层数
Parameters:
myTree - 决策树
Returns:
maxDepth - 决策树的层数
"""
def getTreeDepth(myTree):
maxDepth=0
firstStr=next(iter(myTree))
secondDict=myTree[firstStr]
for key in secondDict.keys():
if(type(secondDict[key]).__name__=='dict'):
thisDepth=1+getTreeDepth(secondDict[key])
else:
thisDepth=1
if(thisDepth>maxDepth):
maxDepth=thisDepth
return maxDepth
"""
函数说明:绘制结点
Parameters:
nodeTxt - 结点名
centerPt - 文本位置
parentPt - 标注的箭头位置
nodeType - 结点格式
Returns:
无
"""
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
arrow_args = dict(arrowstyle="<-") #定义箭头格式
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) #设置中文字体
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', #绘制结点
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)
"""
函数说明:标注有向边属性值
Parameters:
cntrPt、parentPt - 用于计算标注位置
txtString - 标注的内容
Returns:
无
"""
def plotMidText(cntrPt,parentPt,txtString):
#计算标注位置
xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
"""
函数说明:绘制决策树
Parameters:
myTree - 决策树(字典)
parentPt - 标注的内容
nodeTxt - 结点名
Returns:
无
"""
def plotTree(myTree, parentPt, nodeTxt):
decisionNode = dict(boxstyle="sawtooth", fc="0.8") #设置结点格式
leafNode = dict(boxstyle="round4", fc="0.8") #设置叶结点格式
numLeafs = getNumLeafs(myTree) #获取决策树叶结点数目,决定了树的宽度
depth = getTreeDepth(myTree) #获取决策树层数
firstStr = next(iter(myTree)) #下个字典
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #中心位置
plotMidText(cntrPt, parentPt, nodeTxt) #标注有向边属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode) #绘制结点
secondDict = myTree[firstStr] #下一个字典,也就是继续绘制子结点
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #y偏移
for key in secondDict.keys():
if type(secondDict[key]).__name__=='dict': #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
plotTree(secondDict[key],cntrPt,str(key)) #不是叶结点,递归调用继续绘制
else: #如果是叶结点,绘制叶结点,并标注有向边属性值
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
"""
函数说明:创建绘制面板
Parameters:
inTree - 决策树(字典)
Returns:
无
"""
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') #创建fig
fig.clf() #清空fig
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) #去掉x、y轴
plotTree.totalW = float(getNumLeafs(inTree)) #获取决策树叶结点数目
plotTree.totalD = float(getTreeDepth(inTree)) #获取决策树层数
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0; #x偏移
plotTree(inTree, (0.5,1.0), '') #绘制决策树
plt.show()
"""
函数说明:使用决策树分类
Parameters:
inputTree - 已经生成的决策树
featLabels - 特征标签
testVec - 测试数据列表
Returns:
classLabel - 分类结果
"""
def classify(inputTree,featLabels,testVec):
firstStr=next(iter(inputTree))
secondDict=inputTree[firstStr]
featIndex=featLabels.index(firstStr)
for key in secondDict.keys():
if(testVec[featIndex]==key):
if(type(secondDict[key]).__name__=='dict'):
classLabel=classify(secondDict[key],featLabels,testVec)
else:
classLabel=secondDict[key]
return classLabel
if __name__=='__main__':
dataSet,labels=createDataSet()
labelTemp=copy.copy(labels)
print(dataSet)
print(calShannonEnt(dataSet))
print("最优特征索引值:"+str(chooseBestFeatureToSplit(dataSet)))
featLabels=[]
myTree=createTree(dataSet,labels,featLabels)
print(myTree)
createPlot(myTree)
testVec = [0,1,0,1] #测试数据
result = classify(myTree, labelTemp, testVec)
if result == 'yes':
print('放贷')
if result == 'no':
print('不放贷')