TensorFlow HOWTO 4.2 多層感知機回歸(時間序列)

4.2 多層感知機回歸(時間序列)

這篇教程中,我們使用多層感知機來預測時間序列,這是回歸問題。

操作步驟

導入所需的包。

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

導入數據,并進行預處理。我們使用國際航班乘客數據集,由于它不存在于任何現有庫中,我們需要先下載它。

ts = pd.read_csv('international-airline-passengers.csv', usecols=[1], header=0).dropna().values.ravel()

之后,我們需要將其轉換為結構化數據集。我知道時間序列有很多實用的特征,但是這篇教程中,為了展示 MLP 的強大,我僅僅使用最簡單的特征,也就是乘客數的歷史值,并根據歷史值來預測當前值。為此,我們需要一個窗口大小,也就是幾個歷史值與當前值有關。

wnd_sz = 5
ds = []
for i in range(0, len(ts) - wnd_sz + 1):
    ds.append(ts[i:i + wnd_sz])
ds = np.asarray(ds)

x_ = ds[:, 0:wnd_sz - 1]
y_ = ds[:, [wnd_sz - 1]]

之后是訓練集和測試集的劃分。為時間序列劃分訓練集和測試集的時候,絕對不能打亂,而是應該把前一部分當做訓練集,后一部分當做測試集。因為在時間序列中,未來值依賴歷史值,而歷史值不依賴未來值,這樣可以盡可能避免在訓練中使用測試集的信息。

train_size = int(len(x_) * 0.7)
x_train = x_[:train_size]
y_train = y_[:train_size]
x_test = x_[train_size:]
y_test  = y_[train_size:]

定義超參數。時間序列很容易過擬合,為了避免過擬合,建議不要將迭代數設置太大。

變量 含義
n_input 樣本特征數
n_epoch 迭代數
n_hidden1 隱層 1 的單元數
n_hidden2 隱層 2 的單元數
lr 學習率
n_input = wnd_sz - 1
n_hidden1 = 8
n_hidden2 = 8
n_epoch = 10000
lr = 0.05

搭建模型。要注意隱層的激活函數使用了目前暫時最優的 ELU。由于這個是回歸問題,并且標簽的取值是正數,輸出層激活函數最好是 ReLU,不過我這里用了f(x)=x

變量 含義
x 輸入
y 真實標簽
w_l{1,2,3} {1,2,3}層的權重
b_l{1,2,3} {1,2,3}層的偏置
z_l{1,2,3} {1,2,3}層的中間變量,前一層輸出的線性變換
a_l{1,2,3} {1,2,3}層的輸出,其中a_l3是模型輸出
x = tf.placeholder(tf.float64, [None, n_input])
y = tf.placeholder(tf.float64, [None, 1])
w_l1 = tf.Variable(np.random.rand(n_input, n_hidden1))
b_l1 = tf.Variable(np.random.rand(1, n_hidden1))
w_l2 = tf.Variable(np.random.rand(n_hidden1, n_hidden2))
b_l2 = tf.Variable(np.random.rand(1, n_hidden2))
w_l3 = tf.Variable(np.random.rand(n_hidden2, 1))
b_l3 = tf.Variable(np.random.rand(1, 1))
z_l1 = x @ w_l1 + b_l1
a_l1 = tf.nn.elu(z_l1)
z_l2 = a_l1 @ w_l2 + b_l2
a_l2 = tf.nn.elu(z_l2)
z_l3 = a_l2 @ w_l3 + b_l3
a_l3 = z_l3

定義 MSE 損失、優化操作、和 R 方度量指標。

變量 含義
loss 損失
op 優化操作
r_sqr R 方
loss = tf.reduce_mean((a_l3 - y) ** 2)
op = tf.train.AdamOptimizer(lr).minimize(loss)

y_mean = tf.reduce_mean(y)
r_sqr = 1 - tf.reduce_sum((y - z_l3) ** 2) / tf.reduce_sum((y - y_mean) ** 2)

使用訓練集訓練模型。

losses = []
r_sqrs = []

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    for e in range(n_epoch):
        _, loss_ = sess.run([op, loss], feed_dict={x: x_train, y: y_train})
        losses.append(loss_)

使用測試集計算 R 方。

        r_sqr_ = sess.run(r_sqr, feed_dict={x: x_test, y: y_test})
        r_sqrs.append(r_sqr_)

每一百步打印損失和度量值。

        if e % 100 == 0:
            print(f'epoch: {e}, loss: {loss_}, r_sqr: {r_sqr_}')

得到模型對訓練特征和測試特征的預測值。

    y_train_pred = sess.run(a_l3, feed_dict={x: x_train})
    y_test_pred = sess.run(a_l3, feed_dict={x: x_test})

輸出:

