標簽平滑 Label Smoothing 詳解及 pytorch tensorflow實現

定義

標簽平滑(Label smoothing),像L1、L2和dropout一樣,是機器學習領域的一種正則化方法,通常用于分類問題,目的是防止模型在訓練時過于自信地預測標簽,改善泛化能力差的問題。

背景

對于分類問題,我們通常認為訓練數據中標簽向量的目標類別概率應為1,非目標類別概率應為0。傳統的one-hot編碼的標簽向量yi為,

yi={1,i=target0,i≠target

在訓練網絡時,最小化損失函數H(y,p)=?K∑iyilogpi,其中pi由對模型倒數第二層輸出的logits向量z應用Softmax函數計算得到,

pi=exp(zi)∑Kjexp(zj)

傳統one-hot編碼標簽的網絡學習過程中,鼓勵模型預測為目標類別的概率趨近1,非目標類別的概率趨近0,即最終預測的logits向量(logits向量經過softmax后輸出的就是預測的所有類別的概率分布)中目標類別zi的值會趨于無窮大,使得模型向預測正確與錯誤標簽的logit差值無限增大的方向學習,而過大的logit差值會使模型缺乏適應性,對它的預測過于自信。

在訓練數據不足以覆蓋所有情況下,這就會導致網絡過擬合,泛化能力差,而且實際上有些標注數據不一定準確,這時候使用交叉熵損失函數作為目標函數也不一定是最優的了。

數學定義

label smoothing結合了均勻分布,用更新的標簽向量^yi來替換傳統的ont-hot編碼的標簽向量yhat

^yi=yhot(1?α)+α/K

其中K為多分類的類別總個數,αα是一個較小的超參數(一般取0.1),即

^yi={1?α,i=targetα/K,i≠target

這樣,標簽平滑后的分布就相當于往真實分布中加入了噪聲,避免模型對于正確標簽過于自信,使得預測正負樣本的輸出值差別不那么大,從而避免過擬合,提高模型的泛化能力。

效果

NIPS 2019上的這篇論文<u style="box-sizing: border-box; list-style: inherit;">When Does Label Smoothing Help?</u>用實驗說明了為什么Label smoothing可以work,指出標簽平滑可以讓分類之間的cluster更加緊湊,增加類間距離,減少類內距離,提高泛化性,同時還能提高Model Calibration(模型對于預測值的confidences和accuracies之間aligned的程度)。但是在模型蒸餾中使用Label smoothing會導致性能下降。

從標簽平滑的定義我們可以看出,它鼓勵神經網絡選擇正確的類,并且正確類和其余錯誤的類的差別是一致的。與之不同的是,如果我們使用硬目標,則會允許不同的錯誤類之間有很大不同。基于此論文作者提出了一個結論:標簽平滑鼓勵倒數第二層激活函數之后的結果靠近正確的類的模板,并且同樣的遠離錯誤類的模板。

作者設計了一個可視化的方案來證明這件事情,具體方案為:(1)挑選3個類;(2)選取通過這三個類的模板的標準正交基的平面;(3)將倒數第二層激活函數之后的結果映射到該平面。作者做了4組實驗,第一組實驗為在CIFAR-10/AlexNet(數據集/模型)上面“飛機”、“汽車”和“鳥”三類的結果,可視化結果如下所示:

[圖片上傳失敗...(image-aef0e1-1644756672006)]

從中我們可以看出,加了標簽平滑之后(后兩張圖),每個類聚的更緊了,而且和其余類的距離大致一致。第二組實驗為在CIFAR-100/ResNet-56(數據集/模型)上的實驗結果,三個類分別為“河貍”、“海豚”與“水獺”,我們可以得到類似的結果:

[圖片上傳失敗...(image-d61e63-1644756672006)]

在第三組實驗中,作者測試了在ImageNet/Inception-v4(數據集/模型)上的表現,三個類分別為“貓鼬”、“鯉魚”和“切刀肉”,結果如下:

[圖片上傳失敗...(image-c18b24-1644756672006)]

因為ImageNet有很多細粒度的分類,可以用來測試比較相似的類之間的關系。作者在第四組實驗中選擇的三個類分別為“玩具貴賓犬”、“ 迷你貴賓犬”和“鯉魚”,可以看出前兩個類是很相似的,最后一個差別比較大的類在圖中用藍色表示,結果如下:

[圖片上傳失敗...(image-fb65ba-1644756672006)]

可以看出在使用硬目標的情況下,兩個相似的類彼此比較靠近。但是標簽平滑強制要求每個示例與所有剩余類的模板之間的距離相等,這就導致了后兩張圖中兩個類距離較遠,這在一定程度上造成了信息的損失。

代碼實現

pytorch部分代碼

class LabelSmoothing(nn.Module):
    def __init__(self, size, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False)
        #self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing#if i=y的公式
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
    
    def forward(self, x, target):
        """
        x表示輸入 (N,M)N個樣本,M表示總類數,每一個類的概率log P
        target表示label(M,)
        """
        assert x.size(1) == self.size
        true_dist = x.data.clone()#先深復制過來
        #print true_dist
        true_dist.fill_(self.smoothing / (self.size - 1))#otherwise的公式
        #print true_dist
        #變成one-hot編碼,1表示按列填充,
        #target.data.unsqueeze(1)表示索引,confidence表示填充的數字
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        self.true_dist = true_dist

        return self.criterion(x, Variable(true_dist, requires_grad=False))
        
loss_function = LabelSmoothing(num_labels, 0.1)

tensorflow代碼實現

def smoothing_cross_entropy(logits,labels,vocab_size,confidence):
  with tf.name_scope("smoothing_cross_entropy", values=[logits, labels]):
    # Low confidence is given to all non-true labels, uniformly.
    low_confidence = (1.0 - confidence) / to_float(vocab_size - 1)

    # Normalizing constant is the best cross-entropy value with soft targets.
    # We subtract it just for readability, makes no difference on learning.
    normalizing = -(
        confidence * tf.log(confidence) + to_float(vocab_size - 1) *
        low_confidence * tf.log(low_confidence + 1e-20))

    soft_targets = tf.one_hot(
          tf.cast(labels, tf.int32),
          depth=vocab_size,
          on_value=confidence,
          off_value=low_confidence)
    xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=logits, labels=soft_targets)
    return xentropy - normalizing
  1. https://www.jiqizhixin.com/articles/2019-07-09-7
  2. https://www.cnblogs.com/irvingluo/p/13873699.html
  3. https://proceedings.neurips.cc/paper/2019/file/f1748d6b0fd9d439f71450117eba2725-Paper.pdf
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容