Dataloader重要參數(shù)與內(nèi)部機(jī)制

@[TOC]

一、pytorch數(shù)據(jù)輸入

Dataset負(fù)責(zé)生產(chǎn)數(shù)據(jù),DataLoader負(fù)責(zé)數(shù)據(jù)的分批(batch_size)、采樣(sampler)、傳輸
Pytorch版本:1.0.1

1. Dataset

繼承torch.utils.data.Dataset,實(shí)現(xiàn)兩個(gè)函數(shù)即可:

  • def len(self) 數(shù)據(jù)總數(shù)
  • def getitem(self, index) 根據(jù)下標(biāo)獲取其中一條數(shù)據(jù)

2. DataLoader

將Dataset作為參數(shù),構(gòu)造一個(gè)torch.utils.data.DataLoader對(duì)象即可。
DataLoader其他參數(shù)見(jiàn)下文。

二、Dataloader參數(shù)匯總

  • dataset(Dataset):
    傳入的數(shù)據(jù)集

  • batch_size(int, optional):
    每個(gè)batch有多少個(gè)樣本

  • shuffle(bool, optional):
    在每個(gè)epoch開(kāi)始的時(shí)候,對(duì)數(shù)據(jù)進(jìn)行重新打亂

  • sampler(Sampler, optional):
    自定義從數(shù)據(jù)集中取樣本的策略,如果指定這個(gè)參數(shù),那么shuffle必須為False

  • batch_sampler(Sampler, optional):
    與sampler類(lèi)似,但是一次只返回一個(gè)batch的indices(索引),需要注意的是,一旦指定了這個(gè)參數(shù),那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)

  • num_workers (int, optional):
    這個(gè)參數(shù)決定了有幾個(gè)進(jìn)程來(lái)處理data loading。0意味著所有的數(shù)據(jù)都會(huì)被load進(jìn)主進(jìn)程。(默認(rèn)為0)

  • collate_fn (callable, optional):
    將一個(gè)list的sample組成一個(gè)mini-batch的函數(shù)

  • pin_memory (bool, optional):
    如果設(shè)置為T(mén)rue,那么data loader將會(huì)在返回它們之前,將tensors拷貝到CUDA中的固定內(nèi)存(CUDA pinned memory)中.

  • drop_last (bool, optional):
    如果設(shè)置為T(mén)rue:這個(gè)是對(duì)最后的未完成的batch來(lái)說(shuō)的,比如你的batch_size設(shè)置為64,而一個(gè)epoch只有100個(gè)樣本,那么訓(xùn)練的時(shí)候后面的36個(gè)就被扔掉了…
    如果為False(默認(rèn)),那么會(huì)繼續(xù)正常執(zhí)行,只是最后的batch_size會(huì)小一點(diǎn)。

  • timeout(numeric, optional):
    如果是正數(shù),表明等待從worker進(jìn)程中收集一個(gè)batch等待的時(shí)間,若超出設(shè)定的時(shí)間還沒(méi)有收集到,那就不收集這個(gè)內(nèi)容了。這個(gè)numeric應(yīng)總是大于等于0。默認(rèn)為0

  • worker_init_fn (callable, optional):
    每個(gè)worker初始化函數(shù) If not None, this will be called on each
    worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

2.1 sampler:分布式訓(xùn)練需DistributedSampler

train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)

DataLoader構(gòu)造函數(shù)中相關(guān)代碼:

        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)  ##如果shuffer就隨機(jī)  
                else:
                    sampler = SequentialSampler(dataset)  ##否則順序采樣  
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler  

batch_sampler是sampler的封裝,可自定義批次數(shù)據(jù)的構(gòu)造。默認(rèn)BatchSampler相關(guān)源碼:

    def __iter__(self):
        batch = []
        for idx in self.sampler:
            batch.append(idx)      ##遍歷sampler獲取數(shù)據(jù),滿(mǎn)batch_size就yield  
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if len(batch) > 0 and not self.drop_last:
            yield batch

2.2 collate_fn:將batch的數(shù)據(jù)重新組裝

