PyTorch基本用法(六)——快速搭建網絡

文章作者:Tyan
博客:noahsnail.com ?|? CSDN ?|? 簡書

本文主要是關于PyTorch的一些用法。

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

from torch.autograd import Variable

# 許多沒解釋的東西可以去查文檔, 文檔中都有, 已查過
# pytorch文檔: http://pytorch.org/docs/master/index.html
# matplotlib文檔: https://matplotlib.org/

# 隨機算法的生成種子
torch.manual_seed(1)

# 生成數據
n_data = torch.ones(100, 2)


# 類別一的數據
x0 = torch.normal(2 * n_data, 1)
# 類別一的標簽
y0 = torch.zeros(100)

# 類別二的數據
x1 = torch.normal(-2 * n_data, 1)
# 類別二的標簽
y1 = torch.ones(100)

# x0, x1連接起來, 按維度0連接, 并指定數據的類型
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)
# y0, y1連接, 由于只有一維, 因此沒有指定維度, torch中標簽類型必須為LongTensor
y = torch.cat((y0, y1), ).type(torch.LongTensor)


# x,y 轉為變量, torch只支持變量的訓練, 因為Variable中有grad
x, y = Variable(x), Variable(y)
# 繪制數據散點圖
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c = y.data.numpy(), s = 100, lw = 0, cmap = 'RdYlGn')
plt.show()
png
# 快速搭建分類網絡
net = torch.nn.Sequential(
    torch.nn.Linear(2, 10),
    torch.nn.ReLU(),
    torch.nn.Linear(10, 2))
print(net)
Sequential (
  (0): Linear (2 -> 10)
  (1): ReLU ()
  (2): Linear (10 -> 2)
)
# 定義優化方法
optimizer = torch.optim.SGD(net.parameters(), lr = 0.02)
# 定義損失函數
loss_func = torch.nn.CrossEntropyLoss()

plt.ion()

# 訓練過程
for i in xrange(100):
    prediction = net(x)
    loss = loss_func(prediction, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if i % 2 == 0:
        plt.cla()
        # 獲取概率最大的類別的索引
        prediction = torch.max(F.softmax(prediction), 1)[1]
        # 將輸出結果變為一維
        pred_y = prediction.data.numpy().squeeze()
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c = pred_y, s = 100, lw = 0, cmap = 'RdYlGn')
        # 計算準確率
        accuracy = sum(pred_y == target_y) / 200.0
        plt.text(1.5, -4, 'Accuracy=%.2f' % accuracy, fontdict = {'size': 10, 'color':  'red'})
        plt.pause(0.1)

plt.ioff()
plt.show()
png

參考資料

  1. https://www.youtube.com/user/MorvanZhou
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容