Tensorflow模型的保存與恢復(fù)

在這篇tensorflow教程中,我會解釋:

1) Tensorflow的模型(model)長什么樣子?

2) 如何保存tensorflow的模型?

3) 如何恢復(fù)一個tensorflow模型來用于預(yù)測或者遷移學(xué)習(xí)?

4) 如何使用預(yù)訓(xùn)練好的模型(imported pretrained models)來用于fine-tuning和?modification

1. Tensorflow模型是什么?

當(dāng)你已經(jīng)訓(xùn)練好一個神經(jīng)網(wǎng)絡(luò)之后,你想要保存它,用于以后的使用,部署到產(chǎn)品里面去。所以,Tensorflow模型是什么?Tensorflow模型主要包含網(wǎng)絡(luò)的設(shè)計(jì)或者圖(graph),和我們已經(jīng)訓(xùn)練好的網(wǎng)絡(luò)參數(shù)的值。因此Tensorflow模型有兩個主要的文件:

A)?Meta graph:

這是一個保存完整Tensorflow graph的protocol buffer,比如說,所有的?variables, operations, collections等等。這個文件的后綴是.meta。

B)?Checkpoint file:

這是一個包含所有權(quán)重(weights),偏置(biases),梯度(gradients)和所有其他保存的變量(variables)的二進(jìn)制文件。它包含兩個文件:

mymodel.data-00000-of-00001

mymodel.index

其中,.data文件包含了我們的訓(xùn)練變量。

另外,除了這兩個文件,Tensorflow有一個叫做checkpoint的文件,記錄著已經(jīng)最新的保存的模型文件。

:Tensorflow 0.11版本以前,Checkpoint file只有一個后綴名為.ckpt的文件。

?因此,總結(jié)來說,Tensorflow(版本0.10以后)模型長這個樣子:

? ? ? ?Tensorflow版本0.11以前,只包含以下三個文件:

inception_v1.meta

inception_v1.ckpt

checkpoint


?????? 接下來說明如何保存模型。


2. 保存一個Tensorflow模型

當(dāng)網(wǎng)絡(luò)訓(xùn)練結(jié)束時,我們要保存所有變量和網(wǎng)絡(luò)結(jié)構(gòu)體到文件中。在Tensorflow中,我們可以創(chuàng)建一個tf.train.Saver()?類的實(shí)例,如下:

saver = tf.train.Saver()


由于Tensorflow變量僅僅只在session中存在,因此需要調(diào)用save方法來將模型保存在一個session中。

saver.save(sess,'my-test-model')

在這里,sess是一個session對象,其中my-test-model是你給模型起的名字。下面是一個完整的例子:


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')

w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

saver = tf.train.Saver()

sess = tf.Session()

sess.run(tf.global_variables_initializer())

saver.save(sess, 'my_test_model')# This will save following files in Tensorflow v >= 0.11# my_test_model.data-00000-of-00001# my_test_model.index# my_test_model.meta# checkpoint

如果我們想在訓(xùn)練1000次迭代之后保存模型,可以使用如下方法保存

saver.save(sess,'my_test_model',global_step=1000)

這個將會在模型名字的后面追加上‘-1000’,下面的文件將會被創(chuàng)建:

my_test_model-1000.index

my_test_model-1000.meta

my_test_model-1000.data-00000-of-00001

checkpoint


由于網(wǎng)絡(luò)的圖(graph)在訓(xùn)練的時候是不會改變的,因此,我們沒有必要每次都重復(fù)保存.meta文件,可以使用如下方法:


saver.save(sess,'my-model',global_step=step,write_meta_graph=False)

如果你只想要保存最新的4個模型,并且想要在訓(xùn)練的時候每2個小時保存一個模型,那么你可以使用max_to_keep 和 keep_checkpoint_every_n_hours,如下所示:


#saves a model every 2 hours and maximum 4 latest models are saved.saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

