TensorFlow從0到1 - 15 - 重新思考神經(jīng)網(wǎng)絡(luò)初始化

TensorFlow從0到1系列回顧

上一篇14 交叉熵?fù)p失函數(shù)——克服學(xué)習(xí)緩慢從最優(yōu)化算法層面入手,將二次的均方誤差(MSE)更換為交叉熵作為損失函數(shù),避免了當(dāng)出現(xiàn)“嚴(yán)重錯(cuò)誤”時(shí)導(dǎo)致的學(xué)習(xí)緩慢。

本篇引入1/sqrt(nin)權(quán)重初始化方法,從另一個(gè)層面——參數(shù)初始化(神經(jīng)網(wǎng)絡(luò)調(diào)教的5個(gè)層面歸納在13 AI馴獸師:神經(jīng)網(wǎng)絡(luò)調(diào)教綜述)入手改善網(wǎng)絡(luò)的學(xué)習(xí)速度。

相比之前采用的標(biāo)準(zhǔn)正態(tài)分布初始化,1/sqrt(nin)權(quán)重初始化不僅明顯的加快了學(xué)習(xí)速度,而且單純性(其他任何參數(shù)不變)的提升了測試集識別精度1~2個(gè)百分點(diǎn)。

理解了1/sqrt(nin)權(quán)重初始化的思想,就能很容易的理解Xavier、He權(quán)重初始化方法。

參數(shù)初始化之“重”

神經(jīng)網(wǎng)絡(luò)的訓(xùn)練過程,就是自動(dòng)調(diào)整網(wǎng)絡(luò)中參數(shù)的過程。在訓(xùn)練的起初,網(wǎng)絡(luò)的參數(shù)總要從某一狀態(tài)開始,而這個(gè)初始狀態(tài)的設(shè)定,就是神經(jīng)網(wǎng)絡(luò)的初始化。

之所以要重新思考神經(jīng)網(wǎng)絡(luò)權(quán)重和偏置的初始化,是因?yàn)樗鼘τ诤罄m(xù)的訓(xùn)練非常重要。

12 TensorFlow構(gòu)建3層NN玩轉(zhuǎn)MNIST中就踩了“參數(shù)初始化的坑”:簡單將權(quán)重和偏置初始化為0,導(dǎo)致了網(wǎng)絡(luò)訓(xùn)練陷入了一個(gè)局部最優(yōu)沼澤而無法自拔,最終識別率僅為60%。

不僅有“局部最優(yōu)”的坑,在14 交叉熵?fù)p失函數(shù)——防止學(xué)習(xí)緩慢還見識了初始化導(dǎo)致“神經(jīng)元飽和”的坑。

合適網(wǎng)絡(luò)初始值,不僅有助于梯度下降法在一個(gè)好的“起點(diǎn)”上去尋找最優(yōu)值,還能避免神經(jīng)元發(fā)生學(xué)習(xí)飽和

重新審視標(biāo)準(zhǔn)正態(tài)分布

Initialization

在之前實(shí)現(xiàn)的MNIST數(shù)字識別案例中,權(quán)重和偏置的初始化采用的是符合均值為0、標(biāo)準(zhǔn)差為1的標(biāo)準(zhǔn)正態(tài)分布(Standard Noraml Distribution)隨機(jī)化方法。基于它的訓(xùn)練過程還算平穩(wěn)。但它是最佳的初始化策略嗎?

它如此“特別”更像是一個(gè)警告:我們并不總能輕易的得到最佳答案,一定還有“壓榨”的空間。

一個(gè)尋找切入點(diǎn)的常用方法,就是人為誘導(dǎo)其產(chǎn)生問題。讓一個(gè)具有1000個(gè)神經(jīng)元輸入層的網(wǎng)絡(luò),以標(biāo)準(zhǔn)正態(tài)分布做隨機(jī)初始化,然后人造干預(yù):令輸入層神經(jīng)元一半(500個(gè))值為1,另一半(另500個(gè))值為0。現(xiàn)在聚焦到接下來隱藏層中的一個(gè)神經(jīng)元:

隱藏層神經(jīng)元

如上圖所示,1000個(gè)輸入層神經(jīng)元全部連接到了隱藏層的第一個(gè)神經(jīng)元。此時(shí)考察神經(jīng)元的加權(quán)和z = ∑jwjxj + b:

  • 將z的表達(dá)式展開,初始共有1001項(xiàng)(不要漏掉偏置b);
  • 人為令輸入xj中的500個(gè)為0,所以z的表達(dá)式最終有501項(xiàng);
  • 人為令輸入xj的其余500個(gè)為1,所以z由500項(xiàng)wj和1項(xiàng)b組成,它們符合標(biāo)準(zhǔn)正態(tài)分布N(0,1);
  • 推導(dǎo)得到z符合均值為0,標(biāo)準(zhǔn)差為√501(501的平方根)的正態(tài)分布,推導(dǎo)過程稍后解釋;

