第 3 章 決策樹
[TOC]
本章內容
- 決策樹簡介
- 在數據集中度量一致性
- 使用遞歸構造決策樹
- 使用 Matplotlib 繪制樹形圖
與 k-近鄰算法相比,決策樹 的主要優勢在于數據形式非常容易理解
1. 決策樹的構造
決策樹 :
- 優點:計算復雜度不高,輸出結果易于理解,對中間值得缺失不敏感,可以處理不相關特征數據
- 缺點:可能會產生過度匹配問題
- 適用數據類型:數值型和標稱型
本節將一步步地構造決策樹算法,首先我們討論數學上如何使用 信息論 劃分數據集,然后編寫代碼將理論應用到具體的數據集上,最后編寫代碼構建決策樹。
- 在構造決策樹時,我們需要解決的第一個問題就是,當前數據集上那個特征在劃分數據分類時起決定性作用。為了找到決定性的特征,劃分出最好的結果,我們必須評估每個特征。
- 完成測試后,原始數據就被劃分為幾個數據子集。這些數據子集會分布在第一個決策點的所有分支上。
- 如果某個分支下的數據屬于同一類型,則到這里以及正確地劃分數據分類,無序進一步對數據集進行分割。
- 如果數據子集內的數據不屬于同一類型,則需要重復劃分數據子集的過程。
- 如何劃分數據子集的算法和劃分原始數據集的方法相同,直到所有具有相同類型的數據均在一個數據子集內。
創建分支的偽代碼函數 createBranch() 如下:
檢測數據集中的每一個子項是否屬于同一分類:
if so return 類標簽;
else
尋找劃分數據集的最好特征
劃分數據集
創建分支節點
for 每個劃分的子集
調用函數 createBranch 并增加返回結果到分支節點中
return 分支節點
決策樹的一般流程:
- 收集數據
- 準備數據:樹構造算法只適用于標稱型數據,因此數值型數據必須離散化
- 分析數據:可以使用任何方法,構造樹完成之后,應該檢查圖形是否符合預期
- 訓練算法:使用經驗樹計算錯誤率
- 使用算法:此步驟可以適用于任何監督學習算法,而使用決策樹可以更好地理解數據的內在含義
如果依據某個屬性劃分法數據將會產生 4 個可能的值,我們將把數據劃分為 4 塊,并 創建 4 個不同的分支。這里將使用 ID3 算法劃分數據集。
問題 :每次劃分數據集時我們只選取一個特征屬性,如果訓練集中存在 20 個特征,第一次我們選擇哪個特征作為話的參考屬性呢?
表3-1 的數據包含 5 個海洋動物,特征包括:不浮出水面是否可以生存,以及是否有腳蹼。我們可以將這些動物分成兩類:魚類和非魚類。在回答這個問題之前,我們必須采用量化的方法判斷如何劃分數據:
1.1 信息增益
劃分數據集的大原則是:將無需的數據變得更加有序。組織雜亂無章數據的一種方法就是使用信息論度量信息,信息論是量化處理信息的分支科學。我們可以在劃分數據之前使用信息論量化度量信息的內容。
在劃分數據集之前之后信息發生的變化稱為 信息增益,知道如何計算信息增益,我們就可以計算每個特征值劃分數據集獲得的信息增益,獲得信息增益最高的特征就是最好的選擇 。
在可以評測哪種數據劃分方式就是最好的數據劃分之前,必須學習如何計算信息增益。集合信息的度量方式稱為香農熵(information gain) 或者簡稱為 熵(entropy) 。
熵(entropy),定義為信息的期望值。
信息:如果待分類的事務可能劃分在多個分類之中,則符號
的信息定義為
? 其中
是選擇該分類的概率。
為了計算熵,需要計算所有類別所有可能只包含的信息期望值,通過下面的公式得到:
其中 n 是分類的數目。
創建名為 s_2_tree.py 的文件,添加 calcShannonEnt 函數,其功能是計算給定數據集的熵:
import math
def calcShannonEnt(dataSet):
"""
計算給定數據集的香農熵
:param dataSet: 數據集
:return: 香農熵
"""
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1] # 可能值
if currentLabel not in labelCounts.keys():
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1 # 計數加一
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries # 可能值的期望
shannonEnt -= prob * math.log(prob, 2) # 熵
return shannonEnt
熵越高,則混合的數據也越多,我們可以在數據集中添加更多的分類,觀察熵是如何變化的。
得到熵之后,就可以按照獲取最大信息增益的方法劃分數據集。另一個度量集合無序程度的方法是 基尼不純度(Gini impurity),簡單地說就是從一個數據集中隨機選取子項,度量其被錯誤分類到其他分組里的概率。
1.2 劃分數據集
添加 splitDataSet 函數,按照給定特征劃分數據集:
def splitDataSet(dateSet, axis, value):
"""
按照給定特征劃分數據集
:param dateSet: 待劃分的數據集
:param axis: 劃分數據集的特征
:param value: 特征的返回值
:return:劃分完的數據集
"""
retDataSet = [] # 因python不用考慮內存分配問題,在函數中傳遞的是列表的引用,所以需聲明一個新列表對象
for featVec in dateSet:
if featVec[axis] == value: # 該特征值等于判斷值
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:]) #上一步和這一步是排除掉特征值
retDataSet.append(reducedFeatVec) # 加入返回的數據集
print(retDataSet)
return retDataSet
接下來將遍歷整個數據集,循環計算香農熵和 splitDataSet 函數,找到最好的特征劃分方式。熵計算將會告訴我們如何劃分數據集時最好的數據組織方式。
添加 chooseBestFeatureToSplit 函數,選擇最好的數據集劃分方式:
def chooseBestFeatureToSplit(dataSet):
"""
選擇最好的數據集劃分方式
:param dataSet: 待劃分的數據集
:return: 最好的特征
"""
numFeatures = len(dataSet[0]) - 1
baseEntropy = calcShannonEnt(dataSet) # 原始的香農熵
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
# 1.創建唯一的分類標簽列表
featList = [example[i] for example in dataSet]
uniqueVals = set(featList)
newEntropy = 0.0
# 2.計算每種劃分方式的信息熵
for value in uniqueVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet) / float(len(dataSet)) # 概率
newEntropy += prob * calcShannonEnt(subDataSet) # 香農熵,也就是信息量
infoGain = baseEntropy - newEntropy # 該特征的熵
# 3.計算最好的信息增益
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
測試上面代碼的實際輸出結果:
myDataSet,labels = createDataSet()
bestFeature = chooseBestFeatureToSplit(myDataSet)
print(bestFeature) # 輸出:0
運行結果告訴我們,第 0 個特征是最好的用于劃分數據集的特征。
1.3 遞歸構建決策樹
從數據集構造決策樹算法所需要的子功能模塊,其工作原理如下:
- 得到原始數據集
- 基于最好的屬性值劃分數據集,由于特征值可能多余兩個,因此可能存在大于兩個分支的數據集劃分。
- 第一次劃分之后,數據將被向下傳遞到樹分支的下一個節點,在這個節點上,我們可以再次劃分數據。因此可以采用遞歸的原則處理數據集。
- 遞歸結束的條件是程序遍歷完所有劃分數據集的屬性,或者每個分支下的所有實例都具有相同的分類。如果所有實例具有相同的分類,則得到一個葉子節點或者終止塊。任何到達葉子節點的數據必然屬于葉子節點的分類,參見圖3-2。
- 如果數據集已經處理了所有屬性,但是類標簽依然不是唯一的,此時我們需要決定如何定義該葉子節點,在這種情況下,我們通常會采用多數表決的方法決定葉子節點的分類。
添加 majorityCnt 函數,多數表決來決定葉子節點的分類:
def majorityCnt(classList):
"""
多數表決判斷分類
:param classList:分類名稱的列表
:return: 表決得到的分類名稱
"""
# 1.分類計數
classCount = {}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
# 2.排序,取出次數最多的分類名稱
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
添加 createTree 函數,創建樹:
def createTree(dataSet, labels):
"""
遞歸創建樹
:param dataSet: 數據集
:param labels: 標簽集
:return:樹
"""
# 1. 取出所有類別
classList = [example[-1] for example in dataSet]
# 2. 判斷:類別完全相同則停止繼續劃分
if classList.count(classList[0]) == len(classList):
return classList[0]
# 3. 遍歷完所有特征時返回出現次數最多的
if len(dataSet[0]) == 1:
return majorityCnt(classList)
# 4. 取出最好的分類特征
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLable = labels[bestFeat]
myTree = {bestFeatLable:{}}
del(labels[bestFeat])
# 5. 得到列表包含的所有屬性值
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
# 6.遞歸創建樹
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
2. 在 python 中使用 Matplotlib 注解繪制樹形圖
使用 Matplotlib 可創建樹形圖,決策樹的主要優點就是直觀易于理解,如果不能將其直觀地顯示出來,就無法發揮其優勢。
2.1 Matplotlib 注解
Matplotlib 提供了一個 **注解工具 annotations **,可以在數據圖形上添加文本注釋。注解通常用于解釋數據的內容
下面是 treePlotter.py 的代碼(使用文本注解繪制樹節點):
import matplotlib.pyplot as plt
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
繪制 parentPt 指向 centerPt 帶箭頭的線,箭頭節點的文本為 nodeTxt
:param nodeTxt:節點文本
:param centerPt:子節點
:param parentPt:父節點
:param nodeType:節點樣式
:return:
"""
createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',\
xytext=centerPt, textcoords='axes fraction',\
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
fig = plt.figure(1, facecolor='white')
fig.clf()
createPlot.axl = plt.subplot(111, frameon=False)
plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
plt.show()
這是第一個版本的 createPlot 函數,與例子文件中的 createPlot 函數有些不同,隨著內容的深入,我們將逐步添加缺失的代碼。代碼定義了樹節點格式的常量。然后定義 plotNode 函數執行了實際的繪圖功能,該函數需要一個繪圖區,該區域由全局變量 createPlot.ax1 定義。python 語言中所有的變量默認都是全局有效的,只要我們清楚知道當前代碼的主要功能,并不會引入太大的麻煩。最好定義了 createPlot 函數,它是這段代碼的核心。createPlot 函數首先創建了一個新圖形并清空繪圖區,然后再繪圖區上繪制了兩個代表不同類型的樹節點,后面我們將用這兩個節點繪制樹圖形。
2.2 構造注解樹
添加 getNumLeafs 函數,獲取葉節點的數目:
def getNumLeafs(myTree):
"""
獲取葉節點的數目
:param myTree: 樹
:return: 葉節點的數目
"""
numLeafs = 0
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
添加 getTreeDepth 函數,獲取樹的層數
def getTreeDepth(myTree):
"""
獲取樹的層數
:param myTree:樹
:return: 層數
"""
maxDepth = 0
firstStr = list(myTree.keys())[0] # python3 中 dict.keys 返回的是 dict_keys 對象,支持 iterable,但不支持 indexable,所以要轉換成list
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 +getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth : maxDepth = thisDepth
return maxDepth
為了節省時間,函數 retrieveTree 輸出預先存儲的樹信息,避免每次測試都要創建樹的麻煩:
def retrieveTree(i):
listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0:{'head': {0:'no', 1: 'yes'}}, 1:'no'}}}}
]
return listOfTrees[i]
接下來是主要的繪圖部分:
添加 plotNode 函數,繪制節點:
decisionNode = dict(boxstyle="sawtooth", fc="0.8") # 注解樣式
leafNode = dict(boxstyle="round4", fc="0.8") # 注解樣式
arrow_args = dict(arrowstyle="<-") # 箭頭樣式
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
繪制 parentPt 指向 centerPt 帶箭頭的線,箭頭節點的文本為 nodeTxt
:param nodeTxt:節點文本
:param centerPt:子節點
:param parentPt:父節點
:param nodeType:節點樣式
:return:
"""
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
添加 plotMidText 函數:
def plotMidText(cntrPt, parentPt, txtString):
"""
計算父節點和子節點的中間位置,并添加文本標簽信息
:param cntrPt:子節點位置
:param parentPt:父節點位置
:param txtString:文本標簽信息
:return:
"""
xMid = (parentPt[0] - cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString)
添加 plotTree 函數,繪制樹形圖:
def plotTree(myTree, parentPt, nodeTxt):
"""
繪制樹形圖
:param myTree:樹
:param parentPt:父節點位置
:param nodeTxt:節點文本
:return:
"""
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0] # 找到第一個元素,根節點
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) # 節點位置
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr] # 獲取節點下的內容
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD # 減少 y 的值,將樹的總深度平分,向下移動(樹是向下繪制)
for key in secondDict.keys(): # 鍵值:0、1
if type(secondDict[key]) == dict: # 判斷是 dict 還是 value
plotTree(secondDict[key], cntrPt, str(key)) # 遞歸調用
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW # 更新 x 值
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
添加 createPlot 函數,創建繪圖:
def createPlot(inTree):
"""
創建繪圖
:param inTree:
:return:
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
axprps = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprps) # 定義繪圖區
plotTree.totalW = float(getNumLeafs(inTree)) # 存儲樹的寬度
plotTree.totalD = float(getTreeDepth(inTree)) #存儲樹的深度
# 使用了這兩個全局變量追蹤已經繪制的節點位置,以及放置下一個節點的恰當位置
plotTree.xOff = -0.5/plotTree.totalW # 存儲樹在 x 軸的偏移
plotTree.yOff = 1.0 # 存儲樹在 y 軸的偏移
plotTree(inTree, (0.5,1.0), ' ')
plt.show()
- 函數 createPlot() 是使用的主函數,它調用了 plotTree() ,函數 plotTree() 又依次調用了前面的函數和 plotMidText() 。
- 函數 plotTree() 是遞歸函數:
- 計算樹的寬和高。全局變量 plotTree.totalW 、plotTree.totalD 存儲樹的寬度、深度,使用這兩個變量計算樹節點的擺放位置,這樣可以將樹繪制在水平方向和垂直方向的中心位置。
- 樹的寬度用于計算放置判斷節點的位置,主要的計算原則是將它放在所有葉子節點的中間,而不僅僅是它子節點的中間。
- 同時我們使用兩個全局變量 plotTree.xOff 、plotTree.yOff 追蹤已經繪制的節點位置,以及放置下一個節點的恰當位置。
- 另一個需要說明的問題是,繪制圖形的 x 軸有效范圍是 0.0 到 1.0,y 軸有效范圍也是 0.0 到 1.0。
- 通過計算樹包含的所有葉子節點數,劃分圖形的寬度,從而計算得到當前節點的中心位置,也就是說,我們按照葉子節點的數目將 x 軸劃分為若干部分。按照圖形比例繪制樹形圖的最大好處是無序關心實際輸出圖形的大小,一旦圖形大小發生了變化,函數會自動按照圖形大小重新繪制。
- 繪制出子節點具有的特征值,或者沿此分支向下的數據實例必須具有的特征值。
- 使用 plotMidText() 計算父節點和子節點的中間位置,并在此處添加簡單的文本標簽信息
- 按比例減少全局變量 plotTree.yOff,并標注此處將要繪制子節點,這些節點即可以是葉子節點也可以是判斷節點,此處需要只保存繪制圖形的軌跡
- 采用 getNumLeafs() 和 getTreeDepth() 以相同的方式遞歸遍歷整棵樹,如果節點是葉子節點則在圖形畫出葉子節點,如果不是葉子節點則遞歸調用 plotTree() 函數。在繪制了所有子節點之后,增加全局變量 Y 的偏移
3. 測試和存儲分類器
3.1 測試算法:使用決策樹執行分類
依靠訓練數據構造了決策樹之后,可以將它用于實際數據的分類:
- 在執行數據分類時,需要決策樹以及用于構造樹的標簽向量;
- 程序比較測試數據域決策樹上的數值,遞歸執行該過程直到進入葉子節點;
- 將測試數據定義為葉子節點所屬的類型
添加 classify() 函數,使用決策樹進行分類:
def classify(inputTree, featLabels, testVec):
"""
使用決策樹分類
:param inputTree:樹
:param featLabels: 標簽集
:param testVec:測試向量
:return:分類名稱
"""
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]) == dict:
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
3.2 使用算法:決策時的存儲
為了節省計算時間,最好能夠在每次執行分類時調用已經構造好的決策樹。需要使用 python 模塊 pickle 序列化對象,然后在磁盤上保存對象,并在需要的時候取出來。
使用 pickle 模塊存儲決策樹:
def storeTree(inputTree, filename):
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def gradTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
通過上面的代碼,我們可以將分類器存儲在硬盤上,而不用每次對數據分類時重新學習一遍,這也是決策樹的優點之一,像 k-近鄰算法就無法持久化分類器。可以余弦提煉并存儲數據集中包含的知識信息,在需要對事物進行分類時再使用這些知識。
4. 示例:使用決策樹預測隱形眼鏡類型
本節將通過一個例子講解決策樹如何預測患者需要佩戴的隱形眼鏡類型。
示例:使用決策樹預測隱形眼鏡類型:
- 收集數據:提供的文本文件
- 準備數據:解析 tab 鍵分隔的數據行
- 分析數據:快速檢查數據,確保正確地解析數據內容,使用 createPlot() 函數繪制最終的樹形圖
- 訓練算法:使用 第 1 節的 createTree() 函數
- 測試算法:編寫測試函數驗證決策樹可以正確分類給定的數據實例
- 使用算法:存儲樹的數據結構,以便下次使用時無需重新構造樹
訓練代碼:
fr = open('lenses.txt')
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
lensesTree = createTree(lenses,lensesLabels)
createPlot(lensesTree)
本章使用的算法稱為 ID3,是一個好的算法但是并不完美。ID3 無法直接處理數值型數據,盡管可以通過量化的方法將數值型數據轉化為標稱型數值,但是如果蔡遵太多的特征劃分,ID3 仍然會面臨其他問題。
5. 本章小結
決策樹分類器就像帶有終止塊的流程圖,終止塊表示分類結果。開始處理數據集時,首先需要測量集合中數據的不一致性,也就是熵,然后尋找最優方案劃分數據,直到數據集中的所有數據屬于同一分類。ID3 算法可以用于劃分標稱型數據集。構建決策樹時,通常采用遞歸的方法將數據集轉化為決策樹。