交叉熵代價函數(Cross-entropy cost function)是用來衡量人工神經網絡(ANN)的預測值與實際值的一種方式。與二次代價函數相比,它能更有效地促進ANN的訓練。在介紹交叉熵代價函數之前,本文先簡要介紹二次代價函數,以及其存在的不足。
- 二次代價函數的不足
ANN的設計目的之一是為了使機器可以像人一樣學習知識。人在學習分析新事物時,當發現自己犯的錯誤越大時,改正的力度就越大。比如投籃:當運動員發現自己的投籃方向離正確方向越遠,那么他調整的投籃角度就應該越大,籃球就更容易投進籃筐。同理,我們希望:ANN在訓練時,如果預測值與實際值的誤差越大,那么在反向傳播訓練的過程中,各種參數調整的幅度就要更大,從而使訓練更快收斂。然而,如果使用二次代價函數訓練ANN,看到的實際效果是,如果誤差越大,參數調整的幅度可能更小,訓練更緩慢。
以一個神經元的二類分類訓練為例,進行兩次實驗(ANN常用的激活函數為sigmoid函數,該實驗也采用該函數):輸入一個相同的樣本數據x=1.0(該樣本對應的實際分類y=0);兩次實驗各自隨機初始化參數,從而在各自的第一次前向傳播后得到不同的輸出值,形成不同的代價(誤差):
實驗1:第一次輸出值為0.82
實驗2:第一次輸出值為0.98
在實驗1中,隨機初始化參數,使得第一次輸出值為0.82(該樣本對應的實際值為0);經過300次迭代訓練后,輸出值由0.82降到0.09,逼近實際值。而在實驗2中,第一次輸出值為0.98,同樣經過300迭代訓練,輸出值只降到了0.20。
從兩次實驗的代價曲線中可以看出:實驗1的代價隨著訓練次數增加而快速降低,但實驗2的代價在一開始下降得非常緩慢;直觀上看,初始的誤差越大,收斂得越緩慢。
其實,誤差大導致訓練緩慢的原因在于使用了二次代價函數。二次代價函數的公式如下:
如圖所示,實驗2的初始輸出值(0.98)對應的梯度明顯小于實驗1的輸出值(0.82),因此實驗2的參數梯度下降得比實驗1慢。這就是初始的代價(誤差)越大,導致訓練越慢的原因。與我們的期望不符,即:不能像人一樣,錯誤越大,改正的幅度越大,從而學習得越快。
可能有人會說,那就選擇一個梯度不變化或變化不明顯的激活函數不就解決問題了嗎?圖樣圖森破,那樣雖然簡單粗暴地解決了這個問題,但可能會引起其他更多更麻煩的問題。而且,類似sigmoid這樣的函數(比如tanh函數)有很多優點,非常適合用來做激活函數,具體請自行google之
說起交叉熵損失函數「Cross Entropy Loss」,腦海中立馬浮現出它的公式:
我們已經對這個交叉熵函數非常熟悉,大多數情況下都是直接拿來使用就好。但是它是怎么來的?為什么它能表征真實樣本標簽和預測概率之間的差值?上面的交叉熵函數是否有其它變種?也許很多朋友還不是很清楚!沒關系,接下來我將盡可能以最通俗的語言回答上面這幾個問題。
1. 交叉熵損失函數的數學原理
我們知道,在二分類問題模型:例如邏輯回歸「Logistic Regression」、神經網絡「Neural Network」等,真實樣本的標簽為 [0,1],分別表示負類和正類。模型的最后通常會經過一個 Sigmoid 函數,輸出一個概率值,這個概率值反映了預測為正類的可能性:概率越大,可能性越大。
Sigmoid 函數的表達式和圖形如下所示:
其中 s 是模型上一層的輸出,Sigmoid 函數有這樣的特點:s = 0 時,g(s) = 0.5;s >> 0 時, g ≈ 1,s << 0 時,g ≈ 0。顯然,g(s) 將前一級的線性輸出映射到 [0,1] 之間的數值概率上。這里的 g(s) 就是交叉熵公式中的模型預測輸出 。
我們說了,預測輸出即 Sigmoid 函數的輸出表征了當前樣本標簽為 1 的概率:
很明顯,當前樣本標簽為 0 的概率就可以表達成:
重點來了,如果我們從極大似然性的角度出發,把上面兩種情況整合到一起:
也即,當真實樣本標簽 y = 0 時,上面式子第一項就為 1,概率等式轉化為:
當真實樣本標簽 y = 1 時,上面式子第二項就為 1,概率等式轉化為:
兩種情況下概率表達式跟之前的完全一致,只不過我們把兩種情況整合在一起了。重點看一下整合之后的概率表達式,我們希望的是概率 P(y|x) 越大越好。首先,我們對 P(y|x) 引入 log 函數,因為 log 運算并不會影響函數本身的單調性。則有:
我們希望 log P(y|x) 越大越好,反過來,只要 log P(y|x) 的負值 -log P(y|x) 越小就行了。那我們就可以引入損失函數,且令 Loss = -log P(y|x)即可。則得到損失函數為:
非常簡單,我們已經推導出了單個樣本的損失函數,是如果是計算 N 個樣本的總的損失函數,只要將 N 個 Loss 疊加起來就可以了:
這樣,我們已經完整地實現了交叉熵損失函數的推導過程。
2. 交叉熵損失函數的直觀理解
可能會有讀者說,我已經知道了交叉熵損失函數的推導過程。但是能不能從更直觀的角度去理解這個表達式呢?而不是僅僅記住這個公式。好問題!接下來,我們從圖形的角度,分析交叉熵函數,加深大家的理解。
首先,還是寫出單個樣本的交叉熵損失函數:
我們知道,當 y = 1 時:
這時候,L 與預測輸出的關系如下圖所示:
看了 L 的圖形,簡單明了!橫坐標是預測輸出,縱坐標是交叉熵損失函數 L。顯然,預測輸出越接近真實樣本標簽 1,損失函數 L 越小;預測輸出越接近 0,L 越大。因此,函數的變化趨勢完全符合實際需要的情況。當 y = 0 時:
這時候,L 與預測輸出的關系如下圖所示:
同樣,預測輸出越接近真實樣本標簽 0,損失函數 L 越小;預測函數越接近 1,L 越大。函數的變化趨勢也完全符合實際需要的情況。
從上面兩種圖,可以幫助我們對交叉熵損失函數有更直觀的理解。無論真實樣本標簽 y 是 0 還是 1,L 都表征了預測輸出與 y 的差距。
另外,重點提一點的是,從圖形中我們可以發現:預測輸出與 y 差得越多,L 的值越大,也就是說對當前模型的 “ 懲罰 ” 越大,而且是非線性增大,是一種類似指數增長的級別。這是由 log 函數本身的特性所決定的。這樣的好處是模型會傾向于讓預測輸出更接近真實樣本標簽 y。
3. 交叉熵損失函數的其它形式
什么?交叉熵損失函數還有其它形式?沒錯!我剛才介紹的是一個典型的形式。接下來我將從另一個角度推導新的交叉熵損失函數。
這種形式下假設真實樣本的標簽為 +1 和 -1,分別表示正類和負類。有個已知的知識點是Sigmoid 函數具有如下性質:
這個性質我們先放在這,待會有用。
好了,我們之前說了 y = +1 時,下列等式成立:
如果 y = -1 時,并引入 Sigmoid 函數的性質,下列等式成立:
重點來了,因為 y 取值為 +1 或 -1,可以把 y 值帶入,將上面兩個式子整合到一起:
接下來,同樣引入 log 函數,得到:
要讓概率最大,反過來,只要其負數最小即可。那么就可以定義相應的損失函數為:
還記得 Sigmoid 函數的表達式吧?將 g(ys) 帶入:
好咯,L 就是我要推導的交叉熵損失函數。如果是 N 個樣本,其交叉熵損失函數為:
接下來,我們從圖形化直觀角度來看。當 y = +1 時:
這時候,L 與上一層得分函數 s 的關系如下圖所示:
橫坐標是 s,縱坐標是 L。顯然,s 越接近真實樣本標簽 1,損失函數 L 越小;s 越接近 -1,L 越大。另一方面,當 y = -1 時:
這時候,L 與上一層得分函數 s 的關系如下圖所示:
同樣,s 越接近真實樣本標簽 -1,損失函數 L 越小;s 越接近 +1,L 越大。
4. 總結
本文主要介紹了交叉熵損失函數的數學原理和推導過程,也從不同角度介紹了交叉熵損失函數的兩種形式。第一種形式在實際應用中更加常見,例如神經網絡等復雜模型;第二種多用于簡單的邏輯回歸模型。
作者:To_2020_1_4
鏈接:http://www.lxweimin.com/p/b07f4cd32ba6
來源:簡書