【教程】Tensorflow vs PyTorch —— 神經(jīng)網(wǎng)絡(luò)的搭建和訓(xùn)練

brown and white stripe textile

image from unsplash.com by @wolfgang_hasselmann

上一篇文章,我們用 Tensorflow 和 PyTorch 分別完成了函數(shù)自動(dòng)求導(dǎo)以及參數(shù)手動(dòng)和自動(dòng)優(yōu)化的任務(wù),這篇文章我們就通過經(jīng)典的 MNSIT 手寫數(shù)字識(shí)別數(shù)據(jù)集,對(duì)比一下,如何使用兩個(gè)框架建立訓(xùn)練全鏈接的神經(jīng)網(wǎng)絡(luò),對(duì)手寫數(shù)字進(jìn)行分類.

獲取文章代碼請(qǐng)關(guān)注微信公眾號(hào)"tensor_torch" 二維碼見文末

1. 數(shù)據(jù)導(dǎo)入

像 MNIST 這樣經(jīng)典的數(shù)據(jù)集 Tensorflow 和 PyTorch 都能直接下載,并提供非常方便快捷的加載工具.

  1. Tensorflow 用 tf.keras.datasets.mnist.load_data()加載數(shù)據(jù),數(shù)據(jù)為 numpy.ndarray格式 .

  2. PyTorch 從 torchvison.datasets.MNIST 中加載,數(shù)據(jù)格式為 Image,無(wú)法直接使用,需要設(shè)置 transform = transforms.ToTensor() 轉(zhuǎn)換成張量數(shù)據(jù);這里的 transform 不僅可以轉(zhuǎn)換數(shù)據(jù)格式, 如果傳入transform.Compose() 可以通過 list 傳入更多轉(zhuǎn)換的參數(shù),比如代碼中就將數(shù)據(jù)同時(shí)進(jìn)行了normalize 的處理.

  3. Tensorflow 中可以通過tf.data.Dataset.from_tensor_slices() 構(gòu)建數(shù)據(jù)集對(duì)象.并通過 .map 自定義的preprocess函數(shù),對(duì)數(shù)據(jù)進(jìn)行預(yù)處理.還可以直接使用.shuffle().batch()對(duì)數(shù)據(jù)進(jìn)行打散和批處理.

  4. PyTorch 中使用torch.utils.data.DataLoader 構(gòu)建數(shù)據(jù)集對(duì)象,完成數(shù)據(jù) 創(chuàng)建batch 批處理,以及對(duì)數(shù)據(jù)進(jìn)行打散(Shuffle)

  5. 注意處理后數(shù)據(jù)的 shape, Tensorflow 中 image shape: [b, 28, 28], label shape: [b], PyTorch image shape: [b, 1,28, 28], label shape: [b]

  6. PyTorch 的 DataLoader 可以設(shè)置訓(xùn)練數(shù)據(jù)的 Train = False 避免在測(cè)試數(shù)據(jù)庫(kù)中對(duì)數(shù)據(jù)進(jìn)行訓(xùn)練,而 Tensorflow 就只能在搭建網(wǎng)絡(luò)的時(shí)候才能聲明了.

# ------------------------Tensorflow -----------------------------
(x, y),(x_test, y_test) = keras.datasets.mnist.load_data()

ds_train = tf.data.Dataset.from_tensor_slices((x,y))
ds_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))

def preprocess(x, y):
  x = (tf.cast(x, tf.float32)/255)-0.1307
  y = tf.cast(y, tf.int32)
#   y = tf.one_hot(y,depth=10)   
  return x, y

ds_train = ds_train.map(preprocess).shuffle(1000).batch(batch_size)
ds_test = ds_test.map(preprocess).shuffle(1000).batch(batch_size)

# ------------------------PyTorch --------------------------------

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True

2. 手動(dòng)搭網(wǎng)

2.1 參數(shù)初始化

我們首先介紹如何手動(dòng)搭建全鏈接的神經(jīng)網(wǎng)絡(luò),這里的難點(diǎn)是參數(shù)的初始化和管理.我們的模型有三層全鏈接的神經(jīng)網(wǎng)絡(luò),所以我們需要初始化3組 w 和 b.注意每一組的shape:

網(wǎng)絡(luò):[b, 786] -> [b, 200] -> [b, 100] -> [b, 10]

w1: [786, 200], b1: [200],

w2: [200,100], b2: [100],

w3: [100,10], b3:[10]

