??當我們面臨雜亂的數據怎么處理?PyTorch提供工具模塊給程序員使用;包含機器視覺中的圖像處理,當需要訓練自己的模型的時候,這些工具對數據的而預處理就很方便。
??實際LeNet,AlexNet,VGG,ResNet等神經網絡都有自己的圖像預處理套路。
??這里開辟一個主題,整理下PyTorch的圖像預處理的模塊,有這些模塊的了解,結合一些圖像標注工具,我們就可以創建的自己的圖像訓練集來訓練模型了。
??這個內容的主要關注點在于采樣器的使用,這樣結合交叉驗證等工具會提供更好的訓練。
- PyTorch提供了很多數據集
- MNIST
- Fashion-MNIST
- KMNIST
- EMNIST
- QMNIST
- FakeData
- COCO
- LSUN
- ImageFolder
- DatasetFolder
- ImageNet
- CIFAR
- STL10
- SVHN
- PhotoTour
- SBU
- Flickr
- VOC
- Cityscapes
- SBD
- USPS
- Kinetics-400
- HMDB51
- UCF101
MNIST數據集
MNIST類說明
CLASS torchvision.datasets.MNIST(
root, # 數據集的存放目錄
train=True, # True表示加載Train數據集,否則加載Test數據集
transform=None, # 變換函數,用來對數據特征進行變換處理
target_transform=None, # 變換函數,用來對數據特征進行變換處理
download=False) # True表示從互聯網下載數據集到root,否則認為已經下載,直接從root讀取(下載的時候,下載全部數據集)
from torchvision.datasets import MNIST
ds_mnist = MNIST(root="./datasets", download=True)
print(type(ds_mnist))
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets\MNIST\raw\train-images-idx3-ubyte.gz
100.1%
Extracting ./datasets\MNIST\raw\train-images-idx3-ubyte.gz to ./datasets\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz
113.5%
Extracting ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz
100.4%
Extracting ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz
180.4%
Extracting ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw
Processing...
Done!
<class 'torchvision.datasets.mnist.MNIST'>
MNIST類使用
- 所有的數據集類型都是DataSet類型
torch.utils.data.Dataset
- 數據集分成兩類:
- map-style datasets,
- torch.utils.data.Dataset
- iterable-style datasets.
- torch.utils.data.IterableDataset
- map-style datasets,
- MNIST屬于map-style datasets
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, IterableDataset
ds_mnist = MNIST(root="./datasets", download=False)
if isinstance(ds_mnist, Dataset) :
print("是Map數據集")
if isinstance(ds_mnist, IterableDataset) :
print("是Iterable數據集")
else:
print("不是迭代數據集")
是Map數據集
不是迭代數據集
-
MNIST的集成結構是:
- MNIST
- | -torchvision.datasets.vision.VisionDataset
- | -torch.utils.data.dataset.Dataset
- | -builtins.object
- MNIST
-
MNIST是Map-Style的數據集
- 數據遍歷方式:
- 取長度:
__len__(self)
- 根據下標獲取元素:
__getitem__(self, index)
- 取長度:
- 數據遍歷方式:
%matplotlib inline
import matplotlib.pyplot as plt
from torchvision.datasets import MNIST
train_mnist = MNIST(root="./datasets", train=True, download=False)
print(len(train_mnist))
test_mnist = MNIST(root="./datasets", train=False, download=False)
print(len(test_mnist))
print(type(train_mnist[0]))
plt.imshow(train_mnist[0][0], cmap='gray')
60000
10000
<class 'tuple'>
<matplotlib.image.AxesImage at 0x1c2950b82e8>
PyTorch的MNIST數據集樣本
- 數據的屬性:
- 與數據有關的屬性
- class_to_idx
- processed_folder
- raw_folder
- test_data
- test_labels
- train_data
- train_labels
- 與資源有關的屬性
- classes =
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', ...
- resources =
[('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyt...
- test_file = 'test.pt'
- training_file = 'training.pt'
- classes =
- 與數據有關的屬性
from torchvision.datasets import MNIST
test_mnist = MNIST(root="./datasets", train=False, download=False)
print(test_mnist.class_to_idx)
print(test_mnist.processed_folder)
print(test_mnist.raw_folder)
print(test_mnist.test_data.shape)
print(test_mnist.test_labels.shape)
print(test_mnist.train_data.shape)
print(test_mnist.train_labels.shape)
print(test_mnist.classes)
print(test_mnist.resources)
print(test_mnist.test_file)
print(test_mnist.training_file)
{'0 - zero': 0, '1 - one': 1, '2 - two': 2, '3 - three': 3, '4 - four': 4, '5 - five': 5, '6 - six': 6, '7 - seven': 7, '8 - eight': 8, '9 - nine': 9}
./datasets\MNIST\processed
./datasets\MNIST\raw
torch.Size([10000, 28, 28])
torch.Size([10000])
torch.Size([10000, 28, 28])
torch.Size([10000])
['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
[('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'f68b3c2dcbeaaa9fbdd348bbdeb94873'), ('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'd53e105ee54ea40749a09fcbcd1e9432'), ('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', '9fb629c4189551a2d022fa330f9573f3'), ('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 'ec29112dd5afa0611ce80d1b7f02629c')]
test.pt
training.pt
help(ds_mnist)
Help on MNIST in module torchvision.datasets.mnist object:
class MNIST(torchvision.datasets.vision.VisionDataset)
| `MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
|
| Args:
| root (string): Root directory of dataset where ``MNIST/processed/training.pt``
| and ``MNIST/processed/test.pt`` exist.
| train (bool, optional): If True, creates dataset from ``training.pt``,
| otherwise from ``test.pt``.
| download (bool, optional): If true, downloads the dataset from the internet
Map樣式-Dataset類型
Dataset抽象類
- 支持
__add__操作
:+
- 支持
__getitem__
操作:[]
from torch.utils.data import Dataset
help(Dataset)
Help on class Dataset in module torch.utils.data.dataset:
class Dataset(builtins.object)
| An abstract class representing a :class:`Dataset`.
|
|
| __add__(self, other)
|
| __getitem__(self, index)
|
| ----------------------------------------------------------------------
+
操作
- 數據集的合并操作
- 支持
__add__
運算+
- 返回的類型是:ConcatDataset
- 支持
from torchvision.datasets import MNIST
train_mnist = MNIST(root="./datasets", train=True, download=False)
test_mnist = MNIST(root="./datasets", train=False, download=False)
data = train_mnist + test_mnist
print(type(data), len(data))
<class 'torch.utils.data.dataset.ConcatDataset'> 70000
ConcatDataset類
- 提供使用列表形式提供數據合并。(構造模式)
- 提供靜態函數:cumsum(sequence)(工廠模式)
- 使用序列合并數據集
from torchvision.datasets import MNIST
train_mnist = MNIST(root="./datasets", train=True, download=False)
test_mnist = MNIST(root="./datasets", train=False, download=False)
data = train_mnist + test_mnist
# 可是使用構造器構造
from torch.utils.data.dataset import ConcatDataset
all_data = ConcatDataset([train_mnist, test_mnist])
print(type(all_data), len(all_data))
help(data)
<class 'torch.utils.data.dataset.ConcatDataset'> 70000
Help on ConcatDataset in module torch.utils.data.dataset object:
class ConcatDataset(Dataset)
| Dataset as a concatenation of multiple datasets
|
| __getitem__(self, idx)
| __init__(self, datasets)
| Initialize self. See help(type(self)) for accurate signature.
| __len__(self)
| ----------------------------------------------------------------------
| Static methods defined here:
| cumsum(sequence)
| ----------------------------------------------------------------------
| Data descriptors defined here:
| cummulative_sizes
| __add__(self, other)
- cumsum函數與cumulative_sizes屬性
- cummulative_sizes名字已經改成cumulative_sizes
from torchvision.datasets import MNIST
from torch.utils.data.dataset import ConcatDataset
train_mnist = MNIST(root="./datasets", train=True, download=False)
test_mnist = MNIST(root="./datasets", train=False, download=False)
all_data = ConcatDataset([train_mnist, test_mnist])
print(all_data.cumulative_sizes)
# ----------------------------------
list_sum = ConcatDataset.cumsum([train_mnist , test_mnist])
print(list_sum)
[60000, 70000]
[60000, 70000]
Iterable樣式-IterableDataset類
提供
__iter__(self)
實現的類。提供
__add__(self, other)
實現數據添加-
這個類是規范類,用戶集成用來實現自己的可迭代數據集。
- 這個類的規范在于可以被DataLoader訪問使用。
from torch.utils.data import IterableDataset
class RangeDataset(IterableDataset):
def __init__(self, start, end):
self.start=start
self.end=end
def __iter__(self):
return iter(range(self.start, self.end))
ds_range = RangeDataset(1,5)
lst = [x for x in ds_range]
print(lst)
print(list(ds_range))
print(enumerate(ds_range))
[1, 2, 3, 4]
[1, 2, 3, 4]
<enumerate object at 0x000001C28D518558>
DataLoader與Dataset
DataLoader 類說明
- DataLoader作用
- 主要用來對數據集進行分配,DataLoader的工作是基于Dataset與IterableDataset的。
CLASStorch.utils.data.DataLoader(
dataset, # 需要分匹配的數據集
batch_size=1, # 范培的批次大小
shuffle=False, # 是否洗牌
sampler=None, # 設置采樣器
batch_sampler=None, # 批次采樣器
num_workers=0, # 數據處理的子進程任務數
collate_fn=None, # 合并子進程數據的協作函數
pin_memory=False, # 拷貝數據到CUDA
drop_last=False, # 是否丟棄不足批次數的
timeout=0, # 設置加載時間(0:加載到完成為止)
worker_init_fn=None, # 一個回調函數,在worker數據前后后調用。
multiprocessing_context=None) # 進程上下文,一般情況設置為None,表示使用默認當前進程的上下文。
常規使用
-
常規使用主要是如下幾個常用的參數:
- dataset必須的
- batch_size批次數據個數
- shuffle是否洗牌打亂(隨機洗牌)
- drop_last是否丟棄剩余不足一個批次的數據。
-
使用模式:
- 迭代獲取數據,獲取的數據是Tensor類型。
-
例子1:迭代數據集
- 迭代數據集只有一個循環元素。
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader
class RangeDataset(IterableDataset):
def __init__(self, start, end):
self.start=start
self.end=end
def __iter__(self):
return iter(range(self.start, self.end))
ds_range = RangeDataset(1,5)
loader = DataLoader(ds_range)
for item in loader:
print(item)
tensor([1])
tensor([2])
tensor([3])
tensor([4])
-
例子2:map數據集
- 循環的是一個大小為2的元祖。
-
注意:
- MNIST的數據加載時PIL.Image.Image類型,需要使用函數轉換下,這里先使用把Image轉換為Tensor的函數ToTensor
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
train_mnist = MNIST(root="./datasets", train=True, download=False,transform=ToTensor())
loader = DataLoader(train_mnist, batch_size=10000, shuffle=True, drop_last=False)
for d, t in loader: # 數據與標簽
print(d.shape, t.shape)
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
采樣器的使用
- PyTorch提供了抽象類Sampler實現數據采樣,實現用戶定制的迭代返回數據。
- 該類提供兩個接口函數,實現數據處理的規范:
-
__iter__()
- 返回數據集索引的迭代器。
-
__len__()
- 返回迭代器的個數。
-
- 該類提供兩個接口函數,實現數據處理的規范:
數據采樣器參數sampler
- 定義一個采樣器
- sampler采樣器返回的必須是整數迭代器。
-
__len__
在這里沒有作用。
from torch.utils.data import Sampler
class MySampler(Sampler):
def __init__(self):
pass
def __iter__(self):
return iter(range(0,4))
def __len__(self):
return None
- 使用采樣器
- 采樣器sampler參數與shuffle參數矛盾的。
- smapler負責從原數據集獲取數據。(默認是全部采樣)
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
test_mnist = MNIST(root="./datasets", train=False, download=False)
sampler = MySampler()
loader = DataLoader(train_mnist, sampler=sampler, batch_size=100, drop_last=False)
for d, t in loader:
print(d.shape, t.shape)
print("-------------------------------")
loader = DataLoader(train_mnist, sampler=sampler, batch_size=2, drop_last=False)
for d, t in loader:
print(d.shape, t.shape)
torch.Size([4, 1, 28, 28]) torch.Size([4])
-------------------------------
torch.Size([2, 1, 28, 28]) torch.Size([2])
torch.Size([2, 1, 28, 28]) torch.Size([2])
批次采樣器batch_sampler參數
- 定義批次采樣器
- 返回的迭代器的迭代器。
from torch.utils.data import Sampler
class MyBatchSampler(Sampler):
def __init__(self):
pass
def __iter__(self):
return iter([iter(range(0,4)), iter(range(4,10))]) # 使用生成器也行。
def __len__(self):
return None
- 使用批次采樣器
- 批次采樣器參數與下面的參數矛盾,不能同時使用。
- batch_size, shuffle, sampler, and drop_last
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
test_mnist = MNIST(root="./datasets", train=False, download=False)
sampler = MyBatchSampler()
loader = DataLoader(train_mnist, batch_sampler=sampler)
for d, t in loader:
print(d.shape, t.shape)
torch.Size([4, 1, 28, 28]) torch.Size([4])
torch.Size([6, 1, 28, 28]) torch.Size([6])
其他采樣器
- PyTorch為了規范采樣器,提供了分類接口
- SequentialSampler:序列采樣
- RandomSampler:隨機采樣
- SubsetRandomSampler:子集隨機采樣
- WeightedRandomSampler:權重隨機采樣
- BatchSampler:批次采樣
- DistributedSampler:分布式采樣
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.utils.data import SequentialSampler
from torch.utils.data import RandomSampler
from torch.utils.data import SubsetRandomSampler
from torch.utils.data import WeightedRandomSampler
from torch.utils.data import BatchSampler
from torch.utils.data import DistributedSampler
test_mnist = MNIST(root="./datasets", train=False, download=False)
- SequentialSampler采樣器的例子
- 注意構造器需要一個數據集,這個數據集主要取他長度形成順序采樣器。順序是從0,1,2,....開始的,長度與數據集一樣。
- 構造器:
__init__(self, data_source)
ds = [1,3,10]
s_sampler = SequentialSampler(ds) # 使用ds的len作為訪問順序
for x in s_sampler:
print(x)
print("------------------------")
loader = DataLoader(train_mnist, sampler=s_sampler, batch_size=100, drop_last=False) # 采樣前3個。不是采樣1,3,10位置上的數據哈
for d, t in loader:
print(d.shape, t.shape)
0
1
2
torch.Size([3, 1, 28, 28]) torch.Size([3])
- 實現代碼:
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
- RandomSampler
- 構造器:
__init__(self, data_source, replacement=False, num_samples=None)
- data_source:數據集;這樣產生的隨機數不會查過數據集的個數值。
- replacement=True:使用num_samples作為采樣個數,false使用數據集長度作為采樣個數
- num_samples在replacement=True使用。replacement=False該參數沒有意義,還會拋出異常。
- 構造器:
- 實現代碼:
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
def __len__(self):
return self.num_samples
- 使用代碼
- 返回的隨機序列值不會超過數據集的長度值。
ds = [1,3,10] # 下面產生的隨機數,在[ 0,len(ds))范圍內
r_sampler = RandomSampler(ds, replacement=True, num_samples=10) # replacement=True的使用
r_sampler = RandomSampler(ds, replacement=False) # 使用ds的len作為訪問順序
for x in r_sampler:
print(x)
print("------------------------")
loader = DataLoader(train_mnist, sampler=r_sampler, batch_size=100, drop_last=False) # 采樣前3個。不是采樣1,3,10位置上的數據哈
for d, t in loader:
print(d.shape, t.shape)
2
0
1
------------------------
torch.Size([3, 1, 28, 28]) torch.Size([3])
- SubsetRandomSampler采樣器
- indices參數:一個索引序列。(會被打亂)
- 源代碼:
- 直接手工指定子集的索引,并隨機排序
class SubsetRandomSampler(Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.indices = indices
def __iter__(self):
return (self.indices[i] for i in torch.randperm(len(self.indices))) # 對下表隨機排序,randperm產生隨機全排列
def __len__(self):
return len(self.indices)
- 使用例子
ds = [1,3,10, 5000]
sr_sampler = SubsetRandomSampler(ds) # 采樣下表為1,3,10的數據,并打亂返回
for x in sr_sampler:
print(x)
print("------------------------")
loader = DataLoader(train_mnist, sampler=sr_sampler, batch_size=100, drop_last=False)
for d, t in loader:
print(d.shape, t.shape)
3
5000
10
1
------------------------
torch.Size([4, 1, 28, 28]) torch.Size([4])
- WeightedRandomSampler
- weights (sequence) : 權重,其和可以不為1。
- num_samples (int) :樣本總數
- replacement (bool): :放回抽樣
True
w = [0.4, 0.8, 0.2, 0.6, 0.9] # 權重高的其對應的下表抽取的機會高。 注意:抽的下標,權重知識表示抽到的概率
wr_sampler = WeightedRandomSampler(w, 10, True) # 抽取10個:當最后一個參數為False,則第二個參數要小于等于權重長度
for x in wr_sampler:
print(x)
print("------------------------")
loader = DataLoader(train_mnist, sampler=wr_sampler, batch_size=100, drop_last=False)
for d, t in loader:
print(d.shape, t.shape)
4
3
4
1
3
3
4
1
4
1
------------------------
torch.Size([10, 1, 28, 28]) torch.Size([10])
- 不放回抽樣
- 注意:抽取的個數必須與權重個數一樣或者小。
w = [0.4, 0.8, 0.2, 0.6, 0.9] # 權重高的其對應的下表抽取的機會高。 注意:抽的下標,權重知識表示抽到的概率
wr_sampler = WeightedRandomSampler(w, 4, False) # 抽取10個
for x in wr_sampler:
print(x)
print("------------------------")
loader = DataLoader(train_mnist, sampler=wr_sampler, batch_size=100, drop_last=False)
for d, t in loader:
print(d.shape, t.shape)
3
0
4
1
------------------------
torch.Size([4, 1, 28, 28]) torch.Size([4])
- BatchSampler
- 這個用于批次采樣器,不用這樣本采樣器。參數包含:
- sampler (Sampler) – 產生采樣下標集合.
- batch_size (python:int) – 批次大小.
- drop_last (bool) - 剩余的不足的是否保留。
- 這個用于批次采樣器,不用這樣本采樣器。參數包含:
w = [0.4, 0.8, 0.2, 0.6, 0.9]
wr_sampler = WeightedRandomSampler(w, 4, False)
b_sampler = BatchSampler(wr_sampler, batch_size=3, drop_last=False)
for x in b_sampler:
print(x)
print("------------------------")
loader = DataLoader(train_mnist, batch_sampler=b_sampler)
for d, t in loader:
print(d.shape, t.shape)
[1, 4, 3]
[0]
------------------------
torch.Size([3, 1, 28, 28]) torch.Size([3])
torch.Size([1, 1, 28, 28]) torch.Size([1])
- DistributedSampler
- 用于分布式訓練
dataset,
num_replicas=None, # 進程數
rank=None, # 進程排名
shuffle=True - 因為需要分布式的package,所以這里暫時不演示代碼。
worker的使用
- 指定num_workers參數,可以啟動多進程加載數據集。
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
train_mnist = MNIST(root="./datasets", train=True, download=False,transform=ToTensor())
loader = DataLoader(train_mnist, batch_size=10000, shuffle=True, drop_last=False, num_workers=2)
for d, t in loader: # 數據與標簽
print(d.shape, t.shape)
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
torch.Size([10000, 1, 28, 28]) torch.Size([10000])
- worker_init_fn 指定每個進程的處理與初始化
- 用來設置每個進程的狀態;
- 這個函數會通過參數傳遞一個pid(進程id)。
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
train_mnist = MNIST(root="./datasets", train=True, download=False,transform=ToTensor())
def worker_fn(w_id):
print(F"進程:{w_id}")
loader = DataLoader(train_mnist, batch_size=10000, shuffle=True, drop_last=False, num_workers=3, worker_init_fn=worker_fn)
for d, t in loader: # 數據與標簽
print(d.shape, t.shape)
---------------------------------------------------------------------------
Empty Traceback (most recent call last)
c:\program files\python36\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
760 try:
--> 761 data = self._data_queue.get(timeout=timeout)
762 return (True, data)
RuntimeError: DataLoader worker (pid(s) 13772, 14988, 21548) exited unexpectedly
- 說明:
- 上面的錯誤來自平臺的緣故:非Window平臺才能執行;
- 源代碼注釋中的說明:
- Window中不支持SIGCHLD信號,這個是典型的僵尸進程的處理方式。
# This raises a `RuntimeError` if any worker died expectedly. This error
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
# (only for non-Windows platforms), or the manual check below on errors
# and timeouts.
#
- collate_fn協作函數的處理:
- 這個函數會傳遞一個batch參數
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
train_mnist = MNIST(root="./datasets", train=True, download=False,transform=ToTensor())
def collate_fn(batch):
print("協作:", type(batch))
return batch
loader = DataLoader(train_mnist, batch_size=10000, shuffle=True, drop_last=False, num_workers=3, collate_fn=collate_fn)
for d, t in loader: # 數據與標簽
print(d.shape, t.shape)
---------------------------------------------------------------------------
Empty Traceback (most recent call last)
c:\program files\python36\lib\site-packages\torch\utils\data\dataloader.py in _try_get_data(self, timeout)
772 if len(failed_workers) > 0:
773 pids_str = ', '.join(str(w.pid) for w in failed_workers)
--> 774 raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
775 if isinstance(e, queue.Empty):
776 return (False, None)
RuntimeError: DataLoader worker (pid(s) 9988, 13420, 8472) exited unexpectedly
- 這個錯誤也是因為Window平臺的緣故。