知識蒸餾 Distilling the knowledge

1 、什么是知識?

通常認為,知識是模型學(xué)習(xí)到的參數(shù)(比如卷積的權(quán)重)

2 、什么是蒸餾?

將知識從大模型(教師模型)轉(zhuǎn)移到更適合部署的小模型(學(xué)生模型)

Distilling the knowledge in a Neural Network

知識蒸餾主要思想:

Student Model 學(xué)生模型模仿 Teacher Model 教師模型,二者相互競爭,直到學(xué)生模型可以與教師模型持平甚至卓越的表現(xiàn);(使用的數(shù)據(jù)是相同的)

The generic teacher-student framework for knowledge distillation.

知識蒸餾的算法:

主要由:1)知識 Knwoledge,2)蒸餾算法 Distillate,3)師生架構(gòu)組成(見上圖)

The schematic illustrations of sources of response-based knowledge, feature-based knowledge and relation-based knowledge in a deep teacher network.

通過上圖可知Knowledge 知識的形式主要有三種

1 、 Response-Based Knowledge?

主要指Teacher-Model 教師模型的最后一層———輸出層的特征。

主要思想是讓 Student Model 學(xué)生模型直接學(xué)習(xí)教師模型的預(yù)測結(jié)果(Knowledge)。

最簡單有效的模型壓縮方法,即:老師學(xué)習(xí)好,把結(jié)論直接告訴學(xué)生就 OK

假設(shè)張量z_{t} 為教師模型輸出,張量z_{s} 為學(xué)生模型輸出,Response-Based Knowledge的蒸餾形式可以被描述為:L_{ResD} (z_{t} ,z_{s} )? =?L_{R} (z_{t} ,z_{s} )

The generic response-based knowledge distillation

通過學(xué)習(xí)流程圖,可知相同的數(shù)據(jù),有個 Teacher模型(上圖紅色代表老師)和 Student 模型(綠色代表學(xué)生),會把 Teacher 輸出的特征給到 Student 模型去學(xué)習(xí),會拿出最后一層的特征,然后通過Distialltion Loss,讓學(xué)生最后一層的特征去學(xué)習(xí)老師輸出的特征。老師的輸出特征應(yīng)該是比較固定的而學(xué)生是不太固定的需要去學(xué)習(xí)的,于是通過一個損失函數(shù)去模擬、去減少、去學(xué)習(xí)是的兩個的 Logits 越小越好

2、Feature-Based Knowledge

深度神經(jīng)網(wǎng)絡(luò)善于學(xué)習(xí)到不同層級的表征,因此中間層和輸出層的都可以被用做知識來訓(xùn)練學(xué)生模型,中間層學(xué)習(xí)知識的Feature-Based Knowledge 對于Response-Based Knowledge 是一個很好的補充,其主要思想是將教師和學(xué)生的特征激活進行關(guān)聯(lián)起來,F(xiàn)eature-Based Knowledge 的知識轉(zhuǎn)移的蒸餾損失可表示為:

L_{FeaD} (f_{t} (x),f_{s} (x)) =L_{F} (Φ_{t} (f_{t} (x)),Φ_{s} (f_{s} (x)))

The generic feature-based knowledge distillation

通過學(xué)習(xí)流程圖可知:Distillation Loss 是建立在 Teacher Model 和 Student Model 的中間層,通過中間層去建立連接關(guān)系。

這種算法的好處 Teacher 網(wǎng)絡(luò)可以為 Student 網(wǎng)絡(luò)提供大量的、有用的參考信息。但如何有效的從教師模型中選擇提示層,從學(xué)生模型中選擇引導(dǎo)層,仍有待進一步研究。

缺點:由于提示層和引導(dǎo)層大小存在明顯差異,如何正確匹配教師和學(xué)生的特征也需要探討。

3、Relation-Based Knowledge

基于Feature-Based Knowledge 和Response-Based Knowledge 中都使用了教師模型中特定層中特征的輸出。基于關(guān)系的知識進一步探索了不同層或數(shù)據(jù)樣本之間的關(guān)系。一般情況下,基于特征圖關(guān)系的關(guān)系知識的蒸餾損失可以表示為:

L_{RelD}(f_{t} ,f_{s} ) =?L_{R^1 } (Ψ_{t} (\hat{f} _{t} ,\check{f} _{t} ), Ψ_{s} (\hat{f} _{s} ,\check{f} _{s} ))

The generic instance relation-based knowledge distillation

