PyTorch-13 生成對抗網絡(DCGAN)教程

要閱讀帶插圖的教程,請前往 http://studyai.com/pytorch-1.4/beginner/dcgan_faces_tutorial.html

本教程將通過一個示例介紹DCGANs。我們將訓練一個生成對抗網絡(generative adversarial network, GAN), 在給它展示許多名流的照片之后,產生新的名人。這里的大部分代碼都來自 pytorch/examples 的實現(xiàn), 本文檔將詳細解釋實現(xiàn),并闡明該模型是如何工作的和為什么工作的。但別擔心,不需要事先知道GANs, 但它可能需要第一次花一些時間來推理在表象的下面真正發(fā)生了什么。此外,為了時間,有一個或兩個GPU可能是個好事兒。 讓我們從頭開始。

生成對抗網絡

什么是 GAN?

GANS是一個框架,它教授DL模型以捕獲訓練數據的分布,這樣我們就可以從相同的分布生成新的數據。 GANs 是由伊恩·古德費羅于2014年發(fā)明的,并首次在論文 Generative Adversarial Nets 中進行了描述。它們由兩種不同的模型組成,一種是生成器(generator),另一種是判別器(discriminator)。 生成器的工作是生成看起來像訓練圖像的“假”圖像。判別器的工作是查看圖像并輸出它是真實的訓練圖像還是來自生成器的假圖像。 在訓練過程中,生成器不斷地試圖通過生成越來越好的偽圖像來勝過判別器,而判別器正在努力成為一名更好的偵探, 并正確地對真假圖像進行分類。這個游戲的均衡是當生成器生成看起來像是直接來自訓練數據的完美假象時, 判別器總是以50%的信心猜測生成器輸出是真是假的。

現(xiàn)在,讓我們從判別器開始,在整個教程中定義一些要使用的符號。假設 x
是表示圖像的數據。 D(x) 是判別器網絡,它輸出 x 來自訓練數據而不是生成器的(標量)概率。這里, 由于我們處理的是圖像,D(x) 的輸入是HWC大小為3x64x64的圖像。 直覺上,當 x 來自訓練數據時, D(x) 應該是高的, 當 x 來自生成器時,D(x) 應該是低的。 D(x)

也可以看作是一種傳統(tǒng)的二元分類器。

對于生成器的表示法,設 z
是從標準正態(tài)分布中采樣的潛在空間向量(latent space vector)。 G(z) 表示生成函數,它將潛在向量 z 映射到數據空間。 G 的目標是估計訓練數據的分布 (pdata) ,從而從估計出的分布(pg

)中生成假樣本。

因此, D(G(z))
是生成器 G 輸出的圖像為真實圖像的概率(標量)。 正如 古德費羅的論文, 所描述的那樣, D 和 G 玩了一個極小極大的博弈(minimax game),其中 D 試圖最大化它正確地分類真圖像和假圖像的概率(logD(x)),G 試圖最小化 D 預測其輸出是假的的概率 (log(1?D(G(x)))

) 。文中給出了GAN損失函數:
minGmaxDV(D,G)=Ex~pdata(x)[logD(x)]+Ez~pz(z)[log(1?D(G(x)))]

理論上,這個極小極大博弈的解是 在 pg=pdata

時,判別器只能隨機猜測輸入是真還是假。 然而,GANS的收斂理論仍在積極研究之中,而在現(xiàn)實中,模型并不總是訓練到這一點。

什么又是 DCGAN?

DCGAN是上述GANs的直接擴展,只是它在鑒別器和生成器中分別顯式地使用卷積和卷積轉置層。 它首先由Radford在文章 Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks 提出了一種基于深層卷積生成對抗網絡的無監(jiān)督表示學習方法。 判別器由跨步卷積層(strided convolution layers )、 批歸一化層(batch norm layers) 和 LeakyReLU 激活函數構成。 輸入是3x64x64圖像,輸出是 輸入來自真實數據分布的 標量概率。 生成器由卷積轉置層(convolutional-transpose)、 批歸一化層和 ReLU 激活層組成。 輸入是從標準正態(tài)分布中提取的潛在矢量(latent vector) z

