KNN例子

參考CS231n,將KNN 跑起來了,成功將系統搞死,,內存和計算能力開銷太大。

以下代碼 切記不用輕易跑。。


數據集

http://www.cs.toronto.edu/~kriz/cifar.html


code:


import os

import sys

import numpy as np

import pickle

def load_CIFAR_batch(filename):

"""

cifar-10數據集是分batch存儲的,這是載入單個batch

@參數 filename: cifar文件名

@r返回值: X, Y: cifar batch中的 data 和 labels

"""

with open(filename,"rb") as f :

datadict = pickle.load(f,encoding='iso-8859-1')

print(filename)

X=datadict['data']

Y=datadict['labels']

X=X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")

Y=np.array(Y)

return X, Y

def load_CIFAR10(ROOT):

"""

讀取載入整個 CIFAR-10 數據集

@參數 ROOT: 根目錄名

@return: X_train, Y_train: 訓練集 data 和 labels

X_test, Y_test: 測試集 data 和 labels

"""

xs=[]

ys=[]

for b in range(1,6):

f=os.path.join(ROOT, "data_batch_%d" % (b, ))

X, Y=load_CIFAR_batch(f)

xs.append(X)

ys.append(Y)

X_train=np.concatenate(xs)

Y_train=np.concatenate(ys)

del X, Y

X_test, Y_test=load_CIFAR_batch(os.path.join(ROOT, "test_batch"))

return X_train, Y_train, X_test, Y_test

# 載入訓練和測試數據集

X_train, Y_train, X_test, Y_test = load_CIFAR10('data/cifar/')

# 把32*32*3的多維數組展平

Xtr_rows = X_train.reshape(X_train.shape[0], 32 * 32 * 3) # Xtr_rows : 50000 x 3072

Xte_rows = X_test.reshape(X_test.shape[0], 32 * 32 * 3) # Xte_rows : 10000 x 3072

class NearestNeighbor:

def __init__(self):

pass

def train(self, X, y):

"""

這個地方的訓練其實就是把所有的已有圖片讀取進來 -_-||

"""

# the nearest neighbor classifier simply remembers all the training data

self.Xtr = X

self.ytr = y

def predict(self, X):

"""

所謂的預測過程其實就是掃描所有訓練集中的圖片,計算距離,取最小的距離對應圖片的類目

"""

num_test = X.shape[0]

# 要保證維度一致哦

Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

# 把訓練集掃一遍 -_-||

for i in range(num_test):

# 計算l1距離,并找到最近的圖片

distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)

min_index = np.argmin(distances) # 取最近圖片的下標

Ypred[i] = self.ytr[min_index] # 記錄下label

return Ypred

nn = NearestNeighbor() # 初始化一個最近鄰對象

nn.train(Xtr_rows, Y_train) # 訓練...其實就是讀取訓練集

Yte_predict = nn.predict(Xte_rows) # 預測

# 比對標準答案,計算準確率

print ('accuracy: %f' % ( np.mean(Yte_predict == Y_test)))

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容