一介紹
該部分主要介紹mnist數(shù)據(jù)集上的神經(jīng)網(wǎng)絡(luò)模型,變量管理,模型持久化這幾部分。
二 變量管理
Tensorflow通過(guò)變量名稱獲取變量的機(jī)制主要是通過(guò)tf.get_variable和tf.variable_scope函數(shù)來(lái)現(xiàn)。? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?tf.get_variable創(chuàng)建變量時(shí)和Variable基本等價(jià)
get_variable和Variable不同在于,如果該變量名已經(jīng)存在的化,會(huì)報(bào)錯(cuò),但是Variable卻不會(huì)報(bào)錯(cuò),所以get_variable要獲取變量時(shí),需要通過(guò)variable_scope函數(shù)來(lái)生成一個(gè)上下文管理器。
三 模型持久化
模型持久化就是將模型保存,以方便復(fù)用。
model.ckpt.meta:保存tensorflow計(jì)算圖的結(jié)構(gòu)。
model.ckpt:保存了tensorflow程序中每一個(gè)變量的取值。
checkpoint:? 這個(gè)文件中保存了一個(gè)目錄下所有模型文件列表。
如果不想將tensorflow的網(wǎng)絡(luò)結(jié)構(gòu)重新一遍的化,可以直接加載,但是麻煩在于獲取張量的方式。
為了保存或者加載部分變量,在聲明tf.train.Saver類時(shí)可以提供一個(gè)來(lái)指定需要保存或者加載的變量。比如在上面代碼的例子中,想加載進(jìn)v1變量,可以saver = tf.train.Saver([v1])這種方式,但是因?yàn)関2沒(méi)有加載進(jìn)去,所以會(huì)報(bào)錯(cuò)v2沒(méi)有初始化的錯(cuò)誤。
重命名加載的變量。
關(guān)于重命名的方式很適合上一章節(jié)講述的滑動(dòng)平均值,每一個(gè)變量的滑動(dòng)平均值是通過(guò)影子變量維護(hù)的,所以要獲取變量的滑動(dòng)平均值就是獲取這個(gè)影子變量的取值。
在滑動(dòng)平均模型中有這個(gè)應(yīng)用,提供了variables_to_restore函數(shù)來(lái)生成tf.train.Saver所需要的變量重命名字典。
使用tf.train.Saver()會(huì)保存運(yùn)行tensorflow程序所需要的全部信息,然而有時(shí)并不需要某些信息,比如在測(cè)試或離線預(yù)測(cè)時(shí),只需要知道神經(jīng)網(wǎng)絡(luò)從輸入層到輸出層即可,不需要變量初始化,模型保存等輔助信息。根據(jù)這些需求,tensorflow提供了convert_variables_to_constants函數(shù),通過(guò)這個(gè)函數(shù)可以將計(jì)算圖中的變量及取值通過(guò)常量的方式保存。
可以參考:保存,凍結(jié),讀取