對(duì) BEGAN 損失函數(shù) 的理解

具體的對(duì)BEGAN的原理和特點(diǎn)的講解,就不展開(kāi)了,具體可以參考這兩篇文章

深度學(xué)習(xí)【43】BEGAN
BEGAN解讀

但是其對(duì)BEGAN損失函數(shù)的解釋?zhuān)矣X(jué)得有點(diǎn)太理論化,不太好理解,以下是對(duì)這個(gè)對(duì)抗生成網(wǎng)絡(luò)的損失函數(shù)的理解

BEGAN生成的圖片質(zhì)量和多樣性有了很大的提升

GAN的損失函數(shù)

首先看一下CGAN的判別器和生成器的損失函數(shù)

CGAN中,生成器是通過(guò)一些條件,來(lái)生成需要的圖像,而判別器是通過(guò)輸入真假圖片和條件,來(lái)輸出判斷的結(jié)果

gen_label = generator(image=train_picture, gf_dim=64, reuse=False, name='generator')  # 得到生成器的輸出
dis_real = discriminator(image=train_picture, targets=train_label, df_dim=64, reuse=False, name="discriminator")  # 判別器返回的對(duì)真實(shí)標(biāo)簽的判別結(jié)果
dis_fake = discriminator(image=train_picture, targets=gen_label, df_dim=64, reuse=True, name="discriminator")  # 判別器返回的對(duì)生成(虛假的)標(biāo)簽判別結(jié)果

gen_loss = tf.reduce_mean(-tf.log(dis_fake))  # 計(jì)算生成器的loss
dis_loss = tf.reduce_mean(-(tf.log(dis_real) + tf.log(1 - dis_fake)))  # 計(jì)算判別器的loss

可以看到,判別器的目的是為了讓真的圖片的判別結(jié)果盡量靠近1,讓生成的假的圖像的判別結(jié)果盡量靠近0

而生成器的目的是為了讓假的圖像盡量像真的,也就是讓假的圖片的判別結(jié)果盡量靠近1

BEGAN的結(jié)構(gòu)

BEGAN中,判別器并不是直觀的輸出判別的結(jié)果,而是一個(gè)自編碼器,用圖表示的話,如下圖

自編碼器也就是把一個(gè)圖片壓縮成一個(gè)小的向量,再把這個(gè)小的向量重新解壓成一張圖片。

那為什么要進(jìn)行這個(gè)過(guò)程呢,因?yàn)樵诓粩嘤?xùn)練的這個(gè)過(guò)程中,這個(gè)自編碼器就相當(dāng)于變成了一個(gè)特征提取器。

形象的說(shuō),在這個(gè)過(guò)程中訓(xùn)練了一個(gè)壓縮和解壓的方法。這個(gè)自編碼器(特征提取器)會(huì)根據(jù)圖片的特征進(jìn)行壓縮和解壓

  • 假如輸入了一張人臉的圖片,那在壓縮時(shí),會(huì)保存一些重要的特征,比如五官和發(fā)型。將這些特征抽象出來(lái)保存到一個(gè)小的向量中,再進(jìn)行解壓。解壓的過(guò)程也是一樣,把這些抽象的人臉特征再表達(dá)出來(lái)

  • 但是假如輸入了一張亂七八糟的非人臉圖。那么在壓縮的時(shí)候,找不到人臉的有效特征,就會(huì)亂七八糟的壓縮,得到的向量也沒(méi)有什么價(jià)值,最后解壓出來(lái)的結(jié)果也不會(huì)和原來(lái)輸入的圖像有什么關(guān)聯(lián)

而生成器就相當(dāng)于用了半個(gè)判別器,它只用到了解壓的過(guò)程,也就是把一個(gè)帶有特征的向量,生成一個(gè)人臉數(shù)據(jù),自己畫(huà)了一個(gè)簡(jiǎn)單的圖

BEGAN的簡(jiǎn)單結(jié)構(gòu)圖

(其實(shí)我并不喜歡把這個(gè)結(jié)構(gòu)看成是生成器-判別器的結(jié)構(gòu),而更喜歡看作是生成器-特征提取器的結(jié)構(gòu))