例如cirtorch中將數(shù)據(jù)拆成input_data和target兩個(gè)數(shù)據(jù)。
因Dataset中g(shù)et_item返回input_data和target兩個(gè)值,如果不用該函數(shù),每個(gè)batch的數(shù)據(jù)應(yīng)該是[batch_size,2(先input_data再target),,,],經(jīng)過(guò)該函數(shù)將變成([batch_size,,,],[batch_size,,]),第一個(gè)數(shù)據(jù)全是input_data,第二個(gè)數(shù)據(jù)全是target。

2.3 pin_memory=True:提高數(shù)據(jù)從cpu到gpu傳輸效率

pin_memory可在cpu主存(內(nèi)存)中分配不可交換到swap(緩存)的內(nèi)存。。默認(rèn)內(nèi)存分配中的數(shù)據(jù)都可交換到swap中,那CUDA驅(qū)動(dòng)會(huì)通過(guò)DRAM機(jī)制將數(shù)據(jù)從內(nèi)存?zhèn)鞯紾PU顯存時(shí)會(huì)復(fù)制2次(先復(fù)制到一臨時(shí)不可見(jiàn)pinned固定內(nèi)存,再往顯存中復(fù)制),因此pin_memory=True可提高約2倍cpu到gpu傳輸效率(.cuda()或 .to(device)的時(shí)候)。相見(jiàn)CPU和GPU內(nèi)存交互

【拓展】Elasticsearch中的Memlock(內(nèi)存鎖定)可申請(qǐng)固定大小且不可交換內(nèi)存空間。

三、DataLoader的并行

# Our data model looks like this (queues are indicated with curly brackets):
    #
    #                main process                              ||
    #                     |                                    ||
    #               {index_queue}                              ||
    #                     |                                    ||
    #              worker processes                            ||     DATA
    #                     |                                    ||
    #            {worker_result_queue}                         ||     FLOW
    #                     |                                    ||
    #      pin_memory_thread of main process                   ||   DIRECTION
    #                     |                                    ||
    #               {data_queue}                               ||
    #                     |                                    ||
    #                data output                               \/
    #
    # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
    #      `pin_memory=False`.
    
  • 基于multiprocessing多進(jìn)程
  • 每個(gè)子進(jìn)程的輸入輸出,通過(guò)兩個(gè)主要的隊(duì)列(multiprocessing.Queue()): index_queue要處理的下標(biāo)、worker_result_queue要返回的下標(biāo)。
  • 每個(gè)worker一次產(chǎn)生一個(gè)batch的數(shù)據(jù)
  • 返回batch數(shù)據(jù)前放入下一個(gè)批次數(shù)據(jù)下標(biāo)
  • 構(gòu)造函數(shù)子進(jìn)程初始化:
            self.index_queues = []
            self.workers = []
            for i in range(self.num_workers):
                index_queue = multiprocessing.Queue() # 1.每個(gè)子進(jìn)程一個(gè)隊(duì)列放要處理的下標(biāo)
                index_queue.cancel_join_thread()
                w = multiprocessing.Process(
                    target=_utils.worker._worker_loop, # 每個(gè)子進(jìn)程循環(huán)執(zhí)行的函數(shù)  
                    args=(self.dataset, index_queue,
                          self.worker_result_queue, self.done_event, #2.self.worker_result_queue 多子進(jìn)程公用要返回batch數(shù)據(jù)的隊(duì)列  
                          self.collate_fn, base_seed + i,
                          self.worker_init_fn, i))
                w.daemon = True
                # NB: Process.start() actually take some time as it needs to
                #     start a process and pass the arguments over via a pipe.
                #     Therefore, we only add a worker to self.workers list after
                #     it started, so that we do not call .join() if program dies
                #     before it starts, and __del__ tries to join but will get:
                #     AssertionError: can only join a started process.
                w.start()
                self.index_queues.append(index_queue)
                self.workers.append(w)

3.1 index_queue 要處理的數(shù)據(jù)下標(biāo)