# ------------------------Tensorflow -----------------------------
w1 = tf.Variable(tf.random.uniform([28*28, 200]))
b1 = tf.Variable(tf.zeros([200]))

w2 = tf.Variable(tf.random.uniform([200, 100]))
b2 = tf.Variable(tf.zeros([100]))

w3 = tf.Variable(tf.random.uniform([100, 10]))
b3 = tf.Variable(tf.zeros([10]))
# ------------------------PyTorch --------------------------------
w1 = torch.rand(28*28, 200 , requires_grad=True)
b1 = torch.zeros(200, requires_grad=True)

w2 = torch.rand(200, 100, requires_grad=True)
b2 = torch.zeros(100, requires_grad=True)

w3 = torch.rand(100, 10, requires_grad=True)
b3 = torch.zeros(10, requires_grad=True)

2.2 搭建網(wǎng)絡(luò)

這里我們均采用自定義函數(shù)的方式來搭建網(wǎng)絡(luò),這個(gè)部分兩個(gè)框架沒有太大區(qū)別.我們手動(dòng)定義了三層神經(jīng)網(wǎng)絡(luò),前兩層包含 relu 激活函數(shù),最后一層沒有使用激活函數(shù).

# ------------------------Tensorflow -----------------------------
# forward func
def model(x):
    x = tf.nn.relu(x@w1 + b1)
    x = tf.nn.relu(x@w2 + b2)
    x = x@w3 + b3
        
    return x
# ------------------------PyTorch --------------------------------
# forward func
def forward(x):
    x = F.relu(x@w1 + b1)
    x = F.relu(x@w2 + b2)
    x = x@w3 + b3
        
    return x

2.3 訓(xùn)練網(wǎng)絡(luò)

該部分與前文中介紹的自動(dòng)求導(dǎo),參數(shù)優(yōu)化的部分一致,按照套路進(jìn)行就行了,需注意以下幾點(diǎn).

  1. 對(duì)于全鏈接網(wǎng)絡(luò)首先需要對(duì)數(shù)據(jù)打平,Tensorflow 和 PyTorch 都可以用 reshape 方法實(shí)現(xiàn).
  2. 為了與 PyTorch 中torch.nn.CrossEntropyLoss()交叉熵的方法一致,Tensorflow 中并未對(duì)label 進(jìn)行 One-Hot 編碼,所以使用了tf.losses.sparse_categorical_crossentropy() 方法計(jì)算交叉熵.
# ------------------------Tensorflow -----------------------------
optimizer = tf.optimizers.Adam(learning_rate)

for epoch in range(epochs):
    
    for step, (x, y) in enumerate(ds_train):
        x = tf.reshape(x, [-1, 28*28])
        with tf.GradientTape() as tape:            
            logits = model(x)
            
            losses = tf.losses.sparse_categorical_crossentropy(y,logits,from_logits=True)
            loss = tf.reduce_mean(losses)
            
        grads = tape.gradient(loss, [w1,b1,w2,b2,w3,b3])
        
        optimizer.apply_gradients(zip(grads, [w1,b1,w2,b2,w3,b3]))
# ------------------------PyTorch --------------------------------
optimizer = torch.optim.Adam([w1,b1,w2,b2,w3,b3],
                            lr=learning_rate)
