TensorFlow利用卷積神經網絡在谷歌inception_v3模型基礎上解決花朵分類問題

本篇更多的是在代碼實戰方向,不會涉及太多的理論。本文主要針對TensorFlow和卷積神經網絡有一定基礎的同學,并對圖像處理有一定的了解。

閱讀本文你大概需要以下知識:

1.TensorFlow基礎
2.TensorFlow實現卷積神經網絡的前向傳播過程
3.TFRecord數據格式
4.Dataset的使用
5.Slim的使用

好了廢話不多說,下面開始。

一.數據準備

首先我們需要有一個讓我們訓練的數據集,這里谷歌已經幫我們做好了。這里要把數據集下載下來,打開命令行,執行如下命令:

wget http://download.tensorflow.org/example_image/flower_photo.tgz
//解壓
tar xzf flower_photos.tgz

這里需要注意的是,文件最好是下載到你的工程目錄下方便你的讀取。什么?你還不會搭建TensorFlow程序?請移步https://www.tensorflow.org/install/
選擇自己的操作系統,在這里我的是macOS。我使用的是Virtualenv來搭建TensorFlow運行環境。
數據集下載并解壓后,我們可以看到大概是這個樣子

每一個文件夾里都是一個種類的花的圖片,這里總共有五種花。
好了,數據有了?接下來該怎么辦呢?當然是把數據進行預處理拉,你不會覺得我們的TensorFlow可以直接識別這些圖片進行訓練吧,hhhhhh。

二.數據預處理

接下來我們在目錄下新建pre_data.python文件。TensorFlow對圖片做處理一般是生成TFRecord文件。什么是TFRecord?后面我們會講到。

首先我們要引入我們需要的庫。

# glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
import glob
#os.path生成路徑方便glob獲取
import os.path
#這里主要用到隨機數
import numpy as np
#引入tensorflow框架
import tensorflow as tf
#引入gflie對圖片做處理
from tensorflow.python.platform import gfile

相關庫在我們這個程序中的功能都作了簡單介紹,下面用到的時候我們會更加詳細的說明。

大家都知道我們的數據集一般分訓練,測試和驗證數據集。觀察上面的數據集,谷歌只是給出了每一種花的圖片,并沒有給去哪些我訓練,哪些是測試,哪些是驗證數據集。所以在這里我們要進行劃分。

#輸入圖片地址
INPUT_DATA = '../../flower_photos'
#訓練數據集
OUTPUT_FILE = './path/to/output.tfrecords'
#測試數據集
OUTPUT_TEST_FILE = './path/to/output_test.tfrecords'
#驗證數據集
OUTPUT_VALIDATION_FILE = './path/to/output_validation.tfrecords'
#測試數據和驗證數據的比例
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10

關于VALIDATION_PERCENTAGE和TEST_PERCENTAGE這兩個常量,我們在后面的例子會給出。

下面我們就來定義處理數據的方法:

def create_image_lists(sess,testing_percentage,validation_percentage):
    #拿到INPUT_DATA文件夾下的所有目錄(包括root)
    sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
    #如果是root_dir不需要做處理
    is_root_dir = True
    #定義圖片對應的標簽,從0-4分別代表不同的花
    current_label = 0
    #寫入TFRecord的數據需要首先定義writer
    #這里定義三個writer分別存儲訓練,測試和驗證數據
    writer = tf.python_io.TFRecordWriter(OUTPUT_FILE)
    writer_test = tf.python_io.TFRecordWriter(OUTPUT_TEST_FILE)
    writer_validation = tf.python_io.TFRecordWriter(OUTPUT_VALIDATION_FILE)
    #循環目錄
    for sub_dir in sub_dirs:
        if is_root_dir:
            #跳過根目錄
            is_root_dir = False
            continue
        #定義空數組來裝圖片路徑
        file_list = []
        #生成查找路徑
        dir_name = os.path.basename(sub_dir)
        file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + "jpg")
        # extend合并兩個數組
        # glob模塊的主要方法就是glob,該方法返回所有匹配的文件路徑列表(list)
        # 比如:glob.glob(r’c:*.txt’) 這里就是獲得C盤下的所有txt文件
        file_list.extend(glob.glob(file_glob))
        #路徑下沒有文件就跳過,不繼續操作
        if not file_list: continue
        #這里我定義index來打印當前進度
        index = 0
        #file_list此時是圖片路徑列表
        for file_name in file_list:
            #使用gfile從路徑中讀取圖片
            image_raw_data = gfile.FastGFile(file_name, 'rb').read()
            #對圖像解碼,解碼結果為一個張量
            image = tf.image.decode_jpeg(image_raw_data)

            #對圖像矩陣進行歸一化處理
            #因為為了將圖片數據能夠保存到 TFRecord 結構體中
            #所以需要將其圖片矩陣轉換成 string
            #所以為了在使用時能夠轉換回來
            #這里確定下數據格式為 tf.float32  
            if image.dtype != tf.float32:
                image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            # 將圖片轉化成299*299方便模型處理
            image = tf.image.resize_images(image, [299, 299])
            #為了拿到圖片的真實數據這里我們要運行一個session op
            image_value = sess.run(image)
           
            pixels = image_value.shape[1]
            #存儲在TFrecord里面的不能是array的形式
            #所以我們需要利用tostring()將上面的矩陣
            #轉化成字符串
            #再通過tf.train.BytesList轉化成可以存儲的形式
            image_raw = image_value.tostring()

            #存到features
            #隨機劃分測試集和訓練集
            #這里存入TFRecord三個數據,圖像的pixels像素
            #圖像原張量,這里我們需要轉成string
            #以及當前圖像對應的標簽
            example = tf.train.Example(features=tf.train.Features(feature={
                'pixels': _int64_feature(pixels),
                'label': _int64_feature(current_label),
                'image_raw': _bytes_feature(image_raw)
            }))
            chance = np.random.randint(100)
            #隨機劃分數據集
            if chance < validation_percentage:
                writer_validation.write(example.SerializeToString())
            elif chance < (testing_percentage+validation_percentage):
                writer_test.write(example.SerializeToString())
            else:
                writer.write(example.SerializeToString())
            # print('example',index)
            index = index + 1

        #每一個文件夾下的所有圖片都是一個類別
        #所以這里每遍歷完一個文件夾,標簽就增加1
        current_label += 1

    writer.close()
    writer_validation.close()
    writer_test.close()

運行上述程序需要一定時間,我的電腦比較爛,大概跑了三十分鐘左右。這時候在你的./path/to目錄下可以看到output.tfrecords,output_test.tfrecords,output_validation.tfrecords三個文件,分別存放了訓練,測試和驗證數據集。上述代碼將所有圖片劃分成訓練、驗證和測試數據集。并且把圖片從原始的jpg格式轉換成inception-v3模型需要的299 * 299 * 3的數字矩陣。在數據處理完畢之后,通過以下命令可以下載谷歌提供好的Inception_v3模型。

wget http://download.tensorflow.org/models/inception_v3_2016_08_26.tar.gz
//解壓之后可以得到訓練好的模型文件inception_v3.ckpt
tar xzf inception_v3_2016_08

二.訓練

當新的數據集和已經訓練好的模型都準備好之后,我們來寫代碼在谷歌inception_v3的基礎上訓練新數據集。

首先同樣我們導入相關的庫并且定義相關常量。在這里我們通過slim工具來直接加載模型,而不用自己再定義前向傳播過程。

import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
# 加載通過TensorFlow-Silm定義好的 inception_v3模型
import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3

# 輸入數據文件
INPUT_DATA = './path/to/output.tfrecords'
# 驗證數據集
VALIDATION_DATA = './path/to/output_validation.tfrecords'
# 保存訓練好的模型的路徑
ls = './path/to/save_model'
# 谷歌提供的訓練好的模型文件地址
CKPT_FILE = './path/to/inception_v3.ckpt'
TRAIN_FILE = './path/to/save_model'

# 定義訓練中使用的參數
LEARNING_RATE = 0.01
#組合batch的大小
BATCH = 32

#用于one_hot函數輸出概率分布
N_CLASSES = 5
#打亂順序,并設置出隊和入隊中元素最少的個數,這里是10000個
shuffle_buffer = 10000