通過人為設(shè)置特殊的輸入,由權(quán)重w和偏置b的統(tǒng)計(jì)分布,得到了z的統(tǒng)計(jì)分布:

z的分布

從圖中可見,由于標(biāo)準(zhǔn)差√501非常大,導(dǎo)致z的分布從-30到30出現(xiàn)的比例都很高,也就是說,∣z∣ >> 1出現(xiàn)的概率非常大。還記得Sigmoid曲線嗎?當(dāng)∣z∣ >> 1時(shí),σ'(z)就會(huì)變得非常小,神經(jīng)元學(xué)習(xí)飽和。

Sigmoid

類似的,網(wǎng)絡(luò)中后續(xù)層中的神經(jīng)元也有同樣的性質(zhì)。

雖然是人為制造特殊的輸入數(shù)據(jù)暴露了網(wǎng)絡(luò)的問題,但是從中可以得到一個(gè)啟示:如果網(wǎng)絡(luò)的權(quán)重和偏置采用N(0,1)初始化,那么網(wǎng)絡(luò)中各層的神經(jīng)元數(shù)量n越多,造成后續(xù)層神經(jīng)元加權(quán)和z的標(biāo)準(zhǔn)差越大,∣z∣ >> 1出現(xiàn)的概率也越大,最終造成神經(jīng)元飽和——學(xué)習(xí)緩慢

1/sqrt(nin)權(quán)重初始化

順著上面的分析,一個(gè)比較自然的思路是:既然神經(jīng)元加權(quán)和z的標(biāo)準(zhǔn)差與網(wǎng)絡(luò)上一層神經(jīng)元的數(shù)量nin有相關(guān)性,那么為了抵消掉神經(jīng)元數(shù)量的影響,初始化分布的標(biāo)準(zhǔn)差就不應(yīng)該是一個(gè)常數(shù)。

本篇引入的1/sqrt(nin)權(quán)重初始化就是答案所在:使用均值為0,標(biāo)準(zhǔn)差為1/sqrt(nin)的正態(tài)分布來初始化權(quán)重。sqrt表示開根號,同√。

繼續(xù)使用之前的人為輸入數(shù)據(jù)和網(wǎng)絡(luò)架構(gòu):

  • 將z的表達(dá)式展開,初始共有1001項(xiàng);
  • 人為令輸入xj中的500個(gè)為0,所以z的表達(dá)式最終有501項(xiàng);
  • 人為令輸入xj的其余500個(gè)為1,所以z由500項(xiàng)wj和1項(xiàng)b組成,它們符合正態(tài)分布N(0,1/sqrt(nin));
  • 推導(dǎo)得到z符合均值為0,標(biāo)準(zhǔn)差為√(3/2)(3/2的平方根)正態(tài)分布,推導(dǎo)過程稍后解釋;

得到了新的z的統(tǒng)計(jì)分布:

z的分布

此時(shí)的正態(tài)曲線變的非常尖銳,z的可能取值都在0附近,再看Sigmoid曲線就會(huì)發(fā)現(xiàn),z在0附近時(shí),σ(z)曲線最“陡”,σ'(z)值越大,學(xué)習(xí)速度越快。

注意一點(diǎn),由于神經(jīng)元的偏置b對于加權(quán)和z的貢獻(xiàn)不受上一層神經(jīng)元數(shù)量nin的影響,所以偏置b的初始化可以沿用之前的N(0,1)。

z的概率分布推導(dǎo)

回來解釋下已知w和b的分布,如何計(jì)算z = ∑jwj + b的分布(xj為1,故省略)。

先準(zhǔn)備兩個(gè)特性:

  • 獨(dú)立隨機(jī)變量和的方差,是每個(gè)獨(dú)立隨機(jī)變量方差的和
  • 方差是標(biāo)準(zhǔn)差的平方

權(quán)重和偏置分布為N(0,1)情況下的推導(dǎo):

  • 已知wj和b的標(biāo)準(zhǔn)差是1,那么wj和b的方差也是1;
  • 由于人為輸入,z的展開式有501=1000/2+1項(xiàng),每項(xiàng)標(biāo)準(zhǔn)差為1;
  • z的方差 = 12 x 501;
  • z的標(biāo)準(zhǔn)差 = √501;

