tensorflow 恢復(restore)模型的兩種方式

image

1. 介紹

首先我們要理解TensorFlow的一個規則,首先構建計算圖(graph),然后初始化graph中的data,這兩步是分開的。

2. 如何恢復模型

有兩種方式(這兩種方式有比較大的不同):

2.1 重新使用代碼構建圖

舉個例子(完整代碼):

def build_graph():
    w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
    w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
    w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
    w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
    add = tf.add(w1,w2,name='add')
    add1 = tf.add(add,w3,name='add1')
    return w3,add1

with tf.Session() as sess:
    ckpt_state = tf.train.get_checkpoint_state('./temp/')
    if ckpt_state:
        w3,add1=build_graph()
        saver = tf.train.Saver()
        saver.restore(sess, ckpt_state.model_checkpoint_path)
    else:
        w3,add1=build_graph()
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()
    a = sess.run(add1,feed_dict={
            w3:[1,2,3,4]
        })
    print(a)
    saver.save(sess,'./temp/model')

上面的流程很簡單,首先build_graph(),然后如果有ckpt文件就從該文件中讀取數據,否則用sess.run(init_op)初始化數據。

那么第一種restore方法就出來了:

build_graph()
saver = tf.train.Saver()
saver.restore(sess, ckpt_state.model_checkpoint_path)

首先build graph,等于是將圖重新建立了一遍,和之前圖的一樣,然后將ckpt文件里的數據restore到圖里的變量里。

當然,在build graph的過程中,你可以在原有的圖里加一些變量,但是加的變量一定要初始化,但是要注意到一個問題,如果使用:

init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)

這種方式時,如果定義init_op時的graph中已經存在原有圖的變量,那么sess.run(init_op)會將加載進來的數據清空。

為了解決這個問題,兩種方式:

  1. 新定義的變量放在init_op之前,在init_op之后restore(注意,加載好變量后才run(init_op)同樣會覆蓋)
    即,init_op得到當前圖中的所有變量,sess.run(init_op)對init_op中的變量進行初始化,所以什么時候定義init_op和什么時候運行run(init_op)都很重要

  2. 只初始化未初始化的變量

def get_uninitialized_variables(sess):
global_vars = tf.global_variables()

# print([str(i.name) for i in global_vars])

is_not_initialized = sess.run([tf.is_variable_initialized(var) for var in global_vars])
not_initialized_vars = [v for (v, f) in zip(global_vars, is_not_initialized) if not f]
print([str(i.name) for i in not_initialized_vars])
return not_initialized_vars
sess.run(tf.variables_initializer(get_uninitialized_variables(sess)))

PS:注意saver = tf.train.Saver()要定義在圖構建完成之后

? 即將被restore的變量不用初始化,但是只有在restore之后,這些變量才會被初始化,所以在restore之前運行這些值會報沒有初始化的錯。

2.2 利用保存的.meta文件恢復圖

參考:Tensorflow如何保存、讀取model (即利用訓練好的模型測試新數據的準確度)

上面的方式適用于斷點續訓,且自己有構建圖的完整代碼,如果我要用別人的網絡(fine tune),或者在自己原有網絡上修改(即修改原有網絡的某個部分),那么將網絡的圖重新構建一遍會很麻煩,那么我們可以直接從.meta文件中加載網絡結構。

2.2.1 get_tensor_by_name

完整代碼:

def build_graph():
    w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
    w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
    w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
    w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
    add = tf.add(w1,w2,name='add')
    add1 = tf.add(add,w3,name='add1')
    return w3,add1

with tf.Session() as sess:
    ckpt_state = tf.train.get_checkpoint_state('./temp/')
    if ckpt_state:
        saver = tf.train.import_meta_graph('./temp/model.meta')
        graph = tf.get_default_graph()
        w3 = graph.get_tensor_by_name('W3:0')
        add1 = graph.get_tensor_by_name('add1:0')
        saver.restore(sess, tf.train.latest_checkpoint('./temp/'))
        print(sess.run(tf.get_collection('w1')[0]))
    else:
        w3,add1=build_graph()
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()
    a = sess.run(add1,feed_dict={
            w3:[1,2,3,4]
        })
    print(a)
    saver.save(sess,'./temp/model')

上面使用了import_meta_graph()來加載圖,并用restore給變量賦值。

通過get_tensor_by_name來獲取保存的圖中的op或變量,之后可以對獲取的值進行操作,如果之后save的話,也會將import_meta_graph()中圖引用的部分保存下來。

2.2.2

def build_graph():
    w1 = tf.Variable([1,3,10,15],name='W1',dtype=tf.float32)
    w2 = tf.Variable([3,4,2,18],name='W2',dtype=tf.float32)
    w3 = tf.placeholder(shape=[4],dtype=tf.float32,name='W3')
    w4 = tf.Variable([100,100,100,100],dtype=tf.float32,name='W4')
    add = tf.add(w1,w2,name='add')
    add1 = tf.add(add,w3,name='add1')
    tf.add_to_collection('w1','W1:0')
    tf.add_to_collection('w3',w3)
    tf.add_to_collection('add1',add1)
    return w3,add1

with tf.Session() as sess:
    ckpt_state = tf.train.get_checkpoint_state('./temp/')
    if ckpt_state:
        saver = tf.train.import_meta_graph('./temp/model.meta')
        w3 = tf.get_collection('w3')[0]
        add1 = tf.get_collection('add1')[0]
        # run init_op before restore
        saver.restore(sess, tf.train.latest_checkpoint('./temp/'))
    else:
        w3,add1=build_graph()
        init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(init_op)
        saver = tf.train.Saver()
    a = sess.run(add1,feed_dict={
            w3:[1,2,3,4]
        })
    print(a)
    saver.save(sess,'./temp/model')

通過import_meta_graph引進圖,通過get_collection獲得變量,其實和get_tensor_by_name差不多,但是可能會更方便一點。

3. 總結

總的來說,兩種方式都是先構造好圖,然后通過restore來給圖里的變量賦值。

一個常見的問題是,要引入新的變量,對以前的圖進行改造,那么如何初始化新的變量且不覆蓋原來的數據?

  • 可以先啥都不管把所有的圖相關的部分構造好后,得到init_op,然后在restore前run(init_op)
  • 對未初始化的變量進行初始化

4. 最后

image
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容