[kaggle系列 二] 使用決策樹判斷是否能從泰坦尼克號生還

題目

連接:https://www.kaggle.com/c/titanic

簡析

上一篇用了貝葉斯分類器,這次用決策樹和隨機森林試一試,不過最終的得分沒有貝葉斯分類器高,好吧,說實話,感覺再用幾個不同的機器學習方法應該結果也差不多,現在主要是試水,先搞懂基礎的算法,然后再通過數據的處理與分析去優化結果。

決策樹

我個人認為,決策樹應該是比較好理解的機器學習算法了。其中心思想就是ifelse,存在很多個條件的時候,如果第一個條件是A,第二個條件是B…………就選擇方案C。是一個很自然的方法,我們平常生活中也可能很常用,比如下圖就是一個屌絲假日的日常決策樹:



看起來是很簡單的,但是要怎么應用到機器學習上,讓它成為一個分類器呢?以本例中的問題來說,我知道每一個人的生還情況,還有他的各種屬性(特征),我們要根據這些特征,來生成一顆決策樹,最終到達葉子節點的時候,我們就知道對于某一系列的屬性,最終是生還還是死亡,大概就是要生成這樣一棵樹(進行了簡化):

中間還可以加很多其他屬性,比如艙位信息之類的,最終,在篩完所有信息以后,你就可以通過當前訓練數據給出一個概率,表示當前葉結點生還的概率。
接下來的問題是如何實現,選擇條件的順序是否影響結果,是不是可以隨便選擇條件呢?當然這里樹的結構肯定會影響最終的結果,如何去構造一棵樹呢?
這里需要用到一個信息熵的概念。熵這個東西應該大家都聽說過,熵表明的是一個事物的混亂程度,熵越大,混亂程度越高,熵越小,表明混亂程度越低。信息熵的概念也是一樣的,就是用來表明信息的混亂程度,我們選擇一個樹根的時候,最好的情況肯定是通過這個屬性把數據分成幾類以后,這些數據的熵越小越好,因為越小代表越有序,分類越清晰。那么我們要做的就是計算每個條件作為當前的根結點的信息熵,最終選一個最小的分類方法作為根節點,并以此類推,直到葉節點~
信息熵的計算公式如下:

其中,m是最終的分類結果,在本例中,就是生還與否兩個類,pi是這個決策(分類)發生的概率。

隨機森林

隨機森林就是決策樹的加強版,決策樹這種方法,雖然有信息熵作為劃分方法,但是實際上,如果劃分到最精細的一層,那么就會出現過擬合的問題,泛化能力就比較差,往往訓練數據上表現的比較好,對于新的數據,準確度就會變低。在我寫的決策樹代碼就出現了這個問題,當時沒多想,直接分到最后一層,結果準確率只有0.57。但是如果不分得更精細,準確度也不夠高。可能需要進行大量的測試,才能找到一個平衡點,既不至于過擬合,也不至于欠擬合導致準確率太低。
隨機森林提供了一個比較通用的解決方法,就是隨機生成多個比較淺的決策樹,當進行擬合的時候,讓多個決策樹進行投票,最終哪個分類的票高就決定是哪個類。

代碼與結果

首先是決策樹的代碼,說實話,這代碼寫的比較丑,寫起來不是很順手,邊學邊寫,邏輯搞的有點亂。最終準確率只有0.57416,這個我認為是劃分過于精細導致模型過擬合了,在適當的分支進行剪枝效果可能會更好,當然,也有可能哪里寫出了點小bug(笑)。

import csv
import os
import random
import math

class Node:
    def __init__(self):
        self.attr_name = ""
        self.value_type = ""
        self.classifier = None
        self.childrens = []
        self.entropy = 0

    def getNext(self, value):
        result = 0
        if self.value_type == 'disperse_data':
            pos = 0
            if self.classifier.has_key(value):
                pos = self.classifier[value]
            else:
                pos = random.randint(0, len(self.childrens) - 1) 
            result = self.childrens[pos]
        elif self.value_type == 'continuity_data':
            if value <= self.classifier:
                result = self.childrens[0]
            else:
                result = self.childrens[1] 
        return result
        

