BN(Batch Normalization)在TensorFlow的實(shí)現(xiàn)

BN是Google inception系列模型里,從inception v2到inception v3的一個重要升級,在activation層之前,將卷積層的輸出進(jìn)行歸一化,使activation的輸入在[0,1]之間,避免梯度消失的問題。

具體地,BN在TF中實(shí)現(xiàn),涉及到兩個方法:tf.nn.moments 和 tf.nn.batch_normalization。

具體的方法說明請參考官方API文檔。主要思路是moments計算數(shù)據(jù)的mean和variance,batch_normalization利用mean和variance計算歸一化后的數(shù)據(jù)。

一、tf.nn.moments

def moments(x, axes, name=None, keep_dims=False)

參數(shù)解釋:

·x 可以理解為我們輸出的數(shù)據(jù),形如 [batchsize, height, width, kernels]
·axes 表示在哪個維度上求解,是個list,例如 [0, 1, 2]
·name 就是個名字,不多解釋
·keep_dims 是否保持維度,不多解釋

這個函數(shù)的輸出就是BN需要的mean和variance。
Test code:

import tensorflow as tf
sess = tf.InteractiveSession()
img = tf.random_normal([2, 3])
axis = list(range(len(img.get_shape()) - 1))
mean, variance = tf.nn.moments(img, axis)
mean.eval()
variance.eval()

輸出

img = [[ 0.69495416  2.08983064 -1.08764684]
       [ 0.31431156 -0.98923939 -0.34656194]]
mean =  [ 0.50463283  0.55029559 -0.71710438]
variance =  [ 0.0362222   2.37016821  0.13730171]

可以理解為batchsize=2,kernels=3,最終得到每個kernel對應(yīng)的mean和variance。

img=[128,32,32,64]對應(yīng)的物理意義

二、tf.nn.batch_normalization

def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name=None)

參數(shù)解釋:
·x同moments方法
·mean moments方法的輸出之一
·variance moments方法的輸出之一
·offset BN需要學(xué)習(xí)的參數(shù)
·scale BN需要學(xué)習(xí)的參數(shù)
·variance_epsilon 歸一化時防止分母為0加的一個常量

參數(shù)對應(yīng)的BN計算公式:


BN計算公式

其中Xi對應(yīng)x,μ即為mean,δ對應(yīng)variance。第3個公式做初步的Norm,第4個公式中,γ即為scale,β對應(yīng)offset。

BN在實(shí)際中,由于mean和variance是和batch內(nèi)的數(shù)據(jù)有關(guān)的,因此需要注意訓(xùn)練過程和預(yù)測過程中,mean和variance無法使用相同的數(shù)據(jù)。需要一個trick,即moving_average,代碼如下:

update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, BN_DECAY)
update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, BN_DECAY)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_mean)
tf.add_to_collection(UPDATE_OPS_COLLECTION, update_moving_variance)
mean, variance = control_flow_ops.cond(['is_training'], lambda: (mean, variance), lambda: (moving_mean, moving_variance))

在訓(xùn)練的過程中,通過每個step得到的mean和variance,疊加計算對應(yīng)的moving_average(滑動平均),并最終保存下來以便在inference的過程中使用。
對于assign_moving_average方法如下:

def assign_moving_average(variable, value, decay, zero_debias=True, name=None)

其實(shí)內(nèi)部計算比較簡單,公式表達(dá)如下:
variable = variable * decay + value * (1 - decay)
變換一下:
variable = variable - (1 - decay) * (variable - value)
減號后面的項(xiàng)就是moving_average的更新delta了。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容