GAN

#coding:utf-8

from tensorflow.examples.tutorials.mnist import input_data #連接遠程服務器,下載數據集
import tensorflow as tf
import numpy as np
from skimage.io import imsave
import os #獲取路徑
import shutil
import sys
import six

#############
#定義模型參數
#############
# 圖片尺寸
img_height= 28
img_width = 28
img_size = img_height * img_width#784

###################
#模型保存與加載參數
###################
to_train = True
to_restore = False #為了判斷是否加載模型
output_path = 'D:\Python\GANTest\output'# 保存地址


###################
#網絡及訓練參數
###################
#生成模型的輸入是100,隨機變量的維度是100,
# 最后要變成784維,所以隱層是增的,第一個隱層是150,第二個是300,這是生成模型
# 判別模型要倒著來,輸入是784,第一層300,第二層150,最后的輸出是1維,映射到一個值,用0或者1全宇判別
z_size = 100
h1_size = 150
h2_size = 300
# 一次訓練256個,因為圖像像素也不高,可以一次多訓練一點
batch_size = 256
# 深度學習一般都是1000,實際上收斂的時候,一般小于1000
max_epoch = 1000

###################
#DNN網絡構建(多層感知機)
###################
# 生成模型做的: 就是給一個隨機變量,返回一個生成的數值x_generate,和一套參數
def 生成模型(z_prior):
    # 從輸入層向第一隱層的w,應該是100行,150列

    # truncated_normal:從截斷的正態分布中輸出隨機值。
    # shape表示生成張量的維度,mean是均值,stddev是標準差。這個函數產生正太分布,均值和標準差自己設定。
    # 這是一個截斷的產生正太分布的函數,就是說產生正太分布的值如果與均值的差值大于兩倍的標準差,那就重新生成
    w1 = tf.Variable(tf.truncated_normal([z_size, h1_size], \
                                         # 方差 stddev = 0.1,均值不用寫,均值默認是0
                                         stddev = 0.1), name = 'g_w1', \
                                         dtype=tf.float32)
    # 注意此時是:zeros
    b1 = tf.Variable(tf.zeros([h1_size]), name = 'g_b1', dtype = tf.float32)
    # 公式上是Wx+b,但矩陣運算時,x放前面
    h1 = tf.nn.relu(tf.matmul(z_prior, w1) + b1)

    # 從第一個隱層到第二個隱層
    w2 = tf.Variable(tf.truncated_normal([h1_size, h2_size], \
                                         stddev = 0.1), name = 'g_w2', \
                                         dtype=tf.float32)
    b2 = tf.Variable(tf.zeros([h2_size]), name = 'g_b2', dtype = tf.float32)

    # 注意此時是h1,也就是 隱層1激活之后的結果值
    h2 = tf.nn.relu(tf.matmul(h1, w2) + b2)

    w3 = tf.Variable(tf.truncated_normal([h2_size, img_size], \
                                         stddev = 0.1), name = 'g_w3', \
                                         dtype=tf.float32)
    b3 = tf.Variable(tf.zeros([img_size]), name = 'g_b3', dtype = tf.float32)
    # h3也能很大,也可能很小,如果是做分類,就用softmax激活,這里不用分類,選tanh激活
    h3 = tf.matmul(h2, w3) + b3
    x_generate = tf.nn.tanh(h3)#fake,生成模型,假的
    g_params = [w1, b1, w2, b2, w3, b3]#生成模型的參數,更新的參數,需要保存
    return x_generate, g_params