BEGAN的損失函數(shù)

一般生成對(duì)抗網(wǎng)絡(luò)的兩個(gè)結(jié)構(gòu)(生成器、判別器)的損失函數(shù)都直接來(lái)源于判別器的輸出結(jié)果,那現(xiàn)在這個(gè)判別器是一個(gè)特征提取器時(shí),要怎么設(shè)計(jì)損失函數(shù)呢?

首先,一開(kāi)始特征提取器的效果并不好,還學(xué)不到完全的人臉特征。所以我們需要讓特征提取器盡量準(zhǔn)確的提取特征,也就是讓真圖像經(jīng)過(guò)特征提取器的壓縮和解壓之后,盡可能的形成和原圖相似的圖片(使用L1_loss)

于是有了第一個(gè)損失函數(shù),下面的input_real是真實(shí)的圖像,d_model_real是真實(shí)圖像經(jīng)過(guò)判別器(特征提取器)重構(gòu)后的圖像

d_real = tf.reduce_mean(tf.abs(input_real - d_model_real))  # 使用了L1_loss
d_loss = d_real

這個(gè)d_loss就是判別器(特征提取器)的損失函數(shù)原型。通過(guò)讓重構(gòu)后的圖像和原圖像盡量相似,來(lái)讓特征提取器能夠更加準(zhǔn)確的提取人臉特征并進(jìn)行重構(gòu)。

因?yàn)檫@里是為了訓(xùn)練提取人臉特征的能力,所以我們主要是應(yīng)用了真實(shí)人臉的重構(gòu)數(shù)據(jù),因?yàn)橐婚_(kāi)始假的人臉還非常假,所以其重構(gòu)的結(jié)構(gòu)不能作為判斷條件

那生成器的Loss呢?

在論文中,生成器的Loss是下述代碼中的g_loss,其中

d_fake = tf.reduce_mean(tf.abs(g_model_fake - d_model_fake))
g_loss = d_fake

很明顯,它是通過(guò)讓“人造人臉” 經(jīng)過(guò)特征提取器后 重構(gòu)的人臉 更加接近于 原來(lái)輸入的“人造人臉”

因?yàn)檫@個(gè)特征提取器是一個(gè)“人臉特征提取器”,它可以很好的把人臉中的特征提取出來(lái),壓縮再解壓,進(jìn)行還原。假如現(xiàn)在有一張圖像,通過(guò)這個(gè)特征提取器后,重構(gòu)的圖像和原圖像很類(lèi)似,那這個(gè)圖像一定就是人臉圖像!這也就是為什么可以用上述損失函數(shù)來(lái)作為生成器的損失函數(shù),讓生成器生成的人臉越來(lái)越逼真

但是論文中,還有另一個(gè)參數(shù),叫k_t

剛才上面的判別器的損失函數(shù),只是原型,其完整的損失函數(shù)如下

d_real = tf.reduce_mean(tf.abs(input_real - d_model_real))
d_fake = tf.reduce_mean(tf.abs(g_model_fake - d_model_fake))
d_loss = d_real - k_t * d_fake

也就是說(shuō),判別器不僅僅是用到了真實(shí)圖像的重構(gòu)結(jié)果,也用到了假的圖像的重構(gòu)結(jié)果

假的圖像一開(kāi)始生成的肯定不是人臉,為什么可以把它作為特征提取器的訓(xùn)練內(nèi)容呢?

在這里,這個(gè)k_t的變量是不斷變化的,其一開(kāi)始是一個(gè)很小的小數(shù),比如0.0001,所以在一開(kāi)始生成器的結(jié)果還不是很好時(shí),幾乎就是真實(shí)圖像在起作用。當(dāng)訓(xùn)練了一段時(shí)間之后,k_t就會(huì)慢慢增加,差不多到2000次時(shí),k_t增加到了0.04

也就是說(shuō),在訓(xùn)練了一段時(shí)間之后,隨著生成的人臉圖像越來(lái)越好。人造的人臉圖像也會(huì)被加入到特征提取器的訓(xùn)練當(dāng)中。

