本文首發于渣畫質的救贖——基于GAN的超分辨率方案
歡迎關注專欄深度學習下的計算機視覺
本次給大家介紹2篇文章——SRGAN[3]和ESRGAN[5],基于生成對抗網絡的超分辨率方案。
我自己不是研究GAN或者超分的,而這個工作至少對我或者其它計算機視覺領域的研究者有兩個幫助:
計算機視覺里總會遇到小物體、模糊物體,不妨試試把超分網絡加入進去,如[1,2]。
所有輸出一張圖片的任務都可以嘗試用GAN去解決,比如各種質量改善:圖像去噪,去霧,去模糊;顯著性檢測,圖像分割…… 而GAN用作超分重建算是GAN所有任務中最簡單又最容易拓展的一個。
下面進入正題:
1. 圖像放大與基于CNN的超分方案
所有使用電腦的人都會用到圖片放大的操作,而把圖片放大超過其分辨率后,圖片就會模糊。傳統放大圖片都是采用插值的方法,最常用的有最近鄰插值(Nearest-neighbor)、雙線性插值(Bilinear)、雙立方插值(bicubic)等,這些方法的速度非常快,因此被廣泛采用。
基于卷積神經網絡/深度學習超分方案的開山之作叫做SRCNN(Super-Resolution Convolutional Neural Network),由港中大多媒體實驗室在2015年提出。
基于卷積神經網絡的方案很容易理解:傳統插值方法可以看做把像素復制放大倍數后,用某種固定的卷積核去卷積;而基于卷積神經網絡的超分方法無非是去學習這個卷積核,根據構建出的超分圖像與真實的高分辨率圖像(Ground Truth)的差距去更新網絡參數。
這種方法的損失函數一般都采用均方誤差MSE——即構建出來的超分圖像與真實超分圖像逐像素差的平方:
(1)
是網絡參數,
是低分辨率圖像,
是重建出來的高分辨率圖像,
是真實的高分辨率圖像,
分別是圖片數量、圖片寬和高,都可以看成常數。
這種方法存在一個問題:盡管在客觀評價指標——MSE(均方誤差)、PSNR(峰值信噪比)上成績很好,但總是傾向于生成過于平滑的圖像,如下圖所示:
左邊是基于優化MSE的深度學習方法,中間是下面要介紹的SRGAN,右邊是真實的超分圖片。
2. SRGAN
2.1 SRGAN對比GAN
SRGAN率先把GAN引入到超分辨領域。了解GAN的朋友可以快速理解SRGAN:與標準的GAN結構相比,SRGAN生成器的輸入不再是噪聲,而是低分辨率圖像;而判別器結構跟普通的GAN沒有什么區別。
2.2 SRGAN的損失函數
與先前的基于深度學習的超分方法相比,SRGAN只有一個明顯的變化: (生成器的)Loss函數不再單是對構建出來圖片與真實高分辨率圖片求均方誤差,而是加上對構建出圖片的特征圖與真實高分辨率圖片的特征圖求均方誤差。
作者定義了一個內容損失(Content loss),為原MSE與特征圖MSE加權和[4]:
(2)
即上面的公式(1),
的具體做法是分別得到真實圖片與構建圖片在VGG19下的某個層的特征圖(tensorlayer的實現是用所有的特征圖),在特征圖下求MSE,這部分loss被稱為VGG loss:
? ? ? (3)
公式(2)與公式(1)相比,無非是多了個 。
指的是第i個maxpooling層前的第j個卷積的特征圖。
我這里用的是TensorLayer的實現,和原論文有些不同:原文并沒有提到公式(2),只說了,沒說要加上
...?
的權重是個超參,
是TensorLayer的設置,設置到這么小應該是因為用了所有的特征圖。這個超參不要亂動,不知道是前輩們調了多久調出來的……
至于 為什么能提升效果不好證明,我的理解是使用特征圖計算MSE相對于直接計算MSE約束變弱了。直接計算MSE約束太強,生成的圖片有很低的MSE但過于平滑,降低要求后,MSE變高,但人眼看上去反而覺得效果變好了。
整個生成器的損失函數被作者稱為“感知損失(Perceptual loss)”,除了內容損失以外還要加上一個GAN原有的對抗損失:
? ? (4)
是判別器對于生成超分圖片的輸出,-log(x)在(0,1)上是個單調遞減的函數。生成器希望
的值越大越好,也就是
越小越好,因此梯度更新的時候需要最小化
。
最終生成器的損失函數為:
? ? ? (5)
該部分代碼[4]如下:
g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
g_loss = mse_loss + vgg_loss + g_gan_loss
代碼中的logits_real和logits_fake分別判別器對是真實高分圖片、GAN生成的高分圖片的輸出。fake_patchs, hr_patchs分別是生成器的輸出、真實的高分圖片。feature_fake、feature_real是構建的圖片、真實圖片在VGG網絡中的特征圖:
feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
feature_real = VGG((hr_patchs+1)/2.)
至于判別器的損失,沒有什么特別之處:
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
d_loss = d_loss1 + d_loss2
到這里SRGAN就差不多講完了,至于生成器和判別器使用了什么樣的CNN不是本文重點,大家看一眼就行:
注意生成器里面使用的是殘差結構,并使用了BN,這兩個點是ESRGAN改進的地方。
3. ESRGAN
ESRGAN[5]這篇論文中的是ECCV2018的workshop,沒中ECCV應該是因為這篇文章中絕大部分改進都是直接使用別人的方法,但這并不代表這篇論文不夠出色。這篇論文是港中大多媒體實驗室拿到超分比賽冠軍的模型,打比賽當然是追求效果好了。該論文在Google Scholar上的引用已經有200多次,其GitHub項目有2000多個Star!
ESRGAN的整體框架和SRGAN保持一致,相比SRGAN,ESRGAN有4處改進。
3.1 改進一:用Dense Block替換Residual Block,并去掉BN層
如題,如圖……
去掉BN并加上Dense Block效果為什么好?作者對該問題的答案并沒有給出很好的解釋,這是因為,作者寫這篇文章的時候,[6,7]這兩項研究還沒有出世,所以作者的解釋不用看了,讓我來解釋吧…
3.1.1 為什么要去掉BN?
推薦這篇博客https://zhuanlan.zhihu.com/p/43200897 這里直接引用如下:
對于有些像素級圖片生成任務來說,BN效果不佳;
對于圖片分類等任務,只要能夠找出關鍵特征,就能正確分類,這算是一種粗粒度的任務,在這種情形下通常BN是有積極效果的。但是對于有些輸入輸出都是圖片的像素級別圖片生成任務,比如圖片風格轉換等應用場景,使用BN會帶來負面效果,這很可能是因為在Mini-Batch內多張無關的圖片之間計算統計量,弱化了單張圖片本身特有的一些細節信息。
以及這篇博客http://www.pianshen.com/article/2449328261/
以圖像超分辨率來說,網絡輸出的圖像在色彩、對比度、亮度上要求和輸入一致,改變的僅僅是分辨率和一些細節。而Batch Norm,對圖像來說類似于一種對比度的拉伸,任何圖像經過Batch Norm后,其色彩的分布都會被歸一化。也就是說,它破壞了圖像原本的對比度信息,所以Batch Norm的加入反而影響了網絡輸出的質量。ResNet可以用BN,但也僅僅是在殘差塊當中使用。
3.1.2 為什么要使用Dense Block?
? ? ? 論文[6] How does batch normalization help optimization指出,BN的作用是網絡更容易優化,不容易陷入局部極小值。ESRGAN去掉了BN,可以猜想,如果保持原有的Residual Block結構,網絡會變得非常難易訓練,而且很容易陷入局部極小值導致結果不好。論文[7] Visualizing the loss landscape of neural nets可視化了一些網絡的解空間:
可以看到,DenseNet的解空間非常平滑,也就是說,DenseNet相比其他網絡要容易訓練的多,Dense Block和BN提升網絡性能的原因是相同的!(劃重點!!!)
BN有副作用所以去掉了BN,所以要拿Dense Block來彌補!
推薦下我的這篇博客,https://zhuanlan.zhihu.com/p/86886887 對這個問題有更詳盡的說明。
3.2 改進二:改進對抗損失函數——使用Relativistic GAN
Relativistic GAN[8]改進了GAN的損失函數:
表示判別器的輸出,
是sigmoid函數。
是對一個batch的數據求平均,可以不用看。上圖的意思就是:原來的判別器希望【輸入真實圖片】后輸出趨近于1,現在判別器希望【輸入真實圖片-輸入生成圖片】后輸出趨近于1;原來的判別器希望【輸入生成圖片】后輸出趨近于0,現在判別器希望【輸入生成圖片-輸入真實圖片】后的輸出趨近于0。
判別器該部分的損失為:
? (6)
生成器該部分的損失為:
(7)
下面我推導一下(6,7)兩個公式是怎么來的:
Ian Goodfellow 在GAN開篇之作給出的公式:
(8)
可以改寫為:
(9)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (10)
(10)的另一種形式:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (11)
(x表示真實圖片,z表示輸入G網絡的噪聲)
將公式(9)的換成
,
換成
就變成了公式(6)。公式(11)比公式(7)和(8)少一項是因為公式(8)中的第1項和生成器G無關。而Relativistic GAN,這里變成了相關的,把公式(11)改寫為:
? ? ? ? ? ? ? (12)
再次將公式(12)的 換成
,
換成
就變成了公式(7)。
3.3 改進三:改進生成器——使用relu激活前的特征圖計算損失
作者解釋有兩個原因:
激活后的特征圖變的非常稀疏,丟失了很多信息。
使用激活后的特征圖會造成重建圖片在亮度上的不連續。
此時,就可以求出生成器的損失函數:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (13)
其中,是重建圖片與真實高清圖片逐像素差的絕對值,也就是把MSE換成絕對值。
感知損失就是前面公式(3),只是把激活前的特征圖換成激活后的特征圖。
順便說一句,SRGAN和ESRGAN給損失函數起的名字不同!我不知道ESRGAN作者沒延續SRGAN的名字是寫錯了還是有意為之,反正ESRGAN給loss起的名字比SRGAN合理!SRGAN中的Perceptual Loss是指整個生成器的損失函數,而ESRGAN是指由特征圖計算出來的Loss;SRGAN中的Content Loss是指原圖的MSE+其特征圖的MSE,而ESRGAN是指由真實高清圖與重建圖直接計算的L1 loss。
來個表格對比一下兩篇論文的命名:
3.4 改進四:使用網絡插值(network interpolation)方法平衡解決客觀評價指標與主觀評價指標的矛盾
基于GAN的方法有一個缺點,經常會生成奇怪的紋理,而非GAN的方法總是缺失細節,能不能把兩種方法生成的圖片加權相加呢?這就是所謂的Network Interpolation。
訓練一個非GAN的網絡,在這個網絡的基礎上fine-tuning出GAN的生成器,然后把兩個網絡的參數加權相加:
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? (14)
這個公式非常好理解了兩個網絡的參數相加。其實等價于兩種方法的輸出分別相加。
通過 這個參數可以對生成圖片的平滑程度進行調控,這點還是很爽的。
到這里ESRGAN也介紹完了,最后強烈推薦官方代碼
https://github.com/xinntao/ESRGAN?github.com
我git clone完在pytorch環境下可以直接運行,測試下效果:
左,中,右分別為低分辨率圖像,ESRGAN生成的圖像,原始高清圖像。可以看到中間的圖生成了很多紋理,看起來比右邊的原圖還要清晰。。。
代碼解析如果有空的話我會補一下(我好忙啊...)
[1] Noh, Junhyug, et al. "Better to Follow, Follow to Be Better: Towards Precise Supervision of Feature Super-Resolution for Small Object Detection." Proceedings of the IEEE International Conference on Computer Vision. 2019.
[2] Li, Jianan, et al. "Perceptual generative adversarial networks for small object detection." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2017.
[3] Ledig, Christian, et al. "Photo-realistic single image super-resolution using a generative adversarialnetwork." Proceedings of the IEEE conference on computer vision and pattern recognition. 2017.
[4] tensorlayer/srgan [https://github.com/tensorlayer/srgan]
[5] Wang, Xintao, et al. "Esrgan: Enhanced super-resolution generative adversarial networks." Proceedings of the European Conference on Computer Vision (ECCV). 2018.
[6] Santurkar, Shibani, et al. "How does batch normalization help optimization?." Advances in Neural Information Processing Systems. 2018.
[7] Li, Hao, et al. "Visualizing the loss landscape of neural nets." Advances in Neural Information Processing Systems. 2018.
[8] Jolicoeur-Martineau, Alexia. "The relativistic discriminator: a key element missing from standard GAN." arXiv preprint arXiv:1807.00734 (2018).