,輸出是 3x64x64 的RGB圖像。 跨步卷積轉置層(strided conv-transpose layers)允許將潛在矢量(latent vector)變換為具有與圖像相同的shape。 作者還就如何設置優(yōu)化器、如何計算損失函數以及如何初始化模型的權重等方面給出了一些提示,這些都將在后面的章節(jié)中加以說明。

from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seem for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

輸入

我們先來定義一些輸入:

dataroot - dataset 文件夾根目錄的路徑。我們將在下一節(jié)中更多地討論數據集。
workers - 用于用 DataLoader 加載數據的工作線程數。
batch_size - 訓練中使用的批次大小。DCGAN 使用的批次大小為128。
image_size - 用于訓練的圖像的空間大小。此實現(xiàn)默認為64x64。 如果需要另一個尺寸,則必須改變D和G的結構。有關更多細節(jié),請參閱 這里 。
nc - 輸入圖像的顏色通道數. 彩色圖像是3通道的。
nz - 潛在向量(latent vector)的長度
ngf - 與通過生成器進行的特征映射的深度有關。
ndf - 設置通過鑒別器傳播的特征映射的深度。
num_epochs - 要運行的訓練回合(epoch)數。長期的訓練可能會帶來更好的效果,但也需要更長的時間。
lr - 用于訓練的學習率. 就像在 DCGAN 論文中建議的, 這個參數設為 0.0002 。
beta1 - Adam 優(yōu)化器的beta1超參數。 就像在 DCGAN 論文中建議的, 這個參數設為 0.5 。
ngpu - 可用的 GPUs 數量。 如果沒有GPU, 代碼將會在 CPU 模式下運行。 如果有多個GPU,那就可以加速計算了。
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

數據

在本教程中,我們將使用 Celeb-A Faces 數據集, 該數據集可以在鏈接的站點上下載,也可以在GoogleDrive中下載。dataset將作為一個名為 img_align_celeba.zip 的文件下載。 下載完后,創(chuàng)建一個名為 celeba 的目錄,并將zip文件解壓縮到該目錄中。 然后,將此筆記本的 dataroot 輸入設置為您剛剛創(chuàng)建的renarba目錄。由此產生的目錄結構應該是:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

這是一個重要的步驟,因為我們將使用 ImageFolder 類,它需要在dataset的根文件夾中有子目錄。 現(xiàn)在,我們可以創(chuàng)建 dataset ,dataloader ,設置設備運行,并最終可視化一些訓練數據。

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

實現(xiàn)

在設置了輸入參數并準備好數據集之后,我們現(xiàn)在可以進入實現(xiàn)了。我們將從wigthts初始化策略開始, 然后詳細討論生成器、判別器、損失函數和訓練循環(huán)。
權重初始化

從DCGAN的文獻中,作者指出所有模型的權重都應從均值=0,stdev=0.2的正態(tài)分布中隨機初始化。 權值函數以初始化模型作為輸入,并重新初始化所有卷積、卷積-轉置和批處理歸一化層,以滿足這一標準。 該函數在初始化后立即應用于模型。

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

生成器(Generator)

生成器 G

被設計用于將潛在空間矢量(z)映射到數據空間。由于我們的數據是圖像, 將 z 轉換為數據空間意味著最終創(chuàng)建一個與訓練圖像(即3x64x64)相同大小的RGB圖像。 在實踐中,這是通過一系列strided 2d convolutional transpose layers 來實現(xiàn)的, 每個層與一個2d batch norm layer和一個relu activation層配對。 生成器的輸出送入到一個tanh函數,將其輸出值壓縮在 [?1,1]

的范圍。 值得注意的是batch norm functions是在conv-transpose layers之后的, 因為這是DCGAN論文的一個關鍵貢獻。這些層有助于訓練期間的梯度流。 DCGAN文章中給出的生成器的結構如下所示。
dcgan_generator

注意,我們在輸入部分(nz, ngf, 和 nc) 中設置的輸入如何影響代碼中的生成器體系結構。 nz 是 z 輸入向量的長度, ngf 與通過生成器傳播的特征圖的大小有關, nc 是輸出圖像中的通道數(對于RGB圖像設置為3)。下面是生成器的代碼。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

