PyTorch教程-5:詳解PyTorch中加載數(shù)據(jù)的方法--Dataset、Dataloader、Sampler、collate_fn等

筆者PyTorch的全部簡單教程請?jiān)L問:http://www.lxweimin.com/nb/48831659

PyTorch教程-5:詳解PyTorch中加載數(shù)據(jù)的方法--Dataset、Dataloader、Sampler、collate_fn等

數(shù)據(jù)讀取是所有訓(xùn)練模型任務(wù)中最基礎(chǔ)最重要的一步,PyTorch為數(shù)據(jù)集的讀取、加載和使用提供了很好的機(jī)制,使得數(shù)據(jù)加載的工作變得異常簡單而且具有非常高的定制性。

Dataset、Dataloader、Sampler的關(guān)系

PyTorch中對于數(shù)據(jù)集的處理有三個非常重要的類:DatasetDataloaderSampler,它們均是 torch.utils.data 包下的模塊(類)。它們的關(guān)系可以這樣理解:

  • Dataset是數(shù)據(jù)集的類,主要用于定義數(shù)據(jù)集
  • Sampler是采樣器的類,用于定義從數(shù)據(jù)集中選出數(shù)據(jù)的規(guī)則,比如是隨機(jī)取數(shù)據(jù)還是按照順序取等等
  • Dataloader是數(shù)據(jù)的加載類,它是對于DatasetSampler的進(jìn)一步包裝,即其實(shí)DatasetSampler會作為參數(shù)傳遞給Dataloader,用于實(shí)際讀取數(shù)據(jù),可以理解為它是這個工作的真正實(shí)踐者,而DatasetSampler負(fù)責(zé)定義。我們訓(xùn)練、測試所獲得的數(shù)據(jù)也是Dataloader直接給我們的。

總的來說Dataset定義了整個數(shù)據(jù)集,Sampler提供了取數(shù)據(jù)的機(jī)制,最后由Dataloader取完成取數(shù)據(jù)的任務(wù)。

本篇以一個最簡單的例子為例,比如有一個文件夾(data-folder)中存儲訓(xùn)練的數(shù)據(jù)(一共30張圖片:0.png 到 29.png),他們對應(yīng)的標(biāo)簽被寫在了一個labels.txt文件中,第n行對應(yīng)n-1.png的標(biāo)簽,是一個三分類問題,即0、1和2三種標(biāo)簽(虛構(gòu)的數(shù)據(jù)集,不具有任何意義)。目錄結(jié)構(gòu)如下:

|--- Project
   |--- main.py
   |--- labels.txt
   |--- data-folder
      |--- 0.png
      |--- 1.png
      |--- ……
      |--- 29.png

Dataset

Dataset 位于 torch.utils.data 下,我們通過定義繼承自這個類的子類來自定義數(shù)據(jù)集。它有兩個最重要的方法需要重寫,實(shí)際上它們都是類的特殊方法:

  • __getitem__(self, index):傳入?yún)?shù)index為下標(biāo),返回?cái)?shù)據(jù)集中對應(yīng)下標(biāo)的數(shù)據(jù)組(數(shù)據(jù)和標(biāo)簽)
  • __len__(self):返回?cái)?shù)據(jù)集的大小

簡單說,重寫了這兩個方法的繼承自 Dataset 的類都可以作為數(shù)據(jù)集的定義類使用,即一個Dataset類的必要結(jié)構(gòu):

class Dataset(torch.utils.data.Dataset):
    def __init__(self, filepath=None,dataLen=None):
        pass
        
    def __getitem__(self, index):
        pass

    def __len__(self):
        pass

如下就是我們的例子的加載實(shí)例,其中的 image2tensor 使用了torchvision.transforms 完成了一個簡單的從PIL.Image 格式的圖片到 tensor 的轉(zhuǎn)換,可以先不必在意,后面會詳細(xì)地講到 transforms 這個超級重要的工具:

from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms


class MyDataset(Dataset):
    def __init__(self, images_folder_path, labels_file_path):
        self.images_folder_path = images_folder_path

        with open(labels_file_path, 'r') as file:
            self.labels = list(map(int, file.read().splitlines()))

    def __getitem__(self, item):
        image = Image.open(os.path.join(self.images_folder_path, "{}.png".format(item)))
        image = self.image2tensor(image)
        label = self.labels[item]

        return (image, label)

    def __len__(self):
        return len(self.labels)

    def image2tensor(self, image):
        """
        transform PIL.Image to tensor
        :param image: image in PIL.Image format
        :return: image in tensor format
        """
        transform = transforms.Compose([
            transforms.ToTensor()
        ])
        image = image.convert('RGB')
        return transform(image)