注意到,我們在tf.train.Saver()中并沒有指定任何東西,因此它將保存所有變量。如果我們不想保存所有的變量,只想保存其中一些變量,我們可以在創(chuàng)建tf.train.Saver實(shí)例的時候,給它傳遞一個我們想要保存的變量的list或者字典。示例如下:


import tensorflow as tf

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')

w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')

saver = tf.train.Saver([w1,w2])

sess = tf.Session()

sess.run(tf.global_variables_initializer())

saver.save(sess, 'my_test_model',global_step=1000)


3. 導(dǎo)入一個已經(jīng)訓(xùn)練好的模型

如果你想要使用別人已經(jīng)訓(xùn)練好的模型來fine-tuning,那么你需要做兩個步驟:

A)創(chuàng)建網(wǎng)絡(luò)Create the network:

?????? 你可以通過寫python代碼,來手動地創(chuàng)建每一個、每一層,使得跟原始網(wǎng)絡(luò)一樣。

但是,如果你仔細(xì)想的話,我們已經(jīng)將模型保存在了.meta文件中,因此我們可以使用tf.train.import()函數(shù)來重新創(chuàng)建網(wǎng)絡(luò),使用方法如下:

saver = tf.train.import_meta_graph('my_test_model-1000.meta')

? ? ? ?注意,這僅僅是將已經(jīng)定義的網(wǎng)絡(luò)導(dǎo)入到當(dāng)前的graph中,但是我們還是需要加載網(wǎng)絡(luò)的參數(shù)值。


B)加載參數(shù)Load the parameters

?????? 我們可以通過調(diào)用restore函數(shù)來恢復(fù)網(wǎng)絡(luò)的參數(shù),如下:

with tf.Session() as sess:

? new_saver = tf.train.import_meta_graph('my_test_model-1000.meta')

? new_saver.restore(sess, tf.train.latest_checkpoint('./'))

在這之后,像w1和w2的tensor的值已經(jīng)被恢復(fù),并且可以獲取到:

with tf.Session() as sess:? ?

? ? saver = tf.train.import_meta_graph('my-model-1000.meta')

? ? saver.restore(sess,tf.train.latest_checkpoint('./'))

? ? print(sess.run('w1:0'))##Model has been restored. Above statement will print the saved value of w1.

? ? ? ?上面介紹了如何保存和恢復(fù)一個Tensorflow模型。下面介紹一個加載任何預(yù)訓(xùn)練模型的實(shí)用方法。



4. Working with restored models

下面介紹如何恢復(fù)任何一個預(yù)訓(xùn)練好的模型,并使用它來預(yù)測,fine-tuning或者進(jìn)一步訓(xùn)練。當(dāng)你使用Tensorflow時,你會定義一個圖(graph),其中,你會給這個圖喂(feed)訓(xùn)練數(shù)據(jù)和一些超參數(shù)(比如說learning rate,global step等)。下面我們使用placeholder建立一個小的網(wǎng)絡(luò),然后保存該網(wǎng)絡(luò)。注意到,當(dāng)網(wǎng)絡(luò)被保存時,placeholder的值并不會被保存。

import tensorflow as tf#Prepare to feed input, i.e. feed_dict and placeholdersw1 = tf.placeholder("float", name="w1")

w2 = tf.placeholder("float", name="w2")

b1= tf.Variable(2.0,name="bias")

feed_dict ={w1:4,w2:8}#Define a test operation that we will restorew3 = tf.add(w1,w2)

w4 = tf.multiply(w3,b1,name="op_to_restore")

sess = tf.Session()

sess.run(tf.global_variables_initializer())#Create a saver object which will save all the variablessaver = tf.train.Saver()#Run the operation by feeding inputprint sess.run(w4,feed_dict)#Prints 24 which is sum of (w1+w2)*b1 #Now, save the graphsaver.save(sess,'my_test_model',global_step=1000)