# x_data: 就是通過input_data下載的數據集,在后面的會話中給他喂,先寫一個形參x_data就是y值
# x_generate:傳兩個
#  keep_prob:Dropout率,意思是每個元素被保留的概率,那么 keep_prob:1就是所有元素全部保留的意思。
# 一般在大量數據訓練時,為了防止過擬合,添加Dropout層,設置一個0~1之間的小數
def 判別模型(x_data, x_generate, keep_prob):
    # 這兩個值是并行輸入的
    # tf.concat:tensorflow中用來拼接張量的函數。 tf.concat()拼接的張量只會改變一個維度,其他維度是保存不變的。
    # 比如兩個shape為[2,3]的矩陣拼接,要么通過axis=0變成[4,3],要么通過axis=1變成[2,6]。改變的維度索引對應axis的值。
    # x_data(y值)和x_generate(y^值)都是一維的,所以拼接成兩行的向量,但是他們不是相加,一起進網絡,一起出網絡
    # x_in代表判別模型的輸入
    x_in = tf.concat([x_data, x_generate], 0)
    w1 = tf.Variable(tf.truncated_normal([img_size, h2_size], \
                                         stddev = 0.1), name = 'd_w1', \
                                         dtype=tf.float32)
    b1 = tf.Variable(tf.zeros([h2_size]), name = 'd_b1', dtype = tf.float32)
    # 這里比之前可以加一個dropout率
    h1 = tf.nn.dropout(tf.nn.relu(tf.matmul(x_in, w1) + b1),keep_prob)

    w2 = tf.Variable(tf.truncated_normal([h2_size, h1_size], \
                                         stddev = 0.1), name = 'd_w2', \
                                         dtype=tf.float32)
    b2 = tf.Variable(tf.zeros([h1_size]), name = 'd_b2', dtype = tf.float32)
    h2 = tf.nn.dropout(tf.nn.relu(tf.matmul(h1, w2) + b2),keep_prob)
    # w3只有一個值,注意b和后面的網絡深度相同,因為前一層的b連的是后一層的cell
    w3 = tf.Variable(tf.truncated_normal([h1_size, 1], stddev = 0.1), \
                                        name = 'd_w3', dtype = tf.float32)
    b3 = tf.Variable(tf.zeros([1]), name = 'd_b3', dtype = tf.float32)
    h3 = tf.matmul(h2, w3) + b3
    print(h3)
    # 兩維數組,上面一行是1*256(因為一次訓練256個)
    # tf.slice(input_, begin, size):
    # “input_”是你輸入的tensor,就是被切的那個。
    # “begin”是每一個維度的起始位置,這個下面詳細說。
    # “size”相當于問每個維度拿幾個元素出來。
    # 需要把y和y^分開,因為之前是黏在一起的,然后算損失
    #  為什么[batch_size, -1]是[256, 784]?累加?全連接層

    y_data = tf.nn.sigmoid(tf.slice(h3, [0, 0], [batch_size, -1]), name = None) #y
    # 二分類直接用sigmoid去激活
    # [batch_size, 0]中batch_size:從第256開始切,[-1, -1]:第一個-1,切剩下所有的
    y_generated = tf.nn.sigmoid(tf.slice(h3, [batch_size, 0], [-1, -1]), name = None) #y^
    d_params = [w1, b1, w2, b2, w3, b3]#判別模型的參數
    return y_data, y_generated, d_params