現(xiàn)在,我們可以實例化生成器并應用 weights_init 函數。 查看打印的模型,看看生成器對象是如何構造的。

# 創(chuàng)建生成器對象
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# 應用 weights_init 函數 來隨機初始化 所有權重到 mean=0, stdev=0.2.
netG.apply(weights_init)

# 打印輸出模型
print(netG)

判別器(Discriminator)

如上所述,判別器 D
是一種兩類分類網絡,它以圖像為輸入,輸出 輸入圖像為真(而不是假)的標量概率。 這里,D 接受一個 3x64x64 輸入圖像,通過一系列Conv2d、BatchNorm2d和LeakyReLU層處理它, 并通過 sigmoid 激活函數輸出最終的概率。如果有必要的話,可以用更多的層來擴展這個體系結構, 但是使用strided convolution、BatchNorm和LeakyReLU是很有意義的。DCGAN的論文提到, 使用strided convolution而不是pooling來降采樣是一種很好的做法, 因為它讓網絡學習自己的池化函數。此外,batch norm 和leaky relu函數促進了健康的梯度流, 這對于 G 和 D

的學習過程都是至關重要的。

Discriminator Code

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

現(xiàn)在,和生成器一樣,我們可以創(chuàng)建判別器,應用 weights_init 函數,并打印模型的結構。

# 創(chuàng)建 Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# 應用 weights_init 函數,隨機初始化所有權重到 mean=0, stdev=0.2.
netD.apply(weights_init)

# 打印輸出模型
print(netD)

損失函數和優(yōu)化器

當 D
和 G

設置好以后, 我們可以指定它們如何通過損失函數和優(yōu)化器學習。 我們將使用二值交叉熵損失(Binary Cross Entropy loss (BCELoss)) 函數,在 PyTorch 中是如下定義的:
?(x,y)=L={l1,…,lN}?,ln=?[yn?logxn+(1?yn)?log(1?xn)]

注意這個函數提供目標函數中的兩個對數組件的計算 (i.e. log(D(x))
和 log(1?D(G(z)))) 。 我們可以使用 y 指定 BCE 等式的哪一部分將被計算。 這將在訓練過程中完成,稍后會講到。但是理解我們如何通過 改變 y

的值(i.e. GT labels) 去選擇我們想要計算的損失函數的一部分是非常重要的。

接下來,我們將真標簽定義為1,假標簽定義為0。這些標簽將用于計算 D
和 G 的損失, 這也是在原始GAN文章中使用的約定。最后,我們建立了兩個分開的優(yōu)化器,一個用于 D , 一個用于 G 。正如DCGAN論文所指出的,兩者都是Adam優(yōu)化器,其學習速率為0.0002,Beta1=0.5。 為了跟蹤生成器的學習過程,我們將從高斯分布(即固定噪聲)中生成固定批次的潛在向量(latent vectors)。 在訓練循環(huán)中,我們將周期性地將這個固定的噪聲輸入到 G

中。在迭代過程中,我們將看到圖像從噪聲中形成。

# 初始化 BCELoss 函數
criterion = nn.BCELoss()

# 創(chuàng)建一批 latent vectors 用于可視化生成器的進度過程
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# 為在訓練過程中的真假標簽建立約定
real_label = 1
fake_label = 0

# 為 G 和 D 設置 Adam optimizers
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

訓練

最后,現(xiàn)在我們已經定義了GAN框架的所有部分,我們可以對其進行訓練。請注意, 訓練GANs是一種藝術,因為不正確的超參數設置會導致模式崩潰, 而對錯誤的原因幾乎沒有解釋。在這里,我們將密切遵循古德費羅論文中的算法1, 同時遵循在 ganhacks 中顯示的一些最佳實踐。 也就是說,我們將“為真假圖像構造不同的小批量”圖像, 并調整G的目標函數,使 logD(G(z))

最大化。訓練分為兩個主要部分。 第1部分更新判別器,第2部分更新生成器。

**Part 1 - 訓練判別器(Discriminator) **

