keras_regression

問題:神經網絡可以用來模擬回歸問題 (regression),例如給下面一組數據,用一條線來對數據進行擬合,并可以預測新輸入 x 的輸出值。

Paste_Image.png

用 Keras 構建回歸神經網絡的步驟:
1.導入模塊并創建數據
2.建立模型
3.激活模型
4.訓練模型
5.驗證模型
6.可視化結果


Demo.py

#導入模塊
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential  #models.Sequential,用來一層一層一層的去建立神經層;
from keras.layers import Dense   #layers.Dense 意思是這個神經層是全連接層

# 創建數據
X = np.linspace(-1, 1, 200)
np.random.shuffle(X)# 數據隨機化
Y = 0.5 * X + 2 + np.random.normal(0, 0.05, (200,))# 創建數據及參數, 并加入噪聲
# 繪制數據
plt.scatter(X, Y)
plt.show()
# 分為訓練數據和測試數據
X_train, Y_train = X[:160], Y[:160]   # train 前 160 data points
X_test, Y_test = X[160:], Y[160:]   # test 后 40 data points

# 使用keras創建神經網絡
# Sequential是指一層層堆疊的神經網絡
# Dense是指全連接層

#建立模型
model = Sequential()# 用 Sequential 建立 model
model.add(Dense(units = 1, input_dim = 1))#model.add 添加神經層,添加的是 Dense 全連接神經層。
#Dense參數有兩個,一個是輸入數據和輸出數據的維度,本代碼的例子中 x 和 y 是一維的。
#如果需要添加下一個神經層的時候,不用再定義輸入的緯度,因為它默認就把前一層的輸出作為當前層的輸入。
#在這個例子里,只需要一層就夠了。

# 激活模型 
#選擇損失函數和優化方法
model.compile(loss = 'mse', optimizer = 'sgd')#誤差函數用的是 mse 均方誤差;優化器用的是 sgd 隨機梯度下降法

print '----Training----'
# 訓練過程
for step in range(501):
    # 進行訓練, 返回損失(代價)函數
    cost = model.train_on_batch(X_train, Y_train)#訓練的時候用 model.train_on_batch 一批一批的訓練 X_train, Y_train。默認的返回值是 cost
    if step % 100 == 0:#每100步輸出一下結果
        print 'loss: ', cost
        
#檢驗模型
print '----Testing----'    
# 訓練結束進行測試
cost = model.evaluate(X_test, Y_test, batch_size = 40)
print 'test loss: ', cost

# 獲取參數
W, b = model.layers[0].get_weights()#weights 和 biases 是取在模型的第一層 model.layers[0] 學習到的參數
print 'Weights: ',W
print 'Biases: ', b

#可視化結果
# plotting the prediction
Y_pred = model.predict(X_test)
plt.scatter(X_test, Y_test)
plt.plot(X_test, Y_pred)
plt.show()

結果:

Paste_Image.png
----Training----
----Training----
loss:  4.05043315887
loss:  0.0760689899325
loss:  0.00436494173482
loss:  0.00265302229673
loss:  0.00251104100607
loss:  0.00248079258017
----Testing----
40/40 [==============================] - 0s
test loss:  0.00255159125663
Weights:  [[ 0.49018186]]
Biases:  [ 2.00758481]
Paste_Image.png
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容