Sequence to Sequence學習最早由Bengio在2014年的論文中提出。
這篇文章主要是提供了一種嶄新的RNN Encoder-Decoder算法,并且將其應用于機器翻譯中。
這種算法也是現在谷歌已經應用于線上機器翻譯的算法,翻譯質量基本達到、甚至超越人類水平。
所謂的RNN Encoder-Decoder結構,簡單的來說就是算法包含兩部分,一個負責對輸入的信息進行Encoding,將輸入轉換為向量形式。
然后由Decoder對這個向量進行解碼,還原為輸出序列。
而RNN Encoder-Decoder結構就是編碼器與解碼器都是使用RNN算法,一般為LSTM。
LSTM的優勢在于處理序列,它可以將上文包含的信息保存在隱藏狀態中,這樣就提高了算法對于上下文的理解能力。
Encoder與Decoder各自可以算是單獨的模型,一般是一層或多層的LSTM。
LSTM
LSTM是Long-short Term Memory的縮寫,是RNN算法中的一種。
它很好的抑制了原始RNN算法中的梯度消失彌散(Vanishing Gradient)問題。
一個LSTM神經元(Cell)可以接收兩個信息,其中一個是序列的某一位輸入,另一個是上一輪的隱藏狀態。
而一個LSTM神經元也會產生兩個信息,一個是當前輪的輸出,另一個是當前輪的隱藏狀態。
假設我們輸入序列長度為2
,輸出序列長度也為2
,流程如下:
圖中畫了兩個LSTM神經元,不過實際上只有一個,只是它要處理不同時序(t)的信息。
從序列,到序列
以機器翻譯為例,假設我們要將How are you
翻譯為你好嗎
,模型要做的事情如下圖:
上圖中,LSTM Encoder是一個LSTM神經元,Decoder是另一個,Encoder自身運行了3
次,Decoder運行了4
次。
可以看出,Encoder的輸出會被拋棄,我們只需要保留隱藏狀態(即圖中EN狀態)作為下一次ENCODER的狀態輸入。
Encoder的最后一輪輸出狀態會與Decoder的輸入組合在一起,共同作為Decoder的輸入。
而Decoder的輸出會被保留,當做下一次的的輸入。注意,這是在說預測時時的情況,一般在訓練時一般會用真正正確的輸出序列內容,而預測時會用上一輪Decoder的輸出。
給Decoder的第一個輸入是<S>
,這是我們指定的一個特殊字符,它用來告訴Decoder,你該開始輸出信息了。
而最末尾的<E>
也是我們指定的特殊字符,它告訴我們,句子已經要結束了,不用再運行了。
偽數學
從更高層的角度來看算法,整個模型也無非是一種從輸入到輸出的函數映射。
我們已知的輸入數據是How are you
,我們希望的輸出是你好啊
,
模型學習了下面這些函數映射,組成了一個單射函數:
{ How, are, you, < S > } ---> {你}
{ How, are, you, < S >, 你 } ---> {好}
{ How, are, you, < S >, 你, 好 } ---> {嗎}
{ How, are, you, < S >, 你, 好, 嗎 } ---> {< E >}
為什么這么麻煩?
我們說,本質上RNN Encoder Decoder模型也是一種函數映射的學習,
那么我們能不能用其他模型學習這樣的映射關系?
理論上是可以的,但是實際上傳統機器學習模型很難學習這樣多的映射信息,算法所需要的VC維度太高,
而且很難如RNN模型一樣,很好的保留序列的上下文信息(例如語序),使得模型的訓練非常困難。
應用
Sequence to Sequence模型已經被谷歌成功應用于機器翻譯上。
而理論上任意的序列到序列的有監督問題都可以用這種模型。
- 古詩生成,輸入上一句,輸出下一句
- 對聯生成,輸入上聯,輸出下聯
- 有標注的分詞訓練,輸入一句話,輸出分詞序列
- 有標注的命名實體識別訓練
- 輸入前10天的股價,輸出后10天的股價
- 對話機器人,輸入用戶對話,輸出機器人的回答
當然對于這些問題,實踐中能否有效,模型的具體結構與參數,都是有待研究的。
Trick
雖然LSTM能避免梯度彌散問題,但是不能對抗梯度爆炸問題(Exploding Gradient)。
為了對抗梯度爆炸,一般會對梯度進行裁剪。
梯度剪裁的方法一般有兩種,一種是當梯度的某個維度絕對值大于某個上限的時候,就剪裁為上限。
另一種是梯度的L2范數大于上限后,讓梯度除以范數,避免過大。
Bengio的原文中用的另一個trick是他們的輸入序列是反向輸入的,也就是說實際輸入模型的順序并不是
How are you
而是you are How
。至于為什么這樣效果更好,還是一個迷。
參考
Bengio 2014 https://arxiv.org/pdf/1406.1078.pdf
Tensorflow seq2seq tutorial https://www.tensorflow.org/tutorials/seq2seq/