# 不需要從谷歌模型中加載的參數,這里就是最后的全連接層。因為輸出類別不一樣,所以最后全連接層的參數也不一樣
CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
# 需要訓練的網絡層參數 這里就是最后的全連接層
TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'

接下來我們定義幾個輔助方法。首先因為我們的數據存在TFRecord里,需要定義方法從TFRecord解析數據。

def parse(record):
    features = tf.parse_single_example(
        record,
        features={
            'image_raw': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
            'pixels': tf.FixedLenFeature([], tf.int64)
        }
    )
    #decode_raw用于解析TFRecord里面的字符串
    decoded_image = tf.decode_raw(features['image_raw'], tf.uint8)
    label = features['label']
    #要注意這里的decoded_image并不能直接進行reshape操作
    #之前我們在存儲的時候,把圖片進行了tostring()操作
    #這會導致圖片的長度在原來基礎上*8
    #后面我們要用到numpy的fromstring來處理
    return decoded_image, label

接下來定義兩個方法。因為我們已經下載了谷歌訓練好的inception_v3模型的參數,下面我們需要定義兩個方法從里面加載參數。

#直接從inception_v3.ckpt中讀取的參數
def get_tuned_variables():
    #strip刪除頭尾字符,默認為空格
    exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(",")]
    variables_to_restore = []
    #這里給出了所有slim模型下的參數
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
            if not excluded:
                variables_to_restore.append(var)
        return variables_to_restore

#需要重新訓練的參數
def get_trainable_variables():
    #strip刪除頭尾字符,默認為空格
    scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(",")]
    variables_to_train = []
    # 枚舉所有需要訓練的參數前綴,并通過這些前綴找到所有的參數。
    for scope in scopes:
      #從TRAINABLE_VARIABLES集合中獲取名為scope的變量
      #也就是我們需要重新訓練的參數
        variables = tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES, scope)
        variables_to_train.extend(variables)
    return variables_to_train

這里我們就寫完了所需要的工具函數,接下來我們定義主函數。主函數主要完成數據讀取,模型定義,通過模型得出前向傳播結果,通過損失函數計算損失,最后把損失交給優化器做處理。首先我們先來完成數據讀取的代碼,這里我們使用的是TensorFlow高層API Dataset。不清楚的可以去看一下Dataset的用法。

這里我們在訓練的同時也對模型做了驗證。所以我們需要加載訓練和驗證數據

#讀取測試數據
    #利用TFRecordDataset讀取TFRecord文件
    dataset = tf.data.TFRecordDataset([INPUT_DATA])
    #解析TFRecord
    dataset = dataset.map(parse)
    #把數據打亂順序并組裝成batch
    dataset = dataset.shuffle(shuffle_buffer).batch(BATCH)
    #定義數據重復的次數
    NUM_EPOCHS = 10
    dataset = dataset.repeat(NUM_EPOCHS)
    #定義迭代器來獲取處理后的數據
    iterator = dataset.make_one_shot_iterator()
    #迭代器開始迭代
    img, label = iterator.get_next()

    #讀取驗證數據(同上)
    valida_dataset = tf.data.TFRecordDataset([VALIDATION_DATA])
    valida_dataset = valida_dataset.map(parse)
    valida_dataset = valida_dataset.batch(BATCH)
    valida_iterator = valida_dataset.make_one_shot_iterator()
    valida_img,valida_label = valida_iterator.get_next()

    #定義inception-v3的輸入,images為輸入圖片,label為每一張圖片對應的標簽
    #再解釋下每一個維度 None為batch的大小,299為圖片大小,3為通道
    images = tf.placeholder(tf.float32,[None,299,299,3],name='input_images')
    labels = tf.placeholder(tf.int64,[None],name='labels')

要注意上述定義的只是tensorflow的張量,保存的只是計算過程并沒有具體的數據。只有運行session之后才會拿到具體的數據。

