PyTorch使用總覽
原文鏈接:https://blog.csdn.net/u014380165/article/details/79222243
參考:PyTorch學習之路(level1)——訓練一個圖像分類模型、PyTorch學習之路(level2)——自定義數據讀取、PyTorch源碼解讀之torchvision.transforms、PyTorch源碼解讀之torch.utils.data.DataLoader、PyTorch源碼解讀之torchvision.models
- PyTorch的官方github地址:https://github.com/pytorch/pytorch
- PyTorch官方文檔:http://pytorch.org/docs/0.3.0/、https://pytorch.org/docs/stable/index.html
一、數據讀取
官方代碼庫中有一個接口例子:torchvision.ImageFolder -- 針對的數據存放方式是每個文件夾包含一個類的圖像,但往往實際應用中可能你的數據不是這樣維護的,此時需要自定義一個數據讀取接口(使用PyTorch中數據讀取基類:torch.utils.data.Dataset)
數據讀取接口
class customData(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
"""
提供數據地址(data path)、每一文件所屬的類別(label),and other Info wanted(transform\loader\...) --> self.(attributes)
:param root(string): Root directory path.
:param transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
:param target_transform(callable, optional): A function/transform that takes in the target and transforms it.
:param loader (callable, optional): A function to load an image given its path.
the data loader where the images are arranged in this way: ::
root/class_1_xxx.png
root/class_2_xxx.png
...
root/class_n_xxx.png # 此例中,文件名包含label信息,__init__中可不需要額外提供
"""
self.dataset = [os.path.join(root, npy_data) for npy_data in os.listdir(root)] # 整個數據集(圖像)文件的路徑
self.transform = transform # (optional)
self.target_transform = target_transform # (optional)
self.loader = loader # (optional)
def __getitem__(self, index):
"""
:return 相應index的data && label
"""
data = np.load(self.dataset[index])
if self.transform is not None: # (optional)
img = self.transform(img)
if self.target_transform is not None: # (optional)
target = self.target_transform(target)
label_txt = self.dataset[index].split('/')[-1][:2] # (class_n)_xxxx.npy → (class_n)
if label_txt == 'class_1':
label = 0
elif label_txt == 'class_2':
label = 1
else:
raise RuntimeError('Now only support class_1 vs class_2.')
return data, label
def __len__(self):
"""
:return 數據集數量
"""
return len(self.dataset)
上述提到的transforms數據預處理,可以通過torchvision.transforms接口來實現。具體請看博客:PyTorch源碼解讀之torchvision.transforms
接口調用
root_dir = r'xxxxxxxx'
image_datasets = {x: customData(root=root_dir+x) for x in ['train', 'val', 'test']}
返回的image_datasets(自定義數據讀取接口)就和用torchvision.datasets.ImageFolder類(官方提供的數據讀取接口)返回的數據類型一樣
數據迭代器封裝
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
for x in ['train', 'valid', 'test']}
torch.utils.data.DataLoader接口將每個batch的圖像數據和標簽都分別封裝成Tensor,方便以batch進行模型批訓練,具體可以參考博客: PyTorch源碼解讀之torch.utils.data.DataLoader
至此,從圖像和標簽文件就生成了Tensor類型的數據迭代器,后續僅需將Tensor對象用torch.autograd.Variable接口封裝成Variable類型
(比如train_data=torch.autograd.Variable(train_data),如果要在gpu上運行則是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作為模型的輸入
二、網絡構建
PyTorch框架中提供了一些方便使用的網絡結構及預訓練模型接口:torchvision.models,具體可以看博客:PyTorch源碼解讀之torchvision.models。該接口可以直接導入指定的網絡結構,并且可以選擇是否用預訓練模型初始化導入的網絡結構。示例如下:
import torchvision
model = torchvision.models.resnet50(pretrained=True) # 導入resnet50的預訓練模型
那么如何自定義網絡結構呢?在PyTorch中,構建網絡結構的類都是基于torch.nn.Module
這個基類進行的,也就是說所有網絡結構的構建都可以通過繼承該類來實現,包括torchvision.models接口中的模型實現類也是繼承這個基類進行重寫的。自定義網絡結構可以參考:1、https://github.com/miraclewkf/MobileNetV2-PyTorch。該項目中的MobileNetV2.py腳本自定義了網絡結構。2、https://github.com/miraclewkf/SENet-PyTorch。該項目中的se_resnet.py和se_resnext.py腳本分別自定義了不同的網絡結構。
如果要用某預訓練模型為自定義的網絡結構進行參數初始化,可以用torch.load接口導入預訓練模型,然后調用自定義的網絡結構對象的load_state_dict方式進行參數初始化,具體可以看https://github.com/miraclewkf/MobileNetV2-PyTorch項目中的train.py腳本中if args.resume條件語句(如下所示)。
if args.resume:
if os.path.isfile(args.resume):
print(("=> loading checkpoint '{}'".format(args.resume)))
checkpoint = torch.load(args.resume)
base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.state_dict().items())}
model.load_state_dict(base_dict)
else:
print(("=> no checkpoint found at '{}'".format(args.resume)))
三、其他設置
優化函數通過torch.optim包實現,比如torch.optim.SGD()接口表示隨機梯度下降。更多優化函數可以看官方文檔:http://pytorch.org/docs/0.3.0/optim.html。
學習率策略通過torch.optim.lr_scheduler接口實現,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch數減少學習率。更多學習率變化策略可以看官方文檔:http://pytorch.org/docs/0.3.0/optim.html。
損失函數通過torch.nn包實現,比如torch.nn.CrossEntropyLoss()接口
表示交叉熵等。
多GPU訓練通過torch.nn.DataParallel接口
實現,比如:model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上訓練模型。
模塊解讀
torch.utils.data.DataLoader
將數據讀取接口的輸入按照batch size封裝成Tensor,后續只需要再包裝成Variable即可作為模型的輸入,因此該接口有承上啟下的作用
源碼地址:https://github.com/pytorch/pytorch/blob/master/torch/utils/data/dataloader.py
示例:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4)
for x in ['train', 'valid', 'test']}
- dataset (Dataset): dataset from which to load the data.
- batch_size (int, optional): how many samples per batch to load (default: 1).
- shuffle (bool, optional): set to True to have the data reshuffled at every epoch (default: False).
- num_workers (int, optional): how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
- ... ...
從torch.utils.data.DataLoader類生成的對象中取數據:
train_data=torch.utils.data.DataLoader(...)
for i, (input, target) in enumerate(train_data):
# ...
pass
此時,調用DataLoader類的__iter__
方法 ??:
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
使用隊列queue對象,完成多線程調度;通過迭代器iter,完成batch更替(詳情讀源碼)
torchvision.transforms
基本上PyTorch中的data augmentation操作都可以通過該接口實現,包含resize、crop等常見的data augmentation操作
示例:
import torchvision
import torch
train_augmentation = torchvision.transforms.Compose([torchvision.transforms.Resize(256),
torchvision.transforms.RandomCrop(224),
torchvision.transofrms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torch vision.Normalize([0.485, 0.456, -.406],[0.229, 0.224, 0.225])
])
class custom_dataread(torch.utils.data.Dataset): # 數據讀取接口
def __init__():
...
def __getitem__():
# use self.transform for input image
def __len__():
...
train_loader = torch.utils.data.DataLoader( # 數據迭代器
custom_dataread(transform=train_augmentation),
batch_size = batch_size, shuffle = True,
num_workers = workers, pin_memory = True)
這里定義了resize、crop、normalize等數據預處理操作,并最終作為數據讀取類custom_dataread的一個參數傳入,可以在內部方法__getitem__
中實現數據增強操作。
源碼地址:transformas.py --- 定義各種data augmentation的類、functional.py --- 提供transformas.py中所需功能函數的實現
Compose類:Composes several transforms together. 對輸入圖像img逐次應用輸入的[transform_1, transform_2, ...]操作
ToTensor類:Convert a
PIL Image
ornumpy.ndarray
to tensor. 要強調的是在做數據歸一化之前必須要把PIL Image轉成Tensor,而其他resize或crop操作則不需要.ToPILImage類:Convert a
tensor
or anndarray
toPIL Image
.Normalize類:Normalize a tensor image with mean and standard deviation.一般都會對輸入數據做歸一化操作
Resize類:Resize the input PIL
Image
to the given size. 幾乎都要用到,這里輸入可以是int,此時表示將輸入圖像的短邊resize到這個int數,長邊則根據對應比例調整,圖像的長寬比不變。如果輸入是個(h,w)的序列,h和w都是int,則直接將輸入圖像resize到這個(h,w)尺寸,相當于force resize,所以一般最后圖像的長寬比會變化,也就是圖像內容被拉長或縮短。若輸入是PIL Image,則將調用Image的各種方法;若輸入是Tensor,則對應函數基本是在調用Tensor的各種方法。CenterCrop類:Crops the given PIL Image at the center. 一般數據增強不會采用這個,因為當size固定的時候,在相同輸入圖像的情況下,N次CenterCrop的結果都是一樣的
RandomCrop類:Crop the given PIL Image at a random location. 相較CenterCrop,隨機裁剪更常用
RandomResizedCrop類:Crop the given PIL Image to random size and aspect ratio. 根據隨機生成的scale、aspect ratio(縮放比例、長寬比)、中心點裁剪原圖,(為可正常訓練)再縮放為輸入的size大小
RandomHorizontalFlip類:Horizontally flip the given PIL Image randomly with a given probability. 隨機的圖像水平翻轉,通俗講就是圖像的左右對調,較常用。 probability of the image being flipped. Default value is 0.5 (水平翻轉的概率是0.5)
RandomVerticalFlip類:Vertically flip the given PIL Image randomly with a given probability. 隨機的圖像豎直翻轉,通俗講就是圖像的上下對調,較常用。probability of the image being flipped. Default value is 0.5(豎直翻轉的概率是0.5)
FiveCrop類:Crop the given PIL Image into four corners and the central crop. 曾在TSN算法的看到過這種用法。
TenCrop類:Crop the given PIL Image into four corners and the central crop plus the flipped version of
these (horizontal flipping is used by default) 將輸入圖像進行水平或豎直翻轉,然后再進行FiveCrop操作;加上原始的FiveCrop操作,這樣一張輸入圖像就能得到10張crop結果。LinearTransformation類:Transform a tensor image with a square transformation matrix and a mean_vector computed offline. 用一個變換矩陣去乘輸入圖像得到輸出結果。
ColorJitter類:Randomly change the brightness, contrast, saturation and hue (即亮度,對比度,飽和度和色調)of an image,可以根據注釋來合理設置這4個參數。(較常用)
RandomRotation類:隨機旋轉輸入圖像,具體參數可以看注釋,在F.rotate()中主要是調用PIL Image的rotate方法。(較常用)
Grayscale類:用來將輸入圖像轉成灰度圖的,這里根據參數num_output_channels的不同有兩種轉換方式
RandomGrayscale類:Randomly convert image to grayscale with a probability of p (default 0.1).