PyTorch獲取自帶模型的中間特征圖

本文首發(fā)自【簡書】用戶【西北小生_】的博客,轉(zhuǎn)載請注明出處!

PyTorch之HOOK——獲取神經(jīng)網(wǎng)絡(luò)特征和梯度的有效工具記錄了PyTorch獲取卷積神經(jīng)網(wǎng)絡(luò)特征圖和梯度的方法,但由于舉例簡單,并不能直接應(yīng)用于用其它構(gòu)造方法構(gòu)造的神經(jīng)網(wǎng)絡(luò)模型。為解決更具一般性的問題,現(xiàn)針對PyTorch自帶的vgg16模型,記錄獲取其任一層輸出特征圖的方法,讀者閱讀完本文,可以自行應(yīng)用于獲取更多的模型(resnet, densenet,mobilenet等)特征圖和梯度。

首先導(dǎo)入需要的包:

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import models
from torchvision.transforms.functional import normalize, resize, to_tensor, to_pil_image

本文以pytorch自帶的vgg16模型為例,故需要從torchvision.models加載vgg16模型:

model = models.vgg16_bn(pretrained=True).eval()

我們打印一下vgg16_bn看一下它的結(jié)構(gòu):

In [3]: print(model)
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (16): ReLU(inplace=True)
    (17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (19): ReLU(inplace=True)
    (20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (26): ReLU(inplace=True)
    (27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (29): ReLU(inplace=True)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (32): ReLU(inplace=True)
    (33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (36): ReLU(inplace=True)
    (37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (39): ReLU(inplace=True)
    (40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (42): ReLU(inplace=True)
    (43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

可以看出pytorch自帶的vgg16_bn模型是由features, avgpool, classifier 三個children構(gòu)成的,features和classifier又由Sequential()組成。

如果我們想獲取features中任意一層的輸出特征圖或梯度,該如何獲取呢?

那我們就取男人最愛的18這個數(shù)字吧(手動狗頭)

假設(shè)要獲取features中的編號為18的層輸出的特征圖,我們首先要建立hook函數(shù),然后對該層用register_forward_hook()函數(shù)進(jìn)行注冊(對這一部分不熟悉的可以看我的博客):

# 建立一個全局變量的字典,將特征圖放在其中
feature_map = {}

# 構(gòu)建hook函數(shù)
def forward_hook(module, inp, outp):
    feature_map['features18'] = outp

其實這一方法的難點就在于如何獲取相應(yīng)的層,并對其注冊hook,我介紹一種簡單的做法,那就是torch.nn.Module.children()方法。(torch.nn.Module.children()介紹請看我的這篇博客

pytorch自帶的vgg16_bn模型有3個children:features, avgpool, classifier 。為了獲取features中的第18層,我們需要先獲取features,再用索引的方式獲取第18層(或者任一層):

features = list(model.children())[0]

hook_layer = features[18]

這樣就獲取了features中的第18層,在vgg16_bn中,第18層是BN層,我們可以打印hook_layer變量驗證一下:

In [11]: hook_layer
Out[11]: BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

可以看到,我們確實已經(jīng)獲取到了第18層。只要看懂了這里,獲取其它層都是一樣簡單的。

接下來就是對第18層進(jìn)行hook注冊:

hook_layer.register_forward_hook(forward_hook)

到這里,只需要加載一幅圖像,輸入模型model,對其進(jìn)行前向傳播一次,第18層的輸出特征圖就會出現(xiàn)在全局字典變量feature_map中了。

加載一幅圖像(本例中的圖像來源于ImageNet):

imgdir = 'ILSVRC2012_val_00000003.JPEG'
origin_img = Image.open(imgdir)
img_tensor = normalize(to_tensor(resize(origin_img, (224, 224))),
                           [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
input_img = torch.unsqueeze(img_tensor, 0)

將圖像輸入模型中,進(jìn)行一次前向傳播:

with torch.no_grad():
    score = model(input_img)

這時第18層輸出的特征圖就已經(jīng)出現(xiàn)在feature_map中了,我們查看一下它的尺寸:

In [23]: feature_map['features18'].shape
Out[23]: torch.Size([1, 256, 56, 56])

可以看到我們已經(jīng)獲取了第18層的輸出特征圖,后續(xù)處理就不一一贅述。等以后有時間了寫一個CAM系列講hook的應(yīng)用。

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請聯(lián)系作者
平臺聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務(wù)。

推薦閱讀更多精彩內(nèi)容