PyTorch教程-7:PyTorch中保存與加載tensor和模型詳解

筆者PyTorch的全部簡單教程請訪問:http://www.lxweimin.com/nb/48831659

PyTorch教程-7:PyTorch中保存與加載tensor和模型詳解

保存和讀取Tensor

PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save()方法保存張量,使用torch.load()來讀取張量:

x = torch.rand(4,5)
torch.save(x, "./myTensor.pt")

y = torch.load("./myTensor.pt")
print(y)

tensor([[0.9363, 0.2292, 0.1612, 0.9558, 0.9414],
        [0.3649, 0.9622, 0.3547, 0.5772, 0.7575],
        [0.7005, 0.8115, 0.6132, 0.6640, 0.1173],
        [0.6999, 0.1023, 0.8544, 0.7708, 0.1254]])

當然,saveload方法也適用于其他數據類型,比如list、tuple、dict等:

a = {'a':torch.rand(2,2), 'b':torch.rand(3,4)}
torch.save(a, "./myDict.pth")

b = torch.load("./myDict.pth")
print(b)

{'a': tensor([[0.9356, 0.0240],
        [0.6004, 0.3923]]), 'b': tensor([[0.0222, 0.1799, 0.9172, 0.8159],
        [0.3749, 0.6689, 0.4796, 0.5772],
        [0.5016, 0.5279, 0.5109, 0.0592]])}

保存Tensor的純數據

PyTorch中,使用 torch.save 保存的不僅有其中的數據,還包括一些它的信息,包括它與其它數據(可能存在)的關系,這一點是很有趣的。

This is an implementation detail that may change in the future, but it typically saves space and lets PyTorch easily reconstruct the view relationships between the loaded tensors.

詳細的原文可以參考:https://pytorch.org/docs/stable/notes/serialization.html#saving-and-loading-tensors-preserves-views

這里結合例子給出一個簡單的解釋。

x = torch.arange(20)
y = x[:5]

torch.save([x,y], "./myTensor.pth")
x_, y_ = torch.load("././myTensor.pth")

y_ *= 100

print(x_)

tensor([  0, 100, 200, 300, 400,   5,   6,   7,   8,   9,  10,  11,  12,  13, 14,  15,  16,  17,  18,  19])

比如在上邊的例子中,我們看到yx的一個前五位的切片,當我們同時保存xy后,它們的切片關系也被保存了下來,再將他們加載出來,它們之間依然保留著這個關系,因此可以看到,我們將加載出來的 y_ 乘以100后,x_ 也跟著變化了。

如果不想保留他們的關系,其實也很簡單,再保存y之前使用 clone 方法保存一個只有數據的“克隆體”,這樣就能只保存數據而不保留關系:

x = torch.arange(20)
y = x[:5]

torch.save([x,y.clone()], "./myTensor.pth")
x_, y_ = torch.load("././myTensor.pth")

y_ *= 100

print(x_)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])

當我們只保存y而不同時保存x會怎樣呢?這樣的話確實可以避免如上的情況,即不會再在讀取數據后保留他們的關系,但是實際上有一個不容易被看到的影響存在,那就是保存的數據所占用的空間會和其“父親”級別的數據一樣大

x = torch.arange(1000)
y = x[:5]

torch.save(y, "./myTensor1.pth")
torch.save(y.clone(), "./myTensor2.pth")

y1_ = torch.load("./myTensor1.pth")
y2_ = torch.load("./myTensor2.pth")

print(y1_.storage().size())
print(y2_.storage().size())

1000
5

如果你去觀察他們保存的文件,會發現占用的空間確實存在很大的差距:

myTensor1.pth      9KB
myTensor2.pth      1KB

綜上所述,對于一些“被關系”的數據來說,如果不想保留他們的關系,最好使用 clone 來保存其“純數據”

保存與加載模型

保存與加載state_dict