def readData(fileName):
    result = {}
    with open(fileName,'rb') as f:
        rows = csv.reader(f)
        for row in rows:
            if result.has_key('attr_list'):
                for i in range(len(result['attr_list'])):
                    key = result['attr_list'][i]
                    if not result.has_key(key):
                        result[key] = []
                    result[key].append(row[i])
            else:
                result['attr_list'] = row
    return result

def writeData(fileName, data):
    csvFile = open(fileName, 'w')
    writer = csv.writer(csvFile)
    n = len(data)
    for i in range(n):
        writer.writerow(data[i])
    csvFile.close()

def convertData(dataList):
    hashTable = {}
    count = 0
    for i in range(len(dataList)):
        if not hashTable.has_key(dataList[i]):
            hashTable[dataList[i]] = count
            count += 1
        dataList[i] = str(hashTable[dataList[i]])

def convertValueData(dataList):
    sumValue = 0.0
    count = 0
    for i in range(len(dataList)):
        if dataList[i] == "":
            continue
        sumValue += float(dataList[i])
        count += 1
        dataList[i] = float(dataList[i])
    avg = sumValue / count
    for i in range(len(dataList)):
        if dataList[i] == "":
            dataList[i] = avg

def dataPredeal(data):
    useDataList = ['Sex','Pclass', 'SibSp','Parch','Embarked']
    result = {}
    convertValueData(data["Age"])
    result['Age'] = data['Age']
    for i in range(len(useDataList)):
        attrName = useDataList[i]
        convertData(data[attrName])
        result[attrName] = data[attrName]
    return result

def calEntropy(dataList, labelList, isContinuity):
    if not isContinuity:
        count = 0.0
        attrCount = {}
        for i in range(len(dataList)):
            key = dataList[i]
            label = labelList[i]
            count += 1
            if not attrCount.has_key(key):
                attrCount[key] = {'0':0.0,'1':0.0}
            if not attrCount[key].has_key(label):
                attrCount[key][label] = 0.0
            attrCount[key][label] += 1.0
        entropy = 0
        for key in attrCount:
            p0 = attrCount[key]['0']/(attrCount[key]['0'] + attrCount[key]['1'])
            p1 = attrCount[key]['1']/(attrCount[key]['0'] + attrCount[key]['1'])
            v0 = 0 if p0 == 0 else p0*math.log(p0,2)
            v1 = 0 if p1 == 0 else p1*math.log(p1,2)
            temp = (attrCount[key]['0'] + attrCount[key]['1']) / count * (v0 + v1)
            entropy -= temp
        return entropy, None
    else:
        ageList = set([dataList[i] for i in range(len(dataList))])
        ageList = list(ageList)
        ageList.sort()
        minEntropy = 1
        targetAge = 0
        for i in range(len(ageList) - 1):
            avgAge = (ageList[i] + ageList[i + 1]) / 2
            count = 0.0
            left_sum = {'0':0.0,'1':0.0}
            right_sum = {'0':0.0,'1':0.0}
            for j in range(len(dataList)):
                if dataList[j] <= avgAge:
                    left_sum[labelList[j]] += 1.0
                else:
                    right_sum[labelList[j]] += 1.0
                count += 1.0
            pl = (left_sum['0'] + left_sum['1']) / count
            pl0 = left_sum['0']/(left_sum['0'] + left_sum['1'])
            pl1 = 1.0 - pl0
            pr = (right_sum['0'] + right_sum['1']) / count
            pr0 = right_sum['0']/(right_sum['0'] + right_sum['1'])
            pr1 = 1.0 - pr0
            vl0 = 0 if pl0 == 0 else pl0*math.log(pl0,2)
            vl1 = 0 if pl1 == 0 else pl1*math.log(pl1,2)
            vr0 = 0 if pr0 == 0 else pr0*math.log(pr0,2)
            vr1 = 0 if pr1 == 0 else pr1*math.log(pr1,2)
            entropy = - pl*(vl0 + vl1) - pr*(vr0 + vr1)
            if entropy < minEntropy:
                minEntropy = entropy
                targetAge = avgAge
        return minEntropy, targetAge