每個(gè)worker有一個(gè)index_queue dataloader.py#L544-L552
每個(gè)worker從index_queue取要處理的下標(biāo) dataloader.py#L124
dataloader輸出一次數(shù)據(jù)前先往index_queue中放一次下標(biāo), _process_next_batch函數(shù):

    def _process_next_batch(self, batch):
        self.rcvd_idx += 1
        self._put_indices()  ## 先放下一批數(shù)據(jù)下標(biāo)
        if isinstance(batch, ExceptionWrapper):
            raise batch.exc_type(batch.exc_msg)
        return batch         ## 再返回該批數(shù)據(jù)

_put_indices依次往不同worker所屬的index_queue中放 dataloader.py#L644-L652

完整的dataloader next函數(shù):

    def __next__(self):
        if self.num_workers == 0:  # same-process loading
            indices = next(self.sample_iter)  # may raise StopIteration
            batch = self.collate_fn([self.dataset[i] for i in indices])
            if self.pin_memory:
                batch = pin_memory_batch(batch)
            return batch

        # check if the next sample has already been generated
        if self.rcvd_idx in self.reorder_dict:
            batch = self.reorder_dict.pop(self.rcvd_idx)
            return self._process_next_batch(batch) ## 5. 之前以及取出來(lái)該下標(biāo)數(shù)據(jù),直接返回

        if self.batches_outstanding == 0:
            self._shutdown_workers()
            raise StopIteration

        while True:  ## 1.直到取的數(shù)據(jù)下標(biāo)正確才return
            assert (not self.shutdown and self.batches_outstanding > 0)
            idx, batch = self._get_batch()  ## 2.從worker_result_queue中獲取數(shù)據(jù)
            self.batches_outstanding -= 1
            if idx != self.rcvd_idx:
                # store out-of-order samples
                self.reorder_dict[idx] = batch  ## 3.下標(biāo)不對(duì)先存一下
                continue
            return self._process_next_batch(batch) ## 4.內(nèi)部先放下一批數(shù)據(jù)下標(biāo)再返回batch數(shù)據(jù)  

3.2 worker_result_queue 返回結(jié)果

每個(gè)worker一直在執(zhí)行的循環(huán)_worker_loop,其中worker_result_queue作為_(kāi)worker_loop函數(shù)的data_queue傳入(dataloader.py#L544-L552),相見(jiàn):

def _worker_loop(dataset, index_queue, data_queue, done_event, collate_fn, seed, init_fn, worker_id):
    # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
    # logic of this function.

    try:
        global _use_shared_memory
        _use_shared_memory = True

        # Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
        # module's handlers are executed after Python returns from C low-level
        # handlers, likely when the same fatal signal happened again already.
        # https://docs.python.org/3/library/signal.html Sec. 18.8.1.1
        _set_worker_signal_handlers()

        torch.set_num_threads(1)
        random.seed(seed)
        torch.manual_seed(seed)

        data_queue.cancel_join_thread()

        if init_fn is not None:
            init_fn(worker_id)

        watchdog = ManagerWatchdog()

        while watchdog.is_alive():
            try:
                r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) ##從index_queue中獲取要處理的下標(biāo)
            except queue.Empty:
                continue
            if r is None:
                # Received the final signal
                assert done_event.is_set()
                return
            elif done_event.is_set():
                # Done event is set. But I haven't received the final signal
                # (None) yet. I will keep continuing until get it, and skip the
                # processing steps.
                continue
            idx, batch_indices = r
            try:
                samples = collate_fn([dataset[i] for i in batch_indices]) ##1.根據(jù)下標(biāo)取樣本數(shù)據(jù)  
            except Exception:
                # It is important that we don't store exc_info in a variable,
                # see NOTE [ Python Traceback Reference Cycle Problem ]
                data_queue.put((idx, ExceptionWrapper(sys.exc_info())))
            else: ## 2. 沒(méi)有拋異常就將樣本數(shù)據(jù)放入結(jié)果返回隊(duì)列  
                data_queue.put((idx, samples))
                del samples
    except KeyboardInterrupt:
        # Main process will raise KeyboardInterrupt anyways.
        pass

參考文獻(xiàn)

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容