#########################################################
#定義圖片展示的函數(即將生成模型的輸出相片繪制成圖片保存)
#########################################################
# 繪制一個8*8的格子,每個格子保存16張圖片,每張圖片就是28*28=784的圖片
def 展示結果保存(batch_res, fname, grid_size=(8, 8), grid_pad=5):#show_result
    '''圖片保存'''
    batch_res = 0.5 * batch_res.reshape((batch_res.shape[0], img_height, img_width)) + 0.5
    img_h, img_w = batch_res.shape[1], batch_res.shape[2]
    grid_h = img_h * grid_size[0] + grid_pad * (grid_size[0] - 1)
    grid_w = img_w * grid_size[1] + grid_pad * (grid_size[1] - 1)
    img_grid = np.zeros((grid_h, grid_w), dtype=np.uint8)
    for i, res in enumerate(batch_res):
        if i >= grid_size[0] * grid_size[1]:
            break
        img = (res) * 255
        img = img.astype(np.uint8)
        row = (i // grid_size[0]) * (img_h + grid_pad)
        col = (i % grid_size[1]) * (img_w + grid_pad)
        img_grid[row:row + img_h, col:col + img_w] = img
    imsave(fname, img_grid)

######################
#開始訓練
######################
def 開始訓練():
    # input_data是網上下載的,one_hot打開,one_hot指的是標簽
    mnist = input_data.read_data_sets('MNIST_data', one_hot = True)
    # [batch_size, z_size]=[256,100], z_prior:輸入給生成模型的,隨機變量
    z_prior = tf.placeholder(tf.float32, [batch_size, z_size], name = 'z_prior')
    # [batch_size, img_size]=[256,784]   x_data:真實值
    x_data = tf.placeholder(tf.float32, [batch_size, img_size], name = 'x_data')
    keep_prob = tf.placeholder(tf.float32, name = 'keep_prob')
    # 步數,假設255步,trainable:不是訓練的步長,是測試的步長
    global_step = tf.Variable(255, name = 'global_step', trainable = False)

    x_generate, g_params = 生成模型(z_prior)
    y_data, y_generated, d_params = 判別模型(x_data, x_generate, keep_prob)

    ##########
    #構建損失,兩個損失,先更新d_loss,注意損失函數的區別
    ##########
    d_loss = -(tf.log(y_data)+tf.log(1-y_generated))
    g_loss = -tf.log(y_generated)
    # 優化器,要優化兩次
    optimizer = tf.train.AdamOptimizer(0.0001)
    # d_params兩個模型中的參數列表,run這兩個模型就好了:d_train,g_train
    d_train = optimizer.minimize(d_loss, var_list = d_params)
    g_train = optimizer.minimize(g_loss, var_list = g_params)
    # 之前的寫法也可以
    init = tf.initialize_all_variables()
    #####################
    #保存與加載模型
    #####################
    saver = tf.train.Saver()
    sess = tf.Session()
    # 先run(init),初始化變量
    sess.run(init)
    if to_restore:#加載模型
        chkpt_fname = tf.train.latest_checkpoint(output_path)
        # 把模型及上面的參數加載到會話中來
        saver.restore(sess, chkpt_fname)
    elif os.path.exists(output_path):
        # 存在,則刪掉,重新創建
        shutil.rmtree(output_path)
        os.mkdir(output_path)
    # os.path.exists(output_path)
    # # 存在,則刪掉,重新創建
    # shutil.rmtree(output_path)
    # os.mkdir(output_path)

    # 創建隨機變量 0, 1,均值為0,方差為1,(batch_size, z_size) 256行*100維
    # z_sample_val = np.random.normal(0, 1, size = (batch_size, z_size)).astype(tf.float32)

    # 60000整個數據集有60000張圖片
    steps = 60000//batch_size

    #for i in range(sess.run(global_step), max_epoch):
    for i in range(max_epoch):

        for j in range(steps):

            print("epoch序次:%d, 迭代次數:%d"%(i, j))

            ############
            #傳入實參
            ############
            # 切batch_size個
            x_value, _ = mnist.train.next_batch(batch_size)
            # 歸一化處理,圖像一般都這樣處理,*2-1,讓像素分布在0的兩邊
            x_value = 2*x_value.astype(np.float32) - 1
            # x_value:真實值。z_value
            z_value = np.random.normal(0, 1, size = (batch_size, z_size)).astype(np.float32)

            # 先run判別模型
            sess.run(d_train, feed_dict = {x_data:x_value, z_prior:z_value, keep_prob:np.sum(0.7).astype(np.float32)})

            sess.run(g_train, feed_dict = {x_data:x_value, z_prior:z_value, keep_prob:np.sum(0.7).astype(np.float32)})
            # 每運行一遍,再運行x_generate
            x_gen_val = sess.run(x_generate, feed_dict={z_prior:z_value})

            展示結果保存(x_gen_val, 'D:\Python\GANTest\output_sample\sample{0}.jpg'.format(i))
            ################
            #保存模型
            ################
            sess.run(tf.assign(global_step, i+1))
            saver.save(sess, os.path.join(output_path, 'model'), global_step = global_step)

if __name__ =='__main__':

    開始訓練()

最后編輯于
?著作權歸作者所有,轉載或內容合作請聯系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發布,文章內容僅代表作者本人觀點,簡書系信息發布平臺,僅提供信息存儲服務。
  • 序言:七十年代末,一起剝皮案震驚了整個濱河市,隨后出現的幾起案子,更是在濱河造成了極大的恐慌,老刑警劉巖,帶你破解...
    沈念sama閱讀 228,983評論 6 537
  • 序言:濱河連續發生了三起死亡事件,死亡現場離奇詭異,居然都是意外死亡,警方通過查閱死者的電腦和手機,發現死者居然都...
    沈念sama閱讀 98,772評論 3 422
  • 文/潘曉璐 我一進店門,熙熙樓的掌柜王于貴愁眉苦臉地迎上來,“玉大人,你說我怎么就攤上這事。” “怎么了?”我有些...
    開封第一講書人閱讀 176,947評論 0 381
  • 文/不壞的土叔 我叫張陵,是天一觀的道長。 經常有香客問我,道長,這世上最難降的妖魔是什么? 我笑而不...
    開封第一講書人閱讀 63,201評論 1 315
  • 正文 為了忘掉前任,我火速辦了婚禮,結果婚禮上,老公的妹妹穿的比我還像新娘。我一直安慰自己,他們只是感情好,可當我...
    茶點故事閱讀 71,960評論 6 410
  • 文/花漫 我一把揭開白布。 她就那樣靜靜地躺著,像睡著了一般。 火紅的嫁衣襯著肌膚如雪。 梳的紋絲不亂的頭發上,一...
    開封第一講書人閱讀 55,350評論 1 324
  • 那天,我揣著相機與錄音,去河邊找鬼。 笑死,一個胖子當著我的面吹牛,可吹牛的內容都是我干的。 我是一名探鬼主播,決...
    沈念sama閱讀 43,406評論 3 444
  • 文/蒼蘭香墨 我猛地睜開眼,長吁一口氣:“原來是場噩夢啊……” “哼!你這毒婦竟也來了?” 一聲冷哼從身側響起,我...
    開封第一講書人閱讀 42,549評論 0 289
  • 序言:老撾萬榮一對情侶失蹤,失蹤者是張志新(化名)和其女友劉穎,沒想到半個月后,有當地人在樹林里發現了一具尸體,經...
    沈念sama閱讀 49,104評論 1 335
  • 正文 獨居荒郊野嶺守林人離奇死亡,尸身上長有42處帶血的膿包…… 初始之章·張勛 以下內容為張勛視角 年9月15日...
    茶點故事閱讀 40,914評論 3 356
  • 正文 我和宋清朗相戀三年,在試婚紗的時候發現自己被綠了。 大學時的朋友給我發了我未婚夫和他白月光在一起吃飯的照片。...
    茶點故事閱讀 43,089評論 1 371
  • 序言:一個原本活蹦亂跳的男人離奇死亡,死狀恐怖,靈堂內的尸體忽然破棺而出,到底是詐尸還是另有隱情,我是刑警寧澤,帶...
    沈念sama閱讀 38,647評論 5 362
  • 正文 年R本政府宣布,位于F島的核電站,受9級特大地震影響,放射性物質發生泄漏。R本人自食惡果不足惜,卻給世界環境...
    茶點故事閱讀 44,340評論 3 347
  • 文/蒙蒙 一、第九天 我趴在偏房一處隱蔽的房頂上張望。 院中可真熱鬧,春花似錦、人聲如沸。這莊子的主人今日做“春日...
    開封第一講書人閱讀 34,753評論 0 28
  • 文/蒼蘭香墨 我抬頭看了看天上的太陽。三九已至,卻和暖如春,著一層夾襖步出監牢的瞬間,已是汗流浹背。 一陣腳步聲響...
    開封第一講書人閱讀 36,007評論 1 289
  • 我被黑心中介騙來泰國打工, 沒想到剛下飛機就差點兒被人妖公主榨干…… 1. 我叫王不留,地道東北人。 一個月前我還...
    沈念sama閱讀 51,834評論 3 395
  • 正文 我出身青樓,卻偏偏與公主長得像,于是被迫代替她去往敵國和親。 傳聞我的和親對象是個殘疾皇子,可洞房花燭夜當晚...
    茶點故事閱讀 48,106評論 2 375