VAE教程

好久沒(méi)看的VAE又不太記得了,重新梳理一下思路,在S同學(xué)的指導(dǎo)下,又有了一些新的理解。之前寫過(guò)一篇關(guān)于VAE的入門教程,但是感覺(jué)還不夠簡(jiǎn)練,刪掉重新寫一個(gè)哈哈哈。這篇文章主要是從一個(gè)熟悉machine learning但是對(duì)于VAE一點(diǎn)都不懂的視角,進(jìn)行寫作的,并不涉及很多復(fù)雜的理論知識(shí),輔助理解為主。另外,上次那篇文章是先給結(jié)論,在慢慢講細(xì)節(jié)。這篇文章將會(huì)循序漸進(jìn)逐漸推導(dǎo)出VAE的各種Trick存在的必要性。

VAE 入門

我們首先要明確一點(diǎn),VAE是一個(gè)生成式的模型,什么是生成式的模型?簡(jiǎn)單來(lái)說(shuō),就是可以用來(lái)生成數(shù)據(jù)的模型。怎么樣才能生成數(shù)據(jù)呢?就是我們是知道數(shù)據(jù)的分布P(X)的?有了這個(gè)分布之后,我們就可以從這個(gè)分布中采樣,獲得新的數(shù)據(jù)。

這個(gè)思路好像很簡(jiǎn)單啊,但是問(wèn)題是這個(gè)P(X)是怎么得到的。有很多方法啊,其中包含這樣兩大類:1. 基于概率的,如MCMC, Variational Inference等。以及2. 基于機(jī)器學(xué)習(xí)的。前面提到過(guò),本文主要面向的是對(duì)于機(jī)器學(xué)習(xí)比較熟悉的人,所以這里就對(duì)概率方法不多說(shuō)了,主要講一下機(jī)器學(xué)習(xí)的思路。

其實(shí)機(jī)器學(xué)習(xí)來(lái)解決這種問(wèn)題的思路是很清晰的,大多數(shù)的機(jī)器學(xué)習(xí)問(wèn)題都有這樣一個(gè)思路。我們想要優(yōu)化某個(gè)目標(biāo)O,我們先對(duì)這個(gè)問(wèn)題建個(gè)模型,模型可以表示為某個(gè)數(shù)學(xué)表達(dá)式f(x;\theta) ,其中\theta是參數(shù)。我們用數(shù)據(jù)去訓(xùn)練這個(gè)模型,然后根據(jù)目標(biāo)O,去調(diào)整我們的參數(shù)\theta,我們希望訓(xùn)練結(jié)束的時(shí)候,能夠找到一組最優(yōu)的\theta。對(duì)應(yīng)到我們這個(gè)生成式的問(wèn)題,我們希望能夠生成一個(gè)新的數(shù)據(jù)X,那么我們構(gòu)造一個(gè)模型f(x;\theta) ,我們的目標(biāo)呢就是這個(gè)生成的數(shù)據(jù)越真越好,就是在眾多的\theta中,我們希望能夠找到一個(gè)最好的\theta,能夠讓這個(gè)數(shù)據(jù)存在的概率P(X;\theta)越大越好。有了目標(biāo),我們就能計(jì)算出損失函數(shù)L,然后就是利用梯度下降,逐步調(diào)整參數(shù),最終找到最優(yōu)。

前文這個(gè)過(guò)程好像很熟悉,但是存在幾個(gè)問(wèn)題:

  1. 我們建模f(x;\theta)其實(shí)是根據(jù)我們的assumption來(lái)的,我們的模型結(jié)構(gòu),初始參數(shù)設(shè)置都是根據(jù)我們的assumption來(lái)的, 但是我們的assumption有可能是錯(cuò)的,而且很有可能是錯(cuò)的。因此引入過(guò)多或者過(guò)強(qiáng)的assumption都會(huì)導(dǎo)致我們的模型效果很差。
  2. 因?yàn)槲覀兊腶ssumption和真實(shí)數(shù)據(jù)分布存在偏差,相應(yīng)的,我們?cè)趦?yōu)化的過(guò)程中,很容易陷入到局部最優(yōu)中。
  3. 如果我們直接采用建模的方式來(lái)解決生成式問(wèn)題,那么我們通常需要構(gòu)造一個(gè)相對(duì)復(fù)雜的模型,或者說(shuō)參數(shù)很多的模型,來(lái)獲得較大的Capacity。這樣就導(dǎo)致我們優(yōu)化的過(guò)程非常的耗時(shí)(Computationally Expensive)。

