人生苦短我用GAN
首先聲明一下,本教程面向入門吃瓜群眾,大牛可以繞道,閑話不多說,先方一波廣告。(高級GAN玩法),怎么說,我越來越感覺到人工智能正在迎來生成模型的時代,以前海量數據訓練模型的辦法有點揠苗助長,看似效果很好,實際上機器什么卵都沒有學到(至少從遷移性上看缺少一點味道,不過就圖片領域來說另當別論,在CV領域監督學習還是相當成功)。
但是問題來了,GAN這么屌這么牛逼,我怎么搞?怎么入門?誰帶我?慌了!
莫慌,50行代碼你就可以成為無監督學習大牛
我最討厭那些,嘴里一堆算法,算法實現不出來的人。因為我喜歡看到結果啊!尤其是一些教程,就是將論文,雞巴論文獎那么多有什么用?你碼代碼給我看啊,我不知道數據是什么,不知道輸入維度是什么,輸出什么,里面到底發生了什么變化我怎么學?這就有點像,典型的在沙漠里教你釣魚,在我看來,論文應該是最后才去看的東西。但是問題在于,你要有一個入門的教程啊。我想這是一個鴻溝,科研里面,理論和動手的鴻溝。
這篇教程就是引路人了。歡迎加入生成模型隊伍。這個教程會一直保持更新,因為科技每天變幻莫測,同時我還會加入很多新內容,改進一些在以后看來是錯誤的說法。
首先,我們廢話不多說了,直接show you the code:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
from scipy import stats
def generate_real_data_distribution(n_dim, num_samples):
all_data = []
for i in range(num_samples):
x = np.random.uniform(0, 8, n_dim)
y = stats.lognorm.pdf(x, 0.6)
all_data.append(y)
all_data = np.array(all_data)
print('generated data shape: ', all_data.shape)
return all_data
def batch_inputs(all_data, batch_size=6):
assert isinstance(all_data, np.ndarray), 'all_data must be numpy array'
batch_x = all_data[np.random.randint(all_data.shape[0], size=batch_size)]
return Variable(torch.from_numpy(batch_x).float())
def main():
# 給generator的噪音維數
n_noise_dim = 30
# 真實數據的維度
n_real_data_dim = 256
num_samples = 666
lr_g = 0.001
lr_d = 0.03
batch_size = 6
epochs = 1000
real_data = generate_real_data_distribution(n_real_data_dim, num_samples=num_samples)
print('sample from real data: \n', real_data[: 10])
g_net = nn.Sequential(
nn.Linear(n_noise_dim, 128),
nn.ReLU(),
nn.Linear(128, n_real_data_dim)
)
d_net = nn.Sequential(
nn.Linear(n_real_data_dim, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
opt_d = torch.optim.Adam(d_net.parameters(), lr=lr_d)
opt_g = torch.optim.Adam(g_net.parameters(), lr=lr_g)
for epoch in range(epochs):
for i in range(num_samples // batch_size):
batch_x = batch_inputs(real_data, batch_size)
batch_noise = Variable(torch.randn(batch_size, n_noise_dim))
g_data = g_net(batch_noise)
# 用G判斷兩個輸出分別多大概率是來自真正的畫家
prob_fake = d_net(g_data)
prob_real = d_net(batch_x)
# 很顯然,mean里面的這部分是一個負值,如果想整體loss變小,必須要變成正直,加一個負號,否則會越來越大
d_loss = -torch.mean(torch.log(prob_real) + torch.log(1 - prob_fake))
# 而g的loss要使得discriminator的prob_fake盡可能小,這樣才能騙過它,因此也要加一個負號
g_loss = -torch.mean(torch.log(prob_fake))
opt_d.zero_grad()
d_loss.backward(retain_variables=True)
opt_d.step()
opt_g.zero_grad()
g_loss.backward(retain_variables=True)
opt_g.step()
print('Epoch: {}, batch: {}, d_loss: {}, g_loss: {}'.format(epoch, i, d_loss.data.numpy()[0],
g_loss.data.numpy()[0]))
if __name__ == '__main__':
main()
這些代碼,總共,也就是90行,核心代碼50行,基本上,比你寫一個其他程序都端,什么紅黑算法,什么排序之類的。我個人比較喜歡簡約,我很多時候不喜歡太雞巴隆昌的代碼。
直接開始訓練吧
這個GAN很簡單,三部分:
- real data生成,這個real data我們怎么去模擬呢?注意這里用的數據是二維的,不是圖片,圖片是三維的,二維你可以看成是csv,或者是序列,在這里面我們每一行,也就是一個樣本,是sample自某個分布的數據,這里用的分布式lognorm;
- d_net 和 g_net,這里兩個net都是非常小,小到爆炸,這如果要是用tensorflow寫就有點蛋疼了,我選擇PyTorch,一目了然;
- loss,loss在GAN中非常重要,是接下來的重點。
OK,一陣復制粘貼,你就可以訓練一個GAN,這個GAN用來做什么?就是你隨機輸入一個噪音,生成模型將會生成一個和lognorm分布一樣的數據。也就是說,生成模型學到了lognrom分布。這能說明什么?神經網絡學到了概率!用到圖片里面就是,他知道哪個顏色快可能是什么東西,這也是現在的CycleGAN, DiscoGAN的原理。
我吃飯去了
未完待續...
來了
繼續剛才的,好像我寫的文章沒有人看啊,傷感。自己寫自己看吧,哎,我騷味改了一下代碼,loss函數部分,之前的寫錯了,我偷一張圖把。
這個是公式,原始GAN論文里面給的公式,但是毫無疑問,正如很多人說的那樣,GAN很容易漂移:
Epoch: 47, batch: 66, d_loss: 0.7026655673980713, g_loss: 2.0336945056915283
Epoch: 47, batch: 67, d_loss: 0.41225430369377136, g_loss: 2.1994106769561768
Epoch: 47, batch: 68, d_loss: 0.674636960029602, g_loss: 1.5774009227752686
Epoch: 47, batch: 69, d_loss: 0.5779278874397278, g_loss: 2.2797725200653076
Epoch: 47, batch: 70, d_loss: 0.4029145836830139, g_loss: 2.200833559036255
Epoch: 47, batch: 71, d_loss: 0.7264774441719055, g_loss: 1.5658557415008545
Epoch: 47, batch: 72, d_loss: 0.46858924627304077, g_loss: 2.355680227279663
Epoch: 47, batch: 73, d_loss: 0.6716371774673462, g_loss: 1.7127293348312378
Epoch: 47, batch: 74, d_loss: 0.7237206101417542, g_loss: 1.4458404779434204
Epoch: 47, batch: 75, d_loss: 0.9684935212135315, g_loss: 1.943861961364746
Epoch: 47, batch: 76, d_loss: 0.4705852270126343, g_loss: 2.439894199371338
Epoch: 47, batch: 77, d_loss: 0.4989328980445862, g_loss: 1.5290288925170898
Epoch: 47, batch: 78, d_loss: 0.44530192017555237, g_loss: 2.9254989624023438
Epoch: 47, batch: 79, d_loss: 0.6329593658447266, g_loss: 1.7527830600738525
Epoch: 47, batch: 80, d_loss: 0.42348209023475647, g_loss: 1.856258749961853
Epoch: 47, batch: 81, d_loss: 0.5396828651428223, g_loss: 2.268836498260498
Epoch: 47, batch: 82, d_loss: 0.9727945923805237, g_loss: 1.0528483390808105
Epoch: 47, batch: 83, d_loss: 0.7551510334014893, g_loss: 1.508225917816162
Epoch: 47, batch: 84, d_loss: 2.4204068183898926, g_loss: 1.5375216007232666
Epoch: 47, batch: 85, d_loss: 1.517686128616333, g_loss: 0.6334291100502014
Epoch: 47, batch: 86, d_loss: inf, g_loss: 0.7990849614143372
Epoch: 47, batch: 87, d_loss: nan, g_loss: nan
Epoch: 47, batch: 88, d_loss: nan, g_loss: nan
Epoch: 47, batch: 89, d_loss: nan, g_loss: nan
Epoch: 47, batch: 90, d_loss: nan, g_loss: nan
Epoch: 47, batch: 91, d_loss: nan, g_loss: nan
你如果train一下的話會發現,到一定程度就會nan,這個nan我就無法理解了,按道理來說,從loss來看我們定義的來自以log,如果為無窮那么應該是log(0)了,但是我們的discriminator出來的函數是sigmoid啊,sigmoid不可能為0,只看是0-1且不包括閉區間。這個問題比較玄學。
既然nan的話,我也不深究是因為啥了,總之這個重點在于loss,因為后面GAN的變種基本上都是在loss的訓練形式上。
GAN 生成mnist
我們現在玩一下mnist把。
交流
我見了一個GAN群,加我微信讓我拉進來。jintianiloveu, 順便下載一個我做的app吧,內側中,專門用來看美女圖片的,你懂得。。傳送門