那為什么,是d_real - k_t * d_fake 而不是 d_real + k_t * d_fake

如果是用加號(hào),那么我們是希望d_reald_fake都越小越好。也就是說(shuō)我們希望真假圖像在重構(gòu)后,和原圖一模一樣

但是實(shí)際上這并不是我們想要的,而且這種訓(xùn)練方式會(huì)很難,在將一張圖片進(jìn)行壓縮后,必定會(huì)損失某些信息,完全無(wú)損的還原,幾乎是不可能的

所以在使用上述的——“讓真圖的重構(gòu)結(jié)果和真圖更相近的方法”訓(xùn)練判別器(特征提取器)時(shí),在收斂到一定的層次之后就無(wú)法繼續(xù)收斂

我們需要的是:當(dāng)真實(shí)人臉經(jīng)過(guò) 提取特征重構(gòu) 之后,在眉頭增加了一顆痣,那么這個(gè)重構(gòu)誤差就是這顆痣;當(dāng)生成的人臉經(jīng)過(guò) 提取特征重構(gòu) 之后,也在眉頭增加了一顆痣,那么重構(gòu)誤差也是一顆痣,他們的重構(gòu)誤差是接近的

所以當(dāng)出現(xiàn)上述這種情況時(shí),這個(gè)特征提取器就是一個(gè)非常好的特征提取器,也就是讓真假圖像的 重構(gòu)誤差的分布相似,以訓(xùn)練整個(gè)特征提取器提取的特征越來(lái)越好

判別器的損失函數(shù)總結(jié)

生成器在一開(kāi)始生成圖像的質(zhì)量并不高時(shí),k_t值很小,可以當(dāng)做只是在讓真實(shí)人臉重構(gòu)后盡量像真實(shí)人臉,以達(dá)到快速收斂的效果。在收斂到一定程度之后,k_t的值已經(jīng)是一個(gè)不可忽略的值了,而這時(shí)候僅適用真實(shí)圖像重構(gòu)的相似度也已經(jīng)很難優(yōu)化了,所以使用兩個(gè)重構(gòu)的誤差來(lái)進(jìn)行優(yōu)化,達(dá)到進(jìn)一步收斂的效果

所以,生成器和判別器的損失函數(shù)以及具體k_t值的變化過(guò)程如下

其中:
L(x) = L(x, D) = L1\underline{ }loss(x-D(x))
是真實(shí)圖像的重構(gòu)誤差
L(G(z_D)) = L(G(z_D), D) = L1\underline{ }loss(G(z_D)-D(G(z_D)))
是生成的圖像的重構(gòu)誤差

另一種理解方式

另一種理解方式,假設(shè)真假圖像經(jīng)過(guò)判別器(特征提取器)重構(gòu)后得到的和原圖的誤差分布分別為A'和B'

那么判別器D的目的就是讓A'和B'的盡可能的相差大一點(diǎn),也就是讓特征提取器可以更好的分辨真假圖像

然后生成器G的目的則是讓假圖像經(jīng)過(guò)判別器 與原圖的誤差分布B'盡可能小,也就是讓A'和B'盡可能相差的小一點(diǎn)

在論文中證明了A'和B'都是滿足正態(tài)分布的,所以A'和B'的距離可以用
Wasserstein距離來(lái)得到


大概的概括一下原理就是:根據(jù)Wasserstein距離,來(lái)匹配自編碼器的損失分布。并采用神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),在訓(xùn)練中添加額外的均衡過(guò)程,來(lái)平衡生成器和判別器(具體見(jiàn)論文和代碼)

當(dāng)然以上內(nèi)容只是個(gè)人對(duì)該網(wǎng)絡(luò)結(jié)構(gòu)以及損失函數(shù)的理解而已,具體的推導(dǎo)過(guò)程可以見(jiàn)其他闡述原理的文章,或者是直接看 BEGAN論文

有一篇文章也不錯(cuò) BEGAN論文閱讀

最后編輯于
?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容