筆者PyTorch的全部簡(jiǎn)單教程請(qǐng)?jiān)L問(wèn):http://www.lxweimin.com/nb/48831659
PyTorch教程-2:PyTorch中反向傳播的梯度追蹤與計(jì)算
基本原理
torch.Tensor
類具有一個(gè)屬性 requires_grad
用以表示該tensor是否需要計(jì)算梯度,如果該屬性設(shè)置為True
的話,表示這個(gè)張量需要計(jì)算梯度,計(jì)算梯度的張量會(huì)跟蹤其在后續(xù)的所有運(yùn)算,當(dāng)我們完成計(jì)算后需要反向傳播(back propagation)計(jì)算梯度時(shí),使用 .backward()
即可自動(dòng)計(jì)算梯度。當(dāng)然,對(duì)于一些我們不需要一直跟蹤記錄運(yùn)算的tensor,也可以取消這一操作,尤其是在對(duì)模型進(jìn)行驗(yàn)證的時(shí)候,不會(huì)對(duì)變量再做反向傳播,所以自然不需要再進(jìn)行追蹤,從而減少運(yùn)算。
追蹤計(jì)算歷史
一個(gè)tensor的 requires_grad
屬性決定了這個(gè)tensor是否被追蹤運(yùn)算,對(duì)其主要的操作方式:
- 查看/返回該屬性:
tensor.requires_grad
- 定義該屬性的值:在創(chuàng)建一個(gè)tensor時(shí)顯式地聲明
requires_grad
變量為True
(默認(rèn)為False
) - 更改該屬性的值:使用
tensor.requires_grad_()
改變其值
a=torch.rand(2,2,requires_grad=True)
print(a.requires_grad)
a.requires_grad_(False)
print(a.requires_grad)
True
False
每當(dāng)對(duì)于requires_grad
為True
的tensor進(jìn)行一些運(yùn)算時(shí)(除了用戶直接賦值、創(chuàng)建等操作),這些操作都會(huì)保存在變量的 grad_fn
屬性中,該屬性返回一個(gè)操作,即是上一個(gè)作用在這個(gè)變量上的操作:
x=torch.ones(2,2,requires_grad=True)
print(x.grad_fn)
y = x+2
print(y.grad_fn)
z = y*y*3
print(z.grad_fn)
out = z.mean()
print(out.grad_fn)
None
<AddBackward0 object at 0x0000026EF919AC88>
<MulBackward0 object at 0x0000026EF919AC88>
<MeanBackward0 object at 0x0000026EBACDE688>
如果需要繼續(xù)往前得到連續(xù)的操作,對(duì)grad_fu
使用 next_functions
即可獲得其上一步的操作(next_functions
返回一個(gè)多層的tuple,真正的操作記錄對(duì)象要經(jīng)過(guò)兩層的[0]
索引:
x=torch.ones(2,2,requires_grad=True)
y = x+2
z = y*y*3
out = z.mean()
print(out.grad_fn)
print(out.grad_fn.next_functions[0][0])
print(out.grad_fn.next_functions[0][0].next_functions[0][0])
<MeanBackward0 object at 0x0000026EF873D5C8>
<MulBackward0 object at 0x0000026EF9354A08>
<MulBackward0 object at 0x0000026EF873D5C8>
梯度計(jì)算
對(duì)于requires_grad
為True
的tensor,在某一層運(yùn)算結(jié)果的tensor上調(diào)用 backward()
方法,即可計(jì)算它對(duì)于原始tensor的梯度。比如 out=mean(z(y(x))
這樣一個(gè)三層的運(yùn)算作用后,使用 out.backward()
方法就可以對(duì)x進(jìn)行求導(dǎo)(有條件的),完成求導(dǎo)后,梯度會(huì)存儲(chǔ)在x
這個(gè)tensor的grad
屬性下,每個(gè)tensor都有grad
屬性,用于記錄高層運(yùn)算對(duì)它求導(dǎo)的梯度值。剛剛說(shuō)到的有條件是指這個(gè)被求導(dǎo)的變量需要是一個(gè)只包含一個(gè)標(biāo)量的tensor。
- 對(duì)于結(jié)果只包含一個(gè)標(biāo)量的tensor:使用
tensor.backward()
反向傳播,對(duì)底層的求導(dǎo)結(jié)果,梯度會(huì)存儲(chǔ)在最底層tensor的grad
屬性中 - 對(duì)于結(jié)果是一個(gè)向量/矩陣的tensor,使用
tensor.backward()
時(shí)需要傳入作為反向傳播的參數(shù)來(lái)計(jì)算Jacobian矩陣
x=torch.ones(2,2,requires_grad=True)
y = x+2
z = y*y*3
out = z.mean()
out.backward()
print(x.grad)
y.backward(x)
print(x.grad)
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
tensor([[5.5000, 5.5000],
[5.5000, 5.5000]])
取消追蹤運(yùn)算
有時(shí)我們并不需要追蹤梯度,將requires_grad
設(shè)置為False
即可,但是由于有時(shí)候有些tensor需要在模型訓(xùn)練時(shí)計(jì)算梯度,在模型驗(yàn)證時(shí)不計(jì)算梯度,我們不希望直接對(duì)tensor的requires_grad
屬性做更改,所以需要更好、更方便的設(shè)置方法:
-
使用
with torch.no_grad()
語(yǔ)句塊,放在這個(gè)語(yǔ)句塊下的所有tensor操作(不影響tensor本身)都不會(huì)被跟蹤運(yùn)算:x=torch.ones(2,2,requires_grad=True) print((x**2).requires_grad) with torch.no_grad(): print(x.requires_grad) print((x**2).requires_grad)
True True False
-
使用
tensor.detach()
方法獲得一個(gè)跟原tensor值一樣但是不會(huì)被記錄運(yùn)算的tensor(不改變?cè)瓉?lái)的tensor屬性):x=torch.ones(2,2,requires_grad=True) print(x.requires_grad) y = x.detach() print(x.requires_grad) print(y.requires_grad)
True True False