筆者PyTorch的全部簡單教程請?jiān)L問:http://www.lxweimin.com/nb/48831659
PyTorch教程-5:詳解PyTorch中加載數(shù)據(jù)的方法--Dataset、Dataloader、Sampler、collate_fn等
數(shù)據(jù)讀取是所有訓(xùn)練模型任務(wù)中最基礎(chǔ)最重要的一步,PyTorch為數(shù)據(jù)集的讀取、加載和使用提供了很好的機(jī)制,使得數(shù)據(jù)加載的工作變得異常簡單而且具有非常高的定制性。
Dataset、Dataloader、Sampler的關(guān)系
PyTorch中對于數(shù)據(jù)集的處理有三個非常重要的類:Dataset
、Dataloader
、Sampler
,它們均是 torch.utils.data
包下的模塊(類)。它們的關(guān)系可以這樣理解:
-
Dataset
是數(shù)據(jù)集的類,主要用于定義數(shù)據(jù)集 -
Sampler
是采樣器的類,用于定義從數(shù)據(jù)集中選出數(shù)據(jù)的規(guī)則,比如是隨機(jī)取數(shù)據(jù)還是按照順序取等等 -
Dataloader
是數(shù)據(jù)的加載類,它是對于Dataset
和Sampler
的進(jìn)一步包裝,即其實(shí)Dataset
和Sampler
會作為參數(shù)傳遞給Dataloader
,用于實(shí)際讀取數(shù)據(jù),可以理解為它是這個工作的真正實(shí)踐者,而Dataset
和Sampler
則負(fù)責(zé)定義。我們訓(xùn)練、測試所獲得的數(shù)據(jù)也是Dataloader
直接給我們的。
總的來說,Dataset
定義了整個數(shù)據(jù)集,Sampler
提供了取數(shù)據(jù)的機(jī)制,最后由Dataloader
取完成取數(shù)據(jù)的任務(wù)。
本篇以一個最簡單的例子為例,比如有一個文件夾(data-folder)中存儲訓(xùn)練的數(shù)據(jù)(一共30張圖片:0.png 到 29.png),他們對應(yīng)的標(biāo)簽被寫在了一個labels.txt文件中,第n行對應(yīng)n-1.png的標(biāo)簽,是一個三分類問題,即0、1和2三種標(biāo)簽(虛構(gòu)的數(shù)據(jù)集,不具有任何意義)。目錄結(jié)構(gòu)如下:
|--- Project
|--- main.py
|--- labels.txt
|--- data-folder
|--- 0.png
|--- 1.png
|--- ……
|--- 29.png
Dataset
Dataset
位于 torch.utils.data
下,我們通過定義繼承自這個類的子類來自定義數(shù)據(jù)集。它有兩個最重要的方法需要重寫,實(shí)際上它們都是類的特殊方法:
-
__getitem__(self, index)
:傳入?yún)?shù)index
為下標(biāo),返回?cái)?shù)據(jù)集中對應(yīng)下標(biāo)的數(shù)據(jù)組(數(shù)據(jù)和標(biāo)簽) -
__len__(self)
:返回?cái)?shù)據(jù)集的大小
簡單說,重寫了這兩個方法的繼承自 Dataset
的類都可以作為數(shù)據(jù)集的定義類使用,即一個Dataset
類的必要結(jié)構(gòu):
class Dataset(torch.utils.data.Dataset):
def __init__(self, filepath=None,dataLen=None):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
如下就是我們的例子的加載實(shí)例,其中的 image2tensor
使用了torchvision.transforms
完成了一個簡單的從PIL.Image
格式的圖片到 tensor
的轉(zhuǎn)換,可以先不必在意,后面會詳細(xì)地講到 transforms
這個超級重要的工具:
from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms
class MyDataset(Dataset):
def __init__(self, images_folder_path, labels_file_path):
self.images_folder_path = images_folder_path
with open(labels_file_path, 'r') as file:
self.labels = list(map(int, file.read().splitlines()))
def __getitem__(self, item):
image = Image.open(os.path.join(self.images_folder_path, "{}.png".format(item)))
image = self.image2tensor(image)
label = self.labels[item]
return (image, label)
def __len__(self):
return len(self.labels)
def image2tensor(self, image):
"""
transform PIL.Image to tensor
:param image: image in PIL.Image format
:return: image in tensor format
"""
transform = transforms.Compose([
transforms.ToTensor()
])
image = image.convert('RGB')
return transform(image)
myDataset = MyDataset("./data-folder/", "./labels.txt")
Dataloader
Dataloader
對Dataset
(和Sampler
等)打包,完成最后對數(shù)據(jù)的讀取的執(zhí)行工作,一般不需要自己定義或者重寫一個Dataloader
的類(或子類),直接使用即可,通過傳入?yún)?shù)定制Dataloader
,定制化的功能應(yīng)該在Dataset
(和Sampler
等)中完成了。
Dataloader
的完整簽名見:https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
Dataloader的一些常用參數(shù)
Dataloader
的一些重要的參數(shù)如下,除了第一個dataset
參數(shù)外,其他均為可選參數(shù):
-
dataset
(第一個參數(shù),必須的參數(shù)):一個Dataset
的實(shí)例,即傳入的數(shù)據(jù)集(或者其他可迭代對象) -
batch_size
:整數(shù)值,每個batch的樣本數(shù)量,即batch大小,默認(rèn)為1 -
shuffle
:bool值,如果設(shè)置為True
,則在每個epoch開始的時(shí)候,會對數(shù)據(jù)集的數(shù)據(jù)進(jìn)行重新排序,默認(rèn)False
-
sampler
:傳入一個自定義的Sampler
實(shí)例,定義從數(shù)據(jù)集中取樣本的策略,Sampler
每次返回一個索引,默認(rèn)為None
-
batch_sampler
:也是傳入一個自定義的Sampler
實(shí)例,但是與sampler
參數(shù)不同的是,它接收的Sampler
是一次返回一個batch的索引,默認(rèn)為None
-
num_workers
:整數(shù)值,定義有幾個進(jìn)程來處理數(shù)據(jù)。0
意味著所有的數(shù)據(jù)都會被加載進(jìn)主進(jìn)程,默認(rèn)0
-
collate_fn
:傳入一個函數(shù),它的作用是將一個batch
的樣本打包成一個大的tensor
,tensor
的第一維就是這些樣本,如果沒有特殊需求可以保持默認(rèn)即可(后邊會詳細(xì)介紹) -
pin_memory
:bool值,如果為True
,那么將加載的數(shù)據(jù)拷貝到CUDA中的固定內(nèi)存中。 -
drop_last
:bool值,如果為True
,則對最后的一個batch來說,如果不足batch_size
個樣本了就舍棄,如果為False
,也會繼續(xù)正常執(zhí)行,只是最后的一個batch可能會小一點(diǎn)(剩多少算多少),默認(rèn)False
-
timeout
:如果是正數(shù),表明等待從加載一個batch等待的時(shí)間,若超出設(shè)定的時(shí)間還沒有加載完,就放棄這個batch,如果是0
,表示不設(shè)置限制時(shí)間。默認(rèn)為0
Dataloader參數(shù)之間的互斥
值得注意的是,Dataloader
的參數(shù)之間存在互斥的情況,主要針對自己定義的采樣器:
-
sampler
:如果自行指定了sampler
參數(shù),則shuffle
必須保持默認(rèn)值,即False
-
batch_sampler
:如果自行指定了batch_sampler
參數(shù),則batch_size
、shuffle
、sampler
、drop_last
都必須保持默認(rèn)值
如果沒有指定自己是采樣器,那么默認(rèn)的情況下(即sampler
和batch_sampler
均為None
的情況下),dataloader
的采樣策略是如何的呢:
-
sampler
:-
shuffle = True
:sampler
采用RandomSampler
,即隨機(jī)采樣 -
shuffle = Flase
:sampler
采用SequentialSampler
,即按照順序采樣
-
-
batch_sampler
:采用BatchSampler
,即根據(jù)batch_size
進(jìn)行batch采樣
上面提到的 RandomSampler
、SequentialSampler
和BatchSampler
都是PyTorch自己實(shí)現(xiàn)的,且它們都是Sampler
的子類,后邊會詳述。
Dataloader的實(shí)例
下面我們繼續(xù)我們的例子,定義Dataloader
的實(shí)例,從我們定義的 myDataset
數(shù)據(jù)集中加載數(shù)據(jù),每一個batch大小為8
。并且我們使用了一個循環(huán)來驗(yàn)證其工作的情況:
from torch.utils.data import DataLoader
myDataloader = DataLoader(myDataset, batch_size=8)
for epoch in range(2):
for data in myDataloader:
images, labels = data[0], data[1]
print(len(images))
print(labels)
# train your module
8
tensor([0, 1, 1, 1, 2, 0, 1, 2])
8
tensor([0, 2, 1, 1, 1, 1, 2, 0])
8
tensor([1, 0, 0, 0, 0, 1, 1, 0])
6
tensor([2, 0, 1, 1, 1, 2])
8
tensor([0, 1, 1, 1, 2, 0, 1, 2])
8
tensor([0, 2, 1, 1, 1, 1, 2, 0])
8
tensor([1, 0, 0, 0, 0, 1, 1, 0])
6
tensor([2, 0, 1, 1, 1, 2])
Sampler
Sampler
類是一個很抽象的父類,其主要用于設(shè)置從一個序列中返回樣本的規(guī)則,即采樣的規(guī)則。Sampler
是一個可迭代對象,使用step
方法可以返回下一個迭代后的結(jié)果,因此其主要的類方法就是 __iter__
方法,定義了迭代后返回的內(nèi)容。其父類的代碼如下(PyTorch 1.7):
class Sampler(Generic[T_co]):
def __init__(self, data_source: Optional[Sized]) -> None:
pass
def __iter__(self) -> Iterator[T_co]:
raise NotImplementedError
從上述代碼可見,其實(shí)Sampler
父類并沒有給出__iter__
的具體定義,因此,如果我們要定義自己的采樣器,就要編寫繼承自Sampler
的子類,并且重寫__iter__
方法給出迭代返回樣本的邏輯。
但是,正如上文提到的,Dataloader
中的sampler
和batch_sampler
參數(shù)默認(rèn)情況下使用的那些采樣器(RandomSampler
、SequentialSampler
和BatchSampler
)一樣,PyTorch自己實(shí)現(xiàn)了很多Sampler
的子類,這些采樣器其實(shí)可以完成大部分功能,所以本節(jié)主要關(guān)注一些Sampler
的子類以及他們的用法,而不過多地討論如何自己實(shí)現(xiàn)一個Sample
r。
SequentialSampler
SequentialSampler
就是一個按照順序進(jìn)行采樣的采樣器,接收一個數(shù)據(jù)集做參數(shù)(實(shí)際上任何可迭代對象都可),按照順序?qū)ζ溥M(jìn)行采樣:
from torch.utils.data import SequentialSampler
pseudo_dataset = list(range(10))
for data in SequentialSampler(pseudo_dataset):
print(data, end=" ")
0 1 2 3 4 5 6 7 8 9
RandmSampler
RandomSampler
即一個隨機(jī)采樣器,返回隨機(jī)采樣的值,第一個參數(shù)依然是一個數(shù)據(jù)集(或可迭代對象)。還有一組參數(shù)如下:
-
replacement
:bool值,默認(rèn)是False
,設(shè)置為True
時(shí)表示可以采出重復(fù)的樣本 -
num_samples
:只有在replacement
設(shè)置為True
的時(shí)候才能設(shè)置此參數(shù),表示要采出樣本的個數(shù),默認(rèn)為數(shù)據(jù)集的總長度。有時(shí)候由于replacement
置True
的原因?qū)е轮貜?fù)數(shù)據(jù)被采樣,導(dǎo)致有些數(shù)據(jù)被采不到,所以往往會設(shè)置一個比較大的值
from torch.utils.data import RandomSampler
pseudo_dataset = list(range(10))
randomSampler1 = RandomSampler(pseudo_dataset)
randomSampler2 = RandomSampler(pseudo_dataset, replacement=True, num_samples=20)
print("for random sampler #1: ")
for data in randomSampler1:
print(data, end=" ")
print("\n\nfor random sampler #2: ")
for data in randomSampler2:
print(data, end=" ")
for random sampler #1:
4 5 2 9 3 0 6 8 7 1
for random sampler #2:
4 9 0 6 9 3 1 6 1 8 5 0 2 7 2 8 6 4 0 6
SubsetRandomSampler
SubsetRandomSampler
可以設(shè)置子集的隨機(jī)采樣,多用于將數(shù)據(jù)集分成多個集合,比如訓(xùn)練集和驗(yàn)證集的時(shí)候使用:
from torch.utils.data import SubsetRandomSampler
pseudo_dataset = list(range(10))
subRandomSampler1 = SubsetRandomSampler(pseudo_dataset[:7])
subRandomSampler2 = SubsetRandomSampler(pseudo_dataset[7:])
print("for subset random sampler #1: ")
for data in subRandomSampler1:
print(data, end=" ")
print("\n\nfor subset random sampler #2: ")
for data in subRandomSampler2:
print(data, end=" ")
for subset random sampler #1:
0 4 6 5 3 2 1
for subset random sampler #2:
7 8 9
WeightedRandomSampler
WeightedRandomSampler
和RandomSampler
的參數(shù)一致,但是不在傳入一個dataset
,第一個參數(shù)變成了weights
,只接收一個一定長度的list作為 weights
參數(shù),表示采樣的權(quán)重,采樣時(shí)會根據(jù)權(quán)重隨機(jī)從 list(range(len(weights)))
中采樣,即WeightedRandomSampler
并不需要傳入樣本集,而是只在一個根據(jù)weights
長度創(chuàng)建的數(shù)組中采樣,所以采樣的結(jié)果可能需要進(jìn)一步處理才能使用。weights
的所有元素之和不需要為1
。
from torch.utils.data import WeightedRandomSampler
pseudo_dataset = list(range(10))
weights = [1,1,1,1,1,10,10,10,10,10]
weightedRandomSampler = WeightedRandomSampler(weights, replacement=True, num_samples=20)
for data in weightedRandomSampler:
print(data, end=" ")
7 8 7 7 9 7 8 9 8 7 5 5 9 9 6 5 8 9 6 5
BatchSampler
以上的四個Sampler
在每次迭代都只返回一個索引,而BatchSampler
的作用是對上述這類返回一個索引的采樣器進(jìn)行包裝,按照設(shè)定的batch size返回一組索引,因其他的參數(shù)和上述的有些不同:
-
sampler
:一個Sampler
對象(或者一個可迭代對象) -
batch_size
:batch的大小 -
drop_last
:是否丟棄最后一個可能不足batch size大小的數(shù)據(jù)
from torch.utils.data import BatchSampler
pseudo_dataset = list(range(10))
batchSampler1 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=False)
batchSampler2 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=True)
print("for batch sampler #1: ")
for data in batchSampler1:
print(data, end=" ")
print("\n\nfor batch sampler #2: ")
for data in batchSampler2:
print(data, end=" ")
for batch sampler #1:
[0, 1, 2] [3, 4, 5] [6, 7, 8] [9]
for batch sampler #2:
[0, 1, 2] [3, 4, 5] [6, 7, 8]
collate_fn參數(shù)
Dataloader
其實(shí)還有一個比較重要的參數(shù)是 collate_fn
,它接收一個 callable
對象,比如一個函數(shù),它的作用是將每次迭代出來的數(shù)據(jù)打包成batch。
舉個例子,如果我們在Dataloader
中設(shè)置了batch_size
為8
,實(shí)際上,從Dataloader
所讀取的數(shù)據(jù)集Dataset
中取出數(shù)據(jù)時(shí)得到的是單獨(dú)的數(shù)據(jù),比如我們的例子中,每次采樣得到一個tuple:(image, label)
,因此collate_fn
的作用就有了,它負(fù)責(zé)包裝batch,即每從數(shù)據(jù)集中抽出8個這樣的tuple,它負(fù)責(zé)把8個(image, label)
包裝成一個list: [images, labels]
,這個list有兩個元素,每一個是一個tensor,比如第一個元素,實(shí)際上是一個 8×size(image) 的tensor,即給原來的數(shù)據(jù)增加了一維,也就是最前邊的batch的維度,labels
也同理。
有時(shí)候我們可能會需要實(shí)現(xiàn)自己的包裝邏輯,所以需要自定義一個函數(shù)來完成定制化的如上的內(nèi)容,只要將該函數(shù)名傳遞給collate_fn
參數(shù)即可。
PyTorch集成的數(shù)據(jù)集
實(shí)際上,PyTorch提供了很多常用數(shù)據(jù)集的接口,如果使用這些數(shù)據(jù)集的話,可以直接使用對應(yīng)的包加載,會方便很多,比如:
-
torchvision.datasets
就提供了很多視覺方向的數(shù)據(jù)集:https://pytorch.org/docs/stable/torchvision/datasets.html?highlight=torchvision%20datasets -
torchtext
則提供了很多文本處理方向的數(shù)據(jù)集 -
torchaudio
提供了很多音頻處理方向的數(shù)據(jù)集
等等
當(dāng)然PyTorch也可以配合其他包來獲得數(shù)據(jù)以及對數(shù)據(jù)進(jìn)行處理,比如:
- 對于視覺方面,配合Pillow、OpenCV等
- 對于音頻處理方面,配合scipy、librosa等
- 對于文本處理方面,配合Cython、NLTK、SpaCy等