【GiantPandaCV導語】Deep Mutual Learning是Knowledge Distillation的外延,經(jīng)過測試(代碼來自Knowledge-Distillation-Zoo), Deep Mutual Learning性能確實超出了原始KD很多,所以本文分析這篇CVPR2018年被接受的論文。同時PPOCRv2中也提到了DML,并提出了CML,取得效果顯著。
引言
首先感謝:https://github.com/AberHu/Knowledge-Distillation-Zoo
筆者在這個基礎上進行測試,測試了在CIFAR10數(shù)據(jù)集上的結果。
學生網(wǎng)絡resnet20:92.29% 教師網(wǎng)絡resnet110:94.31%
這里只展示幾個感興趣的算法結果帶來的收益:
logits(mimic learning via regressing logits): + 0.78
ST(soft target): + 0.16
OFD(overhaul of feature distillation): +0.45
AT(attention transfer): +0.71
NST(neural selective transfer): +0.38
RKD(relational knowledge distillation): +0.65
AFD(attention feature distillation): +0.18
DML(deep mutual learning): + 2.24 (ps: 這里教師網(wǎng)絡已經(jīng)訓練好了,與DML不同)
DML也是傳統(tǒng)知識蒸餾的擴展,其目標也是將大型模型壓縮為小的模型。但是不同于傳統(tǒng)知識蒸餾的單向蒸餾(教師→學生),DML認為可以讓學生互相學習(雙向蒸餾),在整個訓練的過程中互相學習,通過這種方式可以提升模型的性能。
DML通過實驗證明在沒有先驗強大的教師網(wǎng)絡的情況下,僅通過學生網(wǎng)絡之間的互相學習也可以超過傳統(tǒng)的KD。
如果傳統(tǒng)的知識蒸餾是由教師網(wǎng)絡指導學生網(wǎng)絡,那么DML就是讓兩個學生互幫互助,互相學習。
DML
小型的網(wǎng)絡通常有與大網(wǎng)絡相同的表示能力,但是訓練起來比大網(wǎng)絡更加困難。那么先訓練一個大型的網(wǎng)絡,然后通過使用模型剪枝、知識蒸餾等方法就可以讓小型模型的性能提升,甚至超過大型模型。
以知識蒸餾為例,通常需要先訓練一個大而寬的教師網(wǎng)絡,然后讓小的學生網(wǎng)絡來模仿教師網(wǎng)絡。通過這種方式相比直接從hard label學習,可以降低學習的難度,這樣學生網(wǎng)絡甚至可以比教師網(wǎng)絡更強。
Deep Mutual Learning則是讓兩個小的學生網(wǎng)絡同時學習,對于每個單獨的網(wǎng)絡來說,會有針對hard label的分類損失函數(shù),還有模仿另外的學生網(wǎng)絡的損失函數(shù),用于對齊學生網(wǎng)絡的類別后驗。
這種方式一般會產(chǎn)生這樣的疑問,兩個隨機初始化的學生網(wǎng)絡最初階段性能都很差的情況,這樣相互模仿可能會導致性能更差,或者性能停滯不前(the blind lead the blind)。
文章中這樣進行解釋:
每個學生主要是倍傳統(tǒng)的有監(jiān)督學習損失函數(shù)影響,這意味著學生網(wǎng)絡的性能大體會是增長趨勢,這意味著他們的表現(xiàn)通常會提高,他們不能作為一個群體任意地漂移到群體思維。(原文: they cannot drift arbitrarily into groupthink as a cohort.)
在監(jiān)督信號下,所有的網(wǎng)絡都會朝著預測正確label的方向發(fā)展,但是不同的網(wǎng)絡在初始化值不同,他們會學到不同的表征,因此他們對下一類最有可能的概率的估計是不同的。
在Mutual Learning中,學生群體可以有效匯集下一個最后可能的類別估計,為每個訓練實例找到最有可能的類別,同時根據(jù)他們互學習對象增加每個學生的后驗熵,有助于網(wǎng)絡收斂到更平坦的極小值,從而帶來更好的泛華能力和魯棒性。
Why Deep Nets Generalise 有關網(wǎng)絡泛化性能的討論認為:在深度神經(jīng)網(wǎng)絡中,有很多解法(參數(shù)組合)可以使得訓練錯誤為0,其中一些在比較loss landscape平坦處參數(shù)可以比其他narrow位置的泛華性能更好,所以小的干擾不會徹底改變預測的效果;
DML通過實驗發(fā)現(xiàn):(1)訓練過程損失可以接近于0 。(2)在擾動下對loss的變動接受能力更強。(3)給出的class置信度不會過于高。總體來說就是:DML并沒有幫助我們找到更好的訓練損失最小值,而是幫助我們找到更廣泛/更穩(wěn)健的最小值,更好地對測試數(shù)據(jù)進行泛華。
DML具有的特點是:
適合于各種網(wǎng)絡架構,由大小網(wǎng)絡混合組成的異構的網(wǎng)絡也可以進行相互學習(因為只學習logits)
效能會隨著隊列中網(wǎng)絡數(shù)量的增加而增加,即互學習對象增多的時候,性能會有一定的提升。
有利于半監(jiān)督學習,因為其在標記和未標記數(shù)據(jù)上都激活了模仿?lián)p失。
雖然DML的重點是得到某一個有效的網(wǎng)絡,整個隊列中的網(wǎng)絡可以作為模型集成的對象進行集成。
DML中使用到了KL Divergence衡量兩者之間的差距:
P1和P2代表兩者的邏輯層輸出,那么對于每個網(wǎng)絡來說,他們需要學習的損失函數(shù)為:
其中代表傳統(tǒng)的分類損失函數(shù),比如交叉熵損失函數(shù)。
可以發(fā)現(xiàn)KL divergence是非對稱的,那么對兩個網(wǎng)絡來說,學習到的會有所不同,所以可以使用堆成的Jensen-Shannon Divergence Loss作為替代:
更新過程的偽代碼:
更多的互學習對象
給定K個互學習網(wǎng)絡,, 那么目標函數(shù)變?yōu)椋?/p>
將模仿信息變?yōu)槠渌W習網(wǎng)絡的KL divergence的均值。
擴展到半監(jiān)督學習
在訓練半監(jiān)督的時候,我們對于有標簽數(shù)據(jù)只使用交叉熵損失函數(shù),對于所有訓練數(shù)據(jù)(包括有標簽和無標簽)的計算KL Divergence 損失。
這是因為KL Divergence loss的計算天然的不需要真實標簽,因此有助于半監(jiān)督的學習。
實驗結果
幾個網(wǎng)絡的參數(shù)情況:
在CIFAR10和CIFAR100上訓練效果
在Reid數(shù)據(jù)集Market-1501上也進行了測試:
發(fā)現(xiàn)互學習目標越多,性能呈上升趨勢:
結論
本文提出了一種簡單而普遍適用的方法來提高深度神經(jīng)網(wǎng)絡的性能,方法是在一個隊列中通過對等和相互蒸餾進行訓練。
通過這種方法,可以獲得緊湊的網(wǎng)絡,其性能優(yōu)于那些從強大但靜態(tài)的教師中提煉出來的網(wǎng)絡。
DML的一個應用是獲得緊湊、快速和有效的網(wǎng)絡。文章還表明,這種方法也有希望提高大型強大網(wǎng)絡的性能,并且以這種方式訓練的網(wǎng)絡隊列可以作為一個集成來進一步提高性能。