1 、什么是知識?
通常認為,知識是模型學(xué)習(xí)到的參數(shù)(比如卷積的權(quán)重)
2 、什么是蒸餾?
將知識從大模型(教師模型)轉(zhuǎn)移到更適合部署的小模型(學(xué)生模型)
Distilling the knowledge in a Neural Network
知識蒸餾主要思想:
Student Model 學(xué)生模型模仿 Teacher Model 教師模型,二者相互競爭,直到學(xué)生模型可以與教師模型持平甚至卓越的表現(xiàn);(使用的數(shù)據(jù)是相同的)
知識蒸餾的算法:
主要由:1)知識 Knwoledge,2)蒸餾算法 Distillate,3)師生架構(gòu)組成(見上圖)
通過上圖可知Knowledge 知識的形式主要有三種:
1 、 Response-Based Knowledge?
主要指Teacher-Model 教師模型的最后一層———輸出層的特征。
主要思想是讓 Student Model 學(xué)生模型直接學(xué)習(xí)教師模型的預(yù)測結(jié)果(Knowledge)。
最簡單有效的模型壓縮方法,即:老師學(xué)習(xí)好,把結(jié)論直接告訴學(xué)生就 OK
假設(shè)張量為教師模型輸出,張量
為學(xué)生模型輸出,Response-Based Knowledge的蒸餾形式可以被描述為:
(
,
)? =?
(
,
)
通過學(xué)習(xí)流程圖,可知相同的數(shù)據(jù),有個 Teacher模型(上圖紅色代表老師)和 Student 模型(綠色代表學(xué)生),會把 Teacher 輸出的特征給到 Student 模型去學(xué)習(xí),會拿出最后一層的特征,然后通過Distialltion Loss,讓學(xué)生最后一層的特征去學(xué)習(xí)老師輸出的特征。老師的輸出特征應(yīng)該是比較固定的而學(xué)生是不太固定的需要去學(xué)習(xí)的,于是通過一個損失函數(shù)去模擬、去減少、去學(xué)習(xí)是的兩個的 Logits 越小越好
2、Feature-Based Knowledge
深度神經(jīng)網(wǎng)絡(luò)善于學(xué)習(xí)到不同層級的表征,因此中間層和輸出層的都可以被用做知識來訓(xùn)練學(xué)生模型,中間層學(xué)習(xí)知識的Feature-Based Knowledge 對于Response-Based Knowledge 是一個很好的補充,其主要思想是將教師和學(xué)生的特征激活進行關(guān)聯(lián)起來,F(xiàn)eature-Based Knowledge 的知識轉(zhuǎn)移的蒸餾損失可表示為:
(
(x),
(x)) =
(
(
(x)),
(
(x)))
通過學(xué)習(xí)流程圖可知:Distillation Loss 是建立在 Teacher Model 和 Student Model 的中間層,通過中間層去建立連接關(guān)系。
這種算法的好處 Teacher 網(wǎng)絡(luò)可以為 Student 網(wǎng)絡(luò)提供大量的、有用的參考信息。但如何有效的從教師模型中選擇提示層,從學(xué)生模型中選擇引導(dǎo)層,仍有待進一步研究。
缺點:由于提示層和引導(dǎo)層大小存在明顯差異,如何正確匹配教師和學(xué)生的特征也需要探討。
3、Relation-Based Knowledge
基于Feature-Based Knowledge 和Response-Based Knowledge 中都使用了教師模型中特定層中特征的輸出。基于關(guān)系的知識進一步探索了不同層或數(shù)據(jù)樣本之間的關(guān)系。一般情況下,基于特征圖關(guān)系的關(guān)系知識的蒸餾損失可以表示為:
(
,
) =?
(
(
,
),
(
,
))
通過上圖可知:相同的數(shù)據(jù),有個 Teacher模型和 Student 模型,Distillation Loss 就不僅僅是學(xué)習(xí)網(wǎng)絡(luò)模型中間的特征還有最后一層的特征信息,它還會學(xué)習(xí)數(shù)據(jù)樣本和網(wǎng)絡(luò)模型層之間的關(guān)系
Knowledge Distillation: A Survey
知識蒸餾可以劃分為:1)Offline Distillation 2)Online Distillation 3)Self-Distillation
紅色代表:預(yù)訓(xùn)練的模型
黃色代表:將要去訓(xùn)練的模型
1)Offline Distillation 通俗講:指知識淵博的教師向?qū)W生傳授知識
大多數(shù)蒸餾采用Offline Distillation,蒸餾過程被分為兩個階段
1.1、蒸餾前 Teacher 模型預(yù)訓(xùn)練?
1.2 、蒸餾算法遷移知識
因此Offline Distillation主要側(cè)重于知識遷移部分
通常采用單向知識轉(zhuǎn)移和兩階段訓(xùn)練過程。在步驟 1 中需要 Teacher 模型參數(shù)量比較大,訓(xùn)練時間比較長,這種方式對學(xué)生模型的蒸餾比較高效。
Tips:這種訓(xùn)練模式下的學(xué)生模型往往過度依賴于教師模型
2)Online Distillation 通俗講:指教師和學(xué)生共同學(xué)習(xí)知識
主要針對參數(shù)量大、精度性能好的教師模型不可獲得的情況。教師模型和學(xué)生模型同時更新,整個知識蒸餾算法是一種有效的端到端可訓(xùn)練方案(教師模型和學(xué)生模型一起去學(xué)習(xí))
Tips:現(xiàn)有的Online Distillation 往往難以獲得在線環(huán)境下參數(shù)量大、精度性能好的教師模型
3)Self-Distillation 通俗講:指學(xué)生自己學(xué)習(xí)知識
教師模型和學(xué)生模型使用相同的網(wǎng)絡(luò)結(jié)構(gòu)(自學(xué)習(xí)),同樣采用端到端可訓(xùn)練方案,屬于Online Distillation 的一種特例
知識蒸餾的過程
分成 5小步:
1 、把數(shù)據(jù)喂養(yǎng)到教師網(wǎng)絡(luò)去訓(xùn)練,通過升溫的 Softmax(T=t),得到 soft targets1
2 、把數(shù)據(jù)喂養(yǎng)到學(xué)生網(wǎng)絡(luò)去訓(xùn)練,通過升溫的 Softmax(T=t),得到 soft targets2(與步驟1是同溫的)
3、通過 1 、 2 兩步之后有兩個結(jié)果,對這兩個結(jié)果來運用一下,算一下就可以得到蒸餾損失 Distillation loss
4 、同樣把數(shù)據(jù)喂養(yǎng)到學(xué)生網(wǎng)絡(luò)去訓(xùn)練,通過正常的(未升溫) Softmax(T=1),得到 soft targets3
5、通過soft targets3和Ground Truth Label(正確標簽) ,這兩個值再計算一下,就可以的得到一個學(xué)生損失(Student loss)
整個過程會涉及兩個損失
蒸餾損失distillation loss 和 學(xué)生損失 student loss,這兩個損失
蒸餾損失
輸入:相同溫度下,學(xué)生模型和教師模型的 soft targets
常用:KL 散度?
作用:讓學(xué)生網(wǎng)絡(luò)的類別輸出預(yù)測分布盡可能擬合教師網(wǎng)絡(luò)輸出預(yù)測分布(通俗講:讓學(xué)生去學(xué)老師的一些行為)
學(xué)生損失
輸入:T=1 時,學(xué)生模型的 soft targets 和正確標簽
常用:交叉熵損失
作用:減少教師網(wǎng)絡(luò)中的錯誤信息被蒸餾到學(xué)生網(wǎng)絡(luò)中
蒸餾損失和學(xué)生損失,兩個損失函數(shù)是獨立的,怎么建立聯(lián)系?以及知識蒸餾整個過程的關(guān)鍵點有哪些?下一篇將詳細介紹經(jīng)典的知識蒸餾算法