def checkFinal(data,labelList, root):
    diff_count = 0
    hash_key = {}
    attrName = ""
    for key in data:
        if not hash_key.has_key(key):
            hash_key[key] = True
            diff_count += 1
            attrName = key
        if diff_count > 1:
            break
    if diff_count > 1:
        return False
    root.attr_name = attrName
    root.value_type = 'continuity_data' if attrName == 'Age' else 'disperse_data'
    ageBoundary = None
    if attrName == 'Age':
        entropy,ageBoundary = calEntropy(data[attrName], labelList, True)
    statistics = {}
    for i in range(len(data[attrName])):
        key = data[attrName][i]
        if ageBoundary != None:
            key = 0 if key <= ageBoundary else 1
        if not statistics.has_key(key):
            statistics[key] = [0.0,0.0]
        pos = int(labelList[i])
        statistics[key][pos] += 1.0
    
    root.classifier = ageBoundary if attrName == 'Age' else {}
    root.childrens = [] if attrName != 'Age' else [0,0]
    count = 0
    for key in statistics:
        if ageBoundary == None:
            if not root.classifier.has_key(key):
                root.classifier[key] = count
                root.childrens.append(0)
                count += 1
            root.childrens[root.classifier[key]] = 0 if statistics[key][0] > statistics[key][1] else 1
        else:
            root.childrens[key] = 0 if statistics[key][0] > statistics[key][1] else 1
    return True

def deepPrint(deep, info):
    s = ''
    for i in range(deep):
        s += ' '
    s += 'deep:' + str(deep) + '   attr:' + info
    print s

def buildTree(data, labelList, deep=0):
    root = Node()
    if checkFinal(data, labelList, root) == True:
        #deepPrint(deep, root.attr_name)
        return root
    minEntropy = 1
    targetAttrName = ''
    continuityValueBoundary = None
    for key in data:
        entropy, targetAge = calEntropy(data[key], labelList, key == 'Age')
        if entropy < minEntropy:
            minEntropy = entropy
            targetAttrName = key
            continuityValueBoundary = targetAge
    root.attr_name = targetAttrName
    #deepPrint(deep, root.attr_name)
    if continuityValueBoundary != None:
        root.value_type = 'continuity_data'
        root.classifier = continuityValueBoundary
        root.childrens = [0,0]
    else:
        root.value_type = 'disperse_data'
        root.classifier = {}
        root.childrens = []
    subDatas = {}
    for i in range(len(data[targetAttrName])):
        key = data[targetAttrName][i]
        if continuityValueBoundary != None:
            if key <= root.classifier:
                key = 0
            else:
                key = 1
        if not subDatas.has_key(key):
            subDatas[key] = {'data':{},'labelList':[]}
        for k in data:
            if k != targetAttrName:
                if not subDatas[key]['data'].has_key(k):
                    subDatas[key]['data'][k] = []
                subDatas[key]['data'][k].append(data[k][i])
        subDatas[key]['labelList'].append(labelList[i])

    count = 0
    for key in subDatas:
        child = buildTree(subDatas[key]['data'], subDatas[key]['labelList'], deep+1)
        if root.value_type == 'continuity_data':
            root.childrens[key] = child
        else:
            root.classifier[key] = count
            root.childrens.append(child)
            count += 1
    return root
    
def train(train_data):
    x = dataPredeal(train_data)
    tree = buildTree(x, train_data['Survived'])
    return tree

