訓練WGAN的時候,有幾個方面可以調參:
? a. 調節Generator loss中GAN loss的權重。 G loss和Gan loss在一個尺度上或者G loss比Gan loss大一個尺度。但是千萬不能讓Gan loss占主導地位, 這樣整個網絡權重會被帶偏。
? b. 調節Generator和Discrimnator的訓練次數比。一般來說,Discrimnator要訓練的比Genenrator多。比如訓練五次Discrimnator,再訓練一次Genenrator(WGAN論文 是這么干的)。
? c. 調節learning rate,這個學習速率不能過大。一般要比Genenrator的速率小一點。
? d. Optimizer的選擇不能用基于動量法的,如Adam和momentum。可使用RMSProp或者SGD。
? e. Discrimnator的結構可以改變。如果用WGAN,判別器的最后一層需要去掉sigmoid。但是用原始的GAN,需要用sigmoid,因為其loss function里面需要取log,所以值必須在[0,1]。這里用的是鄧煒的critic模型當作判別器。之前twitter的論文里面的判別器即使去掉了sigmoid也不好訓練。
? f. Generator loss的誤差曲線走向。因為Generator的loss定義為:
?? G_loss = -tf.reduce_mean(D_fake)
? ? Generator_loss = gen_loss + lamda*G_loss
其中gen_loss為Generator的loss,G_loss為Discrimnator的loss,目標是使Generator_loss不斷變小。所以理想的Generator loss的誤差曲線應該是不斷往0靠的下降的拋物線。
? g. Discrimnator loss的誤差曲線走向。因為Discrimnator的loss定義為:
? ? ? D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
這個是一個和Generator抗衡的loss。目標就是使判別器分不清哪個是生成器的輸出哪個是真實的label。所以理想的Discrimnator loss的誤差曲線應該是最終在0附近振蕩,即傻傻分不清。換言之,就是判別器有50%的概率判斷你是真的,50%概率判斷你是假的。
? h. 之前的想法是就算判別器不訓練,那么它判斷這個圖片是真是假的概率都是50%,那D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)不就已經在0附近了嗎?
其實不是這樣的。如果是wgan的話,判別器的輸出是一個負無窮到正無窮的數值,那么要讓它對兩個不同的輸入產生相似的輸出是很難的。同理,對于gan的話,判別器的輸出是介于[0,1]之間的,產生兩個相似的輸出也是很困難的。如果判別器的輸出是0或者1的話,那就是上面說的情況。所以,網絡要經過學習,使得 輸出盡可能相似,那就達到了傻傻分不清的狀態了。