Tensorflow使用筆記(1): Tensorflow的模型保存和使用

如何保存和使用訓練好的模型參數

引言

最近在學習Tensorflow 構建CNN,訓練需要耗費時間,把訓練好的各個參數保存下來是最簡便的,網上有很多教程,但是跟著教程走不一定一帆風順還是踩了一些坑,然后自己填了一下坑

如何保存好訓練的結果:

假設會話為sess,計算圖為graph
網上看了很多資料,使用

saver=tf.train.Saver()  # 不傳入參數代表默認存入全部參數
file_name = 'saved_model/model.ckpt'  # 將保存到當前目錄下的的saved_model文件夾下model.ckpt文件
saver.saver(sess,file_name )  # 保存好的模型文件

這樣來保存模型,這樣就行了嗎?還不行
但是按照這個方式,開始IDE總是會報錯:No Variable to save


然后我的想法是:可能要把sess傳給saver是嗎?下面是我的代碼,和一次嘗試性的修改

graph = tf.Graph()  # 計算圖
with graph.as_default():
    # 定義計算圖
    ...
    # 以上是一些權重和卷積層的定義,這里就不貼出來了
sess = tf.Session(graph=graph)  # 把上一步定義的計算圖載入到會話中
# 給定義saver一個sess作為輸入,結果也是不行的
saver=tf.train.Saver(sess)  # 不傳入參數代表默認存入全部參數
saver.saver(sess,'saved_model/model.ckpt')

這樣還是會報錯:沒有可以被用來保存的變量。我思來想去,可能要載入graph?

graph = tf.Graph()
with graph.as_default():
    # 定義計算圖
    ... 
    # ---
sess = tf.Session(graph=graph)  # 把上一步定義的計算圖載入到會話中
# 這次把graph傳進來做參數
saver = tf.train.Saver(graph)
saver.saver(sess,'saved_model/model.ckpt')

結果還是不行


多次嘗試之后,終于修改正確了

graph = tf.Graph()
with graph.as_default():
    # 定義計算圖
    ... 
    # ---
    saver = tf.train.Saver()  # 默認存儲上面全部定義參數,如果不想全部存下來,也可以輸入你想要保存的參數

sess = tf.Session(graph=graph)  # 把上一步定義的計算圖載入到會話中
# 保存,這次就可以成功了
saver.saver(sess,'saved_model/model.ckpt')

保存成功,文件夾saved_model下會出現幾個文件
checkpoint文件(這個文件很重要,記錄了) 還有幾個文件,它們的后綴分別是 .data , .index , .meta 。我們似乎可以不搭理這三個文件


先小結一下

如果定義了子圖,或者說自己定義graph而不使用tensorflow的默認計算圖的時候,定義要在定義graph最后進行定義,想保存哪個子圖的變量,就要在哪個子圖定義相關的Saver,這樣才能實現想要的效果。
Tensorflow的圖graph和會話session還是有點抽象的,一不小心就整亂了。


Notes: 假設沒有使用with graph.as_graph():這種結構,直接在腳本上定義了tensorflow的變量,再使用saver()應該是沒有問題的。


讀取和使用保存好的模型參數

那存好了之后,該怎么調用了
假設我需要在另一個新的腳本,例如 :test.py文件上使用我的代碼做測試,要怎么使用保存好的模型參數呢

有兩種方式:


先把之前訓練,構建計算圖已經你定義網絡參數的那些代碼粘貼到test.py文件下

graph = tf.Graph()
with graph.as_default():
    # 定義計算圖
    ... 
    # ---
    saver = tf.train.Saver()

特別注意,運行到上一步的時候

然后再使用如下代碼

with tf.Session(graph=graph) as sess:
        check_point_path = 'saved_model/' # 保存好模型的文件路徑
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)

# 從模型中恢復參數      saver.restore(sess,ckpt.model_checkpoint_path)  # 讀取成功,然后就可以使用模型參數進行預測,或者測試了。

如果你覺得上面那個方式有點繁瑣,可以直接import train.py
假設train.py這個腳本的代碼是這樣的:

graph = tf.Graph()
with graph.as_default():
    # 定義計算圖
    ... 
    # ---
    saver = tf.train.Saver()

那么你在你的test.py中可以這樣寫

import train
# 參考python的命名空間方法
graph = train.graph
sess = train.sess

with tf.Session(graph=graph) as sess:
        check_point_path = 'saved_model/' # 保存好模型的文件路徑
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)

# 從模型中恢復參數      saver.restore(sess,ckpt.model_checkpoint_path)  


Notes: 在使用saver.restore(sess,ckpt.model_checkpoint_path) 后,這個時候,就不需要再次使用sess.run(init) 對參數進行初始化了(否則會覆蓋掉訓練好的參數),如果你在前面使用run,進行初始化,權重會根據你的定義進行初始化,但是你使用這個語句后,模型中的參數會把它覆蓋掉


最后再說一下

好像高級的使用方法,可以根據選擇不同迭代次數更新時候的權重,這里只做簡單總結一下,以后學習到了再更新吧

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容