myDataset = MyDataset("./data-folder/", "./labels.txt")

Dataloader

DataloaderDataset(和Sampler等)打包,完成最后對數(shù)據(jù)的讀取的執(zhí)行工作,一般不需要自己定義或者重寫一個Dataloader的類(或子類),直接使用即可,通過傳入?yún)?shù)定制Dataloader,定制化的功能應(yīng)該在Dataset(和Sampler等)中完成了。

Dataloader的完整簽名見:https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader

Dataloader的一些常用參數(shù)

Dataloader的一些重要的參數(shù)如下,除了第一個dataset參數(shù)外,其他均為可選參數(shù)

  • dataset(第一個參數(shù),必須的參數(shù)):一個Dataset的實(shí)例,即傳入的數(shù)據(jù)集(或者其他可迭代對象)
  • batch_size:整數(shù)值,每個batch的樣本數(shù)量,即batch大小,默認(rèn)為1
  • shuffle:bool值,如果設(shè)置為True,則在每個epoch開始的時(shí)候,會對數(shù)據(jù)集的數(shù)據(jù)進(jìn)行重新排序,默認(rèn)False
  • sampler:傳入一個自定義的Sampler實(shí)例,定義從數(shù)據(jù)集中取樣本的策略,Sampler每次返回一個索引,默認(rèn)為None
  • batch_sampler:也是傳入一個自定義的Sampler實(shí)例,但是與sampler參數(shù)不同的是,它接收的Sampler是一次返回一個batch的索引,默認(rèn)為None
  • num_workers:整數(shù)值,定義有幾個進(jìn)程來處理數(shù)據(jù)。0意味著所有的數(shù)據(jù)都會被加載進(jìn)主進(jìn)程,默認(rèn)0
  • collate_fn:傳入一個函數(shù),它的作用是將一個batch的樣本打包成一個大的tensortensor的第一維就是這些樣本,如果沒有特殊需求可以保持默認(rèn)即可(后邊會詳細(xì)介紹
  • pin_memory:bool值,如果為True,那么將加載的數(shù)據(jù)拷貝到CUDA中的固定內(nèi)存中。
  • drop_last:bool值,如果為True,則對最后的一個batch來說,如果不足batch_size個樣本了就舍棄,如果為False,也會繼續(xù)正常執(zhí)行,只是最后的一個batch可能會小一點(diǎn)(剩多少算多少),默認(rèn)False
  • timeout:如果是正數(shù),表明等待從加載一個batch等待的時(shí)間,若超出設(shè)定的時(shí)間還沒有加載完,就放棄這個batch,如果是0,表示不設(shè)置限制時(shí)間。默認(rèn)為0

Dataloader參數(shù)之間的互斥

值得注意的是,Dataloader的參數(shù)之間存在互斥的情況,主要針對自己定義的采樣器:

  • sampler:如果自行指定了sampler參數(shù),則shuffle必須保持默認(rèn)值,即False
  • batch_sampler:如果自行指定了batch_sampler參數(shù),則 batch_sizeshufflesamplerdrop_last 都必須保持默認(rèn)值

如果沒有指定自己是采樣器,那么默認(rèn)的情況下(即samplerbatch_sampler均為None的情況下),dataloader的采樣策略是如何的呢:

  • sampler
    • shuffle = Truesampler采用 RandomSampler,即隨機(jī)采樣
    • shuffle = Flasesampler采用 SequentialSampler,即按照順序采樣
  • batch_sampler:采用 BatchSampler,即根據(jù) batch_size 進(jìn)行batch采樣

上面提到的 RandomSamplerSequentialSamplerBatchSampler都是PyTorch自己實(shí)現(xiàn)的,且它們都是Sampler的子類,后邊會詳述

Dataloader的實(shí)例

下面我們繼續(xù)我們的例子,定義Dataloader的實(shí)例,從我們定義的 myDataset 數(shù)據(jù)集中加載數(shù)據(jù),每一個batch大小為8。并且我們使用了一個循環(huán)來驗(yàn)證其工作的情況:

from torch.utils.data import DataLoader

myDataloader = DataLoader(myDataset, batch_size=8)

for epoch in range(2):
    for data in myDataloader:
        images, labels = data[0], data[1]
        print(len(images))
        print(labels)
        # train your module

8
tensor([0, 1, 1, 1, 2, 0, 1, 2])
8
tensor([0, 2, 1, 1, 1, 1, 2, 0])
8
tensor([1, 0, 0, 0, 0, 1, 1, 0])
6
tensor([2, 0, 1, 1, 1, 2])
8
tensor([0, 1, 1, 1, 2, 0, 1, 2])
8
tensor([0, 2, 1, 1, 1, 1, 2, 0])
8
tensor([1, 0, 0, 0, 0, 1, 1, 0])
6
tensor([2, 0, 1, 1, 1, 2])

Sampler

Sampler類是一個很抽象的父類,其主要用于設(shè)置從一個序列中返回樣本的規(guī)則,即采樣的規(guī)則。Sampler是一個可迭代對象,使用step方法可以返回下一個迭代后的結(jié)果,因此其主要的類方法就是 __iter__ 方法,定義了迭代后返回的內(nèi)容。其父類的代碼如下(PyTorch 1.7):

class Sampler(Generic[T_co]):
    def __init__(self, data_source: Optional[Sized]) -> None:
        pass

    def __iter__(self) -> Iterator[T_co]:
        raise NotImplementedError

從上述代碼可見,其實(shí)Sampler父類并沒有給出__iter__ 的具體定義,因此,如果我們要定義自己的采樣器,就要編寫繼承自Sampler的子類,并且重寫__iter__ 方法給出迭代返回樣本的邏輯。

但是,正如上文提到的,Dataloader中的samplerbatch_sampler參數(shù)默認(rèn)情況下使用的那些采樣器(RandomSamplerSequentialSamplerBatchSampler)一樣,PyTorch自己實(shí)現(xiàn)了很多Sampler的子類,這些采樣器其實(shí)可以完成大部分功能,所以本節(jié)主要關(guān)注一些Sampler子類以及他們的用法,而不過多地討論如何自己實(shí)現(xiàn)一個Sampler。

SequentialSampler

SequentialSampler就是一個按照順序進(jìn)行采樣的采樣器,接收一個數(shù)據(jù)集做參數(shù)(實(shí)際上任何可迭代對象都可),按照順序?qū)ζ溥M(jìn)行采樣:

from torch.utils.data import SequentialSampler

pseudo_dataset = list(range(10))
for data in SequentialSampler(pseudo_dataset):
    print(data, end=" ")

0 1 2 3 4 5 6 7 8 9 

RandmSampler

RandomSampler 即一個隨機(jī)采樣器,返回隨機(jī)采樣的值,第一個參數(shù)依然是一個數(shù)據(jù)集(或可迭代對象)。還有一組參數(shù)如下:

  • replacement:bool值,默認(rèn)是False,設(shè)置為True時(shí)表示可以采出重復(fù)的樣本
  • num_samples:只有在replacement設(shè)置為True的時(shí)候才能設(shè)置此參數(shù),表示要采出樣本的個數(shù),默認(rèn)為數(shù)據(jù)集的總長度。有時(shí)候由于replacementTrue的原因?qū)е轮貜?fù)數(shù)據(jù)被采樣,導(dǎo)致有些數(shù)據(jù)被采不到,所以往往會設(shè)置一個比較大的值
from torch.utils.data import RandomSampler

pseudo_dataset = list(range(10))

randomSampler1 = RandomSampler(pseudo_dataset)
randomSampler2 = RandomSampler(pseudo_dataset, replacement=True, num_samples=20)

print("for random sampler #1: ")
for data in randomSampler1:
    print(data, end=" ")

print("\n\nfor random sampler #2: ")
for data in randomSampler2:
    print(data, end=" ")

for random sampler #1: 
4 5 2 9 3 0 6 8 7 1 

for random sampler #2: 
4 9 0 6 9 3 1 6 1 8 5 0 2 7 2 8 6 4 0 6 

SubsetRandomSampler

SubsetRandomSampler 可以設(shè)置子集的隨機(jī)采樣,多用于將數(shù)據(jù)集分成多個集合,比如訓(xùn)練集和驗(yàn)證集的時(shí)候使用:

from torch.utils.data import SubsetRandomSampler

pseudo_dataset = list(range(10))