criteon = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    
    for step, (x, y) in enumerate(train_loader):
        x = x.reshape(-1,28*28)
        
        logits = forward(x)
        loss = criteon(logits, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

3. 高級(jí) API 搭建網(wǎng)絡(luò)

手動(dòng)搭建網(wǎng)絡(luò)的好處是,都是采用最底層的方式,整個(gè)過程透明可控.但是壞處就是需要手動(dòng)管理每一個(gè)參數(shù),網(wǎng)絡(luò)一旦復(fù)雜起來就容易出錯(cuò).

Tensorflow 和 PyTorch 均可采用創(chuàng)建模型對(duì)象(Class)的方式創(chuàng)建神經(jīng)網(wǎng)絡(luò)模型.

  1. Tensorflow 繼承 tf.keras.Model對(duì)象,PyTorch 繼承 torch.nn.Module對(duì)象.
  2. Tensorflow 模型對(duì)象中,前向傳播調(diào)用 call() 函數(shù),PyTorch 調(diào)用 forward() 函數(shù).
  3. 在訓(xùn)練過程中僅需將手動(dòng)搭網(wǎng)的函數(shù)替換成初始化后的網(wǎng)絡(luò)模型對(duì)象即可.
# ------------------------Tensorflow -----------------------------
class FC_model(keras.Model):
    def __init__(self):
        super().__init__()
    
        self.model = keras.Sequential(
            [layers.Dense(200),
            layers.ReLU(),
            layers.Dense(100),
            layers.ReLU(),
            layers.Dense(10)]
            )
    
    def call(self,x):
        x = self.model(x)
        
        return x
    
model = FC_model()
# ------------------------PyTorch --------------------------------
class FC_NN(nn.Module):
    def __init__(self):
        super().__init__()
    
        self.model = nn.Sequential(
            nn.Linear(28*28, 200),
            nn.ReLU(inplace=True),
            nn.Linear(200, 100),
            nn.ReLU(inplace=True),
            nn.Linear(100,10)
            )
    
    def forward(self, x):
        x = self.model(x)
        
        return x

network = FC_NN().to(device)  

4. 使用 GPU 加速訓(xùn)練

如果訓(xùn)練環(huán)境支持 GPU ,Tensorflow 和 PyTorch 均可以調(diào)用 GPU 加速計(jì)算.Tensorflow 如果使用的是 Tensorflow-gpu 版本,我們無(wú)需任何操作,直接就是調(diào)用的GPU進(jìn)行計(jì)算.

對(duì)于 PyTorch ,需要?jiǎng)?chuàng)建 device = torch.device('cuda:0')并將網(wǎng)絡(luò)和參數(shù)搬到這個(gè) device 上進(jìn)行計(jì)算.

...
device = torch.device('cuda:0')
network = FC_NN().to(device)  

criteon = torch.nn.CrossEntropyLoss().to(device)
...

for epoch in range(epochs):
...        
        x, y = x.to(device), y.to(device)
...

5. 模型測(cè)試

模型訓(xùn)練好了之后需要使用驗(yàn)證數(shù)據(jù)集進(jìn)行測(cè)試。這里我們簡(jiǎn)單的采用正確率(accuracy)來對(duì)模型進(jìn)行驗(yàn)證

正確率 = 預(yù)測(cè)正確的樣本數(shù) / 所有樣本數(shù)

代碼看起來比較繁瑣,不過就是以下幾個(gè)步驟:

  1. 將所有驗(yàn)證數(shù)據(jù)帶入訓(xùn)練好的模型中,給出預(yù)測(cè)值。
  2. 將預(yù)測(cè)值與實(shí)際值進(jìn)行比較。
  3. 累加預(yù)測(cè)正確的樣本數(shù)和總樣本數(shù)。
  4. 用上面的公式算出正確率

實(shí)際上 tensorflow 可以調(diào)用tf.keras.metrics 這個(gè)在之前的文章中已經(jīng)提到,這里就不贅述了。

# ------------------------Tensorflow -----------------------------
if(step%100==0):
            print("epoch:{}, step:{} loss:{}".
                  format(epoch, step, loss.numpy()))         
            
#             test accuracy: 
            total_correct = 0
            total_num = 0
            
            for x_test, y_test in ds_test:
                x_test = tf.reshape(x_test, [-1, 28*28])
                y_pred = tf.argmax(model(x_test),axis=1)
                y_pred = tf.cast(y_pred, tf.int32)
                correct = tf.cast((y_pred == y_test), tf.int32)
                correct = tf.reduce_sum(correct)
                
                total_correct += int(correct)
                total_num += x_test.shape[0]
        
            
            accuracy = total_correct/total_num
            print('accuracy: ', accuracy)
# ------------------------PyTorch --------------------------------
        if(step%100 == 0):
            print("epoch:{}, step:{}, loss:{}".
                  format(epoch, step, loss.item()))
        
#             test accuracy
            total_correct = 0
            total_num = 0    

            for x_test, y_test in test_loader:
                    x_test = x_test.reshape(-1,28*28)
                    x_test, y_test = x_test.to(device), y_test.to(device)

                    y_pred = network(x_test)
                    y_pred = torch.argmax(y_pred, dim = 1)
                    correct = y_pred == y_test
                    correct = correct.sum()

                    total_correct += correct
                    total_num += x_test.shape[0]

            acc = total_correct.float()/total_num
            print("accuracy: ", acc.item())