上述三個(gè)問(wèn)題的存在,讓我們對(duì)直接建模這個(gè)思路產(chǎn)生了動(dòng)搖,至少直接建模并不適用于所有的場(chǎng)景。所以我們?cè)谥苯咏5幕A(chǔ)上做出修改,引入了隱變量的概念。我們認(rèn)為數(shù)據(jù)的生成是受到隱變量z的影響的。比如手寫數(shù)字生成的任務(wù),我們?cè)谏蓴?shù)字的時(shí)候,會(huì)首先考慮,我們要生成的是數(shù)字幾啊,因?yàn)槲覀冎挥?0種數(shù)字可以生成,這個(gè)數(shù)字幾就是我們的隱變量z。有了這個(gè)隱變量,我們就不再是漫天生成數(shù)字了,我們只有10個(gè)方向去生成,這大大的縮小了我們的生成空間,降低了計(jì)算量。

Pasted image 20230127200838.png

引入隱變量,用數(shù)學(xué)公式可以表示為:
P(X)=\int P(X|z)P(z)dz=E_z(X|z)
模型的學(xué)習(xí)過(guò)程也因?yàn)殡[變量的引入發(fā)生了改變。我們最終的目標(biāo)還是要計(jì)算P(X),我們對(duì)隱變量的概率建個(gè)模P(X|z)=f(X,z;\theta),參數(shù)是\theta。我們要找到能讓P(X)值最大的參數(shù)\theta。但是現(xiàn)在是要把隱變量的所有可能取值都找到,然后求一個(gè)上面這樣的積分,來(lái)確定P(X)的值,從而通過(guò)比較不同的\theta對(duì)應(yīng)的P(X)的值來(lái)確定哪個(gè)\theta才是最合適的。是不是覺(jué)得天衣無(wú)縫?對(duì)于手寫數(shù)字來(lái)說(shuō),我們的隱變量的取值只有10個(gè),所以這個(gè)積分就退化成了只有10項(xiàng)的求和。但是對(duì)于很多其他問(wèn)題,這個(gè)隱變量的取值就有可能變得非常多,又不太好做了。所以我們的做法是不去計(jì)算積分了,我們做了一步近似。我們就采樣一個(gè)隱變量,我們希望挑出來(lái)的參數(shù)\theta能夠在這一個(gè)隱變量上表現(xiàn)好就行了。What?這樣的近似是不是差的太多?這樣真的呆膠布嗎?答案是肯定的。我們雖然用單個(gè)變量代替了積分,或者說(shuō),代替了期望值,但是我們機(jī)器學(xué)習(xí)的過(guò)程是在不斷的迭代的隨機(jī)過(guò)程(Stochastic process)。簡(jiǎn)單來(lái)說(shuō),就是我們會(huì)找一個(gè)又一個(gè)的sample,重復(fù)的進(jìn)行優(yōu)化,理論上講依然能夠得到最優(yōu)解(可以參考機(jī)器學(xué)習(xí)的學(xué)習(xí)理論)。總而言之,就是這樣用單樣本代替期望是可行的。

這個(gè)過(guò)程是不是聽(tīng)起來(lái)又是很合理?但是存在一個(gè)問(wèn)題,我們說(shuō)想要去采樣隱變量z,但是從什么分布里采樣?從P(z)?可以嗎?可以。但是我們?cè)偎伎家粋€(gè)問(wèn)題,真的所有的隱變量都是平等的嗎?回到手寫數(shù)字的例子,如果我們現(xiàn)在要生成的是數(shù)字7,那么隱變量如果是0,8這種帶圓圈的概率是不是不大。假如我們采樣到了一個(gè)隱變量代表的是數(shù)字0,那么這一次采樣是不是相當(dāng)于浪費(fèi)了?你本來(lái)就不怎么能指導(dǎo)我做這一次生成呀。所以,為了減少這樣的無(wú)效采樣,從而進(jìn)一步的降低計(jì)算量,我們并不是從P(X)中采樣的,而是從P(z|X)中采樣的。
z=sample(P(z|X))
好了,現(xiàn)在又有了一個(gè)新問(wèn)題,這個(gè)P(z|X)我們知道嘛?答案是不知道,不知道怎么辦?不知道那就去求?用什么樣的方法去求?用機(jī)器學(xué)習(xí)的去求,和前面對(duì)P(X|z)建模一樣,我們?cè)谶@里對(duì)P(z|X)建模為Q(z|X),然后求個(gè)Loss,再梯度下降去優(yōu)化它。常見(jiàn)的做法是把Q建模為一個(gè)正態(tài)分布:
Q(z|X)=N(\mu,\sigma^2I)
講到這里VAE的主體框架已經(jīng)出來(lái)了:

