本文翻譯自這個博客,已獲得作者授權(quán),翻譯的不好還請指教。
生成對抗網(wǎng)絡(luò)(Generative Adversarial Nets)是一種非常流行的神經(jīng)網(wǎng)絡(luò)。它首先由Ian Goodfellow在2014年NIPS大會上發(fā)表。它點燃了神經(jīng)網(wǎng)絡(luò)中對抗學(xué)習(xí)的興趣,這可以從論文被引用的次數(shù)中證明。一時之間,冒出來許多不同的GAN:DCGAN,Sequence-GAN,LSTM-GAN等。在NIPS 2016,甚至將會有一整場的對抗學(xué)習(xí)討論會。
現(xiàn)在代碼可以從https://github.com/wiseodd/generative-models得到。
首先讓我們回顧一下論文的要點。之后我們會用TensorFlow來實現(xiàn)GAN,數(shù)據(jù)集是MNIST。
Generative Adversarial Net
讓我們舉一個假幣制造商和警察的例子。假幣制造商和警察各自的目標(biāo)是什么呢?
- 一個成功的假幣制造商會想盡方法騙過警察,使警察分不清假幣與真幣。
- 一個合格的警察會盡力分辨出假幣和真幣。
現(xiàn)在就產(chǎn)生了沖突。這種情況可以認(rèn)為是博弈論中的最大最小游戲。這個過程稱為對抗過程。
GAN是對抗過程的一個特例,它的組成(警察和假幣制造商)是神經(jīng)網(wǎng)絡(luò)。第一個網(wǎng)絡(luò)試圖生成數(shù)據(jù),第二個網(wǎng)絡(luò)試圖分辨出真實數(shù)據(jù)和第一個網(wǎng)絡(luò)生成的偽造數(shù)據(jù)。第二個網(wǎng)絡(luò)會輸出表示真實數(shù)據(jù)概率的張量[0,1]。
在GAN中,第一個網(wǎng)絡(luò)稱為生成器G(Z),第二個網(wǎng)絡(luò)稱為判別器D(X)。
在平衡點,也就是最大最小游戲中的最優(yōu)點,第一個網(wǎng)絡(luò)會生成真實數(shù)據(jù),第二個網(wǎng)絡(luò)輸出的概率會是0.5,因為第一個網(wǎng)絡(luò)生成了真實數(shù)據(jù)。
我們不禁會想“為什么要訓(xùn)練GAN呢?”,這是因為數(shù)據(jù)分布P_{data}可能非常復(fù)雜,很難推斷。所以使用對抗網(wǎng)絡(luò)可以從布P_{data}中生成樣本而不用處理討厭的概率分布問題。
GAN實現(xiàn)
根據(jù)GAN的定義我們需要兩個網(wǎng)絡(luò)。可以是任意的網(wǎng)絡(luò),比如卷積網(wǎng)絡(luò)或僅僅是兩層的感知器網(wǎng)絡(luò)。首先我們使用簡單的兩層感知器網(wǎng)絡(luò)。
# 判別器
X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
D_W1 = tf.Variable(xavier_init([784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(xavier_init([128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name='D_b2')
theta_D = [D_W1, D_W2, D_b1, D_b2]
# 生成器
Z = tf.placeholder(tf.float32, shape=[None, 100], name='Z')
G_W1 = tf.Variable(xavier_init([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(xavier_init([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]
def generator(z):
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
上面,generator(z)接受一個100維的矢量,輸出一個784維的矢量(MNIST圖片大小(28x28))。
discriminator(x)接受MNIST圖片作為輸入,返回一個代表真實圖片概率的張量。
現(xiàn)在讓我們解釋一下GAN的對抗過程。下面是論文中的訓(xùn)練算法:
G_sample = generator(Z)
D_real, D_logit_real = discriminator(X)
D_fake, D_logit_fake = discriminator(G_sample)
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))
上面我們對損失函數(shù)取負(fù)是因為它們需要最大化,而TensorFlow的優(yōu)化器只能進(jìn)行最小化。
另外根據(jù)論文的建議,最好最大化tf.reduce_mean(tf.log(D_fake)),而不是最小化tf.reduce_mean(1-tf.log(D_fake))。
接下來我們來訓(xùn)練網(wǎng)絡(luò)。
# 只更新 D(X)的參數(shù), 所以 var_list = theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
# 只更新 G(X)的參數(shù), 所以 var_list = theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
def sample_Z(m, n):
'''Uniform prior for G(Z)'''
return np.random.uniform(-1., 1., size=[m, n])
for it in range(1000000):
X_mb, _ = mnist.train.next_batch(mb_size)
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
這樣我們就完成了!我們可以看一看訓(xùn)練過程
剛開始我們使用隨機噪聲作為輸入,隨著訓(xùn)練的進(jìn)行,G(Z)開始越來越趨近P_{data}。
替代的損失函數(shù)
我們可以使用不同的方法來表示D_loss和G_loss。
讓我們跟隨自己的直覺。這個方法根據(jù)Brandon Amos’ blog.
讓我們想一想,discriminator(x)試圖將所有的輸出變?yōu)?,也就是我們想最大化真實數(shù)據(jù)的概率。而discriminator(G_sample)試圖將所有的輸出變?yōu)?,即D(G(Z))希望最小化偽造數(shù)據(jù)的概率。
那么generator(z)呢?它當(dāng)然想最大化偽造數(shù)據(jù)的概率!它與D(G(Z))正相反!
因此,代碼可以寫成:
# 另外的損失:
# -------------------
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit_real, tf.ones_like(D_logit_real)))
D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit_fake, tf.zeros_like(D_logit_fake)))
D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(D_logit_fake, tf.ones_like(D_logit_fake)))
這里我們使用的是Logistic Loss。改變損失函數(shù)不會影響到GAN的訓(xùn)練。