Pytorch 載入和保存模型(無格式整理,先記下)

  1. 定義網(wǎng)絡(luò)結(jié)構(gòu)
class DenseNet(nn.Module):
    r"""Densenet-BC model class, based on
    `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
    Args:
        growth_rate (int) - how many filters to add each layer (`k` in paper)
        block_config (list of 4 ints) - how many layers in each pooling block
        num_init_features (int) - the number of filters to learn in the first convolution layer
        bn_size (int) - multiplicative factor for number of bottle neck layers
          (i.e. bn_size * k features in the bottleneck layer)
        drop_rate (float) - dropout rate after each dense layer
        num_classes (int) - number of classification classes
    """
    def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000):

        super(DenseNet, self).__init__()

        # First convolution
        self.features = nn.Sequential(OrderedDict([
            ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ('norm0', nn.BatchNorm2d(num_init_features)),
            ('relu0', nn.ReLU(inplace=True)),
            ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
                                bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2)
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module('norm5', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.avg_pool2d(out, kernel_size=7).view(features.size(0), -1)
        out = self.classifier(out)
        return out
  1. 使用網(wǎng)絡(luò)結(jié)構(gòu)定義模型:
net = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24))
  1. 載入模型參數(shù)
net.load_state_dict(torch.load('/home/wei.fan/.torch/models/densenet161-17b70270.pth'))

4.訓(xùn)練模型

num_ftrs = model_conv.classifier.in_features
net.classifier = nn.Linear(num_ftrs, 100) #調(diào)整最后一層的尺寸
net =net.cuda()
criterion = nn.CrossEntropyLoss()
net = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
net =train_net() #訓(xùn)練模型的函數(shù),自定義
torch.save(net.state_dict(), 'net_params.pkl') #只保存模型參數(shù)
最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容

  • Spring Cloud為開發(fā)人員提供了快速構(gòu)建分布式系統(tǒng)中一些常見模式的工具(例如配置管理,服務(wù)發(fā)現(xiàn),斷路器,智...
    卡卡羅2017閱讀 134,837評(píng)論 18 139
  • 彈性是一個(gè)怎樣的存在,對(duì)於生活怎樣才算是有彈性呢,這是一種超脫原本層次的看待,不同於鬆散相對(duì)於有結(jié)構(gòu),是一種可以落...
    墨然_平凡閱讀 106評(píng)論 0 0
  • 最近恐懼在生與死的概念里,整日價(jià)的在我腦海里盤旋,每當(dāng)我一閉眼,如同宇宙深淵一般的黑暗就會(huì)把我淹沒,莫名的無助與恐...
    花風(fēng)狂骨閱讀 206評(píng)論 0 1
  • 文章出處 什么是多線程下載? ? 多線程下載其實(shí)就是迅雷,BT一些下載原理,通過多個(gè)線程同時(shí)和服務(wù)器連接,那么你就...
    呂中宜閱讀 1,763評(píng)論 0 7
  • 每個(gè)人生命中都會(huì)有一些朋友,你們?cè)谧蠲篮玫哪耆A里相遇,大家有著同樣的愛好和觀念,聚在一起會(huì)斗嘴,不見了又想念,不經(jīng)...
    云中飄舞閱讀 767評(píng)論 4 8