機器學習貝葉斯網絡分類水果

水果部分數據


捕獲.PNG

代碼

import numpy as np
import math
import csv
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import pylab as pl
import random
from matplotlib import cm
from sklearn.model_selection import train_test_split

# 求平均值
def mean(numbers):
    return sum(numbers)/float(len(numbers))

# 求平均差
def stdev(numbers):
  avg = mean(numbers)
  variance = sum([pow(x-avg,2) for x in numbers])/float(len(numbers)-1)
  return math.sqrt(variance)

# 求各列的平均值和方差--提取數據特征
def summarize(dataset):
    parameter = [(mean(attribute), stdev(attribute)) for attribute in zip(*dataset)]
    #parameter = [(mean(dataset.iloc[:,i]),stdev(dataset.iloc[:,i])) for i in range(dataset.shape[1]) ]
    del parameter[-1]
    return parameter

# 進行分類
def separatedByClass(dataset):
    separated = {}
    #創建字典
    for i in range(len(dataset)):
        vector = dataset[i]
        if (vector[-1] not in separated):
            #根據最后一個元素,隨后一個元素為1,2,3,4,代表著水果的種類,作為鍵值key
            separated[vector[-1]] = []
        separated[vector[-1]].append(vector)
    return separated

# 類別屬性提取特征,即每一類四種特征總的均值和方差
def summarizeByClass(dataset):
    separated = separatedByClass(dataset)
    summaries = { }
    #創建字典
    for classValue, instances in separated.items():
        summaries[classValue] = summarize(instances)
    return summaries


# 求出高斯概率密度函數
def calculateProbability(x, mean, stdev):
    exponent = math.exp(-(math.pow(x - mean, 2) / (2 * math.pow(stdev, 2))))
    return (1 / (math.sqrt(2 * math.pi) * stdev)) * exponent

#所屬類的概率
def calculateClassProbabilities(summaries, inputVector):
    probabilities = {}
    #字典
    for classValue, classSummaries in summaries.items():
        probabilities[classValue] = 1
        for i in range(len(classSummaries)):
            mean, stdev = classSummaries[i]
            x = inputVector[i]
            probabilities[classValue] *= calculateProbability(x, mean, stdev)
            #求出總的高斯密度的乘積
    return probabilities

# 對數據單一預測
# 每組測試數據最有可能的情況
def predict(summaries, inputVector):
    probabilities = calculateClassProbabilities(summaries, inputVector)
    bestLabel, bestProb = None, -1
    for classValue, probability in probabilities.items():
        if bestLabel is None or probability > bestProb:

            bestProb = probability
            bestLabel = classValue
    return bestLabel

#進行多重預測
def getPredictions(summaries, testSet):
    predictions = []        #來存儲結果
    for i in range(len(testSet)):
        result = predict(summaries, testSet[i])
        predictions.append(result)
    return predictions     # 最終返回輸出結果

#輸出結果計算準確率
def getAccuracy(testSet, predictions):
    correct = 0
    print("結果:")
    for x in range(len(testSet)):
        print("預測的結果:", predictions[x], "----", testSet[x][-1], ":正確的結果")
        if testSet[x][-1] == predictions[x]:
            correct += 1
    return (correct / float(len(testSet))) * 100.0

def main():

    fruits = pd.read_table('E:/fruit.txt') #fruit.txt所在位置,我將它放在E盤。
    feature_names = ['fruit_label', 'mass', 'width', 'height', 'color_score']
    X = fruits[['mass', 'width', 'height', 'color_score', 'fruit_label']]
    Y = fruits['fruit_label']
    X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.25, random_state=0)  #通過pandas取出數據,再隨機生成X_train和X_test 訓練和測試數據
    Traindataset = X_train.values
    Testdataset = X_test.values

    '''
    代碼原因將數據轉換成一下格式,目的是為了去掉pandas中dataframe的index,如mass,width 等特征值
       mass  width  height  color_score  fruit_label
   42   154    7.2     7.2         0.82            3
   48   174    7.3    10.1         0.72            4
   變成
  [[154.     7.2    7.2    0.82   3.  ]
   [174.     7.3   10.1    0.72   4.  ]
   [ 76.     5.8    4.     0.81   2.  ]]   
   '''
    summaries = summarizeByClass(Traindataset)            #根據測試數據進行提取數據特征, 分類,求方差,均值,然后對每類進行特征值提取
    print("特征的提取:",summaries)                      #輸出貝葉斯整理的結果
    predictions = getPredictions(summaries, Testdataset)  #輸入測試數據
    accuracy = getAccuracy(Testdataset, predictions)
    print("準確率:",accuracy,'%')

if __name__ == "__main__":
    main()

運行結果


捕獲1.PNG
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容

  • 后期整理字體以及排版問題,修訂不適合的翻譯 “A wealth of information. Smart, ye...
    iamzzz閱讀 759評論 0 0
  • 3.1. 介紹 現在,您已經安裝了Wireshark并有可能熱衷于開始捕捉您的第一個數據包。在接下來的章節中,我們...
    wwyyzz閱讀 1,400評論 0 1
  • ¥開啟¥ 【iAPP實現進入界面執行逐一顯】 〖2017-08-25 15:22:14〗 《//首先開一個線程,因...
    小菜c閱讀 6,497評論 0 17
  • 第一部分 HTML&CSS整理答案 1. 什么是HTML5? 答:HTML5是最新的HTML標準。 注意:講述HT...
    kismetajun閱讀 27,588評論 1 45
  • 有一個人經常喜歡指責別人。王陽明對他說:“學習應該多反省自己,如果只是看到別人的不對,責怪別人,就不會看到自己的不...
    六爸啦啦啦閱讀 415評論 0 1