GAN 的 keras 實現

本文結構:

  • 什么是 GAN?
  • 優點?
  • keras 例子?

什么是 GAN?

GAN,全稱為 Generative Adversarial Nets,直譯為生成式對抗網絡,是一種非監督式模型。

一種應用是生成在原始數據集中不存在的但是卻比較合理的數據,還可以拓展一張圖片,生成下一幀影像,由簡單幾筆生成一幅畫:

模型:

主要有兩部分:

The Generative Model:通過輸入任意隨機數據,嘗試生成一些真實的東西(曲線,圖像,聲音,文本,...)

The Discriminative Model:試圖判定哪些是虛假的數據,來減小對真實數據的誤報。


優點:

Markov chains are never needed
避免了計算復雜度特別高的過程,直接進行采樣和推斷,應用效率相應提高。

a wide variety of functions can be incorporated into the model
針對不同的任務就可以設計不同類型的損失函數。

can represent very sharp, even degenerate distributions


Keras 例子:

任務:生成 sin 曲線。

%matplotlib inline
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
from keras.models import Model
from keras.layers import Input, Reshape
from keras.layers.core import Dense, Activation, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling1D, Conv1D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam, SGD
from keras.callbacks import TensorBoard

1. Generative model:

輸入:noise data
輸出:嘗試生成真實的 sin 數據

def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
    x = Dense(dense_dim)(G_in)
    x = Activation('tanh')(x)
    G_out = Dense(out_dim, activation='tanh')(x)
    G = Model(G_in, G_out)
    opt = SGD(lr=lr)
    G.compile(loss='binary_crossentropy', optimizer=opt)
    return G, G_out

2. Discriminative model:

輸出:識別此數據是真實的,還是由 Generative model 生成的

def get_discriminative(D_in, lr=1e-3, drate=.25, n_channels=50, conv_sz=5, leak=.2):
    x = Reshape((-1, 1))(D_in)
    x = Conv1D(n_channels, conv_sz, activation='relu')(x)
    x = Dropout(drate)(x)
    x = Flatten()(x)
    x = Dense(n_channels)(x)
    D_out = Dense(2, activation='sigmoid')(x)
    D = Model(D_in, D_out)
    dopt = Adam(lr=lr)
    D.compile(loss='binary_crossentropy', optimizer=dopt)
    return D, D_out

3. chain the two models into a GAN:

set_trainability 的作用是每次訓練 generator 時要凍住 discriminator。

def set_trainability(model, trainable=False):
    model.trainable = trainable
    for layer in model.layers:
        layer.trainable = trainable
        
def make_gan(GAN_in, G, D):
    set_trainability(D, False)
    x = G(GAN_in)
    GAN_out = D(x)
    GAN = Model(GAN_in, GAN_out)
    GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
    return GAN, GAN_out

4. Training:

交替訓練 discriminator 和 chained GAN,在訓練 chained GAN 時要凍住 discriminator 的參數:

def sample_noise(G, noise_dim=10, n_samples=10000):
    X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
    y = np.zeros((n_samples, 2))
    y[:, 1] = 1
    return X, y

def train(GAN, G, D, epochs=500, n_samples=10000, noise_dim=10, batch_size=32, verbose=False, v_freq=50):
    d_loss = []
    g_loss = []
    e_range = range(epochs)
    if verbose:
        e_range = tqdm(e_range)
    for epoch in e_range:
        X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim)
        set_trainability(D, True)
        d_loss.append(D.train_on_batch(X, y))
        
        X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim)
        set_trainability(D, False)
        g_loss.append(GAN.train_on_batch(X, y))
        if verbose and (epoch + 1) % v_freq == 0:
            print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
    return d_loss, g_loss

d_loss, g_loss = train(GAN, G, D, verbose=True)

5. Results:

N_VIEWED_SAMPLES = 2
data_and_gen, _ = sample_data_and_gen(G, n_samples=N_VIEWED_SAMPLES)
pd.DataFrame(np.transpose(data_and_gen[N_VIEWED_SAMPLES:])).rolling(5).mean()[5:].plot()

學習資料:
https://arxiv.org/pdf/1406.2661.pdf
http://www.rricard.me/machine/learning/generative/adversarial/networks/2017/04/05/gans-part1.html
http://www.rricard.me/machine/learning/generative/adversarial/networks/keras/tensorflow/2017/04/05/gans-part2.html


推薦閱讀 歷史技術博文鏈接匯總
http://www.lxweimin.com/p/28f02bb59fe5
也許可以找到你想要的:
[入門問題][TensorFlow][深度學習][強化學習][神經網絡][機器學習][自然語言處理][聊天機器人]

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容