來源:數據分析網
線性判別式分析(Linear Discriminant Analysis)簡稱LDA,是模式識別的經典算法。通過對歷史數據進行投影,以保證投影后同一類別的數據盡量靠近,不同類別的數據盡量分開。并生成線性判別模型對新生成的數據進行分離和預測。本篇文章使用機器學習庫scikit-learn建立LDA模型,并通過繪圖展示LDA的分類結果。
準備工作
首先是開始前的準備工作,導入需要使用的庫文件,本篇文章中除了常規的數值計算庫numpy,科學計算庫pandas,和繪圖庫matplotlib以外,還有繪圖庫中的顏色庫,以及機器學習中的數據預處理和LDA庫。
#導入數值計算庫
import numpy as np
#導入科學計算庫
import pandas as pd
#導入繪圖庫
import matplotlib.pyplot as plt
#導入繪圖色彩庫產生內置顏色
from matplotlib.colors import ListedColormap
#導入數據預處理庫
from sklearn import preprocessing
#導入linear discriminant analysis庫
from sklearn.lda import LDA
讀取數據
讀取并創建名稱為data的數據表,后面我們將使用這個數據表創建LDA模型并繪圖。
#讀取數據并創建名為data的數據表
data=pd.DataFrame(pd.read_csv('LDA_data.csv'))
使用head函數查看數據表的前5行,這里可以看到數據表共有三個字段,分別為貸款金額loan_amnt,用戶收入annual_inc和貸款狀態loan_status。
#查看數據表的前5行
data.head()
設置模型特征X和目標Y
將數據表中的貸款金額和用戶收入設置為模型特征X,將貸款狀態設置為模型目標Y,也就是我們要分類的結果。
#設置貸款金額和用戶收入為特征X
X = np.array(data[['loan_amnt','annual_inc']])
#設置貸款狀態為目標Y
Y = np.array(data['loan_status’])
對特征進行標準化處理
貸款金額和用戶收入間差異較大,屬于兩個不同量級的數據。因此需要對數據進行標準化處理,轉化為無量綱的純數值。
#特征數據進行標準化
scaler = preprocessing.StandardScaler().fit(X)
X_Standard=scaler.transform(X)
下面是經過標準化處理后的特征數據。
#查看標準化后的特征數據
X_Standard
#設置分類平滑度
h = .01
創建LDA模型并擬合數據
將標準化后的特征X和目標Y代入到LDA模型中。下面是具體的代碼和計算結果。
#創建LDA模型
clf = LDA()
clf.fit(X_Standard,Y)
繪圖數據預處理
對繪圖數據進行預處理,計算X和Y的邊界值,并使用meshgrid函數計算坐標向量矩陣。
#設置X和Y的邊界值
x_min, x_max = X_Standard[, 0].min() - 1, X_Standard[, 0].max() + 1
y_min, y_max = X_Standard[, 1].min() - 1, X_Standard[, 1].max() + 1
#使用meshgrid函數返回X和Y兩個坐標向量矩陣
xx, yy = np.meshgrid(np.arange(x_min, x_max,h), np.arange(y_min, y_max,h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
設置圖表所使用的顏色,這里使用的是HEX值。
#設置colormap顏色
cm_bright = ListedColormap(['#D9E021', '#0D8ECF’])
繪制LDA分類圖表
首先繪制LDA分類圖表的邊界,這里使用之前計算的坐標矩陣,并設置的colormap顏色和透明度。
#繪制分類邊界
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap=cm_bright,alpha=0.6)
最后繪制LDA圖表中的數據點,并設置colormap顏色以及圖表標題。以下是具體代碼和圖表。
#繪制數據點
plt.scatter(X_Standard[, 0], X_Standard[, 1], c=Y, cmap=cm_bright)
plt.title('Linear Discriminant Analysis Classifiers')
plt.axis('tight')
plt.show()
今年第六屆大會PyConChina2016,由PyChina.org發起,CPyUG/TopGeek 等社區協辦,將在2016年9月10日(上海)9月25日(深圳)10月15日(北京、杭州)地舉辦的針對Python開發者所舉辦的最盛大和權威的Python相關技術會議,由PyChina社區主辦,致力于推動各類Python相關的技術在互聯網、企業應用等領域的研發和應用。
您可以點擊此處
了解更多詳情,或者掃描下圖二維碼: