本文是CIKM20上華為發(fā)表的一篇關(guān)于使用知識蒸餾來提升點(diǎn)擊率預(yù)估效果的論文,論文標(biāo)題是《Ensembled CTR Prediction via Knowledge Distillation》,下載地址為:https://dl.acm.org/doi/pdf/10.1145/3340531.3412704
1、背景
當(dāng)前對于點(diǎn)擊率預(yù)估的研究大致集中在兩方面,一種是嘗試更為復(fù)雜的網(wǎng)絡(luò)結(jié)構(gòu)來更好的捕捉特征之間的交叉信息以及用戶的動態(tài)行為信息,如引入卷積神經(jīng)網(wǎng)絡(luò)、循環(huán)神經(jīng)網(wǎng)絡(luò)、注意力機(jī)制和圖神經(jīng)網(wǎng)絡(luò)等;另一種趨勢是沿用Wide & Deep的思路,嘗試將多個子模塊進(jìn)行融合,如DeepFM、DCN、XDeepFM、AutoInt等。盡管這些研究帶來了點(diǎn)擊率預(yù)估效果的提升,但是隨著模型結(jié)構(gòu)的復(fù)雜,在實(shí)際工業(yè)界使用這些模型耗時會越來越高,往往難以真正在線上進(jìn)行部署。
那么如何既能保持多模型融合的效果,同時能夠使得模型更加輕量化呢?知識蒸餾的方式是一種不錯的選擇。
2、多教師網(wǎng)絡(luò)知識蒸餾
2.1 單教師網(wǎng)絡(luò)知識蒸餾
首先來看下單教師網(wǎng)絡(luò)知識蒸餾的框架:
可以看到,同樣的特征分別輸入到teacher網(wǎng)絡(luò)和student網(wǎng)絡(luò)中,得到各自的輸出,那么teacher網(wǎng)絡(luò)和student網(wǎng)絡(luò)的損失分別為:
teacher網(wǎng)絡(luò)的損失只有交叉熵?fù)p失,而student網(wǎng)絡(luò)的損失除包含交叉熵?fù)p失外,還包括一項(xiàng)蒸餾損失。蒸餾損失通常有兩種計(jì)算方式:soft label和hint regression。
soft label
soft label的計(jì)算公式如下:
這里為什么要使用soft label呢?在計(jì)算softmax之前,對所有值乘上一個大于1的數(shù),會起到sharp的作用,使得預(yù)測概率最高的那一類更加接近于1,而除以一個大于1的數(shù),則會起到soft的作用,使得類別的預(yù)測概率更加接近。將這樣的信息傳遞給student網(wǎng)絡(luò),可以提供額外的信息,例如下圖中假設(shè)soft label是預(yù)測為1的概率是0.7,7是0.2,9是0.1,那么student可以學(xué)到不同類別之間的隱藏關(guān)系,比如1和7可能是更接近的,1和9也是比較接近的。
上述圖片出自課程:https://www.bilibili.com/video/BV1SC4y1h7HB?p=7
hint regression
hint regression的目的是引導(dǎo)student網(wǎng)絡(luò)學(xué)習(xí)teacher網(wǎng)絡(luò)的中間層表示。這里VT代表選擇的teacher網(wǎng)絡(luò)的中間層表示,VS代表student網(wǎng)絡(luò)中被指導(dǎo)的層的輸出。通過矩陣W進(jìn)行變換,期望二者的距離越近越好,此時蒸餾損失表示為:
2.2 多教師網(wǎng)絡(luò)知識蒸餾
模型融合能夠有效提升CTR預(yù)估的效果,但會帶來耗時的增加。因此,可以通過知識蒸餾的方式,讓student網(wǎng)絡(luò)從多個模型中進(jìn)行學(xué)習(xí),來達(dá)到近似或比模型融合更佳的效果。因此,論文提出了多教師網(wǎng)絡(luò)知識蒸餾,其結(jié)構(gòu)如下圖所示:
這里的主要問題是,多teacher網(wǎng)絡(luò)如何向student網(wǎng)絡(luò)傳遞知識?最簡單的方式就是對所有teacher的輸出進(jìn)行平均。這種做法實(shí)現(xiàn)簡單,但是不同的teacher的模型結(jié)構(gòu)和訓(xùn)練框架都不同,所能夠提供的知識的重要程度也是不同的,如果有一個效果較差的teacher網(wǎng)絡(luò),可能會影響到student網(wǎng)絡(luò)的學(xué)習(xí)。因此可以對不同teacher網(wǎng)絡(luò)的知識進(jìn)行加權(quán):
權(quán)重的學(xué)習(xí)通過一個gate網(wǎng)絡(luò)得到,計(jì)算方式如下:
2.3 網(wǎng)絡(luò)訓(xùn)練
知識蒸餾一般有兩種訓(xùn)練方式,pre- train方式和co-train方式。pre- train方式是預(yù)先訓(xùn)練teacher網(wǎng)絡(luò),然后再訓(xùn)練student網(wǎng)絡(luò);co-train方式則是通過上述介紹的損失對teacher網(wǎng)絡(luò)和student網(wǎng)絡(luò)進(jìn)行聯(lián)合訓(xùn)練。co-train方式往往訓(xùn)練速度更快,但所需的GPU資源也會更多。后續(xù)實(shí)驗(yàn)部分也會看一下這兩種實(shí)驗(yàn)的效果對比。
3、實(shí)驗(yàn)結(jié)果
最后簡單看下實(shí)驗(yàn)結(jié)果部分。首先看下單教師網(wǎng)絡(luò)知識蒸餾的結(jié)果,可以看到,無論選擇DeepFM、DCN或是xDeepFM作為teacher網(wǎng)絡(luò),均是使用soft label和pre-train方式得到了最優(yōu)的效果:
再看下多教師網(wǎng)絡(luò)知識蒸餾的效果,其中3T(M)代表DeepFM、DCN和xDeepFM三個網(wǎng)絡(luò)作為teacher網(wǎng)絡(luò),6T(M)則是每種模型使用不同的隨機(jī)因子,訓(xùn)練兩遍。3T(D)代表使用同一模型,將同一數(shù)據(jù)集切分成3份訓(xùn)練集進(jìn)行訓(xùn)練,6T(D)則是數(shù)據(jù)集切分成6份進(jìn)行訓(xùn)練:
從上表可以得到如下的結(jié)論:
1)隨著1T->3T->6T,teacher網(wǎng)絡(luò)和student網(wǎng)絡(luò)的效果都是越來越好的
2)3T(D)/6T(D)的效果好于3T(M)/6T(M),這可能是由于選擇的Teacher是3種模型中最好的模型導(dǎo)致的
3)student網(wǎng)絡(luò)的效果反而比teacher網(wǎng)絡(luò)更好,一種解釋是,student網(wǎng)絡(luò)不僅學(xué)習(xí)了teacher網(wǎng)絡(luò)的經(jīng)驗(yàn),同時student網(wǎng)絡(luò)結(jié)構(gòu)相對簡單,保持了更好的泛化性能。
好了,本文就到這里了,論文對于知識蒸餾這一知識點(diǎn)的總結(jié)以及實(shí)驗(yàn)部分都是值得一看的。