本篇更多的是在代碼實戰方向,不會涉及太多的理論。本文主要針對TensorFlow和卷積神經網絡有一定基礎的同學,并對圖像處理有一定的了解。
閱讀本文你大概需要以下知識:
1.TensorFlow基礎
2.TensorFlow實現卷積神經網絡的前向傳播過程
3.TFRecord數據格式
4.Dataset的使用
5.Slim的使用
好了廢話不多說,下面開始。
一.數據準備
首先我們需要有一個讓我們訓練的數據集,這里谷歌已經幫我們做好了。這里要把數據集下載下來,打開命令行,執行如下命令:
wget http://download.tensorflow.org/example_image/flower_photo.tgz
//解壓
tar xzf flower_photos.tgz
這里需要注意的是,文件最好是下載到你的工程目錄下方便你的讀取。什么?你還不會搭建TensorFlow程序?請移步https://www.tensorflow.org/install/
選擇自己的操作系統,在這里我的是macOS。我使用的是Virtualenv來搭建TensorFlow運行環境。
數據集下載并解壓后,我們可以看到大概是這個樣子
好了,數據有了?接下來該怎么辦呢?當然是把數據進行預處理拉,你不會覺得我們的TensorFlow可以直接識別這些圖片進行訓練吧,hhhhhh。
二.數據預處理
接下來我們在目錄下新建pre_data.python文件。TensorFlow對圖片做處理一般是生成TFRecord文件。什么是TFRecord?后面我們會講到。
首先我們要引入我們需要的庫。
# glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
import glob
#os.path生成路徑方便glob獲取
import os.path
#這里主要用到隨機數
import numpy as np
#引入tensorflow框架
import tensorflow as tf
#引入gflie對圖片做處理
from tensorflow.python.platform import gfile
相關庫在我們這個程序中的功能都作了簡單介紹,下面用到的時候我們會更加詳細的說明。
大家都知道我們的數據集一般分訓練,測試和驗證數據集。觀察上面的數據集,谷歌只是給出了每一種花的圖片,并沒有給去哪些我訓練,哪些是測試,哪些是驗證數據集。所以在這里我們要進行劃分。
#輸入圖片地址
INPUT_DATA = '../../flower_photos'
#訓練數據集
OUTPUT_FILE = './path/to/output.tfrecords'
#測試數據集
OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'
#驗證數據集
OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'
#測試數據和驗證數據的比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
關于VALIDATION_PERCENTAGE和TEST_PERCENTAGE這兩個常量,我們在后面的例子會給出。
下面我們就來定義處理數據的方法:
def create_image_lists(sess,testing_percentage,validation_percentage):
#拿到INPUT_DATA文件夾下的所有目錄(包括root)
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
#如果是root_dir不需要做處理
is_root_dir = True
#定義圖片對應的標簽,從0-4分別代表不同的花
current_label = 0
#寫入TFRecord的數據需要首先定義writer
#這里定義三個writer分別存儲訓練,測試和驗證數據
writer = tf.python_io.TFRecordWriter(OUTPUT_FILE)
writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE)
#循環目錄
for sub_dir in sub_dirs:
if is_root_dir:
#跳過根目錄
is_root_dir = False
continue
#定義空數組來裝圖片路徑
file_list = []
#生成查找路徑
dir_name = os.path.basename(sub_dir)
file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg")
# extend合并兩個數組
# glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
# 比如:glob.glob(r’c:*.txt’) 這里就是獲得C盤下的所有txt文件
file_list.extend(glob.glob(file_glob))
#路徑下沒有文件就跳過,不繼續操作
if not file_list: continue
#這里我定義index來打印當前進度
index = 0
#file_list此時是圖片路徑列表
for file_name in file_list:
#使用gfile從路徑中讀取圖片
image_raw_data = gfile.FastGFile(file_name, 'rb').read()
#對圖像解碼,解碼結果為一個張量
image = tf.image.decode_jpeg(image_raw_data)
#對圖像矩陣進行歸一化處理
#因為為了將圖片數據能夠保存到 TFRecord 結構體中
#所以需要將其圖片矩陣轉換成 string
#所以為了在使用時能夠轉換回來
#這里確定下數據格式為 tf.float32
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# 將圖片轉化成299*299方便模型處理
image = tf.image.resize_images(image, [299, 299])
#為了拿到圖片的真實數據這里我們要運行一個session op
image_value = sess.run(image)
pixels = image_value.shape[1]
#存儲在TFrecord里面的不能是array的形式
#所以我們需要利用tostring()將上面的矩陣
#轉化成字符串
#再通過tf.train.BytesList轉化成可以存儲的形式
image_raw = image_value.tostring()
#存到features
#隨機劃分測試集和訓練集
#這里存入TFRecord三個數據,圖像的pixels像素
#圖像原張量,這里我們需要轉成string
#以及當前圖像對應的標簽
example = tf.train.Example(features=tf.train.Features(feature={
'pixels': _int64_feature(pixels),
'label': _int64_feature(current_label),
'image_raw': _bytes_feature(image_raw)
}))
chance = np.random.randint(100)
#隨機劃分數據集
if chance < validation_percentage:
writer_validation.write(example.SerializeToString())
elif chance < (testing_percentage+validation_percentage):
writer_test.write(example.SerializeToString())
else:
writer.write(example.SerializeToString())
# print('example',index)
index = index + 1
#每一個文件夾下的所有圖片都是一個類別
#所以這里每遍歷完一個文件夾,標簽就增加1
current_label += 1
writer.close()
writer_validation.close()
writer_test.close()
運行上述程序需要一定時間,我的電腦比較爛,大概跑了三十分鐘左右。這時候在你的./path/to目錄下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三個文件,分別存放了訓練,測試和驗證數據集。上述代碼將所有圖片劃分成訓練、驗證和測試數據集。并且把圖片從原始的jpg格式轉換成inception-v3模型需要的299 * 299 * 3的數字矩陣。在數據處理完畢之后,通過以下命令可以下載谷歌提供好的Inception_v3模型。
wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz
//解壓之后可以得到訓練好的模型文件inception_v3.ckpt
tar xzf inception_v3_2016_08
二.訓練
當新的數據集和已經訓練好的模型都準備好之后,我們來寫代碼在谷歌inception_v3的基礎上訓練新數據集。
首先同樣我們導入相關的庫并且定義相關常量。在這里我們通過slim工具來直接加載模型,而不用自己再定義前向傳播過程。
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加載通過TensorFlow-Silm定義好的 inception_v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
# 輸入數據文件
INPUT_DATA = './path/to/output.tfrecords'
# 驗證數據集
VALIDATION_DATA = './path/to/output_validation.tfrecords'
# 保存訓練好的模型的路徑
ls = './path/to/save_model'
# 谷歌提供的訓練好的模型文件地址
CKPT_FILE = './path/to/inception_v3.ckpt'
TRAIN_FILE = './path/to/save_model'
# 定義訓練中使用的參數
LEARNING_RATE = 0.01
#組合batch的大小
BATCH = 32
#用于one_hot函數輸出概率分布
N_CLASSES = 5
#打亂順序,并設置出隊和入隊中元素最少的個數,這里是10000個
shuffle_buffer = 10000
# 不需要從谷歌模型中加載的參數,這里就是最后的全連接層。因為輸出類別不一樣,所以最后全連接層的參數也不一樣
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
# 需要訓練的網絡層參數 這里就是最后的全連接層
TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
接下來我們定義幾個輔助方法。首先因為我們的數據存在TFRecord里,需要定義方法從TFRecord解析數據。
def parse(record):
features = tf.parse_single_example(
record,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'pixels': tf.FixedLenFeature([], tf.int64)
}
)
#decode_raw用于解析TFRecord里面的字符串
decoded_image = tf.decode_raw(features['image_raw'], tf.uint8)
label = features['label']
#要注意這里的decoded_image并不能直接進行reshape操作
#之前我們在存儲的時候,把圖片進行了tostring()操作
#這會導致圖片的長度在原來基礎上*8
#后面我們要用到numpy的fromstring來處理
return decoded_image, label
接下來定義兩個方法。因為我們已經下載了谷歌訓練好的inception_v3模型的參數,下面我們需要定義兩個方法從里面加載參數。
#直接從inception_v3.ckpt中讀取的參數
def get_tuned_variables():
#strip刪除頭尾字符,默認為空格
exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
variables_to_restore = []
#這里給出了所有slim模型下的參數
for var in slim.get_model_variables():
excluded = False
for exclusion in exclusions:
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
return variables_to_restore
#需要重新訓練的參數
def get_trainable_variables():
#strip刪除頭尾字符,默認為空格
scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
variables_to_train = []
# 枚舉所有需要訓練的參數前綴,并通過這些前綴找到所有的參數。
for scope in scopes:
#從TRAINABLE_VARIABLES集合中獲取名為scope的變量
#也就是我們需要重新訓練的參數
variables = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES, scope)
variables_to_train.extend(variables)
return variables_to_train
這里我們就寫完了所需要的工具函數,接下來我們定義主函數。主函數主要完成數據讀取,模型定義,通過模型得出前向傳播結果,通過損失函數計算損失,最后把損失交給優化器做處理。首先我們先來完成數據讀取的代碼,這里我們使用的是TensorFlow高層API Dataset。不清楚的可以去看一下Dataset的用法。
這里我們在訓練的同時也對模型做了驗證。所以我們需要加載訓練和驗證數據
#讀取測試數據
#利用TFRecordDataset讀取TFRecord文件
dataset = tf.data.TFRecordDataset([INPUT_DATA])
#解析TFRecord
dataset = dataset.map(parse)
#把數據打亂順序并組裝成batch
dataset = dataset.shuffle(shuffle_buffer).batch(BATCH)
#定義數據重復的次數
NUM_EPOCHS = 10
dataset = dataset.repeat(NUM_EPOCHS)
#定義迭代器來獲取處理后的數據
iterator = dataset.make_one_shot_iterator()
#迭代器開始迭代
img, label = iterator.get_next()
#讀取驗證數據(同上)
valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA])
valida_dataset = valida_dataset.map(parse)
valida_dataset = valida_dataset.batch(BATCH)
valida_iterator = valida_dataset.make_one_shot_iterator()
valida_img,valida_label = valida_iterator.get_next()
#定義inception-v3的輸入,images為輸入圖片,label為每一張圖片對應的標簽
#再解釋下每一個維度 None為batch的大小,299為圖片大小,3為通道
images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images')
labels = tf.placeholder(tf.int64,[None],name='labels')
要注意上述定義的只是tensorflow的張量,保存的只是計算過程并沒有具體的數據。只有運行session之后才會拿到具體的數據。
下面我們來通過slim加載inception-v3模型
#定義inception-v3模型結構 inception_v3.ckpt里只有參數的取值
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
#logits inception_v3前向傳播得到的結果
logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES)
#獲取需要訓練的變量
trainable_variables = get_trainable_variables()
#這里用交叉熵作為損失函數,注意一下tf.losses.softmax_cross_entropy的參數
# tf.losses.softmax_cross_entropy(
# onehot_labels, # 注意此處參數名就叫 onehot_labels
# logits,
# weights=1.0,
# label_smoothing=0,
# scope=None,
# loss_collection=tf.GraphKeys.LOSSES,
# reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
# )
#這里要把labels轉成one_hot類型,logits就是神經網絡的輸出
tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0)
#把計算的損失交給優化器處理
train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())
#計算正確率。
with tf.name_scope('evaluation'):
correct_prediction = tf.equal(tf.argmax(logits,1),labels)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#定義加載模型的函數
load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True)
#定義保存新的訓練好的模型的函數
saver = tf.train.Saver()
with tf.Session() as sess:
#初始化所有變量
init = tf.global_variables_initializer()
sess.run(init)
print('Loading tuned variables from %s'%CKPT_FILE)
#加載谷歌已經訓練好的模型
load_fn(sess)
step = 0;
#在這里我們用一個while來循環訓練,直到dataset里沒有數據就結束循環
while True:
try:
if step % 30 == 0 or step + 1 == STEPS:
#每30輪輸出一次正確率
if step != 0:
#每30輪保存一次當前模型的參數,以便中途訓練中斷可以繼續
saver.save(sess,TRAIN_FILE,global_step=step)
#運行session拿到真實圖片的數據
valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label])
#上面有提到TFRecord里圖片數據被轉成了string,在這里轉回來
valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32)
#把圖片張量拉成新的維度
valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3])
#用session運行上述操作,得到處理后的圖片張量
valida_img_batch = sess.run(valida_img_batch)
#把圖片張量傳到feed_dict算出正確率并顯示
validation_accuracy = sess.run(evaluation_step,feed_dict={
images:valida_img_batch,
labels:valida_label_batch
})
print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0))
#下面是對訓練數據的操作,同上
img_batch,label_batch = sess.run([img,label])
img_batch = np.fromstring(img_batch, dtype=np.float32)
img_batch = tf.reshape(img_batch, [32,299, 299, 3])
img_batch = sess.run(img_batch)
sess.run(train_step,feed_dict={
images:img_batch,
labels:label_batch
})
#step僅僅用于記錄
step = step + 1
except tf.errors.OutOfRangeError:
break
運行上述程序開始訓練。在這里我暫時是使用cpu進行訓練,訓練過程大約3小時,可以得到類型下面的結果。
step 0:Validation accuracy = 12.5%
step 30:Validation accuracy = 22.2%
step 60:Validation accuracy = 63.2%
step 90:Validation accuracy = 79.8%
step 120:Validation accuracy = 86.4%
step 150:Validation accuracy = 88.5%
.....
以上就是我使用谷歌Inception-v3模型訓練新的數據集的全部內容。