PyTorch--數據讀取和操作

Pytorch 中比較重要的是對數據的處理,其中,進行數據讀取的一般有三個類:

  • Dataset
  • DataLoader

其中,這兩個是一個依次封裝的關系:“Dataset被封裝進DataLoader,DataLoader再被封裝進DataLoaderIter

Dataset

Dataset位于torch.utils.data.Dataset,每當我們自定義類MyDataset必須要繼承它并實現其兩個成員函數:

  • __len__()
  • __getitem__()
import torch
from torch.utils.data import Dataset
import pandas as pd

# 定義自己的類
class MyDataset(Dataset):

    # 初始化
    def __init__(self, file_name):
        # 讀入數據
        self.data = pd.read_csv(file_name, sep='\t', usecols=['Phrase', 'Sentiment'])

    # 返回df的長度
    def __len__(self):
        return len(self.data)

    # 獲取第idx+1列的數據
    def __getitem__(self, idx):
        return self.data.iloc[idx].Phrase, self.data.iloc[idx].Sentiment

# 通過實例化對象來訪問該類
# 假設同目錄下存在名為train.tsv的文件
ds = MyDataset('../datasets/train.tsv')
print(ds.data.head()) # 頭數據
print(ds.data.iloc[1]) # 按行索引獲取數據

# 結果
                                              Phrase  Sentiment
0  A series of escapades demonstrating the adage ...          1
1  A series of escapades demonstrating the adage ...          2
2                                           A series          2
3                                                  A          2
4                                             series          2
Phrase       A series of escapades demonstrating the adage ...
Sentiment                                                    2
Name: 1, dtype: object

DataLoader

DataLoader位于torch.utils.data.DataLoader, 為我們提供了對Dataset的讀取操作

# 僅僅列舉了常用的幾個參數
torch.nn.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
  • dataset: 上面所實現的自定義類Dataset
  • batch_size : 默認為1,每次讀取的batch的大小
  • shuffle : 默認為False, 是否對數據進行shuffle操作(簡單理解成將數據集打亂)
  • num_works : 默認為0,表示在加載數據的時候每次使用子進程的數量,即簡單的多線程預讀數據的方法

DataLoader返回的是一個迭代器,我們通過這個迭代器來獲取數據

Dataloder的目的是將給定的n個數據, 經過Dataloader操作后, 在每一次調用時調用一個小batch, 如:

  • 給出的是: (5000,28,28) , 表示有5000個樣本,每個樣本的size為(28,28)
  • 經過Dataloader處理后, 一次得到的是(100,28,28)(假設batch_size大小為100), 表示本次取出100個樣本, 每個樣本的size為(28,28)
# 連接上面的Dataset實現代碼

from torch.utils.data import DataLoader

dl = DataLoader(ds, batch_size=10, shuffle=True, num_workers=2)

通過迭代器來分次獲取數據

dl_data = iter(dl)
print(next(dl_data))

# 結果
[('thematic ironies', 'whimsical and relevant today', "director George Hickenlooper 's approach to the material is too upbeat", 'direct-to-video\\/DVD category', 'Four Feathers', 'may well be the only one laughing at his own joke', 'the end credits', "What sets Ms. Birot 's film apart from others in the genre", 'overcoming-obstacles', 'homage pokepie hat , but as a character'), tensor([2, 3, 2, 1, 2, 1, 2, 3, 2, 2])]

或,直接通過for循環進行遍歷輸出

for i, data in enumerate(dl):
    print(i, data)
    # 這里只循環一次,所以用break
    break

#結果
0 [('huge action sequence', ', characterization , poignancy , and intelligence', 'potentially incredibly twisting mystery', 'felt disrespected', 'a rather bland', 'the character dramas', 'a key strength', "'s never dull and always looks good", "the Queen 's", 'uncompromising knowledge'), tensor([3, 3, 3, 1, 1, 2, 3, 3, 2, 3])]
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容