epoch: 0, loss: 59209399.053257026, r_sqr: -17520.903006130215
epoch: 100, loss: 54125.98862726741, r_sqr: -28.30371839204463
epoch: 200, loss: 48165.48221823986, r_sqr: -25.13646606476775
epoch: 300, loss: 25826.1223418781, r_sqr: -12.89535810028511
epoch: 400, loss: 1596.701326728818, r_sqr: -0.2350739792412242
epoch: 500, loss: 1396.8836047513207, r_sqr: -0.19979831972491247
epoch: 600, loss: 1386.2307618333675, r_sqr: -0.18952804825771152
epoch: 700, loss: 1374.6194509485028, r_sqr: -0.17864308160044695
epoch: 800, loss: 1362.1306530753875, r_sqr: -0.1669310907644168
epoch: 900, loss: 1348.837516403113, r_sqr: -0.15445861855695942
epoch: 1000, loss: 1334.8048545363076, r_sqr: -0.14128485137041857
epoch: 1100, loss: 1320.0909505177317, r_sqr: -0.1274628487494167
epoch: 1200, loss: 1304.7487050247062, r_sqr: -0.11304061847816937
epoch: 1300, loss: 1288.8264056578596, r_sqr: -0.09806183013209635
epoch: 1400, loss: 1272.368278888685, r_sqr: -0.08256632457099822
epoch: 1500, loss: 1255.4149176209135, r_sqr: -0.06659050598517702
epoch: 1600, loss: 1238.0036386736738, r_sqr: -0.05016766677562812
epoch: 1700, loss: 1220.1688168734415, r_sqr: -0.033328289031115066
epoch: 1800, loss: 1201.942219268364, r_sqr: -0.016100344309701198
epoch: 1900, loss: 1183.3533635535823, r_sqr: 0.001490385551745188
epoch: 2000, loss: 1164.4299150007857, r_sqr: 0.019419953232036713
epoch: 2100, loss: 1145.1981380968932, r_sqr: 0.037665840800028216
epoch: 2200, loss: 1125.683416643312, r_sqr: 0.056206491416253
epoch: 2300, loss: 1105.9108537794123, r_sqr: 0.07502076650340628
epoch: 2400, loss: 1085.9059630721106, r_sqr: 0.0940873169623373
epoch: 2500, loss: 1065.6954608224203, r_sqr: 0.11338385761808756
epoch: 2600, loss: 1045.3081603639205, r_sqr: 0.13288634237228225
epoch: 2700, loss: 1024.7759657787278, r_sqr: 0.15256803995690105
epoch: 2800, loss: 1004.1349451931811, r_sqr: 0.172398525823234
epoch: 2900, loss: 983.4264457196117, r_sqr: 0.1923426217757903
epoch: 3000, loss: 962.6981814611527, r_sqr: 0.2123593431286731
epoch: 3100, loss: 942.0051856419402, r_sqr: 0.23240095064001043
epoch: 3200, loss: 921.4104707924716, r_sqr: 0.25241224904623527
epoch: 3300, loss: 900.9855546712687, r_sqr: 0.2723303372744077
epoch: 3400, loss: 880.8157626399897, r_sqr: 0.2920819860142483
epoch: 3500, loss: 860.9797390814788, r_sqr: 0.3115862647284954
epoch: 3600, loss: 841.5508166258288, r_sqr: 0.3307714419429332
epoch: 3700, loss: 822.65464452708, r_sqr: 0.34954174287751905
epoch: 3800, loss: 804.3510636509582, r_sqr: 0.36784090179299245
epoch: 3900, loss: 786.698489755282, r_sqr: 0.385608999078319
epoch: 4000, loss: 769.742814493071, r_sqr: 0.4028012203684326
epoch: 4100, loss: 753.5066274222577, r_sqr: 0.41939554512386545
epoch: 4200, loss: 737.987317032155, r_sqr: 0.435393667355777
epoch: 4300, loss: 723.1589061501688, r_sqr: 0.450820383097843
epoch: 4400, loss: 708.9775872199175, r_sqr: 0.4657183502307326
epoch: 4500, loss: 695.3902622132391, r_sqr: 0.48014080374835966
epoch: 4600, loss: 682.3445164530003, r_sqr: 0.49414259666687754
epoch: 4700, loss: 669.7979767738486, r_sqr: 0.5077712354635172
epoch: 4800, loss: 657.7254031658086, r_sqr: 0.5210596011864659
epoch: 4900, loss: 646.1225385082785, r_sqr: 0.5340213253103482
epoch: 5000, loss: 635.0063312881094, r_sqr: 0.5466495174400207
epoch: 5100, loss: 624.413372450103, r_sqr: 0.5589148722286865
epoch: 5200, loss: 614.3949181844811, r_sqr: 0.5707707632698614
epoch: 5300, loss: 605.0111648523978, r_sqr: 0.5821553829261457
epoch: 5400, loss: 596.3251407668536, r_sqr: 0.5929950345771348
epoch: 5500, loss: 588.394934666788, r_sqr: 0.6032110372771016
epoch: 5600, loss: 581.2665165995329, r_sqr: 0.6127246745768582
epoch: 5700, loss: 574.966974465473, r_sqr: 0.6214638760543536
epoch: 5800, loss: 569.4991301790525, r_sqr: 0.6293697644027529
epoch: 5900, loss: 564.8386101443393, r_sqr: 0.6364025085215417
epoch: 6000, loss: 560.9342529396988, r_sqr: 0.6425455981413335
epoch: 6100, loss: 557.7121333454768, r_sqr: 0.6478080528894717
epoch: 6200, loss: 555.082871816442, r_sqr: 0.6522228049088226
epoch: 6300, loss: 552.9504878905169, r_sqr: 0.655844018250897
epoch: 6400, loss: 551.2218559344465, r_sqr: 0.6587415604098865
epoch: 6500, loss: 549.8130443197065, r_sqr: 0.6609962191562182
epoch: 6600, loss: 548.6541928338876, r_sqr: 0.6626926870925417
epoch: 6700, loss: 547.6907665851817, r_sqr: 0.6639154520465458
epoch: 6800, loss: 546.8824436922644, r_sqr: 0.6647451519232287
epoch: 6900, loss: 546.2005216521292, r_sqr: 0.6652561458635722
epoch: 7000, loss: 545.624816174369, r_sqr: 0.6655151008532998
epoch: 7100, loss: 545.1407754175443, r_sqr: 0.6655803536739032
epoch: 7200, loss: 544.737161207175, r_sqr: 0.6655015768007708
epoch: 7300, loss: 544.4043734455902, r_sqr: 0.6653207475362519
epoch: 7400, loss: 544.1341742901118, r_sqr: 0.6650726389557293
epoch: 7500, loss: 543.9183617756419, r_sqr: 0.6647841774748728
epoch: 7600, loss: 543.7491045438804, r_sqr: 0.6644772215677328
epoch: 7700, loss: 543.6189281327097, r_sqr: 0.664168178900878
epoch: 7800, loss: 543.5208538865398, r_sqr: 0.6638692273468367
epoch: 7900, loss: 543.448546441615, r_sqr: 0.6635888904468247
epoch: 8000, loss: 543.3964235871173, r_sqr: 0.6633326568412115
epoch: 8100, loss: 543.3597144375057, r_sqr: 0.6631036869053042
epoch: 8200, loss: 543.334470694699, r_sqr: 0.6629032437625964
epoch: 8300, loss: 543.3175272915516, r_sqr: 0.6627311295544881
epoch: 8400, loss: 543.3064268470205, r_sqr: 0.6625860646477272
epoch: 8500, loss: 543.2993218444616, r_sqr: 0.6624660124009056
epoch: 8600, loss: 543.2948676666707, r_sqr: 0.6623684566822827
epoch: 8700, loss: 543.2921173712779, r_sqr: 0.66229063392118
epoch: 8800, loss: 543.2904259948474, r_sqr: 0.6622297208034926
epoch: 8900, loss: 543.2893689519885, r_sqr: 0.6621829791500562
epoch: 9000, loss: 543.2886762658412, r_sqr: 0.6621478605931548
epoch: 9100, loss: 543.2881822202421, r_sqr: 0.6621220749401975
epoch: 9200, loss: 543.2877886467712, r_sqr: 0.6621036272136598
epoch: 9300, loss: 543.2874393833906, r_sqr: 0.6620908290593257
epoch: 9400, loss: 543.2871033165793, r_sqr: 0.662082291939509
epoch: 9500, loss: 543.2867636136157, r_sqr: 0.6620769002407012
epoch: 9600, loss: 543.2864111990842, r_sqr: 0.6620737858271186
epoch: 9700, loss: 543.2860410035256, r_sqr: 0.6620722889010859
epoch: 9800, loss: 543.285649892272, r_sqr: 0.6620719225552976
epoch: 9900, loss: 543.2852355789513, r_sqr: 0.662072338159642

