FAST.AI 圖像分類實踐
計算機視覺是深度學(xué)習(xí)中最常見的應(yīng)用領(lǐng)域,其中主要有:圖像分類、圖像生成,對象檢測、目標(biāo)跟蹤、語義分割、實例分割等。FAST.AI 作為一款基于 PyTorch 開發(fā)的快速深度學(xué)習(xí)工具,自然也就包含有大量更便利的圖像處理模塊和方法。接下來,我們將以最常見的圖像分類為例,使用 FAST.AI 進行實踐。
圖像分類
圖像分類是最為常見的一項深度學(xué)習(xí)任務(wù),一般情況下,完成該類任務(wù)會有 3 個重要步驟。
首先,我們需要對原始數(shù)據(jù)進行處理,將圖像數(shù)據(jù)轉(zhuǎn)換為深度學(xué)習(xí)工具能夠支持的張量數(shù)據(jù)。這一步驟往往就是制作相應(yīng)的數(shù)據(jù)加載器。當(dāng)然,F(xiàn)AST.AI 也有自己對應(yīng)的數(shù)據(jù)加載器 DataBunch 對象,這部分內(nèi)容已在前面章節(jié)完成學(xué)習(xí)。
接下來,就是構(gòu)建深度神經(jīng)網(wǎng)絡(luò)模型。圖像處理相關(guān)的任務(wù),大部分都會使用卷積神經(jīng)網(wǎng)絡(luò)模型。卷積神經(jīng)網(wǎng)絡(luò)是一種非常擅長解決計算機視覺任務(wù)的神經(jīng)網(wǎng)絡(luò)模型。當(dāng)然,無論是 PyTorch,還是 TensorFlow,構(gòu)建一個神經(jīng)網(wǎng)絡(luò)模型的難度不高,我們往往只需要調(diào)用相應(yīng)深度神經(jīng)網(wǎng)絡(luò)框架完成層堆疊即可。
最后,就是神經(jīng)網(wǎng)絡(luò)訓(xùn)練的部分。這部分代碼一般是最為復(fù)雜的,我們需要對數(shù)據(jù)進行適當(dāng)?shù)靥幚恚哉_的方式輸入到神經(jīng)網(wǎng)絡(luò)。最后,對神經(jīng)網(wǎng)絡(luò)的輸出進行處理和評估。神經(jīng)網(wǎng)絡(luò)訓(xùn)練的部分需要有一定的構(gòu)建經(jīng)驗才能完成,尤其是在 PyTorch 的應(yīng)用過程中,相對于 TensorFlow 更為復(fù)雜。
FAST.AI 基于 PyTorch 開發(fā),實際上我認(rèn)為其最大的改進之處就是優(yōu)化了 PyTorch 訓(xùn)練神經(jīng)網(wǎng)絡(luò)復(fù)雜的過程。接下來,我們將通過一個圖像分類示例,來學(xué)習(xí)使用 FAST.AI 完成一個完整的圖像分類任務(wù)。
數(shù)據(jù)處理
接下來,我們選擇前面接觸過的 MNIST 數(shù)據(jù)集進行演示,MNIST 是一個 10 個類別的圖像分類任務(wù),數(shù)據(jù)體積較小,非常適合作為工具使用方法的示例數(shù)據(jù)。首先,我們加載數(shù)據(jù),并構(gòu)建 DataBunch 對象,這部分內(nèi)容實際上已經(jīng)學(xué)習(xí)過了。
from fastai.datasets import untar_data, URLs, download_data
from fastai.vision import ImageDataBunch
# 因原數(shù)據(jù)集下載較慢,從藍橋云課服務(wù)器下載數(shù)據(jù),本次實驗時無需此行代碼
download_data("https://labfile.oss.aliyuncs.com/courses/1445/mnist_png")
mnist_path = untar_data(URLs.MNIST)
mnist_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing')
mnist_data
模型構(gòu)建
構(gòu)建完數(shù)據(jù)加載器 DataBunch 之后,接下來就可以開始構(gòu)建模型了。一般情況下,構(gòu)建一個圖像分類模型有 2 種思路,分別是從頭構(gòu)建和遷移學(xué)習(xí)。從頭構(gòu)建,即意味著由你自己設(shè)計模型的結(jié)構(gòu)和參數(shù)。而遷移學(xué)習(xí)則是利用一些在經(jīng)典神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu)上預(yù)訓(xùn)練的模型進行學(xué)習(xí)。
首先,我們選擇從頭構(gòu)建模型。FAST.AI 提供了一個非常友好的接口 fastai.vision.simple_cnn
來快速實現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的構(gòu)建。該 API 包含 4 個參數(shù):
actns
:定義卷積模塊的數(shù)量和輸入輸出大小。
kernel_szs
:定義卷積核大小,默認(rèn)為 3。
strides
:定義卷積步長大小,默認(rèn)為 2。
bn
:是否包含批量歸一化操作,布爾類型。
接下來,我們就調(diào)用該接口來快速定義一個卷積神經(jīng)網(wǎng)絡(luò)。
from fastai.vision import simple_cnn
model = simple_cnn(actns=(3, 16, 16, 10))
model
如上所示,我們只是定義了卷積神經(jīng)網(wǎng)絡(luò)包含的卷積模塊數(shù)量和輸入輸出大小。該參數(shù)主要注意輸入和輸出尺寸,其中,(3, 16, 16, 10) 表示有 4 個卷積模塊。因為前面的 DataBunch 對象尺寸為 (3, 28, 28),即為 3 個通道圖像,所以第一層卷積操作的輸入尺寸為 3。由于是 10 分類問題,所以最后一個數(shù)字是 10。中間層的尺寸可以自定義,我們選擇了 16。fastai.vision.simple_cnn 最終會自動構(gòu)建為 PyTorch 支持的 Sequential 順序模型。
訓(xùn)練評估
有了模型之后,我們就可以開始第三步,也就是訓(xùn)練過程。FAST.AI 的模型訓(xùn)練過程會用到其核心類 fastai.vision.Learner。最簡單的情況下,我們只需要將數(shù)據(jù) DataBunch,模型和評估指標(biāo)傳入,即可開始訓(xùn)練。
from fastai.vision import Learner, accuracy
# 傳入數(shù)據(jù),模型和準(zhǔn)確度評估指標(biāo)
learner = Learner(mnist_data, model, metrics=[accuracy])
learner
如上所示,我們定義的 Learner 選擇了 accuracy 準(zhǔn)確度作為評估指標(biāo)。你可以通過 Learner 的輸出看到其他相關(guān)的默認(rèn)參數(shù)設(shè)置。例如優(yōu)化器 opt_func 選擇了 Adam,損失函數(shù) loss_func 選擇了交叉熵。
加下來,我們可以調(diào)用 Learner 完成最終的訓(xùn)練,訓(xùn)練方法為 Learner.fit,傳入迭代次數(shù) Epoch 即可。
learner.fit(1) # 數(shù)據(jù)集上訓(xùn)練迭代 1 次
最終,Learner 會打印出最終的訓(xùn)練損失,驗證損失,準(zhǔn)確度和訓(xùn)練所用時長。至此,我們就使用 FAST.AI 完成了一次針對 MNIST 的圖像分類過程。我們整理上面的完整代碼如下:
mnist_path = untar_data(URLs.MNIST)
mnist_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing')
model = simple_cnn(actns=(3, 16, 16, 10))
learner = Learner(mnist_data, model, metrics=[accuracy])
learner.fit(1)
你可以看出,使用 FAST.AI 完成 MNSIT 分類我們只使用了 5 行代碼,而相比于 PyTorch 需要的數(shù)十行代碼和復(fù)雜的構(gòu)建過程,F(xiàn)AST.AI 中的 FAST 是顯而易見的。
遷移學(xué)習(xí)
上面,我們從頭構(gòu)建了一個卷積神經(jīng)網(wǎng)絡(luò)并針對 MNSIT 進行了訓(xùn)練。由于 MNSIT 數(shù)據(jù)集本身就質(zhì)量較高,背景純凈,數(shù)據(jù)規(guī)范,所以最終準(zhǔn)確度還是不錯的。如果你將上方 Learner 訓(xùn)練迭代次數(shù)調(diào)至 3~5 次,準(zhǔn)確度還會有一定的提升,并最終超過 90%。但對于一些復(fù)雜的任務(wù),尤其是樣本數(shù)據(jù)不規(guī)范的情況下,從頭開始訓(xùn)練并不是一個很明智的選擇。所以,很多時候我們會使用預(yù)訓(xùn)練模型做遷移學(xué)習(xí)。
遷移學(xué)習(xí)是一種站在巨人的肩膀上的訓(xùn)練方法。我們可以沿用一些經(jīng)典神經(jīng)網(wǎng)絡(luò)在大型數(shù)據(jù)集訓(xùn)練好的模型,使用自定義數(shù)據(jù)集繼續(xù)更新其中部分層的權(quán)重。最終,可以在較少的時間下取得不錯的訓(xùn)練效果。
FAST.AI 提供的預(yù)訓(xùn)練模型大部分直接來自于 PyTorch,你可以通過 此頁面 瀏覽這些模型。接下來,我們以 ResNet18 為例,針對上面的 MNIST 數(shù)據(jù)完成一次遷移學(xué)習(xí)過程。ResNet18 是 ResNet 精簡結(jié)構(gòu)在 ImageNet 數(shù)據(jù)集上得到的預(yù)訓(xùn)練模型,首先載入該模型并查看結(jié)構(gòu)。
from fastai.vision import models
models.resnet18()
可以看出,相比于我們之前自行搭建的 CNN 結(jié)構(gòu),ResNet18 要復(fù)雜很多。解析來的訓(xùn)練過程需要利用 fastai.vision.cnn_learner 類來構(gòu)建 Learner,這一點也與上面有所不同。你只需要記住,如果是從頭開始就使用 fastai.vision.Learner,如果是遷移學(xué)習(xí)就使用 fastai.vision.cnn_learner 即可。
from fastai.vision import cnn_learner
# 構(gòu)建基于 ResNet18 的 Learner 學(xué)習(xí)器
learner = cnn_learner(mnist_data, models.resnet18, metrics=[accuracy])
learner.fit(1) # 訓(xùn)練迭代 1 次
你可以看到 Learner 會自動下載 ResNet18 的 .pth 預(yù)訓(xùn)練權(quán)重文件,然后開始訓(xùn)練迭代過程。訓(xùn)練過程相對于上方會更長一些,原因是模型復(fù)雜度更高。最終,使用 ResNet18 完成 1 次迭代的準(zhǔn)確度,應(yīng)該會比上方我們自定義的模型高一些。
所以,當(dāng)我們使用 FAST.AI 執(zhí)行遷移學(xué)習(xí)時,代碼可以進一步精簡至 4 行。這對于使用 PyTorch 和 TensorFlow 是不可想象的簡單,也體現(xiàn)了高階 API 的優(yōu)勢。
mnist_path = untar_data(URLs.MNIST)
mnist_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing')
learner = Learner(mnist_data, models.resnet18, metrics=[accuracy])
learner.fit(1)
雖然我們的準(zhǔn)確度已經(jīng)達到了 90% 以上,但模型仍然對部分驗證數(shù)據(jù)無法準(zhǔn)確區(qū)分。接下來,我們可以通過 FAST.AI 提供的 fastai.vision.ClassificationInterpretation
方法來對結(jié)果進行進一步分析。
from fastai.vision import ClassificationInterpretation
# 載入學(xué)習(xí)器
interp = ClassificationInterpretation.from_learner(learner)
interp
首先,我們可以輸出那些被分類器預(yù)測錯誤的樣本進行觀察。直接通過 interp.plot_top_losses
方法輸出損失最大的 9 個驗證樣本,并比對它們本來的標(biāo)簽和預(yù)測結(jié)果。
interp.plot_top_losses(9, figsize=(9, 9))
上面依次輸出了預(yù)測標(biāo)簽,真實標(biāo)簽,損失和預(yù)測概率。你可以看到,部分樣本的確人眼都很難完成辨識,當(dāng)然也有一些人眼可辨識樣本被錯誤分類。
除了比對圖像,F(xiàn)AST.AI 還提供了一個非常方便的方法 interp.plot_confusion_matrix。通過該方法,我們可以直接繪制出真實標(biāo)簽和預(yù)測標(biāo)簽之間的混淆矩陣。
interp.plot_confusion_matrix(figsize=(5, 5), dpi=100)
混淆矩陣展示了真實標(biāo)簽和預(yù)測標(biāo)簽對應(yīng)樣本的數(shù)量。可以看出,0-9 這 10 類樣本在分布上沒有明顯的傾斜。你也可以進一步看出,到底哪些樣本更容易被預(yù)測錯誤,以及被錯誤預(yù)測的標(biāo)簽結(jié)果。
數(shù)據(jù)擴增
數(shù)據(jù)在神經(jīng)網(wǎng)絡(luò)訓(xùn)練過程中伴有很大的左右,如果符合要求的數(shù)據(jù)越多,往往訓(xùn)練的結(jié)果也更好。所以,很多時候我們會對現(xiàn)有數(shù)據(jù)進行一些旋轉(zhuǎn)、變換、鏡像、歸一化等操作。這些操作不僅可以在一定程度上起到數(shù)據(jù)擴增的效果,能夠?qū)δP陀?xùn)練帶來一些幫助。
FAST.AI 提供了一個非常方便的函數(shù) fastai.vision.get_transforms
用于對圖像進行變換,該函數(shù)的主要參數(shù)有:
do_flip
:如果為 True
,則以 0.5 的概率應(yīng)用隨機翻轉(zhuǎn)。
flip_vert
:應(yīng)用水平翻轉(zhuǎn)。如果 do_flip=True
時,則可以垂直翻轉(zhuǎn)圖像或旋轉(zhuǎn) 90 度。
max_rotate
:如果不為 None
,則在 -max_rotate
和 max_rotate
度之間隨機旋轉(zhuǎn),概率為 p_affine
。
max_zoom
:如果不是 1 或小于 1,則在 1 之前進行隨機縮放,并以 p_affine
概率應(yīng)用 max_zoom
。
max_lighting
:如果不為 None
,則以 max_lighting
概率 p_lighting
施加由 max_lighting
控制的隨機噪聲和對比度變化。
max_warp
:如果不是 None
,則以概率 p_affine
施加 -max_warp
和 maw_warp
之間的隨機對稱扭曲。
p_affine
:應(yīng)用每個仿射變換和對稱扭曲的概率。
p_lighting
:應(yīng)用每個照明變換的概率。
xtra_tfms
:您想要應(yīng)用的其他變換的列表。
接下來,通過一個直觀的例子來演示數(shù)據(jù)變換擴增的效果。我們讀取 MNIST 訓(xùn)練數(shù)據(jù)中第一個樣本:
img, label = mnist_data.train_ds[0]
img.show(title=f'{label}')
然后,我們嘗試對該數(shù)據(jù)進行隨機旋轉(zhuǎn)變換操作。為了更加方便地演示旋轉(zhuǎn)后的效果,這里定義一個輔助繪圖函數(shù) plots_f
。
from fastai.vision import get_transforms
from matplotlib import pyplot as plt
%matplotlib inline
# 輔助繪圖函數(shù),參考自 FAST.AI 官方文檔
def plots_f(rows, cols, width, height, **kwargs):
[img.apply_tfms(tfms[0], **kwargs).show(ax=ax) for i, ax in enumerate(plt.subplots(
rows, cols, figsize=(width, height))[1].flatten())]
接下來,定義變換操作并應(yīng)用繪圖。
# 定義變換操作,最大 [-25, 25] 度之間的隨機旋轉(zhuǎn)
tfms = get_transforms(max_rotate=25)
# 繪制樣本變換后圖像
plots_f(2, 4, 12, 6, size=224)
可以看到,樣本被執(zhí)行了 -25 度到 25 度之間的隨機旋轉(zhuǎn)操作。不過,上面的示例有一定的缺陷。因為對于手寫字符,較大幅度的旋轉(zhuǎn)或鏡像圖像會嚴(yán)重影響樣本所反映的內(nèi)容,甚至變成完全不是數(shù)字的樣子。所以,對于 MNIST 這類數(shù)據(jù),我們往往只能應(yīng)用小幅度的旋轉(zhuǎn)、添加噪聲等變換,以避免對樣本本身含義的影響。但是,對于如下所示的動物圖像,更大幅度的變換對數(shù)據(jù)集擴增更有意義。
fastai.vision.get_transforms
操作一般會直接添加至 DataBunch 對象生成過程中,這樣就可以將變換操作應(yīng)用于樣本數(shù)據(jù)。
# 示例,制作 DataBunch 對象時添加 get_transforms 操作
tfms = get_transforms(max_rotate=25)
tfms_data = ImageDataBunch.from_folder(mnist_path, 'training', 'testing', ds_tfms=tfms)
tfms_data.show_batch(rows=3, figsize=(5,5))
CIFAR10 圖像分類挑戰(zhàn)
前面的挑戰(zhàn)中,我們已經(jīng)熟悉了 CIFAR10 數(shù)據(jù)集,并將其處理成 FAST.AI 支持的 DataBunch 對象。本次挑戰(zhàn)中,我們同樣需讀取 CIFAR10 數(shù)據(jù)集,并添加針對數(shù)據(jù)集變換的預(yù)處理過程。
接下來,請將 CIFAR10 數(shù)據(jù)集處理成 DataBunch 對象。挑戰(zhàn)要求,將 train 文件夾中數(shù)據(jù)分離 20% 作為驗證集,剩下數(shù)據(jù)作為訓(xùn)練集。test 文件夾下數(shù)據(jù)作為測試集。同時,加入 get_transforms 變換,應(yīng)用[?30,30] 度之間的隨機旋轉(zhuǎn)變換。
from fastai.datasets import untar_data, URLs, download_data
from fastai.vision import ImageDataBunch, get_transforms
download_data("http://labfile.oss.aliyuncs.com/courses/1445/cifar10")
data_path = untar_data(URLs.CIFAR)
# 針對數(shù)據(jù)集變換
tfms = get_transforms(max_rotate=30)
data_bunch = ImageDataBunch.from_folder(data_path, train='train', test='test',
valid_pct=0.2, ds_tfms=tfms)
接下來,請使用 FAST.AI 提供的建模方法,應(yīng)用卷積神經(jīng)網(wǎng)絡(luò)對 CIFAR10 完成分類和評估。你可以自由選擇「從頭開始訓(xùn)練」或「遷移學(xué)習(xí)方法」。遷移學(xué)習(xí)所使用的預(yù)訓(xùn)練模型也可以通過閱讀官方文檔自由選擇。
挑戰(zhàn)最終要求,驗證集上的分類準(zhǔn)確度不得低于 70%。由于訓(xùn)練時間較長,你可以在恰當(dāng)?shù)臅r候中止訓(xùn)練。
from fastai.vision import models, cnn_learner, accuracy
models.resnet18()
# 構(gòu)建基于 ResNet18 的 Learner 學(xué)習(xí)器
learner = cnn_learner(data_bunch, models.resnet18, metrics=[accuracy])
learner.fit(15) # 訓(xùn)練迭代 15 次
僅供參考,accuracy 最終大于 70% 即可。