作為深度學(xué)習(xí)界的“hello world!”,學(xué)習(xí)起來真沒那么容易。
接觸深度學(xué)習(xí),第一個接觸的就是mnist。但是初次接觸就只跑了三個腳本
get_mnist.sh
create_mnist.sh
train_lenet.sh
然后就結(jié)束了,對此我蒙逼了許久。因?yàn)閷τ赾affe的整體框架不熟悉,對CNN不深入,因此感覺舉步維艱。經(jīng)過1個多月的沉淀終于能完整的走一遍MNIST。
對于初學(xué)者,深度學(xué)習(xí)分為三步:1.數(shù)據(jù)準(zhǔn)備 2.訓(xùn)練 3.預(yù)測
一.數(shù)據(jù)準(zhǔn)備
官方例程推薦的數(shù)據(jù)集為
t10k-images-idx3-ubyte
t10k-labels-idx1-ubyte
train-images-idx3-ubyte
train-labels-idx1-ubyte
相信許多人和我一樣會問:這是什么啊,打開還是一推二進(jìn)制數(shù)。確實(shí),官方的數(shù)據(jù)集可視化不好,但是可以借助matlab或者python解析出來。但是,對于普通人拿到的數(shù)據(jù)往往都是圖片格式,而且是很多。
這該進(jìn)行怎么加載訓(xùn)練呢。
先粗略的看下,官方的數(shù)據(jù)集。可以看出images對應(yīng)一個labels,所以我們準(zhǔn)備的數(shù)據(jù)包括圖片和標(biāo)簽。
1)基礎(chǔ)準(zhǔn)備
在data文件夾下創(chuàng)建如下文件夾,準(zhǔn)備訓(xùn)練集,驗(yàn)證集和測試集
創(chuàng)建 train test文件夾和對應(yīng)的txt將你的訓(xùn)練集放到train中,將驗(yàn)證集放到test中。(這里應(yīng)該多建一個valid文件夾,里面存放的是驗(yàn)證集,而test中放測試集,這里偷工減料了)
接著要制作標(biāo)簽,如果量少可以考慮手敲,但是大數(shù)據(jù)就只能借助代碼了。
創(chuàng)建make_list.py
#coding=utf-8
#caffe and opencv test mnist
#test by yuzefan
import os
from os.path import join, isdir
def gen_listfile(dir):
cwd=os.getcwd() # 獲取當(dāng)前目錄
os.chdir(dir) # 改變當(dāng)前的目錄
sd=[d for d in os.listdir('.') if isdir(d)] # 列出當(dāng)前目錄下的所有文件和目錄名,os.listdir可以列出文件和目錄
sd.sort()
class_id=0
with open(join(dir,'listfile.txt'),'w') as f : #join():connect string,"with...as"is used for safety,without it,you must write by"file = open("/tmp/foo.txt") file.close()
for d in sd :
fs=[join(d,x) for x in os.listdir(d)]
for img in fs:
f.write(img + ' '+str(class_id)+'\n')
class_id+=1
os.chdir(cwd)
if __name__ == "__main__":
root_dir = raw_input('image root dir: ')
while not isdir(root_dir):
raw_input('not exist, re-input please: ')
gen_listfile(root_dir)
運(yùn)行后可以得到標(biāo)簽,如下:
list已經(jīng)準(zhǔn)備好了,接著要把數(shù)據(jù)轉(zhuǎn)成lmdb。caffe之所以速度快,得益于lmdb數(shù)據(jù)格式。
創(chuàng)建creat_lmdb.sh腳本
#coding=utf-8
#!/usr/bin/env sh
#指定腳本的解釋程序
#by yuzefan
set -e #如果任何語句的執(zhí)行結(jié)果不是true則應(yīng)該退出
# CAFFEIMAGEPATH is the txt file path
# DATA is the txt file path
CAFFEDATAPATH=mytest/chinese/data
DATA=mytest/chinese/data/mnist
TOOLS=~/caffe-master/build/tools
# TRAIN_DATA_PATH & VAL_DATA_ROOT is root path of your images path, so your train.txt file must do not contain
# this line again!!
TRAIN_DATA_ROOT=/home/ubuntu/caffe-master/mytest/chinese/data/mnist/train/
VAL_DATA_ROOT=/home/ubuntu/caffe-master/mytest/chinese/data/mnist/test/
# Set RESIZE=true to resize the images to 28x28. Leave as false if images have
# already been resized using another tool.
RESIZE=true
if $RESIZE;then
RESIZE_HEIGHT=28
RESIZE_WIDTH=28
else
RESIZE_HEIGHT=0
RESIZE_WIDTH=0
fi
if [ ! -d "$TRAIN_DATA_ROOT" ]; then
echo "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"
echo "Set the TRAIN_DATA_ROOT variable in create_imagenet.sh to the path" \
"where the ImageNet training data is stored."
exit 1
fi
if [ ! -d "$VAL_DATA_ROOT" ]; then
echo "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"
echo "Set the VAL_DATA_ROOT variable in create_imagenet.sh to the path" \
"where the ImageNet validation data is stored."
exit 1
fi
echo "Creating train lmdb..."
GLOG_logtostderr=1 $TOOLS/convert_imageset \
--resize_height=$RESIZE_HEIGHT \
--resize_width=$RESIZE_WIDTH \
--shuffle \
--gray=true\
$TRAIN_DATA_ROOT \
$DATA/train.txt \
$CAFFEDATAPATH/caffe_train_lmdb
echo "Creating val lmdb..."
GLOG_logtostderr=1 $TOOLS/convert_imageset \
--resize_height=$RESIZE_HEIGHT \
--resize_width=$RESIZE_WIDTH \
--shuffle \
--gray=true\
$VAL_DATA_ROOT \
$DATA/test.txt \
$CAFFEDATAPATH/caffe_val_lmdb
echo "Done."
運(yùn)行完后在data目錄下出現(xiàn)
caffe_train_lmdb
caffe_val_lmdb
這里使用了caffe的tools中的convert_imageset。使用方法:
convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME
其中
參數(shù):ROOTFOLDER 表示輸入的文件夾
參數(shù):LISTFILE 表示輸入文件列表,其每一行為:類似 subfolder1/file1.JPEG 7
可選參數(shù):[FLAGS] 可以指示是否使用shuffle,顏色空間,編碼等。
--gray=true \-------------------------------------------->如果灰度圖的話加上即可
還調(diào)用了opencv,對輸入圖像進(jìn)行尺寸變換,滿足網(wǎng)絡(luò)的要求。
注意:
TRAIN_DATA_PATH & VAL_DATA_ROOT is root path of your images path, so your train.txt file must do not contain
到此,數(shù)據(jù)準(zhǔn)備就結(jié)束了。
二.訓(xùn)練
訓(xùn)練需要模型描述文件和模型求解文件。
lenet_train_test.prototxt
lenet_solver.prototxt
對于lenet_train_test.prototxt,需要改的地方只有數(shù)據(jù)層
name: "LeNet"
layer {
name: "mnist" #名字隨便
type: "Data"
top: "data"
top: "label"
include {
phase: TRAIN
}
transform_param {
scale: 0.00390625
}
data_param {
source: "mytest/chinese/data/caffe_train_lmdb" #這里是上一步生成的lmdb
batch_size: 64#一次壓入網(wǎng)絡(luò)的數(shù)量
backend: LMDB
}
}
layer {
name: "mnist"
type: "Data"
top: "data"
top: "label"
include {
phase: TEST
}
transform_param {
scale: 0.00390625
}
data_param {
source: "mytest/chinese/data/caffe_val_lmdb"
batch_size: 100
backend: LMDB
}
}
對于lenet_solver.prototxt
# The train/test net protocol buffer definition
net: "mytest/chinese/lenet_train_test.prototxt"#這里可以把訓(xùn)練和驗(yàn)證放到一起,實(shí)際可以分開
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
test_iter: 100 #test_iter * batch_size= 10000(test集的大小)
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "inv"
gamma: 0.0001
power: 0.75
# Display every 100 iterations
display: 20
# The maximum number of iterations
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "mytest/chinese/lenet"
# solver mode: CPU or GPU
solver_mode: GPU
訓(xùn)練可以執(zhí)行train_lenet.sh,實(shí)際上還是調(diào)用了tools
#!/usr/bin/env sh
set -e
./build/tools/caffe train --solver=mytest/chinese/lenet_solver.prototxt $@
沒有意外的話就能正常開始訓(xùn)練了。
三.預(yù)測
預(yù)測可以參考我之前寫的
Caffe學(xué)習(xí)筆記1:用訓(xùn)練好的mnist模型進(jìn)行預(yù)測(兩種方法)
http://www.lxweimin.com/p/6fcdefbacf5b
小筆記:均值計(jì)算
減均值預(yù)處理能提高訓(xùn)練和預(yù)測的速度,利用tools
二進(jìn)制格式的均值計(jì)算
build/tools/compute_image_mean examples/mnist/mnist_train_lmdb examples/mnist/mean.binaryproto
帶兩個參數(shù):
第一個參數(shù):examples/mnist/mnist_train_lmdb, 表示需要計(jì)算均值的數(shù)據(jù),格式為lmdb的訓(xùn)練數(shù)據(jù)。
第二個參數(shù):examples/mnist/mean.binaryproto, 計(jì)算出來的結(jié)果保存文件。
接下來的計(jì)劃:現(xiàn)在說白了是個10類的分類器,接下來增強(qiáng)網(wǎng)絡(luò)使其能夠訓(xùn)練并預(yù)測出0~9 and ‘a(chǎn)’~‘z’