我們用Q采樣出來(lái)一個(gè)隱變量z,然后我們根據(jù)這個(gè)隱變量,利用P(X|z)生成新的圖片\hat{x}。我們從這個(gè)過(guò)程中,計(jì)算損失值L,通過(guò)梯度下降的方式,不斷優(yōu)化QP(X|z)的參數(shù),從而我們能夠生成越來(lái)越好的圖片

這個(gè)框架到目前為止已經(jīng)可以說(shuō)是相對(duì)很完整了,里面呢有兩個(gè)函數(shù)需要去優(yōu)化:QP(X|z)。這兩個(gè)函數(shù)我們都用神經(jīng)網(wǎng)絡(luò)去建模,但是我們依然需要做一件事,就是去定義一個(gè)損失函數(shù)L。我們這樣思考一下,我們定義損失函數(shù),是為了能夠讓QP(X|z)更好。我們首先來(lái)考慮Q,如何讓Q變得更好?我們回憶一下,Q是我們定義出來(lái)用來(lái)估計(jì)分布P(z|X)的,最好的Q當(dāng)然就是能夠跟P(z|X)一毛一樣啦。那么我們很自然的就想要把目標(biāo)函數(shù),或者說(shuō)損失值定義成這兩個(gè)分布之間的差距了。而計(jì)算分布差距,最常用的metric之一就是KL 散度。所以,我們想要讓這個(gè)公式最小化:
KL(Q(z|X)\;||\;P(z|X))
有人可能要說(shuō)啦:這,怎么最小化?我們不知道這個(gè)P(z|X)是啥我們才估計(jì)的呀。沒(méi)錯(cuò),不過(guò)我們可以試著把這個(gè)公式變個(gè)形式,試試看:
$$
\begin{aligned}
KL(Q(z|X);||;P(z|X)) & =E(logQ(z|X)-logP(z|X))\
&=E(logQ(z|X))-E(logP(z|X))\
&=E(logQ(z|X))-E(log(\frac{P(X|z)P(z)}{P(X)}))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+E(logP(X))\
&=E(logQ(z|X))-E(logP(X|z))-E(logP(z))+logP(X)\
&=KL(logQ(z|X)||P(z))-E(logP(X|z))+logP(X)

\end{aligned}
交換一下公式的項(xiàng),我們得到:
logP(X)-KL(Q(z|X);||;P(z|X))=E(logP(X|z))-KL(logQ(z|X)||P(z))
$$
上面這個(gè)公式的轉(zhuǎn)換過(guò)程中,沒(méi)有用到什么很特別的技巧,主要就是貝葉斯公式套進(jìn)去了一下,我就不多說(shuō)了。重點(diǎn)來(lái)看一下最終的公式形式,我們發(fā)現(xiàn)公式左邊,恰好是我們想要優(yōu)化的目標(biāo),當(dāng)我們讓左邊最大化的時(shí)候,不僅QP(z|X)越來(lái)越接近,我們的P(X)也越來(lái)越大,也就是我們的P(X|z)函數(shù)的參數(shù)也越來(lái)越好。一石二鳥(niǎo)!本來(lái)我們只是在考慮Q的問(wèn)題,現(xiàn)在連帶著把P(X|z)的問(wèn)題也解決了。我們看一下右邊是我們能計(jì)算的東西嗎?答案是肯定的,第一項(xiàng)E(logP(X|z))是我們建模的函數(shù),直接可以得到結(jié)果。第二項(xiàng)KL(logQ(z|X)||P(z))中的Q是我們建模的函數(shù),P(z)呢是隱變量的分布,我們可以把這個(gè)分布定義為一個(gè)標(biāo)準(zhǔn)正態(tài)分布N(0,I),因?yàn)閺倪@個(gè)標(biāo)準(zhǔn)正態(tài)分布,理論上講我們可以映射到任意的輸出空間上(當(dāng)然標(biāo)準(zhǔn)正態(tài)分布也只是一個(gè)選項(xiàng),很多別的分布都是可以的)。所以公式右邊的兩項(xiàng)都是可以求的,我們就可以把這個(gè)作為我們的優(yōu)化目標(biāo)。至此,我們的計(jì)算過(guò)程算是完善了,可以用下面這樣一張圖表示