這是一種較為推薦的保存方法,即只保存模型的參數,保存的模型文件會較小,而且比較靈活。但是當加載時,需要先實例化一個模型,然后通過加載將參數賦給這個模型的實例,也就是說加載之前使用者需要知道模型的結構。

  • 保存:
    torch.save(model.state_dict(), PATH)
    
  • 加載:
    model = TheModelClass(*args, **kwargs)
    model.load_state_dict(torch.load(PATH))
    model.eval()
    

比較重要的點是:

  • 保存模型時調用 state_dict() 獲取模型的參數,而不保存結構
  • 加載模型時需要預先實例化一個對應的網絡,比如net=MyNet(),這也就意味著,使用者需要預先有MyNet這個類,如果他/她不知道這個網絡的類定義或者結構,這種只保存參數的方法將無法使用
  • 加載模型使用 load_state_dict 方法,其參數不是文件路徑,而是 torch.load(PATH)
  • 如果加載出來的模型用于驗證,不要忘了使用 model.eval() 方法,它會丟棄 dropout、normalization 等層,因為這些層不能在inference的時候使用,否則得到的推斷結果不一致。

一個例子:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()

        # convolution layers
        self.conv1 = nn.Conv2d(1,6,3)
        self.conv2 = nn.Conv2d(6,16,3)

        # fully-connection layers
        self.fc1 = nn.Linear(16*6*6,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self,x):
        # max pooling over convolution layers
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        x = F.max_pool2d(F.relu(self.conv2(x)),2)

        # fully-connected layers followed by activation functions
        x = x.view(-1,16*6*6)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        # final fully-connected without activation functon
        x = self.fc3(x)

        return x

net = Net()

torch.save(net.state_dict(), "./myModel.pth")

loaded_net = Net()
loaded_net.load_state_dict(torch.load("./myModel.pth"))
loaded_net.eval()

Net(
  (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=576, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

保存與加載整個模型

這種方式不僅保存、加載模型的數據,也包括模型的結構一并存儲,存儲的文件會較大,好處是加載時不需要提前知道模型的結構,解來即用。實際上這與上文提到的保存Tensor是一致的。

  • 保存:
    torch.save(model, PATH)
    
  • 加載:
    model = torch.load(PATH)
    model.eval()
    

同樣的,如果加載的模型用于inference,則需要使用 model.eval()

保存與加載模型與其他信息

有時我們不僅要保存模型,還要連帶保存一些其他的信息。比如在訓練過程中保存一些 checkpoint,往往除了模型,還要保存它的epoch、loss、optimizer等信息,以便于加載后對這些 checkpoint 繼續訓練等操作;或者再比如,有時候需要將多個模型一起打包保存等。這些其實也很簡單,正如我們上文提到的,torch.save 可以保存dict、list、tuple等多種數據結構,所以一個字典可以很完美的解決這個問題,比如一個簡單的例子:

# saving
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

# loading
model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

跨設備存儲與加載

跨設備的情況指對于一些數據的保存、加載在不同的設備上,比如一個在CPU上,一個在GPU上的情況,大致可以分為如下幾種情況:

從CPU保存,加載到CPU

實際上,這就是默認的情況,我們上文提到的所有內容都沒有關心設備的問題,因此也就適應于這種情況。

從CPU保存,加載到GPU

  • 保存:依舊使用默認的方法
  • 加載:有兩種可選的方式
    • 使用 torch.load() 函數的 map_location 參數指定加載后的數據保存的設備
    • 對于加載后的模型使用 to() 函數發送到設備
torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))
# or
loaded_net.to(device)

從GPU保存,加載到CPU

  • 保存:依舊使用默認的方法
  • 加載:只能使用 torch.load() 函數的 map_location 參數指定加載后的數據保存的設備
torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))

從GPU保存,加載到GPU

  • 保存:依舊使用默認的方法
  • 加載:只能使用 對于加載后的模型進行 to() 函數發送到設備
torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.to(device)
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容