回想一下,訓練判別器的目標是最大化將給定的輸入正確分類為真或假的概率。 我們希望“通過提升判別器的隨機梯度來更新判別器”。 實際上,我們希望最大化 log(D(x))+log(1?D(G(z)))
。 由于來自于ganhacks 的separate mini-batch的建議, 我們將用兩個步驟來實現(xiàn)上述最大化的計算過程。首先從訓練集構造一批真實樣本,前向通過 D , 計算損失(log(D(x))) ,然后計算后傳梯度。 其次,用當前生成器構造一批假樣本,通過 D 向前傳遞該批樣本, 計算損失 (log(1?D(G(z)))

) ,并用反向傳遞累積梯度。 現(xiàn)在,有了全真和全假批次樣本中積累的梯度,我們再調用判別器的優(yōu)化器進行一步優(yōu)化。

**Part 2 - 訓練生成器(Generator) **

正如在最初的論文中所述,我們希望通過最小化 log(1?D(G(z)))
來訓練生成器,以產生更好的假樣本。 正如前面提到的,Goodfellow沒有提供足夠的梯度,特別是在學習過程的早期。作為修正, 我們希望最大化 log(D(G(z))) 。在代碼中,我們通過以下方法實現(xiàn)了這一點: 用第1部分的判別器對生成器的輸出進行分類,使用真標簽作為GroundTruth計算G的損失, ,隨后在向后傳遞中計算G的梯度,最后用優(yōu)化器的 step 方法更新G的參數。 使用真標簽作為GT標簽用于損失函數的計算似乎有違直覺,但這允許我們使用BCELoss的 log(x) 部分 (而不是 log(1?x)

部分),這正是我們想要的。

最后,我們將做一些統(tǒng)計報告,并在每個epoch結束時,我們將把固定批次噪聲推到生成器中 以可視化地跟蹤G的訓練進度。所報告的訓練統(tǒng)計數字如下:

Loss_D - 判別器損失,是所有真批次和所有假批次樣本上的損失之和 (log(D(x))+log(D(G(z)))

)。
Loss_G - 生成器損失,用 log(D(G(z)))

計算。
D(x) - 所有批次的真樣本上判別器的平均輸出(跨batch)。這個值應該開始接近1,然后當G變得更好時,理論上收斂到0.5。想想這是為什么。
D(G(z)) - 所有批次的假樣本上判別器的平均輸出。這個值應該開始接近0,后面隨著生成器越來越好就收斂到0.5。想想這是為什么。

Note: 這一步可能會花點時間, 這取決于你要運行多少個epoch以及如果你從數據集移除一些數據。

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

結果

最后,讓我們來看看我們是如何做到的。在這里,我們將看到三個不同的結果。 首先,我們將看到D和G在訓練中的損失是如何變化的。第二,我們將在每個epoch的固定噪聲批次上可視化G的輸出。 第三,我們將看到一批真數據,旁邊是一批來自G的假數據。

Loss versus training iteration

下面是迭代過程中 D 與 G 的損失對比圖。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

G的進度的可視化

記住,在每個訓練回合(epoch)之后,我們是如何將generator的輸出保存在固定噪聲批次上的。 現(xiàn)在,我們可以用動畫來可視化G的訓練進度。按“播放”按鈕啟動動畫。

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

真圖像(Real Images) vs. 假圖像(Fake Images)

最后, 讓我們看看真圖像和假圖像吧!

# 從 dataloader 中抓取一個批次的真圖像
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# 繪制最后一個epoch的假圖像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

下一步去哪里

我們的旅程已經到了盡頭,但是有幾個地方你可以從這里去。你可以:

訓練更長的時間看看得到的結果有多好
修改此模型讓其接收不同的數據集 和 可能改變的圖像大小與模型架構
檢查其他一些很酷的 GAN 項目 這里 。
創(chuàng)建一個 GANs 讓它產生 音樂
最后編輯于
?著作權歸作者所有,轉載或內容合作請聯(lián)系作者
平臺聲明:文章內容(如有圖片或視頻亦包括在內)由作者上傳并發(fā)布,文章內容僅代表作者本人觀點,簡書系信息發(fā)布平臺,僅提供信息存儲服務。

推薦閱讀更多精彩內容