Pasted image 20230127204653.png

我們首先有一個(gè)輸入X,對(duì)應(yīng)到我們的例子里就是一張手寫數(shù)字圖片。我們將這個(gè)X輸入到自己定義的函數(shù)Q中,因?yàn)?img class="math-inline" src="https://math.jianshu.com/math?formula=Q" alt="Q" mathimg="1">是個(gè)正態(tài)分布函數(shù),所以我們的做法是用兩個(gè)神經(jīng)網(wǎng)絡(luò)(管他們叫encoder)分別去計(jì)算期望\mu和方差\sigma^2 。有了這個(gè)分布之后,我們就可以采樣出來(lái)一個(gè)隱變量的樣本z,然后用這個(gè)樣本在通過(guò)神經(jīng)網(wǎng)絡(luò)P(X|z) (管他叫decoder)去生成新的數(shù)據(jù)樣本f(z)。在這個(gè)計(jì)算過(guò)程中,我們?cè)趀ncoder里計(jì)算了一個(gè)損失函數(shù)KL(Q||P(z)),在decoder里計(jì)算了一個(gè)損失函數(shù)||x-f(z)||^2(在數(shù)據(jù)為正態(tài)分布的時(shí)候等價(jià)于-logP(X|z))。

上面描述的過(guò)程更完整了,不過(guò)還有一個(gè)問(wèn)題,就是中間采樣的那一步。我們知道,神經(jīng)網(wǎng)絡(luò)的學(xué)習(xí)依賴于梯度下降,梯度下降就要求整個(gè)損失函數(shù)的梯度鏈條是存在的,或者說(shuō)參數(shù)是可導(dǎo)的。我們看decoder的這個(gè)損失函數(shù),在計(jì)算||x-f(z)||^2,除了涉及分布P(X|z)以外,還依賴于隱變量z,隱變量又是從分布Q中采樣出來(lái)的, 所以我們?cè)趯?duì)decoder的損失函數(shù)進(jìn)行梯度下降的時(shí)候,是要對(duì)Q的參數(shù)也梯度下降的。但是因?yàn)橹虚g這一步采樣,我們的梯度斷掉了。采樣還怎么知道是什么梯度?所以這里用到了一個(gè)小trick: Reparamterization。簡(jiǎn)單來(lái)說(shuō)就是我們不再是采樣了,而是看做按照分布的期望和方差,加上一些小噪音,生成出來(lái)的樣本。也就是說(shuō):
z=\mu+e\sigma
其中這個(gè)e是一個(gè)隨機(jī)噪聲,我們可以從一個(gè)標(biāo)準(zhǔn)正態(tài)分布中采樣得到。
e\sim N(0,1)
這種做法非常 好理解對(duì)吧,我們的每個(gè)樣本都可以看作是這個(gè)樣本服從分布的期望,根據(jù)方差進(jìn)行波動(dòng)的結(jié)果。通過(guò)這種變換,原來(lái)斷掉的梯度鏈恢復(fù)啦,我們的梯度下降終于能夠進(jìn)行下去了。修正后的計(jì)算過(guò)程如下圖。

Pasted image 20230128180833.png

以上就是VAE的主要內(nèi)容,這里額外說(shuō)明一點(diǎn):從decoder的損失函數(shù)看,我們希望Q在估計(jì)隱變量的分布,而隱變量的分布就是一個(gè)標(biāo)準(zhǔn)正態(tài)分布,所以我們?cè)趯?shí)際生成的過(guò)程中,不需要用到encoder,只需要從標(biāo)準(zhǔn)正態(tài)分布里隨便采樣一個(gè)隱變量z就能進(jìn)行生成了。

與自編碼模型Autoencoder比較