相關(guān)文章
【教程】Tensorflow vs PyTorch —— 自動(dòng)求導(dǎo)
【教程】Tensorflow vs PyTorch —— 數(shù)學(xué)運(yùn)算
【教程】Tensorflow vs PyTorch —— 張量的基本操作
Tensorflow 2 vs PyTorch 對(duì)比學(xué)習(xí)教程開啟
Tensorflow 2.0 --- ResNet 實(shí)戰(zhàn) CIFAR100 數(shù)據(jù)集
Tensorflow2.0——可視化工具tensorboard
Tensorflow2.0-數(shù)據(jù)加載和預(yù)處理
Tensorflow 2.0 快速入門 —— 引入Keras 自定義模型
Tensorflow 2.0 快速入門 —— 自動(dòng)求導(dǎo)與線性回歸
Tensorflow 2.0 輕松實(shí)現(xiàn)遷移學(xué)習(xí)
Tensorflow入門——Eager模式像原生Python一樣簡(jiǎn)潔優(yōu)雅
Tensorflow 2.0 —— 與 Keras 的深度融合


最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。
  • 序言:七十年代末,一起剝皮案震驚了整個(gè)濱河市,隨后出現(xiàn)的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,936評(píng)論 6 535
  • 序言:濱河連續(xù)發(fā)生了三起死亡事件,死亡現(xiàn)場(chǎng)離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機(jī),發(fā)現(xiàn)死者居然都...
    沈念sama閱讀 98,744評(píng)論 3 421
  • 文/潘曉璐 我一進(jìn)店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事?!?“怎么了?”我有些...
    開封第一講書人閱讀 176,879評(píng)論 0 381
  • 文/不壞的土叔 我叫張陵,是天一觀的道長(zhǎng)。 經(jīng)常有香客問我,道長(zhǎng),這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,181評(píng)論 1 315
  • 正文 為了忘掉前任,我火速辦了婚禮,結(jié)果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當(dāng)我...
    茶點(diǎn)故事閱讀 71,935評(píng)論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發(fā)上,一...
    開封第一講書人閱讀 55,325評(píng)論 1 324
  • 那天,我揣著相機(jī)與錄音,去河邊找鬼。 笑死,一個(gè)胖子當(dāng)著我的面吹牛,可吹牛的內(nèi)容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,384評(píng)論 3 443
  • 文/蒼蘭香墨 我猛地睜開眼,長(zhǎng)吁一口氣:“原來是場(chǎng)噩夢(mèng)啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側(cè)響起,我...
    開封第一講書人閱讀 42,534評(píng)論 0 289
  • 序言:老撾萬(wàn)榮一對(duì)情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個(gè)月后,有當(dāng)?shù)厝嗽跇淞掷锇l(fā)現(xiàn)了一具尸體,經(jīng)...
    沈念sama閱讀 49,084評(píng)論 1 335
  • 正文 獨(dú)居荒郊野嶺守林人離奇死亡,尸身上長(zhǎng)有42處帶血的膿包…… 初始之章·張勛 以下內(nèi)容為張勛視角 年9月15日...
    茶點(diǎn)故事閱讀 40,892評(píng)論 3 356
  • 正文 我和宋清朗相戀三年,在試婚紗的時(shí)候發(fā)現(xiàn)自己被綠了。 大學(xué)時(shí)的朋友給我發(fā)了我未婚夫和他白月光在一起吃飯的照片。...
    茶點(diǎn)故事閱讀 43,067評(píng)論 1 371
  • 序言:一個(gè)原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內(nèi)的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,623評(píng)論 5 362
  • 正文 年R本政府宣布,位于F島的核電站,受9級(jí)特大地震影響,放射性物質(zhì)發(fā)生泄漏。R本人自食惡果不足惜,卻給世界環(huán)境...
    茶點(diǎn)故事閱讀 44,322評(píng)論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,735評(píng)論 0 27
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽(yáng)。三九已至,卻和暖如春,著一層夾襖步出監(jiān)牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 35,990評(píng)論 1 289
  • 我被黑心中介騙來泰國(guó)打工, 沒想到剛下飛機(jī)就差點(diǎn)兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個(gè)月前我還...
    沈念sama閱讀 51,800評(píng)論 3 395
  • 正文 我出身青樓,卻偏偏與公主長(zhǎng)得像,于是被迫代替她去往敵國(guó)和親。 傳聞我的和親對(duì)象是個(gè)殘疾皇子,可洞房花燭夜當(dāng)晚...
    茶點(diǎn)故事閱讀 48,084評(píng)論 2 375