在1889年,梵高畫了這個美麗的藝術品:星月夜。如今,我的GAN模型只使用20%的標簽數據,學會了畫MNIST數字!它是怎么實現的?讓我們動手做做看。
大多數深度學習分類器需要大量的標簽樣本才能很好地泛化,但獲取這些數據是的過程往往很艱難。為了解決這個限制,半監督學習被提出,它是利用少量標記數據和大量未標記數據的分類技術。許多機器學習研究人員發現,將未標記數據與少量標記數據結合使用時,可以顯著提高學習準確性。在半監督學習中,GAN(生成式對抗網絡)表現出了很大的潛力,其中分類器可以用很少的標簽數據取得良好的表現。
GAN的背景
GAN是深度生成模型的一種。它們特別有趣,因為它們沒有明確表示數據所在空間的概率分布。而是通過從中抽取樣本,提供了一些不直接與這種概率分布不直接相關的方法。
普通GAN架構
GAN的基本原理是在兩個“玩家”之間建立一場比賽:
生成器(G):取隨機噪聲z作為輸入并輸出圖像x。它的參數被調整以讓它產生的假圖像從判別器中獲得高分。
判別器(D):獲取圖像X作為輸入,并輸出一個反映了它對于這是否是真實圖像的信心得分。它的參數被調整為:當有真實圖像饋送時反饋高分,并且發生器饋送假圖像時會反饋低分。
現在,讓我們來稍微討論一下GAN最重要的應用之一,半監督學習。
直覺
普通判別器架構只有一個輸出神經元用于分類R / F概率(對/錯)。我們同時訓練兩個網絡并在訓練完成后丟棄判別器,因為它僅用于改進發生器。
對于半監督任務,除了R / F神經元之外,判別器現在將具有10個用于MNIST數字分類的神經元。而且,這次他們的角色會改變,我們可以在訓練后丟棄生成器,其唯一目標是生成未標記的數據以提高判別器的性能。
現在判別器成為了11個類的分類器,其中1個神經元(R / F神經元)代表假數據輸出,另外10個代表具有類的實際數據。你必須牢記以下幾點:
當來自數據集的真的無監督(或者說標簽)數據被饋送時,要保證R / F神經元輸出標簽= 0
當來自發生器的假的無監督數據被饋送時,要保證R / F神經元輸出標簽= 1
當真實有監督數據被饋送時,要保證R / F輸出標簽= 0并且相應的標簽輸出= 1
不同數據來源的組合將有助于判別器的分類更精確。
架構
現在我們動手進行編碼。
判別器
下面的架構與DCGAN?論文中提出的架構類似。我們使用跨卷積(strided convolutions)來減少特征向量的維度,而不是任何池化層,并且為所有層應用一系列的leaky_relu,dropout和BN來穩定學習。輸入層和最后一層中BN被舍棄(為了特征匹配)。最后,我們執行全局平均池化(Global Average Pooling)以取得特征向量空間維度上的平均值。這可以將張量維度壓縮為單個值。在扁平化了特征之后,為了多類輸出增加一個11個類的稠密層和softmax激活函數。
01def?discriminator(x, dropout_rate=?0., is_training=?True, reuse=?False):
02???# input x -> n+1 classes
03???with tf.variable_scope('Discriminator', reuse=?reuse):
04?????# x = ?*64*64*1
05?????
06?????#Layer 1
07?????conv1=?tf.layers.conv2d(x,128, kernel_size=?[4,4], strides=?[2,2],
08?????????????????????????????padding=?'same', activation=?tf.nn.leaky_relu, name=?'conv1')# ?*32*32*128
09?????#No batch-norm for input layer
10?????dropout1=?tf.nn.dropout(conv1, dropout_rate)
11?????
12?????#Layer2
13?????conv2=?tf.layers.conv2d(dropout1,256, kernel_size=?[4,4], strides=?[2,2],
14?????????????????????????????padding=?'same', activation=?tf.nn.leaky_relu, name=?'conv2')# ?*16*16*256
15?????batch2=?tf.layers.batch_normalization(conv2, training=?is_training)
16?????dropout2=?tf.nn.dropout(batch2, dropout_rate)
17?????
18?????#Layer3
19?????conv3=?tf.layers.conv2d(dropout2,512, kernel_size=?[4,4], strides=?[4,4],
20?????????????????????????????padding=?'same', activation=?tf.nn.leaky_relu, name=?'conv3')# ?*4*4*512
21?????batch3=?tf.layers.batch_normalization(conv3, training=?is_training)
22?????dropout3=?tf.nn.dropout(batch3, dropout_rate)
23???????
24?????# Layer 4
25?????conv4=?tf.layers.conv2d(dropout3,1024, kernel_size=[3,3], strides=[1,1],
26??????????????????????????????padding='valid',activation=?tf.nn.leaky_relu, name='conv4')# ?*2*2*1024
27?????# No batch-norm as this layer's op will be used in feature matching loss
28?????# No dropout as feature matching needs to be definite on logits
29?
30?????# Layer 5
31?????# Note: Applying Global average pooling???????
32?????flatten=?tf.reduce_mean(conv4, axis=?[1,2])
33?????logits_D=?tf.layers.dense(flatten, (1?+?num_classes))
34?????out_D=?tf.nn.softmax(logits_D)????
35???return?flatten,logits_D,out_D
發生器
發生器架構旨在模仿判別器的空間輸出。使用部分跨卷積來增加表示的空間維度。噪聲的四維張量的輸入z被饋送,它經過轉置卷積,relu,BN(輸出層除外)和dropout操作。最后,tanh激活將輸出圖像映射到(-1,1)范圍內。
01def?generator(z, dropout_rate=?0., is_training=?True, reuse=?False):
02????# input latent z -> image x
03????with tf.variable_scope('Generator', reuse=?reuse):
04??????#Layer 1
05??????deconv1=?tf.layers.conv2d_transpose(z,512, kernel_size=?[4,4],
06?????????????????????????????????????????strides=?[1,1], padding=?'valid',
07????????????????????????????????????????activation=?tf.nn.relu, name=?'deconv1')# ?*4*4*512
08??????batch1=?tf.layers.batch_normalization(deconv1, training=?is_training)
09??????dropout1=?tf.nn.dropout(batch1, dropout_rate)
10??????
11??????#Layer 2
12??????deconv2=?tf.layers.conv2d_transpose(dropout1,256, kernel_size=?[4,4],
13?????????????????????????????????????????strides=?[4,4], padding=?'same',
14????????????????????????????????????????activation=?tf.nn.relu, name=?'deconv2')# ?*16*16*256
15??????batch2=?tf.layers.batch_normalization(deconv2, training=?is_training)
16??????dropout2=?tf.nn.dropout(batch2, dropout_rate)
17????????
18??????#Layer 3
19??????deconv3=?tf.layers.conv2d_transpose(dropout2,128, kernel_size=?[4,4],
20?????????????????????????????????????????strides=?[2,2], padding=?'same',
21????????????????????????????????????????activation=?tf.nn.relu, name=?'deconv3')# ?*32*32*256
22??????batch3=?tf.layers.batch_normalization(deconv3, training=?is_training)
23??????dropout3=?tf.nn.dropout(batch3, dropout_rate)
24??????
25??????#Output layer
26??????deconv4=?tf.layers.conv2d_transpose(dropout3,1, kernel_size=?[4,4],
27????????????????????????????????????????strides=?[2,2], padding=?'same',
28????????????????????????????????????????activation=?None, name=?'deconv4')# ?*64*64*1
29??????out=?tf.nn.tanh(deconv4)
30????return?out
模型損失
我們首先通過將實際標簽附加為零來準備整個批次的擴展標簽。這樣做是為了在標記數據饋送時,R / F神經元的輸出為0。未標記數據的判別器損失可以被認為是一個二元sigmoid損失,通過將R / F神經元輸出為1聲明假圖像,而真實圖像輸出為0。
01### Discriminator loss ###
02???# Supervised loss -> which class the real data belongs to???
03???temp=?tf.nn.softmax_cross_entropy_with_logits_v2(logits=?D_real_logit,
04?????????????????????????????????????????????????labels=?extended_label)
05???# Labeled_mask and temp are of same size = batch_size where temp is softmax cross_entropy calculated over whole batch
06?
07???D_L_Supervised=?tf.reduce_sum(tf.multiply(temp,labeled_mask))/?tf.reduce_sum(labeled_mask)
08???
09???# Multiplying temp with labeled_mask gives supervised loss on labeled_mask
10???# data only, calculating mean by dividing by no of labeled samples
11???
12???# Unsupervised loss -> R/F???
13???D_L_RealUnsupervised=?tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
14???????????logits=?D_real_logit[:,0], labels=?tf.zeros_like(D_real_logit[:,0], dtype=tf.float32)))
15???
16???D_L_FakeUnsupervised=?tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
17???????????logits=?D_fake_logit[:,0], labels=?tf.ones_like(D_fake_logit[:,0], dtype=tf.float32)))
18???
19???D_L=?D_L_Supervised+?D_L_RealUnsupervised+?D_L_FakeUnsupervised
發生器損失是fake_image損失與特征匹配損失的組合,前者錯誤的將R / F神經元輸出斷言為0,后者懲罰訓練數據上一組特征的平均值與生成樣本中這組特征的平均值之間的平均絕對誤差。
01?????????????### Generator loss ###???????????????
02# G_L_1 -> Fake data wanna be real
03?
04G_L_1=?tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
05????????logits=?D_fake_logit[:,0],labels=?tf.zeros_like(D_fake_logit[:,0], dtype=tf.float32)))
06?
07# G_L_2 -> Feature matching
08data_moments=?tf.reduce_mean(D_real_features, axis=?0)
09sample_moments=?tf.reduce_mean(D_fake_features, axis=?0)
10G_L_2=?tf.reduce_mean(tf.square(data_moments-sample_moments))
11?
12G_L=?G_L_1+?G_L_2
訓練
訓練圖像從[batch_size,28,28,1]調整為[batch_size,64,64,1]以適應發生器和判別器架構。計算損失,準確性和生成樣本,并觀察每個周期的改進。
01for?epochin?range(epochs):
02??train_accuracies, train_D_losses, train_G_losses=?[], [], []
03??for?itin?range(no_of_batches):
04??
05??batch=?mnist_data.train.next_batch(batch_size, shuffle=?False)
06??# batch[0] has shape: batch_size*28*28*1????????
07??batch_reshaped=?tf.image.resize_images(batch[0], [64,64]).eval()
08??# Reshaping the whole batch into batch_size*64*64*1 for disc/gen architecture
09??batch_z=?np.random.normal(0,1, (batch_size,1,1, latent))
10??mask=?get_labeled_mask(labeled_rate, batch_size)
11????????????????
12??train_feed_dict=?{x : scale(batch_reshaped), z : batch_z,
13??????????????????????????????label : batch[1], labeled_mask : mask,
14???????????????????????????????dropout_rate :0.7, is_training :True}
15??#The label provided in dict are one hot encoded in 10 classes
16????????????????
17??D_optimizer.run(feed_dict=?train_feed_dict)
18??G_optimizer.run(feed_dict=?train_feed_dict)
19????????????????
20??train_D_loss=?D_L.eval(feed_dict=?train_feed_dict)
21??train_G_loss=?G_L.eval(feed_dict=?train_feed_dict)
22??train_accuracy=?accuracy.eval(feed_dict=?train_feed_dict)
23??????????
24??train_D_losses.append(train_D_loss)
25??train_G_losses.append(train_G_loss)
26??train_accuracies.append(train_accuracy)
27??????????
28??tr_GL=?np.mean(train_G_losses)
29??tr_DL=?np.mean(train_D_losses)
30??tr_acc=?np.mean(train_accuracies)??????
31??
32??print?('After epoch: '+?str(epoch+1)+?' Generator loss: '
33???????????????????????+?str(tr_GL)+?' Discriminator loss: '?+?str(tr_DL)+?' Accuracy: '?+?str(tr_acc))
34????????
35??gen_samples=?fake_data.eval(feed_dict=?{z : np.random.normal(0,1, (25,1,1, latent)), dropout_rate :0.7, is_training :False})
36??# Dont train batch-norm while plotting => is_training = False
37??test_images=?tf.image.resize_images(gen_samples, [64,64]).eval()
38??show_result(test_images, (epoch+?1), show=?True, save=?False, path=?'')
結論
由于GPU的限制,訓練已完成5個周期和20%的 labeled_rate。想要獲得更好的結果,建議使用較小的label_rate的訓練更多周期。
完整代碼:https://github.com/raghav64/SemiSuper_GAN/blob/master/SSGAN.py
訓練結果
本文為編譯作品,轉載請注明出處。更多內容關注微信公眾號:atyun_com