def fit(tree, test_data, pos):
    result = tree
    while(result != 0 and result != 1):
        result = result.getNext(test_data[result.attr_name][pos])
    return [test_data['PassengerId'][pos],result]

def run():
    dataRoot = '../../kaggledata/titanic/'
    train_data = readData(dataRoot + 'train.csv')
    test_data = readData(dataRoot + 'test.csv')
    tree = train(train_data)
    result_list = []
    result_list.append(['PassengerId', 'Survived'])
    for i in range(len(test_data['PassengerId'])):
        result_list.append(fit(tree, test_data, i))
    writeData(dataRoot + 'result.csv', result_list)

run()

下面的代碼用了sklearn庫里的隨機森林的方法,還是很方便的,效果也還行,準確率有0.74641,果然三個臭皮匠干死諸葛亮~

import csv
import os
import random
import math
from sklearn.ensemble import RandomForestClassifier

def readData(fileName):
    result = {}
    with open(fileName,'rb') as f:
        rows = csv.reader(f)
        for row in rows:
            if result.has_key('attr_list'):
                for i in range(len(result['attr_list'])):
                    key = result['attr_list'][i]
                    if not result.has_key(key):
                        result[key] = []
                    result[key].append(row[i])
            else:
                result['attr_list'] = row
    return result

def writeData(fileName, data):
    csvFile = open(fileName, 'w')
    writer = csv.writer(csvFile)
    n = len(data)
    for i in range(n):
        writer.writerow(data[i])
    csvFile.close()

def convertData(dataList):
    hashTable = {}
    count = 0
    for i in range(len(dataList)):
        if not hashTable.has_key(dataList[i]):
            hashTable[dataList[i]] = count
            count += 1
        dataList[i] = str(hashTable[dataList[i]])

def convertValueData(dataList):
    sumValue = 0.0
    count = 0
    for i in range(len(dataList)):
        if dataList[i] == "":
            continue
        sumValue += float(dataList[i])
        count += 1
        dataList[i] = float(dataList[i])
    avg = sumValue / count
    for i in range(len(dataList)):
        if dataList[i] == "":
            dataList[i] = avg

def dataPredeal(data):
    useDataList = ['Sex','Pclass', 'SibSp','Parch','Embarked']
    convertValueData(data["Age"])
    for i in range(len(useDataList)):
        attrName = useDataList[i]
        convertData(data[attrName])
    
def train(train_data):
    dataPredeal(train_data)
    useList = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch']
    x = []
    y = []
    for i in range(len(train_data['Survived'])):
        item = []
        for j in range(len(useList)):
            item.append(train_data[useList[j]][i])
        x.append(item)
        y.append(train_data['Survived'][i])
    clf = RandomForestClassifier().fit(x,y)
    return clf

def predict(clf, test_data, pos):
    x = [[]]
    useList = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch']
    for i in range(len(useList)):
        x[0].append(test_data[useList[i]][pos])
    result = clf.predict(x)
    return [test_data['PassengerId'][pos],int(result[0])]

def run():
    dataRoot = '../../kaggledata/titanic/'
    train_data = readData(dataRoot + 'train.csv')
    test_data = readData(dataRoot + 'test.csv')
    clf = train(train_data) 
    dataPredeal(test_data)
    result_list = []
    result_list.append(['PassengerId', 'Survived'])
    for i in range(len(test_data['PassengerId'])):
        result_list.append(predict(clf, test_data, i))
        print 'cal:' + str(i)
    writeData(dataRoot + 'result.csv', result_list)

run()
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,786評論 6 534
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,656評論 3 419
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 176,697評論 0 379
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,098評論 1 314
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,855評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,254評論 1 324
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,322評論 3 442
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,473評論 0 289
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,014評論 1 335
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,833評論 3 355
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,016評論 1 371
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,568評論 5 362
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,273評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,680評論 0 26
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,946評論 1 288
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,730評論 3 393
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,006評論 2 374

推薦閱讀更多精彩內容