【GiantPandaCV引言】 知識回顧(KR)發現學生網絡深層可以通過利用教師網絡淺層特征進行學習,基于此提出了回顧機制,包括ABF和HCL兩個模塊,可以在很多分類任務上得到一致性的提升。
摘要
知識蒸餾通過將知識從教師網絡傳遞到學生網絡,但是之前的方法主要關注提出特征變換和實施相同層的特征。
知識回顧Knowledge Review選擇研究教師與學生網絡之間不同層之間的路徑鏈接。
簡單來說就是研究教師網絡向學生網絡傳遞知識的鏈接方式。
代碼在:https://github.com/Jia-Research-Lab/ReviewKD
KD簡單回顧
KD最初的蒸餾對象是logits層,也即最經典的Hinton的那篇Knowledge Distillation,讓學生網絡和教師網絡的logits KL散度盡可能小。
隨后FitNets出現開始蒸餾中間層,一般通過使用MSE Loss讓學生網絡和教師網絡特征圖盡可能接近。
Attention Transfer進一步發展了FitNets,提出使用注意力圖來作為引導知識的傳遞。
PKT(Probabilistic knowledge transfer for deep representation learning)將知識作為概率分布進行建模。
Contrastive representation Distillation(CRD)引入對比學習來進行知識遷移。
以上方法主要關注于知識遷移的形式以及選擇不同的loss function,但KR關注于如何選擇教師網絡和學生網絡的鏈接,一下圖為例:
(a-c)都是傳統的知識蒸餾方法,通常都是相同層的信息進行引導,(d)代表KR的蒸餾方式,可以使用教師網絡淺層特征來作為學生網絡深層特征的監督,并發現學生網絡深層特征可以從教師網絡的淺層學習到知識。
教師網絡淺層到深層分別對應的知識抽象程度不斷提高,學習難度也進行了提升,所以學生網絡如果能在初期學習到教師網絡淺層的知識會對整體有幫助。
KR認為淺層的知識可以作為舊知識,并進行不斷回顧,溫故知新。如何從教師網絡中提取多尺度信息是本文待解決的關鍵:
提出了Attention based fusion(ABF) 進行特征fusion
提出了Hierarchical context loss(HCL) 增強模型的學習能力。
Knowledge Review
形式化描述
X是輸入圖像,S代表學生網絡,其中代表學生網絡各個層的組成。
Ys代表X經過整個網絡以后的輸出。代表各個層中間層輸出。
那么單層知識蒸餾可以表示為:
M代表一個轉換,從而讓Fs和Ft的特征圖相匹配。D代表衡量兩者分布的距離函數。
同理多層知識蒸餾表示為:
以上公式是學生和教師網絡層層對應,那么單層KR表示方式為:
與之前不同的是,這里計算的是從j=1 to i 代表第i層學生網絡的學習需要用到從第1到i層所有知識。
同理,多層的KR表示為:
Fusion方式設計
已經確定了KR的形式,即學生每一層回顧教師網絡的所有靠前的層,那么最簡單的方法是:
直接縮放學生網絡最后一層feature,讓其形狀和教師網絡進行匹配,這樣可以簡單使用一個卷積層配合插值層完成形狀的匹配過程。這種方式是讓學生網絡更接近教師網絡。
這張圖表示擴展了學生網絡所有層對應的處理方式,也即按照第一張圖的處理方式進行形狀匹配。
這種處理方式可能并不是最優的,因為會導致stage之間出現巨大的差異性,同時處理過程也非常復雜,帶來了額外的計算代價。
為了讓整個過程更加可行,提出了Attention based fusion , 這樣整體蒸餾變為:
如果引入了fusion的模塊,那整體流程就變為下圖所示:
但是為了更高的效率,再對其進行改進:
可以發現,這個過程將fusion的中間結果進行了利用,即, 這樣循環從后往前進行迭代,就可以得到最終的loss。
具體來說,ABF的設計如下(a)所示,采用了注意力機制融合特征,具體來說中間的1x1 conv對兩個level的feature提取綜合空間注意力特征圖,然后再進行特征重標定,可以看做SKNet的空間注意力版本。
而HCL Hierarchical context loss 這里對分別來自于學生網絡和教師網絡的特征進行了空間池化金字塔的處理,L2 距離用于衡量兩者之間的距離。
KR認為這種方式可以捕獲不同level的語義信息,可以在不同的抽象等級提取信息。
實驗
實驗部分主要關注消融實驗:
第一個是使用不同stage的結果:
藍色的值代表比baseline 69.1更好,紅色代表要比baseline更差。通過上述結果可以發現使用教師網絡淺層知識來監督學生網絡深層知識是有效的。
第二個是各個模塊的作用:
源碼
主要關注ABF, HCL的實現:
ABF實現:
class ABF(nn.Module):
def __init__(self, in_channel, mid_channel, out_channel, fuse):
super(ABF, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channel, mid_channel, kernel_size=1, bias=False),
nn.BatchNorm2d(mid_channel),
)
self.conv2 = nn.Sequential(
nn.Conv2d(mid_channel, out_channel,kernel_size=3,stride=1,padding=1,bias=False),
nn.BatchNorm2d(out_channel),
)
if fuse:
self.att_conv = nn.Sequential(
nn.Conv2d(mid_channel*2, 2, kernel_size=1),
nn.Sigmoid(),
)
else:
self.att_conv = None
nn.init.kaiming_uniform_(self.conv1[0].weight, a=1) # pyre-ignore
nn.init.kaiming_uniform_(self.conv2[0].weight, a=1) # pyre-ignore
def forward(self, x, y=None, shape=None, out_shape=None):
n,_,h,w = x.shape
# transform student features
x = self.conv1(x)
if self.att_conv is not None:
# upsample residual features
y = F.interpolate(y, (shape,shape), mode="nearest")
# fusion
z = torch.cat([x, y], dim=1)
z = self.att_conv(z)
x = (x * z[:,0].view(n,1,h,w) + y * z[:,1].view(n,1,h,w))
# output
if x.shape[-1] != out_shape:
x = F.interpolate(x, (out_shape, out_shape), mode="nearest")
y = self.conv2(x)
return y, x
HCL實現:
def hcl(fstudent, fteacher):
# 兩個都是list,存各個stage對象
loss_all = 0.0
for fs, ft in zip(fstudent, fteacher):
n,c,h,w = fs.shape
loss = F.mse_loss(fs, ft, reduction='mean')
cnt = 1.0
tot = 1.0
for l in [4,2,1]:
if l >=h:
continue
tmpfs = F.adaptive_avg_pool2d(fs, (l,l))
tmpft = F.adaptive_avg_pool2d(ft, (l,l))
cnt /= 2.0
loss += F.mse_loss(tmpfs, tmpft, reduction='mean') * cnt
tot += cnt
loss = loss / tot
loss_all = loss_all + loss
return loss_all
ReviewKD實現:
class ReviewKD(nn.Module):
def __init__(
self, student, in_channels, out_channels, shapes, out_shapes,
):
super(ReviewKD, self).__init__()
self.student = student
self.shapes = shapes
self.out_shapes = shapes if out_shapes is None else out_shapes
abfs = nn.ModuleList()
mid_channel = min(512, in_channels[-1])
for idx, in_channel in enumerate(in_channels):
abfs.append(ABF(in_channel, mid_channel, out_channels[idx], idx < len(in_channels)-1))
self.abfs = abfs[::-1]
self.to('cuda')
def forward(self, x):
student_features = self.student(x,is_feat=True)
logit = student_features[1]
x = student_features[0][::-1]
results = []
out_features, res_features = self.abfs[0](x[0], out_shape=self.out_shapes[0])
results.append(out_features)
for features, abf, shape, out_shape in zip(x[1:], self.abfs[1:], self.shapes[1:], self.out_shapes[1:]):
out_features, res_features = abf(features, res_features, shape, out_shape)
results.insert(0, out_features)
return results, logit
參考
https://zhuanlan.zhihu.com/p/363994781