翻譯自https://blog.evjang.com/2018/01/nf2.html
原作者:Eric Jang
譯者:尹肖貽
0. 交代故事
我在下面的教程里教你干一件很酷的事兒:
在教程的上半部分,你學(xué)到使用標(biāo)準(zhǔn)化流,把高斯分布這類地攤貨分布,“變形”為高大上的概率分布。你還親自使用PReLU搭建了小型可逆神經(jīng)網(wǎng)絡(luò),實現(xiàn)了一個簡單的鏈?zhǔn)蕉S仿射雙射函數(shù)。
不過,上次的全連接網(wǎng)絡(luò)只用了兩層隱含層,搭建的流弱爆了。更糟的是,非線性激活函數(shù)是單調(diào)的、分段線性的。這樣一來,網(wǎng)絡(luò)只能在原點(diǎn)附近,稍微地扭捏數(shù)據(jù)分布的形狀【譯者按:1)原文用的是manifold流形,對于非從業(yè)者來說,只需知道數(shù)據(jù)在空間中形成一定形狀就可以了2)因為激活函數(shù)在原點(diǎn)附近是非線性的】。這個流想要實現(xiàn)更高級的變換,就完全不給力。比如把各向同性的高斯變成下面雙模式的“雙月”數(shù)據(jù)。
幸好最近的研究發(fā)明了幾個更強(qiáng)大的標(biāo)準(zhǔn)化流。下面我們就去探索其中的幾個技術(shù)。
1. 標(biāo)準(zhǔn)化流之 自回歸(autoregressive)模型和MAF
自回歸概率密度估計技術(shù),比如 WaveNet和PixelRNN,是用來學(xué)習(xí)聯(lián)合概率密度的。該技術(shù)將聯(lián)合概率密度,分解為以為條件概率密度的乘積,這里的
依賴前
個數(shù)據(jù)的取值:
1.1 學(xué)習(xí)階段
上式中的條件概率密度大都有可學(xué)習(xí)的參數(shù)。例如,常見的選擇是單變量高斯分布,自回歸的概率密度就有兩個參數(shù),均值和標(biāo)準(zhǔn)差。這兩個參數(shù)取決于先前的變量
。
這種方法基于一個簡單粗暴的假設(shè):每個變量都依賴于先于其出現(xiàn)的變量,而不是后來的變量。拍腦門就能想明白這不(全)是真的:頂上的像素和底下的像素能有啥因果關(guān)系?不過(令許多研究者驚訝的是),這東西產(chǎn)生出的圖像效果居然還不錯!
1.2 采樣階段
采樣階段,我們從標(biāo)準(zhǔn)高斯采D個"噪聲變量"
,進(jìn)而依次地生成所有樣本:
自回歸采樣技術(shù)可以這么理解,把(從標(biāo)準(zhǔn)高斯分布)采樣得到的噪聲變量轉(zhuǎn)換為一個新的分布形式,或者說新的分布形式是標(biāo)準(zhǔn)高斯分布的轉(zhuǎn)換分布(TransformedDistribution)。
1.3 搭建一個流
有了這個認(rèn)識,我們可以堆疊多個自回歸模型,搭建一個標(biāo)準(zhǔn)化流。這樣做的好處是,你可以更改流中任意的一個雙射函數(shù)的輸入順序,這樣就可以在某一環(huán)節(jié)(因為不合時宜的排序)效果不佳,而在另一個環(huán)節(jié)彌補(bǔ)。
掩模自回歸流(MAF)實現(xiàn)了一個條件高斯自回歸模型。這里對于任何一個分布:
灰色的單元是當(dāng)前待計算的單元,藍(lán)色的單元表示其依賴的單元。和通過計算傳遞給網(wǎng)絡(luò)(分別是涂上了洋紅色、橘色的圓圈)。即使變換只有尺度變換和平移變換,這些變換也錯綜復(fù)雜地依賴先前的變量。對于第一個單元,其和不依賴于x和u,而設(shè)置為可學(xué)習(xí)的變量。
注意,之所以如此這般地設(shè)計變換函數(shù),是因為計算的時候,不必要再計算和的逆。由于變換僅有尺度變換和平移變換,我們只需要逆運(yùn)算這兩個變換:。雙射函數(shù)的前傳和反傳,僅僅取決于和,這樣我們就可以在和的網(wǎng)絡(luò)內(nèi)部使用非線性激活函數(shù)(如ReLU)和非方形矩陣操作。
MAF模型的梯度反傳可以這樣來評估密度函數(shù):
distribution.log_prob(bijector.inverse(x)) + bijector.inverse_log_det_jacobian(x))
2. 時間復(fù)雜度和MADE
自回歸模型和MAF訓(xùn)練較快,因為計算下面D個似然概率
可以利用GPU的并行技術(shù),一次性地用D個線程計算。這里假設(shè)計算任務(wù),如在計算SIMD向量化用CPU或GPU并行計算時那樣,資源沒有上限。
從另一方面看,自回歸采樣較慢,因為你必須等到都算出來,才能計算
。如此只能利用一條線程采樣D個序列,而無法發(fā)揮并行計算的優(yōu)勢。
另一個問題是,在并行的反傳計算時,我們要不要使用兩個不同的(具有不同的輸入長度的)網(wǎng)絡(luò),來計算和
?這樣做不夠高效,尤其考慮到D個網(wǎng)絡(luò)學(xué)到的表示(learned representation)需要共享(只要不違背自回歸的依賴關(guān)系)。在“密度估計的掩模自編碼器”(MADE)一文中,作者給這個問題提供了一個不錯的解決方案:使用一個網(wǎng)絡(luò),把所有的
和
同時計算出來,而用來掩模的權(quán)重保證了自回歸的有效性。
這一技巧也大大簡化了反傳的計算,從所有的x,只需一步即可計算出所有的u(D個輸入,D個輸出)。這就比處理D個不同的網(wǎng)絡(luò)容易的多((D+1)*D/2個輸入,D個輸出)
3.Inverse Autoregressive Flow (IAF)
可逆自回歸流(IAF)的非線性平移和尺度變換操作,其輸入是前一時刻的噪聲變量,而不是輸入前一時刻的數(shù)據(jù)采樣。
前傳(采樣)速度很快:對于所有的,并行D個線程可以一次算出來。IAF也可以使用MADE框架高效地實現(xiàn)并行。
不過,要是你想估計某一新數(shù)據(jù)附近的概率密度,就需要恢復(fù)所有的u,過程很長:首先計算,然后逐次計算。好在計算這一趟以后,IAF生成的數(shù)據(jù)的(log)概率直接就得到了。因為我們已知u的情況下,不需要再通過反向計算x得到u。
細(xì)心的你可能已經(jīng)察覺到了,要是把底下標(biāo)注,頂層標(biāo)注,IAF和MAF的反傳計算沒有差別。對應(yīng)而言,IAF的反傳計算恰好也就是MAF(把x和u交換一下)。所以在tensorflow實現(xiàn)的時候,雙射函數(shù)的類沒有區(qū)別,而且逆轉(zhuǎn)特征來交換正反傳的操作很容易:
iaf_bijector = tfb.Invert(maf_bijector)
IAF和MAF有互補(bǔ)的計算考慮——MAF訓(xùn)練快采樣慢,IFA訓(xùn)練慢采樣快。對于訓(xùn)練網(wǎng)絡(luò),需要做的密度估計操作更多,采樣操作較少,所以通常在學(xué)習(xí)密度的任務(wù)中選擇MAF更為合適。
4.平行聲波網(wǎng)絡(luò)(Parallel Wavenet)
顯而易見的問題是,IAF和MAF能否合二為一,達(dá)到最優(yōu)性能?比如快速的訓(xùn)練和采樣。
答案是:當(dāng)然能!DeepMind實驗室公布的平行聲波網(wǎng)絡(luò)正是這么做的:MAF快速訓(xùn)練一個產(chǎn)生式網(wǎng)絡(luò),IAF在這個網(wǎng)絡(luò)的“啟發(fā)”下最大化采樣的似然概率。回顧一下IAF,計算外來數(shù)據(jù)的概率密度很慢(比如訓(xùn)練集中的數(shù)據(jù)),但是計算采樣數(shù)據(jù)的密度很容易,這樣就省了反傳計算。只要約束“學(xué)生”IAF和“老師”MAF的概率分布差異,就可以完成訓(xùn)練。
這項研究在標(biāo)準(zhǔn)化流的領(lǐng)域里非常重要——最終實現(xiàn)的效果是,實時的音頻合成的速度比采樣的方法快20倍,而且已經(jīng)在谷歌助手之類的產(chǎn)品上線了。
5.NICE and Real-NVP
最后請看Real-NVP,你可以認(rèn)為是IAF雙射函數(shù)的特殊情況。
訓(xùn)練時,在NVP的“配對層”(coupling layer),依次地調(diào)整標(biāo)號。像IAF一樣,
的尺度變換和平移變換取決于
的值。所不同的是,
都只依賴于
, 所以單次的網(wǎng)絡(luò)前傳就可以得到
和
。【譯者按:這里可能的意思是,算出來x以后,求個逆就得到
和
】
對于,這些單元“放行”了所有信號,并被設(shè)置為
。所以可以認(rèn)為Real-NVP是一種特殊的MAF雙射(當(dāng)
)。
因為尺度平移變換后,整個層的統(tǒng)計量可以從或者一次前傳得到,NVP可以在一次前傳和后傳完成所有的計算(也就是說,采樣和估計都很快)。MADE架構(gòu)就不再需要了。
從實驗效果來看,Real-NVP比MAF或IAF都要差,使用同樣多的層數(shù)NVP在我的例子中(比如SIGGRAPH形)也更差。Real-NVP和IAF在二維圖形的情況下幾乎是等價的。唯一的區(qū)別是在第一個單元,IAF通過尺度平移變換得到,不依賴于,而Real-NVP對第一個單元不作處理。
Real-NVP是NICE雙射的后續(xù)工作。NICE只有平移變換并假定。因為NICE不做尺度變換,ILDJ始終是個常數(shù)!
6.批正則化的雙射函數(shù)(batch normalization bijector)
Real-NVP 的論文中提到許多新奇的分布,其中之一是批正則化的雙射函數(shù),用來穩(wěn)定訓(xùn)練過程。在傳統(tǒng)的觀念中,批正則化應(yīng)用在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的場景中,前傳數(shù)據(jù)服從統(tǒng)計意義上的中心集中、方差是對角陣的高斯分布,批數(shù)據(jù)是正態(tài)分布的統(tǒng)計特性(running mean,running variance)通過指數(shù)移動平均值積累。測試時,積累的統(tǒng)計數(shù)據(jù)用來做標(biāo)準(zhǔn)化處理。
在標(biāo)準(zhǔn)化流中,訓(xùn)練時批正則化用在雙射函數(shù)的逆的計算中bijector.inverse,在測試時積累的統(tǒng)計數(shù)據(jù)要去標(biāo)準(zhǔn)化(bijector.forward)。具體而言,批正則化雙射函數(shù)這樣實現(xiàn):
反傳:
1.計算現(xiàn)有變量x的均值和方差
2.更新當(dāng)前的均值和方差
3.用當(dāng)前的均值方差批正則化當(dāng)前數(shù)據(jù)。
前傳:
1.用當(dāng)前的均值方差去標(biāo)準(zhǔn)化,得到數(shù)據(jù)的分布。
歸功于TF雙射函數(shù),這一過程可用這些代碼來實現(xiàn):
class BatchNorm(tfb.Bijector):
def __init__(self, eps=1e-5, decay=0.95, validate_args=False, name="batch_norm"):
super(BatchNorm, self).__init__(
event_ndims=1, validate_args=validate_args, name=name)
self._vars_created = False
self.eps = eps
self.decay = decay
def _create_vars(self, x):
n = x.get_shape().as_list()[1]
with tf.variable_scope(self.name):
self.beta = tf.get_variable('beta', [1, n], dtype=DTYPE)
self.gamma = tf.get_variable('gamma', [1, n], dtype=DTYPE)
self.train_m = tf.get_variable(
'mean', [1, n], dtype=DTYPE, trainable=False)
self.train_v = tf.get_variable(
'var', [1, n], dtype=DTYPE, initializer=tf.ones_initializer, trainable=False)
self._vars_created = True
def _forward(self, u):
if not self._vars_created:
self._create_vars(u)
return (u - self.beta) * tf.exp(-self.gamma) * tf.sqrt(self.train_v + self.eps) + self.train_m
def _inverse(self, x):
# Eq 22. Called during training of a normalizing flow.
if not self._vars_created:
self._create_vars(x)
# statistics of current minibatch
m, v = tf.nn.moments(x, axes=[0], keep_dims=True)
# update train statistics via exponential moving average
update_train_m = tf.assign_sub(
self.train_m, self.decay * (self.train_m - m))
update_train_v = tf.assign_sub(
self.train_v, self.decay * (self.train_v - v))
# normalize using current minibatch statistics, followed by BN scale and shift
with tf.control_dependencies([update_train_m, update_train_v]):
return (x - m) * 1. / tf.sqrt(v + self.eps) * tf.exp(self.gamma) + self.beta
def _inverse_log_det_jacobian(self, x):
# at training time, the log_det_jacobian is computed from statistics of the
# current minibatch.
if not self._vars_created:
self._create_vars(x)
_, v = tf.nn.moments(x, axes=[0], keep_dims=True)
abs_log_det_J_inv = tf.reduce_sum(
self.gamma - .5 * tf.log(v + self.eps))
return abs_log_det_J_inv
ILDJ的數(shù)學(xué)形式可方便地從逆函數(shù)的對數(shù)形式導(dǎo)出(參考單變量的例子)。
7.代碼在此
歸功于JoshDillon和谷歌貝葉斯流研究組的努力,在MADE架構(gòu)下的Masked Autoregressive Flow 來實現(xiàn)u的快速訓(xùn)練的代碼已經(jīng)寫好了。
我創(chuàng)建了一個復(fù)雜的二維分布,是利用這個混合腳本點(diǎn)云“SIGGRAPH”的形狀。我們搭建了數(shù)據(jù)庫、雙射函數(shù)、轉(zhuǎn)換分布,方法和教程一中非常類似,所以就不重復(fù)了,你可以參看這個Jupyter notebook 文件。這個代碼可以通過MAF、IAF、Real-NVP實現(xiàn),使用或不使用批正則化,重構(gòu)雙月形和“SIGGRAPH”形。
一個細(xì)節(jié)很容易被忽略,如果你忘記重置變量順序,這個代碼就不能正常工作。這是因為不重置順序的話,所有的自回歸分解就只能學(xué)到。幸好,Tensorflow有一個換序雙射函數(shù),可以專門實現(xiàn)這一功能。
這里可視化了流的學(xué)習(xí)過程,從一而終。看到它,我想起一臺太妃糖攪拌機(jī)
8.教程總結(jié)
Tensorflow讓標(biāo)準(zhǔn)化流的實現(xiàn)變得簡單,自動收集所有雅各比矩陣行列式的操作讓代碼變得簡單易讀。如果你要選擇一個實現(xiàn)標(biāo)準(zhǔn)化流的方案,要同時考慮前傳和后傳的快捷性,以及平衡流的表達(dá)能力和ILJD的計算速度。
在該教程的第一部分,我介紹了標(biāo)準(zhǔn)化流的動機(jī):我們需要更強(qiáng)大的分布函數(shù),用于強(qiáng)化學(xué)習(xí)和產(chǎn)生式模型。從大背景來看,在變分推斷和隱式密度估計模型成功應(yīng)用的今天,能夠追蹤變換的標(biāo)準(zhǔn)化流模型是否是人工智能應(yīng)用(如機(jī)器人、架構(gòu)估計)中最適合的模型,尚不明確。即便如此,標(biāo)準(zhǔn)化流仍是你工具箱中非常好用的工具,它們已經(jīng)在現(xiàn)實應(yīng)用中發(fā)揮作用,比如谷歌助手團(tuán)隊開發(fā)的實時的產(chǎn)生音頻的模型。
即便像標(biāo)準(zhǔn)化流這樣的顯式的密度估計模型能方便地通過最大似然法訓(xùn)練,這些模型的作用遠(yuǎn)不止于此。它們和VAE、GAN這類模型是互補(bǔ)的。任何模型中的高斯分布直接就換上一個標(biāo)準(zhǔn)化流,比如VAE的先驗分布和GAN的隱變量。舉個例子,這篇論文用標(biāo)準(zhǔn)化流靈活選擇變分先驗,Tensorflow分布的論文提到VAE中標(biāo)準(zhǔn)化流做PixelCNN的解碼器。平行聲波網(wǎng)絡(luò)通過最小化KL散度訓(xùn)練一個IAF“學(xué)生”模型。
標(biāo)準(zhǔn)化流最具啟發(fā)的性質(zhì)之一是每一步計算都是可逆的(即具有函數(shù)的定義清楚的逆)。這就意味著,如果我們想做一次反傳計算時,不必在前傳時預(yù)先存儲激活值,而是重新計算一遍前傳激活值。(存儲這些值有可能代價很大。)在任務(wù)分配的過程很長的情況下,我們可以使用計算的可逆性來“恢復(fù)”過去的每一個選擇狀態(tài),從而限制了計算時的內(nèi)存消耗。這個主意在論文 RevNets中有所體現(xiàn),正是受NICE雙射函數(shù)運(yùn)算可逆的啟發(fā)。我回憶起電影Memento中的情節(jié),主角不能存貯記憶了,于是通過可逆計算來記住事情。
謝謝閱讀!
Code on Github
9.致謝
非常感謝 Dustin Tran, Luke Metz, Jonathan Shen, Katherine Lee, Samy Bengio 等人預(yù)先檢查閱讀了該教程。
10.引用文獻(xiàn)和推薦閱讀
- 這篇博客的內(nèi)容和格式都受到Masked Autoregressive Flow for Density Estimation
一文很深的影響。該論文寫的很好,我最初理解這一研究主題,或多或少起源于此。快去讀一下吧! - 較早涉及NFS的論文有: https://math.nyu.edu/faculty/tabak/publications/Tabak-Turner.pdf and https://arxiv.org/pdf/1302.5125.pdf 和 https://arxiv.org/abs/1505.05770。
- Laurent Dinh的演講以及與推特皮質(zhì)細(xì)胞(Twitter Cortex)研究者的討論。一些簡明扼要的觀點(diǎn)來源于此。
- 用PyMC包實現(xiàn)標(biāo)準(zhǔn)化流教程。
- 還有一些工作我至今還一知半解,那就是弄清從標(biāo)準(zhǔn)化流到Langevin Flow(郎之萬流) 以及 Hamiltonian Flow(哈密頓流)之間的關(guān)系。研究標(biāo)準(zhǔn)化流中的雙射函數(shù)的文獻(xiàn)汗牛充棟,其中一支研究到連續(xù)時間流(Continous-Time Flow,顯然其變換表達(dá)的能力更強(qiáng)大。
11.譯者補(bǔ)充與提問解答
我很喜歡這個教程,清晰易懂,所以把它全文翻譯成中文。考慮到中英文表達(dá)習(xí)慣的差異,譯文并不完全是原文的原句(原詞)對應(yīng),但意思(大概)還原了原文的意思。
這里補(bǔ)充幾個問題與其解答,也歡迎讀者向我(或原作者)提問!
問題1:這篇論文Parallel WaveNet:Fast High-Fidelity Speech Synthesis涉及到了并行化,MAF作為teacher模型,IVF作為student模型,我始終不是很理解這個方法的原理。
回答1:
這個問題拆解為兩部分。第一,為什么要讓MAF做教師,IAF做學(xué)生;第二,并行化的訓(xùn)練是怎樣完成的。
============================================
第一,原文提到了MAF訓(xùn)練快采樣慢,IAF訓(xùn)練慢采樣快。所謂訓(xùn)練就是從數(shù)據(jù)算出模型參數(shù)的過程,采樣就是從模型參數(shù)產(chǎn)生新數(shù)據(jù)的過程。
+++++++++++++++++++++++
對于MAF,訓(xùn)練的時候,把所有的數(shù)據(jù)送進(jìn)兩個網(wǎng)絡(luò),和
。在允許并行的情況下,一步就可以算出來所有的
和
,結(jié)合采樣得到的所有的
,就可以算出
。(注意這里
和
是超參,不是從
計算出來的)然后就可以去訓(xùn)練
和
的參數(shù),這個過程非常快。
采樣的時候,我們手里拿著所有和
的參數(shù),并
和
,從
,
,
,開始算,
,
,
...這樣一直算到
算完,這個過程是串行的而不能做到并行。
注意,采樣階段,這個式子里,
其實就是
。
+++++++++++++++++++++++
對于IAF,訓(xùn)練的時候,從超參和
開始計算得到
,把它們送進(jìn)兩個網(wǎng)絡(luò),
和
,
,
。繼續(xù)計算
。其余的以此類推
。用所有的
、
和
,就可以算出
。然后就可以去訓(xùn)練
和
的參數(shù)。這個過程不能并行。
采樣的時候,我們手里拿著所有和
的參數(shù),采樣基礎(chǔ)分部產(chǎn)生
,然后一步就可以算出所有的
和
,于是
。這個過程是非常快。
+++++++++++++++++++++++
============================================
Parallel WaveNet訓(xùn)練過程是這樣的:
1.訓(xùn)練一個MAF,得到合理的、
。詳細(xì)步驟見上文。
2.采樣基礎(chǔ)網(wǎng)絡(luò),得到的隨機(jī)噪聲,送進(jìn)IAF里得到
(注意這個IAF沒有訓(xùn)練好),和
。
3.給定測試數(shù)據(jù),利用MAF的反向傳遞,計算(log)似然概率
。
4.計算KL距離,作為損失。
============================================
要徹底理解這個框架,要對最大似然法和EM算法有個大概了解。我貼一張圖,從知乎拷貝的。侵權(quán)就刪: