通過上一篇 13 馴獸師:神經網絡調教綜述,對神經網絡的調教有了一個整體印象,本篇從學習緩慢這一常見問題入手,引入交叉熵損失函數,并分析它是如何克服學習緩慢問題。
“嚴重錯誤”導致學習緩慢
回顧識別MNIST的網絡架構,我們采用了經典的S型神經元,以及常見的基于均方誤差(MSE)的二次函數作為損失函數。殊不知這種組合,在實際輸出與預期偏離較大時,會造成學習緩慢。
簡單的說,如果在初始化權重和偏置時,故意產生一個背離預期較大的輸出,那么訓練網絡的過程中需要用很多次迭代,才能抵消掉這種背離,恢復正常的學習。這種現象與人類學習的經驗相悖:對于明顯的錯誤,人類能進行快速的修正。
為了看清楚這個現象,可以用一個S型神經元,從微觀角度進行重現。這個神經元接受1個固定的輸入“1”,期望經過訓練后能輸出“0”,因此待訓練參數為1個權重w和1個偏置b,如下圖:
先觀察一個“正常”初始化的情況。
令w=0.6,b=0.9,可認為其符合均值為0,標準差為1的正態分布。此時,輸入1,輸出0.82。接下來開始使用梯度下降法進行迭代訓練,從Epoch-Cost曲線可以看到“損失”快速降低,到第100次時就很低了,到第300次迭代時已經幾乎為0,符合預期,如下圖:
接下來換一種初始化策略。
將w和b都賦值為“2.0”。此時,輸入1,輸出為0.98——比之前的0.82偏離預期值0更遠了。接下來的訓練Epoch-Cost曲線顯示200次迭代后“損失”依然很高,減少緩慢,而最后100次迭代才開始恢復正常的學習,如下圖:
學習緩慢原因分析
單個樣本情況下,基于均方誤差的二次損失函數為:
一個神經元的情況下就不用反向傳播求導了,已知a = σ(z),z = wx + b,直接使用鏈式求導即可:
將唯一的一個訓練樣本(x=1,y=0)代入,得到:
觀察σ(z)函數曲線會發現,當σ接近于1時,σ曲線特別的平坦,所以此處σ'(z)是一個非常小的值,由上式可推斷C的梯度也會非常小,“下降”自然也就會變得緩慢。這種情況也成為神經元飽和。這就解釋了前面初始的神經元輸出a=0.98,為什么會比a=0.82學習緩慢那么多。
交叉熵損失函數
S型神經元,與二次均方誤差損失函數的組合,一旦神經元輸出發生“嚴重錯誤”,網絡將陷入一種艱難而緩慢的學習“沼澤”中。
對此一個簡單的策略就是更換損失函數,使用交叉熵損失函數可以明顯的改善當發生“嚴重錯誤”時導致的學習緩慢,使神經網絡的學習更符合人類經驗——快速從錯誤中修正。
交叉熵損失函數定義如下:
在證明它真的能避免學習緩慢之前,有必要先確認它是否至少可以衡量“損失”,后者并不顯而易見。
一個函數能夠作為損失函數,要符合以下兩個特性:
- 非負;
- 當實際輸出接近預期,那么損失函數應該接近0。
交叉熵全部符合。首先,實際輸出a的取值范圍為(0, 1),所以無論是lna還是ln(1-a)都是負數,期望值y的取值非0即1,因此中括號里面每項都是負數,再加上表達式最前面的一個負號,所以整體為非負。再者,當預期y為0時,如果實際輸出a接近0時,C也接近0;當預期y為1時,如果實際輸出a接近1,那么C也接近0。
接下來分析為什么交叉熵可以避免學習緩慢,仍然從求C的偏導開始。
單樣本情況下,交叉熵損失函數可以記為:
對C求w的偏導數:
a = σ(z),將其代入:
對于Sigmoid函數,有σ'(z) = σ(z)(1-σ(z)),所以上式中的σ'(z)被抵消了,得到:
由此可見,C的梯度不再與σ'(z)有關,而與a-y相關,其結果就是:實際輸出與預期偏離越大,梯度越大,學習越快。
對于偏置,同理有:
更換損失函數為交叉熵后,回到之前學習緩慢的例子,重新訓練,Epoch-Cost曲線顯示學習緩慢的情況消失了。
推廣到多神經元網絡
前面的有效性證明是基于一個神經元所做的微觀分析,將其推廣到多層神經元網絡也是很容易的。從分量的角度來看,假設輸出神經元的預期值是y = y1,y2,...,實際輸出aL = aL1,aL2,...,那么交叉熵損失函數計算公式如下:
評價交叉熵損失,注意以下3點:
交叉熵無法改善隱藏層中神經元發生的學習緩慢。損失函數定義中的aL是最后一層神經元的實際輸出,所以“損失”C針對輸出層神經元的權重wLj求偏導數,可以產生抵消σ'(zLj)的效果,從而避免輸出層神經元的學習緩慢問題。但是“損失”C對于隱藏層神經元的權重wL-1j求偏導,就無法產生抵消σ'(zL-1j)的效果。
交叉熵損失函數只對網絡輸出“明顯背離預期”時發生的學習緩慢有改善效果,如果初始輸出背離預期并不明顯,那么應用交叉熵損失函數也無法觀察到明顯的改善。從另一個角度看,應用交叉熵損失是一種防御性策略,增加訓練的穩定性。
應用交叉熵損失并不能改善或避免神經元飽和,而是當輸出層神經元發生飽和時,能夠避免其學習緩慢的問題。
小結
現有神經網絡中存在一種風險:由于初始化或其他巧合因素,一旦出現輸出與預期偏離過大,就會導致網絡學習緩慢。本篇分析了該現象出現的原因,引入交叉熵損失函數,并推理證明了其有效性。
附完整代碼
代碼基于12 TF構建3層NN玩轉MNIST中的tf_12_mnist_nn.py,修改了損失函數,TensorFlow提供了交叉熵的封裝:
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
FLAGS = None
def main(_):
# Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
# Create the model
x = tf.placeholder(tf.float32, [None, 784])
W_2 = tf.Variable(tf.random_normal([784, 30]))
b_2 = tf.Variable(tf.random_normal([30]))
z_2 = tf.matmul(x, W_2) + b_2
a_2 = tf.sigmoid(z_2)
W_3 = tf.Variable(tf.random_normal([30, 10]))
b_3 = tf.Variable(tf.random_normal([10]))
z_3 = tf.matmul(a_2, W_3) + b_3
a_3 = tf.sigmoid(z_3)
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
# loss = tf.reduce_mean(tf.norm(y_ - a_3, axis=1)**2) / 2
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
train_step = tf.train.GradientDescentOptimizer(3.0).minimize(loss)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# Train
best = 0
for epoch in range(30):
for _ in range(5000):
batch_xs, batch_ys = mnist.train.next_batch(10)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
accuracy_currut = sess.run(accuracy, feed_dict={x: mnist.test.images,
y_: mnist.test.labels})
print("Epoch %s: %s / 10000" % (epoch, accuracy_currut))
best = (best, accuracy_currut)[best <= accuracy_currut]
# Test trained model
print("best: %s / 10000" % best)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/MNIST/',
help='Directory for storing input data')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
下載 tf_14_mnist_nn_cross_entropy.py。
上一篇 13 AI馴獸師:神經網絡調教綜述
下一篇 15 1/sqrt(n)權重初始化
共享協議:署名-非商業性使用-禁止演繹(CC BY-NC-ND 3.0 CN)
轉載請注明:作者黑猿大叔(簡書)