繪制時間序列及其預測值。

plt.figure()
plt.plot(ts, label='Original')
y_train_pred = np.concatenate([
    [np.nan] * n_input, 
    y_train_pred.ravel()
])
y_test_pred = np.concatenate([
    [np.nan] * (n_input + train_size),
    y_test_pred.ravel()
])
plt.plot(y_train_pred, label='y_train_pred')
plt.plot(y_test_pred, label='y_test_pred')
plt.legend()
plt.show()
image

繪制訓練集上的損失。

plt.figure()
plt.plot(losses)
plt.title('Loss on Training Set')
plt.xlabel('#epoch')
plt.ylabel('MSE')
plt.show()
image

繪制測試集上的 R 方。

plt.figure()
plt.plot(r_sqrs)
plt.title('$R^2$ on Testing Set')
plt.xlabel('#epoch')
plt.ylabel('$R^2$')
plt.show()
image

擴展閱讀

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 230,825評論 6 546
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 99,814評論 3 429
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 178,980評論 0 384
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 64,064評論 1 319
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 72,779評論 6 414
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 56,109評論 1 330
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 44,099評論 3 450
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 43,287評論 0 291
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,799評論 1 338
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 41,515評論 3 361
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,750評論 1 375
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 39,221評論 5 365
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,933評論 3 351
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 35,327評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,667評論 1 296
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 52,492評論 3 400
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,703評論 2 380

推薦閱讀更多精彩內容