Why transforms?
一般情況下收集到的圖像樣本在尺寸,亮度等方面存在差異,在深度學習中,我們希望樣本分布是獨立同分布的,因此需要對樣本進行歸一化預處理。
有時候只能獲取到少量的樣本數據,獲取數量較多的樣本不容易。但是樣本數量太少訓練的模型精度會比較低,為了解決這樣的問題,往往需要做數據增加data arguement, 數據增加的途徑就是通過一些變換達到目的。
pytorch中的transforms
在pytorch中,transforms位于 torchvision.transforms
包中
image.png
主要包含的變換:
類型 | 作用 |
---|---|
Transforms on PIL Image | 對PIL.Image圖像進行變換 |
Transforms on torch.*Tensor | 對torch.Tensor進行變換 |
Conversion Transforms | |
Generic Transforms | 一些通用的變換 |
Functional Transforms | 函數 |
開發/實驗環境
- windows10 64bit
- Anaconda3
- python3.7
- pytorch1.0
API 參考
Reference
pytorch官網文檔
image.png
實踐
一、對PIL Image進行變換(Transforms on PIL Image)
- Summary
類型 | 說明 |
---|---|
CenterCrop(size) |
中心裁剪 |
FiveCrop(size) |
4個角+中心裁剪 = 5, 返回多張圖像 |
Grayscale(num_output_channels = 1) |
灰度化 |
Pad(padding, fill=o,padding_mode='constant) |
圖像邊沿加pad |
RandomAffine(degrees,translate,scale,shear,resample,fillcolor) |
隨進放射變換 |
RandomApply(..) |
對圖像隨機應用變換 |
RandomCrop(..) |
隨機位置裁剪 |
RandomGrayscale(..) | |
Resize(size) |
對圖像進行尺寸縮放 |
- 實驗
import numpy as np
from torchvision.transforms import transforms
from PIL import Image
# 準備好實驗的圖像,一個彩色32bit圖像
IMG_PATH = './data/lena_rgb.jpg'
img = Image.open(IMG_PATH)
# -----------------類型轉換---------------------------------------
#transforms1 = transforms.Compose([transforms.ToTensor()])
#img1 = transforms1(img)
#print('img1 = ', img1)
# ---------------Tensor上的操作---------------------------------
#transforms2 = transforms.Compose([transforms.Normalize(mean=(0.5, 0.5, #0.5), std=(0.5, 0.5, 0.5))])
#img2 = transforms2(img1)
#print('img2 = ', img2)
# ---------------PIL.Image上的操作---------------------------------
transforms3 = transforms.Compose([transforms.Resize(256)])
img3 = transforms3(img)
print('img3 = ', img3)
img3.show()
transforms4 = transforms.Compose([transforms.CenterCrop(256)])
img4 = transforms4(img)
print('img4 = ', img4)
img4.show()
transforms5 = transforms.Compose([transforms.RandomCrop(224, padding=0)])
img5 = transforms5(img)
print('img5 = ', img5)
img5.show()
transforms6 = transforms.Compose([transforms.Grayscale(num_output_channels=1)])
img6 = transforms6(img)
print('img6 = ', img6)
img6.show()
transforms7 = transforms.Compose([transforms.ColorJitter()])
img7 = transforms7(img)
img7.show()