現(xiàn)在,我們想要恢復(fù)這個網(wǎng)絡(luò),我們不僅需要恢復(fù)圖(graph)和權(quán)重,而且也需要準(zhǔn)備一個新的feed_dict,將新的訓(xùn)練數(shù)據(jù)喂給網(wǎng)絡(luò)。我們可以通過使用graph.get_tensor_by_name()方法來獲得已經(jīng)保存的操作(operations)和placeholder variables。

#How to access saved variable/Tensor/placeholders w1 = graph.get_tensor_by_name("w1:0")## How to access saved operationop_to_restore = graph.get_tensor_by_name("op_to_restore:0")

如果我們僅僅想要用不同的數(shù)據(jù)運(yùn)行這個網(wǎng)絡(luò),可以簡單的使用feed_dict來將新的數(shù)據(jù)傳遞給網(wǎng)絡(luò)。

import tensorflow as tf

sess=tf.Session()? ? #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('my_test_model-1000.meta')

saver.restore(sess,tf.train.latest_checkpoint('./'))# Now, let's access and create placeholders variables and# create feed-dict to feed new datagraph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")

w2 = graph.get_tensor_by_name("w2:0")

feed_dict ={w1:13.0,w2:17.0}#Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")print sess.run(op_to_restore,feed_dict)#This will print 60 which is calculated #using new values of w1 and w2 and saved value of b1.

如果你想要給graph增加更多的操作(operations)然后訓(xùn)練它,可以像如下那么做:

import tensorflow as tf

sess=tf.Session()? ? #First let's load meta graph and restore weightssaver = tf.train.import_meta_graph('my_test_model-1000.meta')

saver.restore(sess,tf.train.latest_checkpoint('./'))# Now, let's access and create placeholders variables and# create feed-dict to feed new datagraph = tf.get_default_graph()

w1 = graph.get_tensor_by_name("w1:0")

w2 = graph.get_tensor_by_name("w2:0")

feed_dict ={w1:13.0,w2:17.0}#Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0")#Add more to the current graphadd_on_op = tf.multiply(op_to_restore,2)print sess.run(add_on_op,feed_dict)#This will print 120.


但是,你可以只恢復(fù)舊的graph的一部分,然后插入一些操作用于fine-tuning?當(dāng)然可以。僅僅需要通過?by graph.get_tensor_by_name()?方法來獲取合適的operation,然后在這上面建立graph。下面是一個實(shí)際的例子,我們使用meta graph?加載了一個預(yù)訓(xùn)練好的vgg模型,并且在最后一層將輸出個數(shù)改成2,然后用新的數(shù)據(jù)fine-tuning。

......

......

saver = tf.train.import_meta_graph('vgg.meta')# Access the graphgraph = tf.get_default_graph()## Prepare the feed_dict for feeding data for fine-tuning #Access the appropriate output for fine-tuningfc7= graph.get_tensor_by_name('fc7:0')#use this if you only want to change gradients of the last layerfc7 = tf.stop_gradient(fc7)# It's an identity functionfc7_shape= fc7.get_shape().as_list()

new_outputs=2weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))

biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))

output = tf.matmul(fc7, weights) + biases

pred = tf.nn.softmax(output)# Now, you run this with fine-tuning data in sess.run()

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,983評論 6 537
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 98,772評論 3 422
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事?!?“怎么了?”我有些...
    開封第一講書人閱讀 176,947評論 0 381
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經(jīng)常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,201評論 1 315
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 71,960評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 55,350評論 1 324
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼。 笑死,一個胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,406評論 3 444
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 42,549評論 0 289
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 49,104評論 1 335
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 40,914評論 3 356
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發(fā)現(xiàn)自己被綠了。 大學(xué)時的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 43,089評論 1 371
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,647評論 5 362
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 44,340評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,753評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,007評論 1 289
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,834評論 3 395
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 48,106評論 2 375

推薦閱讀更多精彩內(nèi)容