Tensorflow: how to save/restore a model?
PS? 馬上要鎖門了,先把代碼 貼出來,.
一、入門
Question:
After you train a model in Tensorflow:
1. How do you save the trained model?
2. How do you later restore this saved model?
程序設計目標:
使用saver.save() 保存簡單的模型
并使用
saver = tf.train.import_meta_graph('保存的模型文件')
saver.restore(sess,tf.train.latest_checkpoint('指定CKPT文件'))
Save Model
import tensorflow as tf
# 定義要占位的變量 , i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}
# 定義要保存的操作
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 創建saver 對象保存
saver = tf.train.Saver()
# 運行
print(sess.run(w4,feed_dict))
#Prints 24 which is sum of (w1+w2)*b1
#保存圖 my_test_model是指定保存模型的路徑
saver.save(sess, './my_test_model',global_step=1000)
Restore the model:
import tensorflow as tf
sess=tf.Session()
# 首先加載模型 ,注意這里要說明加載的文件
saver = tf.train.import_meta_graph('./my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# 訪問模型的中的變量
print(sess.run('bias:0'))
# 會打印出 2, 這個是在上一段程序中保存過的變量
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated