今天繼續(xù)看 TensorFlow Mechanics 101:
https://www.tensorflow.org/get_started/mnist/mechanics
完整版教程可以看中文版tutorial:
http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_tf.html
這一節(jié)講了使用 MNIST 數(shù)據(jù)集訓(xùn)練并評估一個簡易前饋神經(jīng)網(wǎng)絡(luò)(feed-forward neural network)
input,output 和前兩節(jié)是一樣的:即劃分?jǐn)?shù)據(jù)集并預(yù)測圖片的 label
data_sets.train 55000個圖像和標(biāo)簽(labels),作為主要訓(xùn)練集。
data_sets.validation 5000個圖像和標(biāo)簽,用于迭代驗證訓(xùn)練準(zhǔn)確度。
data_sets.test 10000個圖像和標(biāo)簽,用于最終測試訓(xùn)練準(zhǔn)確度(trained accuracy)。
主要有兩個代碼:
mnist.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py
- 構(gòu)建一個全連接網(wǎng)絡(luò),由 2 個隱藏層,1 個 `softmax_linearv 輸出構(gòu)成
- 定義損失函數(shù),用 `cross entropyv
- 定義訓(xùn)練時的優(yōu)化器,用
GradientDescentOptimizer
- 定義評價函數(shù),用
tf.nn.in_top_k
**fully_connected_feed.py **
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py
- 向
placeholder_inputs
傳入batch size
,得到 image 和 label 的兩個placeholder - 定義生成
feed_dict
的函數(shù),key 是 placeholders,value 是 data - 定義
do_eval
函數(shù),每隔 1000 個訓(xùn)練步驟,就對模型進行以下評估,分別作用于訓(xùn)練集、驗證集和測試集 - 訓(xùn)練時:
- 導(dǎo)入數(shù)據(jù)
- 得到 image 和 label 兩個 placeholder
- 傳入
mnist.inference
定義的 NN, 得到 predictions - 將 predictions 傳給
mnist.loss
計算 loss - loss 傳給
mnist.training
進行優(yōu)化訓(xùn)練 - 再用
mnist.evaluation
評價預(yù)測值和實際值
代碼中涉及到下面幾個函數(shù):
with tf.Graph().as_default():
即所有已經(jīng)構(gòu)建的操作都要與默認的 tf.Graph
全局實例關(guān)聯(lián)起來,tf.Graph
實例是一系列可以作為整體執(zhí)行的操作
summary = tf.summary.merge_all():
為了釋放 TensorBoard 所使用的 events file,所有的即時數(shù)據(jù)都要在圖表構(gòu)建時合并至一個操作 op 中,每次運行 summary 時,都會向 events file 中寫入最新的即時數(shù)據(jù)
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph):
用于寫入包含了圖表本身和即時數(shù)據(jù)具體值的 events file。
saver = tf.train.Saver():
就是向訓(xùn)練文件夾中寫入包含了當(dāng)前所有可訓(xùn)練變量值 checkpoint file
with tf.name_scope('hidden1'):
主要用于管理一個圖里面的各種 op,返回的是一個以 scope_name
命名的 context manager,一個 graph 會維護一個 name_space
的堆,實現(xiàn)一種層次化的管理,避免各個 op 之間命名沖突。例如,如果額外使用 tf.get_variable()
定義的變量是不會被 tf.name_scope()
當(dāng)中的名字所影響的
tf.nn.in_top_k(logits, labels, 1):
意思是在 K 個最有可能的預(yù)測中如果可以發(fā)現(xiàn) true,就將輸出標(biāo)記為 correct。本文 K 為 1,也就是只有在預(yù)測是 true 時,才判定它是 correct。
推薦閱讀 歷史技術(shù)博文鏈接匯總
http://www.lxweimin.com/p/28f02bb59fe5
也許可以找到你想要的