https://arxiv.org/abs/2003.10580v4
我們提出了元偽標簽,這是一種半監督學習方法,在ImageNet上實現了90.2%的最新top-1準確率,比現有的最新水平提高了1.6%。與偽標簽一樣,元偽標簽有一個教師網絡,用于在未標記的數據上生成偽標簽,以教授學生網絡。然而,與教師固定的偽標簽不同,元偽標簽中的教師不斷地根據學生在標簽數據集上的表現反饋進行調整。因此,教師可以生成更好的偽標簽來教學生。代碼在https://github.com/google-research/google-research/tree/master/meta_pseudo_labels
1.導言
偽標簽或自訓練方法[57、81、55、36]已成功應用于改進許多計算機視覺任務中的最新模型,如圖像分類(如[79、77])、目標檢測和語義分割(如[89、51])。偽標簽方法通過一對網絡工作,一個作為教師,一個作為學生。教師在未標記的圖像上生成偽標簽。然后將這些偽標記圖像與標記圖像結合起來訓練學生。由于大量的偽標記數據以及數據增強等正則化方法的使用,學生學會了比老師更好[77]。
盡管偽標簽方法的性能很好,但它們有一個主要缺點:如果偽標簽不準確,學生將從不準確的數據中學習。因此,學生可能不會比老師取得顯著的進步。這個缺點也被稱為偽標記中的確認偏差問題[2]。
在本文中,我們設計了一個系統的機制,讓教師通過觀察其偽標簽對學生的影響來糾正偏見。具體來說,我們提出了元偽標簽,它利用學生的反饋來通知教師生成更好的偽標簽。在我們的實現中,反饋信號是學生在標記數據集上的表現。此反饋信號用作獎勵,在學生學習的整個過程中培訓教師。總之,元偽標簽的教師和學生是并行培訓的:(1)學生從教師注釋的偽標簽數據的小批量中學習,(2)教師從獎勵信號中學習學生在從標簽數據集中提取的小批量中的表現。
我們使用元偽標簽進行實驗,使用ImageNet[56]數據集作為標記數據,使用JFT-300M數據集[26,60]作為未標記數據。我們使用元偽標簽訓練了一對高效的網絡,一個是教師,一個是學生。由此產生的學生網絡在ImageNet ILSVRC 2012驗證集[56]上達到了90.2%的最高精度,比之前88.6%的記錄[16]高出1.6%。該學生模型還推廣到ImageNet ReaL測試集[6],如表1所示。在CIFAR10-4K、SVHN-1K和,ImageNet-10%還表明,元偽標簽的性能優于最近提出的一系列其他方法,如FixMatch[58]和無監督數據增強[76]。
2.元偽標簽
圖1概述了偽標簽和元偽標簽之間的對比。主要區別在于,在元偽標簽中,教師在標記的數據集上收到學生表現的反饋。
3.小規模實驗
在本節中,我們將介紹我們在小尺度下對元偽標簽的實證研究。我們首先研究了反饋在簡單TwoMoon數據集上的元偽標簽中的作用[7]。這項研究直觀地說明了元偽標簽的行為和好處。然后,我們在標準基準(如CIFAR-10-4K、SVHN-1K和ImageNet-10%)上將元偽標簽與最先進的半監督學習方法進行比較。我們通過在標準ResNet-50體系結構上使用完整的ImageNet數據集進行實驗來結束本節。
3.1. 雙月實驗
為了理解反饋在元偽標簽中的作用,我們在簡單和經典的TwoMoon數據集上進行了一個實驗[7]。TwoMoon數據集的2D特性使我們能夠可視化元偽標簽相對于監督學習和偽標簽的行為。
數據集。
在這個實驗中,我們生成了我們自己版本的TwoMoon數據集。在我們的版本中,有2000個示例形成兩個集群,每個集群有1000個示例。僅標記了6個示例,每個集群3個示例,而其余示例未標記。要求半監督學習算法使用這6個標記示例和聚類假設將兩個聚類劃分為正確的類。
培訓細節。
我們的模型結構是一個前饋全連接神經網絡,有兩個隱藏層,每個層有8個單元。在每一層上都使用了sigmoid非線性。在元偽標簽中,教師和學生都有這種結構,但有獨立的權重。所有網絡都使用SGD進行訓練,使用0.1的恒定學習率。網絡的權值初始化為-0.1和0.1之間的均勻分布。我們不應用任何正則化。
結果。
我們隨機生成TwoMoon數據集幾次,并重復三種方法:監督學習、偽標簽和元偽標簽。我們觀察到,元偽標簽比監督學習和偽標簽具有更高的找到正確分類器的成功率。圖2展示了我們實驗的典型結果,其中紅色和綠色區域對應于分類器的決策。從圖中可以看出,監督學習發現了一個錯誤的分類器,該分類器對標記的實例進行了正確分類,但未能利用聚類假設來分離兩個“衛星”。偽標簽使用監督學習中的壞分類器,因此在未標記的數據上接收不正確的偽標簽。因此,偽標簽會找到一個分類器,該分類器會對一半的數據(包括一些標記的實例)進行錯誤分類。另一方面,元偽標簽則使用學生模型在標記實例上丟失的反饋來調整教師以生成更好的偽標簽。因此,元偽標簽為這個數據集找到了一個好的分類器。換句話說,在本實驗中,元偽標簽可以解決偽標簽的確認偏差問題[2]。
3.2. CIFAR-10-4K、SVHN-1K和ImageNet-10%實驗
數據集。
我們考慮了三個標準基準:CIOFE-10-4K、SvHN-1K和IMANETET-10%,它們在文獻中被廣泛使用,以相當基準半監督學習算法。這些基準是通過將訓練集的一小部分保留為標記數據,而將其余部分用作未標記數據而創建的。對于CIFAR-10[34],4000個標記的示例保留為標記數據,而41000個示例用作未標記數據。CIFAR-10的測試集是標準的,由10000個示例組成。對于SVHN[46],1000個示例用作標記數據,而大約603000個示例用作未標記數據。SVHN的測試集也是標準的,有26032個示例。最后,對于ImageNet[56],128000個示例用作標記數據,約占整個ImageNet訓練集的10%,而剩余的128萬個示例用作未標記數據。ImageNet的測試集是具有50000個示例的標準ILSVRC 2012版本。對于CIFAR-10和SVHN,我們使用32x32的圖像分辨率,對于ImageNet,我們使用224x224的圖像分辨率。
培訓細節。
在我們的實驗中,我們的老師和學生共享相同的架構,但有獨立的權重。對于CIFAR-10-4K和SVHN-1K,我們使用了WideResNet-28-2[84],它有145萬個參數。對于ImageNet,我們使用一個ResNet-50[24],它有2550萬個參數。這些體系結構也被該領域以前的工作所普遍使用。在培訓教師和學生的元偽標簽培訓階段,我們對所有模型使用先前工作中的默認超參數,但RandAugment[13]中的一些修改除外,我們在附錄C.2中對此進行了詳細說明。附錄C.4中報告了所有超參數。在使用元偽標簽對教師和學生進行培訓之后,我們在標記的數據集上對學生進行微調。對于這個微調階段,我們使用固定學習率為10的SGD?5,批量大小為512,針對ImageNet-10%運行2000個步驟,針對CIFAR-10和SVHN運行1000個步驟。由于所有三個數據集的標記示例數量有限,因此我們不使用任何heldout驗證集。相反,我們在最后一個檢查點返回模型。
基線。
為了確保公平比較,我們僅將元偽標簽與使用相同體系結構的方法進行比較,而不與使用更大體系結構的方法進行比較,例如,對于CIFAR-10和SVHN[5,4,72,76],使用更大的體系結構的方法,如Biger-WideResNet-28-2和PyramidNet+ShakeDrop;對于ImageNet-10%[25, 23, 10, 8, 9]. 我們也不會將元偽標簽與培訓程序進行比較,培訓程序包括自我蒸餾或從更大的老師那里蒸餾[8,9]。我們在基線上實施這些限制,因為眾所周知,更大的體系結構和蒸餾可以改進任何方法,可能包括元偽標簽。
我們直接將元偽標簽與兩個基線進行比較:全數據集監督學習和無監督數據增強(UDA[76])。完整數據集的監督學習代表了凈空,因為它不公平地使用了所有標記的數據(例如,對于CIFAR10,它使用了所有50000個標記的示例)。我們還與UDA進行比較,因為我們的元偽標簽實現在培訓教師時使用了UDA。這兩個基線使用相同的實驗協議,因此確保公平比較。我們遵循[48]的train/eval/test拆分,并使用相同數量的資源來調整基線和元偽標簽的超參數。更多詳情見附錄C。
其他基線。
除了這兩個基線外,我們還將一系列其他半監督基線分為兩類:標簽傳播和自監督。由于這些方法不共享相同的受控環境,因此與它們的比較不是直接的,應該按照[48]的建議進行語境化。比較元偽標簽和其他基線的更多受控實驗見附錄D。
結果。
表2顯示了我們使用元偽標簽與其他方法進行比較的結果。結果表明,在嚴格公平比較的情況下(如[48]所述),元偽標簽顯著優于UDA。有趣的是,在CIFAR-10-4K上,元偽標簽甚至超過了整個數據集上的凈空監督學習。在ImageNet-10%上,元偽標簽在前1名的準確率方面比UDA教師高出5%以上,從68.07%提高到73.89%。對于ImageNet來說,這樣的相對改進非常重要。
與現有最先進的方法相比。
與以往文獻報道的結果相比,元偽標簽在所有三個數據集(CIFAR-10-4K、SVHN-1K和ImageNet-10%)的相同模型體系結構中取得了最好的精度。在CIFAR-10-4K和SVHN-1K上,與最高報告基線相比,元偽標簽導致了近10%的相對誤差降低[58]。在ImageNet-10%上,元偽標簽比SimCLR[8,9]的精度高出2.19%。
雖然在這些數據集上存在更好的結果,但據我們所知,這些結果都是通過更大的模型、更強的正則化技術或額外的蒸餾程序獲得的。例如,CIFAR10-4K的最佳報告精度為97.3%[76],但該精度是通過一個金字塔網實現的,該金字塔網的參數比我們的WideResNet-28-2多17倍,并使用復雜的振動降正則化[80]。另一方面,通過SimCLRv2[9]使用自蒸餾訓練階段和ResNet-152×3(其參數比我們的ResNet-50多32倍),ImageNet-10%的最高報告精度為80.9%,蒸餾也可以應用于元偽標記,以進一步改進我們的結果。
3.3. ResNet-50實驗
先前的實驗表明,在CIFAR-10-4K、SVHN-1K和ImageNet-10%上,元偽標簽優于其他半監督學習方法。在本實驗中,我們對整個ImageNet數據集上的元偽標簽以及來自JFT數據集的未標記圖像進行基準測試。本實驗的目的是在我們對EfficientNet進行更大規模的實驗之前,驗證元偽標簽在廣泛使用的ResNet-50體系結構[24]上是否工作良好(第4節)。
數據集。
如前所述,我們使用來自ImageNet數據集的所有標記示例進行實驗。我們從ImageNet數據集中保留25000個示例,用于超參數調整和模型選擇。我們的測試集是ILSVRC 2012驗證集。此外,我們從JFT數據集中獲取了1280萬張未標記的圖像。為了獲得這1280萬張未標記的圖像,我們首先在整個ImageNet訓練集上訓練一個ResNet-50,然后使用得到的ResNet-50為JFT數據集中的圖像分配類別概率。然后,我們為1000類ImageNet中的每一類選擇12800張概率最高的圖像。這一選擇產生了1280萬張圖像。我們還確保使用的1280萬張圖像中沒有一張與ImageNet的ILSVRC 2012驗證集重疊。UDA[76]和嘈雜的學生[77]使用了過濾額外未標記數據的程序。
實施細節。
我們實現了與第3.2節相同的元偽標簽,但我們使用了更大的批量和更多的訓練步驟,因為本實驗的數據集要大得多。具體來說,對于學生和教師,我們對標記圖像使用4096的批量大小,對未標記圖像使用32768的批量大小。我們在未標記的數據集上訓練500000個步驟,相當于大約160個紀元。在ImageNet+JFT上訓練元偽標簽階段后,我們在ImageNet上微調生成的學生10000 SGD步驟,使用10的固定學習率?4.使用512 TPUv2核,我們的培訓過程大約需要2天。
基線。
我們將元偽標簽與兩組基線進行比較。第一組包含有監督學習方法和數據增強或正則化方法,如AutoAugment[12]、DropBlock[18]和CutMix[83]。這些基線代表了ResNet-50上最先進的監督學習方法。第二組基線包括三種最新的半監督學習方法,它們利用來自ImageNet的標記訓練圖像和其他地方的未標記圖像。具體而言,十億規模的半監督學習[79]使用來自YFCC100數據集的未標記數據[65],而UDA[76]和嘈雜學生[77]都使用JFT作為未標記數據,如元偽標簽。與第3.2節類似,我們僅將Meta偽標簽與使用ResNet-50和未經蒸餾獲得的結果進行比較。
結果。
表3給出了結果。從表中可以看出,元偽標簽將ResNet-50的頂級精度從76.9%提高到83.2%,這對于ImageNet來說是一個很大的改進空間,優于UDA和Noised Student。元偽標簽在頂級精度方面也優于十億規模的SSL[68,79]。這尤其令人印象深刻,因為數十億規模的SSL在Instagram的弱監督圖像上預先訓練他們的ResNet-50。
4.大規模實驗:突破ImageNet精度極限
在本節中,我們將擴展元偽標簽,以便在大型模型和大型數據集上進行訓練,以提高ImageNet精度的極限。具體來說,我們使用EfficientNet-L2體系結構,因為它比Resnet具有更高的容量。EfficientNet-L2也被嘈雜的學生使用[77],在ImageNet上實現了88.4%的頂級精度。
數據集。
在本實驗中,我們使用整個ImageNet訓練集作為標記數據,并使用JFT數據集作為未標記數據。JFT數據集有3億張圖像,然后由有噪聲的學生使用置信閾值和上采樣將圖像過濾到1.3億張[77]。我們使用的是和吵鬧的學生一樣的1.3億張圖片。
模型架構。
我們使用EfficientNetL2進行實驗,因為它在ImageNet[77]上具有最先進的性能,沒有額外的標記數據。除了使用512x512而不是475x475的訓練圖像分辨率外,我們對有噪聲的學生使用相同的超參數。我們增加了輸入圖像的分辨率,以與我們在下一段中討論的模型并行實現兼容。除了EfficientNet-L2之外,我們還使用了一個較小的模型進行實驗,該模型的深度與EfficientNet-B6[63]相同,但寬度因子從2.1增加到了5.0。該模型稱為EfficientNet-B6-Wide,具有3.9億個參數。對于EfficientNet-B6-Wide,我們采用EfficientNet-L2的所有超參數。我們發現EfficientNet-B6-Wide的性能幾乎與EfficientNet-L2相同,但編譯和訓練速度更快。
模型并行性。
由于我們網絡的內存占用,為教師和學生保留兩個這樣的網絡內存將大大超過我們加速器的可用內存。因此,我們設計了一個混合的modeldata并行框架來運行元偽標簽。具體而言,我們的培訓過程在2048個TPUv3核心集群上運行。我們將這些核心劃分為128個相同的副本,以標準的數據并行性和同步的梯度運行。在2048/128=16核上運行的每個副本中,我們實現了兩種類型的模型并行性。首先,將每個分辨率為512x512的輸入圖像沿寬度維度分割為16塊大小相等的512x32塊,并分配到16個核進行處理。請注意,我們選擇輸入分辨率為512x512,因為512與Noised Student使用的分辨率475x475接近,并且512將網絡中間輸出的維度保持為16整除。其次,每個權重張量也被平均分割為16個部分,分配給16個核。我們在XLA分片框架中實現了我們的混合數據模型并行性[37]。通過這種并行性,我們可以將2048個標記圖像和16384個未標記圖像的批量大小放入每個訓練步驟中。我們總共對模型進行100萬步的訓練,對于EfficientNet-L2大約需要11天,對于EfficientNet-B6-Wide大約需要10天。在完成元偽標簽訓練階段后,我們對標記數據集上的模型進行了20000個步驟的微調。詳細的微調程序見附錄C.4。
結果。
我們的結果如表4所示。從表中可以看出,元偽標簽在ImageNet上達到了90.2%的top-1精度,這是該數據集的最新技術。這一結果比使用嘈雜的學生[77]和FixRes[69,70]訓練的相同效率的ET-L2體系結構要好1.8%。元偽標簽的性能也優于BiT-L[33]的最新結果和Vision Transformer[14]的先前狀態。這里重要的對比是,Bit-L和Vision TRANSFORM都對來自JFT的3億個標記圖像進行預訓練,而我們的方法只使用來自該數據集的未標記圖像。在這種精度水平下,與最近的收益相比,我們在[16]上的收益為1.6%,這是一個非常顯著的改進幅度。例如,視覺轉換器[14]在嘈雜的學生+定影器上的增益僅為0.05%,而定影器在嘈雜的學生上的增益僅為0.1%。
最后,為了驗證我們的模型并不是簡單地過度適合ImageNet ILSVRC 2012驗證集,我們在ImageNet ReaL測試集上對其進行了測試[6]。在這個測試集上,我們的模型也運行良好,達到91.02%Precision@1這比視覺變壓器[14]好0.4%。這一差距也比視覺轉換器和吵鬧學生之間的差距大,后者僅為0.17%。
元偽標簽的精簡版本。
考慮到元偽標簽昂貴的訓練成本,我們設計了一個精簡版的元偽標簽,稱為精簡元偽標簽。我們在附錄E中描述了此lite版本,在附錄E中,我們使用EfficientNet-B7在ImageNet ILSRVC 2012驗證集上實現了86.9%的top-1精度。為了避免使用JFT等專有數據,我們使用ImageNet訓練集作為標記數據,使用YFCC100M數據集[65]作為未標記數據。減少元偽標簽使我們能夠實現元偽標簽的反饋機制,同時避免在內存中保留兩個網絡。
5.有關工程
偽標簽。
偽標簽方法也稱為自訓練,是一種簡單的半監督學習(SSL)方法,已成功應用于改善許多任務的最新技術,如:圖像分類[79,77]、對象檢測、語義分割[89]、機器翻譯[22]和語音識別[31,49]。傳統的偽標簽方法在學生學習過程中讓經過預培訓的教師保持固定,當偽標簽不準確時,會導致確認偏差[2]。與普通的偽標簽不同,元偽標簽繼續調整教師,以提高學生在標記數據集上的表現。這種額外的調整允許教師生成更好的偽標簽來教學生,如我們的實驗所示。
其他SSL方法。
其他典型的SSL方法通常通過優化目標函數來訓練單個模型,該目標函數結合了標記數據的監督損失和未標記數據的無監督損失。監督損失通常是在標記數據上計算的交叉熵。同時,無監督損失通常是自監督損失或標簽傳播損失。自我監督損失通常會鼓勵模型建立關于圖像的常識,例如繪畫[50],解決拼圖[47],預測旋轉角度[19],對比預測[25,10,8,9,38],或引導潛在空間[21]。另一方面,標簽傳播損失通常強制要求模型對數據的某些轉換保持不變,例如數據增強、對抗性攻擊或潛在空間中的接近[35、64、44、5、76、30、71、58、32、51、20]。元偽標簽與前面提到的SSL方法有兩個顯著的區別。首先,使用元偽標簽的學生從不直接從標簽數據學習,這有助于避免過度擬合,特別是在標簽數據有限的情況下。其次,元偽標簽中的教師從學生在標簽數據上的表現中接收到的信號是利用標簽數據的一種新方法。
知識提煉和標簽平滑。
元偽標簽中的教師使用其對未標記數據的softmax預測來教導學生。這些softmax預測通常被稱為軟標簽,已在知識提煉文獻中廣泛使用[26、17、86]。在蒸餾工作范圍之外,人工設計的軟標簽(如標簽平滑[45]和溫度銳化或阻尼[76,77])也被證明可以提高模型的泛化能力。這兩種方法都可以看作是調整訓練示例的標簽,以改進優化和泛化。與其他SSL方法類似,這些調整不會收到本文中提出的學生表現的任何反饋。附錄D.2中給出了比較元偽標簽和標簽平滑的實驗。
雙層優化算法。
我們在方法名稱中使用元,因為我們從學生反饋中得出教師更新規則的技術是基于元學習文獻中經常出現的雙層優化問題。已經提出了類似的雙層優化問題來優化模型的學習過程,例如學習學習速率計劃[3],設計架構[40],糾正錯誤的訓練標簽[88],生成訓練示例[59],以及重新加權訓練數據[73,74,54,53]。元偽標簽在這項工作中使用相同的雙層優化技術,從學生的反饋中得出教師的梯度。元偽標簽與這些方法的區別在于,元偽標簽采用雙層優化技術來改進教師模型生成的偽標簽。
6.結論
在本文中,我們提出了半監督學習的元偽標簽方法。元偽標簽的關鍵是教師從學生的反饋中學習,以最有助于學生學習的方式生成偽標簽。元偽標簽中的學習過程包括兩個主要更新:基于教師生成的偽標簽數據更新學生和基于學生表現更新教師。在標準低資源基準(如CIFAR-10-4K、SVHN-1K和ImageNet-10%)上的實驗表明,元偽標簽優于許多現有的半監督學習方法。Meta偽標簽也可以很好地擴展到大型問題,在ImageNet上達到90.2%的top-1精度,這比以前的最先進技術[16]要好1.6%。一致的收益證實了學生對教師反饋的好處。
致謝
作者希望感謝Rohan Anil、Frank Chen和Wang Tao在運行我們的實驗過程中對許多技術問題的幫助。我們還感謝David Berthelot、Nicholas Carlini、Sylvain Gelly、Geoff Hinton、Mohammad Norouzi和Colin Raffel對論文早期草稿的評論,以及Google Brain團隊中的其他人在整個漫長項目中的支持。Jaime Carbonell還建議我們消除Resnet模型ImageNet的數據加載瓶頸。當我們沒有足夠的備用TPU用于我們的ResNet工作時,他的建議幫助很大。他將被深深記住。