@[TOC]
一、pytorch數(shù)據(jù)輸入
Dataset負(fù)責(zé)生產(chǎn)數(shù)據(jù),DataLoader負(fù)責(zé)數(shù)據(jù)的分批(batch_size)、采樣(sampler)、傳輸
Pytorch版本:1.0.1
1. Dataset
繼承torch.utils.data.Dataset,實(shí)現(xiàn)兩個(gè)函數(shù)即可:
- def len(self) 數(shù)據(jù)總數(shù)
- def getitem(self, index) 根據(jù)下標(biāo)獲取其中一條數(shù)據(jù)
2. DataLoader
將Dataset作為參數(shù),構(gòu)造一個(gè)torch.utils.data.DataLoader對(duì)象即可。
DataLoader其他參數(shù)見(jiàn)下文。
二、Dataloader參數(shù)匯總
dataset(Dataset):
傳入的數(shù)據(jù)集batch_size(int, optional):
每個(gè)batch有多少個(gè)樣本shuffle(bool, optional):
在每個(gè)epoch開(kāi)始的時(shí)候,對(duì)數(shù)據(jù)進(jìn)行重新打亂sampler(Sampler, optional):
自定義從數(shù)據(jù)集中取樣本的策略,如果指定這個(gè)參數(shù),那么shuffle必須為Falsebatch_sampler(Sampler, optional):
與sampler類(lèi)似,但是一次只返回一個(gè)batch的indices(索引),需要注意的是,一旦指定了這個(gè)參數(shù),那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)num_workers (int, optional):
這個(gè)參數(shù)決定了有幾個(gè)進(jìn)程來(lái)處理data loading。0意味著所有的數(shù)據(jù)都會(huì)被load進(jìn)主進(jìn)程。(默認(rèn)為0)collate_fn (callable, optional):
將一個(gè)list的sample組成一個(gè)mini-batch的函數(shù)pin_memory (bool, optional):
如果設(shè)置為T(mén)rue,那么data loader將會(huì)在返回它們之前,將tensors拷貝到CUDA中的固定內(nèi)存(CUDA pinned memory)中.drop_last (bool, optional):
如果設(shè)置為T(mén)rue:這個(gè)是對(duì)最后的未完成的batch來(lái)說(shuō)的,比如你的batch_size設(shè)置為64,而一個(gè)epoch只有100個(gè)樣本,那么訓(xùn)練的時(shí)候后面的36個(gè)就被扔掉了…
如果為False(默認(rèn)),那么會(huì)繼續(xù)正常執(zhí)行,只是最后的batch_size會(huì)小一點(diǎn)。timeout(numeric, optional):
如果是正數(shù),表明等待從worker進(jìn)程中收集一個(gè)batch等待的時(shí)間,若超出設(shè)定的時(shí)間還沒(méi)有收集到,那就不收集這個(gè)內(nèi)容了。這個(gè)numeric應(yīng)總是大于等于0。默認(rèn)為0worker_init_fn (callable, optional):
每個(gè)worker初始化函數(shù) If not None, this will be called on each
worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
2.1 sampler:分布式訓(xùn)練需DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
DataLoader構(gòu)造函數(shù)中相關(guān)代碼:
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset) ##如果shuffer就隨機(jī)
else:
sampler = SequentialSampler(dataset) ##否則順序采樣
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
self.sampler = sampler
self.batch_sampler = batch_sampler
batch_sampler是sampler的封裝,可自定義批次數(shù)據(jù)的構(gòu)造。默認(rèn)BatchSampler相關(guān)源碼:
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx) ##遍歷sampler獲取數(shù)據(jù),滿(mǎn)batch_size就yield
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
2.2 collate_fn:將batch的數(shù)據(jù)重新組裝
例如cirtorch中將數(shù)據(jù)拆成input_data和target兩個(gè)數(shù)據(jù)。
因Dataset中g(shù)et_item返回input_data和target兩個(gè)值,如果不用該函數(shù),每個(gè)batch的數(shù)據(jù)應(yīng)該是[batch_size,2(先input_data再target),,,],經(jīng)過(guò)該函數(shù)將變成([batch_size,,,],[batch_size,,]),第一個(gè)數(shù)據(jù)全是input_data,第二個(gè)數(shù)據(jù)全是target。
2.3 pin_memory=True:提高數(shù)據(jù)從cpu到gpu傳輸效率
pin_memory可在cpu主存(內(nèi)存)中分配不可交換到swap(緩存)的內(nèi)存。。默認(rèn)內(nèi)存分配中的數(shù)據(jù)都可交換到swap中,那CUDA驅(qū)動(dòng)會(huì)通過(guò)DRAM機(jī)制將數(shù)據(jù)從內(nèi)存?zhèn)鞯紾PU顯存時(shí)會(huì)復(fù)制2次(先復(fù)制到一臨時(shí)不可見(jiàn)pinned固定內(nèi)存,再往顯存中復(fù)制),因此pin_memory=True可提高約2倍cpu到gpu傳輸效率(.cuda()或 .to(device)的時(shí)候)。相見(jiàn)CPU和GPU內(nèi)存交互。
【拓展】Elasticsearch中的Memlock(內(nèi)存鎖定)可申請(qǐng)固定大小且不可交換內(nèi)存空間。
三、DataLoader的并行
# Our data model looks like this (queues are indicated with curly brackets):
#
# main process ||
# | ||
# {index_queue} ||
# | ||
# worker processes || DATA
# | ||
# {worker_result_queue} || FLOW
# | ||
# pin_memory_thread of main process || DIRECTION
# | ||
# {data_queue} ||
# | ||
# data output \/
#
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
# `pin_memory=False`.
- 基于multiprocessing多進(jìn)程
- 每個(gè)子進(jìn)程的輸入輸出,通過(guò)兩個(gè)主要的隊(duì)列(multiprocessing.Queue()): index_queue要處理的下標(biāo)、worker_result_queue要返回的下標(biāo)。
- 每個(gè)worker一次產(chǎn)生一個(gè)batch的數(shù)據(jù)
- 返回batch數(shù)據(jù)前放入下一個(gè)批次數(shù)據(jù)下標(biāo)
- 構(gòu)造函數(shù)子進(jìn)程初始化:
self.index_queues = []
self.workers = []
for i in range(self.num_workers):
index_queue = multiprocessing.Queue() # 1.每個(gè)子進(jìn)程一個(gè)隊(duì)列放要處理的下標(biāo)
index_queue.cancel_join_thread()
w = multiprocessing.Process(
target=_utils.worker._worker_loop, # 每個(gè)子進(jìn)程循環(huán)執(zhí)行的函數(shù)
args=(self.dataset, index_queue,
self.worker_result_queue, self.done_event, #2.self.worker_result_queue 多子進(jìn)程公用要返回batch數(shù)據(jù)的隊(duì)列
self.collate_fn, base_seed + i,
self.worker_init_fn, i))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
# start a process and pass the arguments over via a pipe.
# Therefore, we only add a worker to self.workers list after
# it started, so that we do not call .join() if program dies
# before it starts, and __del__ tries to join but will get:
# AssertionError: can only join a started process.
w.start()
self.index_queues.append(index_queue)
self.workers.append(w)
3.1 index_queue 要處理的數(shù)據(jù)下標(biāo)
每個(gè)worker有一個(gè)index_queue dataloader.py#L544-L552
每個(gè)worker從index_queue取要處理的下標(biāo) dataloader.py#L124
dataloader輸出一次數(shù)據(jù)前先往index_queue中放一次下標(biāo), _process_next_batch函數(shù):
def _process_next_batch(self, batch):
self.rcvd_idx += 1
self._put_indices() ## 先放下一批數(shù)據(jù)下標(biāo)
if isinstance(batch, ExceptionWrapper):
raise batch.exc_type(batch.exc_msg)
return batch ## 再返回該批數(shù)據(jù)
_put_indices依次往不同worker所屬的index_queue中放 dataloader.py#L644-L652
完整的dataloader next函數(shù):
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = pin_memory_batch(batch)
return batch
# check if the next sample has already been generated
if self.rcvd_idx in self.reorder_dict:
batch = self.reorder_dict.pop(self.rcvd_idx)
return self._process_next_batch(batch) ## 5. 之前以及取出來(lái)該下標(biāo)數(shù)據(jù),直接返回
if self.batches_outstanding == 0:
self._shutdown_workers()
raise StopIteration
while True: ## 1.直到取的數(shù)據(jù)下標(biāo)正確才return
assert (not self.shutdown and self.batches_outstanding > 0)
idx, batch = self._get_batch() ## 2.從worker_result_queue中獲取數(shù)據(jù)
self.batches_outstanding -= 1
if idx != self.rcvd_idx:
# store out-of-order samples
self.reorder_dict[idx] = batch ## 3.下標(biāo)不對(duì)先存一下
continue
return self._process_next_batch(batch) ## 4.內(nèi)部先放下一批數(shù)據(jù)下標(biāo)再返回batch數(shù)據(jù)
3.2 worker_result_queue 返回結(jié)果
每個(gè)worker一直在執(zhí)行的循環(huán)_worker_loop,其中worker_result_queue作為_(kāi)worker_loop函數(shù)的data_queue傳入(dataloader.py#L544-L552),相見(jiàn):
def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
try:
global _use_shared_memory
_use_shared_memory = True
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
# module's handlers are executed after Python returns from C low-level
# handlers, likely when the same fatal signal happened again already.
# https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
_set_worker_signal_handlers()
torch.set_num_threads(1)
random.seed(seed)
torch.manual_seed(seed)
data_queue.cancel_join_thread()
if init_fn is not None:
init_fn(worker_id)
watchdog = ManagerWatchdog()
while watchdog.is_alive():
try:
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) ##從index_queue中獲取要處理的下標(biāo)
except queue.Empty:
continue
if r is None:
# Received the final signal
assert done_event.is_set()
return
elif done_event.is_set():
# Done event is set. But I haven't received the final signal
# (None) yet. I will keep continuing until get it, and skip the
# processing steps.
continue
idx, batch_indices = r
try:
samples = collate_fn([dataset[i] for i in batch_indices]) ##1.根據(jù)下標(biāo)取樣本數(shù)據(jù)
except Exception:
# It is important that we don't store exc_info in a variable,
# see NOTE [ Python Traceback Reference Cycle Problem ]
data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else: ## 2. 沒(méi)有拋異常就將樣本數(shù)據(jù)放入結(jié)果返回隊(duì)列
data_queue.put((idx, samples))
del samples
except KeyboardInterrupt:
# Main process will raise KeyboardInterrupt anyways.
pass