好久沒(méi)看的VAE又不太記得了,重新梳理一下思路,在S同學(xué)的指導(dǎo)下,又有了一些新的理解。之前寫過(guò)一篇關(guān)于VAE的入門教程,但是感覺(jué)還不夠簡(jiǎn)練,刪掉重新寫一個(gè)哈哈哈。這篇文章主要是從一個(gè)熟悉machine learning但是對(duì)于VAE一點(diǎn)都不懂的視角,進(jìn)行寫作的,并不涉及很多復(fù)雜的理論知識(shí),輔助理解為主。另外,上次那篇文章是先給結(jié)論,在慢慢講細(xì)節(jié)。這篇文章將會(huì)循序漸進(jìn)逐漸推導(dǎo)出VAE的各種Trick存在的必要性。
VAE 入門
我們首先要明確一點(diǎn),VAE是一個(gè)生成式的模型,什么是生成式的模型?簡(jiǎn)單來(lái)說(shuō),就是可以用來(lái)生成數(shù)據(jù)的模型。怎么樣才能生成數(shù)據(jù)呢?就是我們是知道數(shù)據(jù)的分布的?有了這個(gè)分布之后,我們就可以從這個(gè)分布中采樣,獲得新的數(shù)據(jù)。
這個(gè)思路好像很簡(jiǎn)單啊,但是問(wèn)題是這個(gè)是怎么得到的。有很多方法啊,其中包含這樣兩大類:1. 基于概率的,如MCMC, Variational Inference等。以及2. 基于機(jī)器學(xué)習(xí)的。前面提到過(guò),本文主要面向的是對(duì)于機(jī)器學(xué)習(xí)比較熟悉的人,所以這里就對(duì)概率方法不多說(shuō)了,主要講一下機(jī)器學(xué)習(xí)的思路。
其實(shí)機(jī)器學(xué)習(xí)來(lái)解決這種問(wèn)題的思路是很清晰的,大多數(shù)的機(jī)器學(xué)習(xí)問(wèn)題都有這樣一個(gè)思路。我們想要優(yōu)化某個(gè)目標(biāo)O,我們先對(duì)這個(gè)問(wèn)題建個(gè)模型,模型可以表示為某個(gè)數(shù)學(xué)表達(dá)式 ,其中
是參數(shù)。我們用數(shù)據(jù)去訓(xùn)練這個(gè)模型,然后根據(jù)目標(biāo)O,去調(diào)整我們的參數(shù)
,我們希望訓(xùn)練結(jié)束的時(shí)候,能夠找到一組最優(yōu)的
。對(duì)應(yīng)到我們這個(gè)生成式的問(wèn)題,我們希望能夠生成一個(gè)新的數(shù)據(jù)
,那么我們構(gòu)造一個(gè)模型
,我們的目標(biāo)呢就是這個(gè)生成的數(shù)據(jù)越真越好,就是在眾多的
中,我們希望能夠找到一個(gè)最好的
,能夠讓這個(gè)數(shù)據(jù)存在的概率
越大越好。有了目標(biāo),我們就能計(jì)算出損失函數(shù)
,然后就是利用梯度下降,逐步調(diào)整參數(shù),最終找到最優(yōu)。
前文這個(gè)過(guò)程好像很熟悉,但是存在幾個(gè)問(wèn)題:
- 我們建模
其實(shí)是根據(jù)我們的assumption來(lái)的,我們的模型結(jié)構(gòu),初始參數(shù)設(shè)置都是根據(jù)我們的assumption來(lái)的, 但是我們的assumption有可能是錯(cuò)的,而且很有可能是錯(cuò)的。因此引入過(guò)多或者過(guò)強(qiáng)的assumption都會(huì)導(dǎo)致我們的模型效果很差。
- 因?yàn)槲覀兊腶ssumption和真實(shí)數(shù)據(jù)分布存在偏差,相應(yīng)的,我們?cè)趦?yōu)化的過(guò)程中,很容易陷入到局部最優(yōu)中。
- 如果我們直接采用建模的方式來(lái)解決生成式問(wèn)題,那么我們通常需要構(gòu)造一個(gè)相對(duì)復(fù)雜的模型,或者說(shuō)參數(shù)很多的模型,來(lái)獲得較大的Capacity。這樣就導(dǎo)致我們優(yōu)化的過(guò)程非常的耗時(shí)(Computationally Expensive)。
上述三個(gè)問(wèn)題的存在,讓我們對(duì)直接建模這個(gè)思路產(chǎn)生了動(dòng)搖,至少直接建模并不適用于所有的場(chǎng)景。所以我們?cè)谥苯咏5幕A(chǔ)上做出修改,引入了隱變量的概念。我們認(rèn)為數(shù)據(jù)的生成是受到隱變量的影響的。比如手寫數(shù)字生成的任務(wù),我們?cè)谏蓴?shù)字的時(shí)候,會(huì)首先考慮,我們要生成的是數(shù)字幾啊,因?yàn)槲覀冎挥?0種數(shù)字可以生成,這個(gè)數(shù)字幾就是我們的隱變量
。有了這個(gè)隱變量,我們就不再是漫天生成數(shù)字了,我們只有10個(gè)方向去生成,這大大的縮小了我們的生成空間,降低了計(jì)算量。
引入隱變量,用數(shù)學(xué)公式可以表示為:
模型的學(xué)習(xí)過(guò)程也因?yàn)殡[變量的引入發(fā)生了改變。我們最終的目標(biāo)還是要計(jì)算,我們對(duì)隱變量的概率建個(gè)模
,參數(shù)是
。我們要找到能讓
值最大的參數(shù)
。但是現(xiàn)在是要把隱變量的所有可能取值都找到,然后求一個(gè)上面這樣的積分,來(lái)確定
的值,從而通過(guò)比較不同的
對(duì)應(yīng)的
的值來(lái)確定哪個(gè)
才是最合適的。是不是覺(jué)得天衣無(wú)縫?對(duì)于手寫數(shù)字來(lái)說(shuō),我們的隱變量的取值只有10個(gè),所以這個(gè)積分就退化成了只有10項(xiàng)的求和。但是對(duì)于很多其他問(wèn)題,這個(gè)隱變量的取值就有可能變得非常多,又不太好做了。所以我們的做法是不去計(jì)算積分了,我們做了一步近似。我們就采樣一個(gè)隱變量,我們希望挑出來(lái)的參數(shù)
能夠在這一個(gè)隱變量上表現(xiàn)好就行了。What?這樣的近似是不是差的太多?這樣真的呆膠布嗎?答案是肯定的。我們雖然用單個(gè)變量代替了積分,或者說(shuō),代替了期望值,但是我們機(jī)器學(xué)習(xí)的過(guò)程是在不斷的迭代的隨機(jī)過(guò)程(Stochastic process)。簡(jiǎn)單來(lái)說(shuō),就是我們會(huì)找一個(gè)又一個(gè)的sample,重復(fù)的進(jìn)行優(yōu)化,理論上講依然能夠得到最優(yōu)解(可以參考機(jī)器學(xué)習(xí)的學(xué)習(xí)理論)。總而言之,就是這樣用單樣本代替期望是可行的。
這個(gè)過(guò)程是不是聽(tīng)起來(lái)又是很合理?但是存在一個(gè)問(wèn)題,我們說(shuō)想要去采樣隱變量,但是從什么分布里采樣?從
?可以嗎?可以。但是我們?cè)偎伎家粋€(gè)問(wèn)題,真的所有的隱變量都是平等的嗎?回到手寫數(shù)字的例子,如果我們現(xiàn)在要生成的是數(shù)字7,那么隱變量如果是0,8這種帶圓圈的概率是不是不大。假如我們采樣到了一個(gè)隱變量代表的是數(shù)字0,那么這一次采樣是不是相當(dāng)于浪費(fèi)了?你本來(lái)就不怎么能指導(dǎo)我做這一次生成呀。所以,為了減少這樣的無(wú)效采樣,從而進(jìn)一步的降低計(jì)算量,我們并不是從
中采樣的,而是從
中采樣的。
好了,現(xiàn)在又有了一個(gè)新問(wèn)題,這個(gè)我們知道嘛?答案是不知道,不知道怎么辦?不知道那就去求?用什么樣的方法去求?用機(jī)器學(xué)習(xí)的去求,和前面對(duì)
建模一樣,我們?cè)谶@里對(duì)
建模為
,然后求個(gè)
,再梯度下降去優(yōu)化它。常見(jiàn)的做法是把
建模為一個(gè)正態(tài)分布:
講到這里VAE的主體框架已經(jīng)出來(lái)了:
我們用Q采樣出來(lái)一個(gè)隱變量
,然后我們根據(jù)這個(gè)隱變量,利用
生成新的圖片
。我們從這個(gè)過(guò)程中,計(jì)算損失值
,通過(guò)梯度下降的方式,不斷優(yōu)化
和
的參數(shù),從而我們能夠生成越來(lái)越好的圖片
這個(gè)框架到目前為止已經(jīng)可以說(shuō)是相對(duì)很完整了,里面呢有兩個(gè)函數(shù)需要去優(yōu)化:和
。這兩個(gè)函數(shù)我們都用神經(jīng)網(wǎng)絡(luò)去建模,但是我們依然需要做一件事,就是去定義一個(gè)損失函數(shù)
。我們這樣思考一下,我們定義損失函數(shù),是為了能夠讓
和
更好。我們首先來(lái)考慮
,如何讓
變得更好?我們回憶一下,
是我們定義出來(lái)用來(lái)估計(jì)分布
的,最好的
當(dāng)然就是能夠跟
一毛一樣啦。那么我們很自然的就想要把目標(biāo)函數(shù),或者說(shuō)損失值定義成這兩個(gè)分布之間的差距了。而計(jì)算分布差距,最常用的metric之一就是KL 散度。所以,我們想要讓這個(gè)公式最小化:
有人可能要說(shuō)啦:這,怎么最小化?我們不知道這個(gè)是啥我們才估計(jì)的呀。沒(méi)錯(cuò),不過(guò)我們可以試著把這個(gè)公式變個(gè)形式,試試看:
$$
\begin{aligned}
KL(Q(z|X);||;P(z|X)) & =E(logQ(z|X)-logP(z|X))\
&=E(logQ(z|X))-E(logP(z|X))\
&=E(logQ(z|X))-E(log(\frac{P(X|z)P(z)}{P(X)}))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+E(logP(X))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+logP(X)\
&=KL(logQ(z|X)||P(z))-E(logP(X|z))+logP(X)
\end{aligned}
logP(X)-KL(Q(z|X);||;P(z|X))=E(logP(X|z))-KL(logQ(z|X)||P(z))
$$
上面這個(gè)公式的轉(zhuǎn)換過(guò)程中,沒(méi)有用到什么很特別的技巧,主要就是貝葉斯公式套進(jìn)去了一下,我就不多說(shuō)了。重點(diǎn)來(lái)看一下最終的公式形式,我們發(fā)現(xiàn)公式左邊,恰好是我們想要優(yōu)化的目標(biāo),當(dāng)我們讓左邊最大化的時(shí)候,不僅和
越來(lái)越接近,我們的
也越來(lái)越大,也就是我們的
函數(shù)的參數(shù)也越來(lái)越好。一石二鳥(niǎo)!本來(lái)我們只是在考慮
的問(wèn)題,現(xiàn)在連帶著把
的問(wèn)題也解決了。我們看一下右邊是我們能計(jì)算的東西嗎?答案是肯定的,第一項(xiàng)
是我們建模的函數(shù),直接可以得到結(jié)果。第二項(xiàng)
中的
是我們建模的函數(shù),
呢是隱變量的分布,我們可以把這個(gè)分布定義為一個(gè)標(biāo)準(zhǔn)正態(tài)分布
,因?yàn)閺倪@個(gè)標(biāo)準(zhǔn)正態(tài)分布,理論上講我們可以映射到任意的輸出空間上(當(dāng)然標(biāo)準(zhǔn)正態(tài)分布也只是一個(gè)選項(xiàng),很多別的分布都是可以的)。所以公式右邊的兩項(xiàng)都是可以求的,我們就可以把這個(gè)作為我們的優(yōu)化目標(biāo)。至此,我們的計(jì)算過(guò)程算是完善了,可以用下面這樣一張圖表示
我們首先有一個(gè)輸入
,對(duì)應(yīng)到我們的例子里就是一張手寫數(shù)字圖片。我們將這個(gè)
輸入到自己定義的函數(shù)
中,因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=Q" alt="Q" mathimg="1">是個(gè)正態(tài)分布函數(shù),所以我們的做法是用兩個(gè)神經(jīng)網(wǎng)絡(luò)(管他們叫encoder)分別去計(jì)算期望
和方差
。有了這個(gè)分布之后,我們就可以采樣出來(lái)一個(gè)隱變量的樣本
,然后用這個(gè)樣本在通過(guò)神經(jīng)網(wǎng)絡(luò)
(管他叫decoder)去生成新的數(shù)據(jù)樣本
。在這個(gè)計(jì)算過(guò)程中,我們?cè)趀ncoder里計(jì)算了一個(gè)損失函數(shù)
,在decoder里計(jì)算了一個(gè)損失函數(shù)
(在數(shù)據(jù)為正態(tài)分布的時(shí)候等價(jià)于
)。
上面描述的過(guò)程更完整了,不過(guò)還有一個(gè)問(wèn)題,就是中間采樣的那一步。我們知道,神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)依賴于梯度下降,梯度下降就要求整個(gè)損失函數(shù)的梯度鏈條是存在的,或者說(shuō)參數(shù)是可導(dǎo)的。我們看decoder的這個(gè)損失函數(shù),在計(jì)算,除了涉及分布
以外,還依賴于隱變量
,隱變量又是從分布
中采樣出來(lái)的, 所以我們?cè)趯?duì)decoder的損失函數(shù)進(jìn)行梯度下降的時(shí)候,是要對(duì)
的參數(shù)也梯度下降的。但是因?yàn)橹虚g這一步采樣,我們的梯度斷掉了。采樣還怎么知道是什么梯度?所以這里用到了一個(gè)小trick: Reparamterization。簡(jiǎn)單來(lái)說(shuō)就是我們不再是采樣了,而是看做按照分布的期望和方差,加上一些小噪音,生成出來(lái)的樣本。也就是說(shuō):
其中這個(gè)是一個(gè)隨機(jī)噪聲,我們可以從一個(gè)標(biāo)準(zhǔn)正態(tài)分布中采樣得到。
這種做法非常 好理解對(duì)吧,我們的每個(gè)樣本都可以看作是這個(gè)樣本服從分布的期望,根據(jù)方差進(jìn)行波動(dòng)的結(jié)果。通過(guò)這種變換,原來(lái)斷掉的梯度鏈恢復(fù)啦,我們的梯度下降終于能夠進(jìn)行下去了。修正后的計(jì)算過(guò)程如下圖。
以上就是VAE的主要內(nèi)容,這里額外說(shuō)明一點(diǎn):從decoder的損失函數(shù)看,我們希望Q在估計(jì)隱變量的分布,而隱變量的分布就是一個(gè)標(biāo)準(zhǔn)正態(tài)分布,所以我們?cè)趯?shí)際生成的過(guò)程中,不需要用到encoder,只需要從標(biāo)準(zhǔn)正態(tài)分布里隨便采樣一個(gè)隱變量就能進(jìn)行生成了。
與自編碼模型Autoencoder比較
很多人可能很熟悉自編碼模型,自編碼模型英文叫Autoencoder。而我們這個(gè)VAE呢,叫Variational Autoencoder,聽(tīng)起來(lái)好像關(guān)系很大,但是其實(shí)關(guān)系真的不是很大。只不過(guò)我們這個(gè)VAE呢也像Autoencoder一樣有一個(gè)encoder,一個(gè)decoder。但是Autoencoder對(duì)于隱變量沒(méi)有什么限制,它的過(guò)程就很簡(jiǎn)單,就是從輸入計(jì)算一個(gè)隱變量
,然后再把
映射到一個(gè)新的
,損失函數(shù)只有一個(gè),就是比較
和
計(jì)算一個(gè)重構(gòu)損失。這樣做呢并沒(méi)有很好的利用隱變量。但是它最大的缺點(diǎn)是生成出來(lái)的東西是和輸入的
高度相關(guān)的,并不能生成出來(lái)什么很新奇的玩意,所以Autoencoder一般只能用來(lái)做做降噪什么的,并不能真正用來(lái)做生成。但是VAE就不一樣了,前面我們講過(guò),VAE在生成階段是完全拋開(kāi)了encoder的,隱變量是從標(biāo)準(zhǔn)正態(tài)分布里隨隨便便采樣出來(lái)的,這就擺脫了對(duì)輸入的依賴,想怎么生成就怎么生成。
代碼實(shí)現(xiàn)
"""
@Time : 26/01/2023 @Software: PyCharm @File : model.py
"""
import torchvision
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
import torch.nn.functional as F
class Encoder(nn.Module):
def __init__(self, hidden_dim=2):
super(Encoder, self).__init__()
self.linear1 = nn.Linear(28 * 28, 512)
self.linear2 = nn.Linear(512, hidden_dim)
def forward(self, x):
"""x:[N,1,28,28]"""
x = torch.flatten(x, start_dim=1) # [N,764]
x = self.linear1(x) # [N, 512]
x = F.relu(x)
return self.linear2(x) # [N,2]
class Decoder(nn.Module):
def __init__(self, hidden_dim):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(hidden_dim, 512)
self.linear2 = nn.Linear(512, 28 * 28)
def forward(self, x):
"""x:[N,2]"""
hidden = self.linear1(x) # [N, 512]
hidden = torch.relu(hidden)
hidden = self.linear2(hidden) # [N,764]
hidden = torch.sigmoid(hidden)
return torch.reshape(hidden, (-1, 1, 28, 28))
class AutoEncoder(nn.Module):
def __init__(self, hidden_dim=2):
super(AutoEncoder, self).__init__()
self.name = "ae"
self.encoder = Encoder(hidden_dim)
self.decoder = Decoder(hidden_dim)
def forward(self, x):
return self.decoder(self.encoder(x))
class VAEEncoder(nn.Module):
def __init__(self, hidden_dim=2):
super(VAEEncoder, self).__init__()
self.linear1 = nn.Linear(28 * 28, 512)
self.linear2 = nn.Linear(512, hidden_dim)
self.linear3 = nn.Linear(512, hidden_dim)
self.noise_dist = torch.distributions.Normal(0, 1)
self.kl = 0
def forward(self, x):
x = torch.flatten(x, start_dim=1)
x = self.linear1(x)
x = torch.relu(x)
mu = self.linear2(x)
sigma = torch.exp(self.linear3(x))
hidden = mu + self.noise_dist.sample(mu.shape) * sigma
self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()
return hidden
class VAE(nn.Module):
def __init__(self, hidden_dim=2):
super(VAE, self).__init__()
self.name = "vae"
self.encoder = VAEEncoder(hidden_dim=hidden_dim)
self.decoder = Decoder(hidden_dim)
self.kl = 0
def forward(self, x):
hidden = self.encoder(x)
self.kl = self.encoder.kl
return self.decoder(hidden)
if __name__ == '__main__':
dataset = torchvision.datasets.MNIST("data", transform=torchvision.transforms.ToTensor(), download=True)
print(dataset[0][0].shape)
核心的VAE 模型代碼實(shí)現(xiàn)我貼在這里,其余代碼已經(jīng)上傳到Github。