TensorFlow-3: 用 feed-forward neural network 識別數(shù)字

今天繼續(xù)看 TensorFlow Mechanics 101:
https://www.tensorflow.org/get_started/mnist/mechanics

完整版教程可以看中文版tutorial:
http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mnist_tf.html

這一節(jié)講了使用 MNIST 數(shù)據(jù)集訓(xùn)練并評估一個簡易前饋神經(jīng)網(wǎng)絡(luò)(feed-forward neural network)

input,output 和前兩節(jié)是一樣的:即劃分?jǐn)?shù)據(jù)集并預(yù)測圖片的 label

data_sets.train 55000個圖像和標(biāo)簽(labels),作為主要訓(xùn)練集。
data_sets.validation    5000個圖像和標(biāo)簽,用于迭代驗證訓(xùn)練準(zhǔn)確度。
data_sets.test  10000個圖像和標(biāo)簽,用于最終測試訓(xùn)練準(zhǔn)確度(trained accuracy)。

主要有兩個代碼:

mnist.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py

  • 構(gòu)建一個全連接網(wǎng)絡(luò),由 2 個隱藏層,1 個 `softmax_linearv 輸出構(gòu)成
  • 定義損失函數(shù),用 `cross entropyv
  • 定義訓(xùn)練時的優(yōu)化器,用 GradientDescentOptimizer
  • 定義評價函數(shù),用 tf.nn.in_top_k

**fully_connected_feed.py **
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

  • placeholder_inputs 傳入 batch size,得到 image 和 label 的兩個placeholder
  • 定義生成 feed_dict 的函數(shù),key 是 placeholders,value 是 data
  • 定義 do_eval 函數(shù),每隔 1000 個訓(xùn)練步驟,就對模型進行以下評估,分別作用于訓(xùn)練集、驗證集和測試集
  • 訓(xùn)練時:
    • 導(dǎo)入數(shù)據(jù)
    • 得到 image 和 label 兩個 placeholder
    • 傳入 mnist.inference 定義的 NN, 得到 predictions
    • 將 predictions 傳給 mnist.loss 計算 loss
    • loss 傳給 mnist.training 進行優(yōu)化訓(xùn)練
    • 再用 mnist.evaluation 評價預(yù)測值和實際值

代碼中涉及到下面幾個函數(shù):

with tf.Graph().as_default():
即所有已經(jīng)構(gòu)建的操作都要與默認的 tf.Graph 全局實例關(guān)聯(lián)起來,tf.Graph 實例是一系列可以作為整體執(zhí)行的操作

summary = tf.summary.merge_all():
為了釋放 TensorBoard 所使用的 events file,所有的即時數(shù)據(jù)都要在圖表構(gòu)建時合并至一個操作 op 中,每次運行 summary 時,都會向 events file 中寫入最新的即時數(shù)據(jù)

summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph):
用于寫入包含了圖表本身和即時數(shù)據(jù)具體值的 events file。

saver = tf.train.Saver():
就是向訓(xùn)練文件夾中寫入包含了當(dāng)前所有可訓(xùn)練變量值 checkpoint file

with tf.name_scope('hidden1'):
主要用于管理一個圖里面的各種 op,返回的是一個以 scope_name 命名的 context manager,一個 graph 會維護一個 name_space 的堆,實現(xiàn)一種層次化的管理,避免各個 op 之間命名沖突。例如,如果額外使用 tf.get_variable() 定義的變量是不會被 tf.name_scope() 當(dāng)中的名字所影響的

tf.nn.in_top_k(logits, labels, 1):
意思是在 K 個最有可能的預(yù)測中如果可以發(fā)現(xiàn) true,就將輸出標(biāo)記為 correct。本文 K 為 1,也就是只有在預(yù)測是 true 時,才判定它是 correct。


推薦閱讀 歷史技術(shù)博文鏈接匯總
http://www.lxweimin.com/p/28f02bb59fe5
也許可以找到你想要的

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容