寫此文原因
網上其實有不少關于pytorch自定義數據集的tutorial,但是之所以要寫這個,是因為我發現他們并沒有結合一兩個的神經網絡來講解。所以我覺得再寫一個tutorial講解關于如何讀取任意的數據集,并且讓某個網絡訓練該數據集還是有必要的。
在初學pytorch的時候,我們一般使用的是pytorch自帶的一些數據集,比如 (代碼參考1)
from torchvision.datasets.mnist import MNIST
...
data_train = MNIST('./data/mnist',
download=True,
transform=transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor()]))
....
data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
引入MNIST數據集。最初始的訓練網絡是Lenet-5識別MNIST里面的數字。這就導致當你面對很多JPG, PNG的格式的torchvision.datasets
里沒有的圖像時,不知道怎么讀取他們。這篇文章會帶領大家讀取自定義的數據集并訓練他們。
最后的lenet5代碼自定義數據集的實現請在我的github下載
https://github.com/zhaozhongch/Pytorch_Lenet5_CustomDataset
內容
下面我們從網上下載PNG格式的MNIST數據集。
git clone https://github.com/myleott/mnist_png.git
cd mnist_png
tar -xvf mnist_png.tar.gz #解壓文件夾
解壓之后在minst_png/mnist_png
文件夾里你會看到testing
和training
兩個文件夾,進入testing
你會看到10個文件夾分別儲存數字為0~9的圖片。下面我們簡單實現Lenet-5網絡來識別圖片中的數字。
Lenet5網絡如下圖
途中范例給的輸入圖片是32X32,實際我們上面的PNG圖片大小是28X28,網絡其他結構依次減小即可。
輸入圖片1通道28X28,輸入給第一層
第一層卷積層,卷積核大小5X5,輸出圖像6通道,24X24,卷積之后接激勵函數ReLU
第二層池化層,使用平均池化,池化核大小2X2,輸出圖像6通道,12X12
第三層卷積層,卷積核大小還是5X5,輸出圖像16通道,大小8X8。之后再接ReLu
第四層再接2X2池化。輸出16通道,4X4大小圖片。
第五層全連接層,先把16X4X4的圖片"展平"為線性向量,再通過線性變換把圖片"展平"為120維的變量,接ReLu
第六層再把120維降為84維,接ReLu
第七層再降為10維(對應0~9 10種數字可能性)輸出。
講解Lenet5并不是本文的重點,所以簡單的說了上面的網絡結構后我們就給出網絡實現,對于細節不熟悉的新手可以參考文章1。
根據上面的網絡結構,網絡在pytorch中的實現如下
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1,6,5)
self.pool = nn.AvgPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.linear1 = nn.Linear(16*4*4, 120)
self.linear2 = nn.Linear(120,84)
self.linear3 = nn.Linear(84,10)
def forward(self,x):
x = self.conv1(x)
x = F.relu(x)
x = self.pool(x)
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*4*4)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
理論上來說是很簡單的。
那么針對網絡對輸入數據的要求,我們應該怎么把最開始下載的一堆圖片輸入進去呢?這就要用到pytorch里的Dataset
類了。
你需要定義一個類,繼承Dataset
類,然后類里必須包含3個函數__init__
,__len__
,__getitem__
,具體結構如下
class ReadDataset(Dataset):
def __init__(self, 參數...):
def __len__(self, 參數...):
...
return 數據長度
def __getitem__(self, 參數...):
...
return 字典
__len__
需要返回一個表示數據長度的整型量,__getitem__
需要返回一個字典。ReadDataset
這個類名是自定義的,繼承了Dataset
即可。
接下來的過程,我們先簡單過一遍得到結果,再回看為什么這么做。
為了處理MNIST dataset,我們先把training
文件夾里的圖像label讀取進來
data_length = 60000
data_label = [-1] * data_length
prev_dir = './mnist_png/mnist_png/training/'
after_dir = '.png'
for id in range(10):
id_string = str(id)
for filename in glob(prev_dir + id_string +'/*.png'):
position = filename.replace(prev_dir+id_string+'/', '')
position = position.replace(after_dir, '')
data_label[int(position)] = id
這幾行代碼的作用,是把training
文件夾里的10個文件夾里的共計60000張圖片放入到data_label
里。舉個例子,圖片編號為21的圖,包含的數字是0(在training
文件夾的0
文件夾里),那么data_label[21] = 0
。
接下來定義繼承了Dataset
類的ReadDataset
類,具體如下。
class ReadDataset(Dataset):
def __init__(self, imgs_dir, data_label):
self.imgs_dir = imgs_dir
self.ids = data_label
def __len__(self):
return len(self.ids)
def __getitem__(self, i):
idx = self.ids[i]
imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')
img = np.array(img)
img = img.reshape(1,28,28)
if img.max() > 1:
img = img / 255
return {'image': torch.from_numpy(img), 'label': torch.tensor(idx)}
可以看到,構造函數__init__
里我們有兩個參數,一個是imgs_dir
,圖像地址,另一個是我們之前創建的列表data_label
,賦值給self.ids
. __len__()
僅僅是返回了data_label
的長度。
有趣的是__getitem__
函數,我們看到這個函數的參數是i
,傳入了i
之后,我們首先根據ids
找到它對應的圖像里所標識的數字,繼而根據
imgs_file = self.imgs_dir+ str(idx) + '/' + str(i) + '.png'
img = Image.open(imgs_file).convert('L')
找到圖像并轉化為黑白。之后再轉化為np,再reshape。原圖像讀進來本來是28X28,但是根據網絡的要求,輸入需要是圖像通道數X圖像尺寸,黑白圖片通道為1,所以我們reshape為1X28X28。最后圖像的像素點的灰度值歸一化到0到1.因為我們要使用cross entropy代價函數來訓練,根據官網,要求cross entropy的矩陣輸入的值為0到1。返回的內容格式必須是字典,我們這兒字典的內容圖像和圖像內對應的數字(label)是
{'image': torch.from_numpy(img), 'label': torch.tensor(idx)}
這個getitem函數如果調用,最終達到的目的就是,假如我在代碼中輸入A = __getitem__(0)
,我就應該能得到0.png
對應的那張圖像,獲取圖像的方式就是A['image'],獲取圖像是數字幾的方式是A['label']
。
有了上面的內容作為鋪墊,我們看看主函數里讀取數據的具體操作。首先有下面一行內容
prev_dir = './mnist_png/mnist_png/testing/'
...
all_data = ReadDataset(prev_dir, data_label)
我們把prev_dir
和之前得到的data_label
作為參數傳入了ReadDataset
并返回了all_data
。有的人可能說,誒,我沒看到ReadDataset
有返回值呀。這是因為這些寫在了Dataset
這個類里,不然繼承它干什么呢。隨后,我們把這個返回值賦值給DataLoader
,就可以定義從torchvision
里自帶的MNIST dataset一樣的操作了。
test_loader = DataLoader(all_data, batch_size=batch_size, shuffle=True, num_workers=4)
定義好batch_size,num_workers,代價函數這些之后,我們就可以在訓練的時候使用返回值test_loader
了。
with torch.no_grad():
for data in test_loader:
images = data['image']
labels = data['label']
...
我們可以看到其實我們并沒有顯式地調用__getitem__
函數,而是通過data遍歷test_loader
, data會自動根據ReadDataset
里ids的長度,從1到ids.length
來批量讀取圖像。如果你設置了batch_size
等于4,那么for循環的第一次循環,會調用__getitem__
四次,data['image']
就會返回__getitem__
,return {'image':...}
中image
所對應的內容。
設置代價函數這些不是本文的內容,就不細講了。具體的可參見github代碼。
可能大家看了上面的例子還是有些不明不白,因為雖然ReadDataset
這個類的內容就是定義三個函數,但是這三個函數具體的內容是什么,就需要根據實際情況確定了。我們上面的數據集的圖像是分別儲存在0~9個文件夾中,其他的數據可能不是這么儲存的,就需要想新的辦法獲得那個data_label
列表。但是你的最終目的是很明白的,
1:getitem所返回的內容,需要能輸入到網絡里,比如我們的
images = data['image']
...
outputs = net(images.float())
2: 根據0到ids
的長度的indx,能遍歷你想要使用的所有圖像。
假想你顯式調用__getitem__(0)
,你需要能獲得名字為0.png
或者0.jpg
之類的圖像的內容。
說這些不如多看兩個例子再自己實踐一下。上面的lenet5的例子之外,我在github里分別分開寫了CPU的方法和GPU的方法,當然其實就一兩行代碼的事兒。不過考慮到這還是屬于接近新手范疇的tutorial,就分開寫了。
另外我還在github代碼里提供了稍微復雜的網絡UNET的實現,UNET是用來做語義分割的網絡,不熟悉的同學可以自行看下語義分割是什么blabla。在UNET的這個網絡里,我同樣是讀取的自定義的數據集而不是使用torchvision.dataset
里帶的數據集。代碼放于github
https://github.com/zhaozhongch/Pytorch_UNET_MultiObjects
當下用得最多的pytroch的UNET的實現還只是一個物體分割的(參考此處
),我順便拔高了一下實現多個物體的語義分割了,不過最后語義分割的效果圖不是非常好,因為懶得花時間去仔細fine tune了。但我相信作為tutorial級別的代碼,我覺得跑一遍熟悉一下網絡結構怎么定義,怎么自定義數據集,已經很夠了。覺得還不錯的可以給這倆倉庫點個小star哈哈。
關于這兩個網絡實現或者其他內容不懂的同學歡迎私信。