下面我們來通過slim加載inception-v3模型

 #定義inception-v3模型結構 inception_v3.ckpt里只有參數的取值
    with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
        #logits  inception_v3前向傳播得到的結果
        logits,_ = inception_v3.inception_v3(images,num_classes=N_CLASSES)
        #獲取需要訓練的變量
        trainable_variables = get_trainable_variables()
        #這里用交叉熵作為損失函數,注意一下tf.losses.softmax_cross_entropy的參數
        # tf.losses.softmax_cross_entropy(
        #     onehot_labels,  # 注意此處參數名就叫 onehot_labels
        #     logits,
        #     weights=1.0,
        #     label_smoothing=0,
        #     scope=None,
        #     loss_collection=tf.GraphKeys.LOSSES,
        #     reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
        # )
        #這里要把labels轉成one_hot類型,logits就是神經網絡的輸出        
        tf.losses.softmax_cross_entropy(tf.one_hot(labels,N_CLASSES),logits,weights=1.0)
        #把計算的損失交給優化器處理
        train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())

        #計算正確率。
        with tf.name_scope('evaluation'):
            correct_prediction = tf.equal(tf.argmax(logits,1),labels)
            evaluation_step = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
        #定義加載模型的函數
        load_fn = slim.assign_from_checkpoint_fn(CKPT_FILE,get_tuned_variables(),ignore_missing_vars=True)
        #定義保存新的訓練好的模型的函數
        saver = tf.train.Saver()
        with tf.Session() as sess:
            #初始化所有變量
            init = tf.global_variables_initializer()
            sess.run(init)
            print('Loading tuned variables from %s'%CKPT_FILE)
            #加載谷歌已經訓練好的模型
            load_fn(sess)
            step = 0;
            #在這里我們用一個while來循環訓練,直到dataset里沒有數據就結束循環
            while True:
                try:
                    if step % 30  == 0 or step + 1 == STEPS:
                      #每30輪輸出一次正確率
                        if step != 0:
                            #每30輪保存一次當前模型的參數,以便中途訓練中斷可以繼續
                            saver.save(sess,TRAIN_FILE,global_step=step)
                       #運行session拿到真實圖片的數據
                        valida_img_batch,valida_label_batch = sess.run([valida_img,valida_label])
                        #上面有提到TFRecord里圖片數據被轉成了string,在這里轉回來
                        valida_img_batch = np.fromstring(valida_img_batch, dtype=np.float32)
                        #把圖片張量拉成新的維度
                        valida_img_batch = tf.reshape(valida_img_batch, [32, 299, 299, 3])
                        #用session運行上述操作,得到處理后的圖片張量
                        valida_img_batch = sess.run(valida_img_batch)
                        #把圖片張量傳到feed_dict算出正確率并顯示
                        validation_accuracy = sess.run(evaluation_step,feed_dict={
                            images:valida_img_batch,
                            labels:valida_label_batch
                        })
                        print('Step %d: Validation accurary = %.1f%%'%(step,validation_accuracy*100.0))
                    #下面是對訓練數據的操作,同上
                    img_batch,label_batch = sess.run([img,label])
                    img_batch = np.fromstring(img_batch, dtype=np.float32)
                    img_batch = tf.reshape(img_batch, [32,299, 299, 3])
                    img_batch = sess.run(img_batch)

                    sess.run(train_step,feed_dict={
                        images:img_batch,
                        labels:label_batch
                    })
                    #step僅僅用于記錄
                    step = step + 1
                except tf.errors.OutOfRangeError:
                    break

運行上述程序開始訓練。在這里我暫時是使用cpu進行訓練,訓練過程大約3小時,可以得到類型下面的結果。

step 0:Validation accuracy = 12.5%
step 30:Validation accuracy = 22.2%
step 60:Validation accuracy = 63.2%
step 90:Validation accuracy = 79.8%
step 120:Validation accuracy = 86.4%
step 150:Validation accuracy = 88.5%
.....

以上就是我使用谷歌Inception-v3模型訓練新的數據集的全部內容。

?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 230,362評論 6 544
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 99,577評論 3 429
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事?!?“怎么了?”我有些...
    開封第一講書人閱讀 178,486評論 0 383
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,852評論 1 317
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 72,600評論 6 412
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,944評論 1 328
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,944評論 3 447
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 43,108評論 0 290
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,652評論 1 336
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 41,385評論 3 358
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,616評論 1 374
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 39,111評論 5 364
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,798評論 3 350
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 35,205評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,537評論 1 295
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,334評論 3 400
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,570評論 2 379

推薦閱讀更多精彩內容