RNN教程之-2 LSTM實戰

本文由清華大學碩士大神金天撰寫,歡迎大家轉載,不過請保留這段版權信息,對本文內容有疑問歡迎聯系作者微信:jintianiloveu探討,多謝合作~

UPDATE:
2019-02-20: 為大家推薦一個學習人工智能算法的好地方,奇異AI算法平臺,海量人工智能算法每周更新,你不需要學,哪怕跑一跑都能學個八九不離十,里面的算法都是原創的:http://strangeai.pro
2017-4-11: 這篇是之前寫的文章,關于時間序列的更新版本在這里, 稍后會開源所有代碼。

前言

說出來你們不敢相信,剛才碼了半天的字,一個側滑媽的全沒了,都怪這Mac的觸摸板太敏感沃日。好吧,不浪費時間了,前言一般都是廢話,這個教程要解決的是一個LSTM的實戰問題,很多人問我RNN是啥,有什么卵用,你可以看看我之前寫的博客可以入門,但是如果你想實際操作代碼,那么慢慢看這篇文章。本文章所有代碼和數據集在我的Github Repository下載。

問題

給你一個數據集,只有一列數據,這是一個關于時間序列的數據,從這個時間序列中預測未來一年某航空公司的客運流量。

首先我們數據預覽一下,用pandas讀取數據,這里我們只需要使用后一列真實數據,如果你下載了數據,數據大概長這樣:

      time       passengers
0    1949-01         112
1    1949-02         118
2    1949-03         132
3    1949-04         129
4    1949-05         121
5    1949-06         135
6    1949-07         148
7    1949-08         148
8    1949-09         136
9    1949-10         119
...    ...          ....

第一列是時間,第二列是客流量,為了看出這個我們要預測的客流量隨時間的變化趨勢,本大神教大家如何把趨勢圖畫出來,接下來就非常牛逼了。用下面的代碼來畫圖:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

df = pd.read_csv('international-airline-passengers.csv', sep=',')
df = df.set_index('time')
df['passengers'].plot()
plt.show()

這時候我們可以看到如下的趨勢圖:

figure_1.png

可以看出,我們的數據存在一定的周期性,這個周期性并不是一個重復出現某個值,而是趨勢的增長過程有一定的規律性,這個我們人肉眼就能看得出來,但是實際上計算機要識別這種規律就有一定的難度了,這時候就需要使用我們的LSTM大法。
好的,數據已經預覽完了,接下來我們得思考一下怎么預測,怎么把數據處理為LSTM網絡需要的格式。

LSTM數據預處理

這個過程非常重要,這也是很多水平不高的博客或者文章中沒有具體闡述而導致普通讀者不知道毛意思的過程,其實我可以這樣簡單的敘述,LSTM你不要以為各種時間序列搞的暈頭轉向,其實本質它還是神經網絡,與普通的神經網絡沒有任何區別。我們接下來就用幾行小代碼把數據處理為我們需要的類似于神經網絡輸入的二維數據。
首先我們確確實實需要的只是一列數據:

df = pd.read_csv(file_name, sep=',', usecols=[1])
data_all = np.array(df).astype(float)
print(data_all)

輸出是:

[[ 112.]
 [ 118.]
 [ 132.]
 [ 129.]
 [ 121.]
 [ 135.]
 [ 148.]
 [ 148.]
 [ 136.]
 [ 119.]
 [ 104.]
 [ 118.]
 [ 115.]
....
]

非常好,現在我們已經把我們需要的數據摳出來了,繼續上面處理:

data = []
for i in range(len(data_all) - sequence_length - 1):
    data.append(data_all[i: i + sequence_length + 1])
reshaped_data = np.array(data).astype('float64')
print(reshaped_data)

這時候你會發現好像結果看不懂,不知道是什么數據,如果你data_all處理時加ravel()(用來把數據最里面的中括號去掉),即:

df = pd.read_csv(file_name, sep=',', usecols=[1])
data_all = np.array(df).ravel().astype(float)
print(data_all)

那么數據輸出一目了然:

[[ 112.  118.  132. ...,  136.  119.  104.]
 [ 118.  132.  129. ...,  119.  104.  118.]
 [ 132.  129.  121. ...,  104.  118.  115.]
 ..., 
 [ 362.  405.  417. ...,  622.  606.  508.]
 [ 405.  417.  391. ...,  606.  508.  461.]
 [ 417.  391.  419. ...,  508.  461.  390.]]