subRandomSampler1 = SubsetRandomSampler(pseudo_dataset[:7])
subRandomSampler2 = SubsetRandomSampler(pseudo_dataset[7:])

print("for subset random sampler #1: ")
for data in subRandomSampler1:
    print(data, end=" ")

print("\n\nfor subset random sampler #2: ")
for data in subRandomSampler2:
    print(data, end=" ")

for subset random sampler #1: 
0 4 6 5 3 2 1 

for subset random sampler #2: 
7 8 9 

WeightedRandomSampler

WeightedRandomSamplerRandomSampler的參數(shù)一致,但是不在傳入一個dataset,第一個參數(shù)變成了weights,只接收一個一定長度的list作為 weights 參數(shù),表示采樣的權(quán)重,采樣時(shí)會根據(jù)權(quán)重隨機(jī)從 list(range(len(weights))) 中采樣,即WeightedRandomSampler并不需要傳入樣本集,而是只在一個根據(jù)weights長度創(chuàng)建的數(shù)組中采樣,所以采樣的結(jié)果可能需要進(jìn)一步處理才能使用。weights的所有元素之和不需要為1

from torch.utils.data import WeightedRandomSampler

pseudo_dataset = list(range(10))
weights = [1,1,1,1,1,10,10,10,10,10]

weightedRandomSampler = WeightedRandomSampler(weights, replacement=True, num_samples=20)

for data in weightedRandomSampler:
    print(data, end=" ")

7 8 7 7 9 7 8 9 8 7 5 5 9 9 6 5 8 9 6 5 

BatchSampler

以上的四個Sampler在每次迭代都只返回一個索引,而BatchSampler的作用是對上述這類返回一個索引的采樣器進(jìn)行包裝,按照設(shè)定的batch size返回一組索引,因其他的參數(shù)和上述的有些不同:

  • sampler:一個Sampler對象(或者一個可迭代對象)
  • batch_size:batch的大小
  • drop_last:是否丟棄最后一個可能不足batch size大小的數(shù)據(jù)
from torch.utils.data import BatchSampler
pseudo_dataset = list(range(10))

batchSampler1 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=False)
batchSampler2 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=True)

print("for batch sampler #1: ")
for data in batchSampler1:
    print(data, end=" ")

print("\n\nfor batch sampler #2: ")
for data in batchSampler2:
    print(data, end=" ")

for batch sampler #1: 
[0, 1, 2] [3, 4, 5] [6, 7, 8] [9] 

for batch sampler #2: 
[0, 1, 2] [3, 4, 5] [6, 7, 8] 

collate_fn參數(shù)

Dataloader其實(shí)還有一個比較重要的參數(shù)是 collate_fn,它接收一個 callable 對象,比如一個函數(shù),它的作用是將每次迭代出來的數(shù)據(jù)打包成batch。

舉個例子,如果我們在Dataloader中設(shè)置了batch_size8,實(shí)際上,從Dataloader所讀取的數(shù)據(jù)集Dataset中取出數(shù)據(jù)時(shí)得到的是單獨(dú)的數(shù)據(jù),比如我們的例子中,每次采樣得到一個tuple:(image, label),因此collate_fn 的作用就有了,它負(fù)責(zé)包裝batch,即每從數(shù)據(jù)集中抽出8個這樣的tuple,它負(fù)責(zé)把8個(image, label)包裝成一個list: [images, labels],這個list有兩個元素,每一個是一個tensor,比如第一個元素,實(shí)際上是一個 8×size(image) 的tensor,即給原來的數(shù)據(jù)增加了一維,也就是最前邊的batch的維度,labels也同理。

有時(shí)候我們可能會需要實(shí)現(xiàn)自己的包裝邏輯,所以需要自定義一個函數(shù)來完成定制化的如上的內(nèi)容,只要將該函數(shù)名傳遞給collate_fn參數(shù)即可。

PyTorch集成的數(shù)據(jù)集

實(shí)際上,PyTorch提供了很多常用數(shù)據(jù)集的接口,如果使用這些數(shù)據(jù)集的話,可以直接使用對應(yīng)的包加載,會方便很多,比如:

當(dāng)然PyTorch也可以配合其他包來獲得數(shù)據(jù)以及對數(shù)據(jù)進(jìn)行處理,比如:

  • 對于視覺方面,配合Pillow、OpenCV等
  • 對于音頻處理方面,配合scipy、librosa等
  • 對于文本處理方面,配合Cython、NLTK、SpaCy等
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容