權(quán)重分布為N(0,1/sqrt(nin)),偏置分布為N(0,1)情況下的推導(dǎo):

  • 已知wj的標(biāo)準(zhǔn)差是1/sqrt(nin),那么wj的方差是1/nin,已知b的標(biāo)準(zhǔn)差是1,那么b的方差也是1;
  • 由于人為輸入,z的展開式有nin/2+1項(xiàng),前nin/2項(xiàng)為權(quán)重wj,每項(xiàng)方差為1/nin,最后1項(xiàng)為偏置b,方差為1;
  • z的方差 = 1/n x n/2 + 1 = 3/2;
  • z的標(biāo)準(zhǔn)差 = √(3/2);

結(jié)果對比

本篇基于12 TensorFlow構(gòu)建3層NN玩轉(zhuǎn)MNIST中的實(shí)現(xiàn),單純性的使用N(0,1/sqrt(nin))權(quán)重初始化與前者進(jìn)行了對比,結(jié)果如下:

N(0,1)參數(shù)初始化
N(0,1/sqrt(n))

從輸出明顯看出,采用N(0,1/sqrt(nin))權(quán)重初始化的學(xué)習(xí)速度明顯快了很多,第一次迭代Epoch 0就獲得了94%的識別率,而前面的N(0,1)實(shí)現(xiàn)到Epoch 7才達(dá)到了94%。

不僅學(xué)習(xí)速率變快,30次迭代結(jié)束后,采用N(0,1/sqrt(nin))權(quán)重初始化的識別精度比前者高了1個(gè)百分點(diǎn),達(dá)到了96%以上。

小結(jié)

本篇引入1/sqrt(nin)權(quán)重初始化方法,改變了神經(jīng)元加權(quán)和z的隨機(jī)概率分布,有效的避免了神經(jīng)元飽和,最終不僅加快了學(xué)習(xí)速率,而且網(wǎng)絡(luò)的性能也有明顯的改善。

有很多其他的權(quán)重初始化方法,比如Xavier、He等,其基本思想都是相似的。

附完整代碼

N(0,1/sqrt(nin))權(quán)重初始化的有效性分析,花了我們不少功夫,但是代碼實(shí)現(xiàn)卻異常簡潔:

W_2 = tf.Variable(tf.random_normal([784, 30]) / tf.sqrt(784.0))
...
W_3 = tf.Variable(tf.random_normal([30, 10]) / tf.sqrt(30.0))

完整代碼:

import argparse
import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

FLAGS = None


def main(_):
    # Import data
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    # Create the model
    x = tf.placeholder(tf.float32, [None, 784])
    W_2 = tf.Variable(tf.random_normal([784, 30]) / tf.sqrt(784.0))
    b_2 = tf.Variable(tf.random_normal([30]))
    z_2 = tf.matmul(x, W_2) + b_2
    a_2 = tf.sigmoid(z_2)

    W_3 = tf.Variable(tf.random_normal([30, 10]) / tf.sqrt(30.0))
    b_3 = tf.Variable(tf.random_normal([10]))
    z_3 = tf.matmul(a_2, W_3) + b_3
    a_3 = tf.sigmoid(z_3)

    # Define loss and optimizer
    y_ = tf.placeholder(tf.float32, [None, 10])

    loss = tf.reduce_mean(tf.norm(y_ - a_3, axis=1)**2) / 2
    # loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y_, logits=z_3))
    train_step = tf.train.GradientDescentOptimizer(3.0).minimize(loss)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    # Train
    best = 0
    for epoch in range(30):
        for _ in range(5000):
            batch_xs, batch_ys = mnist.train.next_batch(10)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        correct_prediction = tf.equal(tf.argmax(a_3, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.int32))
        accuracy_currut = sess.run(accuracy, feed_dict={x: mnist.test.images,
                                                        y_: mnist.test.labels})
        print("Epoch %s: %s / 10000" % (epoch, accuracy_currut))
        best = (best, accuracy_currut)[best <= accuracy_currut]

    # Test trained model
    print("best: %s / 10000" % best)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='/MNIST/',
                        help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

下載 tf_15_mnist_nn_weight_init.py

上一篇 14 交叉熵?fù)p失函數(shù)——防止學(xué)習(xí)緩慢
下一篇 16 L2正則化對抗“過擬合”


共享協(xié)議:署名-非商業(yè)性使用-禁止演繹(CC BY-NC-ND 3.0 CN)
轉(zhuǎn)載請注明:作者黑猿大叔(簡書)

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容