是的,沒有錯!一列數據經過我們這樣不處理就可以作為LSTM網絡的輸入數據了,而且和神經網絡沒有什么兩樣!!牛逼吧?牛逼快去哥的Github Repo給個star,喊你們寢室的菜市場的大爺大媽都來贊!越多越好,快,哥的大牛之路就靠你們了!
然而這還是只是開始。。接下來要做的就是把數據切分為訓練集和測試集:

split = 0.8
np.random.shuffle(reshaped_data)
x = reshaped_data[:, :-1]
y = reshaped_data[:, -1]
split_boundary = int(reshaped_data.shape[0] * split)
train_x = x[: split_boundary]
test_x = x[split_boundary:]

train_y = y[: split_boundary]
test_y = y[split_boundary:]

這些步驟相信聰明的你一點看得懂,我就不多廢話了,我要說明的幾點是,你運行時直接運行Github上的腳本代碼,如果報錯請私信我微信jintianiloveu,我在代碼中把過程包裝成了函數所以文章中的代碼可能不太一樣。在實際代碼中數據是需要歸一化的,這個你應該知道,如何歸一化代碼中也有。

搭建LSTM模型

好,接下來是最牛逼的部分,也是本文章的核心內容(但實際內容并不多),數據有了,我們就得研究研究LSTM這個東東,不管理論上吹得多么牛逼,我只看它能不能解決問題,不管黑貓白貓,能抓到老鼠的就是好貓,像我們這樣不搞偽學術注重經濟效益的商人來說,這點尤為重要。搭建LSTM模型,我比較推薦使用keras,快速簡單高效,分分鐘,但是犧牲的是靈活性,不過話又說回來,真正的靈活性也是可以發揮的,只是要修改底層的東西那就有點麻煩了,我們反正是用它來解決問題的,更基礎的部分我們就不研究了,以后有時間再慢慢深入。
在keras 的官方文檔中,說了LSTM是整個Recurrent層實現的一個具體類,它需要的輸入數據維度是:

形如(samples,timesteps,input_dim)的3D張量

發現沒有,我們上面處理完了數據的格式就是(samples,timesteps)這個time_step是我們采用的時間窗口,把一個時間序列當成一條長鏈,我們固定一個一定長度的窗口對這個長鏈進行采用,最終就得到上面的那個二維數據,那么我們缺少的是input_dim這個維度,實際上這個input_dim就是我們的那一列數據的數據,我們現在處理的是一列也有可能是很多列,一系列與時間有關的數據需要我們去預測,或者文本處理中會遇到。我們先不管那么多,先把數據處理為LSTM需要的格式:

train_x = np.reshape(train_x, (train_x.shape[0], train_x.shape[1], 1))
test_x = np.reshape(test_x, (test_x.shape[0], test_x.shape[1], 1))

好的,這時候數據就是我們需要的啦。接下來搭建模型:

# input_dim是輸入的train_x的最后一個維度,train_x的維度為(n_samples, time_steps, input_dim)
model = Sequential()
model.add(LSTM(input_dim=1, output_dim=50, return_sequences=True))
model.add(LSTM(100, return_sequences=False))
model.add(Dense(output_dim=1))
model.add(Activation('linear'))
model.compile(loss='mse', optimizer='rmsprop')

看到沒,這個LSTM非常簡單!!甚至跟輸入的數據格式沒有任何關系,只要輸入數據的維度是1,就不需要修改模型的任何參數就可以把數據輸入進去進行訓練!
我們這里使用了兩個LSTM進行疊加,第二個LSTM第一個參數指的是輸入的維度,這和第一個LSTM的輸出維度并不一樣,這也是LSTM比較“隨意”的地方。最后一層采用了線性層。

結果

預測的結果如下圖所示:


result.png

這個結果還是非常牛逼啊,要知道我們的數據是打亂過得噢,也就是說泛化能力非常不錯,厲害了word LSTM!
筒子們,本系列教程到此結束,歡迎再次登錄老司機的飛船。。。。如果有不懂的私信我,想引起我的注意快去Github上給我star!!!

系列文章結尾安利:Python深度學習基地群 216912253一個即談理想又談技術的技術人聚集地,歡迎加入。

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 229,836評論 6 540
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 99,275評論 3 428
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 177,904評論 0 383
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,633評論 1 317
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 72,368評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,736評論 1 328
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,740評論 3 446
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,919評論 0 289
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,481評論 1 335
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 41,235評論 3 358
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,427評論 1 374
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,968評論 5 363
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,656評論 3 348
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 35,055評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,348評論 1 294
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,160評論 3 398
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,380評論 2 379

推薦閱讀更多精彩內容