論文 | NeurIPS2020 CrossTransformers:spatially-aware few-shot transfer

一 寫在前面

未經允許,不得轉載,謝謝~~~

嘿,好久不見,我要開始慢慢恢復科研論文筆記的更新啦~

今天分享的文章是做小樣本圖像識別的。

主要信息:

二 主要內容

2.1 相關背景

小樣本圖像識別的方法從整體上來看大概可以分成兩個階段:

  1. representation learning:獲取到一個比較好的圖像特征提取器;
  2. classifier:通過比對query images和support images進行query image的標簽預測

文章首先總結了現有方法的共同點:

  1. 在representation leanring的學習上具有的一個共同點就是都會使用訓練圖片的類別標簽做一個監督學習;
  2. 在classifier學習階段具有的一個共同點是會將query和support圖像之間的整體特征進行比較,例如ProtoNet就是將query的特征與support set中每個類中心的特征進行比較。

2.2 本文工作

文章首先支持現有方法的不足:

  1. 完全依靠類別標簽進行特征學習的方式會導致只能學習到跟類別相關的信息,而忽略其他更加通用的特征表示;
  2. 在做圖像比較的時候,圖像中的一些重要objects和scenes通常是local的,直接用整體特征進行比較的效果不一定是最好的;

相對應地,文章提出了從兩個方面進行優化:

  1. 針對第一個問題,提出引入自監督學習的方法SimCLR來獲取更加通用的圖像特征表示;
  2. 針對第二個問題,提出基于Transformer的新結構CrossTransformer,希望能夠進行local信息的圖像匹配;

三 方法介紹

文章是基于ProtoNet結構的,所以首先介紹下ProtoNet, 然后分別介紹以上兩點novelty。

3.1 ProtoNet

ProtoNet算的上是小樣本圖像識別領域最flagship的工作了,這里只做個簡單的介紹。

N-way-K-shot
給定一堆帶標簽可供參考的Support Images,具體表示為有N個類別,每個類別有K張帶標注的圖像,以及一個等待被分類的query image (query image的類別一定屬于N個類別),我們需要根據support images預測出query image的類別標簽。

key idea:
Protonet的想法非常直接但有效。即對每張圖像都先用神經網絡得到一個特征表示,然后對support set中每個類別c的所有特征取一個平均,作為這個類別的類中心。最后比較query feature跟各個類中心之間的距離,取最近的一個類別作為預測結果。

3.2 SSL with SimCLR

這里的想法也比較直接,就是覺得自監督學習得到的特征表示不僅對semantic敏感,而且對屬于相同類別的不同圖片也具有區分度,可以理解為只用class informaction進行監督學習得到的特征是class-level的,SSL學習到的是instance-level的,因此作者認為SSL學習到的特征泛化性會更好。

具體的做法也比較簡單。為了區分原來的episode和現在用自監督的episode, 分別用MD-categorization episode以及SimCLR episode來表示它們。在訓練的過程中隨機轉化50%的MD-categorization episode為SimCLR episode, 對SimCLR episode用SimCLR中的方法進行增強,然后對query image也進行增強,最后用各自對應的loss function進行優化。

:( 這邊的具體細節感覺只看文章還不是特別清楚,可能需要感興趣的同學可以自己看看他們的code

3.3 CrossTransformers

這部分都是基于Transformer構建的,如果之前完全不了解的話或許是會比較困難的,建議看看原文:https://arxiv.org/abs/1706.03762, 或者推薦一個我個人最推薦的blog:https://zhuanlan.zhihu.com/p/48508221

文章的主要框架圖如下圖所示:

第一張是文章原圖,第二張是我在原圖的基礎上把各個重要部分對應的數據維度標注上去以及補充了額外內容的圖,可以對照著看。

文章原圖
帶標記圖

主要的pipeline包括以下幾步:

  1. 首先看輸入,給定最左邊的一個query image x_q, 以及最上面的support set中類別為c的幾個圖像{x_1^c, x_2^c, ...}, 網絡的目的是要獲取到一個query-specific的類中心(不再是原始ProtoNet版本中直接取平均的方法)
  2. 首先注意到不管是對于query還是對于support images,都是先用一個\phi()得到圖像的特征表示,這里文章中用的是ResNet,并且去掉了最后一個pooling層,所以得到的特征維度為R^{H`^ \times W^` \times D}
  3. 接下來就是基于query,key,value的attention操作。這里的query是指query image,而key和value都是指support sets。理解這一點對理解整個attention還挺重要的。

網絡圖中的query heads,key heads都是將輸入特征從D維度映射到d_k維度,而value heads將輸入特征從D維度映射到d_v維度。

具體地,(建議對著圖看)

  • query heads將query特征從R^{H^` \times W^` \times D}維度映射到R^{H^` \times W^` \times d_k}維度(圖中shi黃色的框框);
  • key heads將support特征從R^{H^` \times W^` \times D}維度映射到R^{H^` \times W^` \times d_k}維度(圖中亮黃色的框框,左右兩個表示的是一樣的意思,看第一個就行了);
  • value heads將support特征從R^{H^` \times W^` \times D}映射到R^{H^` \times W^` \times d_v}維度(圖中紅色框框,也看其中一個就可)。
  1. 然后就是計算query和key之間的attention,我們還是只看一個query(shi黃色框)和一個support圖像特征(第一個亮黃色框框),經過映射之后兩個的特征維度都是R^{H^` \times W^` \times d_k},對于query中任意一個位置p和support中的任意一個位置m,特征維度都是d_k, 通過向量點乘的方法可以得到這2個點之間的attention值,圖中小黑點在的位置。對每個HxW中的點都計算一次attention,最終就會得到一張query和一張support的attention mapa_1^c, 當然還做了一個softmax操作得到更新后的attention map\tilde{a_1^c}。對suppport中的多張圖采取同樣的操作就會得到多張attention map。

  2. 最后就是利用這些attention maps對support set中不同圖像的vaule特征進行加權平均。這部分操作可以理解為,對于<query, support image i>, 對于HxW中的任意一個位置,都用其第i張attention map的值乘上對應第i個紅色框框位置的value,最后把不同support images的結果值進行相加得到最終query-aligned prototype的特征表示,其維度為R^{H^` \times W^` \times d_v}

  3. 到這里為止我們獲取到了query-aligned prototype R^{H^` \times W^` \times d_v}。 但是要做小樣本預測到這里還沒有完全完整,我把第二張圖中把剩下的部分補上了。對于query image,其實也用value head做了一個映射,得到一個query image的value 特征表示,其維度為R^{H^` \times W^` \times d_v}, 跟prototype的維度是一樣的,這樣就可以比較這兩者之間的距離,進而進行label預測了.

五 寫在最后

我在寫這個blog的時候,盡量避免了公式的出現,但可能有些地方解釋的還是有些不好理解,尤其是crossTransformer部分涉及的符號略多,大家見諒啦。

這篇文章暫時介紹到這里,最后打個不那么相關的廣告,我們做小樣本視頻分類的工作(AMeFu-Net)近期開源了,link: https://github.com/lovelyqian/AMeFu-Net,歡迎大家關注~

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。