layout: post
title: 深度學習入門 基于Python的理論實現
subtitle: 第三章 手寫數字識別
tags: [Machine learning, Reading]
第三章 神經網絡
3.6 手寫數字識別
上一個post介紹了神經網絡的基本內容,這一節搭配項目解決實際問題。這個例子非常簡單,是一個機器學習里的Hello world。手寫數字識別問題。但是這個例子是不完全的,我們<font color=red>假設學習已經全部完成</font>,我們用學習到的參數,先實現神經網絡的“推理處理”。這也叫神經網絡的前向傳播。
3.6.1 MNIST數據集
這個數據集網上的資料實在太多了,就連他的進階版本Fashion MNIST也出來很久了,相信能看到現在的人沒有太多人不知道這個數據集。
介紹簡單帶過。MNIST數據集(Mixed National Institute of Standards and Technology database)是美國國家標準與技術研究院收集整理的大型手寫數字數據庫,包含60,000個示例的訓練集以及10,000個示例的測試集。
MNIST的圖像是2828像素的灰度圖像(1通道),像素的取值在0到255之間。每個圖像都標有對應的阿拉伯數字標簽。
這本書提供了數據集和相應的代碼。傳送門
import sys,os
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
(x_train, t_train),(x_test, t_test) = load_mnist(flatten=True, normalize=False)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
第一次執行代碼可能比較慢,原因是需要下載,服務器在國外,下載的比較慢。也可以手動下載放在文件夾,我就是用的這個方法(自動下載實在太慢了)。
于是我們打印出訓練集,測試集和對應的label的shape。
這里對代碼做一點簡單說明,這里的load_mnist函數是將數據集做導入,分別為兩個訓練集兩個測試集,flatten參數為True代表將2828的圖像扁平化,變成1
784的向量。normalize的含義是將數值標準化為0到1之間的數字,這個函數還可以傳入一個參數,就是one_hot_label,這個參數設置為True將會讓標簽變為one hot representation。
因為這里并不涉及參數的訓練,因此我們需要導入參數,這離有一個pkl文件,保存著訓練好的參數,直接導入就可以。下來簡單顯示一下圖片。
import sys,os,cv2
import numpy as np
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
(x_train, t_train),(x_test, t_test) = load_mnist(flatten=True, normalize=True)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
first_image = x_train[0]
first_label = t_train[0]
print(first_label)
img = first_image.reshape(28,28)
cv2.imshow('img',img)
cv2.waitKey(0)
cv2.destroyAllWindows()
3.6.2 神經網絡的推理處理
根據之前的內容,我們設計的神經網絡的輸入層的神經元個數,應該是784,也就是圖像拉長之后的向量長度。輸出層為10個神經元,因為輸出層神經元的數量應該和分類的種類相等。另外,這神經網絡有兩個隱藏層,第一個隱藏層50個神經元,第二個有100個神經元。結合之前學過的知識,得到代碼如下。
import sys,os,cv2,pickle
import numpy as np
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
def sigmoid(x):
return 1/(1+np.exp(-x))
def softmax(x):
exp_x = np.exp(x-np.max(x))
sum_exp_x = np.sum(exp_x)
return exp_x/sum_exp_x
def get_data():
(x_train, t_train),(x_test, t_test) = load_mnist(flatten=True, normalize=True)
return x_test,t_test #沒有訓練階段,因此只取測試數據
def init_network():
with open("sample_weight.pkl",'rb') as f:
network = pickle.load(f)
return network
def predict(network,x):
W1,W2,W3 = network['W1'],network['W2'],network['W3']
b1,b2,b3 = network['b1'],network['b2'],network['b3']
a1 = np.dot(x,W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1,W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2,W3) + b3
y = softmax(a3)
return y
x,t = get_data()
accuracy_cnt = 0
for i in range(len(x)):
y = predict(network,x[i])
p = y.argmax()
if p == t[i]:
accuracy_cnt += 1
accuracy = accuracy_cnt/t.shape[0]
print(accuracy)
最后得到的正確率為93.52%。
3.6.3 批處理
上面就是神經網絡的實現,但是在實際寫代碼的過程中,有一個問題,那就是在計算正確率的時候,我們是一遍一遍調用predict函數,得到結果和label進行比較,使用了for循環,這顯然是不好的,我們引入矩陣運算就是為了應對這樣的情況。但是還應該注意到的是,我們也不能一次將所有的數據輸入進去,因為這會引起內存的溢出等等一系列問題,因此我們使用批處理可以讓計算更加高效。之前計算時,矩陣大小的傳遞過程如下:
下面我們一次性傳入100張圖片,也就是輸入矩陣的大小改變,變成如下的形式。
通過比較可以明顯的看出,一次可以計算100張圖片,輸出100個結果。具體的細節就不多講了,直接上實現。
import sys,os,cv2,pickle
import numpy as np
sys.path.append(os.pardir)
from dataset.mnist import load_mnist
def sigmoid(x):
return 1/(1+np.exp(-x))
def softmax(x):
exp_x = np.exp(x-np.max(x))
sum_exp_x = np.sum(exp_x)
return exp_x/sum_exp_x
def get_data():
(x_train, t_train),(x_test, t_test) = load_mnist(flatten=True, normalize=True)
return x_test,t_test #沒有訓練階段,因此只取測試數據
def init_network():
with open("sample_weight.pkl",'rb') as f:
network = pickle.load(f)
return network
def predict(network,x):
W1,W2,W3 = network['W1'],network['W2'],network['W3']
b1,b2,b3 = network['b1'],network['b2'],network['b3']
a1 = np.dot(x,W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1,W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2,W3) + b3
y = softmax(a3)
return y
x,t = get_data()
batch_size = 100
accuracy_cnt = 0
for i in range(0,len(x),batch_size):
x_batch = x[i:i+batch_size]
y_batch = predict(network,x_batch)
p = np.argmax(y_batch,axis=1)
accuracy_cnt += np.sum(p == t[i:i+batch_size])
accuracy = accuracy_cnt/t.shape[0]
print(accuracy)
最終得到的結果是完全一樣的。
3.7 小結
第三章講的是神經網絡的前向傳播。也就是數據是如何傳遞的,當然這一章的內容也是不完整的,因為沒有訓練部分,而是直接載入參數。
- 實現部分主要的重點在于對批處理的理解。