來源:https://www.researchgate.net/publication/364419868_The_Devil_in_Linear_Transformer
代碼:https://github.com/OpenNLPLab/Transnormer
這篇文章的目的是優化線性transformer,線性transformer相對于標準transformer能夠將計算復雜度從 降到
. 但線性transformer 相對于標準transformer 往往存在著較明顯的指標gap。作者分析認為原因有兩點:
- unbounded gradients。無邊界梯度,會導致模型在訓練時不穩定,收斂不好;
- attention dilution。注意力稀釋,transformer在lower level時應該更關注局部特征,而higher level更關注全局特征,但線性transformer中的attention往往weight 更均勻化,不能聚焦在local區域上,因此稱為attention稀釋。
針對于上述兩點,作者提出了NormAttention和DiagAttention兩個模塊,形成NormFormer的結構。
1.The devil in linear attention
我們首先來看一下作者分析的線性transformer存在的兩點缺陷的結論是怎么來的。
1.1 Unbounded gradients
在標準的attention結構中
正是這里的 帶來的
的計算復雜度。而為了解決這個問題目前主要包含兩類: 基于pattern的方法和基于kernel的方法。
基于pattern的方式主要是通過一些先驗篩選key或query,降低計算復雜度;而基于kernel的方法則是本文提到的線性transformer,通過核函數去取代softmax,從而能夠通過矩陣乘法結合律降低計算復雜度。
那么來看一下計算attention時,vanilla和linear transformer的統一形式:
對于vanilla transformer而言, , 對于linear transformer可以表示為
. 于是可以比較一下兩者的梯度:
vanilla attention: , 這里推理的時候注意湊
這里推理的時候只有 時邊界值成立,所以最終
linear attention: 線性attention的關鍵在于, 因此
即,
.
因為 大小是不確定的,所以相當于linear attention的梯度是無邊界的。這就會導致收斂不穩定,收斂難度大等問題。
1.2 Attention dilution
注意力稀釋方面,作者直接評估了不同level上,每一個query在鄰域內的其他query上的attention的權重占比,這里需要注意的是,query之間是有序的,即對于NLP或者featmap而言,是有固定結構的,才可以這么評估。表示第i個query在其
個鄰域query上的attention之和,可以看下圖,a圖中transformer和linear transformer相比,顯然linear transformer的聚集度要小很多。這就是所謂的注意力稀釋。
2. architecture
針對于1中的兩個問題,有針對性的設計了兩個模塊。
2.1 NormAttention.
作者提出的解決方案
,
這里的XNorm 可以是Layernorm,也可以是 RMSNorm。注意這里的Q,和K是有激活函數的,公式沒寫,但圖中畫了。
文章證明這個做法梯度是有上界的。附錄的證明過程有點復雜。
2.2 DiagAttention
這個模塊其實就是一種基于pattern的attention,將query按距離劃分不重疊的window,每個window內進行 attention的計算。奇怪的是 這里的attention使用的都是vanilla attention。
下圖是文章方法TransNormer的結構:
3. 實驗
實驗都是在NLP上做的,不大了解,因此不做分析,這里只看下消融實驗的結論。
table8. 表明早期的stage應當更關注局部特征,而后期的stage則應該更關注全局信息。
table9. 早期適合使用blockattn,后期適合使用normattn
table10. FFN中作者對比了FFN和GLU的結果,發現GLU效果會更好一些。
table11.表明diagattn中的window的大小,這個其實有有點說不通,如果DiagAttn使用的linear attention, block size越大不是attention 稀釋的越嚴重嗎? 這個地方DiagAttn使用的應該都是vanilla attention,包括softmax attention和ReLA attention.
4. 結論
本文提出的norm attention其實在很多其他方法中都見過,而且所謂的diag attention使用的還是vanilla attention,并沒有把linear attention應用到diag block里,感覺不是很充實。值得學習的是本文中提出的梯度分析的方法。