如何保存和使用訓練好的模型參數
引言
最近在學習Tensorflow 構建CNN,訓練需要耗費時間,把訓練好的各個參數保存下來是最簡便的,網上有很多教程,但是跟著教程走不一定一帆風順還是踩了一些坑,然后自己填了一下坑
如何保存好訓練的結果:
假設會話為sess,計算圖為graph
網上看了很多資料,使用
saver=tf.train.Saver() # 不傳入參數代表默認存入全部參數
file_name = 'saved_model/model.ckpt' # 將保存到當前目錄下的的saved_model文件夾下model.ckpt文件
saver.saver(sess,file_name ) # 保存好的模型文件
這樣來保存模型,這樣就行了嗎?還不行
但是按照這個方式,開始IDE總是會報錯:No Variable to save
然后我的想法是:可能要把sess傳給saver是嗎?下面是我的代碼,和一次嘗試性的修改
graph = tf.Graph() # 計算圖
with graph.as_default():
# 定義計算圖
...
# 以上是一些權重和卷積層的定義,這里就不貼出來了
sess = tf.Session(graph=graph) # 把上一步定義的計算圖載入到會話中
# 給定義saver一個sess作為輸入,結果也是不行的
saver=tf.train.Saver(sess) # 不傳入參數代表默認存入全部參數
saver.saver(sess,'saved_model/model.ckpt')
這樣還是會報錯:沒有可以被用來保存的變量。我思來想去,可能要載入graph?
graph = tf.Graph()
with graph.as_default():
# 定義計算圖
...
# ---
sess = tf.Session(graph=graph) # 把上一步定義的計算圖載入到會話中
# 這次把graph傳進來做參數
saver = tf.train.Saver(graph)
saver.saver(sess,'saved_model/model.ckpt')
結果還是不行
多次嘗試之后,終于修改正確了
graph = tf.Graph()
with graph.as_default():
# 定義計算圖
...
# ---
saver = tf.train.Saver() # 默認存儲上面全部定義參數,如果不想全部存下來,也可以輸入你想要保存的參數
sess = tf.Session(graph=graph) # 把上一步定義的計算圖載入到會話中
# 保存,這次就可以成功了
saver.saver(sess,'saved_model/model.ckpt')
保存成功,文件夾saved_model下會出現幾個文件
checkpoint文件(這個文件很重要,記錄了) 還有幾個文件,它們的后綴分別是 .data , .index , .meta 。我們似乎可以不搭理這三個文件
先小結一下
如果定義了子圖,或者說自己定義graph而不使用tensorflow的默認計算圖的時候,定義要在定義graph最后進行定義,想保存哪個子圖的變量,就要在哪個子圖定義相關的Saver,這樣才能實現想要的效果。
Tensorflow的圖graph和會話session還是有點抽象的,一不小心就整亂了。
Notes: 假設沒有使用with graph.as_graph():這種結構,直接在腳本上定義了tensorflow的變量,再使用saver()應該是沒有問題的。
讀取和使用保存好的模型參數
那存好了之后,該怎么調用了
假設我需要在另一個新的腳本,例如 :test.py文件上使用我的代碼做測試,要怎么使用保存好的模型參數呢
有兩種方式:
先把之前訓練,構建計算圖已經你定義網絡參數的那些代碼粘貼到test.py文件下
graph = tf.Graph()
with graph.as_default():
# 定義計算圖
...
# ---
saver = tf.train.Saver()
特別注意,運行到上一步的時候
然后再使用如下代碼
with tf.Session(graph=graph) as sess:
check_point_path = 'saved_model/' # 保存好模型的文件路徑
ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)
# 從模型中恢復參數 saver.restore(sess,ckpt.model_checkpoint_path) # 讀取成功,然后就可以使用模型參數進行預測,或者測試了。
如果你覺得上面那個方式有點繁瑣,可以直接import train.py
假設train.py這個腳本的代碼是這樣的:
graph = tf.Graph()
with graph.as_default():
# 定義計算圖
...
# ---
saver = tf.train.Saver()
那么你在你的test.py中可以這樣寫
import train
# 參考python的命名空間方法
graph = train.graph
sess = train.sess
with tf.Session(graph=graph) as sess:
check_point_path = 'saved_model/' # 保存好模型的文件路徑
ckpt = tf.train.get_checkpoint_state(checkpoint_dir=check_point_path)
# 從模型中恢復參數 saver.restore(sess,ckpt.model_checkpoint_path)
Notes: 在使用saver.restore(sess,ckpt.model_checkpoint_path)
后,這個時候,就不需要再次使用sess.run(init) 對參數進行初始化了(否則會覆蓋掉訓練好的參數),如果你在前面使用run,進行初始化,權重會根據你的定義進行初始化,但是你使用這個語句后,模型中的參數會把它覆蓋掉
最后再說一下
好像高級的使用方法,可以根據選擇不同迭代次數更新時候的權重,這里只做簡單總結一下,以后學習到了再更新吧