通過上圖可知:相同的數(shù)據(jù),有個 Teacher模型和 Student 模型,Distillation Loss 就不僅僅是學(xué)習(xí)網(wǎng)絡(luò)模型中間的特征還有最后一層的特征信息,它還會學(xué)習(xí)數(shù)據(jù)樣本和網(wǎng)絡(luò)模型層之間的關(guān)系

Knowledge Distillation: A Survey

知識蒸餾可以劃分為:1)Offline Distillation 2)Online Distillation 3)Self-Distillation

Different distillations. The red color for “pre-trained”means networks are learned before distillation and the yellowcolor for “to be trained” means networks are learned duringdistillation

紅色代表:預(yù)訓(xùn)練的模型

黃色代表:將要去訓(xùn)練的模型

1)Offline Distillation 通俗講:指知識淵博的教師向?qū)W生傳授知識

大多數(shù)蒸餾采用Offline Distillation,蒸餾過程被分為兩個階段

1.1、蒸餾前 Teacher 模型預(yù)訓(xùn)練?

1.2 、蒸餾算法遷移知識

因此Offline Distillation主要側(cè)重于知識遷移部分

通常采用單向知識轉(zhuǎn)移和兩階段訓(xùn)練過程。在步驟 1 中需要 Teacher 模型參數(shù)量比較大,訓(xùn)練時間比較長,這種方式對學(xué)生模型的蒸餾比較高效。

Tips:這種訓(xùn)練模式下的學(xué)生模型往往過度依賴于教師模型

2)Online Distillation 通俗講:指教師和學(xué)生共同學(xué)習(xí)知識

主要針對參數(shù)量大、精度性能好的教師模型不可獲得的情況。教師模型和學(xué)生模型同時更新,整個知識蒸餾算法是一種有效的端到端可訓(xùn)練方案(教師模型和學(xué)生模型一起去學(xué)習(xí))

Tips:現(xiàn)有的Online Distillation 往往難以獲得在線環(huán)境下參數(shù)量大、精度性能好的教師模型

3)Self-Distillation 通俗講:指學(xué)生自己學(xué)習(xí)知識

教師模型和學(xué)生模型使用相同的網(wǎng)絡(luò)結(jié)構(gòu)(自學(xué)習(xí)),同樣采用端到端可訓(xùn)練方案,屬于Online Distillation 的一種特例

知識蒸餾的過程

The specific architecture of the benchmark knowledge distillation

分成 5小步:

1 、把數(shù)據(jù)喂養(yǎng)到教師網(wǎng)絡(luò)去訓(xùn)練,通過升溫的 Softmax(T=t),得到 soft targets1

2 、把數(shù)據(jù)喂養(yǎng)到學(xué)生網(wǎng)絡(luò)去訓(xùn)練,通過升溫的 Softmax(T=t),得到 soft targets2(與步驟1是同溫的)

3、通過 1 、 2 兩步之后有兩個結(jié)果,對這兩個結(jié)果來運用一下,算一下就可以得到蒸餾損失 Distillation loss

4 、同樣把數(shù)據(jù)喂養(yǎng)到學(xué)生網(wǎng)絡(luò)去訓(xùn)練,通過正常的(未升溫) Softmax(T=1),得到 soft targets3

5、通過soft targets3和Ground Truth Label(正確標簽) ,這兩個值再計算一下,就可以的得到一個學(xué)生損失(Student loss)

整個過程會涉及兩個損失

蒸餾損失distillation loss 和 學(xué)生損失 student loss,這兩個損失

蒸餾損失

輸入:相同溫度下,學(xué)生模型和教師模型的 soft targets

常用:KL 散度?

作用:讓學(xué)生網(wǎng)絡(luò)的類別輸出預(yù)測分布盡可能擬合教師網(wǎng)絡(luò)輸出預(yù)測分布(通俗講:讓學(xué)生去學(xué)老師的一些行為)

學(xué)生損失

輸入:T=1 時,學(xué)生模型的 soft targets 和正確標簽

常用:交叉熵損失

作用:減少教師網(wǎng)絡(luò)中的錯誤信息被蒸餾到學(xué)生網(wǎng)絡(luò)中


蒸餾損失和學(xué)生損失,兩個損失函數(shù)是獨立的,怎么建立聯(lián)系?以及知識蒸餾整個過程的關(guān)鍵點有哪些?下一篇將詳細介紹經(jīng)典的知識蒸餾算法

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容