keras實現(xiàn)手寫數(shù)字識別(數(shù)據(jù)集:MNIST)

準備工作

  • keras
  • tensorflow
  • numpy
  • PIL

下載MNIST數(shù)據(jù)集

from keras.dataset import mnist

mnist.load_data(path)

path是保存的路徑

模型結(jié)構(gòu)

model1.png

這個模型用了兩個Convolution2D層,兩個MaxPooling2D層,一個Flatten層,兩個全連接Dense層,使用的激活函數(shù)是relu,優(yōu)化器是adam

訓(xùn)練代碼

from keras.model import Sequential
from keras.layers import Convolution2D, Dense, Flatten, Activation, MaxPooling2D
from keras.utils import to_catagorical
from keras.optimizers import Adam
import numpy as np

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000, 28, 28, 1)
y_test = to_categorical(y_test, 10)
y_train = to_categorical(y_train, 10)

# design model
model = Sequential()
model.add(Convolution2D(25, (5, 5), input_shape=(28, 28, 1)))
model.add(MaxPooling2D(2, 2))
model.add(Activation('relu'))
model.add(Convolution2D(50, (5, 5)))
model.add(MaxPooling2D(2, 2))
model.add(Activation('relu'))
model.add(Flatten())

model.add(Dense(50))
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
adam = Adam(lr=0.001)
# compile model
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# training model
model.fit(x_train, y_train, batch_size=100, epochs=5)
# test model
print model.evaluate(x_test, y_test, batch_size=100)
# save model
model.save('/Users/zhang/Desktop/my_model2.h5')
訓(xùn)練效果
Using TensorFlow backend.
Epoch 1/5
2017-09-12 14:49:32.779373: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-09-12 14:49:32.779389: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-09-12 14:49:32.779393: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-09-12 14:49:32.779398: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-09-12 14:49:32.779401: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
60000/60000 [==============================] - 33s - loss: 2.5862 - acc: 0.8057    
Epoch 2/5
60000/60000 [==============================] - 32s - loss: 0.0603 - acc: 0.9820    
Epoch 3/5
60000/60000 [==============================] - 32s - loss: 0.0409 - acc: 0.9873    
Epoch 4/5
60000/60000 [==============================] - 32s - loss: 0.0338 - acc: 0.9895    
Epoch 5/5
60000/60000 [==============================] - 33s - loss: 0.0259 - acc: 0.9922    
 9900/10000 [============================>.] - ETA: 0s[0.054905540546023986, 0.98440000832080843]
  • 可以看到在測試集上識別準確率達到了98.44%。
  • 其實在訓(xùn)練過程中存在運氣問題,對于batch_size=100來說,如果從一開始沒有跑出最優(yōu)值,可能就進入了死胡同,導(dǎo)致訓(xùn)練的準確率一直只有9%。
  • 經(jīng)過一輪訓(xùn)練就能達到80%的準確率,這樣的訓(xùn)練效果很有可能導(dǎo)致過擬合,雖然在測試集上有98%的準確率。

測試

我用ps畫出了幾個手寫數(shù)字的圖片進行測試
WX20170912-145816@2x.png
這些都是28*28的圖片
測試代碼如下
from keras.models import load_model
import numpy as np
from PIL import Image


def ImageToMatrix(filename):
    im = Image.open(filename)
    #change to greyimage
    im=im.convert("L")
    data = im.getdata()
    data = np.matrix(data,dtype='int')
    return data

model = load_model('/Users/zhang/Desktop/my_model.h5')


while 1:
    i = input('number:')
    j = input('type:')
    data = ImageToMatrix('/Users/zhang/Desktop/picture/'+str(i)+'_'+str(j)+'.png')
    data = np.array(data)
    data = data.reshape(1, 28, 28, 1)
    print 'test['+str(i)+'_'+str(j)+'], num='+str(i)+':'
    print model.predict_classes(
        data, batch_size=1, verbose=0
    )
選取了幾個結(jié)果
number:7
type:1
test[7_1], num=7:
[7]

number:7
type:2
test[7_2], num=7:
[7]

number:7
type:3
test[7_3], num=7:
[7]

number:2
type:1
test[2_1], num=2:
[2]

number:1
type:1
test[1_1], num=1:
[4]

number:1
type:2
test[1_2], num=1:
[1]

number:1
type:3
test[1_3], num=1:
[1]

number:6
type:1
test[6_1], num=6:
[5]

總結(jié):

  • 該模型對于小字體沒法正常識別(和訓(xùn)練集字體大小有關(guān))
  • 對于類似 '1' 等數(shù)字,如果放在圖片邊緣,如:1_1,沒法準確識別
  • 當然,對于顛倒方向,橫放豎放的數(shù)字也沒法準確識別
  • 在mnist的測試集中,可以說是10000張圖片只有大約200張識別錯誤

改進方案

  • 對讀取進來的圖片進行先行處理(顛倒,居中,縮放等等),使得識別更加容易
  • 對訓(xùn)練集進行處理,使用不同方向的數(shù)字訓(xùn)練集(訓(xùn)練量加大)
  • 對神經(jīng)層進行改進(由于沒有深入了解,其實這次的神經(jīng)網(wǎng)絡(luò)也是瞎編的,只不過使用了卷積層和全連接層的結(jié)合)
推薦李宏毅的機器學習的視頻(youtube:李宏毅)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容