簡介
KNN(K-Nearest Neighbor)最近鄰分類算法是數據挖掘中最簡單而有效的分類技術之一。它的核心思想類似于“近朱者赤,近墨者黑”,通過依靠鄰近樣本的類別來判斷未知樣本的類別。
核心原理
為了對未知樣本進行分類,首先將所有已知類別的樣本作為參考,計算未知樣本與每個已知樣本的距離。然后,選取距離未知樣本最近的K個已知樣本,采用多數投票法則(majority-voting),將未知樣本分配到K個最近樣本中所占比例較多的類別中。
這就是KNN算法在分類任務中的基本思想,其中K表示選取的最近鄰樣本的個數。默認是5,一般不大于20。
需要注意的是,KNN算法是一種基于最近鄰的分類方法,它不同于判別類域方法,特別適用于存在類域交叉或重疊的待分樣本集。
KNN算法關鍵要點:
特征可比較的量化: 所有樣本特征必須以可比較的數值形式表示。如果存在非數值特征,需通過量化方法將其轉換為數值,例如將顏色轉換為灰度值,以便進行距離計算。
樣本特征歸一化: 樣本可能包含多個參數,具有不同的定義域和取值范圍,對距離計算產生影響。因此,應對樣本參數進行歸一化處理,以避免某些參數的影響過大。常用的方法是對所有特征進行歸一化處理。
距離函數選擇: 需要選擇適當的距離函數來計算樣本之間的距離。常用的距離函數包括歐氏距離、余弦距離、漢明距離、曼哈頓距離等。通常情況下,歐氏距離適用于連續變量,而漢明距離適用于非連續變量,如文本分類。
確定K值: K值的選擇很關鍵。選擇過大的K值可能導致欠擬合,而過小的K值可能導致過擬合。應通過交叉驗證等方法來確定合適的K值。
KNN算法優缺點總結:
優點:
- 簡單易懂,易于實現,無需參數估計和訓練。
- 適用于稀有事件分類。
- 特別適合多分類問題,性能比SVM更好。
缺點:
- 在樣本不平衡時表現不佳。大樣本類別可能占據鄰居中的多數,導致預測偏向這些大樣本類別。
- 計算量大,需要對每個待分類樣本計算與所有已知樣本的距離。
總結
KNN算法是一種簡單有效的分類方法,通過選取最近鄰樣本來進行分類判定。優化KNN算法需要注意數據預處理、距離函數選擇和K值確定等關鍵因素。雖然KNN在一些場景下存在不足,但其在多類別分類和稀有事件分類等方面的優點使其在實際應用中具有廣泛價值。
代碼
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
# 數據
# 樣本
data_x = [
[1.3, 6],
[3.5, 5],
[4.2, 2],
[5, 3.3],
[2, 9],
[5, 7.5],
[7.2, 4 ],
[8.1, 8],
[9, 2.5]
]
# 分類
data_y = [0, 0, 0, 0,
1, 1, 1, 1, 1]
#訓練集
x_train = np.array(data_x)
y_train = np.array(data_y)
x1 = x_train[y_train == 0,0]
y1 = x_train[y_train == 0,1]
x2 = x_train[y_train == 1,0]
y2 = x_train[y_train == 1,1]
# 畫圖
plt.title("KNN")
plt.scatter(x1, y1, color = "orange", marker = "o")
plt.scatter(x2, y2, color = "blue", marker = "x")
# 新的點
data_new = np.array([4, 5])
plt.scatter(data_new[0], data_new[1], color = "purple", marker = "^")
plt.show()
# 計算新的點到舊的點的距離
dist_array = np.array([])
for temp in x_train:
dist = temp - data_new
dist = np.sqrt(np.sum(dist * dist))
print(f'距離為:{dist}')
dist_array = np.append(dist_array, dist)
print(f"dist_array == {dist_array}")
# 排序距離
sort_array = np.sort(dist_array)
print(f"sort_array == {sort_array}")
sort_indext_array = np.argsort(dist_array)
print(f"sort_indext_array == {sort_indext_array}")
# 設定k值
k = 5
# 距離最近的k個點投票
first_k = [y_train[i] for i in sort_indext_array[:k]]
print(f"first_k == {first_k}")
# 列出n個最常見的元素及其從最常見元素開始的計數
c = Counter(first_k).most_common()
ans = c[0][0]
print(f"self ans == {ans}")
# scikit-learn中的KNN算法
from sklearn.neighbors import KNeighborsClassifier
# k的值
knn_obj = KNeighborsClassifier(n_neighbors=5)
# 樣本集
knn_obj.fit(x_train, y_train)
# 預測
ans = knn_obj.predict([data_new])
print(f"cikit-learn ans == {ans[0]}")