很多人可能很熟悉自編碼模型,自編碼模型英文叫Autoencoder。而我們這個(gè)VAE呢,叫Variational Autoencoder,聽(tīng)起來(lái)好像關(guān)系很大,但是其實(shí)關(guān)系真的不是很大。只不過(guò)我們這個(gè)VAE呢也像Autoencoder一樣有一個(gè)encoder,一個(gè)decoder。但是Autoencoder對(duì)于隱變量沒(méi)有什么限制,它的過(guò)程就很簡(jiǎn)單,就是從輸入x計(jì)算一個(gè)隱變量z,然后再把z映射到一個(gè)新的\hat{x},損失函數(shù)只有一個(gè),就是比較x\hat{x}計(jì)算一個(gè)重構(gòu)損失。這樣做呢并沒(méi)有很好的利用隱變量。但是它最大的缺點(diǎn)是生成出來(lái)的東西是和輸入的x高度相關(guān)的,并不能生成出來(lái)什么很新奇的玩意,所以Autoencoder一般只能用來(lái)做做降噪什么的,并不能真正用來(lái)做生成。但是VAE就不一樣了,前面我們講過(guò),VAE在生成階段是完全拋開(kāi)了encoder的,隱變量是從標(biāo)準(zhǔn)正態(tài)分布里隨隨便便采樣出來(lái)的,這就擺脫了對(duì)輸入的依賴,想怎么生成就怎么生成。

代碼實(shí)現(xiàn)

"""  
    @Time    : 26/01/2023    @Software: PyCharm    @File    : model.py  
"""  
import torchvision  
import matplotlib.pyplot as plt  
import torch.nn as nn  
import torch  
import torch.nn.functional as F  
  
  
class Encoder(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(Encoder, self).__init__()  
        self.linear1 = nn.Linear(28 * 28, 512)  
        self.linear2 = nn.Linear(512, hidden_dim)  
  
    def forward(self, x):  
        """x:[N,1,28,28]"""  
        x = torch.flatten(x, start_dim=1)  # [N,764]  
        x = self.linear1(x)  # [N, 512]  
        x = F.relu(x)  
        return self.linear2(x)  # [N,2]  
  
  
class Decoder(nn.Module):  
    def __init__(self, hidden_dim):  
        super(Decoder, self).__init__()  
        self.linear1 = nn.Linear(hidden_dim, 512)  
        self.linear2 = nn.Linear(512, 28 * 28)  
  
    def forward(self, x):  
        """x:[N,2]"""  
        hidden = self.linear1(x)  # [N, 512]  
        hidden = torch.relu(hidden)  
        hidden = self.linear2(hidden)  # [N,764]  
        hidden = torch.sigmoid(hidden)  
        return torch.reshape(hidden, (-1, 1, 28, 28))  
  
  
class AutoEncoder(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(AutoEncoder, self).__init__()  
        self.name = "ae"  
        self.encoder = Encoder(hidden_dim)  
        self.decoder = Decoder(hidden_dim)  
  
    def forward(self, x):  
        return self.decoder(self.encoder(x))  
  
  
class VAEEncoder(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(VAEEncoder, self).__init__()  
        self.linear1 = nn.Linear(28 * 28, 512)  
        self.linear2 = nn.Linear(512, hidden_dim)  
        self.linear3 = nn.Linear(512, hidden_dim)  
        self.noise_dist = torch.distributions.Normal(0, 1)  
        self.kl = 0  
  
    def forward(self, x):  
        x = torch.flatten(x, start_dim=1)  
        x = self.linear1(x)  
        x = torch.relu(x)  
        mu = self.linear2(x)  
        sigma = torch.exp(self.linear3(x))  
        hidden = mu + self.noise_dist.sample(mu.shape) * sigma  
        self.kl = (sigma ** 2 + mu ** 2 - torch.log(sigma) - 1 / 2).sum()  
        return hidden  
  
  
class VAE(nn.Module):  
    def __init__(self, hidden_dim=2):  
        super(VAE, self).__init__()  
        self.name = "vae"  
        self.encoder = VAEEncoder(hidden_dim=hidden_dim)  
        self.decoder = Decoder(hidden_dim)  
        self.kl = 0  
  
    def forward(self, x):  
        hidden = self.encoder(x)  
        self.kl = self.encoder.kl  
        return self.decoder(hidden)  
  
  
if __name__ == '__main__':  
    dataset = torchvision.datasets.MNIST("data", transform=torchvision.transforms.ToTensor(), download=True)  
    print(dataset[0][0].shape)

核心的VAE 模型代碼實(shí)現(xiàn)我貼在這里,其余代碼已經(jīng)上傳到Github

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容