閱讀筆記 - The Devil in Linear Transformer

來源:https://www.researchgate.net/publication/364419868_The_Devil_in_Linear_Transformer
代碼:https://github.com/OpenNLPLab/Transnormer


這篇文章的目的是優化線性transformer,線性transformer相對于標準transformer能夠將計算復雜度從 O(N^2C) 降到O(NC^2). 但線性transformer 相對于標準transformer 往往存在著較明顯的指標gap。作者分析認為原因有兩點:

  • unbounded gradients。無邊界梯度,會導致模型在訓練時不穩定,收斂不好;
  • attention dilution。注意力稀釋,transformer在lower level時應該更關注局部特征,而higher level更關注全局特征,但線性transformer中的attention往往weight 更均勻化,不能聚焦在local區域上,因此稱為attention稀釋。
    針對于上述兩點,作者提出了NormAttention和DiagAttention兩個模塊,形成NormFormer的結構。

1.The devil in linear attention

我們首先來看一下作者分析的線性transformer存在的兩點缺陷的結論是怎么來的。

1.1 Unbounded gradients

在標準的attention結構中
O = \text{softmax}(QK^T/\sqrt{D})V, ~~ Q=XW_Q, K=XW_K, V=XW_V
正是這里的QK^T 帶來的O(N^C)的計算復雜度。而為了解決這個問題目前主要包含兩類: 基于pattern的方法和基于kernel的方法。
基于pattern的方式主要是通過一些先驗篩選key或query,降低計算復雜度;而基于kernel的方法則是本文提到的線性transformer,通過核函數去取代softmax,從而能夠通過矩陣乘法結合律降低計算復雜度。
那么來看一下計算attention時,vanilla和linear transformer的統一形式:
p_{ij} = \frac{f(s_{ij})}{\sum_{k-1}^n f(s_{ik})}
對于vanilla transformer而言, s_{ij} = q_i^Tk_j/\sqrtcn69k16, ~~ f(x) = \text{exp}(x), 對于linear transformer可以表示為 s_{ij} = \phi(q_i)\phi(k_j)^T,~~f(x)=x. 于是可以比較一下兩者的梯度:
vanilla attention: \frac{\partial p_{ij}}{\partial s_{ik}} = \frac{f'(s_{ik})}{f(s_{ik})}\big(1_{j=k}p_{ij} - p_{ij}p_{ik}\big), 這里推理的時候注意湊p_{ij}, p_{ik}
f'(x) = \text{exp}(x) = f(x) \\ \frac{\partial p_{ij}}{\partial s_{ik}} = 1_{j=k}p_{ij} - p_{ij}p_{ik} \\ = \begin{cases} p_{ik} - p_{ij}p_{ik}\in [0, 1/4], &j=k \\ - p_{ij}p_{ik}\in [-1/4, 0],& j\neq k\end{cases}
這里推理的時候只有p_{ik} = p_{ij} 時邊界值成立,所以最終
\Big \vert \frac{\partial p_{ij}}{\partial s_{ik}}\Big\vert \le \frac{1}{4}

linear attention: 線性attention的關鍵在于f'(x) = 1, 因此
f'(x) =1 \\ \frac{\partial p_{ij}}{\partial s_{ik}} = \frac{1}{s_{ik}} \big(1_{j=k}p_{ij} - p_{ij}p_{ik}\big) \\ = \frac{1}{s_{ik}}\begin{cases} p_{ik} - p_{ij}p_{ik}, &j=k \\ - p_{ij}p_{ik},& j\neq k\end{cases} 即,\Big \vert \frac{\partial p_{ij}}{\partial s_{ik}}\Big\vert \le \frac{1}{4|s_{ik}|}.
因為s_{ik} = \phi(q_i)\phi(q_k)^T 大小是不確定的,所以相當于linear attention的梯度是無邊界的。這就會導致收斂不穩定,收斂難度大等問題。

1.2 Attention dilution

注意力稀釋方面,作者直接評估了不同level上,每一個query在鄰域內的其他query上的attention的權重占比,這里需要注意的是,query之間是有序的,即對于NLP或者featmap而言,是有固定結構的,才可以這么評估。l(i, r, N)表示第i個query在其rN個鄰域query上的attention之和,可以看下圖,a圖中transformer和linear transformer相比,顯然linear transformer的聚集度要小很多。這就是所謂的注意力稀釋。

image.png

2. architecture

針對于1中的兩個問題,有針對性的設計了兩個模塊。

2.1 NormAttention.

作者提出的解決方案
O = Q(K^TV) \\ O_{norm} = \text{XNorm}(Q(K^TV)),
這里的XNorm 可以是Layernorm,也可以是 RMSNorm。注意這里的Q,和K是有激活函數的,公式沒寫,但圖中畫了。
\text{RMSNorm}(x) = \frac{x}{\sqrt{\sigma^2 + \epsilon}} \\ \sigma^2 = \sum_{i=1}^d x_i^2 /d , \epsilon > 0,
文章證明這個做法梯度是有上界的。附錄的證明過程有點復雜。

2.2 DiagAttention

這個模塊其實就是一種基于pattern的attention,將query按距離劃分不重疊的window,每個window內進行 attention的計算。奇怪的是 這里的attention使用的都是vanilla attention。

下圖是文章方法TransNormer的結構:


image.png

3. 實驗

實驗都是在NLP上做的,不大了解,因此不做分析,這里只看下消融實驗的結論。

image.png

table8. 表明早期的stage應當更關注局部特征,而后期的stage則應該更關注全局信息。
table9. 早期適合使用blockattn,后期適合使用normattn
table10. FFN中作者對比了FFN和GLU的結果,發現GLU效果會更好一些。
image.png

table11.表明diagattn中的window的大小,這個其實有有點說不通,如果DiagAttn使用的linear attention, block size越大不是attention 稀釋的越嚴重嗎? 這個地方DiagAttn使用的應該都是vanilla attention,包括softmax attention和ReLA attention.

4. 結論

本文提出的norm attention其實在很多其他方法中都見過,而且所謂的diag attention使用的還是vanilla attention,并沒有把linear attention應用到diag block里,感覺不是很充實。值得學習的是本文中提出的梯度分析的方法。

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容